File size: 4,780 Bytes
4bdab37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from pettingzoo.classic.chess.chess_utils import *
import re
from pettingzoo.classic import chess_v5

from chatarena.environments.base import Environment, TimeStep
from typing import List, Dict, Union

from ..message import Message, MessagePool
from ..config import EnvironmentConfig


def action_string_to_alphazero_format(action: str, player_index: int) -> int:
    pattern = r"Move \((\d), (\d)\) to \((\d), (\d)\)"
    match = re.match(pattern, action)

    if not match:
        return -1

    coords = [int(coord) for coord in match.groups()]
    x1, y1, x2, y2 = coords

    if player_index == 1:
        x1, y1, x2, y2 = 7 - x1, 7 - y1, 7 - x2, 7 - y2
    move = chess.Move(from_square=8 * y1 + x1, to_square=8 * y2 + x2, promotion=None)
    move_plane = get_move_plane(move)

    return x1 * 8 * 73 + y1 * 73 + move_plane


class PettingzooChess(Environment):
    type_name = "pettingzoo:chess"

    def __init__(self, player_names: List[str], **kwargs):
        super().__init__(player_names=player_names, **kwargs)
        self.env = chess_v5.env(render_mode="ansi")

        # The "state" of the environment is maintained by the message pool
        self.message_pool = MessagePool()
        self._terminal = False
        self.reset()

    def reset(self):
        self.env.reset()
        self.current_player = 0
        self.turn = 0
        self.message_pool.reset()

        obs_dict, reward, terminal, truncation, info = self.env.last()
        observation = self.get_observation()
        self._terminal = terminal
        return TimeStep(observation=observation, reward=reward, terminal=terminal)

    def get_next_player(self) -> str:
        return self.player_names[self.current_player]

    def get_observation(self, player_name=None) -> List[Message]:
        if player_name is None:
            return self.message_pool.get_all_messages()
        else:
            return self.message_pool.get_visible_messages(player_name, turn=self.turn + 1)

    def _moderator_speak(self, text: str, visible_to: Union[str, List[str]] = "all"):
        """
        moderator say something
        """
        message = Message(agent_name="Moderator", content=text, turn=self.turn, visible_to=visible_to)
        self.message_pool.append_message(message)

    def is_terminal(self) -> bool:
        return self._terminal

    def step(self, player_name: str, action: str) -> TimeStep:
        assert player_name == self.get_next_player(), f"Wrong player! It is {self.get_next_player()} turn."
        self._moderator_speak("\n" + self.env.render())

        message = Message(agent_name=player_name, content=action, turn=self.turn)
        self.message_pool.append_message(message)
        # Convert the action to the AlphaZero format
        alphazero_move = action_string_to_alphazero_format(action, self.current_player)
        if alphazero_move == -1:
            raise ValueError(f"Invalid action: {action}")

        obs_dict, reward, terminal, truncation, info = self.env.last()
        self.env.step(alphazero_move)
        self._terminal = terminal  # Update the terminal state
        reward = {self.player_names[self.current_player]: reward,
                  self.player_names[1 - self.current_player]: 0}

        self.current_player = 1 - self.current_player
        self.turn += 1

        return TimeStep(observation=self.get_observation(), reward=reward, terminal=terminal)

    def check_action(self, action: str, agent_name: str) -> bool:
        # This can be implemented depending on how you want to validate actions for a given agent
        alphazero_move = action_string_to_alphazero_format(action, self.current_player)
        if alphazero_move == -1:
            return False
        elif self.env.last()[0]["action_mask"][alphazero_move] == 0:
            return False
        else:
            return True

    def print(self):
        print(self.env.render())


def test_chess_environment():
    player_names = ["player1", "player2"]
    env = PettingzooChess(player_names)

    env.reset()
    assert env.get_next_player() == "player1"
    env.print()

    # Move sequence: 1. e4 e5 2. Nf3 Nc6
    moves = ["Move (4, 1) to (4, 3)", "Move (4, 6) to (4, 4)",
             "Move (6, 0) to (5, 2)", "Move (1, 7) to (2, 5)"]

    for i, move in enumerate(moves):
        assert env.check_action(move, env.get_next_player())
        timestep = env.step(env.get_next_player(), move)
        print(timestep.reward)
        print(timestep.terminal)
        env.print()


if __name__ == "__main__":
    env = chess_v5.env()

    # Test the conversion function with an example action string
    action = "Move (0, 1) to (0, 3)"
    alphazero_move = action_string_to_alphazero_format(action, 0)
    print(alphazero_move)

    test_chess_environment()