adhot-discussion / chatarena /environments /pettingzoo_tictactoe.py
xa6's picture
Upload folder using huggingface_hub
4bdab37
import re
from pettingzoo.classic import tictactoe_v3
from chatarena.environments.base import Environment, TimeStep
from typing import List, Union
from ..message import Message, MessagePool
def action_string_to_action(action: str) -> int:
pattern = r"(X|O): \((\d), (\d)\)"
match = re.match(pattern, action)
if not match:
return -1
items = [item for item in match.groups()]
coords = [int(coord) for coord in items[1:]]
row, column = coords
if row not in [1, 2, 3]:
return -1
if column not in [1, 2, 3]:
return -1
row = row - 1
column = column - 1
return row + column * 3
class PettingzooTicTacToe(Environment):
type_name = "pettingzoo:tictactoe"
def __init__(self, player_names: List[str], **kwargs):
super().__init__(player_names=player_names, **kwargs)
self.env = tictactoe_v3.env()
# The "state" of the environment is maintained by the message pool
self.message_pool = MessagePool()
self._terminal = False
self.reset()
def reset(self):
self.env.reset()
self.current_player = 0
self.turn = 0
self.message_pool.reset()
obs_dict, reward, terminal, truncation, info = self.env.last()
observation = self.get_observation()
self._terminal = terminal
return TimeStep(observation=observation, reward=reward, terminal=terminal)
def get_next_player(self) -> str:
return self.player_names[self.current_player]
def get_observation(self, player_name=None) -> List[Message]:
if player_name is None:
return self.message_pool.get_all_messages()
else:
return self.message_pool.get_visible_messages(player_name, turn=self.turn + 1)
def _moderator_speak(self, text: str, visible_to: Union[str, List[str]] = "all"):
"""
moderator say something
"""
message = Message(agent_name="Moderator", content=text, turn=self.turn, visible_to=visible_to)
self.message_pool.append_message(message)
def is_terminal(self) -> bool:
return self._terminal
def step(self, player_name: str, action: str) -> TimeStep:
assert player_name == self.get_next_player(), f"Wrong player! It is {self.get_next_player()} turn."
message = Message(agent_name=player_name, content=action, turn=self.turn)
self.message_pool.append_message(message)
# Convert the action to the AlphaZero format
action_index = action_string_to_action(action)
if action_index == -1:
raise ValueError(f"Invalid action: {action}")
self.env.step(action_index)
obs_dict, reward, terminal, truncation, info = self.env.last()
self._terminal = terminal # Update the terminal state
reward = {self.player_names[self.current_player]: reward,
self.player_names[1 - self.current_player]: 0}
self.current_player = 1 - self.current_player
self.turn += 1
self._moderator_speak("\n" + self.render_ansi(obs_dict["observation"]))
return TimeStep(observation=self.get_observation(), reward=reward, terminal=terminal)
def check_action(self, action: str, agent_name: str) -> bool:
# This can be implemented depending on how you want to validate actions for a given agent
action_index = action_string_to_action(action)
if action_index == -1:
return False
elif self.env.last()[0]["action_mask"][action_index] == 0:
return False
else:
return True
def render_ansi(self, observation):
string = ""
observation = observation.transpose(1, 0, 2)
for row in observation:
string += "|"
for column in row:
symbol = "_"
if column[self.current_player] == 1:
symbol = "X"
elif column[1 - self.current_player] == 1:
symbol = "O"
string += " " + symbol + " |"
string += "\n"
return string
def print(self):
obs_dict, reward, terminal, truncation, info = self.env.last()
print(self.render_ansi(obs_dict["observation"]))