import math
from typing import List, Dict, Any, Tuple
from collections import namedtuple

import torch
import torch.nn as nn
from torch.optim import Adam, SGD, AdamW
from torch.optim.lr_scheduler import LambdaLR

from ding.policy import Policy
from ding.model import model_wrap
from ding.torch_utils import to_device
from ding.utils import EasyTimer
from ding.utils import POLICY_REGISTRY


@POLICY_REGISTRY.register('pc_bfs')
class ProcedureCloningBFSPolicy(Policy):

    def default_model(self) -> Tuple[str, List[str]]:
        return 'pc_bfs', ['ding.model.template.procedure_cloning']

    config = dict(
        type='pc',
        cuda=False,
        on_policy=False,
        continuous=False,
        max_bfs_steps=100,
        learn=dict(
            update_per_collect=1,
            batch_size=32,
            learning_rate=1e-5,
            lr_decay=False,
            decay_epoch=30,
            decay_rate=0.1,
            warmup_lr=1e-4,
            warmup_epoch=3,
            optimizer='SGD',
            momentum=0.9,
            weight_decay=1e-4,
        ),
        collect=dict(
            unroll_len=1,
            noise=False,
            noise_sigma=0.2,
            noise_range=dict(
                min=-0.5,
                max=0.5,
            ),
        ),
        eval=dict(),
        other=dict(replay_buffer=dict(replay_buffer_size=10000)),
    )

    def _init_learn(self):
        assert self._cfg.learn.optimizer in ['SGD', 'Adam']
        if self._cfg.learn.optimizer == 'SGD':
            self._optimizer = SGD(
                self._model.parameters(),
                lr=self._cfg.learn.learning_rate,
                weight_decay=self._cfg.learn.weight_decay,
                momentum=self._cfg.learn.momentum
            )
        elif self._cfg.learn.optimizer == 'Adam':
            if self._cfg.learn.weight_decay is None:
                self._optimizer = Adam(
                    self._model.parameters(),
                    lr=self._cfg.learn.learning_rate,
                )
            else:
                self._optimizer = AdamW(
                    self._model.parameters(),
                    lr=self._cfg.learn.learning_rate,
                    weight_decay=self._cfg.learn.weight_decay
                )
        if self._cfg.learn.lr_decay:

            def lr_scheduler_fn(epoch):
                if epoch <= self._cfg.learn.warmup_epoch:
                    return self._cfg.learn.warmup_lr / self._cfg.learn.learning_rate
                else:
                    ratio = (epoch - self._cfg.learn.warmup_epoch) // self._cfg.learn.decay_epoch
                    return math.pow(self._cfg.learn.decay_rate, ratio)

            self._lr_scheduler = LambdaLR(self._optimizer, lr_scheduler_fn)
        self._timer = EasyTimer(cuda=True)
        self._learn_model = model_wrap(self._model, 'base')
        self._learn_model.reset()
        self._max_bfs_steps = self._cfg.max_bfs_steps
        self._maze_size = self._cfg.maze_size
        self._num_actions = self._cfg.num_actions

        self._loss = nn.CrossEntropyLoss()

    def process_states(self, observations, maze_maps):
        """Returns [B, W, W, 3] binary values. Channels are (wall; goal; obs)"""
        loc = torch.nn.functional.one_hot(
            (observations[:, 0] * self._maze_size + observations[:, 1]).long(),
            self._maze_size * self._maze_size,
        ).long()
        loc = torch.reshape(loc, [observations.shape[0], self._maze_size, self._maze_size])
        states = torch.cat([maze_maps, loc], dim=-1).long()
        return states

    def _forward_learn(self, data):
        if self._cuda:
            collated_data = to_device(data, self._device)
        else:
            collated_data = data
        observations = collated_data['obs'],
        bfs_input_maps, bfs_output_maps = collated_data['bfs_in'].long(), collated_data['bfs_out'].long()
        states = observations
        bfs_input_onehot = torch.nn.functional.one_hot(bfs_input_maps, self._num_actions + 1).float()

        bfs_states = torch.cat([
            states,
            bfs_input_onehot,
        ], dim=-1)
        logits = self._model(bfs_states)['logit']
        logits = logits.flatten(0, -2)
        labels = bfs_output_maps.flatten(0, -1)

        loss = self._loss(logits, labels)
        preds = torch.argmax(logits, dim=-1)
        acc = torch.sum((preds == labels)) / preds.shape[0]

        self._optimizer.zero_grad()
        loss.backward()
        self._optimizer.step()
        pred_loss = loss.item()

        cur_lr = [param_group['lr'] for param_group in self._optimizer.param_groups]
        cur_lr = sum(cur_lr) / len(cur_lr)
        return {'cur_lr': cur_lr, 'total_loss': pred_loss, 'acc': acc}

    def _monitor_vars_learn(self):
        return ['cur_lr', 'total_loss', 'acc']

    def _init_eval(self):
        self._eval_model = model_wrap(self._model, wrapper_name='base')
        self._eval_model.reset()

    def _forward_eval(self, data):
        if self._cuda:
            data = to_device(data, self._device)
        max_len = self._max_bfs_steps
        data_id = list(data.keys())
        output = {}

        for ii in data_id:
            states = data[ii].unsqueeze(0)
            bfs_input_maps = self._num_actions * torch.ones([1, self._maze_size, self._maze_size]).long()
            if self._cuda:
                bfs_input_maps = to_device(bfs_input_maps, self._device)
            xy = torch.where(states[:, :, :, -1] == 1)
            observation = (xy[1][0].item(), xy[2][0].item())

            i = 0
            while bfs_input_maps[0, observation[0], observation[1]].item() == self._num_actions and i < max_len:
                bfs_input_onehot = torch.nn.functional.one_hot(bfs_input_maps, self._num_actions + 1).long()

                bfs_states = torch.cat([
                    states,
                    bfs_input_onehot,
                ], dim=-1)
                logits = self._model(bfs_states)['logit']
                bfs_input_maps = torch.argmax(logits, dim=-1)
                i += 1
            output[ii] = bfs_input_maps[0, observation[0], observation[1]]
            if self._cuda:
                output[ii] = {'action': to_device(output[ii], 'cpu'), 'info': {}}
            if output[ii]['action'].item() == self._num_actions:
                output[ii]['action'] = torch.randint(low=0, high=self._num_actions, size=[1])[0]
        return output

    def _init_collect(self) -> None:
        raise NotImplementedError

    def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]:
        raise NotImplementedError

    def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
        raise NotImplementedError

    def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        raise NotImplementedError