Spaces:
Runtime error
Runtime error
from typing import List, Union | |
from .base import TimeStep, Environment | |
from ..message import Message, MessagePool | |
from ..agent import Moderator, SIGNAL_END_OF_CONVERSATION | |
from ..config import EnvironmentConfig, AgentConfig | |
class Conversation(Environment): | |
""" | |
Turn-based fully observable conversation environment. | |
Next speaker order is either parallel or round-robin. | |
""" | |
type_name = "conversation" | |
def __init__(self, player_names: List[str], parallel: bool = False, **kwargs): | |
super().__init__(player_names=player_names, parallel=parallel, **kwargs) | |
self.parallel = parallel | |
# The "state" of the environment is maintained by the message pool | |
self.message_pool = MessagePool() | |
self._current_turn = 0 | |
self._next_player_idx = 0 | |
def reset(self): | |
self._current_turn = 0 | |
self._next_player_idx = 0 | |
self.message_pool.reset() | |
init_timestep = TimeStep(observation=[], | |
reward=self.get_zero_rewards(), | |
terminal=False) | |
return init_timestep | |
def to_config(self) -> EnvironmentConfig: | |
return EnvironmentConfig(env_type=self.type_name, player_names=self.player_names, parallel=self.parallel) | |
def print(self): | |
self.message_pool.print() | |
def get_next_player(self) -> str: | |
""" | |
get the next player | |
""" | |
return self.player_names[self._next_player_idx] | |
def get_observation(self, player_name=None) -> List[Message]: | |
""" | |
get observation for the player | |
""" | |
if player_name is None: | |
return self.message_pool.get_all_messages() | |
else: | |
return self.message_pool.get_visible_messages(player_name, turn=self._current_turn) | |
def is_terminal(self) -> bool: | |
""" | |
check if the conversation is over | |
""" | |
# If the last message is the signal, then the conversation is over | |
if self.message_pool.last_message.content.startswith(SIGNAL_END_OF_CONVERSATION): | |
return True | |
def step(self, player_name: str, action: str) -> TimeStep: | |
""" | |
step function that is called by the arena | |
Args: | |
player_name: the name of the player that takes the action | |
action: the action that the agents wants to take | |
""" | |
message = Message(agent_name=player_name, content=action, turn=self._current_turn) | |
self.message_pool.append_message(message) | |
# Update the counters | |
if not self.parallel or self._next_player_idx == 0: | |
self._current_turn += 1 | |
self._next_player_idx = (self._next_player_idx + 1) % self.num_players | |
timestep = TimeStep(observation=self.get_observation(), | |
reward=self.get_zero_rewards(), | |
terminal=self.is_terminal()) # Return all the messages | |
return timestep | |
class ModeratedConversation(Conversation): | |
""" | |
Turn-based fully observable conversation environment. | |
Next speaker order is either parallel or round-robin. | |
Moderator is a special agent that can see all messages and can decide whether the conversation is over. | |
""" | |
type_name = "moderated_conversation" | |
def __init__(self, player_names: List[str], moderator: Union[Moderator, AgentConfig], | |
parallel: bool = False, moderator_visibility="all", moderator_period=None, **kwargs): | |
super().__init__(player_names=player_names, parallel=parallel, **kwargs) | |
if isinstance(moderator, AgentConfig): | |
moderator_config = moderator | |
moderator = Moderator.from_config(moderator_config) | |
elif not isinstance(moderator, Moderator): | |
raise ValueError("moderator must be either an AgentConfig or a Moderator instance.") | |
self.moderator = moderator | |
self.moderator_visibility = moderator_visibility | |
if moderator_period is None: | |
if parallel: | |
self.moderator_period = "round" | |
else: | |
self.moderator_period = "turn" | |
else: | |
self.moderator_period = moderator_period | |
def to_config(self) -> EnvironmentConfig: | |
# This environment contains some speical config arguments that needs to be handle specially | |
return EnvironmentConfig(env_type=self.type_name, player_names=self.player_names, parallel=self.parallel, | |
moderator=self.moderator.to_config(), moderator_visibility=self.moderator_visibility, | |
moderator_period=self.moderator_period) | |
def step(self, player_name: str, action: str) -> TimeStep: | |
""" | |
step function that is called by the arena | |
Args: | |
player_name: the name of the player that takes the action | |
action: the action that the agents wants to take | |
""" | |
message = Message(agent_name=player_name, content=action, turn=self._current_turn) | |
self.message_pool.append_message(message) | |
# Round-robin order for the next player | |
self._next_player_idx = (self._next_player_idx + 1) % self.num_players | |
if self.moderator_period == "turn" or \ | |
(self.moderator_period == "round" and self._next_player_idx == 0): | |
# Moderator's turn | |
moderator_history = self.message_pool.get_all_messages() | |
moderator_response = self.moderator(moderator_history) | |
moderator_message = Message(agent_name=self.moderator.name, | |
content=moderator_response, | |
turn=self._current_turn, | |
visible_to=self.moderator_visibility) | |
self.message_pool.append_message(moderator_message) | |
terminal = self.moderator.is_terminal(moderator_history) or self.is_terminal() | |
else: | |
terminal = self.is_terminal() | |
# Update the counters | |
if not self.parallel or self._next_player_idx == 0: | |
self._current_turn += 1 | |
timestep = TimeStep(observation=self.get_observation(), | |
reward=self.get_zero_rewards(), | |
terminal=terminal) # Return all the messages | |
return timestep | |