mlabonne commited on
Commit
e5d94ce
·
verified ·
1 Parent(s): 1ea59ea

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +388 -0
app.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import io
3
+ import os
4
+ import re
5
+ import time
6
+ from collections import defaultdict
7
+ from datetime import datetime
8
+
9
+ import cairosvg
10
+ import chess
11
+ import chess.svg
12
+ import gistyc
13
+ import numpy as np
14
+ import outlines.models as models
15
+ import outlines.text.generate as generate
16
+ import pandas as pd
17
+ import requests
18
+ from tqdm.auto import tqdm
19
+ from IPython.display import Image as IPythonImage
20
+ from IPython.display import clear_output, update_display
21
+ from PIL import Image as PILImage
22
+ import gradio as gr
23
+
24
+ # Generate regular expression for legal moves
25
+ def generate_regex(board):
26
+ legal_moves = list(board.legal_moves)
27
+ move_strings = [board.san(move) for move in legal_moves]
28
+ move_strings = [re.sub(r"[+#]", "", move) for move in move_strings]
29
+ regex_pattern = "|".join(re.escape(move) for move in move_strings)
30
+ return regex_pattern
31
+
32
+
33
+ def write_pgn(
34
+ pgn_moves, model_id_white, model_id_black, result, time_budget, termination
35
+ ):
36
+ # Get current UTC date and time
37
+ current_utc_datetime = datetime.utcnow()
38
+ utc_date = current_utc_datetime.strftime("%Y.%m.%d")
39
+ utc_time = current_utc_datetime.strftime("%H:%M:%S")
40
+
41
+ # Output the final PGN with CLKS and additional details
42
+ final_pgn = f"""
43
+ [Event 'Chess LLM Arena']
44
+ [Site 'https://github.com/mlabonne/chessllm']
45
+ [Date '{utc_date}']
46
+ [White '{model_id_white}']
47
+ [Black '{model_id_black}']
48
+ [Result '{result}']
49
+ [Time '{utc_time}']
50
+ [TimeControl '{time_budget}+0']
51
+ [Termination '{termination}']
52
+
53
+ {pgn_moves}
54
+ """
55
+
56
+ return final_pgn
57
+
58
+
59
+ def determine_termination(board, time_budget_white, time_budget_black):
60
+ if board.is_checkmate():
61
+ return "Checkmate"
62
+ elif board.is_stalemate():
63
+ return "Stalemate"
64
+ elif board.is_insufficient_material():
65
+ return "Draw due to insufficient material"
66
+ elif board.can_claim_threefold_repetition():
67
+ return "Draw by threefold repetition"
68
+ elif board.can_claim_fifty_moves():
69
+ return "Draw by fifty-move rule"
70
+ elif time_budget_white <= 0 or time_budget_black <= 0:
71
+ return "Timeout"
72
+ else:
73
+ return "Unknown"
74
+
75
+ def format_elapsed(seconds):
76
+ """Formats elapsed time dynamically to hh:mm:ss, mm:ss, or ss format."""
77
+ hours, remainder = divmod(int(seconds), 3600)
78
+ minutes, seconds = divmod(remainder, 60)
79
+ if hours:
80
+ return f"{hours:02d}:{minutes:02d}:{seconds:02d}"
81
+ elif minutes:
82
+ return f"{minutes:02d}:{seconds:02d}"
83
+ else:
84
+ return f"{seconds:02d}"
85
+
86
+ def create_gif(image_list, gif_path, duration):
87
+ # Convert numpy arrays back to PIL images
88
+ pil_images = [PILImage.fromarray(image) for image in image_list]
89
+
90
+
91
+ def save_result_file(
92
+ pgn_id, model_id_white, model_id_black, termination, result, auth_token, gist_id
93
+ ):
94
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
95
+
96
+ # Data to be written to the file
97
+ data_str = f"{pgn_id},{timestamp},{model_id_white},{model_id_black},{termination},{result}\n"
98
+
99
+ # Append data to a text file
100
+ with open("chessllm_results.csv", "a") as file:
101
+ file.write(data_str)
102
+
103
+ # Update the Gist
104
+ gist_api = gistyc.GISTyc(auth_token=GITHUB_TOKEN)
105
+ response_update_data = gist_api.update_gist(
106
+ file_name="chessllm_results.csv", gist_id=gist_id
107
+ )
108
+
109
+
110
+ def save_pgn(final_pgn, file_name, auth_token):
111
+ # Write final PGN to a file
112
+ with open(file_name + ".pgn", "w") as file:
113
+ file.write(final_pgn)
114
+
115
+ gist_api = gistyc.GISTyc(auth_token=GITHUB_TOKEN)
116
+ response_data = gist_api.create_gist(file_name=file_name + ".pgn")
117
+
118
+ return response_data["id"]
119
+
120
+
121
+ def download_file(base_url, file_name):
122
+ # Unique query parameter to bypass cache (using a timestamp)
123
+ timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
124
+ url = f"{base_url}?ts={timestamp}"
125
+
126
+ headers = {
127
+ "Cache-Control": "no-cache, no-store, must-revalidate",
128
+ "Pragma": "no-cache",
129
+ "Expires": "0",
130
+ }
131
+
132
+ response = requests.get(url, headers=headers)
133
+
134
+ if response.status_code == 200:
135
+ with open(file_name, "wb") as file:
136
+ file.write(response.content)
137
+ else:
138
+ print(f"Failed to download file. HTTP status code: {response.status_code}")
139
+
140
+
141
+ def calculate_elo(rank1, rank2, result):
142
+ """
143
+ Calculate the new ELO rating for a player.
144
+ :param rank1: The current ELO rating of player 1
145
+ :param rank2: The current ELO rating of player 2
146
+ :param result: 1 if player 1 wins, 0 if player 2 wins, 0.5 for a draw
147
+ :return: The updated ELO rating of player 1
148
+ """
149
+ K = 32
150
+ expected_score1 = 1 / (1 + 10 ** ((rank2 - rank1) / 400))
151
+ new_rank1 = rank1 + K * (result - expected_score1)
152
+ return round(new_rank1)
153
+
154
+
155
+ def update_elo_ratings(chess_data):
156
+ """
157
+ Update ELO ratings for each player based on the match results in the dataset.
158
+ :param chess_data: DataFrame with chess match results
159
+ :return: A dictionary with updated ELO ratings for each player
160
+ """
161
+ elo_ratings = defaultdict(lambda: 1000) # Default ELO rating is 1000
162
+
163
+ for index, row in chess_data.iterrows():
164
+ if row["Result"] == "*":
165
+ continue # Skip ongoing games
166
+
167
+ model1 = row["Model1"]
168
+ model2 = row["Model2"]
169
+ result = row["Result"]
170
+
171
+ model1_elo = elo_ratings[model1]
172
+ model2_elo = elo_ratings[model2]
173
+
174
+ # Update ELO based on the result
175
+ if result == "1-0": # Model1 wins
176
+ elo_ratings[model1] = calculate_elo(model1_elo, model2_elo, 1)
177
+ elo_ratings[model2] = calculate_elo(model2_elo, model1_elo, 0)
178
+ elif result == "0-1": # Model2 wins
179
+ elo_ratings[model1] = calculate_elo(model1_elo, model2_elo, 0)
180
+ elo_ratings[model2] = calculate_elo(model2_elo, model1_elo, 1)
181
+ elif result == "1/2-1/2": # Draw
182
+ elo_ratings[model1] = calculate_elo(model1_elo, model2_elo, 0.5)
183
+ elo_ratings[model2] = calculate_elo(model2_elo, model1_elo, 0.5)
184
+
185
+ return elo_ratings
186
+
187
+
188
+ def update(model_id_white, model_id_black):
189
+ model_white = models.transformers(model_id_white)
190
+ model_black = models.transformers(model_id_black)
191
+
192
+ TIME_BUDGET = 180
193
+ prompt = '1.'
194
+
195
+ # Initialize the chess board
196
+ board = chess.Board()
197
+ board_images = []
198
+ pgn_moves = ""
199
+ move_number = 1
200
+ result = None
201
+ clear_output(wait=True)
202
+
203
+ # Time budget
204
+ time_budget_white = TIME_BUDGET
205
+ time_budget_black = TIME_BUDGET
206
+ white_bar_format = "{desc} {n:.0f} seconds left | Elapsed: {elapsed}"
207
+ black_bar_format = "{desc} {n:.0f} seconds left | Elapsed: {elapsed}"
208
+ white_bar = tqdm(total=time_budget_white, desc=f"{model_id_white.split('/')[-1]}:", bar_format=white_bar_format, colour='white')
209
+ black_bar = tqdm(total=time_budget_black, desc=f"{model_id_black.split('/')[-1]}:", bar_format=black_bar_format, colour='black')
210
+
211
+ # Download results
212
+ url1 = (
213
+ f"https://gist.githubusercontent.com/chessllm/{RESULT_GIST_ID}/raw"
214
+ )
215
+ download_file(url1, "chessllm_results.csv")
216
+
217
+ # Load ELO ratings for each model
218
+ url2 = f"https://gist.githubusercontent.com/chessllm/{LEAD_GIST_ID}/raw"
219
+ download_file(url2, "chessllm_leaderboard.csv")
220
+ elo_ratings_df = pd.read_csv("chessllm_leaderboard.csv")
221
+
222
+ # Game loop
223
+ while not board.is_game_over():
224
+ # Select model
225
+ current_model = model_white if board.turn == chess.WHITE else model_black
226
+
227
+ # Generate regex pattern
228
+ regex_pattern = generate_regex(board)
229
+
230
+ # Generate move
231
+ start_time = time.time()
232
+ guided = generate.regex(current_model, regex_pattern, max_tokens=10)(prompt)
233
+ end_time = time.time()
234
+ move_duration = end_time - start_time
235
+
236
+ try:
237
+ # Parse move
238
+ move_san = guided.strip()
239
+ move = board.parse_san(move_san)
240
+ if move not in board.legal_moves:
241
+ print(f"Illegal move: {move_san}")
242
+ break
243
+ board.push(move)
244
+
245
+ # Write move
246
+ if board.turn == chess.BLACK:
247
+ move_str = f"{move_number}. {move_san} "
248
+ move_number += 1
249
+ else:
250
+ move_str = f"{move_san} "
251
+ pgn_moves += move_str
252
+
253
+ # Render the board to an image
254
+ last_move = board.peek()
255
+ svg = chess.svg.board(board=board, arrows=[(last_move.from_square, last_move.to_square)]).encode("utf-8")
256
+ png = cairosvg.svg2png(bytestring=svg)
257
+ image = PILImage.open(io.BytesIO(png))
258
+ board_images.append(np.array(image))
259
+
260
+ # Deduct the time taken for the move from the model's time budget
261
+ if board.turn == chess.WHITE:
262
+ time_budget_black -= move_duration
263
+ black_bar.n = time_budget_black
264
+ black_bar.set_postfix_str(f"{format_elapsed(black_bar.format_dict['elapsed'])} elapsed")
265
+ black_bar.refresh()
266
+ if time_budget_black <= 0:
267
+ result = "1-0"
268
+ break
269
+ else:
270
+ time_budget_white -= move_duration
271
+ white_bar.n = time_budget_white
272
+ white_bar.set_postfix_str(f"{format_elapsed(white_bar.format_dict['elapsed'])} elapsed")
273
+ white_bar.refresh()
274
+ if time_budget_white <= 0:
275
+ result = "0-1"
276
+ break
277
+
278
+ # Display board
279
+ yield image
280
+
281
+ except ValueError:
282
+ print(f"Invalid move: {guided}")
283
+ break
284
+
285
+ white_bar.close()
286
+ black_bar.close()
287
+
288
+ # Get result
289
+ if result is None:
290
+ result = board.result()
291
+
292
+ # Create PGN
293
+ termination = determine_termination(board, time_budget_white, time_budget_black)
294
+ final_pgn = write_pgn(
295
+ pgn_moves, model_id_white, model_id_black, result, TIME_BUDGET, termination
296
+ )
297
+ file_name = f"{model_id_white.split('/')[-1]}_vs_{model_id_black.split('/')[-1]}"
298
+ pgn_id = save_pgn(final_pgn, file_name, GITHUB_TOKEN)
299
+
300
+ # Save results
301
+ save_result_file(
302
+ pgn_id, model_id_white, model_id_black, termination, result, GITHUB_TOKEN, RESULT_GIST_ID
303
+ )
304
+
305
+ # Create and display the GIF
306
+ clear_output(wait=True)
307
+ create_gif(board_images, file_name + ".gif", duration=400)
308
+
309
+ # Print ELO ratings
310
+ current_elo_white = elo_ratings_df.loc[
311
+ elo_ratings_df["Model"] == model_id_white, "ELO Rating"
312
+ ].get(0, 1000)
313
+ current_elo_black = elo_ratings_df.loc[
314
+ elo_ratings_df["Model"] == model_id_black, "ELO Rating"
315
+ ].get(0, 1000)
316
+
317
+ if result == "1-0":
318
+ new_elo_white = calculate_elo(current_elo_white, current_elo_black, 1)
319
+ new_elo_black = calculate_elo(current_elo_black, current_elo_white, 0)
320
+ print(f"{model_id_white} wins! ({termination})")
321
+ print("ELO change:")
322
+ print(
323
+ f"* {model_id_white}: {current_elo_white} -> {new_elo_white} ({new_elo_white - current_elo_white:+})"
324
+ )
325
+ print(
326
+ f"* {model_id_black}: {current_elo_black} -> {new_elo_black} ({new_elo_black - current_elo_black:+})"
327
+ )
328
+ elif result == "0-1":
329
+ new_elo_white = calculate_elo(current_elo_white, current_elo_black, 0)
330
+ new_elo_black = calculate_elo(current_elo_black, current_elo_white, 1)
331
+ print(f"{model_id_black} wins! ({termination})")
332
+ print("ELO change:")
333
+ print(
334
+ f"* {model_id_white}: {current_elo_white} -> {new_elo_white} ({new_elo_white - current_elo_white:+})"
335
+ )
336
+ print(
337
+ f"* {model_id_black}: {current_elo_black} -> {new_elo_black} ({new_elo_black - current_elo_black:+})"
338
+ )
339
+ elif result == "1/2-1/2":
340
+ new_elo_white = calculate_elo(current_elo_white, current_elo_black, 0.5)
341
+ new_elo_black = calculate_elo(current_elo_black, current_elo_white, 0.5)
342
+ print(f"Draw! ({termination})")
343
+ print("ELO change:")
344
+ print(
345
+ f"* {model_id_white}: {current_elo_white} -> {new_elo_white} ({new_elo_white - current_elo_white:+})"
346
+ )
347
+ print(
348
+ f"* {model_id_black}: {current_elo_black} -> {new_elo_black} ({new_elo_black - current_elo_black:+})"
349
+ )
350
+ elif result == "*":
351
+ print(f"Ongoing game! ({termination})")
352
+
353
+ # Update ELO ratings for each model
354
+ chess_data = pd.read_csv('chessllm_results.csv')
355
+ elo_ratings = update_elo_ratings(chess_data)
356
+
357
+ # Convert the dictionary to a DataFrame for better display
358
+ elo_ratings_df = pd.DataFrame(elo_ratings.items(), columns=['Model', 'ELO Rating'])
359
+
360
+ # Round the ELO ratings to the nearest integer
361
+ elo_ratings_df['ELO Rating'] = elo_ratings_df['ELO Rating'].round().astype(int)
362
+
363
+ elo_ratings_df.sort_values(by='ELO Rating', ascending=False, inplace=True)
364
+ elo_ratings_df.reset_index(drop=True, inplace=True)
365
+ elo_ratings_df.to_csv('chessllm_leaderboard.csv', index=False)
366
+
367
+ # Upload chessllm_leaderboard.csv to GIST
368
+ gist_api = gistyc.GISTyc(auth_token=GITHUB_TOKEN)
369
+ response_update_data = gist_api.update_gist(file_name='chessllm_leaderboard.csv', gist_id={LEAD_GIST_ID})
370
+
371
+ return file_name + ".gif"
372
+
373
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
374
+ GITHUB_TOKEN = os.environ["GITHUB_TOKEN"]
375
+ RESULT_GIST_ID = "c491299e7b8a45a61ce5403a70cf8656"
376
+ LEAD_GIST_ID = "696115fe2df47fb2350fcff2663678c9"
377
+
378
+ with gr.Blocks() as demo:
379
+ gr.Markdown("Start typing below and then click **Run** to see the output.")
380
+ with gr.Row():
381
+ model_id_white = gr.Textbox(label="White Model ID", value="mlabonne/chesspythia-70m")
382
+ model_id_black = gr.Textbox(label="Black Model ID", value="BlueSunflower/Pythia-160M-chess")
383
+ btn = gr.Button("Run")
384
+ with gr.Row():
385
+ out = gr.Image(width=256)
386
+ btn.click(fn=update, inputs=[model_id_white, model_id_black], outputs=out)
387
+
388
+ demo.launch()