xa6's picture
Upload folder using huggingface_hub
4bdab37
"""Wrapper to convert a ChatArena environment into a PettingZoo compatible environment."""
# pyright: reportGeneralTypeIssues=false, reportUnboundVariable=false, reportOptionalMemberAccess=false
from __future__ import annotations
import functools
import string
from typing import List
from chatarena.environments import Environment
from chatarena.environments.base import TimeStep
from chatarena.message import Message
from gymnasium import spaces
from gymnasium.utils import EzPickle
from pettingzoo import AECEnv
from pettingzoo.utils.env import AgentID, ObsType
from chatarena.environments.umshini.debate import create_debate_env
from chatarena.environments.umshini.symmetric_content_moderation import (
create_content_moderation_env,
)
from chatarena.environments.umshini.symmetric_deception import create_deception_env
CHAR_SET = string.printable
class PettingZooCompatibilityV0(AECEnv, EzPickle):
"""This compatibility wrapper converts a ChatArena environment into a PettingZoo environment.
Unique arguments for umshini environments: env_name topic, moderation_policy, restricted_action, player_names, round_length
"""
metadata = {
"render_modes": ["human"],
"name": "PettingZooCompatibilityV0",
"is_parallelizable": False,
"render_fps": 2,
}
def __init__(
self,
env: Environment | None = None,
env_name: str | None = None,
topic: str | None = None,
moderation_policy: str | None = None,
restricted_action: str | None = None,
player_names: list | None = None,
round_length: int | None = 8,
string_observation: bool | None = True,
character_limit: int | None = 4000,
render_mode: str | None = None,
save_json: bool | None = False,
disable_judging: bool | None = True
):
"""Wrapper to convert a ChatArena environment into a PettingZoo environment.
Args:
env (chatarena.environments.Environment): chatarena arena to wrap
env_name (str): name of chatarena environment to load (options: "debate", "content_moderation", "deception")
topic (Optional[str]): topic for debate environment
moderation_policy (Optional[str]): moderation policy for content moderation environment
restricted_action (Optional[str]): restricted action for deception environment
player_names (Optional[str]): name of players in the environment
round_length (Optional[int]): number of rounds before swapping roles for symmetric envs, total rounds for asymmetric envs
string_observation (Optional[bool]): send observations as a single string (rather than a dict)
character_limit (Optional[int]): maximum number of characters for observations and actions
render_mode (Optional[str]): rendering mode
save_json (Optional[bool]): flag to save a json file to the disk containing a chat log
"""
EzPickle.__init__(
self,
env,
env_name,
topic,
moderation_policy,
restricted_action,
player_names,
round_length,
string_observation,
character_limit,
render_mode,
save_json,
)
super().__init__()
if env is None and env_name is None:
raise TypeError(
"ChatArena Environment or environment name must be specified"
)
elif env is not None:
self._env = env
if hasattr(env, "topic"):
self.topic = topic
self.max_turns = round_length
elif hasattr(env, "moderation_policy"):
self.moderation_policy = env.moderation_policy
self.max_turns = round_length * 2
elif hasattr(env, "restricted_action"):
self.restricted_action = env.restricted_action
self.max_turns = round_length * 2
elif env_name is not None:
if env_name == "debate":
assert topic is not None, "topic must be specified for debate env"
self._env = create_debate_env(
topic=topic, player_names=player_names, round_length=round_length, disable_judging=disable_judging
)
self.topic = topic
self.max_turns = round_length
elif env_name == "content_moderation":
assert (
moderation_policy is not None
), "moderation policy must be specified for content moderation env"
self._env = create_content_moderation_env(
moderation_policy=moderation_policy,
player_names=player_names,
round_length=round_length,
disable_judging=disable_judging,
)
self.moderation_policy = moderation_policy
self.max_turns = round_length * 2
elif env_name == "deception":
assert (
restricted_action is not None
), "restricted action must be specified for deception env"
self._env = create_deception_env(
restricted_action=restricted_action,
player_names=player_names,
round_length=round_length,
disable_judging=disable_judging,
)
self.restricted_action = restricted_action
self.max_turns = round_length * 2
else:
raise TypeError(
f"Environment not found: {env_name}. Options: debate, content_moderation, deception"
)
else:
raise TypeError(
"Only one environment argument may be specified: either env or env_name."
)
# Reset the underlying ChatArena environment
self._env.reset()
# Arguments
self.string_observation = string_observation
self.character_limit = character_limit
self.render_mode = render_mode
self.save_json = save_json
# PettingZoo arguments
self.possible_agents = list(self._env.player_names)
self.all_agents = [
"Moderator",
self.possible_agents[0],
self.possible_agents[1],
]
self.observations = {agent: {} for agent in self.possible_agents}
self.rewards = {agent: {} for agent in self.possible_agents}
self.terminations = {agent: {} for agent in self.possible_agents}
self.truncations = {agent: {} for agent in self.possible_agents}
self.infos = {
agent: {"turn": 0, "obs_dict": {}, "new_messages": [], "all_messages": []}
for agent in self.possible_agents
}
# Custom attributes for housekeeping
self.total_rewards = {agent: 0.0 for agent in self.possible_agents}
self.current_turn = 0
@functools.lru_cache(maxsize=None)
def observation_space(self, agent: AgentID):
"""observation_space.
We get the observation space from the underlying environment.
Supports both string and dict observations spaces.
Args:
agent (AgentID): agent
"""
if self.string_observation:
observation_space = spaces.Text(
max_length=self.character_limit, min_length=0, charset=CHAR_SET
)
else:
observation_space = spaces.Dict(
{
agent: spaces.Text(
max_length=self.character_limit, min_length=0, charset=CHAR_SET
)
for agent in self.all_agents
}
)
return observation_space
@functools.lru_cache(maxsize=None)
def action_space(self, agent: AgentID):
"""action_space.
Get the action space from the underlying environment.
Action space currently only supports messages to all players, but could be extended to support private messages.
Args:
agent (AgentID): agent
Returns:
space
"""
return spaces.Text(
max_length=self.character_limit, min_length=0, charset=CHAR_SET
)
def render(self):
"""render.
Print the current game state.
"""
if not hasattr(self, "initial_timestep"):
raise UserWarning(
"You must reset the environment using reset() before calling render()."
)
if self.render_mode == "human":
new_messages = self.infos[self.agent_selection].get("new_messages")
if new_messages is None:
raise Exception("New messages not found")
else:
for message in new_messages:
print(
f"[{message.agent_name}->{message.visible_to}]: {message.content}\n"
)
def observe(self, agent: AgentID) -> ObsType:
"""observe.
Args:
agent (AgentID): agent (e.g., "Player 1")
Returns:
observation
"""
# When PettingZoo agents die, they are removed from the info dict (as well as obs, cumulative rewards, termination, truncation)
if agent not in self.agents:
return None
# Observations and infos are calculated in step(), but need to be calculated before the first step() call
elif type(agent) != str:
raise TypeError("AgentID must be a string")
elif self.observations[agent] != {}:
return self.observations[agent]
else:
# get only the messages that this agent can see
messages = self._env.get_observation(agent)
# calculate current turn
if len(messages) > 0:
self.current_turn = messages[-1].turn
else:
self.current_turn = 0
# filter to only new messages for this agent (observation is limited to only the current message)
new_messages = [m for m in messages if m.turn == self.current_turn]
# string observation (optional flag)
if self.string_observation is True:
observation = ""
for m in new_messages:
observation += f"{m.agent_name}: {m.content}"
# dict observation
else:
observation = {m.agent_name: m.content for m in new_messages}
# We return info in the form of ChatArena messages objects, as well as strings, and a dictionary, to allow for maximum flexibility.
# Dict prevents you from having to parse the message to determine the agent, which may lead to errors if LLMs repeat the agent name (common from my testing)
# I'd argue we might want to use it as the default return type for that reason alone
self.infos[agent]["turn"] = self.current_turn
self.infos[agent]["new_messages"] = new_messages
self.infos[agent]["all_messages"] = messages
self.infos[agent]["obs_dict"] = {
m.agent_name: m.content for m in new_messages
}
self.infos[agent]["player_name"] = self.agent_selection
# info: generate string of full chat log
if self.string_observation is True:
all_messages_string = ""
for m in messages:
all_messages_string += f"[{m.agent_name}->all]: {m.content}\n"
self.infos[agent]["all_messages_string"] = all_messages_string
# info: environment specific information
if hasattr(self, "restricted_action"):
self.infos[agent]["restricted_action"] = self.restricted_action
if hasattr(self, "moderation_policy"):
self.infos[agent]["moderation_policy"] = self.moderation_policy
if hasattr(self, "topic"):
self.infos[agent]["topic"] = self.topic
return observation
def close(self):
"""close."""
msg_lst: List[Message] = self._env.message_pool.get_all_messages()
formatted_state = [{"name": m.agent_name, "turn": m.turn, "text": m.content} for m in msg_lst]
if self.save_json:
import json
import os
from pathlib import Path
Path("env_logs").mkdir(exist_ok=True)
os.chdir("env_logs")
files = os.listdir()
files = [f for f in files if f.startswith(self.metadata["name"]) and f.endswith(".json")]
json.dump(formatted_state, open(self.metadata["name"] + str(len(files)) + ".json", "w"))
print(f"Chatlog has been saved to disk: {self.metadata['name'] + str(len(files)) + '.json'}")
else:
return formatted_state
def _unravel_timestep(self, timestep: TimeStep):
# get observation
messages = timestep.observation
# calculate current turn
if len(messages) > 0:
self.current_turn = messages[-1].turn
else:
self.current_turn = 0
# filter to only new messages (observation is limited to only the current message)
new_messages = [m for m in messages if m.turn == self.current_turn]
# string observation (optional flag)
if self.string_observation is True:
observation = ""
for m in new_messages:
observation += f"{m.agent_name}: {m.content}"
# dict observation
else:
observation = {m.agent_name: m.content for m in new_messages}
# get rewards
rewards = timestep.reward
# get termination
termination = timestep.terminal
# get truncation
truncation = (
self.current_turn >= self.max_turns
) # pyright: ignore[reportGeneralTypeIssues]
info = {}
info["turn"] = self.current_turn
info["new_messages"] = new_messages
info["all_messages"] = messages
info["obs_dict"] = {m.agent_name: m.content for m in new_messages}
info["player_name"] = self.agent_selection
# info: generate string of full chat log
if self.string_observation is True:
all_messages_string = ""
for m in messages:
all_messages_string += f"[{m.agent_name}->all]: {m.content}\n"
info["all_messages_string"] = all_messages_string
# info: environment specific information
if hasattr(self, "restricted_action"):
info["restricted_action"] = self.restricted_action
if hasattr(self, "moderation_policy"):
info["moderation_policy"] = self.moderation_policy
if hasattr(self, "topic"):
info["topic"] = self.topic
return observation, rewards, termination, truncation, info
def reset(
self,
return_info: bool | None = False,
seed: int | None = None,
options: dict | None = None,
):
"""reset.
Args:
seed (Optional[int]): seed
return_info (Optional[bool]): flag to return info as well as observation
options (Optional[Dict]): options
"""
# reset our custom attributes
self.current_turn = 0
self.total_rewards = {agent: 0.0 for agent in self.possible_agents}
# reset the ChatArena environment
self.initial_timestep = self._env.reset()
# reset the PettingZoo wrapper
self.agents = self.possible_agents[:]
self.observations = {agent: {} for agent in self.agents}
self._cumulative_rewards = {agent: 0.0 for agent in self.agents}
self.rewards = self.initial_timestep.reward
self.terminations = {agent: False for agent in self.agents}
self.truncations = {agent: False for agent in self.agents}
# info keys: turn, new_messages, all_messages, obs_dict, player_name, all_messages_string, restricted_action, moderation_policy, topic
self.infos = {
agent: {}
for agent in self.possible_agents
}
# get the first player
self._agent_selector = self._env.agent_selector
self.agent_selection = self._agent_selector.reset()
# get the first observation
observation = self.observe(self.agent_selection)
info = self.infos[self.agent_selection]
# render the environment (print the initial scenario text)
if self.render_mode is not None:
self.render()
def step(self, action: str):
"""Steps.
Steps the agent with an action.
Args:
action (str): action
"""
if (
self.truncations[self.agent_selection]
or self.terminations[self.agent_selection]
):
return self._was_dead_step(action)
agent = self.agent_selection
timestep = self._env.step(player_name=agent, action=action)
observation, reward, termination, truncation, info = self._unravel_timestep(
timestep
)
# add moderator messages to info so they are rendered
# some environments (e.g., debate) have the moderator announce the winner as the last message
if termination or truncation:
if info["all_messages"][-1].agent_name == "Moderator":
info["new_messages"].append(info["all_messages"][-2])
# account for the moderator interjecting statements such as "roles are being swapped"
# first turn we already render the moderator's message, so we don't need to add the message here
if info["turn"] > 1:
if len(info["all_messages"]) > 1 and info["all_messages"][-2].agent_name == "Moderator":
info["new_messages"].append(info["all_messages"][-2])
self.observations[agent] = observation
self.rewards = reward
self.terminations[agent] = termination
self.truncations[agent] = truncation
self.infos[agent] = info
# If we receive a termination or truncation signal from either agent, the game is over
if termination:
self.terminations = {agent: True for agent in self.possible_agents}
if truncation:
self.truncations = {agent: True for agent in self.possible_agents}
# Update total rewards for each agent (in one timestep both agents can get rewards/penalties)
self.total_rewards[agent] += self._cumulative_rewards[agent]
# Reset PettingZoo cumulative_rewards attribute (tracks accumulated rewards for an agent since its previous action)
self._cumulative_rewards[agent] = 0
if self.render_mode is not None:
self.render()
# Get the next agent in PettingZoo, and iterate the underlying environment (used for reward calculations)
self.agent_selection = self._agent_selector.next()
# Adds current step rewards to _cumulative_rewards
self._accumulate_rewards()