Spaces:
Runtime error
Runtime error
File size: 4,780 Bytes
4bdab37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from pettingzoo.classic.chess.chess_utils import *
import re
from pettingzoo.classic import chess_v5
from chatarena.environments.base import Environment, TimeStep
from typing import List, Dict, Union
from ..message import Message, MessagePool
from ..config import EnvironmentConfig
def action_string_to_alphazero_format(action: str, player_index: int) -> int:
pattern = r"Move \((\d), (\d)\) to \((\d), (\d)\)"
match = re.match(pattern, action)
if not match:
return -1
coords = [int(coord) for coord in match.groups()]
x1, y1, x2, y2 = coords
if player_index == 1:
x1, y1, x2, y2 = 7 - x1, 7 - y1, 7 - x2, 7 - y2
move = chess.Move(from_square=8 * y1 + x1, to_square=8 * y2 + x2, promotion=None)
move_plane = get_move_plane(move)
return x1 * 8 * 73 + y1 * 73 + move_plane
class PettingzooChess(Environment):
type_name = "pettingzoo:chess"
def __init__(self, player_names: List[str], **kwargs):
super().__init__(player_names=player_names, **kwargs)
self.env = chess_v5.env(render_mode="ansi")
# The "state" of the environment is maintained by the message pool
self.message_pool = MessagePool()
self._terminal = False
self.reset()
def reset(self):
self.env.reset()
self.current_player = 0
self.turn = 0
self.message_pool.reset()
obs_dict, reward, terminal, truncation, info = self.env.last()
observation = self.get_observation()
self._terminal = terminal
return TimeStep(observation=observation, reward=reward, terminal=terminal)
def get_next_player(self) -> str:
return self.player_names[self.current_player]
def get_observation(self, player_name=None) -> List[Message]:
if player_name is None:
return self.message_pool.get_all_messages()
else:
return self.message_pool.get_visible_messages(player_name, turn=self.turn + 1)
def _moderator_speak(self, text: str, visible_to: Union[str, List[str]] = "all"):
"""
moderator say something
"""
message = Message(agent_name="Moderator", content=text, turn=self.turn, visible_to=visible_to)
self.message_pool.append_message(message)
def is_terminal(self) -> bool:
return self._terminal
def step(self, player_name: str, action: str) -> TimeStep:
assert player_name == self.get_next_player(), f"Wrong player! It is {self.get_next_player()} turn."
self._moderator_speak("\n" + self.env.render())
message = Message(agent_name=player_name, content=action, turn=self.turn)
self.message_pool.append_message(message)
# Convert the action to the AlphaZero format
alphazero_move = action_string_to_alphazero_format(action, self.current_player)
if alphazero_move == -1:
raise ValueError(f"Invalid action: {action}")
obs_dict, reward, terminal, truncation, info = self.env.last()
self.env.step(alphazero_move)
self._terminal = terminal # Update the terminal state
reward = {self.player_names[self.current_player]: reward,
self.player_names[1 - self.current_player]: 0}
self.current_player = 1 - self.current_player
self.turn += 1
return TimeStep(observation=self.get_observation(), reward=reward, terminal=terminal)
def check_action(self, action: str, agent_name: str) -> bool:
# This can be implemented depending on how you want to validate actions for a given agent
alphazero_move = action_string_to_alphazero_format(action, self.current_player)
if alphazero_move == -1:
return False
elif self.env.last()[0]["action_mask"][alphazero_move] == 0:
return False
else:
return True
def print(self):
print(self.env.render())
def test_chess_environment():
player_names = ["player1", "player2"]
env = PettingzooChess(player_names)
env.reset()
assert env.get_next_player() == "player1"
env.print()
# Move sequence: 1. e4 e5 2. Nf3 Nc6
moves = ["Move (4, 1) to (4, 3)", "Move (4, 6) to (4, 4)",
"Move (6, 0) to (5, 2)", "Move (1, 7) to (2, 5)"]
for i, move in enumerate(moves):
assert env.check_action(move, env.get_next_player())
timestep = env.step(env.get_next_player(), move)
print(timestep.reward)
print(timestep.terminal)
env.print()
if __name__ == "__main__":
env = chess_v5.env()
# Test the conversion function with an example action string
action = "Move (0, 1) to (0, 3)"
alphazero_move = action_string_to_alphazero_format(action, 0)
print(alphazero_move)
test_chess_environment()
|