xa6's picture
Upload folder using huggingface_hub
4bdab37
from typing import List, Dict, Union
import random
import re
from .base import Environment, TimeStep
from ..message import Message, MessagePool
from ..agent import SIGNAL_END_OF_CONVERSATION
from ..config import EnvironmentConfig
DEFAULT_TOPIC_CODES = {
"Fruits": [
"Apple",
"Banana",
"Orange",
"Grape",
"Strawberry",
"Pineapple",
"Mango",
"Watermelon",
],
"Animals": [
"Lion",
"Elephant",
"Giraffe",
"Monkey",
"Zebra",
"Tiger",
"Bear",
"Kangaroo",
],
"Sports": [
"Soccer",
"Basketball",
"Tennis",
"Baseball",
"Swimming",
"Cycling",
"Volleyball",
"Golf",
],
"Countries": [
"United States",
"Canada",
"Brazil",
"United Kingdom",
"France",
"Germany",
"Japan",
"Australia",
],
}
class Chameleon(Environment):
type_name = "chameleon"
def __init__(self, player_names: List[str], topic_codes: Dict[str, List[str]] = None, **kwargs):
super().__init__(player_names=player_names, topic_codes=topic_codes, **kwargs)
if topic_codes is None:
topic_codes = DEFAULT_TOPIC_CODES
self.topic_codes = topic_codes
# The "state" of the environment is maintained by the message pool
self.message_pool = MessagePool()
# Randomly sample a topic, code and chameleon player
self.topic = None
self.code = None
self.chameleon_name = None
self.non_chameleon_names = None
# Game states
self._current_turn = 0
self._next_player_idx = 0
self._current_phase = "give clues" # "give clues", "accuse", "guess"
self._players_votes = None
self._initialized = False
self.reset() # To initialize the game (select topic, code, chameleon)
def get_next_player(self) -> str:
"""
get the next player
"""
if self._current_phase != "guess":
return self.player_names[self._next_player_idx]
else:
return self.chameleon_name
def reset(self):
"""
sample topic, code and chameleon code
"""
self.topic = random.choice(list(self.topic_codes.keys()))
self.code = random.choice(self.topic_codes[self.topic])
self.chameleon_name = random.choice(self.player_names)
self.non_chameleon_names = [name for name in self.player_names if name != self.chameleon_name]
self._current_turn = 0
self._next_player_idx = 0
self._current_phase = "give clues"
self.message_pool.reset()
self._moderator_speak(f"Now the game starts! The topic is: {self.topic}")
self._moderator_speak(f"You are not chameleon. The word is: {self.code}",
visible_to=self.non_chameleon_names)
self._moderator_speak(f"You are the chameleon!", visible_to=self.chameleon_name)
self._moderator_speak(
f"Now everyone gives one clue (but don't give away the secret word). "
f"You cannot repeat what others has said. We will start with {self.player_names[0]}.")
self._current_turn = 1
self._players_votes = {name: 0 for name in self.player_names}
self._initialized = True
init_timestep = TimeStep(observation=self.get_observation(),
reward=self.get_zero_rewards(),
terminal=False)
return init_timestep
def print(self):
self.message_pool.print()
def get_observation(self, player_name=None) -> List[Message]:
"""
get observation for the player
"""
if player_name is None:
return self.message_pool.get_all_messages()
else:
return self.message_pool.get_visible_messages(player_name, turn=self._current_turn)
def _text2vote(self, text) -> str:
"""
convert text to vote, return a player's name
"""
# lower = text.lower().replace("[", "").replace("]", "").replace(".", "")
text = text.lower()
for name in self.player_names:
candidates = [name.lower(), name.lower().replace(" ", ""), name.lower().replace(" ", "_")]
if any([candidate in text for candidate in candidates]):
return name
return ""
def _is_true_code(self, text) -> bool:
"""
Check whether the text is the true code
"""
# Get the word enclosed by quote marks with regex
pattern = r"\"(.+?)\""
match = re.search(pattern, text)
if match:
return match.group(1).lower().replace(" ", "") == self.code.lower().replace(" ", "")
else:
# if no quote marks, check whether the last k words match the code
words = text.split()
if len(words) >= len(self.code.split()):
guessed_term = "".join(words[-len(self.code.split()):]).lower().replace(".", "")
return guessed_term == self.code.lower().replace(" ", "").replace(".", "")
else:
return False
def _moderator_speak(self, text: str, visible_to: Union[str, List[str]] = "all"):
"""
moderator say something
"""
message = Message(agent_name="Moderator", content=text, turn=self._current_turn, visible_to=visible_to)
self.message_pool.append_message(message)
def get_rewards(self, chameleon_win: bool) -> Dict[str, float]:
"""
get rewards for each player
"""
rewards = {}
for name in self.player_names:
# The winner gets 1, the loser gets 0
rewards[name] = float((name == self.chameleon_name) == chameleon_win)
return rewards
def is_terminal(self) -> bool:
"""
check if the conversation is over
"""
# If the last message is the signal, then the conversation is over
if self.message_pool.last_message.content.startswith(SIGNAL_END_OF_CONVERSATION):
return True
def step(self, player_name: str, action: str) -> TimeStep:
"""
step function that is called by the arena
Args:
player_name: the name of the player that takes the action
action: the action that the agents wants to take
"""
# If not initialized, reset the environment
if not self._initialized:
self.reset()
# self.message_pool.print()
# print(f"Chameleon: {self.chameleon_name}, Code: {self.code}, Topic: {self.topic}")
assert player_name == self.get_next_player(), f"Wrong player! It is {self.get_next_player()} turn."
if self._current_phase == "give clues":
message = Message(agent_name=player_name, content=action, turn=self._current_turn)
self.message_pool.append_message(message)
# Update the counters
self._current_turn += 1
if self._next_player_idx < len(self.player_names) - 1:
self._next_player_idx += 1
else:
self._next_player_idx = 0
self._current_phase = "accuse"
self._moderator_speak("Now vote which of the other players (excluding yourself) is the chameleon. "
"You cannot vote for yourself.")
self._current_turn += 1
timestep = TimeStep(observation=self.get_observation(),
reward=self.get_zero_rewards(),
terminal=False) # Return all the messages
elif self._current_phase == "accuse":
message = Message(agent_name=player_name, content=action, turn=self._current_turn,
visible_to=[player_name])
self.message_pool.append_message(message)
vote = self._text2vote(action)
if vote in self.player_names:
self._players_votes[vote] += 1
if self._next_player_idx < len(self.player_names) - 1:
self._next_player_idx += 1
rewards = self.get_zero_rewards()
terminal = False
else:
# print(self._players_votes)
accuse_correct, even_vote = True, False
max_vote_player = max(self._players_votes, key=self._players_votes.get)
# detach if other players has the same number of votes
for name, vote in self._players_votes.items():
if name != max_vote_player and vote == self._players_votes[max_vote_player]:
accuse_correct, even_vote = False, True
if max_vote_player != self.chameleon_name:
accuse_correct = False
if not accuse_correct:
if even_vote:
self._moderator_speak(f"There are even votes. The accusation does not stand. "
f"{self.chameleon_name} is the chameleon. {self.chameleon_name} won the game!")
else:
self._moderator_speak(
f"The most-voted player is {max_vote_player}. The accusation is incorrect. "
f"{self.chameleon_name} is the chameleon. {self.chameleon_name} won the game!"
)
rewards = self.get_rewards(chameleon_win=True)
terminal = True
else:
self._moderator_speak(f"The accusation is correct! {self.chameleon_name} is the chameleon! "
f"Now {self.chameleon_name} can guess the secret code. "
"You should say: I guess the code is \"...\"")
self._current_phase = "guess"
rewards = self.get_zero_rewards()
terminal = False
self._current_turn += 1
timestep = TimeStep(observation=self.get_observation(), reward=rewards, terminal=terminal)
elif self._current_phase == "guess":
message = Message(agent_name=player_name, content=action, turn=self._current_turn,
visible_to=player_name)
self.message_pool.append_message(message)
if self._is_true_code(action):
self._moderator_speak(f"{player_name} guessed the code correctly! The secret word is {self.code}. "
f"{self.chameleon_name} won!")
rewards = self.get_rewards(chameleon_win=True)
else:
self._moderator_speak(f"{player_name} guessed the code wrong! The secret word is {self.code}. "
f"{self.non_chameleon_names} won!")
rewards = self.get_rewards(chameleon_win=False)
timestep = TimeStep(observation=self.get_observation(),
reward=rewards,
terminal=True)
else:
raise ValueError(f"Unknown phase: {self._current_phase}")
# Check if the player signals the end of the conversation
if self.is_terminal():
timestep.terminal = True
return timestep