refactor: add 2 spaces in tp7 between function
All checks were successful
SonarQube Scan / SonarQube Trigger (push) Successful in 24s
All checks were successful
SonarQube Scan / SonarQube Trigger (push) Successful in 24s
This commit is contained in:
@@ -41,6 +41,7 @@ class A2CAgent:
|
||||
self.model = ActorCritic(state_dim, action_dim).to(self.device)
|
||||
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
|
||||
|
||||
|
||||
def select_action(self, state):
|
||||
state = torch.FloatTensor(state).to(self.device)
|
||||
dist, _ = self.model(state)
|
||||
@@ -48,6 +49,7 @@ class A2CAgent:
|
||||
log_prob = dist.log_prob(action).sum(dim=-1)
|
||||
return action.cpu().numpy(), log_prob
|
||||
|
||||
|
||||
def compute_returns(self, rewards, masks, next_value):
|
||||
R = next_value
|
||||
returns = []
|
||||
@@ -56,6 +58,7 @@ class A2CAgent:
|
||||
returns.insert(0, R)
|
||||
return returns
|
||||
|
||||
|
||||
def update(self, trajectory, next_state):
|
||||
states = torch.FloatTensor([t[0] for t in trajectory]).to(self.device)
|
||||
actions = torch.FloatTensor([t[1] for t in trajectory]).to(self.device)
|
||||
@@ -80,6 +83,7 @@ class A2CAgent:
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
|
||||
def train(self, max_steps=2000, update_every=5):
|
||||
state, _ = self.env.reset()
|
||||
trajectory = []
|
||||
|
||||
Reference in New Issue
Block a user