diff --git a/ex2.py b/ex2.py index f0573d7..1c06516 100644 --- a/ex2.py +++ b/ex2.py @@ -17,8 +17,8 @@ class DQN(nn.Module): """ super().__init__() self.net = nn.Sequential( - nn.Linear(n_states, 64), nn.ReLU(), - nn.Linear(64, n_actions) + nn.Linear(n_states, 128), nn.ReLU(), + nn.Linear(128, n_actions) ) @@ -47,12 +47,12 @@ def train_and_save(weights_path="cartpole_dqn.pth", episodes=2_000, update_targe target_net.load_state_dict(policy_net.state_dict()) # same weights at start target_net.eval() - optimizer = optim.Adam(policy_net.parameters(), lr=1e-3) # <- erreur ici + optimizer = optim.Adam(policy_net.parameters(), lr=1e-3) gamma = 0.99 # discount factor epsilon = 1.0 # Fréquence d'exploration initiale - eps_min = 0.05 # Fréquence d'exploration minimale - eps_decay = 0.995 # Facteur de réduction d'epsilon - memory = deque(maxlen=5000) + eps_min = 0.01 # Fréquence d'exploration minimale + eps_decay = 0.999 # Facteur de réduction d'epsilon + memory = deque(maxlen=100_000) batch_size = 64 for ep in range(episodes): @@ -108,14 +108,16 @@ def show(weights_path='cartpole_dqn.pth') -> None: s, _ = env.reset() s = torch.tensor(s, dtype=torch.float32) done = False + total_r = 0.0 while not done: a = torch.argmax(qnet(s)).item() s_, r, done, _, _ = env.step(a) + total_r += r s = torch.tensor(s_, dtype=torch.float32) env.close() - print('Demonstration finished.') + print(f'Demonstration finished. {total_r:.1f}') if __name__ == '__main__': - #trained_model = train_and_save() + trained_model = train_and_save() show()