Spaces:
Build error
Build error
File size: 7,755 Bytes
01523b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import asyncio
import datetime
import logging
from typing import Any, Dict, List, Optional, Set
# from agentverse.agents.agent import Agent
from agentverse.agents.simulation_agent.conversation import BaseAgent
# from agentverse.environments.simulation_env.rules.base import Rule
from agentverse.environments.simulation_env.rules.base import SimulationRule as Rule
from agentverse.message import Message
from .. import env_registry as EnvironmentRegistry
from ..base import BaseEnvironment
@EnvironmentRegistry.register("pokemon")
class PokemonEnvironment(BaseEnvironment):
"""
An environment for Pokémon demo.
Args:
agents: List of agents
locations: A dict of locations to agents within them
rule: Rule for the environment
max_turns: Maximum number of turns
cnt_turn: Current turn number
last_messages: Messages from last turn
rule_params: Variables set by the rule
"""
agents: List[BaseAgent]
locations_to_agents: Dict[str, Set[str]]
# locations_descriptions: Dict[str, str]
time: datetime.datetime = datetime.datetime(2021, 1, 1, 8, 0, 0)
rule: Rule
max_turns: int = 10
cnt_turn: int = 0
last_messages: List[Message] = []
rule_params: Dict = {}
def __init__(self, rule, locations, **kwargs):
rule_config = rule
order_config = rule_config.get("order", {"type": "sequential"})
visibility_config = rule_config.get("visibility", {"type": "all"})
selector_config = rule_config.get("selector", {"type": "basic"})
updater_config = rule_config.get("updater", {"type": "basic"})
describer_config = rule_config.get("describer", {"type": "basic"})
rule = Rule(
order_config,
visibility_config,
selector_config,
updater_config,
describer_config,
)
locations_to_agents = {}
# locations_descriptions = {}
locations_config = locations
for loc in locations_config:
locations_to_agents[loc["name"]] = set(loc["init_agents"])
# locations_descriptions[loc["name"]] = loc["description"]
super().__init__(
rule=rule,
locations_to_agents=locations_to_agents,
# locations_descriptions=locations_descriptions,
**kwargs,
)
async def step(
self,
is_player: bool = False,
player_content: str = None,
receiver: str = None,
receiver_id: Optional[int] = None,
agent_ids: Optional[List[int]] = None,
) -> List[Message]:
"""Run one step of the environment"""
# Get the next agent index
# time.sleep(8)
# return [Message(content="Test", sender="May", receiver=["May"])]
if is_player:
return await self._respond_to_player(player_content, receiver, receiver_id)
else:
return await self._routine_step(agent_ids)
async def _routine_step(self, agent_ids) -> List[Message]:
self.rule.update_visible_agents(self)
# agent_ids = self.rule.get_next_agent_idx(self)
# Generate current environment description
env_descriptions = self.rule.get_env_description(self)
# Generate the next message
messages = await asyncio.gather(
*[self.agents[i].astep(env_descriptions[i]) for i in agent_ids]
)
# messages = self.get_test_messages()
# Some rules will select certain messages from all the messages
selected_messages = self.rule.select_message(self, messages)
# Update the memory of the agents
self.last_messages = selected_messages
self.rule.update_memory(self)
self.print_messages(selected_messages)
self.cnt_turn += 1
self.time += datetime.timedelta(minutes=5)
return selected_messages
async def _respond_to_player(
self,
player_content: str = None,
receiver: str = None,
receiver_id: Optional[int] = None,
) -> List[Message]:
if receiver_id is None:
for agent in self.agents:
if agent.name == receiver:
receiver_id = agent.agent_id
break
agent_ids = [receiver_id]
agent_name = receiver
player_message = Message(
sender="Brenden", content=player_content, receiver=[agent_name]
)
# Update the set of visible agents for each agent
self.rule.update_visible_agents(self)
# Generate current environment description
env_descriptions = self.rule.get_env_description(self, player_content)
# Generate the next message
messages = await asyncio.gather(
*[self.agents[i].astep(env_descriptions[i]) for i in agent_ids]
)
# Some rules will select certain messages from all the messages
# selected_messages = self.rule.select_message(self, messages)
# Update the memory of the agents
self.last_messages = [player_message, *messages]
self.rule.update_memory(self)
self.print_messages(messages)
self.cnt_turn += 1
return messages
def update_state(self, agent_location: Dict[str, str]):
for agent_name, location in agent_location.items():
# original_location = self.get_agent_to_location()[agent_name]
# self.locations_to_agents[original_location].remove(agent_name)
self.locations_to_agents[location].add(agent_name)
def get_agent_to_location(self) -> Dict[str, str]:
ret = {}
for location, agent_names in self.locations_to_agents.items():
for agent in agent_names:
ret[agent] = location
return ret
def print_messages(self, messages: List[Message]) -> None:
for message in messages:
if message is not None:
logging.info(f"{message.sender}: {message.content}")
def reset(self) -> None:
"""Reset the environment"""
self.cnt_turn = 0
self.rule.reset()
for agent in self.agents:
agent.reset()
def is_done(self) -> bool:
"""Check if the environment is done"""
return self.cnt_turn >= self.max_turns
def get_test_messages(self) -> List[Message]:
messages = [
Message(
content='{"to": "Birch", "action": "Speak", "text": "Hi!!!"}',
sender="May",
receiver={"May", "Birch"},
tool_response=[],
),
Message(
content='{"to": "May", "text": "Good morning, May! How is your research going?", "action": "Speak"}',
sender="Birch",
receiver={"May", "Birch"},
tool_response=[],
),
Message(
content='{"to": "Pokémon Center", "action": "MoveTo"}',
sender="Steven",
receiver={"Steven"},
tool_response=[],
),
Message(
content='{"to": "Shop", "last_time": "10 minutes", "action": "MoveTo"}',
sender="Maxie",
receiver={"Maxie"},
tool_response=[],
),
Message(
content='{"to": "Pok\\u00e9mon Center", "action": "MoveTo"}',
sender="Archie",
receiver={"Archie"},
tool_response=[],
),
Message(
content='{"to": "Shop", "action": "MoveTo"}',
sender="Joseph",
receiver={"Joseph"},
tool_response=[],
),
]
return messages
|