Spaces:
Sleeping
Sleeping
from collections import namedtuple | |
import numpy as np | |
def convert(dictionary): | |
return namedtuple('GenericDict', dictionary.keys())(**dictionary) | |
class MultiAgentEnv(object): | |
def __init__(self, batch_size=None, **kwargs): | |
# Unpack arguments from sacred | |
args = kwargs["env_args"] | |
if isinstance(args, dict): | |
args = convert(args) | |
self.args = args | |
if getattr(args, "seed", None) is not None: | |
self.seed = args.seed | |
self.rs = np.random.RandomState(self.seed) # initialise numpy random state | |
def step(self, actions): | |
""" Returns reward, terminated, info """ | |
raise NotImplementedError | |
def get_obs(self): | |
""" Returns all agent observations in a list """ | |
raise NotImplementedError | |
def get_obs_agent(self, agent_id): | |
""" Returns observation for agent_id """ | |
raise NotImplementedError | |
def get_obs_size(self): | |
""" Returns the shape of the observation """ | |
raise NotImplementedError | |
def get_state(self): | |
raise NotImplementedError | |
def get_state_size(self): | |
""" Returns the shape of the state""" | |
raise NotImplementedError | |
def get_avail_actions(self): | |
raise NotImplementedError | |
def get_avail_agent_actions(self, agent_id): | |
""" Returns the available actions for agent_id """ | |
raise NotImplementedError | |
def get_total_actions(self): | |
""" Returns the total number of actions an agent could ever take """ | |
# TODO: This is only suitable for a discrete 1 dimensional action space for each agent | |
raise NotImplementedError | |
def get_stats(self): | |
raise NotImplementedError | |
# TODO: Temp hack | |
def get_agg_stats(self, stats): | |
return {} | |
def reset(self): | |
""" Returns initial observations and states""" | |
raise NotImplementedError | |
def render(self): | |
raise NotImplementedError | |
def close(self): | |
raise NotImplementedError | |
def seed(self, seed): | |
raise NotImplementedError | |
def get_env_info(self): | |
env_info = { | |
"state_shape": self.get_state_size(), | |
"obs_shape": self.get_obs_size(), | |
"n_actions": self.get_total_actions(), | |
"n_agents": self.n_agents, | |
"episode_limit": self.episode_limit | |
} | |
return env_info | |