Spaces:
Sleeping
Sleeping
from typing import Dict, Union | |
import torch | |
import torch.nn as nn | |
from functools import reduce | |
from ding.torch_utils import one_hot, MLP | |
from ding.utils import squeeze, list_split, MODEL_REGISTRY, SequenceType | |
from .q_learning import DRQN | |
class COMAActorNetwork(nn.Module): | |
""" | |
Overview: | |
Decentralized actor network in COMA algorithm. | |
Interface: | |
``__init__``, ``forward`` | |
""" | |
def __init__( | |
self, | |
obs_shape: int, | |
action_shape: int, | |
hidden_size_list: SequenceType = [128, 128, 64], | |
): | |
""" | |
Overview: | |
Initialize COMA actor network | |
Arguments: | |
- obs_shape (:obj:`int`): the dimension of each agent's observation state | |
- action_shape (:obj:`int`): the dimension of action shape | |
- hidden_size_list (:obj:`list`): the list of hidden size, default to [128, 128, 64] | |
""" | |
super(COMAActorNetwork, self).__init__() | |
self.main = DRQN(obs_shape, action_shape, hidden_size_list) | |
def forward(self, inputs: Dict) -> Dict: | |
""" | |
Overview: | |
The forward computation graph of COMA actor network | |
Arguments: | |
- inputs (:obj:`dict`): input data dict with keys ['obs', 'prev_state'] | |
- agent_state (:obj:`torch.Tensor`): each agent local state(obs) | |
- action_mask (:obj:`torch.Tensor`): the masked action | |
- prev_state (:obj:`torch.Tensor`): the previous hidden state | |
Returns: | |
- output (:obj:`dict`): output data dict with keys ['logit', 'next_state', 'action_mask'] | |
ArgumentsKeys: | |
- necessary: ``obs`` { ``agent_state``, ``action_mask`` }, ``prev_state`` | |
ReturnsKeys: | |
- necessary: ``logit``, ``next_state``, ``action_mask`` | |
Examples: | |
>>> T, B, A, N = 4, 8, 3, 32 | |
>>> embedding_dim = 64 | |
>>> action_dim = 6 | |
>>> data = torch.randn(T, B, A, N) | |
>>> model = COMAActorNetwork((N, ), action_dim, [128, embedding_dim]) | |
>>> prev_state = [[None for _ in range(A)] for _ in range(B)] | |
>>> for t in range(T): | |
>>> inputs = {'obs': {'agent_state': data[t], 'action_mask': None}, 'prev_state': prev_state} | |
>>> outputs = model(inputs) | |
>>> logit, prev_state = outputs['logit'], outputs['next_state'] | |
""" | |
agent_state = inputs['obs']['agent_state'] | |
prev_state = inputs['prev_state'] | |
if len(agent_state.shape) == 3: # B, A, N | |
agent_state = agent_state.unsqueeze(0) | |
unsqueeze_flag = True | |
else: | |
unsqueeze_flag = False | |
T, B, A = agent_state.shape[:3] | |
agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:]) | |
prev_state = reduce(lambda x, y: x + y, prev_state) | |
output = self.main({'obs': agent_state, 'prev_state': prev_state, 'enable_fast_timestep': True}) | |
logit, next_state = output['logit'], output['next_state'] | |
next_state, _ = list_split(next_state, step=A) | |
logit = logit.reshape(T, B, A, -1) | |
if unsqueeze_flag: | |
logit = logit.squeeze(0) | |
return {'logit': logit, 'next_state': next_state, 'action_mask': inputs['obs']['action_mask']} | |
class COMACriticNetwork(nn.Module): | |
""" | |
Overview: | |
Centralized critic network in COMA algorithm. | |
Interface: | |
``__init__``, ``forward`` | |
""" | |
def __init__( | |
self, | |
input_size: int, | |
action_shape: int, | |
hidden_size: int = 128, | |
): | |
""" | |
Overview: | |
initialize COMA critic network | |
Arguments: | |
- input_size (:obj:`int`): the size of input global observation | |
- action_shape (:obj:`int`): the dimension of action shape | |
- hidden_size_list (:obj:`list`): the list of hidden size, default to 128 | |
Returns: | |
- output (:obj:`dict`): output data dict with keys ['q_value'] | |
Shapes: | |
- obs (:obj:`dict`): ``agent_state``: :math:`(T, B, A, N, D)`, ``action_mask``: :math:`(T, B, A, N, A)` | |
- prev_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]` | |
- logit (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` | |
- next_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]` | |
- action_mask (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` | |
""" | |
super(COMACriticNetwork, self).__init__() | |
self.action_shape = action_shape | |
self.act = nn.ReLU() | |
self.mlp = nn.Sequential( | |
MLP(input_size, hidden_size, hidden_size, 2, activation=self.act), nn.Linear(hidden_size, action_shape) | |
) | |
def forward(self, data: Dict) -> Dict: | |
""" | |
Overview: | |
forward computation graph of qmix network | |
Arguments: | |
- data (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action'] | |
- agent_state (:obj:`torch.Tensor`): each agent local state(obs) | |
- global_state (:obj:`torch.Tensor`): global state(obs) | |
- action (:obj:`torch.Tensor`): the masked action | |
ArgumentsKeys: | |
- necessary: ``obs`` { ``agent_state``, ``global_state`` }, ``action``, ``prev_state`` | |
ReturnsKeys: | |
- necessary: ``q_value`` | |
Examples: | |
>>> agent_num, bs, T = 4, 3, 8 | |
>>> obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 | |
>>> coma_model = COMACriticNetwork( | |
>>> obs_dim - action_dim + global_obs_dim + 2 * action_dim * agent_num, action_dim) | |
>>> data = { | |
>>> 'obs': { | |
>>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim), | |
>>> 'global_state': torch.randn(T, bs, global_obs_dim), | |
>>> }, | |
>>> 'action': torch.randint(0, action_dim, size=(T, bs, agent_num)), | |
>>> } | |
>>> output = coma_model(data) | |
""" | |
x = self._preprocess_data(data) | |
q = self.mlp(x) | |
return {'q_value': q} | |
def _preprocess_data(self, data: Dict) -> torch.Tensor: | |
""" | |
Overview: | |
preprocess data to make it can be used by MLP net | |
Arguments: | |
- data (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action'] | |
- agent_state (:obj:`torch.Tensor`): each agent local state(obs) | |
- global_state (:obj:`torch.Tensor`): global state(obs) | |
- action (:obj:`torch.Tensor`): the masked action | |
ArgumentsKeys: | |
- necessary: ``obs`` { ``agent_state``, ``global_state``} , ``action``, ``prev_state`` | |
Return: | |
- x (:obj:`torch.Tensor`): the data can be used by MLP net, including \ | |
``global_state``, ``agent_state``, ``last_action``, ``action``, ``agent_id`` | |
""" | |
t_size, batch_size, agent_num = data['obs']['agent_state'].shape[:3] | |
agent_state_ori, global_state = data['obs']['agent_state'], data['obs']['global_state'] | |
# splite obs, last_action and agent_id | |
agent_state = agent_state_ori[..., :-self.action_shape - agent_num] | |
last_action = agent_state_ori[..., -self.action_shape - agent_num:-agent_num] | |
last_action = last_action.reshape(t_size, batch_size, 1, -1).repeat(1, 1, agent_num, 1) | |
agent_id = agent_state_ori[..., -agent_num:] | |
action = one_hot(data['action'], self.action_shape) # T, B, A,N | |
action = action.reshape(t_size, batch_size, -1, agent_num * self.action_shape).repeat(1, 1, agent_num, 1) | |
action_mask = (1 - torch.eye(agent_num).to(action.device)) | |
action_mask = action_mask.view(-1, 1).repeat(1, self.action_shape).view(agent_num, -1) # A, A*N | |
action = (action_mask.unsqueeze(0).unsqueeze(0)) * action # T, B, A, A*N | |
global_state = global_state.unsqueeze(2).repeat(1, 1, agent_num, 1) | |
x = torch.cat([global_state, agent_state, last_action, action, agent_id], -1) | |
return x | |
class COMA(nn.Module): | |
""" | |
Overview: | |
The network of COMA algorithm, which is QAC-type actor-critic. | |
Interface: | |
``__init__``, ``forward`` | |
Properties: | |
- mode (:obj:`list`): The list of forward mode, including ``compute_actor`` and ``compute_critic`` | |
""" | |
mode = ['compute_actor', 'compute_critic'] | |
def __init__( | |
self, agent_num: int, obs_shape: Dict, action_shape: Union[int, SequenceType], | |
actor_hidden_size_list: SequenceType | |
) -> None: | |
""" | |
Overview: | |
initialize COMA network | |
Arguments: | |
- agent_num (:obj:`int`): the number of agent | |
- obs_shape (:obj:`Dict`): the observation information, including agent_state and \ | |
global_state | |
- action_shape (:obj:`Union[int, SequenceType]`): the dimension of action shape | |
- actor_hidden_size_list (:obj:`SequenceType`): the list of hidden size | |
""" | |
super(COMA, self).__init__() | |
action_shape = squeeze(action_shape) | |
actor_input_size = squeeze(obs_shape['agent_state']) | |
critic_input_size = squeeze(obs_shape['agent_state']) + squeeze(obs_shape['global_state']) + \ | |
agent_num * action_shape + (agent_num - 1) * action_shape | |
critic_hidden_size = actor_hidden_size_list[-1] | |
self.actor = COMAActorNetwork(actor_input_size, action_shape, actor_hidden_size_list) | |
self.critic = COMACriticNetwork(critic_input_size, action_shape, critic_hidden_size) | |
def forward(self, inputs: Dict, mode: str) -> Dict: | |
""" | |
Overview: | |
forward computation graph of COMA network | |
Arguments: | |
- inputs (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action'] | |
- agent_state (:obj:`torch.Tensor`): each agent local state(obs) | |
- global_state (:obj:`torch.Tensor`): global state(obs) | |
- action (:obj:`torch.Tensor`): the masked action | |
ArgumentsKeys: | |
- necessary: ``obs`` { ``agent_state``, ``global_state``, ``action_mask`` }, ``action``, ``prev_state`` | |
ReturnsKeys: | |
- necessary: | |
- compute_critic: ``q_value`` | |
- compute_actor: ``logit``, ``next_state``, ``action_mask`` | |
Shapes: | |
- obs (:obj:`dict`): ``agent_state``: :math:`(T, B, A, N, D)`, ``action_mask``: :math:`(T, B, A, N, A)` | |
- prev_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]` | |
- logit (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` | |
- next_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]` | |
- action_mask (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` | |
- q_value (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` | |
Examples: | |
>>> agent_num, bs, T = 4, 3, 8 | |
>>> agent_num, bs, T = 4, 3, 8 | |
>>> obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 | |
>>> coma_model = COMA( | |
>>> agent_num=agent_num, | |
>>> obs_shape=dict(agent_state=(obs_dim, ), global_state=(global_obs_dim, )), | |
>>> action_shape=action_dim, | |
>>> actor_hidden_size_list=[128, 64], | |
>>> ) | |
>>> prev_state = [[None for _ in range(agent_num)] for _ in range(bs)] | |
>>> data = { | |
>>> 'obs': { | |
>>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim), | |
>>> 'action_mask': None, | |
>>> }, | |
>>> 'prev_state': prev_state, | |
>>> } | |
>>> output = coma_model(data, mode='compute_actor') | |
>>> data= { | |
>>> 'obs': { | |
>>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim), | |
>>> 'global_state': torch.randn(T, bs, global_obs_dim), | |
>>> }, | |
>>> 'action': torch.randint(0, action_dim, size=(T, bs, agent_num)), | |
>>> } | |
>>> output = coma_model(data, mode='compute_critic') | |
""" | |
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) | |
if mode == 'compute_actor': | |
return self.actor(inputs) | |
elif mode == 'compute_critic': | |
return self.critic(inputs) | |