feat: add tp5
All checks were successful
SonarQube Scan / SonarQube Trigger (push) Successful in 24s
All checks were successful
SonarQube Scan / SonarQube Trigger (push) Successful in 24s
This commit is contained in:
2
ex2.py
2
ex2.py
@@ -119,5 +119,5 @@ def show(weights_path='cartpole_dqn.pth') -> None:
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
trained_model = train_and_save()
|
||||
#trained_model = train_and_save()
|
||||
show()
|
||||
|
||||
146
tp5.py
Normal file
146
tp5.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import gymnasium as gym
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class Actor(nn.Module):
|
||||
"""
|
||||
The action DNN
|
||||
"""
|
||||
|
||||
def __init__(self, n_states, n_actions):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(n_states, 128), nn.ReLU(),
|
||||
nn.Linear(128, 64), nn.ReLU(),
|
||||
nn.Linear(64, n_actions)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.softmax(self.net(x), dim=-1)
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
"""
|
||||
The critic DNN
|
||||
"""
|
||||
|
||||
def __init__(self, n_states):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(n_states, 64), nn.ReLU(),
|
||||
nn.Linear(64, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def train_and_save(weights_path="cartpole_actor_critic.pth", episodes=500):
|
||||
env = gym.make("CartPole-v1")
|
||||
n_states, n_actions = env.observation_space.shape[0], env.action_space.n
|
||||
|
||||
# Definition des DNN acteur & critique
|
||||
actor_net = Actor(n_states, n_actions)
|
||||
critic_net = Critic(n_states)
|
||||
|
||||
# Hyperparameters et optimiser
|
||||
optimizer_actor = optim.Adam(actor_net.parameters(), lr=1e-3)
|
||||
optimizer_critic = optim.Adam(critic_net.parameters(), lr=5e-4)
|
||||
gamma = 0.99
|
||||
|
||||
for ep in range(episodes):
|
||||
# le state courant donner par l'environnement
|
||||
s, _ = env.reset()
|
||||
s = torch.tensor(s, dtype=torch.float32)
|
||||
|
||||
# Des variables purement fonctionnelles
|
||||
done, total_r = False, 0
|
||||
|
||||
log_probs = []
|
||||
td_errors = []
|
||||
|
||||
while not done:
|
||||
# Acteur : choisit une action
|
||||
action_probs = actor_net(s)
|
||||
dist = torch.distributions.Categorical(action_probs)
|
||||
action = dist.sample()
|
||||
log_prob = dist.log_prob(action)
|
||||
|
||||
# Environnement : effectue l'action
|
||||
ns, r, terminated, truncated, _ = env.step(action.item())
|
||||
done = terminated or truncated
|
||||
ns = torch.tensor(ns, dtype=torch.float32)
|
||||
total_r += r
|
||||
|
||||
# Critique : calcule la TD error
|
||||
with torch.no_grad():
|
||||
value_ns = critic_net(ns) if not done else torch.tensor([0.0]) # force ns = 0 si la simulation et terminée
|
||||
value_n = critic_net(s)
|
||||
|
||||
td_error = r + gamma * value_ns - value_n # Pas de detach ici, car on veut le gradient pour le critic
|
||||
|
||||
# Actor loss
|
||||
actor_loss = -log_prob * td_error.detach() # Detach td_error pour l'actor
|
||||
optimizer_actor.zero_grad()
|
||||
actor_loss.backward()
|
||||
optimizer_actor.step()
|
||||
|
||||
# Critic loss
|
||||
critic_loss = td_error.pow(2).mean() # MSE
|
||||
optimizer_critic.zero_grad()
|
||||
critic_loss.backward()
|
||||
optimizer_critic.step()
|
||||
|
||||
print("value_n:", value_n.item(), "value_ns:", value_ns.item(), "td_error:", td_error.item())
|
||||
|
||||
log_probs.append(log_prob)
|
||||
td_errors.append(td_error)
|
||||
|
||||
# Mise à jour de l'état
|
||||
s = ns
|
||||
|
||||
print(f'Episode {ep + 1}: total reward {total_r:.1f}')
|
||||
|
||||
fig, axes = plt.subplots(1, 2)
|
||||
axes[0].plot(range(len(log_probs)), log_probs)
|
||||
axes[0].set_title('log probability')
|
||||
|
||||
axes[1].plot(range(len(td_errors)), td_errors)
|
||||
axes[1].set_title('TD errors')
|
||||
|
||||
plt.show()
|
||||
|
||||
# Libération des ressources liées à l'environnement
|
||||
env.close()
|
||||
|
||||
# Sauvegarde des poids
|
||||
torch.save(actor_net.state_dict(), weights_path)
|
||||
print(f'Training finished. Weights saved to {weights_path}')
|
||||
return actor_net
|
||||
|
||||
|
||||
def show(weights_path="cartpole_actor_critic.pth"):
|
||||
env = gym.make("CartPole-v1", render_mode="human")
|
||||
actor_net = Actor(env.observation_space.shape[0], env.action_space.n)
|
||||
actor_net.load_state_dict(torch.load(weights_path))
|
||||
actor_net.eval()
|
||||
s, _ = env.reset()
|
||||
s = torch.tensor(s, dtype=torch.float32)
|
||||
done = False
|
||||
while not done:
|
||||
with torch.no_grad():
|
||||
action_probs = actor_net(s)
|
||||
action = torch.argmax(action_probs).item()
|
||||
s_, r, terminated, truncated, _ = env.step(action)
|
||||
done = terminated or truncated
|
||||
s = torch.tensor(s_, dtype=torch.float32)
|
||||
env.close()
|
||||
print('Demonstration finished.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
trained_model = train_and_save()
|
||||
show()
|
||||
Reference in New Issue
Block a user