Files
tp2-iaavancee/ex2.py
Namu fecea4f5a0
All checks were successful
SonarQube Scan / SonarQube Trigger (push) Successful in 1m0s
fix: try to fix the code to stop the robot the run in the wall
2025-10-02 08:34:53 +02:00

123 lines
3.8 KiB
Python

import random
from collections import deque
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
class DQN(nn.Module):
def __init__(self, n_states=4, n_actions=2):
"""
Notre modèle à deux états, et peux faire deux actions (trouver à gauche
ou à droite)
:param n_state:
:param n_action:
"""
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_states, 64), nn.ReLU(),
nn.Linear(64, n_actions)
)
def forward(self, x):
"""
:param x:
:return:
"""
return self.net(x)
def epsilon_greedy(epsilon: float, s, policy_net: DQN, n_actions: int) -> int:
if random.random() < epsilon:
return random.randrange(n_actions)
else:
return torch.argmax(policy_net(s)).item()
def train_and_save(weights_path="cartpole_dqn.pth", episodes=2_000, update_target_every=20) -> DQN:
env = gym.make('CartPole-v1')
n_states, n_actions = env.observation_space.shape[0], env.action_space.n
policy_net = DQN(n_states, n_actions) # Q Network
target_net = DQN(n_states, n_actions) # Target network
target_net.load_state_dict(policy_net.state_dict()) # same weights at start
target_net.eval()
optimizer = optim.Adam(policy_net.parameters(), lr=1e-2) #1e-3
gamma = 0.99 # discount factor
epsilon = 1.0 # Fréquence d'exploration initiale
eps_min = 0.05 # Fréquence d'exploration minimale
eps_decay = 0.995 # Facteur de réduction d'epsilon
memory = deque(maxlen=5_000)
batch_size = 64
for ep in range(episodes):
print(f'Episode: {ep}/{episodes}')
s, _ = env.reset()
s = torch.tensor(s, dtype=torch.float32)
done, total_r = False, 0
while not done:
a = epsilon_greedy(epsilon, s, policy_net, n_actions)
ns, r, done, _, _ = env.step(a)
ns = torch.tensor(ns, dtype=torch.float32)
memory.append((s, a, r, ns, done))
s, total_r = ns, total_r + r
if len(memory) >= batch_size:
batch = random.sample(memory, batch_size)
s_b, a_b, r_b, ns_b, d_b = zip(*batch)
s_b = torch.stack(s_b)
ns_b = torch.stack(ns_b)
q_pred = policy_net(s_b).gather(1, torch.tensor(a_b).unsqueeze(1)).squeeze()
with torch.no_grad():
q_next = target_net(ns_b).max(1)[0]
q_target = torch.tensor(r_b, dtype=torch.float32) + \
gamma * q_next * (1 - torch.tensor(d_b, dtype=torch.float32))
loss = ((q_pred - q_target)**2).mean()
optimizer.zero_grad(); loss.backward(); optimizer.step()
epsilon = max(eps_min, epsilon * eps_decay)
if (ep + 1) % update_target_every == 0:
target_net.load_state_dict(policy_net.state_dict())
if (ep + 1) % 20 == 0:
print(f'Episode {ep + 1}: total reward {total_r:.1f}, epsilon {epsilon:.2f}')
env.close()
torch.save(policy_net.state_dict(), weights_path)
print(f'Training finished. Weights saved to {weights_path}')
return policy_net # <--- trained Q-network
def show(weights_path='cartpole_dqn.pth') -> None:
env = gym.make('CartPole-v1', render_mode='human')
qnet = DQN()
qnet.load_state_dict(torch.load(weights_path))
qnet.eval()
s, _ = env.reset()
s = torch.tensor(s, dtype=torch.float32)
done = False
while not done:
a = torch.argmax(qnet(s)).item()
s_, r, done, _, _ = env.step(a)
s = torch.tensor(s_, dtype=torch.float32)
env.close()
print('Demonstration finished.')
if __name__ == '__main__':
trained_model = train_and_save()
show()