Spaces:
Sleeping
Sleeping
from typing import Union, Dict, Optional | |
import torch | |
import torch.nn as nn | |
from ding.utils import SequenceType, squeeze, MODEL_REGISTRY | |
from ..common import ReparameterizationHead, RegressionHead, DiscreteHead | |
class MAVAC(nn.Module): | |
""" | |
Overview: | |
The neural network and computation graph of algorithms related to (state) Value Actor-Critic (VAC) for \ | |
multi-agent, such as MAPPO(https://arxiv.org/abs/2103.01955). This model now supports discrete and \ | |
continuous action space. The MAVAC is composed of four parts: ``actor_encoder``, ``critic_encoder``, \ | |
``actor_head`` and ``critic_head``. Encoders are used to extract the feature from various observation. \ | |
Heads are used to predict corresponding value or action logit. | |
Interfaces: | |
``__init__``, ``forward``, ``compute_actor``, ``compute_critic``, ``compute_actor_critic``. | |
""" | |
mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] | |
def __init__( | |
self, | |
agent_obs_shape: Union[int, SequenceType], | |
global_obs_shape: Union[int, SequenceType], | |
action_shape: Union[int, SequenceType], | |
agent_num: int, | |
actor_head_hidden_size: int = 256, | |
actor_head_layer_num: int = 2, | |
critic_head_hidden_size: int = 512, | |
critic_head_layer_num: int = 1, | |
action_space: str = 'discrete', | |
activation: Optional[nn.Module] = nn.ReLU(), | |
norm_type: Optional[str] = None, | |
sigma_type: Optional[str] = 'independent', | |
bound_type: Optional[str] = None, | |
) -> None: | |
""" | |
Overview: | |
Init the MAVAC Model according to arguments. | |
Arguments: | |
- agent_obs_shape (:obj:`Union[int, SequenceType]`): Observation's space for single agent, \ | |
such as 8 or [4, 84, 84]. | |
- global_obs_shape (:obj:`Union[int, SequenceType]`): Global observation's space, such as 8 or [4, 84, 84]. | |
- action_shape (:obj:`Union[int, SequenceType]`): Action space shape for single agent, such as 6 \ | |
or [2, 3, 3]. | |
- agent_num (:obj:`int`): This parameter is temporarily reserved. This parameter may be required for \ | |
subsequent changes to the model | |
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``actor_head`` network, defaults \ | |
to 256, it must match the last element of ``agent_obs_shape``. | |
- actor_head_layer_num (:obj:`int`): The num of layers used in the ``actor_head`` network to compute action. | |
- critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``critic_head`` network, defaults \ | |
to 512, it must match the last element of ``global_obs_shape``. | |
- critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output for \ | |
critic's nn. | |
- action_space (:obj:`Union[int, SequenceType]`): The type of different action spaces, including \ | |
['discrete', 'continuous'], then will instantiate corresponding head, including ``DiscreteHead`` \ | |
and ``ReparameterizationHead``. | |
- activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` the after \ | |
``layer_fn``, if ``None`` then default set to ``nn.ReLU()``. | |
- norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \ | |
``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN']. | |
- sigma_type (:obj:`Optional[str]`): The type of sigma in continuous action space, see \ | |
``ding.torch_utils.network.dreamer.ReparameterizationHead`` for more details, in MAPPO, it defaults \ | |
to ``independent``, which means state-independent sigma parameters. | |
- bound_type (:obj:`Optional[str]`): The type of action bound methods in continuous action space, defaults \ | |
to ``None``, which means no bound. | |
""" | |
super(MAVAC, self).__init__() | |
agent_obs_shape: int = squeeze(agent_obs_shape) | |
global_obs_shape: int = squeeze(global_obs_shape) | |
action_shape: int = squeeze(action_shape) | |
self.global_obs_shape, self.agent_obs_shape, self.action_shape = global_obs_shape, agent_obs_shape, action_shape | |
self.action_space = action_space | |
# Encoder Type | |
# We directly connect the Head after a Liner layer instead of using the 3-layer FCEncoder. | |
# In SMAC task it can obviously improve the performance. | |
# Users can change the model according to their own needs. | |
self.actor_encoder = nn.Identity() | |
self.critic_encoder = nn.Identity() | |
# Head Type | |
self.critic_head = nn.Sequential( | |
nn.Linear(global_obs_shape, critic_head_hidden_size), activation, | |
RegressionHead( | |
critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type | |
) | |
) | |
assert self.action_space in ['discrete', 'continuous'], self.action_space | |
if self.action_space == 'discrete': | |
self.actor_head = nn.Sequential( | |
nn.Linear(agent_obs_shape, actor_head_hidden_size), activation, | |
DiscreteHead( | |
actor_head_hidden_size, | |
action_shape, | |
actor_head_layer_num, | |
activation=activation, | |
norm_type=norm_type | |
) | |
) | |
elif self.action_space == 'continuous': | |
self.actor_head = nn.Sequential( | |
nn.Linear(agent_obs_shape, actor_head_hidden_size), activation, | |
ReparameterizationHead( | |
actor_head_hidden_size, | |
action_shape, | |
actor_head_layer_num, | |
sigma_type=sigma_type, | |
activation=activation, | |
norm_type=norm_type, | |
bound_type=bound_type | |
) | |
) | |
# must use list, not nn.ModuleList | |
self.actor = [self.actor_encoder, self.actor_head] | |
self.critic = [self.critic_encoder, self.critic_head] | |
# for convenience of call some apis(such as: self.critic.parameters()), but may cause | |
# misunderstanding when print(self) | |
self.actor = nn.ModuleList(self.actor) | |
self.critic = nn.ModuleList(self.critic) | |
def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict: | |
""" | |
Overview: | |
MAVAC forward computation graph, input observation tensor to predict state value or action logit. \ | |
``mode`` includes ``compute_actor``, ``compute_critic``, ``compute_actor_critic``. | |
Different ``mode`` will forward with different network modules to get different outputs and save \ | |
computation. | |
Arguments: | |
- inputs (:obj:`Dict`): The input dict including observation and related info, \ | |
whose key-values vary from different ``mode``. | |
- mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class. | |
Returns: | |
- outputs (:obj:`Dict`): The output dict of MAVAC's forward computation graph, whose key-values vary from \ | |
different ``mode``. | |
Examples (Actor): | |
>>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14) | |
>>> inputs = { | |
'agent_state': torch.randn(10, 8, 64), | |
'global_state': torch.randn(10, 8, 128), | |
'action_mask': torch.randint(0, 2, size=(10, 8, 14)) | |
} | |
>>> actor_outputs = model(inputs,'compute_actor') | |
>>> assert actor_outputs['logit'].shape == torch.Size([10, 8, 14]) | |
Examples (Critic): | |
>>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14) | |
>>> inputs = { | |
'agent_state': torch.randn(10, 8, 64), | |
'global_state': torch.randn(10, 8, 128), | |
'action_mask': torch.randint(0, 2, size=(10, 8, 14)) | |
} | |
>>> critic_outputs = model(inputs,'compute_critic') | |
>>> assert actor_outputs['value'].shape == torch.Size([10, 8]) | |
Examples (Actor-Critic): | |
>>> model = MAVAC(64, 64) | |
>>> inputs = { | |
'agent_state': torch.randn(10, 8, 64), | |
'global_state': torch.randn(10, 8, 128), | |
'action_mask': torch.randint(0, 2, size=(10, 8, 14)) | |
} | |
>>> outputs = model(inputs,'compute_actor_critic') | |
>>> assert outputs['value'].shape == torch.Size([10, 8, 14]) | |
>>> assert outputs['logit'].shape == torch.Size([10, 8]) | |
""" | |
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) | |
return getattr(self, mode)(inputs) | |
def compute_actor(self, x: Dict) -> Dict: | |
""" | |
Overview: | |
MAVAC forward computation graph for actor part, \ | |
predicting action logit with agent observation tensor in ``x``. | |
Arguments: | |
- x (:obj:`Dict`): Input data dict with keys ['agent_state', 'action_mask'(optional)]. | |
- agent_state: (:obj:`torch.Tensor`): Each agent local state(obs). | |
- action_mask(optional): (:obj:`torch.Tensor`): When ``action_space`` is discrete, action_mask needs \ | |
to be provided to mask illegal actions. | |
Returns: | |
- outputs (:obj:`Dict`): The output dict of the forward computation graph for actor, including ``logit``. | |
ReturnsKeys: | |
- logit (:obj:`torch.Tensor`): The predicted action logit tensor, for discrete action space, it will be \ | |
the same dimension real-value ranged tensor of possible action choices, and for continuous action \ | |
space, it will be the mu and sigma of the Gaussian distribution, and the number of mu and sigma is the \ | |
same as the number of continuous actions. | |
Shapes: | |
- logit (:obj:`torch.FloatTensor`): :math:`(B, M, N)`, where B is batch size and N is ``action_shape`` \ | |
and M is ``agent_num``. | |
Examples: | |
>>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14) | |
>>> inputs = { | |
'agent_state': torch.randn(10, 8, 64), | |
'global_state': torch.randn(10, 8, 128), | |
'action_mask': torch.randint(0, 2, size=(10, 8, 14)) | |
} | |
>>> actor_outputs = model(inputs,'compute_actor') | |
>>> assert actor_outputs['logit'].shape == torch.Size([10, 8, 14]) | |
""" | |
if self.action_space == 'discrete': | |
action_mask = x['action_mask'] | |
x = x['agent_state'] | |
x = self.actor_encoder(x) | |
x = self.actor_head(x) | |
logit = x['logit'] | |
logit[action_mask == 0.0] = -99999999 | |
elif self.action_space == 'continuous': | |
x = x['agent_state'] | |
x = self.actor_encoder(x) | |
x = self.actor_head(x) | |
logit = x | |
return {'logit': logit} | |
def compute_critic(self, x: Dict) -> Dict: | |
""" | |
Overview: | |
MAVAC forward computation graph for critic part. \ | |
Predict state value with global observation tensor in ``x``. | |
Arguments: | |
- x (:obj:`Dict`): Input data dict with keys ['global_state']. | |
- global_state: (:obj:`torch.Tensor`): Global state(obs). | |
Returns: | |
- outputs (:obj:`Dict`): The output dict of MAVAC's forward computation graph for critic, \ | |
including ``value``. | |
ReturnsKeys: | |
- value (:obj:`torch.Tensor`): The predicted state value tensor. | |
Shapes: | |
- value (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch size and M is ``agent_num``. | |
Examples: | |
>>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14) | |
>>> inputs = { | |
'agent_state': torch.randn(10, 8, 64), | |
'global_state': torch.randn(10, 8, 128), | |
'action_mask': torch.randint(0, 2, size=(10, 8, 14)) | |
} | |
>>> critic_outputs = model(inputs,'compute_critic') | |
>>> assert critic_outputs['value'].shape == torch.Size([10, 8]) | |
""" | |
x = self.critic_encoder(x['global_state']) | |
x = self.critic_head(x) | |
return {'value': x['pred']} | |
def compute_actor_critic(self, x: Dict) -> Dict: | |
""" | |
Overview: | |
MAVAC forward computation graph for both actor and critic part, input observation to predict action \ | |
logit and state value. | |
Arguments: | |
- x (:obj:`Dict`): The input dict contains ``agent_state``, ``global_state`` and other related info. | |
Returns: | |
- outputs (:obj:`Dict`): The output dict of MAVAC's forward computation graph for both actor and critic, \ | |
including ``logit`` and ``value``. | |
ReturnsKeys: | |
- logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``. | |
- value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. | |
Shapes: | |
- logit (:obj:`torch.FloatTensor`): :math:`(B, M, N)`, where B is batch size and N is ``action_shape`` \ | |
and M is ``agent_num``. | |
- value (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch sizeand M is ``agent_num``. | |
Examples: | |
>>> model = MAVAC(64, 64) | |
>>> inputs = { | |
'agent_state': torch.randn(10, 8, 64), | |
'global_state': torch.randn(10, 8, 128), | |
'action_mask': torch.randint(0, 2, size=(10, 8, 14)) | |
} | |
>>> outputs = model(inputs,'compute_actor_critic') | |
>>> assert outputs['value'].shape == torch.Size([10, 8]) | |
>>> assert outputs['logit'].shape == torch.Size([10, 8, 14]) | |
""" | |
logit = self.compute_actor(x)['logit'] | |
value = self.compute_critic(x)['value'] | |
return {'logit': logit, 'value': value} | |