Spaces:
Build error
Build error
File size: 4,646 Bytes
01523b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import asyncio
import logging
from typing import Any, Dict, List
import json
from agentverse.agents.simulation_agent.conversation import BaseAgent
# from agentverse.environments.simulation_env.rules.base import Rule
from agentverse.environments.simulation_env.rules.base import SimulationRule as Rule
from agentverse.message import Message
from .. import env_registry as EnvironmentRegistry
from ..base import BaseEnvironment
from agentverse.initialization import load_tools
@EnvironmentRegistry.register("sde_team")
class SdeTeamEnvironment(BaseEnvironment):
"""
A basic environment implementing the logic of conversation to craft code.
Args:
agents: List of agents
rule: Rule for the environment
max_turns: Maximum number of turns
cnt_turn: Current turn number
last_messages: Messages from last turn
rule_params: Variables set by the rule
"""
agents: List[BaseAgent]
rule: Rule
max_turns: int = 10
cnt_turn: int = 0
last_messages: List[Message] = []
rule_params: Dict = {}
task_name: str = "test"
def __init__(self, rule, **kwargs):
rule_config = rule
order_config = rule_config.get("order", {"type": "sde_team"})
visibility_config = rule_config.get("visibility", {"type": "base"})
selector_config = rule_config.get("selector", {"type": "sde_team"})
updater_config = rule_config.get("updater", {"type": "sde_team"})
describer_config = rule_config.get("describer", {"type": "base"})
rule = Rule(
order_config,
visibility_config,
selector_config,
updater_config,
describer_config,
)
super().__init__(rule=rule, **kwargs)
self.rule_params["first_round"] = True
self.rule_params["end_flag"] = False
# # Test code
self.rule_params["name_to_tools"] = {
tool.name: tool
for tool in load_tools(
[
{
"tool_name": "code_interpreter",
"tool_url": "http://127.0.0.1:8079/tools/code_interpreter/",
}
]
)
}
tool = self.rule_params["name_to_tools"]["execute_unit_tests"]
# print(type(tool))
# d = {
# "func_impl": "def f(x):\n\treturn x + 1",
# "tests": ["assert f(1) == 2"]
# }
# # input_str = json.dumps(d)
# json.loads(input_str)
# tool.run(input_str, verbose=True)
# exit()
async def step(self) -> List[Message]:
"""Run one step of the environment"""
# Get the next agent index
agent_ids = self.rule.get_next_agent_idx(self) # order
# Generate current environment description
# env_descriptions = self.rule.get_env_description(self) # describer
# # Generate the next message
# messages = await asyncio.gather(
# *[self.agents[i].astep(env_descriptions[i]) for i in agent_ids]
# ) # call chatgpt api
messages = await asyncio.gather(*[self.agents[i].astep("") for i in agent_ids])
# Track the messages to get the role of the sender
self.last_messages = messages
# Some rules will select certain messages from all the messages
selected_messages = self.rule.select_message(self, messages) # selector
self.last_messages = selected_messages
self.print_messages(selected_messages)
# Update the memory of the agents
self.rule.update_memory(self) # updater: update memory
# Update the set of visible agents for each agent
self.rule.update_visible_agents(self) # change receiver
self.cnt_turn += 1
return selected_messages
def print_messages(self, messages: List[Message]) -> None:
for message in messages:
if message is not None:
logging.info(f"{message.sender}: {message.content}")
def reset(self) -> None:
"""Reset the environment"""
self.cnt_turn = 0
self.rule.reset()
for agent in self.agents:
agent.reset()
def is_done(self) -> bool:
"""Check if the environment is done"""
if self.cnt_turn >= self.max_turns or self.rule_params["end_flag"]:
# with open("record_human_eval.txt", "a") as f:
# wd = dict()
# wd['task_id'] = self.task_name
# wd['code'] = self.rule_params['code']
# f.write(json.dumps(wd))
return True
return False
|