File size: 1,574 Bytes
52dd602 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
# Based on : https://github.com/djbyrne/core_rl/blob/master/algos/dqn/model.py
import numpy as np
import torch
class Agent:
def __init__(self,
net,
action_space=None,
exploration_initial_eps=None,
exploration_decay=None,
exploration_final_eps=None):
self.net = net
self.action_space = action_space
self.exploration_initial_eps = exploration_initial_eps
self.exploration_decay = exploration_decay
self.exploration_final_eps = exploration_final_eps
self.epsilon = 0.
def __call__(self, state, device=torch.device('cpu')):
if np.random.random() < self.epsilon:
action = self.get_random_action()
else:
action = self.get_action(state, device)
return action
def get_random_action(self):
action = self.action_space.sample()
return action
def get_action(self, state, device=torch.device('cpu')):
if not isinstance(state, torch.Tensor):
state = torch.tensor([state])
if device.type != 'cpu':
state = state.cuda(device)
q_values = self.net.eval()(state)
_, action = torch.max(q_values, dim=1)
return int(action.item())
def update_epsilon(self, step):
self.epsilon = max(
self.exploration_final_eps, self.exploration_final_eps +
(self.exploration_initial_eps - self.exploration_final_eps) *
self.exploration_decay**step)
return self.epsilon
|