138 lines
4.1 KiB
Python
138 lines
4.1 KiB
Python
import gymnasium as gym
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
class Actor(nn.Module):
|
|
"""
|
|
The action DNN
|
|
"""
|
|
|
|
def __init__(self, n_states, n_actions):
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
nn.Linear(n_states, 128), nn.ReLU(),
|
|
nn.Linear(128, 64), nn.ReLU(),
|
|
nn.Linear(64, n_actions)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return torch.softmax(self.net(x), dim=-1)
|
|
|
|
|
|
class Critic(nn.Module):
|
|
"""
|
|
The critic DNN
|
|
"""
|
|
|
|
def __init__(self, n_states):
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
nn.Linear(n_states, 64), nn.ReLU(),
|
|
nn.Linear(64, 1)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
|
|
def train_and_save(weights_path="cartpole_actor_critic.pth", episodes=500):
|
|
env = gym.make("CartPole-v1")
|
|
n_states, n_actions = env.observation_space.shape[0], env.action_space.n
|
|
|
|
# Definition des DNN acteur & critique
|
|
actor_net = Actor(n_states, n_actions)
|
|
critic_net = Critic(n_states)
|
|
|
|
# Hyperparameters et optimiser
|
|
optimizer_actor = optim.Adam(actor_net.parameters(), lr=1e-3)
|
|
optimizer_critic = optim.Adam(critic_net.parameters(), lr=5e-4)
|
|
gamma = 0.99
|
|
|
|
for ep in range(episodes):
|
|
# le state courant donner par l'environnement
|
|
s, _ = env.reset()
|
|
s = torch.tensor(s, dtype=torch.float32)
|
|
|
|
# Des variables purement fonctionnelles
|
|
done, total_r = False, 0
|
|
|
|
log_probs = []
|
|
td_errors = []
|
|
|
|
while not done:
|
|
# Acteur : choisit une action
|
|
action_probs = actor_net(s)
|
|
dist = torch.distributions.Categorical(action_probs)
|
|
action = dist.sample()
|
|
log_prob = dist.log_prob(action)
|
|
|
|
# Environnement : effectue l'action
|
|
ns, r, terminated, truncated, _ = env.step(action.item())
|
|
done = terminated or truncated
|
|
ns = torch.tensor(ns, dtype=torch.float32)
|
|
total_r += r
|
|
|
|
# Critique : calcule la TD error
|
|
with torch.no_grad():
|
|
value_ns = critic_net(ns) if not done else torch.tensor([0.0]) # force ns = 0 si la simulation et terminée
|
|
value_n = critic_net(s)
|
|
|
|
td_error = r + gamma * value_ns - value_n # Pas de detach ici, car on veut le gradient pour le critic
|
|
|
|
# Actor loss
|
|
actor_loss = -log_prob * td_error.detach() # Detach td_error pour l'actor
|
|
optimizer_actor.zero_grad()
|
|
actor_loss.backward()
|
|
optimizer_actor.step()
|
|
|
|
# Critic loss
|
|
critic_loss = td_error.pow(2).mean() # MSE
|
|
optimizer_critic.zero_grad()
|
|
critic_loss.backward()
|
|
optimizer_critic.step()
|
|
|
|
print("value_n:", value_n.item(), "value_ns:", value_ns.item(), "td_error:", td_error.item())
|
|
|
|
log_probs.append(log_prob)
|
|
td_errors.append(td_error)
|
|
|
|
# Mise à jour de l'état
|
|
s = ns
|
|
|
|
print(f'Episode {ep + 1}: total reward {total_r:.1f}')
|
|
|
|
# Libération des ressources liées à l'environnement
|
|
env.close()
|
|
|
|
# Sauvegarde des poids
|
|
torch.save(actor_net.state_dict(), weights_path)
|
|
print(f'Training finished. Weights saved to {weights_path}')
|
|
return actor_net
|
|
|
|
|
|
def show(weights_path="cartpole_actor_critic.pth"):
|
|
env = gym.make("CartPole-v1", render_mode="human")
|
|
actor_net = Actor(env.observation_space.shape[0], env.action_space.n)
|
|
actor_net.load_state_dict(torch.load(weights_path))
|
|
actor_net.eval()
|
|
s, _ = env.reset()
|
|
s = torch.tensor(s, dtype=torch.float32)
|
|
done = False
|
|
while not done:
|
|
with torch.no_grad():
|
|
action_probs = actor_net(s)
|
|
action = torch.argmax(action_probs).item()
|
|
s_, r, terminated, truncated, _ = env.step(action)
|
|
done = terminated or truncated
|
|
s = torch.tensor(s_, dtype=torch.float32)
|
|
env.close()
|
|
print('Demonstration finished.')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
trained_model = train_and_save()
|
|
show()
|