From 140ac03222b2c358c6067e2206186f828515a277 Mon Sep 17 00:00:00 2001 From: Namu Date: Mon, 20 Oct 2025 14:02:01 +0200 Subject: [PATCH] refactor: add 2 spaces in tp7 between function --- tp7.py | 4 +++- tp7_gpt_exemple.py | 4 ++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tp7.py b/tp7.py index b1f0e47..84fd763 100644 --- a/tp7.py +++ b/tp7.py @@ -25,6 +25,7 @@ class Actor(nn.Module): std = torch.exp(log_std) return mu, std + class Critic(nn.Module): def __init__(self, state_dim): super().__init__() @@ -53,6 +54,7 @@ def compute_gae(rewards, values, gamma, lam, next_value): advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) return returns, advantages + # ——— Entraînement ——— def train_and_save(): env = gym.make("Pusher-v5") @@ -151,7 +153,7 @@ def train_and_save(): torch.save(actor.state_dict(), "a2c_pusher.pth") -# ——— Démonstration ——— + def show(weights_path="a2c_pusher.pth"): env = gym.make("Pusher-v5", render_mode="human") actor = Actor(env.observation_space.shape[0], env.action_space.shape[0]) diff --git a/tp7_gpt_exemple.py b/tp7_gpt_exemple.py index 5768983..8818164 100644 --- a/tp7_gpt_exemple.py +++ b/tp7_gpt_exemple.py @@ -41,6 +41,7 @@ class A2CAgent: self.model = ActorCritic(state_dim, action_dim).to(self.device) self.optimizer = optim.Adam(self.model.parameters(), lr=lr) + def select_action(self, state): state = torch.FloatTensor(state).to(self.device) dist, _ = self.model(state) @@ -48,6 +49,7 @@ class A2CAgent: log_prob = dist.log_prob(action).sum(dim=-1) return action.cpu().numpy(), log_prob + def compute_returns(self, rewards, masks, next_value): R = next_value returns = [] @@ -56,6 +58,7 @@ class A2CAgent: returns.insert(0, R) return returns + def update(self, trajectory, next_state): states = torch.FloatTensor([t[0] for t in trajectory]).to(self.device) actions = torch.FloatTensor([t[1] for t in trajectory]).to(self.device) @@ -80,6 +83,7 @@ class A2CAgent: loss.backward() self.optimizer.step() + def train(self, max_steps=2000, update_every=5): state, _ = self.env.reset() trajectory = []