import gymnasium as gym import torch import torch.nn as nn import torch.optim as optim import numpy as np import matplotlib.pyplot as plt # ——— Réseaux de neurones ——— class Actor(nn.Module): def __init__(self, state_dim, action_dim): super().__init__() self.net = nn.Sequential( nn.Linear(state_dim, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), ) self.mu_head = nn.Linear(256, action_dim) self.log_std_head = nn.Linear(256, action_dim) def forward(self, state): x = self.net(state) mu = self.mu_head(x) log_std = torch.clamp(self.log_std_head(x), -20, 2) std = torch.exp(log_std) return mu, std class Critic(nn.Module): def __init__(self, state_dim): super().__init__() self.net = nn.Sequential( nn.Linear(state_dim, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 1) ) def forward(self, state): return self.net(state) # ——— GAE ——— def compute_gae(rewards, values, gamma, lam, next_value): values = [v.detach() for v in values] + [next_value.detach()] gae = 0 returns = [] for t in reversed(range(len(rewards))): delta = rewards[t] + gamma * values[t + 1] - values[t] gae = delta + gamma * lam * gae returns.insert(0, gae + values[t]) returns = torch.tensor(returns, dtype=torch.float32) advantages = returns - torch.stack(values[:-1]).squeeze() advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) return returns, advantages # ——— Entraînement ——— def train_and_save(): env = gym.make("Pusher-v5") actor = Actor(env.observation_space.shape[0], env.action_space.shape[0]) critic = Critic(env.observation_space.shape[0]) optimizerA = optim.Adam(actor.parameters(), lr=1e-4) optimizerC = optim.Adam(critic.parameters(), lr=1e-4) gamma = 0.99 lam = 0.95 nb_episodes = 2000 rewards_history = [] advantages_history = [] critic_preds = [] td_errors = [] for episode in range(nb_episodes): state, _ = env.reset() done = False log_probs = [] values = [] rewards = [] entropies = [] while not done: state_tensor = torch.tensor(state, dtype=torch.float32) mu, std = actor(state_tensor) dist = torch.distributions.Normal(mu, std) action = dist.rsample() # clamp pour respecter les limites de l'environnement low = torch.tensor(env.action_space.low, dtype=torch.float32) high = torch.tensor(env.action_space.high, dtype=torch.float32) action_clamped = torch.clamp(action, low, high) next_state, reward, terminated, truncated, _ = env.step(action_clamped.detach().numpy()) done = terminated or truncated reward_scaled = reward / 10.0 # scaling pour stabiliser l'apprentissage value = critic(state_tensor) log_prob = dist.log_prob(action).sum(dim=-1) entropy = dist.entropy().sum(dim=-1) log_probs.append(log_prob) values.append(value) rewards.append(reward_scaled) entropies.append(entropy) state = next_state # next_value pour GAE state_tensor = torch.tensor(state, dtype=torch.float32) next_value = critic(state_tensor).detach() # même si done=True # ——— GAE ——— returns, advantages = compute_gae(rewards, values, gamma, lam, next_value) # ——— Mise à jour Actor ——— log_probs = torch.stack(log_probs) entropies = torch.stack(entropies) actor_loss = -(log_probs * advantages.detach()).mean() - 0.02 * entropies.mean() # entropy coeff réduit optimizerA.zero_grad() actor_loss.backward() optimizerA.step() # ——— Mise à jour Critic ——— critic_loss = (returns - torch.stack(values).squeeze()).pow(2).mean() optimizerC.zero_grad() critic_loss.backward() optimizerC.step() total_reward = sum(rewards) rewards_history.append(total_reward) advantages_history.append(advantages.mean().item()) critic_preds.append(torch.stack(values).mean().item()) td_errors.append((returns - torch.stack(values).squeeze()).mean().item()) print(f"Épisode {episode}, Récompense : {total_reward:.2f}") # ——— Graphiques tous les 500 épisodes ——— if episode % 500 == 0 and episode != 0: fig, axes = plt.subplots(1, 4, figsize=(20, 4)) axes[0].plot(rewards_history, label='Rewards'); axes[0].set_title('Rewards'); axes[0].legend() axes[1].plot(advantages_history, label='Advantages', color='orange'); axes[1].set_title('Advantages'); axes[1].legend() axes[2].plot(critic_preds, label='Critic Prediction', color='green'); axes[2].plot(rewards_history, label='Actual Reward', color='red', linestyle='--'); axes[2].set_title('Critic vs Reward'); axes[2].legend() axes[3].plot(td_errors, label='TD Error', color='purple'); axes[3].set_title('TD Error'); axes[3].legend() plt.suptitle(f'Épisode {episode}') plt.tight_layout() plt.show() torch.save(actor.state_dict(), "a2c_pusher.pth") # ——— Démonstration ——— def show(weights_path="a2c_pusher.pth"): env = gym.make("Pusher-v5", render_mode="human") actor = Actor(env.observation_space.shape[0], env.action_space.shape[0]) actor.load_state_dict(torch.load(weights_path)) actor.eval() state, _ = env.reset() done = False while not done: state_tensor = torch.tensor(state, dtype=torch.float32).detach() with torch.no_grad(): mu, _ = actor(state_tensor) action = mu.detach().numpy() next_state, _, terminated, truncated, _ = env.step(action) done = terminated or truncated state = next_state env.close() print("Demonstration finished.") if __name__ == "__main__": train_and_save() show()