feat: add tp5
All checks were successful
SonarQube Scan / SonarQube Trigger (push) Successful in 24s

This commit is contained in:
Namu
2025-10-12 16:17:06 +02:00
parent d3500bff48
commit 4c3b81b779
2 changed files with 147 additions and 1 deletions

2
ex2.py
View File

@@ -119,5 +119,5 @@ def show(weights_path='cartpole_dqn.pth') -> None:
if __name__ == '__main__':
trained_model = train_and_save()
#trained_model = train_and_save()
show()

146
tp5.py Normal file
View File

@@ -0,0 +1,146 @@
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
class Actor(nn.Module):
"""
The action DNN
"""
def __init__(self, n_states, n_actions):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_states, 128), nn.ReLU(),
nn.Linear(128, 64), nn.ReLU(),
nn.Linear(64, n_actions)
)
def forward(self, x):
return torch.softmax(self.net(x), dim=-1)
class Critic(nn.Module):
"""
The critic DNN
"""
def __init__(self, n_states):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_states, 64), nn.ReLU(),
nn.Linear(64, 1)
)
def forward(self, x):
return self.net(x)
def train_and_save(weights_path="cartpole_actor_critic.pth", episodes=500):
env = gym.make("CartPole-v1")
n_states, n_actions = env.observation_space.shape[0], env.action_space.n
# Definition des DNN acteur & critique
actor_net = Actor(n_states, n_actions)
critic_net = Critic(n_states)
# Hyperparameters et optimiser
optimizer_actor = optim.Adam(actor_net.parameters(), lr=1e-3)
optimizer_critic = optim.Adam(critic_net.parameters(), lr=5e-4)
gamma = 0.99
for ep in range(episodes):
# le state courant donner par l'environnement
s, _ = env.reset()
s = torch.tensor(s, dtype=torch.float32)
# Des variables purement fonctionnelles
done, total_r = False, 0
log_probs = []
td_errors = []
while not done:
# Acteur : choisit une action
action_probs = actor_net(s)
dist = torch.distributions.Categorical(action_probs)
action = dist.sample()
log_prob = dist.log_prob(action)
# Environnement : effectue l'action
ns, r, terminated, truncated, _ = env.step(action.item())
done = terminated or truncated
ns = torch.tensor(ns, dtype=torch.float32)
total_r += r
# Critique : calcule la TD error
with torch.no_grad():
value_ns = critic_net(ns) if not done else torch.tensor([0.0]) # force ns = 0 si la simulation et terminée
value_n = critic_net(s)
td_error = r + gamma * value_ns - value_n # Pas de detach ici, car on veut le gradient pour le critic
# Actor loss
actor_loss = -log_prob * td_error.detach() # Detach td_error pour l'actor
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()
# Critic loss
critic_loss = td_error.pow(2).mean() # MSE
optimizer_critic.zero_grad()
critic_loss.backward()
optimizer_critic.step()
print("value_n:", value_n.item(), "value_ns:", value_ns.item(), "td_error:", td_error.item())
log_probs.append(log_prob)
td_errors.append(td_error)
# Mise à jour de l'état
s = ns
print(f'Episode {ep + 1}: total reward {total_r:.1f}')
fig, axes = plt.subplots(1, 2)
axes[0].plot(range(len(log_probs)), log_probs)
axes[0].set_title('log probability')
axes[1].plot(range(len(td_errors)), td_errors)
axes[1].set_title('TD errors')
plt.show()
# Libération des ressources liées à l'environnement
env.close()
# Sauvegarde des poids
torch.save(actor_net.state_dict(), weights_path)
print(f'Training finished. Weights saved to {weights_path}')
return actor_net
def show(weights_path="cartpole_actor_critic.pth"):
env = gym.make("CartPole-v1", render_mode="human")
actor_net = Actor(env.observation_space.shape[0], env.action_space.n)
actor_net.load_state_dict(torch.load(weights_path))
actor_net.eval()
s, _ = env.reset()
s = torch.tensor(s, dtype=torch.float32)
done = False
while not done:
with torch.no_grad():
action_probs = actor_net(s)
action = torch.argmax(action_probs).item()
s_, r, terminated, truncated, _ = env.step(action)
done = terminated or truncated
s = torch.tensor(s_, dtype=torch.float32)
env.close()
print('Demonstration finished.')
if __name__ == '__main__':
trained_model = train_and_save()
show()