import gymnasium as gym import torch import torch.nn as nn import torch.optim as optim import matplotlib.pyplot as plt class Actor(nn.Module): """ The action DNN """ def __init__(self, n_states, n_actions): super().__init__() self.net = nn.Sequential( nn.Linear(n_states, 128), nn.ReLU(), nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, n_actions) ) def forward(self, x): return torch.softmax(self.net(x), dim=-1) class Critic(nn.Module): """ The critic DNN """ def __init__(self, n_states): super().__init__() self.net = nn.Sequential( nn.Linear(n_states, 64), nn.ReLU(), nn.Linear(64, 1) ) def forward(self, x): return self.net(x) def train_and_save(weights_path="cartpole_actor_critic.pth", episodes=500): env = gym.make("CartPole-v1") n_states, n_actions = env.observation_space.shape[0], env.action_space.n # Definition des DNN acteur & critique actor_net = Actor(n_states, n_actions) critic_net = Critic(n_states) # Hyperparameters et optimiser optimizer_actor = optim.Adam(actor_net.parameters(), lr=1e-3) optimizer_critic = optim.Adam(critic_net.parameters(), lr=5e-4) gamma = 0.99 for ep in range(episodes): # le state courant donner par l'environnement s, _ = env.reset() s = torch.tensor(s, dtype=torch.float32) # Des variables purement fonctionnelles done, total_r = False, 0 log_probs = [] td_errors = [] while not done: # Acteur : choisit une action action_probs = actor_net(s) dist = torch.distributions.Categorical(action_probs) action = dist.sample() log_prob = dist.log_prob(action) # Environnement : effectue l'action ns, r, terminated, truncated, _ = env.step(action.item()) done = terminated or truncated ns = torch.tensor(ns, dtype=torch.float32) total_r += r # Critique : calcule la TD error with torch.no_grad(): value_ns = critic_net(ns) if not done else torch.tensor([0.0]) # force ns = 0 si la simulation et terminée value_n = critic_net(s) td_error = r + gamma * value_ns - value_n # Pas de detach ici, car on veut le gradient pour le critic # Actor loss actor_loss = -log_prob * td_error.detach() # Detach td_error pour l'actor optimizer_actor.zero_grad() actor_loss.backward() optimizer_actor.step() # Critic loss critic_loss = td_error.pow(2).mean() # MSE optimizer_critic.zero_grad() critic_loss.backward() optimizer_critic.step() print("value_n:", value_n.item(), "value_ns:", value_ns.item(), "td_error:", td_error.item()) log_probs.append(log_prob) td_errors.append(td_error) # Mise à jour de l'état s = ns print(f'Episode {ep + 1}: total reward {total_r:.1f}') # Libération des ressources liées à l'environnement env.close() # Sauvegarde des poids torch.save(actor_net.state_dict(), weights_path) print(f'Training finished. Weights saved to {weights_path}') return actor_net def show(weights_path="cartpole_actor_critic.pth"): env = gym.make("CartPole-v1", render_mode="human") actor_net = Actor(env.observation_space.shape[0], env.action_space.n) actor_net.load_state_dict(torch.load(weights_path)) actor_net.eval() s, _ = env.reset() s = torch.tensor(s, dtype=torch.float32) done = False while not done: with torch.no_grad(): action_probs = actor_net(s) action = torch.argmax(action_probs).item() s_, r, terminated, truncated, _ = env.step(action) done = terminated or truncated s = torch.tensor(s_, dtype=torch.float32) env.close() print('Demonstration finished.') if __name__ == '__main__': trained_model = train_and_save() show()