114 lines
3.8 KiB
Python
114 lines
3.8 KiB
Python
import gymnasium as gym
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.distributions import Normal
|
|
|
|
# -----------------------------
|
|
# Réseau Actor-Critic
|
|
# -----------------------------
|
|
class ActorCritic(nn.Module):
|
|
def __init__(self, state_dim, action_dim, hidden_dim=128):
|
|
super().__init__()
|
|
self.shared = nn.Sequential(
|
|
nn.Linear(state_dim, hidden_dim),
|
|
nn.ReLU()
|
|
)
|
|
self.actor_mean = nn.Linear(hidden_dim, action_dim)
|
|
self.actor_logstd = nn.Parameter(torch.zeros(action_dim))
|
|
self.critic = nn.Linear(hidden_dim, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.shared(x)
|
|
mean = self.actor_mean(x)
|
|
logstd = self.actor_logstd.expand_as(mean)
|
|
dist = Normal(mean, logstd.exp())
|
|
value = self.critic(x)
|
|
return dist, value
|
|
|
|
# -----------------------------
|
|
# Agent A2C
|
|
# -----------------------------
|
|
class A2CAgent:
|
|
def __init__(self, env_name, gamma=0.99, lr=1e-3):
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
self.env = gym.make(env_name)
|
|
self.gamma = gamma
|
|
|
|
state_dim = self.env.observation_space.shape[0]
|
|
action_dim = self.env.action_space.shape[0]
|
|
|
|
self.model = ActorCritic(state_dim, action_dim).to(self.device)
|
|
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
|
|
|
|
|
|
def select_action(self, state):
|
|
state = torch.FloatTensor(state).to(self.device)
|
|
dist, _ = self.model(state)
|
|
action = dist.sample()
|
|
log_prob = dist.log_prob(action).sum(dim=-1)
|
|
return action.cpu().numpy(), log_prob
|
|
|
|
|
|
def compute_returns(self, rewards, masks, next_value):
|
|
R = next_value
|
|
returns = []
|
|
for step in reversed(range(len(rewards))):
|
|
R = rewards[step] + self.gamma * R * masks[step]
|
|
returns.insert(0, R)
|
|
return returns
|
|
|
|
|
|
def update(self, trajectory, next_state):
|
|
states = torch.FloatTensor([t[0] for t in trajectory]).to(self.device)
|
|
actions = torch.FloatTensor([t[1] for t in trajectory]).to(self.device)
|
|
log_probs = torch.stack([t[2] for t in trajectory]).to(self.device)
|
|
rewards = [t[3] for t in trajectory]
|
|
masks = [t[4] for t in trajectory]
|
|
|
|
with torch.no_grad():
|
|
_, next_value = self.model(torch.FloatTensor(next_state).to(self.device))
|
|
next_value = next_value.squeeze()
|
|
returns = self.compute_returns(rewards, masks, next_value)
|
|
returns = torch.FloatTensor(returns).to(self.device)
|
|
|
|
dist, values = self.model(states)
|
|
advantages = returns - values.squeeze()
|
|
|
|
actor_loss = -(log_probs * advantages.detach()).mean()
|
|
critic_loss = advantages.pow(2).mean()
|
|
loss = actor_loss + 0.5 * critic_loss
|
|
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
|
|
def train(self, max_steps=2000, update_every=5):
|
|
state, _ = self.env.reset()
|
|
trajectory = []
|
|
|
|
for step in range(max_steps):
|
|
action, log_prob = self.select_action(state)
|
|
next_state, reward, terminated, truncated, _ = self.env.step(action)
|
|
done = terminated or truncated
|
|
mask = 0.0 if done else 1.0 # <-- correction ici
|
|
|
|
trajectory.append((state, action, log_prob, reward, mask))
|
|
state = next_state
|
|
|
|
if (step + 1) % update_every == 0:
|
|
self.update(trajectory, next_state)
|
|
trajectory = []
|
|
|
|
if (step + 1) % 100 == 0:
|
|
print(f"Step {step + 1}, reward: {reward}")
|
|
|
|
|
|
# -----------------------------
|
|
# Lancer l'entraînement
|
|
# -----------------------------
|
|
if __name__ == "__main__":
|
|
agent = A2CAgent("Pusher-v5")
|
|
agent.train(max_steps=2000)
|