From d3500bff480a12e551029a34fc8a0125de44ccb8 Mon Sep 17 00:00:00 2001 From: Namu Date: Sat, 4 Oct 2025 22:59:09 +0200 Subject: [PATCH] feat: add tp3 --- requirements.txt | Bin 612 -> 1258 bytes tp3.py | 147 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 147 insertions(+) create mode 100644 tp3.py diff --git a/requirements.txt b/requirements.txt index 095319227af6fd3287017f07fb76f2a89f5ca5aa..6f5d8f92362f192d9c4e46a69fb71699253cfb65 100644 GIT binary patch literal 1258 zcmZ9MTTkOq421oRv_GY&DVK$ZeFX_|llBLoP?AvErd5)1*&h#lW9KAgRh=rS?eWam zr(d7BJz8rko7mPq><^!r9oi8mvtM>#b7C9Y*%R>vZ%f__&x@QfZ)3G}p4$vqS&wzp zV|_SYVT~h9y^H!quY&po_^n4(n>mxIS6^68U2bd6#BS-N?9y4~{wxMjD`cHxuT+j$ zONkUnWw1XVdY8n+UkWw@hhyTWjuZ?_@LSXs;DlFG{5l~^NWrtBy0o`GE)Bs(ZdQ1_ z_K1F)fvxW`GJ!+rn$u&{A}|pSh z=aRfMk#I$RXv!TKnpxtQvcl2mDIJmf>?rCbR^ir~E-xU=IjLIFsIBST!B{>l?S05N zah|oi`~R(oP2R#zo0KzAIZ}CTY(jL_ZCN?{*o2Hd*k>l(iu!f0UelA;v1j*0?@(5N zeQDPNGhs-*91?+Wbu zKkvjD@%??*9)@?pZrrk#?wU{huS8UP5d4LzO^dmS^VwcZ+QQifc)UAr-EM7ky5+*k Sd^-d2AJHwJH0jyiCVT_%m9_r> delta 49 zcmV-10M7sF3FHI<|NfC8D3PpSlQID$lX3xMlQ04vlXL= batch_size: + batch = random.sample(memory, batch_size) + s_b, a_b, r_b, ns_b, d_b = zip(*batch) + s_b = torch.stack(s_b) + ns_b = torch.stack(ns_b) + + # Q values for chosen actions + q_pred = policy_net(s_b).gather(1, torch.tensor(a_b).unsqueeze(1)).squeeze() + + # Target values using target network + with torch.no_grad(): + q_next = target_net(ns_b).max(1)[0] + q_target = torch.tensor(r_b, dtype=torch.float32) + \ + gamma * q_next * (1 - torch.tensor(d_b, dtype=torch.float32)) + + # MSE + loss = ((q_pred - q_target)**2).mean() + optimizer.zero_grad(); loss.backward(); optimizer.step() + + # decay epsilon to gradually reduce exploration + epsilon = max(eps_min, epsilon * eps_decay) + + # Periodically synchronise target network with policy network + if (ep + 1) % update_target_every == 0: + target_net.load_state_dict(policy_net.state_dict()) + + if (ep + 1) % 20 == 0: + print(f'Episode {ep + 1}: total reward {total_r:.1f}, epsilon {epsilon:.2f}') + env.close() + + # save trained policy network + torch.save(policy_net.state_dict(), weights_path) + print(f'Training finished. Weights saved to {weights_path}') + return policy_net # <--- trained Q-network + + +def show(weights_path="humanoid_dqn.pth") -> None: + """ + Load trained Q network and run a single episode to visually + demonstrate the learned policy + :param weights_path: path to the saved network weights + :return: + """ + env = gym.make("Humanoid-v5", render_mode="human") + qnet = DQN() + qnet.load_state_dict(torch.load(weights_path)) + qnet.eval() + + s, _ = env.reset() + s = torch.tensor(s, dtype=torch.float32) + done = False + total_r = 0.0 + while not done: + a = torch.argmax(qnet(s)).item() + s_, r, done, _, _ = env.step(ACTION_SET[a]) + s = torch.tensor(s_, dtype=torch.float32) + total_r += r + env.close() + print(f'Demonstration finished. Reward: {total_r:.2f}') + + +if __name__ == '__main__': + trained_model = train_and_save() + show()