-
Notifications
You must be signed in to change notification settings - Fork 0
/
dqn_agent.py
88 lines (73 loc) · 3.28 KB
/
dqn_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
from torch import nn
import copy
from collections import deque
import random
class DQN_Agent:
def __init__(self, seed, layer_sizes, lr, sync_freq, exp_replay_size):
torch.manual_seed(seed)
self.q_net = self.build_nn(layer_sizes)
self.target_net = copy.deepcopy(self.q_net)
self.q_net.cuda()
self.target_net.cuda()
self.loss_fn = torch.nn.MSELoss()
self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=lr)
self.network_sync_freq = sync_freq
self.network_sync_counter = 0
self.gamma = torch.tensor(0.95).float().cuda()
self.experience_replay = deque(maxlen=exp_replay_size)
return
def build_nn(self, layer_sizes):
assert len(layer_sizes) > 1
layers = []
for index in range(len(layer_sizes) - 1):
linear = nn.Linear(layer_sizes[index], layer_sizes[index + 1])
act = nn.Tanh() if index < len(layer_sizes) - 2 else nn.Identity()
layers += (linear, act)
return nn.Sequential(*layers)
def load_pretrained_model(self, model_path):
self.q_net.load_state_dict(torch.load(model_path))
def save_trained_model(self, model_path="cartpole-dqn.pth"):
torch.save(self.q_net.state_dict(), model_path)
def get_action(self, state, action_space_len, epsilon):
# We do not require gradient at this point, because this function will be used either
# during experience collection or during inference
with torch.no_grad():
Qp = self.q_net(torch.from_numpy(state).float().cuda())
Q, A = torch.max(Qp, axis=0)
A = A if torch.rand(1, ).item() > epsilon else torch.randint(0, action_space_len, (1,))
return A
def get_q_next(self, state):
with torch.no_grad():
qp = self.target_net(state)
q, _ = torch.max(qp, axis=1)
return q
def collect_experience(self, experience):
self.experience_replay.append(experience)
return
def sample_from_experience(self, sample_size):
if len(self.experience_replay) < sample_size:
sample_size = len(self.experience_replay)
sample = random.sample(self.experience_replay, sample_size)
s = torch.tensor([exp[0] for exp in sample]).float()
a = torch.tensor([exp[1] for exp in sample]).float()
rn = torch.tensor([exp[2] for exp in sample]).float()
sn = torch.tensor([exp[3] for exp in sample]).float()
return s, a, rn, sn
def train(self, batch_size):
s, a, rn, sn = self.sample_from_experience(sample_size=batch_size)
if self.network_sync_counter == self.network_sync_freq:
self.target_net.load_state_dict(self.q_net.state_dict())
self.network_sync_counter = 0
# predict expected return of current state using main network
qp = self.q_net(s.cuda())
pred_return, _ = torch.max(qp, axis=1)
# get target return using target network
q_next = self.get_q_next(sn.cuda())
target_return = rn.cuda() + self.gamma * q_next
loss = self.loss_fn(pred_return, target_return)
self.optimizer.zero_grad()
loss.backward(retain_graph=True)
self.optimizer.step()
self.network_sync_counter += 1
return loss.item()