diff --git a/ex2.py b/ex2.py index 1c06516..6f9a3fc 100644 --- a/ex2.py +++ b/ex2.py @@ -119,5 +119,5 @@ def show(weights_path='cartpole_dqn.pth') -> None: if __name__ == '__main__': - trained_model = train_and_save() + #trained_model = train_and_save() show() diff --git a/tp5.py b/tp5.py new file mode 100644 index 0000000..d419f8a --- /dev/null +++ b/tp5.py @@ -0,0 +1,146 @@ +import gymnasium as gym +import torch +import torch.nn as nn +import torch.optim as optim +import matplotlib.pyplot as plt + + +class Actor(nn.Module): + """ + The action DNN + """ + + def __init__(self, n_states, n_actions): + super().__init__() + self.net = nn.Sequential( + nn.Linear(n_states, 128), nn.ReLU(), + nn.Linear(128, 64), nn.ReLU(), + nn.Linear(64, n_actions) + ) + + def forward(self, x): + return torch.softmax(self.net(x), dim=-1) + + +class Critic(nn.Module): + """ + The critic DNN + """ + + def __init__(self, n_states): + super().__init__() + self.net = nn.Sequential( + nn.Linear(n_states, 64), nn.ReLU(), + nn.Linear(64, 1) + ) + + def forward(self, x): + return self.net(x) + + +def train_and_save(weights_path="cartpole_actor_critic.pth", episodes=500): + env = gym.make("CartPole-v1") + n_states, n_actions = env.observation_space.shape[0], env.action_space.n + + # Definition des DNN acteur & critique + actor_net = Actor(n_states, n_actions) + critic_net = Critic(n_states) + + # Hyperparameters et optimiser + optimizer_actor = optim.Adam(actor_net.parameters(), lr=1e-3) + optimizer_critic = optim.Adam(critic_net.parameters(), lr=5e-4) + gamma = 0.99 + + for ep in range(episodes): + # le state courant donner par l'environnement + s, _ = env.reset() + s = torch.tensor(s, dtype=torch.float32) + + # Des variables purement fonctionnelles + done, total_r = False, 0 + + log_probs = [] + td_errors = [] + + while not done: + # Acteur : choisit une action + action_probs = actor_net(s) + dist = torch.distributions.Categorical(action_probs) + action = dist.sample() + log_prob = dist.log_prob(action) + + # Environnement : effectue l'action + ns, r, terminated, truncated, _ = env.step(action.item()) + done = terminated or truncated + ns = torch.tensor(ns, dtype=torch.float32) + total_r += r + + # Critique : calcule la TD error + with torch.no_grad(): + value_ns = critic_net(ns) if not done else torch.tensor([0.0]) # force ns = 0 si la simulation et terminée + value_n = critic_net(s) + + td_error = r + gamma * value_ns - value_n # Pas de detach ici, car on veut le gradient pour le critic + + # Actor loss + actor_loss = -log_prob * td_error.detach() # Detach td_error pour l'actor + optimizer_actor.zero_grad() + actor_loss.backward() + optimizer_actor.step() + + # Critic loss + critic_loss = td_error.pow(2).mean() # MSE + optimizer_critic.zero_grad() + critic_loss.backward() + optimizer_critic.step() + + print("value_n:", value_n.item(), "value_ns:", value_ns.item(), "td_error:", td_error.item()) + + log_probs.append(log_prob) + td_errors.append(td_error) + + # Mise à jour de l'état + s = ns + + print(f'Episode {ep + 1}: total reward {total_r:.1f}') + + fig, axes = plt.subplots(1, 2) + axes[0].plot(range(len(log_probs)), log_probs) + axes[0].set_title('log probability') + + axes[1].plot(range(len(td_errors)), td_errors) + axes[1].set_title('TD errors') + + plt.show() + + # Libération des ressources liées à l'environnement + env.close() + + # Sauvegarde des poids + torch.save(actor_net.state_dict(), weights_path) + print(f'Training finished. Weights saved to {weights_path}') + return actor_net + + +def show(weights_path="cartpole_actor_critic.pth"): + env = gym.make("CartPole-v1", render_mode="human") + actor_net = Actor(env.observation_space.shape[0], env.action_space.n) + actor_net.load_state_dict(torch.load(weights_path)) + actor_net.eval() + s, _ = env.reset() + s = torch.tensor(s, dtype=torch.float32) + done = False + while not done: + with torch.no_grad(): + action_probs = actor_net(s) + action = torch.argmax(action_probs).item() + s_, r, terminated, truncated, _ = env.step(action) + done = terminated or truncated + s = torch.tensor(s_, dtype=torch.float32) + env.close() + print('Demonstration finished.') + + +if __name__ == '__main__': + trained_model = train_and_save() + show()