Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import io
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import time
|
6 |
+
from collections import defaultdict
|
7 |
+
from datetime import datetime
|
8 |
+
|
9 |
+
import cairosvg
|
10 |
+
import chess
|
11 |
+
import chess.svg
|
12 |
+
import gistyc
|
13 |
+
import numpy as np
|
14 |
+
import outlines.models as models
|
15 |
+
import outlines.text.generate as generate
|
16 |
+
import pandas as pd
|
17 |
+
import requests
|
18 |
+
from tqdm.auto import tqdm
|
19 |
+
from IPython.display import Image as IPythonImage
|
20 |
+
from IPython.display import clear_output, update_display
|
21 |
+
from PIL import Image as PILImage
|
22 |
+
import gradio as gr
|
23 |
+
|
24 |
+
# Generate regular expression for legal moves
|
25 |
+
def generate_regex(board):
|
26 |
+
legal_moves = list(board.legal_moves)
|
27 |
+
move_strings = [board.san(move) for move in legal_moves]
|
28 |
+
move_strings = [re.sub(r"[+#]", "", move) for move in move_strings]
|
29 |
+
regex_pattern = "|".join(re.escape(move) for move in move_strings)
|
30 |
+
return regex_pattern
|
31 |
+
|
32 |
+
|
33 |
+
def write_pgn(
|
34 |
+
pgn_moves, model_id_white, model_id_black, result, time_budget, termination
|
35 |
+
):
|
36 |
+
# Get current UTC date and time
|
37 |
+
current_utc_datetime = datetime.utcnow()
|
38 |
+
utc_date = current_utc_datetime.strftime("%Y.%m.%d")
|
39 |
+
utc_time = current_utc_datetime.strftime("%H:%M:%S")
|
40 |
+
|
41 |
+
# Output the final PGN with CLKS and additional details
|
42 |
+
final_pgn = f"""
|
43 |
+
[Event 'Chess LLM Arena']
|
44 |
+
[Site 'https://github.com/mlabonne/chessllm']
|
45 |
+
[Date '{utc_date}']
|
46 |
+
[White '{model_id_white}']
|
47 |
+
[Black '{model_id_black}']
|
48 |
+
[Result '{result}']
|
49 |
+
[Time '{utc_time}']
|
50 |
+
[TimeControl '{time_budget}+0']
|
51 |
+
[Termination '{termination}']
|
52 |
+
|
53 |
+
{pgn_moves}
|
54 |
+
"""
|
55 |
+
|
56 |
+
return final_pgn
|
57 |
+
|
58 |
+
|
59 |
+
def determine_termination(board, time_budget_white, time_budget_black):
|
60 |
+
if board.is_checkmate():
|
61 |
+
return "Checkmate"
|
62 |
+
elif board.is_stalemate():
|
63 |
+
return "Stalemate"
|
64 |
+
elif board.is_insufficient_material():
|
65 |
+
return "Draw due to insufficient material"
|
66 |
+
elif board.can_claim_threefold_repetition():
|
67 |
+
return "Draw by threefold repetition"
|
68 |
+
elif board.can_claim_fifty_moves():
|
69 |
+
return "Draw by fifty-move rule"
|
70 |
+
elif time_budget_white <= 0 or time_budget_black <= 0:
|
71 |
+
return "Timeout"
|
72 |
+
else:
|
73 |
+
return "Unknown"
|
74 |
+
|
75 |
+
def format_elapsed(seconds):
|
76 |
+
"""Formats elapsed time dynamically to hh:mm:ss, mm:ss, or ss format."""
|
77 |
+
hours, remainder = divmod(int(seconds), 3600)
|
78 |
+
minutes, seconds = divmod(remainder, 60)
|
79 |
+
if hours:
|
80 |
+
return f"{hours:02d}:{minutes:02d}:{seconds:02d}"
|
81 |
+
elif minutes:
|
82 |
+
return f"{minutes:02d}:{seconds:02d}"
|
83 |
+
else:
|
84 |
+
return f"{seconds:02d}"
|
85 |
+
|
86 |
+
def create_gif(image_list, gif_path, duration):
|
87 |
+
# Convert numpy arrays back to PIL images
|
88 |
+
pil_images = [PILImage.fromarray(image) for image in image_list]
|
89 |
+
|
90 |
+
|
91 |
+
def save_result_file(
|
92 |
+
pgn_id, model_id_white, model_id_black, termination, result, auth_token, gist_id
|
93 |
+
):
|
94 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
95 |
+
|
96 |
+
# Data to be written to the file
|
97 |
+
data_str = f"{pgn_id},{timestamp},{model_id_white},{model_id_black},{termination},{result}\n"
|
98 |
+
|
99 |
+
# Append data to a text file
|
100 |
+
with open("chessllm_results.csv", "a") as file:
|
101 |
+
file.write(data_str)
|
102 |
+
|
103 |
+
# Update the Gist
|
104 |
+
gist_api = gistyc.GISTyc(auth_token=GITHUB_TOKEN)
|
105 |
+
response_update_data = gist_api.update_gist(
|
106 |
+
file_name="chessllm_results.csv", gist_id=gist_id
|
107 |
+
)
|
108 |
+
|
109 |
+
|
110 |
+
def save_pgn(final_pgn, file_name, auth_token):
|
111 |
+
# Write final PGN to a file
|
112 |
+
with open(file_name + ".pgn", "w") as file:
|
113 |
+
file.write(final_pgn)
|
114 |
+
|
115 |
+
gist_api = gistyc.GISTyc(auth_token=GITHUB_TOKEN)
|
116 |
+
response_data = gist_api.create_gist(file_name=file_name + ".pgn")
|
117 |
+
|
118 |
+
return response_data["id"]
|
119 |
+
|
120 |
+
|
121 |
+
def download_file(base_url, file_name):
|
122 |
+
# Unique query parameter to bypass cache (using a timestamp)
|
123 |
+
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
124 |
+
url = f"{base_url}?ts={timestamp}"
|
125 |
+
|
126 |
+
headers = {
|
127 |
+
"Cache-Control": "no-cache, no-store, must-revalidate",
|
128 |
+
"Pragma": "no-cache",
|
129 |
+
"Expires": "0",
|
130 |
+
}
|
131 |
+
|
132 |
+
response = requests.get(url, headers=headers)
|
133 |
+
|
134 |
+
if response.status_code == 200:
|
135 |
+
with open(file_name, "wb") as file:
|
136 |
+
file.write(response.content)
|
137 |
+
else:
|
138 |
+
print(f"Failed to download file. HTTP status code: {response.status_code}")
|
139 |
+
|
140 |
+
|
141 |
+
def calculate_elo(rank1, rank2, result):
|
142 |
+
"""
|
143 |
+
Calculate the new ELO rating for a player.
|
144 |
+
:param rank1: The current ELO rating of player 1
|
145 |
+
:param rank2: The current ELO rating of player 2
|
146 |
+
:param result: 1 if player 1 wins, 0 if player 2 wins, 0.5 for a draw
|
147 |
+
:return: The updated ELO rating of player 1
|
148 |
+
"""
|
149 |
+
K = 32
|
150 |
+
expected_score1 = 1 / (1 + 10 ** ((rank2 - rank1) / 400))
|
151 |
+
new_rank1 = rank1 + K * (result - expected_score1)
|
152 |
+
return round(new_rank1)
|
153 |
+
|
154 |
+
|
155 |
+
def update_elo_ratings(chess_data):
|
156 |
+
"""
|
157 |
+
Update ELO ratings for each player based on the match results in the dataset.
|
158 |
+
:param chess_data: DataFrame with chess match results
|
159 |
+
:return: A dictionary with updated ELO ratings for each player
|
160 |
+
"""
|
161 |
+
elo_ratings = defaultdict(lambda: 1000) # Default ELO rating is 1000
|
162 |
+
|
163 |
+
for index, row in chess_data.iterrows():
|
164 |
+
if row["Result"] == "*":
|
165 |
+
continue # Skip ongoing games
|
166 |
+
|
167 |
+
model1 = row["Model1"]
|
168 |
+
model2 = row["Model2"]
|
169 |
+
result = row["Result"]
|
170 |
+
|
171 |
+
model1_elo = elo_ratings[model1]
|
172 |
+
model2_elo = elo_ratings[model2]
|
173 |
+
|
174 |
+
# Update ELO based on the result
|
175 |
+
if result == "1-0": # Model1 wins
|
176 |
+
elo_ratings[model1] = calculate_elo(model1_elo, model2_elo, 1)
|
177 |
+
elo_ratings[model2] = calculate_elo(model2_elo, model1_elo, 0)
|
178 |
+
elif result == "0-1": # Model2 wins
|
179 |
+
elo_ratings[model1] = calculate_elo(model1_elo, model2_elo, 0)
|
180 |
+
elo_ratings[model2] = calculate_elo(model2_elo, model1_elo, 1)
|
181 |
+
elif result == "1/2-1/2": # Draw
|
182 |
+
elo_ratings[model1] = calculate_elo(model1_elo, model2_elo, 0.5)
|
183 |
+
elo_ratings[model2] = calculate_elo(model2_elo, model1_elo, 0.5)
|
184 |
+
|
185 |
+
return elo_ratings
|
186 |
+
|
187 |
+
|
188 |
+
def update(model_id_white, model_id_black):
|
189 |
+
model_white = models.transformers(model_id_white)
|
190 |
+
model_black = models.transformers(model_id_black)
|
191 |
+
|
192 |
+
TIME_BUDGET = 180
|
193 |
+
prompt = '1.'
|
194 |
+
|
195 |
+
# Initialize the chess board
|
196 |
+
board = chess.Board()
|
197 |
+
board_images = []
|
198 |
+
pgn_moves = ""
|
199 |
+
move_number = 1
|
200 |
+
result = None
|
201 |
+
clear_output(wait=True)
|
202 |
+
|
203 |
+
# Time budget
|
204 |
+
time_budget_white = TIME_BUDGET
|
205 |
+
time_budget_black = TIME_BUDGET
|
206 |
+
white_bar_format = "{desc} {n:.0f} seconds left | Elapsed: {elapsed}"
|
207 |
+
black_bar_format = "{desc} {n:.0f} seconds left | Elapsed: {elapsed}"
|
208 |
+
white_bar = tqdm(total=time_budget_white, desc=f"{model_id_white.split('/')[-1]}:", bar_format=white_bar_format, colour='white')
|
209 |
+
black_bar = tqdm(total=time_budget_black, desc=f"{model_id_black.split('/')[-1]}:", bar_format=black_bar_format, colour='black')
|
210 |
+
|
211 |
+
# Download results
|
212 |
+
url1 = (
|
213 |
+
f"https://gist.githubusercontent.com/chessllm/{RESULT_GIST_ID}/raw"
|
214 |
+
)
|
215 |
+
download_file(url1, "chessllm_results.csv")
|
216 |
+
|
217 |
+
# Load ELO ratings for each model
|
218 |
+
url2 = f"https://gist.githubusercontent.com/chessllm/{LEAD_GIST_ID}/raw"
|
219 |
+
download_file(url2, "chessllm_leaderboard.csv")
|
220 |
+
elo_ratings_df = pd.read_csv("chessllm_leaderboard.csv")
|
221 |
+
|
222 |
+
# Game loop
|
223 |
+
while not board.is_game_over():
|
224 |
+
# Select model
|
225 |
+
current_model = model_white if board.turn == chess.WHITE else model_black
|
226 |
+
|
227 |
+
# Generate regex pattern
|
228 |
+
regex_pattern = generate_regex(board)
|
229 |
+
|
230 |
+
# Generate move
|
231 |
+
start_time = time.time()
|
232 |
+
guided = generate.regex(current_model, regex_pattern, max_tokens=10)(prompt)
|
233 |
+
end_time = time.time()
|
234 |
+
move_duration = end_time - start_time
|
235 |
+
|
236 |
+
try:
|
237 |
+
# Parse move
|
238 |
+
move_san = guided.strip()
|
239 |
+
move = board.parse_san(move_san)
|
240 |
+
if move not in board.legal_moves:
|
241 |
+
print(f"Illegal move: {move_san}")
|
242 |
+
break
|
243 |
+
board.push(move)
|
244 |
+
|
245 |
+
# Write move
|
246 |
+
if board.turn == chess.BLACK:
|
247 |
+
move_str = f"{move_number}. {move_san} "
|
248 |
+
move_number += 1
|
249 |
+
else:
|
250 |
+
move_str = f"{move_san} "
|
251 |
+
pgn_moves += move_str
|
252 |
+
|
253 |
+
# Render the board to an image
|
254 |
+
last_move = board.peek()
|
255 |
+
svg = chess.svg.board(board=board, arrows=[(last_move.from_square, last_move.to_square)]).encode("utf-8")
|
256 |
+
png = cairosvg.svg2png(bytestring=svg)
|
257 |
+
image = PILImage.open(io.BytesIO(png))
|
258 |
+
board_images.append(np.array(image))
|
259 |
+
|
260 |
+
# Deduct the time taken for the move from the model's time budget
|
261 |
+
if board.turn == chess.WHITE:
|
262 |
+
time_budget_black -= move_duration
|
263 |
+
black_bar.n = time_budget_black
|
264 |
+
black_bar.set_postfix_str(f"{format_elapsed(black_bar.format_dict['elapsed'])} elapsed")
|
265 |
+
black_bar.refresh()
|
266 |
+
if time_budget_black <= 0:
|
267 |
+
result = "1-0"
|
268 |
+
break
|
269 |
+
else:
|
270 |
+
time_budget_white -= move_duration
|
271 |
+
white_bar.n = time_budget_white
|
272 |
+
white_bar.set_postfix_str(f"{format_elapsed(white_bar.format_dict['elapsed'])} elapsed")
|
273 |
+
white_bar.refresh()
|
274 |
+
if time_budget_white <= 0:
|
275 |
+
result = "0-1"
|
276 |
+
break
|
277 |
+
|
278 |
+
# Display board
|
279 |
+
yield image
|
280 |
+
|
281 |
+
except ValueError:
|
282 |
+
print(f"Invalid move: {guided}")
|
283 |
+
break
|
284 |
+
|
285 |
+
white_bar.close()
|
286 |
+
black_bar.close()
|
287 |
+
|
288 |
+
# Get result
|
289 |
+
if result is None:
|
290 |
+
result = board.result()
|
291 |
+
|
292 |
+
# Create PGN
|
293 |
+
termination = determine_termination(board, time_budget_white, time_budget_black)
|
294 |
+
final_pgn = write_pgn(
|
295 |
+
pgn_moves, model_id_white, model_id_black, result, TIME_BUDGET, termination
|
296 |
+
)
|
297 |
+
file_name = f"{model_id_white.split('/')[-1]}_vs_{model_id_black.split('/')[-1]}"
|
298 |
+
pgn_id = save_pgn(final_pgn, file_name, GITHUB_TOKEN)
|
299 |
+
|
300 |
+
# Save results
|
301 |
+
save_result_file(
|
302 |
+
pgn_id, model_id_white, model_id_black, termination, result, GITHUB_TOKEN, RESULT_GIST_ID
|
303 |
+
)
|
304 |
+
|
305 |
+
# Create and display the GIF
|
306 |
+
clear_output(wait=True)
|
307 |
+
create_gif(board_images, file_name + ".gif", duration=400)
|
308 |
+
|
309 |
+
# Print ELO ratings
|
310 |
+
current_elo_white = elo_ratings_df.loc[
|
311 |
+
elo_ratings_df["Model"] == model_id_white, "ELO Rating"
|
312 |
+
].get(0, 1000)
|
313 |
+
current_elo_black = elo_ratings_df.loc[
|
314 |
+
elo_ratings_df["Model"] == model_id_black, "ELO Rating"
|
315 |
+
].get(0, 1000)
|
316 |
+
|
317 |
+
if result == "1-0":
|
318 |
+
new_elo_white = calculate_elo(current_elo_white, current_elo_black, 1)
|
319 |
+
new_elo_black = calculate_elo(current_elo_black, current_elo_white, 0)
|
320 |
+
print(f"{model_id_white} wins! ({termination})")
|
321 |
+
print("ELO change:")
|
322 |
+
print(
|
323 |
+
f"* {model_id_white}: {current_elo_white} -> {new_elo_white} ({new_elo_white - current_elo_white:+})"
|
324 |
+
)
|
325 |
+
print(
|
326 |
+
f"* {model_id_black}: {current_elo_black} -> {new_elo_black} ({new_elo_black - current_elo_black:+})"
|
327 |
+
)
|
328 |
+
elif result == "0-1":
|
329 |
+
new_elo_white = calculate_elo(current_elo_white, current_elo_black, 0)
|
330 |
+
new_elo_black = calculate_elo(current_elo_black, current_elo_white, 1)
|
331 |
+
print(f"{model_id_black} wins! ({termination})")
|
332 |
+
print("ELO change:")
|
333 |
+
print(
|
334 |
+
f"* {model_id_white}: {current_elo_white} -> {new_elo_white} ({new_elo_white - current_elo_white:+})"
|
335 |
+
)
|
336 |
+
print(
|
337 |
+
f"* {model_id_black}: {current_elo_black} -> {new_elo_black} ({new_elo_black - current_elo_black:+})"
|
338 |
+
)
|
339 |
+
elif result == "1/2-1/2":
|
340 |
+
new_elo_white = calculate_elo(current_elo_white, current_elo_black, 0.5)
|
341 |
+
new_elo_black = calculate_elo(current_elo_black, current_elo_white, 0.5)
|
342 |
+
print(f"Draw! ({termination})")
|
343 |
+
print("ELO change:")
|
344 |
+
print(
|
345 |
+
f"* {model_id_white}: {current_elo_white} -> {new_elo_white} ({new_elo_white - current_elo_white:+})"
|
346 |
+
)
|
347 |
+
print(
|
348 |
+
f"* {model_id_black}: {current_elo_black} -> {new_elo_black} ({new_elo_black - current_elo_black:+})"
|
349 |
+
)
|
350 |
+
elif result == "*":
|
351 |
+
print(f"Ongoing game! ({termination})")
|
352 |
+
|
353 |
+
# Update ELO ratings for each model
|
354 |
+
chess_data = pd.read_csv('chessllm_results.csv')
|
355 |
+
elo_ratings = update_elo_ratings(chess_data)
|
356 |
+
|
357 |
+
# Convert the dictionary to a DataFrame for better display
|
358 |
+
elo_ratings_df = pd.DataFrame(elo_ratings.items(), columns=['Model', 'ELO Rating'])
|
359 |
+
|
360 |
+
# Round the ELO ratings to the nearest integer
|
361 |
+
elo_ratings_df['ELO Rating'] = elo_ratings_df['ELO Rating'].round().astype(int)
|
362 |
+
|
363 |
+
elo_ratings_df.sort_values(by='ELO Rating', ascending=False, inplace=True)
|
364 |
+
elo_ratings_df.reset_index(drop=True, inplace=True)
|
365 |
+
elo_ratings_df.to_csv('chessllm_leaderboard.csv', index=False)
|
366 |
+
|
367 |
+
# Upload chessllm_leaderboard.csv to GIST
|
368 |
+
gist_api = gistyc.GISTyc(auth_token=GITHUB_TOKEN)
|
369 |
+
response_update_data = gist_api.update_gist(file_name='chessllm_leaderboard.csv', gist_id={LEAD_GIST_ID})
|
370 |
+
|
371 |
+
return file_name + ".gif"
|
372 |
+
|
373 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
374 |
+
GITHUB_TOKEN = os.environ["GITHUB_TOKEN"]
|
375 |
+
RESULT_GIST_ID = "c491299e7b8a45a61ce5403a70cf8656"
|
376 |
+
LEAD_GIST_ID = "696115fe2df47fb2350fcff2663678c9"
|
377 |
+
|
378 |
+
with gr.Blocks() as demo:
|
379 |
+
gr.Markdown("Start typing below and then click **Run** to see the output.")
|
380 |
+
with gr.Row():
|
381 |
+
model_id_white = gr.Textbox(label="White Model ID", value="mlabonne/chesspythia-70m")
|
382 |
+
model_id_black = gr.Textbox(label="Black Model ID", value="BlueSunflower/Pythia-160M-chess")
|
383 |
+
btn = gr.Button("Run")
|
384 |
+
with gr.Row():
|
385 |
+
out = gr.Image(width=256)
|
386 |
+
btn.click(fn=update, inputs=[model_id_white, model_id_black], outputs=out)
|
387 |
+
|
388 |
+
demo.launch()
|