# Based on : https://github.com/djbyrne/core_rl/blob/master/algos/dqn/model.py

import numpy as np
import torch


class Agent:
    def __init__(self,
                 net,
                 action_space=None,
                 exploration_initial_eps=None,
                 exploration_decay=None,
                 exploration_final_eps=None):

        self.net = net
        self.action_space = action_space
        self.exploration_initial_eps = exploration_initial_eps
        self.exploration_decay = exploration_decay
        self.exploration_final_eps = exploration_final_eps
        self.epsilon = 0.

    def __call__(self, state, device=torch.device('cpu')):
        if np.random.random() < self.epsilon:
            action = self.get_random_action()
        else:
            action = self.get_action(state, device)

        return action

    def get_random_action(self):
        action = self.action_space.sample()
        return action

    def get_action(self, state, device=torch.device('cpu')):
        if not isinstance(state, torch.Tensor):
            state = torch.tensor([state])

        if device.type != 'cpu':
            state = state.cuda(device)

        q_values = self.net.eval()(state)
        _, action = torch.max(q_values, dim=1)
        return int(action.item())

    def update_epsilon(self, step):
        self.epsilon = max(
            self.exploration_final_eps, self.exploration_final_eps +
            (self.exploration_initial_eps - self.exploration_final_eps) *
            self.exploration_decay**step)
        return self.epsilon