Spaces:
Runtime error
Runtime error
fen entry
Browse files- src/attention_interface.py +30 -16
src/attention_interface.py
CHANGED
@@ -13,23 +13,28 @@ from . import constants, state, visualisation
|
|
13 |
|
14 |
def compute_cache(
|
15 |
game_pgn,
|
|
|
16 |
attention_layer,
|
17 |
attention_head,
|
18 |
comp_index,
|
19 |
state_cache,
|
20 |
state_board_index,
|
21 |
):
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
33 |
state_cache = [(fen, state.model_cache(fen)) for fen in fen_list]
|
34 |
return (
|
35 |
*make_plot(
|
@@ -156,10 +161,19 @@ def next_board(
|
|
156 |
with gr.Blocks() as interface:
|
157 |
with gr.Row():
|
158 |
with gr.Column():
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
compute_cache_button = gr.Button("Compute cache")
|
164 |
with gr.Group():
|
165 |
with gr.Row():
|
@@ -228,7 +242,7 @@ with gr.Blocks() as interface:
|
|
228 |
state_board_index = gr.State(value=0)
|
229 |
compute_cache_button.click(
|
230 |
compute_cache,
|
231 |
-
inputs=[game_pgn, *static_inputs, state_cache, state_board_index],
|
232 |
outputs=[*static_outputs, state_cache],
|
233 |
)
|
234 |
|
|
|
13 |
|
14 |
def compute_cache(
|
15 |
game_pgn,
|
16 |
+
board_fen,
|
17 |
attention_layer,
|
18 |
attention_head,
|
19 |
comp_index,
|
20 |
state_cache,
|
21 |
state_board_index,
|
22 |
):
|
23 |
+
if game_pgn == "" and board_fen != "":
|
24 |
+
board = chess.Board(board_fen)
|
25 |
+
fen_list = [board.fen()]
|
26 |
+
else:
|
27 |
+
board = chess.Board()
|
28 |
+
fen_list = [board.fen()]
|
29 |
+
for move in game_pgn.split():
|
30 |
+
if move.endswith("."):
|
31 |
+
continue
|
32 |
+
try:
|
33 |
+
board.push_san(move)
|
34 |
+
fen_list.append(board.fen())
|
35 |
+
except ValueError:
|
36 |
+
gr.Warning(f"Invalid move {move}, stopping before it.")
|
37 |
+
break
|
38 |
state_cache = [(fen, state.model_cache(fen)) for fen in fen_list]
|
39 |
return (
|
40 |
*make_plot(
|
|
|
161 |
with gr.Blocks() as interface:
|
162 |
with gr.Row():
|
163 |
with gr.Column():
|
164 |
+
with gr.Group():
|
165 |
+
gr.Markdown(
|
166 |
+
"Specify the game PGN of FEN string that you want to analyse (PGN overrides FEN)."
|
167 |
+
)
|
168 |
+
game_pgn = gr.Textbox(
|
169 |
+
label="Game PGN",
|
170 |
+
lines=1,
|
171 |
+
)
|
172 |
+
board_fen = gr.Textbox(
|
173 |
+
label="Board FEN",
|
174 |
+
lines=1,
|
175 |
+
max_lines=1,
|
176 |
+
)
|
177 |
compute_cache_button = gr.Button("Compute cache")
|
178 |
with gr.Group():
|
179 |
with gr.Row():
|
|
|
242 |
state_board_index = gr.State(value=0)
|
243 |
compute_cache_button.click(
|
244 |
compute_cache,
|
245 |
+
inputs=[game_pgn, board_fen, *static_inputs, state_cache, state_board_index],
|
246 |
outputs=[*static_outputs, state_cache],
|
247 |
)
|
248 |
|