diff --git a/ex2.py b/ex2.py index b66645f..1c06516 100644 --- a/ex2.py +++ b/ex2.py @@ -5,7 +5,6 @@ import gymnasium as gym import torch import torch.nn as nn import torch.optim as optim -import matplotlib.pyplot as plt class DQN(nn.Module): @@ -18,10 +17,11 @@ class DQN(nn.Module): """ super().__init__() self.net = nn.Sequential( - nn.Linear(n_states, 64), nn.ReLU(), - nn.Linear(64, n_actions) + nn.Linear(n_states, 128), nn.ReLU(), + nn.Linear(128, n_actions) ) + def forward(self, x): """ @@ -38,7 +38,7 @@ def epsilon_greedy(epsilon: float, s, policy_net: DQN, n_actions: int) -> int: return torch.argmax(policy_net(s)).item() -def train_and_save(weights_path="cartpole_dqn.pth", episodes=2_000, update_target_every=20) -> DQN: +def train_and_save(weights_path="cartpole_dqn.pth", episodes=2_000, update_target_every=20): env = gym.make('CartPole-v1') n_states, n_actions = env.observation_space.shape[0], env.action_space.n @@ -47,16 +47,15 @@ def train_and_save(weights_path="cartpole_dqn.pth", episodes=2_000, update_targe target_net.load_state_dict(policy_net.state_dict()) # same weights at start target_net.eval() - optimizer = optim.Adam(policy_net.parameters(), lr=1e-2) #1e-3 + optimizer = optim.Adam(policy_net.parameters(), lr=1e-3) gamma = 0.99 # discount factor epsilon = 1.0 # Fréquence d'exploration initiale - eps_min = 0.05 # Fréquence d'exploration minimale - eps_decay = 0.995 # Facteur de réduction d'epsilon - memory = deque(maxlen=5_000) + eps_min = 0.01 # Fréquence d'exploration minimale + eps_decay = 0.999 # Facteur de réduction d'epsilon + memory = deque(maxlen=100_000) batch_size = 64 for ep in range(episodes): - print(f'Episode: {ep}/{episodes}') s, _ = env.reset() s = torch.tensor(s, dtype=torch.float32) done, total_r = False, 0 @@ -109,12 +108,14 @@ def show(weights_path='cartpole_dqn.pth') -> None: s, _ = env.reset() s = torch.tensor(s, dtype=torch.float32) done = False + total_r = 0.0 while not done: a = torch.argmax(qnet(s)).item() s_, r, done, _, _ = env.step(a) + total_r += r s = torch.tensor(s_, dtype=torch.float32) env.close() - print('Demonstration finished.') + print(f'Demonstration finished. {total_r:.1f}') if __name__ == '__main__':