From 3ca647c6025b863eab6b7ad0a9053b60f1584030 Mon Sep 17 00:00:00 2001 From: Namu Date: Thu, 25 Sep 2025 13:37:07 +0200 Subject: [PATCH] Robots works --- .gitignore | 8 +++ requirements.txt | Bin 0 -> 612 bytes tp2.py | 139 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 147 insertions(+) create mode 100644 .gitignore create mode 100644 requirements.txt create mode 100644 tp2.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3c3b8e4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +# .idea +.idea/ + +# venv +.venv/ + +# weights +*.pth diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..095319227af6fd3287017f07fb76f2a89f5ca5aa GIT binary patch literal 612 zcmYjO$xg#i44g9(pHkGkE-DAE9FXXR4?tB&LV>(gd5PNoJTPN7iz-)1>{&ei`_1u# z20cD-!Vx*X`Cj0OzZD*kS>59f&p4oE)vN{f#B`i54Rf&P1`)}SOQKY32&oxyN=N(? z_v6T35}D|r;Dq{t_<#-jnpql8g)7z76?eea{-xlZ1@+xkPM*n$c9=6nZjk7yzp<$w zSedE3uoLmcnP6H&3R~;G@_Wl2dU}21b2-b(UCB?)u5#+|X&h1zkGqJ(bU4$im@cEr z#uTFYI&oLe!I3NL>1##a`lc$^O2M5cg40e{+lw-0S2L-?IdUPjj%PY8Z fDePvwE7_G#g?Fl5s8rb1k3= batch_size: + batch = random.sample(memory, batch_size) + s_b, a_b, r_b, ns_b, d_b = zip(*batch) + s_b = torch.stack(s_b) + ns_b = torch.stack(ns_b) + + # Q values for chosen actions + q_pred = policy_net(s_b).gather(1, torch.tensor(a_b).unsqueeze(1)).squeeze() + + # Target values using target network + with torch.no_grad(): + q_next = target_net(ns_b).max(1)[0] + q_target = torch.tensor(r_b, dtype=torch.float32) + \ + gamma * q_next * (1 - torch.tensor(d_b, dtype=torch.float32)) + + # MSE + loss = ((q_pred - q_target)**2).mean() + optimizer.zero_grad(); loss.backward(); optimizer.step() + + # decay epsilon to gradually reduce exploration + epsilon = max(eps_min, epsilon * eps_decay) + + # Periodically synchronise target network with policy network + if (ep + 1) % update_target_every == 0: + target_net.load_state_dict(policy_net.state_dict()) + + if (ep + 1) % 20 == 0: + print(f'Episode {ep + 1}: total reward {total_r:.1f}, epsilon {epsilon:.2f}') + env.close() + + # save trained policy network + torch.save(policy_net.state_dict(), weights_path) + print(f'Training finished. Weights saved to {weights_path}') + return policy_net # <--- trained Q-network + + +def show(weights_path="acrobot_dqn.pth") -> None: + """ + Load trained Q network and run a single episode to visually + demonstrate the learned policy + :param weights_path: path to the saved network weights + :return: + """ + env = gym.make("Acrobot-v1", render_mode="human") + qnet = DQN() + qnet.load_state_dict(torch.load(weights_path)) + qnet.eval() + + s, _ = env.reset() + s = torch.tensor(s, dtype=torch.float32) + done = False + while not done: + a = torch.argmax(qnet(s)).item() + s_, r, done, _, _ = env.step(a) + s = torch.tensor(s_, dtype=torch.float32) + env.close() + print('Demonstration finished.') + + +if __name__ == '__main__': + trained_model = train_and_save() + show()