Yiqiao Jin commited on
Commit
c2aafbd
·
2 Parent(s): da473f8 c291457

Merge branch 'gradio'

Browse files
agentreview/environments/paper_review.py CHANGED
@@ -57,7 +57,7 @@ class PaperReview(Conversation):
57
 
58
  if self._phases is not None:
59
  return self._phases
60
-
61
  reviewer_names = [name for name in self.player_names if name.startswith("Reviewer")]
62
 
63
  num_reviewers = len(reviewer_names)
@@ -180,13 +180,10 @@ class PaperReview(Conversation):
180
  "Phase V. (AC makes decisions).")
181
 
182
  else:
183
- logger.info(f"Phase {self.phase_index}: end of the speaking order. Move to Phase ({self.phase_index + 1}).")
184
  self.phase_index += 1
185
  self._current_turn += 1
186
 
187
-
188
-
189
-
190
  else:
191
  self._next_player_index += 1
192
 
@@ -200,7 +197,7 @@ class PaperReview(Conversation):
200
 
201
  def get_next_player(self) -> str:
202
  """Get the next player in the current phase."""
203
- speaking_order = self.phases[self.phase_index]["speaking_order"]
204
  next_player = speaking_order[self._next_player_index]
205
  return next_player
206
 
@@ -214,3 +211,7 @@ class PaperReview(Conversation):
214
  player_name, phase_index=self.phase_index, next_player_idx=self._next_player_index,
215
  player_names=self.player_names
216
  )
 
 
 
 
 
57
 
58
  if self._phases is not None:
59
  return self._phases
60
+
61
  reviewer_names = [name for name in self.player_names if name.startswith("Reviewer")]
62
 
63
  num_reviewers = len(reviewer_names)
 
180
  "Phase V. (AC makes decisions).")
181
 
182
  else:
183
+ print(f"Phase {self.phase_index}: end of the speaking order. Move to Phase ({self.phase_index + 1}).")
184
  self.phase_index += 1
185
  self._current_turn += 1
186
 
 
 
 
187
  else:
188
  self._next_player_index += 1
189
 
 
197
 
198
  def get_next_player(self) -> str:
199
  """Get the next player in the current phase."""
200
+ speaking_order = self.phases[self.phase_index]["speaking_order"]
201
  next_player = speaking_order[self._next_player_index]
202
  return next_player
203
 
 
211
  player_name, phase_index=self.phase_index, next_player_idx=self._next_player_index,
212
  player_names=self.player_names
213
  )
214
+
215
+ def get_messages_from_player(self, player_name: str) -> List[str]:
216
+ """Get the list of actions that the player can take."""
217
+ return self.message_pool.get_messages_from_player(player_name)
agentreview/message.py CHANGED
@@ -84,6 +84,7 @@ class MessagePool:
84
  """
85
  self._messages.append(message)
86
 
 
87
  def print(self):
88
  """Print all the messages in the pool."""
89
  for message in self._messages:
@@ -148,3 +149,16 @@ class MessagePool:
148
  ):
149
  visible_messages.append(message)
150
  return visible_messages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  """
85
  self._messages.append(message)
86
 
87
+
88
  def print(self):
89
  """Print all the messages in the pool."""
90
  for message in self._messages:
 
149
  ):
150
  visible_messages.append(message)
151
  return visible_messages
152
+
153
+ def get_messages_from_player(self, player_name: str) -> List[Message]:
154
+ """
155
+ Get all the messages from a given player.
156
+
157
+ Parameters:
158
+ player_name (str): The name of the player.
159
+
160
+ Returns:
161
+ List[Message]: A list of messages from the player.
162
+ """
163
+ return [message for message in self._messages if message.agent_name == player_name]
164
+
agentreview/paper_review_arena.py CHANGED
@@ -101,6 +101,7 @@ class PaperReviewArena(Arena):
101
  player.role_desc = get_reviewer_description(phase="reviewer_ac_discussion",
102
  **self.environment.experiment_setting["players"][
103
  'Reviewer'][reviewer_index - 1])
 
104
 
105
  elif self.environment.phase_index == 5: # Phase 5 AC Makes Decisions
106
 
 
101
  player.role_desc = get_reviewer_description(phase="reviewer_ac_discussion",
102
  **self.environment.experiment_setting["players"][
103
  'Reviewer'][reviewer_index - 1])
104
+
105
 
106
  elif self.environment.phase_index == 5: # Phase 5 AC Makes Decisions
107
 
agentreview/paper_review_message.py CHANGED
@@ -60,7 +60,7 @@ class PaperReviewMessagePool(MessagePool):
60
  visible_messages = []
61
 
62
  elif phase_index == 4:
63
- if agent_name.startswith("AC"):
64
  area_chair_type = self.experiment_setting['players']['AC'][0]["area_chair_type"]
65
 
66
  # 'BASELINE' means we do not specify the area chair's characteristics in the config file
@@ -86,7 +86,6 @@ class PaperReviewMessagePool(MessagePool):
86
  else:
87
  raise ValueError(f"Unknown Area chair type: {area_chair_type}.")
88
 
89
-
90
  else:
91
 
92
  visible_messages = []
 
60
  visible_messages = []
61
 
62
  elif phase_index == 4:
63
+ if agent_name.startswith("AC"):
64
  area_chair_type = self.experiment_setting['players']['AC'][0]["area_chair_type"]
65
 
66
  # 'BASELINE' means we do not specify the area chair's characteristics in the config file
 
86
  else:
87
  raise ValueError(f"Unknown Area chair type: {area_chair_type}.")
88
 
 
89
  else:
90
 
91
  visible_messages = []
agentreview/paper_review_player.py CHANGED
@@ -78,6 +78,7 @@ class PaperExtractorPlayer(Player):
78
  paper_decision: str,
79
  conference: str,
80
  backend: Union[BackendConfig, IntelligenceBackend],
 
81
  global_prompt: str = None,
82
  **kwargs,
83
  ):
@@ -85,6 +86,9 @@ class PaperExtractorPlayer(Player):
85
  self.paper_id = paper_id
86
  self.paper_decision = paper_decision
87
  self.conference: str = conference
 
 
 
88
 
89
  def act(self, observation: List[Message]) -> str:
90
  """
@@ -96,12 +100,17 @@ class PaperExtractorPlayer(Player):
96
  Returns:
97
  str: The action (response) of the player.
98
  """
99
-
100
- logging.info(f"Loading {self.conference} paper {self.paper_id} ({self.paper_decision}) ...")
 
 
101
 
102
  loader = PDFReader()
103
- document_path = Path(os.path.join(self.args.data_dir, self.conference, "paper", self.paper_decision,
104
- f"{self.paper_id}.pdf")) #
 
 
 
105
  documents = loader.load_data(file=document_path)
106
 
107
  num_words = 0
@@ -118,5 +127,7 @@ class PaperExtractorPlayer(Player):
118
  main_contents += text + ' '
119
  if FLAG:
120
  break
121
-
 
 
122
  return main_contents
 
78
  paper_decision: str,
79
  conference: str,
80
  backend: Union[BackendConfig, IntelligenceBackend],
81
+ paper_pdf_path: str = None,
82
  global_prompt: str = None,
83
  **kwargs,
84
  ):
 
86
  self.paper_id = paper_id
87
  self.paper_decision = paper_decision
88
  self.conference: str = conference
89
+
90
+ if paper_pdf_path is not None:
91
+ self.paper_pdf_path = paper_pdf_path
92
 
93
  def act(self, observation: List[Message]) -> str:
94
  """
 
100
  Returns:
101
  str: The action (response) of the player.
102
  """
103
+ if self.paper_pdf_path is not None:
104
+ logging.info(f"Loading paper from {self.paper_pdf_path} ...")
105
+ else:
106
+ logging.info(f"Loading {self.conference} paper {self.paper_id} ({self.paper_decision}) ...")
107
 
108
  loader = PDFReader()
109
+ if self.paper_pdf_path is not None:
110
+ document_path = Path(self.paper_pdf_path)
111
+ else:
112
+ document_path = Path(os.path.join(self.args.data_dir, self.conference, "paper", self.paper_decision,
113
+ f"{self.paper_id}.pdf")) #
114
  documents = loader.load_data(file=document_path)
115
 
116
  num_words = 0
 
127
  main_contents += text + ' '
128
  if FLAG:
129
  break
130
+
131
+ print(main_contents)
132
+
133
  return main_contents
agentreview/utility/authentication_utils.py CHANGED
@@ -16,13 +16,6 @@ def get_openai_client(client_type: str):
16
 
17
  assert client_type in ["azure_openai", "openai"]
18
 
19
- endpoint: str = os.environ['AZURE_ENDPOINT']
20
-
21
- if not endpoint.startswith("https://"):
22
- endpoint = f"https://{endpoint}.openai.azure.com"
23
-
24
- os.environ['AZURE_ENDPOINT'] = endpoint
25
-
26
  if not os.environ.get('OPENAI_API_VERSION'):
27
  os.environ['OPENAI_API_VERSION'] = "2023-05-15"
28
 
@@ -32,6 +25,13 @@ def get_openai_client(client_type: str):
32
  )
33
 
34
  elif client_type == "azure_openai":
 
 
 
 
 
 
 
35
  client = openai.AzureOpenAI(
36
  api_key=os.environ['AZURE_OPENAI_KEY'],
37
  azure_endpoint=os.environ['AZURE_ENDPOINT'], # f"https://YOUR_END_POINT.openai.azure.com"
 
16
 
17
  assert client_type in ["azure_openai", "openai"]
18
 
 
 
 
 
 
 
 
19
  if not os.environ.get('OPENAI_API_VERSION'):
20
  os.environ['OPENAI_API_VERSION'] = "2023-05-15"
21
 
 
25
  )
26
 
27
  elif client_type == "azure_openai":
28
+ endpoint: str = os.environ['AZURE_ENDPOINT']
29
+
30
+ if not endpoint.startswith("https://"):
31
+ endpoint = f"https://{endpoint}.openai.azure.com"
32
+
33
+ os.environ['AZURE_ENDPOINT'] = endpoint
34
+
35
  client = openai.AzureOpenAI(
36
  api_key=os.environ['AZURE_OPENAI_KEY'],
37
  azure_endpoint=os.environ['AZURE_ENDPOINT'], # f"https://YOUR_END_POINT.openai.azure.com"
app.py CHANGED
@@ -1,7 +1,725 @@
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def echo_text(text):
4
- return text
5
 
6
- iface = gr.Interface(fn=echo_text, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from glob import glob
4
+ from argparse import Namespace
5
+
6
  import gradio as gr
7
 
 
 
8
 
9
+ from agentreview import const
10
+ from agentreview.config import AgentConfig
11
+ from agentreview.agent import Player
12
+ from agentreview.backends import BACKEND_REGISTRY
13
+ from agentreview.environments import PaperReview
14
+ from agentreview.paper_review_arena import PaperReviewArena
15
+ from agentreview.utility.experiment_utils import initialize_players
16
+ from agentreview.paper_review_player import PaperExtractorPlayer, AreaChair, Reviewer
17
+ from agentreview.role_descriptions import (get_reviewer_description, get_ac_description, get_author_config,
18
+ get_paper_extractor_config, get_author_description)
19
+
20
+ # 该文件的使命是前端交互:构建前端页面,从页面中获取用户的配置,传入后端运行,将结果实时展示到相应模块
21
+
22
+ css = """#col-container {max-width: 90%; margin-left: auto; margin-right: auto; display: flex; flex-direction: column;}
23
+ #header {text-align: center;}
24
+ #col-chatbox {flex: 1; max-height: min(900px, 100%);}
25
+ #label {font-size: 2em; padding: 0.5em; margin: 0;}
26
+ .message {font-size: 1.2em;}
27
+ .message-wrap {max-height: min(700px, 100vh);}
28
+ """
29
+ # .wrap {min-width: min(640px, 100vh)}
30
+ # #env-desc {max-height: 100px; overflow-y: auto;}
31
+ # .textarea {height: 100px; max-height: 100px;}
32
+ # #chatbot-tab-all {height: 750px; max-height: min(750px, 100%);}
33
+ # #chatbox {height: min(750px, 100%); max-height: min(750px, 100%);}
34
+ # #chatbox.block {height: 730px}
35
+ # .wrap {max-height: 680px;}
36
+ # .scroll-hide {overflow-y: scroll; max-height: 100px;}
37
+
38
+ DEBUG = False
39
+
40
+ DEFAULT_BACKEND = "openai-chat"
41
+ MAX_NUM_PLAYERS = 5
42
+ DEFAULT_NUM_PLAYERS = 5
43
+ CURRENT_STEP_INDEX = 0
44
+
45
+ def load_examples():
46
+ example_configs = {}
47
+ # Load json config files from examples folder
48
+ example_files = glob("examples/*.json")
49
+ for example_file in example_files:
50
+ with open(example_file, encoding="utf-8") as f:
51
+ example = json.load(f)
52
+ try:
53
+ example_configs[example["name"]] = example
54
+ except KeyError:
55
+ print(f"Example {example_file} is missing a name field. Skipping.")
56
+ return example_configs
57
+
58
+
59
+ EXAMPLE_REGISTRY = load_examples()
60
+
61
+ # DB = SupabaseDB() if supabase_available else None
62
+
63
+ def get_player_components(name, visible):
64
+ with gr.Row():
65
+ with gr.Column():
66
+ role_name = gr.Textbox(
67
+ lines=1,
68
+ show_label=False,
69
+ interactive=True,
70
+ visible=False,
71
+ value=name,
72
+ )
73
+
74
+ # is benign, is_knowledgeable, is_responsible,
75
+ # player_config = gr.CheckboxGroup(
76
+ # choices=["Benign", "Knowledgeable", "Responsible"],
77
+ # label="Reviewer Type",
78
+ # visible=visible,
79
+ # )
80
+
81
+ with gr.Row():
82
+ # Converting the three attributes into dropdowns
83
+ Intention_config = gr.Dropdown(
84
+ choices=["Benign", "Malicious", "Neutral"],
85
+ interactive=True,
86
+ label = "Intention",
87
+ show_label=True,
88
+ value="Neutral",
89
+ )
90
+
91
+ Knowledge_config = gr.Dropdown(
92
+ choices=["Knowledgeable", "Unknownledgeable", "Normal"],
93
+ interactive=True,
94
+ label = "Knowledgeability",
95
+ show_label=True,
96
+ value="Normal",
97
+ )
98
+
99
+ Responsibility_config = gr.Dropdown(
100
+ choices=["Responsible", "Irresponsible", "Normal"],
101
+ interactive=True,
102
+ label = "Responsibility",
103
+ show_label=True,
104
+ value="Normal",
105
+ )
106
+
107
+
108
+ role_desc = gr.Textbox(
109
+ lines=8,
110
+ max_lines=8,
111
+ show_label=False,
112
+ interactive=True,
113
+ visible=visible,
114
+ autoscroll=False,
115
+ value=get_reviewer_description()
116
+ )
117
+
118
+ def update_role_desc(Intention_config, Knowledge_config, Responsibility_config):
119
+
120
+ is_benign = True if Intention_config == "Benign" else (False if Intention_config == "Malicious" else None)
121
+ is_knowledgeable = True if Knowledge_config == "Knowledgeable" else (False if Knowledge_config == "Unknownledgeable" else None)
122
+ is_responsible = True if Responsibility_config == "Responsible" else (False if Responsibility_config == "Lazy" else None)
123
+
124
+ phase = 'reviewer_write_reviews' if CURRENT_STEP_INDEX < 2 else 'reviewer_ac_discussion'
125
+ return get_reviewer_description(is_benign, is_knowledgeable, is_responsible, phase=phase) # FIXME:依据阶段变化
126
+
127
+ Intention_config.select(fn=update_role_desc, inputs=[Intention_config, Knowledge_config, Responsibility_config], outputs=[role_desc])
128
+ Knowledge_config.select(fn=update_role_desc, inputs=[Intention_config, Knowledge_config, Responsibility_config], outputs=[role_desc])
129
+ Responsibility_config.select(fn=update_role_desc, inputs=[Intention_config, Knowledge_config, Responsibility_config], outputs=[role_desc])
130
+
131
+ with gr.Column():
132
+ backend_type = gr.Dropdown(
133
+ show_label=False,
134
+ choices=list(BACKEND_REGISTRY.keys()),
135
+ interactive=True,
136
+ visible=visible,
137
+ value=DEFAULT_BACKEND,
138
+ )
139
+ with gr.Accordion(
140
+ f"{name} Parameters", open=False, visible=visible
141
+ ) as accordion:
142
+ temperature = gr.Slider(
143
+ minimum=0.,
144
+ maximum=2.0,
145
+ step=0.1,
146
+ interactive=True,
147
+ visible=visible,
148
+ label="temperature",
149
+ value=0.7,
150
+ )
151
+ max_tokens = gr.Slider(
152
+ minimum=10,
153
+ maximum=500,
154
+ step=10,
155
+ interactive=True,
156
+ visible=visible,
157
+ label="max tokens",
158
+ value=200,
159
+ )
160
+
161
+ return [role_name, Intention_config, Knowledge_config, Responsibility_config, backend_type, accordion, temperature, max_tokens]
162
+
163
+
164
+ def get_author_components(name, visible):
165
+ with gr.Row():
166
+ with gr.Column():
167
+ role_desc = gr.Textbox(
168
+ lines=8,
169
+ max_lines=8,
170
+ show_label=False,
171
+ interactive=True,
172
+ visible=visible,
173
+ value=get_author_description(),
174
+ )
175
+ with gr.Column():
176
+ backend_type = gr.Dropdown(
177
+ show_label=False,
178
+ choices=list(BACKEND_REGISTRY.keys()),
179
+ interactive=True,
180
+ visible=visible,
181
+ value=DEFAULT_BACKEND,
182
+ )
183
+ with gr.Accordion(
184
+ f"{name} Parameters", open=False, visible=visible
185
+ ) as accordion:
186
+ temperature = gr.Slider(
187
+ minimum=0.,
188
+ maximum=2.0,
189
+ step=0.1,
190
+ interactive=True,
191
+ visible=visible,
192
+ label="temperature",
193
+ value=0.7,
194
+ )
195
+ max_tokens = gr.Slider(
196
+ minimum=10,
197
+ maximum=500,
198
+ step=10,
199
+ interactive=True,
200
+ visible=visible,
201
+ label="max tokens",
202
+ value=200,
203
+ )
204
+
205
+ return [role_desc, backend_type, accordion, temperature, max_tokens]
206
+
207
+
208
+ def get_area_chair_components(name, visible):
209
+ with gr.Row():
210
+ with gr.Column():
211
+
212
+ role_name = gr.Textbox(
213
+ lines=1,
214
+ show_label=False,
215
+ interactive=True,
216
+ visible=False,
217
+ value=name,
218
+ )
219
+
220
+ AC_type = gr.Dropdown(
221
+ label = "AC Type",
222
+ show_label=True,
223
+ choices=["Inclusive", "Conformist", "Authoritarian", "Normal"],
224
+ interactive=True,
225
+ visible=visible,
226
+ value="Normal",
227
+ )
228
+
229
+ role_desc = gr.Textbox(
230
+ lines=8,
231
+ max_lines=8,
232
+ show_label=False,
233
+ interactive=True,
234
+ visible=visible,
235
+ value=get_ac_description("BASELINE", "ac_write_metareviews", 'None', 1),
236
+ )
237
+
238
+ def update_role_desc(AC_type):
239
+ ac_type = 'BASELINE' if AC_type == "Normal" else AC_type.lower()
240
+ return get_ac_description(ac_type, "ac_write_metareviews", "None", 1) # FIXME:依据阶段变化
241
+
242
+ AC_type.select(fn=update_role_desc, inputs=[AC_type], outputs=[role_desc])
243
+
244
+ with gr.Column():
245
+ backend_type = gr.Dropdown(
246
+ show_label=False,
247
+ choices=list(BACKEND_REGISTRY.keys()),
248
+ interactive=True,
249
+ visible=visible,
250
+ value=DEFAULT_BACKEND,
251
+ )
252
+ with gr.Accordion(
253
+ f"{name} Parameters", open=False, visible=visible
254
+ ) as accordion:
255
+ temperature = gr.Slider(
256
+ minimum=0,
257
+ maximum=2.0,
258
+ step=0.1,
259
+ interactive=True,
260
+ visible=visible,
261
+ label="temperature",
262
+ value=0.7,
263
+ )
264
+ max_tokens = gr.Slider(
265
+ minimum=10,
266
+ maximum=500,
267
+ step=10,
268
+ interactive=True,
269
+ visible=visible,
270
+ label="max tokens",
271
+ value=200,
272
+ )
273
+
274
+ return [role_name, AC_type, backend_type, accordion, temperature, max_tokens]
275
+
276
+
277
+ def get_empty_state():
278
+ return gr.State({"arena": None})
279
+
280
+
281
+ with (gr.Blocks(css=css) as demo):
282
+ state = get_empty_state()
283
+ all_components = []
284
+
285
+ with gr.Column(elem_id="col-container"):
286
+ gr.Markdown(
287
+ """# 🤖[AgentReview](https://arxiv.org/abs/2406.12708)<br>
288
+ S Multi-Agent to Simulate conference reviews on your own papers.
289
+ **[Project Homepage](https://github.com/Ahren09/AgentReview)**""",
290
+ elem_id="header",
291
+ )
292
+
293
+ # Environment configuration
294
+ env_desc_textbox = gr.Textbox(
295
+ show_label=True,
296
+ lines=2,
297
+ visible=True,
298
+ label="Environment Description",
299
+ interactive=True,
300
+ # placeholder="Enter a description of a scenario or the game rules.",
301
+ value=const.GLOBAL_PROMPT,
302
+ )
303
+
304
+ all_components += [env_desc_textbox]
305
+
306
+ with gr.Row():
307
+ with gr.Column(elem_id="col-chatbox"):
308
+ with gr.Tab("All", visible=True):
309
+ chatbot = gr.Chatbot(
310
+ elem_id="chatbox", visible=True, show_label=False, height=600
311
+ )
312
+
313
+ player_chatbots = []
314
+ for i in range(MAX_NUM_PLAYERS):
315
+ if i in [0, 1, 2]:
316
+ player_name = f"Reviewer {i + 1}"
317
+
318
+ elif i == 3:
319
+ player_name = "AC"
320
+
321
+ elif i == 4:
322
+ player_name = "Author"
323
+
324
+ with gr.Tab(player_name, visible=(i < DEFAULT_NUM_PLAYERS)):
325
+ player_chatbot = gr.Chatbot(
326
+ elem_id=f"chatbox-{i}",
327
+ visible=i < DEFAULT_NUM_PLAYERS,
328
+ label=player_name,
329
+ show_label=False,
330
+ height=600, # FIXME: this parameter is not working
331
+ )
332
+ player_chatbots.append(player_chatbot)
333
+
334
+ all_components += [chatbot, *player_chatbots]
335
+
336
+ with gr.Column(elem_id="col-config"): # Player Configuration
337
+ # gr.Markdown("Player Configuration")
338
+
339
+ # parallel_checkbox = gr.Checkbox(
340
+ # label="Parallel Actions", value=False, visible=True
341
+ # )
342
+
343
+ all_players_components, players_idx2comp = [], {}
344
+ with gr.Blocks():
345
+ for i in range(MAX_NUM_PLAYERS):
346
+ if i in [0, 1, 2]:
347
+ player_name = f"Reviewer {i + 1}"
348
+
349
+ elif i == 3:
350
+ player_name = "AC"
351
+
352
+ elif i == 4:
353
+ player_name = "Author"
354
+
355
+ else:
356
+ raise ValueError(f"Invalid player index: {i}")
357
+ with gr.Tab(
358
+ player_name, visible=(i < DEFAULT_NUM_PLAYERS)
359
+ ) as tab:
360
+ if "Reviewer" in player_name:
361
+ player_comps = get_player_components(
362
+ player_name, visible=(i < DEFAULT_NUM_PLAYERS)
363
+ )
364
+ elif player_name == "AC":
365
+ player_comps = get_area_chair_components(
366
+ player_name, visible=(i < DEFAULT_NUM_PLAYERS)
367
+ )
368
+ elif player_name == "Author":
369
+ player_comps = get_author_components(
370
+ player_name, visible=(i < DEFAULT_NUM_PLAYERS)
371
+ )
372
+
373
+ players_idx2comp[i] = player_comps + [tab]
374
+ all_players_components += player_comps + [tab]
375
+
376
+ all_components += all_players_components
377
+
378
+ # human_input_textbox = gr.Textbox(
379
+ # show_label=True,
380
+ # label="Human Input",
381
+ # lines=1,
382
+ # visible=True,
383
+ # interactive=True,
384
+ # placeholder="Upload your paper here",
385
+ # )
386
+
387
+ upload_file_box = gr.File(
388
+ visible=True,
389
+ height = 100,
390
+ )
391
+
392
+ with gr.Row():
393
+ btn_step = gr.Button("Submit")
394
+ btn_restart = gr.Button("Clear")
395
+
396
+ all_components += [upload_file_box, btn_step, btn_restart]
397
+
398
+
399
+ def _convert_to_chatbot_output(all_messages, display_recv=False):
400
+ chatbot_output = []
401
+ for i, message in enumerate(all_messages):
402
+ agent_name, msg, recv = (
403
+ message.agent_name,
404
+ message.content,
405
+ str(message.visible_to),
406
+ )
407
+ new_msg = re.sub(
408
+ r"\n+", "<br>", msg.strip()
409
+ ) # Preprocess message for chatbot output
410
+ if display_recv:
411
+ new_msg = f"**{agent_name} (-> {recv})**: {new_msg}" # Add role to the message
412
+ else:
413
+ new_msg = f"**{agent_name}**: {new_msg}"
414
+
415
+ if agent_name == "Moderator":
416
+ chatbot_output.append((new_msg, None))
417
+ else:
418
+ chatbot_output.append((None, new_msg))
419
+ return chatbot_output
420
+
421
+ def _create_arena_config_from_components(all_comps: dict):
422
+
423
+ env_desc = all_comps[env_desc_textbox]
424
+ paper_pdf_path = all_comps[upload_file_box]
425
+
426
+ # Step 1: Initialize the players
427
+ num_players = MAX_NUM_PLAYERS
428
+
429
+ # You can ignore these fields for the demo
430
+ conference = "EMNLP2024"
431
+ paper_decision = "Accept"
432
+ data_dir = ''
433
+ paper_id = "12345"
434
+
435
+ args = Namespace(openai_client_type="azure_openai",
436
+ experiment_name="test",
437
+ max_num_words=16384)
438
+
439
+ # 在paper_decision 阶段 中只启用 AC
440
+ players = []
441
+
442
+ # 不能直接获取role_desc,需要根据Intention_config, Knowledge_config, Responsibility_config生成一个配置
443
+ # self.environment.experiment_setting["players"]['Reviewer'][reviewer_index - 1]
444
+
445
+ experiment_setting = {
446
+ "paper_id": paper_id,
447
+ "paper_decision": paper_decision,
448
+ "players": {
449
+
450
+ # Paper Extractor is a special player that extracts a paper from the dataset.
451
+ # Its constructor does not take any arguments.
452
+ "Paper Extractor": [{}],
453
+
454
+ # Assume there is only one area chair (AC) in the experiment.
455
+ "AC": [],
456
+
457
+ # Author role with default configuration.
458
+ "Author": [{}],
459
+
460
+ # Reviewer settings are generated based on reviewer types provided in the settings.
461
+ "Reviewer": [],
462
+ },
463
+ # "global_settings": setting['global_settings']
464
+ }
465
+
466
+
467
+ for i in range(num_players):
468
+
469
+ role_name = role_desc = backend_type = temperature = max_tokens = None
470
+
471
+ if i in [0, 1, 2]: # reviewer
472
+ role_name, intention_config, knowledge_config, responsibility_config, backend_type, temperature, max_tokens = (
473
+ all_comps[c]
474
+ for c in players_idx2comp[i]
475
+ if not isinstance(c, (gr.Accordion, gr.Tab))
476
+ )
477
+
478
+ is_benign = True if intention_config == "Benign" else (False if intention_config == "Malicious" else None)
479
+ is_knowledgeable = True if knowledge_config == "Knowledgeable" else (False if knowledge_config == "Unknownledgeable" else None)
480
+ is_responsible = True if responsibility_config == "Responsible" else (False if responsibility_config == "Lazy" else None)
481
+
482
+ experiment_setting["players"]['Reviewer'].append({"is_benign": is_benign,
483
+ "is_knowledgeable": is_knowledgeable,
484
+ "is_responsible": is_responsible,
485
+ "knows_authors": 'unfamous'})
486
+
487
+ role_desc = get_reviewer_description(is_benign, is_knowledgeable, is_responsible)
488
+
489
+ elif i == 3: # AC
490
+ role_name, ac_type, backend_type, temperature, max_tokens = (
491
+ all_comps[c]
492
+ for c in players_idx2comp[i]
493
+ if not isinstance(c, (gr.Accordion, gr.Tab))
494
+ )
495
+
496
+ ac_type = 'BASELINE' if ac_type == "Normal" else ac_type.lower()
497
+
498
+ experiment_setting["players"]['AC'].append({"area_chair_type": ac_type})
499
+
500
+ role_desc = get_ac_description(ac_type, "ac_write_metareviews", "None", 1)
501
+
502
+ elif i == 4: # Author
503
+ role_name, backend_type, temperature, max_tokens = (
504
+ all_comps[c]
505
+ for c in players_idx2comp[i]
506
+ if not isinstance(c, (gr.Accordion, gr.Tab))
507
+ )
508
+
509
+ role_desc = get_author_description()
510
+
511
+ else:
512
+ raise ValueError(f"Invalid player index: {i}")
513
+
514
+
515
+
516
+ # common config for all players
517
+ player_config = {
518
+ "name": role_name,
519
+ "role_desc": role_desc,
520
+ "global_prompt": env_desc,
521
+ "backend": {
522
+ "backend_type": backend_type,
523
+ "temperature": temperature,
524
+ "max_tokens": max_tokens,
525
+ },
526
+ }
527
+
528
+ player_config = AgentConfig(**player_config)
529
+
530
+ if i < num_players-1:
531
+ player = Reviewer(data_dir=data_dir, conference=conference, args=args, **player_config)
532
+ else:
533
+ player_config["env_type"] = "paper_review"
534
+ player = AreaChair(data_dir=data_dir, conference=conference, args=args, **player_config)
535
+
536
+ players.append(player)
537
+
538
+ # 根据上面的player_config和人造生成该阶段的players
539
+ # if CURRENT_STEP == "paper_review":
540
+
541
+ # 人为加入paper extractor
542
+ paper_extractor_config = get_paper_extractor_config(max_tokens=2048)
543
+
544
+ paper_extractor = PaperExtractorPlayer(paper_pdf_path=paper_pdf_path,
545
+ data_dir=data_dir, paper_id=paper_id,
546
+ paper_decision=paper_decision, args=args,
547
+ conference=conference, **paper_extractor_config)
548
+ players.append(paper_extractor)
549
+
550
+ # 人为加入author
551
+ author_config = get_author_config()
552
+ author = Player(data_dir=data_dir, conference=conference, args=args,
553
+ **author_config)
554
+
555
+ players.append(author)
556
+
557
+
558
+ player_names = [player.name for player in players]
559
+
560
+ # Step 2: Initialize the environment
561
+ env = PaperReview(player_names=player_names, paper_decision=paper_decision, paper_id=paper_id,
562
+ args=args, experiment_setting=experiment_setting)
563
+
564
+ # Step 3: Initialize the Arena
565
+ arena = PaperReviewArena(players=players, environment=env, args=args, global_prompt=env_desc)
566
+
567
+ return arena
568
+
569
+ def step_game(all_comps: dict):
570
+ global CURRENT_STEP_INDEX
571
+
572
+ yield {
573
+ btn_step: gr.update(value="Running...", interactive=False),
574
+ btn_restart: gr.update(interactive=False),
575
+ }
576
+
577
+ cur_state = all_comps[state]
578
+
579
+ # If arena is not yet created, create it
580
+ if cur_state["arena"] is None:
581
+ # Create the Arena
582
+ arena = _create_arena_config_from_components(all_comps)
583
+ cur_state["arena"] = arena
584
+ else:
585
+ arena = cur_state["arena"]
586
+
587
+ # TODO: 连续运行
588
+
589
+ timestep = arena.step()
590
+
591
+ CURRENT_STEP_INDEX = int(arena.environment.phase_index)
592
+
593
+ # 更新前端信息
594
+ if timestep:
595
+ all_messages = timestep.observation
596
+ all_messages[0].content = 'Paper content has been extracted.'
597
+ chatbot_output = _convert_to_chatbot_output(all_messages, display_recv=True)
598
+
599
+ # Initialize update dictionary
600
+ update_dict = {
601
+ chatbot: chatbot_output,
602
+ btn_step: gr.update(
603
+ value="Next Step", interactive=not timestep.terminal
604
+ ),
605
+ btn_restart: gr.update(interactive=True),
606
+ state: cur_state,
607
+ }
608
+
609
+ # Define a mapping of player names to their respective chatbots
610
+ player_name_to_chatbot = {
611
+ "Reviewer 1": player_chatbots[0],
612
+ "Reviewer 2": player_chatbots[1],
613
+ "Reviewer 3": player_chatbots[2],
614
+ "AC": player_chatbots[3],
615
+ "Author": player_chatbots[4],
616
+ }
617
+
618
+ # Update each player's chatbot output
619
+ for player in arena.players:
620
+ player_name = player.name
621
+ if player_name in player_name_to_chatbot:
622
+ player_messages = arena.environment.get_messages_from_player(player_name)
623
+ # player_messages[0].content = 'Paper content has been extracted.'
624
+ player_output = _convert_to_chatbot_output(player_messages)
625
+ update_dict[player_name_to_chatbot[player_name]] = player_output
626
+
627
+ # # Reviewer 1, 2, 3 Area Chair, Paper Extractor, Author
628
+ # for i, player in enumerate(arena.players):
629
+ # player_name = player.name
630
+ # # Get the messages for the current player
631
+ # player_messages = arena.environment.get_observation(player_name)
632
+ # player_messages[0].content = 'Paper content has been extracted.'
633
+ #
634
+ # # Convert messages to chatbot output
635
+ # player_output = _convert_to_chatbot_output(player_messages)
636
+
637
+
638
+ """
639
+ if 'Reviewer' in player.name and arena.environment.phase_index < 4: # FIXME: 临时逻辑
640
+ player_messages = arena.environment.get_observation(player.name)
641
+ # 不要显示第一条长段的信息,只显示 文章内容已被抽取
642
+ player_messages[0].content = 'Paper content has been extracted.'
643
+ player_output = _convert_to_chatbot_output(player_messages)
644
+ # Update the player's chatbot output
645
+ update_dict[player_chatbots[i]] = player_output
646
+ elif arena.environment.phase_index in [4, 5]: # FIXME: 临时逻辑
647
+ player_messages = arena.environment.get_observation('AC')
648
+ player_messages[0].content = 'Paper content has been extracted.'
649
+ player_output = _convert_to_chatbot_output(player_messages)
650
+ # Update the player's chatbot output
651
+ update_dict[player_chatbots[3]] = player_output
652
+ """
653
+ # Ahren: Auto run
654
+ # if not timestep.terminal:
655
+ # yield from step_game(all_comps)
656
+
657
+ yield update_dict
658
+
659
+
660
+ def restart_game(all_comps: dict):
661
+ global CURRENT_STEP_INDEX
662
+ CURRENT_STEP_INDEX = 0
663
+
664
+ cur_state = all_comps[state]
665
+ cur_state["arena"] = None
666
+ yield {
667
+ chatbot: [],
668
+ btn_restart: gr.update(interactive=False),
669
+ btn_step: gr.update(interactive=False),
670
+ state: cur_state,
671
+ }
672
+
673
+ # arena_config = _create_arena_config_from_components(all_comps)
674
+ # arena = Arena.from_config(arena_config)
675
+ # log_arena(arena, database=DB)
676
+ # cur_state["arena"] = arena
677
+
678
+ yield {
679
+ btn_step: gr.update(value="Start", interactive=True),
680
+ btn_restart: gr.update(interactive=True),
681
+ upload_file_box: gr.update(value=None),
682
+ state: cur_state,
683
+ }
684
+
685
+ # Remove Accordion and Tab from the list of components
686
+ all_components = [
687
+ comp for comp in all_components if not isinstance(comp, (gr.Accordion, gr.Tab))
688
+ ]
689
+
690
+ # update component
691
+ # env_desc_textbox.change()
692
+
693
+ # If any of the Textbox, Slider, Checkbox, Dropdown, RadioButtons is changed, the Step button is disabled
694
+ for comp in all_components:
695
+
696
+ def _disable_step_button(state):
697
+ if state["arena"] is not None:
698
+ return gr.update(interactive=False)
699
+ else:
700
+ return gr.update()
701
+
702
+ if (
703
+ isinstance(
704
+ comp, (gr.Textbox, gr.Slider, gr.Checkbox, gr.Dropdown, gr.Radio)
705
+ )
706
+ and comp is not upload_file_box
707
+ ):
708
+ comp.change(_disable_step_button, state, btn_step)
709
+
710
+ # Ahren: Auto run
711
+ btn_step.click(
712
+ step_game,
713
+ set(all_components + [state]),
714
+ [chatbot, *player_chatbots, btn_step, btn_restart, state, upload_file_box],
715
+ )
716
+
717
+ btn_restart.click(
718
+ restart_game,
719
+ set(all_components + [state]),
720
+ [chatbot, *player_chatbots, btn_step, btn_restart, state, upload_file_box],
721
+ )
722
+
723
+
724
+ demo.queue()
725
+ demo.launch()
data DELETED
@@ -1 +0,0 @@
1
- ../agent4reviews/data
 
 
template.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from glob import glob
4
+
5
+ import gradio as gr
6
+
7
+ from chatarena.arena import Arena, TooManyInvalidActions
8
+ from chatarena.backends import BACKEND_REGISTRY
9
+ from chatarena.backends.human import HumanBackendError
10
+ from chatarena.config import ArenaConfig
11
+ from chatarena.database import SupabaseDB, log_arena, log_messages, supabase_available
12
+ from chatarena.environments import ENV_REGISTRY
13
+ from chatarena.message import Message
14
+
15
+ css = """#col-container {max-width: 90%; margin-left: auto; margin-right: auto; display: flex; flex-direction: column;}
16
+ #header {text-align: center;}
17
+ #col-chatbox {flex: 1; max-height: min(750px, 100%);}
18
+ #label {font-size: 2em; padding: 0.5em; margin: 0;}
19
+ .message {font-size: 1.2em;}
20
+ .message-wrap {max-height: min(700px, 100vh);}
21
+ """
22
+ # .wrap {min-width: min(640px, 100vh)}
23
+ # #env-desc {max-height: 100px; overflow-y: auto;}
24
+ # .textarea {height: 100px; max-height: 100px;}
25
+ # #chatbot-tab-all {height: 750px; max-height: min(750px, 100%);}
26
+ # #chatbox {height: min(750px, 100%); max-height: min(750px, 100%);}
27
+ # #chatbox.block {height: 730px}
28
+ # .wrap {max-height: 680px;}
29
+ # .scroll-hide {overflow-y: scroll; max-height: 100px;}
30
+
31
+
32
+ DEBUG = False
33
+
34
+ DEFAULT_BACKEND = "openai-chat"
35
+ DEFAULT_ENV = "conversation"
36
+ MAX_NUM_PLAYERS = 6
37
+ DEFAULT_NUM_PLAYERS = 2
38
+
39
+
40
+ def load_examples():
41
+ example_configs = {}
42
+ # Load json config files from examples folder
43
+ example_files = glob("examples/*.json")
44
+ for example_file in example_files:
45
+ with open(example_file, encoding="utf-8") as f:
46
+ example = json.load(f)
47
+ try:
48
+ example_configs[example["name"]] = example
49
+ except KeyError:
50
+ print(f"Example {example_file} is missing a name field. Skipping.")
51
+ return example_configs
52
+
53
+
54
+ EXAMPLE_REGISTRY = load_examples()
55
+
56
+ DB = SupabaseDB() if supabase_available else None
57
+
58
+
59
+ def get_moderator_components(visible=True):
60
+ name = "Moderator"
61
+ with gr.Row():
62
+ with gr.Column():
63
+ role_desc = gr.Textbox(
64
+ label="Moderator role",
65
+ lines=1,
66
+ visible=visible,
67
+ interactive=True,
68
+ placeholder=f"Enter the role description for {name}",
69
+ )
70
+ terminal_condition = gr.Textbox(
71
+ show_label=False,
72
+ lines=1,
73
+ visible=visible,
74
+ interactive=True,
75
+ placeholder="Enter the termination criteria",
76
+ )
77
+ with gr.Column():
78
+ backend_type = gr.Dropdown(
79
+ show_label=False,
80
+ visible=visible,
81
+ interactive=True,
82
+ choices=list(BACKEND_REGISTRY.keys()),
83
+ value=DEFAULT_BACKEND,
84
+ )
85
+ with gr.Accordion(
86
+ f"{name} Parameters", open=False, visible=visible
87
+ ) as accordion:
88
+ temperature = gr.Slider(
89
+ minimum=0,
90
+ maximum=2.0,
91
+ step=0.1,
92
+ interactive=True,
93
+ visible=visible,
94
+ label="temperature",
95
+ value=0.7,
96
+ )
97
+ max_tokens = gr.Slider(
98
+ minimum=10,
99
+ maximum=500,
100
+ step=10,
101
+ interactive=True,
102
+ visible=visible,
103
+ label="max tokens",
104
+ value=200,
105
+ )
106
+
107
+ return [
108
+ role_desc,
109
+ terminal_condition,
110
+ backend_type,
111
+ accordion,
112
+ temperature,
113
+ max_tokens,
114
+ ]
115
+
116
+
117
+ def get_player_components(name, visible):
118
+ with gr.Row():
119
+ with gr.Column():
120
+ role_name = gr.Textbox(
121
+ line=1,
122
+ show_label=False,
123
+ interactive=True,
124
+ visible=visible,
125
+ placeholder=f"Player name for {name}",
126
+ )
127
+ role_desc = gr.Textbox(
128
+ lines=3,
129
+ show_label=False,
130
+ interactive=True,
131
+ visible=visible,
132
+ placeholder=f"Enter the role description for {name}",
133
+ )
134
+ with gr.Column():
135
+ backend_type = gr.Dropdown(
136
+ show_label=False,
137
+ choices=list(BACKEND_REGISTRY.keys()),
138
+ interactive=True,
139
+ visible=visible,
140
+ value=DEFAULT_BACKEND,
141
+ )
142
+ with gr.Accordion(
143
+ f"{name} Parameters", open=False, visible=visible
144
+ ) as accordion:
145
+ temperature = gr.Slider(
146
+ minimum=0,
147
+ maximum=2.0,
148
+ step=0.1,
149
+ interactive=True,
150
+ visible=visible,
151
+ label="temperature",
152
+ value=0.7,
153
+ )
154
+ max_tokens = gr.Slider(
155
+ minimum=10,
156
+ maximum=500,
157
+ step=10,
158
+ interactive=True,
159
+ visible=visible,
160
+ label="max tokens",
161
+ value=200,
162
+ )
163
+
164
+ return [role_name, role_desc, backend_type, accordion, temperature, max_tokens]
165
+
166
+
167
+ def get_empty_state():
168
+ return gr.State({"arena": None})
169
+
170
+
171
+ with gr.Blocks(css=css) as demo:
172
+ state = get_empty_state()
173
+ all_components = []
174
+
175
+ with gr.Column(elem_id="col-container"):
176
+ gr.Markdown(
177
+ """# 🏟 ChatArena️<br>
178
+ Prompting multiple AI agents to play games in a language-driven environment.
179
+ **[Project Homepage](https://github.com/chatarena/chatarena)**""",
180
+ elem_id="header",
181
+ )
182
+
183
+ with gr.Row():
184
+ env_selector = gr.Dropdown(
185
+ choices=list(ENV_REGISTRY.keys()),
186
+ value=DEFAULT_ENV,
187
+ interactive=True,
188
+ label="Environment Type",
189
+ show_label=True,
190
+ )
191
+ example_selector = gr.Dropdown(
192
+ choices=list(EXAMPLE_REGISTRY.keys()),
193
+ interactive=True,
194
+ label="Select Example",
195
+ show_label=True,
196
+ )
197
+
198
+ # Environment configuration
199
+ env_desc_textbox = gr.Textbox(
200
+ show_label=True,
201
+ lines=2,
202
+ visible=True,
203
+ label="Environment Description",
204
+ placeholder="Enter a description of a scenario or the game rules.",
205
+ )
206
+
207
+ all_components += [env_selector, example_selector, env_desc_textbox]
208
+
209
+ with gr.Row():
210
+ with gr.Column(elem_id="col-chatbox"):
211
+ with gr.Tab("All", visible=True):
212
+ chatbot = gr.Chatbot(
213
+ elem_id="chatbox", visible=True, show_label=False
214
+ )
215
+
216
+ player_chatbots = []
217
+ for i in range(MAX_NUM_PLAYERS):
218
+ player_name = f"Player {i + 1}"
219
+ with gr.Tab(player_name, visible=(i < DEFAULT_NUM_PLAYERS)):
220
+ player_chatbot = gr.Chatbot(
221
+ elem_id=f"chatbox-{i}",
222
+ visible=i < DEFAULT_NUM_PLAYERS,
223
+ label=player_name,
224
+ show_label=False,
225
+ )
226
+ player_chatbots.append(player_chatbot)
227
+
228
+ all_components += [chatbot, *player_chatbots]
229
+
230
+ with gr.Column(elem_id="col-config"): # Player Configuration
231
+ # gr.Markdown("Player Configuration")
232
+ parallel_checkbox = gr.Checkbox(
233
+ label="Parallel Actions", value=False, visible=True
234
+ )
235
+ with gr.Accordion("Moderator", open=False, visible=True):
236
+ moderator_components = get_moderator_components(True)
237
+ all_components += [parallel_checkbox, *moderator_components]
238
+
239
+ all_players_components, players_idx2comp = [], {}
240
+ with gr.Blocks():
241
+ num_player_slider = gr.Slider(
242
+ 2,
243
+ MAX_NUM_PLAYERS,
244
+ value=DEFAULT_NUM_PLAYERS,
245
+ step=1,
246
+ label="Number of players:",
247
+ )
248
+ for i in range(MAX_NUM_PLAYERS):
249
+ player_name = f"Player {i + 1}"
250
+ with gr.Tab(
251
+ player_name, visible=(i < DEFAULT_NUM_PLAYERS)
252
+ ) as tab:
253
+ player_comps = get_player_components(
254
+ player_name, visible=(i < DEFAULT_NUM_PLAYERS)
255
+ )
256
+
257
+ players_idx2comp[i] = player_comps + [tab]
258
+ all_players_components += player_comps + [tab]
259
+
260
+ all_components += [num_player_slider] + all_players_components
261
+
262
+ def variable_players(k):
263
+ k = int(k)
264
+ update_dict = {}
265
+ for i in range(MAX_NUM_PLAYERS):
266
+ if i < k:
267
+ for comp in players_idx2comp[i]:
268
+ update_dict[comp] = gr.update(visible=True)
269
+ update_dict[player_chatbots[i]] = gr.update(visible=True)
270
+ else:
271
+ for comp in players_idx2comp[i]:
272
+ update_dict[comp] = gr.update(visible=False)
273
+ update_dict[player_chatbots[i]] = gr.update(visible=False)
274
+ return update_dict
275
+
276
+ num_player_slider.change(
277
+ variable_players,
278
+ num_player_slider,
279
+ all_players_components + player_chatbots,
280
+ )
281
+
282
+ human_input_textbox = gr.Textbox(
283
+ show_label=True,
284
+ label="Human Input",
285
+ lines=1,
286
+ visible=True,
287
+ interactive=True,
288
+ placeholder="Enter your input here",
289
+ )
290
+ with gr.Row():
291
+ btn_step = gr.Button("Start")
292
+ btn_restart = gr.Button("Clear")
293
+
294
+ all_components += [human_input_textbox, btn_step, btn_restart]
295
+
296
+ def _convert_to_chatbot_output(all_messages, display_recv=False):
297
+ chatbot_output = []
298
+ for i, message in enumerate(all_messages):
299
+ agent_name, msg, recv = (
300
+ message.agent_name,
301
+ message.content,
302
+ str(message.visible_to),
303
+ )
304
+ new_msg = re.sub(
305
+ r"\n+", "<br>", msg.strip()
306
+ ) # Preprocess message for chatbot output
307
+ if display_recv:
308
+ new_msg = f"**{agent_name} (-> {recv})**: {new_msg}" # Add role to the message
309
+ else:
310
+ new_msg = f"**{agent_name}**: {new_msg}"
311
+
312
+ if agent_name == "Moderator":
313
+ chatbot_output.append((new_msg, None))
314
+ else:
315
+ chatbot_output.append((None, new_msg))
316
+ return chatbot_output
317
+
318
+ def _create_arena_config_from_components(all_comps: dict) -> ArenaConfig:
319
+ env_desc = all_comps[env_desc_textbox]
320
+
321
+ # Initialize the players
322
+ num_players = all_comps[num_player_slider]
323
+ player_configs = []
324
+ for i in range(num_players):
325
+ role_name, role_desc, backend_type, temperature, max_tokens = (
326
+ all_comps[c]
327
+ for c in players_idx2comp[i]
328
+ if not isinstance(c, (gr.Accordion, gr.Tab))
329
+ )
330
+ player_config = {
331
+ "name": role_name,
332
+ "role_desc": role_desc,
333
+ "global_prompt": env_desc,
334
+ "backend": {
335
+ "backend_type": backend_type,
336
+ "temperature": temperature,
337
+ "max_tokens": max_tokens,
338
+ },
339
+ }
340
+ player_configs.append(player_config)
341
+
342
+ # Initialize the environment
343
+ env_type = all_comps[env_selector]
344
+ # Get moderator config
345
+ (
346
+ mod_role_desc,
347
+ mod_terminal_condition,
348
+ moderator_backend_type,
349
+ mod_temp,
350
+ mod_max_tokens,
351
+ ) = (
352
+ all_comps[c]
353
+ for c in moderator_components
354
+ if not isinstance(c, (gr.Accordion, gr.Tab))
355
+ )
356
+ moderator_config = {
357
+ "role_desc": mod_role_desc,
358
+ "global_prompt": env_desc,
359
+ "terminal_condition": mod_terminal_condition,
360
+ "backend": {
361
+ "backend_type": moderator_backend_type,
362
+ "temperature": mod_temp,
363
+ "max_tokens": mod_max_tokens,
364
+ },
365
+ }
366
+ env_config = {
367
+ "env_type": env_type,
368
+ "parallel": all_comps[parallel_checkbox],
369
+ "moderator": moderator_config,
370
+ "moderator_visibility": "all",
371
+ "moderator_period": None,
372
+ }
373
+
374
+ # arena_config = {"players": player_configs, "environment": env_config}
375
+ arena_config = ArenaConfig(players=player_configs, environment=env_config)
376
+ return arena_config
377
+
378
+ def step_game(all_comps: dict):
379
+ yield {
380
+ btn_step: gr.update(value="Running...", interactive=False),
381
+ btn_restart: gr.update(interactive=False),
382
+ }
383
+
384
+ cur_state = all_comps[state]
385
+
386
+ # If arena is not yet created, create it
387
+ if cur_state["arena"] is None:
388
+ # Create the Arena
389
+ arena_config = _create_arena_config_from_components(all_comps)
390
+ arena = Arena.from_config(arena_config)
391
+ log_arena(arena, database=DB)
392
+ cur_state["arena"] = arena
393
+ else:
394
+ arena = cur_state["arena"]
395
+
396
+ try:
397
+ timestep = arena.step()
398
+ except HumanBackendError as e:
399
+ # Handle human input and recover with the game update
400
+ human_input = all_comps[human_input_textbox]
401
+ if human_input == "":
402
+ timestep = None # Failed to get human input
403
+ else:
404
+ timestep = arena.environment.step(e.agent_name, human_input)
405
+ except TooManyInvalidActions:
406
+ timestep = arena.current_timestep
407
+ timestep.observation.append(
408
+ Message(
409
+ "System",
410
+ "Too many invalid actions. Game over.",
411
+ turn=-1,
412
+ visible_to="all",
413
+ )
414
+ )
415
+ timestep.terminal = True
416
+
417
+ if timestep is None:
418
+ yield {
419
+ human_input_textbox: gr.update(
420
+ value="", placeholder="Please enter a valid input"
421
+ ),
422
+ btn_step: gr.update(value="Next Step", interactive=True),
423
+ btn_restart: gr.update(interactive=True),
424
+ }
425
+ else:
426
+ all_messages = timestep.observation # user sees what the moderator sees
427
+ log_messages(arena, all_messages, database=DB)
428
+
429
+ chatbot_output = _convert_to_chatbot_output(all_messages, display_recv=True)
430
+ update_dict = {
431
+ human_input_textbox: gr.Textbox.update(value=""),
432
+ chatbot: chatbot_output,
433
+ btn_step: gr.update(
434
+ value="Next Step", interactive=not timestep.terminal
435
+ ),
436
+ btn_restart: gr.update(interactive=True),
437
+ state: cur_state,
438
+ }
439
+ # Get the visible messages for each player
440
+ for i, player in enumerate(arena.players):
441
+ player_messages = arena.environment.get_observation(player.name)
442
+ player_output = _convert_to_chatbot_output(player_messages)
443
+ # Update the player's chatbot output
444
+ update_dict[player_chatbots[i]] = player_output
445
+
446
+ if DEBUG:
447
+ arena.environment.print()
448
+
449
+ yield update_dict
450
+
451
+ def restart_game(all_comps: dict):
452
+ cur_state = all_comps[state]
453
+ cur_state["arena"] = None
454
+ yield {
455
+ chatbot: [],
456
+ btn_restart: gr.update(interactive=False),
457
+ btn_step: gr.update(interactive=False),
458
+ state: cur_state,
459
+ }
460
+
461
+ arena_config = _create_arena_config_from_components(all_comps)
462
+ arena = Arena.from_config(arena_config)
463
+ log_arena(arena, database=DB)
464
+ cur_state["arena"] = arena
465
+
466
+ yield {
467
+ btn_step: gr.update(value="Start", interactive=True),
468
+ btn_restart: gr.update(interactive=True),
469
+ state: cur_state,
470
+ }
471
+
472
+ # Remove Accordion and Tab from the list of components
473
+ all_components = [
474
+ comp for comp in all_components if not isinstance(comp, (gr.Accordion, gr.Tab))
475
+ ]
476
+
477
+ # If any of the Textbox, Slider, Checkbox, Dropdown, RadioButtons is changed, the Step button is disabled
478
+ for comp in all_components:
479
+
480
+ def _disable_step_button(state):
481
+ if state["arena"] is not None:
482
+ return gr.update(interactive=False)
483
+ else:
484
+ return gr.update()
485
+
486
+ if (
487
+ isinstance(
488
+ comp, (gr.Textbox, gr.Slider, gr.Checkbox, gr.Dropdown, gr.Radio)
489
+ )
490
+ and comp is not human_input_textbox
491
+ ):
492
+ comp.change(_disable_step_button, state, btn_step)
493
+
494
+ btn_step.click(
495
+ step_game,
496
+ set(all_components + [state]),
497
+ [chatbot, *player_chatbots, btn_step, btn_restart, state, human_input_textbox],
498
+ )
499
+ btn_restart.click(
500
+ restart_game,
501
+ set(all_components + [state]),
502
+ [chatbot, *player_chatbots, btn_step, btn_restart, state, human_input_textbox],
503
+ )
504
+
505
+ # If an example is selected, update the components
506
+ def update_components_from_example(all_comps: dict):
507
+ example_name = all_comps[example_selector]
508
+ example_config = EXAMPLE_REGISTRY[example_name]
509
+ update_dict = {}
510
+
511
+ # Update the environment components
512
+ env_config = example_config["environment"]
513
+ update_dict[env_desc_textbox] = gr.update(value=example_config["global_prompt"])
514
+ update_dict[env_selector] = gr.update(value=env_config["env_type"])
515
+ update_dict[parallel_checkbox] = gr.update(value=env_config["parallel"])
516
+
517
+ # Update the moderator components
518
+ if "moderator" in env_config:
519
+ (
520
+ mod_role_desc,
521
+ mod_terminal_condition,
522
+ moderator_backend_type,
523
+ mod_temp,
524
+ mod_max_tokens,
525
+ ) = (
526
+ c
527
+ for c in moderator_components
528
+ if not isinstance(c, (gr.Accordion, gr.Tab))
529
+ )
530
+ update_dict[mod_role_desc] = gr.update(
531
+ value=env_config["moderator"]["role_desc"]
532
+ )
533
+ update_dict[mod_terminal_condition] = gr.update(
534
+ value=env_config["moderator"]["terminal_condition"]
535
+ )
536
+ update_dict[moderator_backend_type] = gr.update(
537
+ value=env_config["moderator"]["backend"]["backend_type"]
538
+ )
539
+ update_dict[mod_temp] = gr.update(
540
+ value=env_config["moderator"]["backend"]["temperature"]
541
+ )
542
+ update_dict[mod_max_tokens] = gr.update(
543
+ value=env_config["moderator"]["backend"]["max_tokens"]
544
+ )
545
+
546
+ # Update the player components
547
+ update_dict[num_player_slider] = gr.update(value=len(example_config["players"]))
548
+ for i, player_config in enumerate(example_config["players"]):
549
+ role_name, role_desc, backend_type, temperature, max_tokens = (
550
+ c
551
+ for c in players_idx2comp[i]
552
+ if not isinstance(c, (gr.Accordion, gr.Tab))
553
+ )
554
+
555
+ update_dict[role_name] = gr.update(value=player_config["name"])
556
+ update_dict[role_desc] = gr.update(value=player_config["role_desc"])
557
+ update_dict[backend_type] = gr.update(
558
+ value=player_config["backend"]["backend_type"]
559
+ )
560
+ update_dict[temperature] = gr.update(
561
+ value=player_config["backend"]["temperature"]
562
+ )
563
+ update_dict[max_tokens] = gr.update(
564
+ value=player_config["backend"]["max_tokens"]
565
+ )
566
+
567
+ return update_dict
568
+
569
+ example_selector.change(
570
+ update_components_from_example,
571
+ set(all_components + [state]),
572
+ all_components + [state],
573
+ )
574
+
575
+ demo.queue()
576
+ demo.launch(debug=DEBUG, server_port=8080)