Spaces:
Sleeping
Sleeping
from urllib.parse import urlparse, parse_qs | |
def _behavior_to_agent_id(behavior_name: str, unique_id: int) -> str: | |
return f"{behavior_name}?agent_id={unique_id}" | |
def _agent_id_to_behavior(agent_id: str) -> str: | |
return agent_id.split("?agent_id=")[0] | |
def _unwrap_batch_steps(batch_steps, behavior_name): | |
decision_batch, termination_batch = batch_steps | |
decision_id = [ | |
_behavior_to_agent_id(behavior_name, i) for i in decision_batch.agent_id | |
] | |
termination_id = [ | |
_behavior_to_agent_id(behavior_name, i) for i in termination_batch.agent_id | |
] | |
agents = decision_id + termination_id | |
obs = { | |
agent_id: [batch_obs[i] for batch_obs in termination_batch.obs] | |
for i, agent_id in enumerate(termination_id) | |
} | |
if decision_batch.action_mask is not None: | |
obs.update( | |
{ | |
agent_id: { | |
"observation": [batch_obs[i] for batch_obs in decision_batch.obs], | |
"action_mask": [mask[i] for mask in decision_batch.action_mask], | |
} | |
for i, agent_id in enumerate(decision_id) | |
} | |
) | |
else: | |
obs.update( | |
{ | |
agent_id: [batch_obs[i] for batch_obs in decision_batch.obs] | |
for i, agent_id in enumerate(decision_id) | |
} | |
) | |
obs = {k: v if len(v) > 1 else v[0] for k, v in obs.items()} | |
dones = {agent_id: True for agent_id in termination_id} | |
dones.update({agent_id: False for agent_id in decision_id}) | |
rewards = { | |
agent_id: termination_batch.reward[i] | |
for i, agent_id in enumerate(termination_id) | |
} | |
rewards.update( | |
{agent_id: decision_batch.reward[i] for i, agent_id in enumerate(decision_id)} | |
) | |
cumulative_rewards = {k: v for k, v in rewards.items()} | |
infos = {} | |
for i, agent_id in enumerate(decision_id): | |
infos[agent_id] = {} | |
infos[agent_id]["behavior_name"] = behavior_name | |
infos[agent_id]["group_id"] = decision_batch.group_id[i] | |
infos[agent_id]["group_reward"] = decision_batch.group_reward[i] | |
for i, agent_id in enumerate(termination_id): | |
infos[agent_id] = {} | |
infos[agent_id]["behavior_name"] = behavior_name | |
infos[agent_id]["group_id"] = termination_batch.group_id[i] | |
infos[agent_id]["group_reward"] = termination_batch.group_reward[i] | |
infos[agent_id]["interrupted"] = termination_batch.interrupted[i] | |
id_map = {agent_id: i for i, agent_id in enumerate(decision_id)} | |
return agents, obs, dones, rewards, cumulative_rewards, infos, id_map | |
def _parse_behavior(full_behavior): | |
parsed = urlparse(full_behavior) | |
name = parsed.path | |
ids = parse_qs(parsed.query) | |
team_id: int = 0 | |
if "team" in ids: | |
team_id = int(ids["team"][0]) | |
return name, team_id | |