from typing import Union, Dict, Optional
import torch
import torch.nn as nn

from ding.utils import SequenceType, squeeze, MODEL_REGISTRY
from ..common import ReparameterizationHead, RegressionHead, DiscreteHead


@MODEL_REGISTRY.register('mavac')
class MAVAC(nn.Module):
    """
    Overview:
        The neural network and computation graph of algorithms related to (state) Value Actor-Critic (VAC) for \
        multi-agent, such as MAPPO(https://arxiv.org/abs/2103.01955). This model now supports discrete and \
        continuous action space. The MAVAC is composed of four parts: ``actor_encoder``, ``critic_encoder``, \
        ``actor_head`` and ``critic_head``. Encoders are used to extract the feature from various observation. \
        Heads are used to predict corresponding value or action logit.
    Interfaces:
        ``__init__``, ``forward``, ``compute_actor``, ``compute_critic``, ``compute_actor_critic``.
    """
    mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']

    def __init__(
        self,
        agent_obs_shape: Union[int, SequenceType],
        global_obs_shape: Union[int, SequenceType],
        action_shape: Union[int, SequenceType],
        agent_num: int,
        actor_head_hidden_size: int = 256,
        actor_head_layer_num: int = 2,
        critic_head_hidden_size: int = 512,
        critic_head_layer_num: int = 1,
        action_space: str = 'discrete',
        activation: Optional[nn.Module] = nn.ReLU(),
        norm_type: Optional[str] = None,
        sigma_type: Optional[str] = 'independent',
        bound_type: Optional[str] = None,
    ) -> None:
        """
        Overview:
            Init the MAVAC Model according to arguments.
        Arguments:
            - agent_obs_shape (:obj:`Union[int, SequenceType]`): Observation's space for single agent, \
                such as 8 or [4, 84, 84].
            - global_obs_shape (:obj:`Union[int, SequenceType]`): Global observation's space, such as 8 or [4, 84, 84].
            - action_shape (:obj:`Union[int, SequenceType]`): Action space shape for single agent, such as 6 \
                or [2, 3, 3].
            - agent_num (:obj:`int`): This parameter is temporarily reserved. This parameter may be required for \
                subsequent changes to the model
            - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``actor_head`` network, defaults \
                to 256, it must match the last element of ``agent_obs_shape``.
            - actor_head_layer_num (:obj:`int`): The num of layers used in the ``actor_head`` network to compute action.
            - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``critic_head`` network, defaults \
                to 512, it must match the last element of ``global_obs_shape``.
            - critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output for \
                critic's nn.
            - action_space (:obj:`Union[int, SequenceType]`): The type of different action spaces, including \
                ['discrete', 'continuous'], then will instantiate corresponding head, including ``DiscreteHead`` \
                and ``ReparameterizationHead``.
            - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` the after \
                ``layer_fn``, if ``None`` then default set to ``nn.ReLU()``.
            - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
                ``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN'].
            - sigma_type (:obj:`Optional[str]`): The type of sigma in continuous action space, see \
                ``ding.torch_utils.network.dreamer.ReparameterizationHead`` for more details, in MAPPO, it defaults \
                to ``independent``, which means state-independent sigma parameters.
            - bound_type (:obj:`Optional[str]`): The type of action bound methods in continuous action space, defaults \
                to ``None``, which means no bound.
        """
        super(MAVAC, self).__init__()
        agent_obs_shape: int = squeeze(agent_obs_shape)
        global_obs_shape: int = squeeze(global_obs_shape)
        action_shape: int = squeeze(action_shape)
        self.global_obs_shape, self.agent_obs_shape, self.action_shape = global_obs_shape, agent_obs_shape, action_shape
        self.action_space = action_space
        # Encoder Type
        # We directly connect the Head after a Liner layer instead of using the 3-layer FCEncoder.
        # In SMAC task it can obviously improve the performance.
        # Users can change the model according to their own needs.
        self.actor_encoder = nn.Identity()
        self.critic_encoder = nn.Identity()
        # Head Type
        self.critic_head = nn.Sequential(
            nn.Linear(global_obs_shape, critic_head_hidden_size), activation,
            RegressionHead(
                critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type
            )
        )
        assert self.action_space in ['discrete', 'continuous'], self.action_space
        if self.action_space == 'discrete':
            self.actor_head = nn.Sequential(
                nn.Linear(agent_obs_shape, actor_head_hidden_size), activation,
                DiscreteHead(
                    actor_head_hidden_size,
                    action_shape,
                    actor_head_layer_num,
                    activation=activation,
                    norm_type=norm_type
                )
            )
        elif self.action_space == 'continuous':
            self.actor_head = nn.Sequential(
                nn.Linear(agent_obs_shape, actor_head_hidden_size), activation,
                ReparameterizationHead(
                    actor_head_hidden_size,
                    action_shape,
                    actor_head_layer_num,
                    sigma_type=sigma_type,
                    activation=activation,
                    norm_type=norm_type,
                    bound_type=bound_type
                )
            )
        # must use list, not nn.ModuleList
        self.actor = [self.actor_encoder, self.actor_head]
        self.critic = [self.critic_encoder, self.critic_head]
        # for convenience of call some apis(such as: self.critic.parameters()), but may cause
        # misunderstanding when print(self)
        self.actor = nn.ModuleList(self.actor)
        self.critic = nn.ModuleList(self.critic)

    def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict:
        """
        Overview:
            MAVAC forward computation graph, input observation tensor to predict state value or action logit. \
            ``mode`` includes ``compute_actor``, ``compute_critic``, ``compute_actor_critic``.
            Different ``mode`` will forward with different network modules to get different outputs and save \
            computation.
        Arguments:
            - inputs (:obj:`Dict`): The input dict including observation and related info, \
                whose key-values vary from different ``mode``.
            - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class.
        Returns:
            - outputs (:obj:`Dict`): The output dict of MAVAC's forward computation graph, whose key-values vary from \
                different ``mode``.

        Examples (Actor):
            >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14)
            >>> inputs = {
                    'agent_state': torch.randn(10, 8, 64),
                    'global_state': torch.randn(10, 8, 128),
                    'action_mask': torch.randint(0, 2, size=(10, 8, 14))
                }
            >>> actor_outputs = model(inputs,'compute_actor')
            >>> assert actor_outputs['logit'].shape == torch.Size([10, 8, 14])

        Examples (Critic):
            >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14)
            >>> inputs = {
                    'agent_state': torch.randn(10, 8, 64),
                    'global_state': torch.randn(10, 8, 128),
                    'action_mask': torch.randint(0, 2, size=(10, 8, 14))
                }
            >>> critic_outputs = model(inputs,'compute_critic')
            >>> assert actor_outputs['value'].shape == torch.Size([10, 8])

        Examples (Actor-Critic):
            >>> model = MAVAC(64, 64)
            >>> inputs = {
                    'agent_state': torch.randn(10, 8, 64),
                    'global_state': torch.randn(10, 8, 128),
                    'action_mask': torch.randint(0, 2, size=(10, 8, 14))
                }
            >>> outputs = model(inputs,'compute_actor_critic')
            >>> assert outputs['value'].shape == torch.Size([10, 8, 14])
            >>> assert outputs['logit'].shape == torch.Size([10, 8])

        """
        assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
        return getattr(self, mode)(inputs)

    def compute_actor(self, x: Dict) -> Dict:
        """
        Overview:
            MAVAC forward computation graph for actor part, \
            predicting action logit with agent observation tensor in ``x``.
        Arguments:
            - x (:obj:`Dict`): Input data dict with keys ['agent_state', 'action_mask'(optional)].
                - agent_state: (:obj:`torch.Tensor`): Each agent local state(obs).
                - action_mask(optional): (:obj:`torch.Tensor`): When ``action_space`` is discrete, action_mask needs \
                    to be provided to mask illegal actions.
        Returns:
            - outputs (:obj:`Dict`): The output dict of the forward computation graph for actor, including ``logit``.
        ReturnsKeys:
            - logit (:obj:`torch.Tensor`): The predicted action logit tensor, for discrete action space, it will be \
                the same dimension real-value ranged tensor of possible action choices, and for continuous action \
                space, it will be the mu and sigma of the Gaussian distribution, and the number of mu and sigma is the \
                same as the number of continuous actions.
        Shapes:
            - logit (:obj:`torch.FloatTensor`): :math:`(B, M, N)`, where B is batch size and N is ``action_shape`` \
              and M is ``agent_num``.

        Examples:
            >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14)
            >>> inputs = {
                    'agent_state': torch.randn(10, 8, 64),
                    'global_state': torch.randn(10, 8, 128),
                    'action_mask': torch.randint(0, 2, size=(10, 8, 14))
                }
            >>> actor_outputs = model(inputs,'compute_actor')
            >>> assert actor_outputs['logit'].shape == torch.Size([10, 8, 14])

        """
        if self.action_space == 'discrete':
            action_mask = x['action_mask']
            x = x['agent_state']
            x = self.actor_encoder(x)
            x = self.actor_head(x)
            logit = x['logit']
            logit[action_mask == 0.0] = -99999999
        elif self.action_space == 'continuous':
            x = x['agent_state']
            x = self.actor_encoder(x)
            x = self.actor_head(x)
            logit = x
        return {'logit': logit}

    def compute_critic(self, x: Dict) -> Dict:
        """
        Overview:
            MAVAC forward computation graph for critic part. \
            Predict state value with global observation tensor in ``x``.
        Arguments:
            - x (:obj:`Dict`): Input data dict with keys ['global_state'].
                - global_state: (:obj:`torch.Tensor`): Global state(obs).
        Returns:
            - outputs (:obj:`Dict`): The output dict of MAVAC's forward computation graph for critic, \
                including ``value``.
        ReturnsKeys:
            - value (:obj:`torch.Tensor`): The predicted state value tensor.
        Shapes:
            - value (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch size and M is ``agent_num``.

        Examples:
            >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14)
            >>> inputs = {
                    'agent_state': torch.randn(10, 8, 64),
                    'global_state': torch.randn(10, 8, 128),
                    'action_mask': torch.randint(0, 2, size=(10, 8, 14))
                }
            >>> critic_outputs = model(inputs,'compute_critic')
            >>> assert critic_outputs['value'].shape == torch.Size([10, 8])
        """

        x = self.critic_encoder(x['global_state'])
        x = self.critic_head(x)
        return {'value': x['pred']}

    def compute_actor_critic(self, x: Dict) -> Dict:
        """
        Overview:
            MAVAC forward computation graph for both actor and critic part, input observation to predict action \
            logit and state value.
        Arguments:
            - x (:obj:`Dict`): The input dict contains ``agent_state``, ``global_state`` and other related info.
        Returns:
            - outputs (:obj:`Dict`): The output dict of MAVAC's forward computation graph for both actor and critic, \
                including ``logit`` and ``value``.
        ReturnsKeys:
            - logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``.
            - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
        Shapes:
            - logit (:obj:`torch.FloatTensor`): :math:`(B, M, N)`, where B is batch size and N is ``action_shape`` \
              and M is ``agent_num``.
            - value (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch sizeand M is ``agent_num``.

        Examples:
            >>> model = MAVAC(64, 64)
            >>> inputs = {
                    'agent_state': torch.randn(10, 8, 64),
                    'global_state': torch.randn(10, 8, 128),
                    'action_mask': torch.randint(0, 2, size=(10, 8, 14))
                }
            >>> outputs = model(inputs,'compute_actor_critic')
            >>> assert outputs['value'].shape == torch.Size([10, 8])
            >>> assert outputs['logit'].shape == torch.Size([10, 8, 14])
        """
        logit = self.compute_actor(x)['logit']
        value = self.compute_critic(x)['value']
        return {'logit': logit, 'value': value}