Xmaster6y commited on
Commit
7acfda9
·
unverified ·
1 Parent(s): 55ecc31
Files changed (1) hide show
  1. src/attention_interface.py +30 -16
src/attention_interface.py CHANGED
@@ -13,23 +13,28 @@ from . import constants, state, visualisation
13
 
14
  def compute_cache(
15
  game_pgn,
 
16
  attention_layer,
17
  attention_head,
18
  comp_index,
19
  state_cache,
20
  state_board_index,
21
  ):
22
- board = chess.Board()
23
- fen_list = [board.fen()]
24
- for move in game_pgn.split():
25
- if move.endswith("."):
26
- continue
27
- try:
28
- board.push_san(move)
29
- fen_list.append(board.fen())
30
- except ValueError:
31
- gr.Warning(f"Invalid move {move}, stopping before it.")
32
- break
 
 
 
 
33
  state_cache = [(fen, state.model_cache(fen)) for fen in fen_list]
34
  return (
35
  *make_plot(
@@ -156,10 +161,19 @@ def next_board(
156
  with gr.Blocks() as interface:
157
  with gr.Row():
158
  with gr.Column():
159
- game_pgn = gr.Textbox(
160
- label="Game PGN",
161
- lines=1,
162
- )
 
 
 
 
 
 
 
 
 
163
  compute_cache_button = gr.Button("Compute cache")
164
  with gr.Group():
165
  with gr.Row():
@@ -228,7 +242,7 @@ with gr.Blocks() as interface:
228
  state_board_index = gr.State(value=0)
229
  compute_cache_button.click(
230
  compute_cache,
231
- inputs=[game_pgn, *static_inputs, state_cache, state_board_index],
232
  outputs=[*static_outputs, state_cache],
233
  )
234
 
 
13
 
14
  def compute_cache(
15
  game_pgn,
16
+ board_fen,
17
  attention_layer,
18
  attention_head,
19
  comp_index,
20
  state_cache,
21
  state_board_index,
22
  ):
23
+ if game_pgn == "" and board_fen != "":
24
+ board = chess.Board(board_fen)
25
+ fen_list = [board.fen()]
26
+ else:
27
+ board = chess.Board()
28
+ fen_list = [board.fen()]
29
+ for move in game_pgn.split():
30
+ if move.endswith("."):
31
+ continue
32
+ try:
33
+ board.push_san(move)
34
+ fen_list.append(board.fen())
35
+ except ValueError:
36
+ gr.Warning(f"Invalid move {move}, stopping before it.")
37
+ break
38
  state_cache = [(fen, state.model_cache(fen)) for fen in fen_list]
39
  return (
40
  *make_plot(
 
161
  with gr.Blocks() as interface:
162
  with gr.Row():
163
  with gr.Column():
164
+ with gr.Group():
165
+ gr.Markdown(
166
+ "Specify the game PGN of FEN string that you want to analyse (PGN overrides FEN)."
167
+ )
168
+ game_pgn = gr.Textbox(
169
+ label="Game PGN",
170
+ lines=1,
171
+ )
172
+ board_fen = gr.Textbox(
173
+ label="Board FEN",
174
+ lines=1,
175
+ max_lines=1,
176
+ )
177
  compute_cache_button = gr.Button("Compute cache")
178
  with gr.Group():
179
  with gr.Row():
 
242
  state_board_index = gr.State(value=0)
243
  compute_cache_button.click(
244
  compute_cache,
245
+ inputs=[game_pgn, board_fen, *static_inputs, state_cache, state_board_index],
246
  outputs=[*static_outputs, state_cache],
247
  )
248