Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
from ding.utils import MODEL_REGISTRY | |
from .qmix import QMix | |
class MADQN(nn.Module): | |
def __init__( | |
self, | |
agent_num: int, | |
obs_shape: int, | |
action_shape: int, | |
hidden_size_list: list, | |
global_obs_shape: int = None, | |
mixer: bool = False, | |
global_cooperation: bool = True, | |
lstm_type: str = 'gru', | |
dueling: bool = False | |
) -> None: | |
super(MADQN, self).__init__() | |
self.current = QMix( | |
agent_num=agent_num, | |
obs_shape=obs_shape, | |
action_shape=action_shape, | |
hidden_size_list=hidden_size_list, | |
global_obs_shape=global_obs_shape, | |
mixer=mixer, | |
lstm_type=lstm_type, | |
dueling=dueling | |
) | |
self.global_cooperation = global_cooperation | |
if self.global_cooperation: | |
cooperation_obs_shape = global_obs_shape | |
else: | |
cooperation_obs_shape = obs_shape | |
self.cooperation = QMix( | |
agent_num=agent_num, | |
obs_shape=cooperation_obs_shape, | |
action_shape=action_shape, | |
hidden_size_list=hidden_size_list, | |
global_obs_shape=global_obs_shape, | |
mixer=mixer, | |
lstm_type=lstm_type, | |
dueling=dueling | |
) | |
def forward(self, data: dict, cooperation: bool = False, single_step: bool = True) -> dict: | |
if cooperation: | |
if self.global_cooperation: | |
data['obs']['agent_state'] = data['obs']['global_state'] | |
return self.cooperation(data, single_step=single_step) | |
else: | |
return self.current(data, single_step=single_step) | |