refactor: add 2 spaces in tp7 between function
All checks were successful
SonarQube Scan / SonarQube Trigger (push) Successful in 24s
All checks were successful
SonarQube Scan / SonarQube Trigger (push) Successful in 24s
This commit is contained in:
4
tp7.py
4
tp7.py
@@ -25,6 +25,7 @@ class Actor(nn.Module):
|
|||||||
std = torch.exp(log_std)
|
std = torch.exp(log_std)
|
||||||
return mu, std
|
return mu, std
|
||||||
|
|
||||||
|
|
||||||
class Critic(nn.Module):
|
class Critic(nn.Module):
|
||||||
def __init__(self, state_dim):
|
def __init__(self, state_dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -53,6 +54,7 @@ def compute_gae(rewards, values, gamma, lam, next_value):
|
|||||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||||
return returns, advantages
|
return returns, advantages
|
||||||
|
|
||||||
|
|
||||||
# ——— Entraînement ———
|
# ——— Entraînement ———
|
||||||
def train_and_save():
|
def train_and_save():
|
||||||
env = gym.make("Pusher-v5")
|
env = gym.make("Pusher-v5")
|
||||||
@@ -151,7 +153,7 @@ def train_and_save():
|
|||||||
|
|
||||||
torch.save(actor.state_dict(), "a2c_pusher.pth")
|
torch.save(actor.state_dict(), "a2c_pusher.pth")
|
||||||
|
|
||||||
# ——— Démonstration ———
|
|
||||||
def show(weights_path="a2c_pusher.pth"):
|
def show(weights_path="a2c_pusher.pth"):
|
||||||
env = gym.make("Pusher-v5", render_mode="human")
|
env = gym.make("Pusher-v5", render_mode="human")
|
||||||
actor = Actor(env.observation_space.shape[0], env.action_space.shape[0])
|
actor = Actor(env.observation_space.shape[0], env.action_space.shape[0])
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ class A2CAgent:
|
|||||||
self.model = ActorCritic(state_dim, action_dim).to(self.device)
|
self.model = ActorCritic(state_dim, action_dim).to(self.device)
|
||||||
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
|
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
|
||||||
|
|
||||||
|
|
||||||
def select_action(self, state):
|
def select_action(self, state):
|
||||||
state = torch.FloatTensor(state).to(self.device)
|
state = torch.FloatTensor(state).to(self.device)
|
||||||
dist, _ = self.model(state)
|
dist, _ = self.model(state)
|
||||||
@@ -48,6 +49,7 @@ class A2CAgent:
|
|||||||
log_prob = dist.log_prob(action).sum(dim=-1)
|
log_prob = dist.log_prob(action).sum(dim=-1)
|
||||||
return action.cpu().numpy(), log_prob
|
return action.cpu().numpy(), log_prob
|
||||||
|
|
||||||
|
|
||||||
def compute_returns(self, rewards, masks, next_value):
|
def compute_returns(self, rewards, masks, next_value):
|
||||||
R = next_value
|
R = next_value
|
||||||
returns = []
|
returns = []
|
||||||
@@ -56,6 +58,7 @@ class A2CAgent:
|
|||||||
returns.insert(0, R)
|
returns.insert(0, R)
|
||||||
return returns
|
return returns
|
||||||
|
|
||||||
|
|
||||||
def update(self, trajectory, next_state):
|
def update(self, trajectory, next_state):
|
||||||
states = torch.FloatTensor([t[0] for t in trajectory]).to(self.device)
|
states = torch.FloatTensor([t[0] for t in trajectory]).to(self.device)
|
||||||
actions = torch.FloatTensor([t[1] for t in trajectory]).to(self.device)
|
actions = torch.FloatTensor([t[1] for t in trajectory]).to(self.device)
|
||||||
@@ -80,6 +83,7 @@ class A2CAgent:
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
def train(self, max_steps=2000, update_every=5):
|
def train(self, max_steps=2000, update_every=5):
|
||||||
state, _ = self.env.reset()
|
state, _ = self.env.reset()
|
||||||
trajectory = []
|
trajectory = []
|
||||||
|
|||||||
Reference in New Issue
Block a user