File size: 10,394 Bytes
bdafe83
 
 
 
 
 
 
 
 
 
 
 
53709ed
 
bdafe83
 
57d59fc
bdafe83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53709ed
 
bdafe83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53709ed
bdafe83
 
 
 
 
 
 
 
 
53709ed
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import logging
import os
import os.path as osp
from typing import Union

from colorama import Fore
from colorama import Style as CRStyle
from prompt_toolkit import prompt
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.styles import Style
from rich.console import Console

from agentreview.utility.utils import get_rebuttal_dir, load_llm_ac_decisions, \
    save_llm_ac_decisions
from ..arena import Arena, TooManyInvalidActions
from ..backends.human import HumanBackendError
from ..const import AGENTREVIEW_LOGO
from ..environments import PaperReview, PaperDecision

# Get the ASCII art from https://patorjk.com/software/taag/#p=display&f=Big&t=Chat%20Arena

color_dict = {
    "red": Fore.RED,
    "green": Fore.GREEN,

    "blue": Fore.BLUE,  # Paper Extractor
    "light_red": Fore.LIGHTRED_EX,  # AC
    "light_green": Fore.LIGHTGREEN_EX,  # Author
    "yellow": Fore.YELLOW,  # R1
    "magenta": Fore.MAGENTA,  # R2
    "cyan": Fore.CYAN,
    "white": Fore.WHITE,
    "black": Fore.BLACK,

    "light_yellow": Fore.LIGHTYELLOW_EX,
    "light_blue": Fore.LIGHTBLUE_EX,
    "light_magenta": Fore.LIGHTMAGENTA_EX,
    "light_cyan": Fore.LIGHTCYAN_EX,
    "light_white": Fore.LIGHTWHITE_EX,
    "light_black": Fore.LIGHTBLACK_EX,

}

visible_colors = [
    color
    for color in color_dict  # ANSI_COLOR_NAMES.keys()
    if color not in ["black", "white", "red", "green"] and "grey" not in color
]

try:
    import colorama
except ImportError:
    raise ImportError(
        "Please install colorama: `pip install colorama`"
    )

MAX_STEPS = 20  # We should not need this parameter for paper reviews anyway

# Set logging level to ERROR
logging.getLogger().setLevel(logging.ERROR)


class ArenaCLI:
    """The CLI user interface for ChatArena."""

    def __init__(self, arena: Arena):
        self.arena = arena
        self.args = arena.args


    def launch(self, max_steps: int = None, interactive: bool = True):
        """Run the CLI."""

        if not interactive and max_steps is None:
            max_steps = MAX_STEPS

        args = self.args

        console = Console()
        # Print ascii art
        timestep = self.arena.reset()
        console.print("🎓AgentReview Initialized!", style="bold green")

        env: Union[PaperReview, PaperDecision] = self.arena.environment
        players = self.arena.players

        env_desc = self.arena.global_prompt
        num_players = env.num_players
        player_colors = visible_colors[:num_players]  # sample different colors for players
        name_to_color = dict(zip(env.player_names, player_colors))

        print("name_to_color: ", name_to_color)
        # System and Moderator messages are printed in red
        name_to_color["System"] = "red"
        name_to_color["Moderator"] = "red"

        console.print(
            f"[bold green underline]Environment ({env.type_name}) description:[/]\n{env_desc}"
        )

        # Print the player name, role_desc and backend_type
        for i, player in enumerate(players):
            player_name_str = f"[{player.name} ({player.backend.type_name})] Role Description:"
            # player_name = Text(player_name_str)
            # player_name.stylize(f"bold {name_to_color[player.name]} underline")
            # console.print(player_name)
            # console.print(player.role_desc)

            logging.info(color_dict[name_to_color[player.name]] + player_name_str + CRStyle.RESET_ALL)
            logging.info(color_dict[name_to_color[player.name]] + player.role_desc + CRStyle.RESET_ALL)

        console.print(Fore.GREEN + "\n========= Arena Start! ==========\n" + CRStyle.RESET_ALL)

        step = 0
        while not timestep.terminal:
            if env.type_name == "paper_review":
                if env.phase_index > 4:
                    break

            elif env.type_name == "paper_decision":
                # Phase 5: AC makes decisions
                if env.phase_index > 5:
                    break

            else:
                raise NotImplementedError(f"Unknown environment type: {env.type_name}")

            if interactive:
                command = prompt(
                    [("class:command", "command (n/r/q/s/h) > ")],
                    style=Style.from_dict({"command": "blue"}),
                    completer=WordCompleter(
                        [
                            "next",
                            "n",
                            "reset",
                            "r",
                            "exit",
                            "quit",
                            "q",
                            "help",
                            "h",
                            "save",
                            "s",
                        ]
                    ),
                )
                command = command.strip()

                if command == "help" or command == "h":
                    console.print("Available commands:")
                    console.print("    [bold]next or n or <Enter>[/]: next step")
                    console.print("    [bold]exit or quit or q[/]: exit the game")
                    console.print("    [bold]help or h[/]: print this message")
                    console.print("    [bold]reset or r[/]: reset the game")
                    console.print("    [bold]save or s[/]: save the history to file")
                    continue
                elif command == "exit" or command == "quit" or command == "q":
                    break
                elif command == "reset" or command == "r":
                    timestep = self.arena.reset()
                    console.print(
                        "\n========= Arena Reset! ==========\n", style="bold green"
                    )
                    continue
                elif command == "next" or command == "n" or command == "":
                    pass
                elif command == "save" or command == "s":
                    # Prompt to get the file path
                    file_path = prompt(
                        [("class:command", "save file path > ")],
                        style=Style.from_dict({"command": "blue"}),
                    )
                    file_path = file_path.strip()
                    # Save the history to file
                    self.arena.save_history(file_path)
                    # Print the save success message
                    console.print(f"History saved to {file_path}", style="bold green")
                else:
                    console.print(f"Invalid command: {command}", style="bold red")
                    continue

            try:
                timestep = self.arena.step()
            except HumanBackendError as e:
                # Handle human input and recover with the game update
                human_player_name = env.get_next_player()
                if interactive:
                    human_input = prompt(
                        [
                            (
                                "class:user_prompt",
                                f"Type your input for {human_player_name}: ",
                            )
                        ],
                        style=Style.from_dict({"user_prompt": "ansicyan underline"}),
                    )
                    # If not, the conversation does not stop
                    timestep = env.step(human_player_name, human_input)
                else:
                    raise e  # cannot recover from this error in non-interactive mode
            except TooManyInvalidActions as e:
                # Print the error message
                # console.print(f"Too many invalid actions: {e}", style="bold red")
                print(Fore.RED + "This will be red text" + CRStyle.RESET_ALL)
                break

            # The messages that are not yet logged
            messages = [msg for msg in env.get_observation() if not msg.logged]

            # Print the new messages
            for msg in messages:
                message_str = f"[{msg.agent_name}->{msg.visible_to}]: {msg.content}"
                if self.args.skip_logging:
                    console.print(color_dict[name_to_color[msg.agent_name]] + message_str + CRStyle.RESET_ALL)
                msg.logged = True

            step += 1
            if max_steps is not None and step >= max_steps:
                break

        console.print("\n========= Arena Ended! ==========\n", style="bold red")

        if env.type_name == "paper_review":

            paper_id = self.arena.environment.paper_id
            rebuttal_dir = get_rebuttal_dir(output_dir=self.args.output_dir,
                                            paper_id=paper_id,
                                            experiment_name=self.args.experiment_name,
                                            model_name=self.args.model_name,
                                            conference=self.args.conference)

            os.makedirs(rebuttal_dir, exist_ok=True)

            path_review_history = f"{rebuttal_dir}/{paper_id}.json"

            if osp.exists(path_review_history):
                raise Exception(f"History already exists!! ({path_review_history}). There must be something wrong with "
                                f"the path to save the history ")

            self.arena.save_history(path_review_history)

        elif env.type_name == "paper_decision":
            ac_decisions = load_llm_ac_decisions(output_dir=args.output_dir,
                                                            conference=args.conference,
                                                            model_name=args.model_name,
                                                            ac_scoring_method=args.ac_scoring_method,
                                                            experiment_name=args.experiment_name,
                                                            num_papers_per_area_chair=args.num_papers_per_area_chair)


            ac_decisions += [env.ac_decisions]

            save_llm_ac_decisions(ac_decisions,
                                 output_dir=args.output_dir,
                                 conference=args.conference,
                                 model_name=args.model_name,
                                 ac_scoring_method=args.ac_scoring_method,
                                 experiment_name=args.experiment_name)