diff --git a/ex2.py b/ex2.py index f0573d7..b66645f 100644 --- a/ex2.py +++ b/ex2.py @@ -5,6 +5,7 @@ import gymnasium as gym import torch import torch.nn as nn import torch.optim as optim +import matplotlib.pyplot as plt class DQN(nn.Module): @@ -21,7 +22,6 @@ class DQN(nn.Module): nn.Linear(64, n_actions) ) - def forward(self, x): """ @@ -38,7 +38,7 @@ def epsilon_greedy(epsilon: float, s, policy_net: DQN, n_actions: int) -> int: return torch.argmax(policy_net(s)).item() -def train_and_save(weights_path="cartpole_dqn.pth", episodes=2_000, update_target_every=20): +def train_and_save(weights_path="cartpole_dqn.pth", episodes=2_000, update_target_every=20) -> DQN: env = gym.make('CartPole-v1') n_states, n_actions = env.observation_space.shape[0], env.action_space.n @@ -47,15 +47,16 @@ def train_and_save(weights_path="cartpole_dqn.pth", episodes=2_000, update_targe target_net.load_state_dict(policy_net.state_dict()) # same weights at start target_net.eval() - optimizer = optim.Adam(policy_net.parameters(), lr=1e-3) # <- erreur ici + optimizer = optim.Adam(policy_net.parameters(), lr=1e-2) #1e-3 gamma = 0.99 # discount factor epsilon = 1.0 # Fréquence d'exploration initiale eps_min = 0.05 # Fréquence d'exploration minimale eps_decay = 0.995 # Facteur de réduction d'epsilon - memory = deque(maxlen=5000) + memory = deque(maxlen=5_000) batch_size = 64 for ep in range(episodes): + print(f'Episode: {ep}/{episodes}') s, _ = env.reset() s = torch.tensor(s, dtype=torch.float32) done, total_r = False, 0 @@ -117,5 +118,5 @@ def show(weights_path='cartpole_dqn.pth') -> None: if __name__ == '__main__': - #trained_model = train_and_save() + trained_model = train_and_save() show()