Spaces:
Running
Running
"""Interface to play against the model. | |
""" | |
import os | |
import chess | |
import chess.pgn | |
import random | |
import gradio as gr | |
from lczerolens import LczeroBoard, LczeroModel | |
from lczerolens.play import PolicySampler | |
from .. import constants | |
from ..utils import create_board_figure | |
def get_sampler(model_name: str): | |
model = LczeroModel.from_onnx_path(os.path.join(constants.ONNX_MODEL_DIRECTORY, model_name)) | |
return PolicySampler(model) | |
def get_pgn(board: LczeroBoard): | |
game = chess.pgn.Game() | |
for move in board.move_stack: | |
game.add_variation(move) | |
return str(game).split("\n")[-1] | |
def render_board( | |
board: LczeroBoard, | |
): | |
player = board.turn | |
if len(board.move_stack) > 0: | |
last_move_uci = board.peek().uci() | |
else: | |
last_move_uci = None | |
if board.is_check(): | |
check = board.king(board.turn) | |
else: | |
check = None | |
filepath = create_board_figure( | |
board, | |
orientation=player, | |
arrows=last_move_uci, | |
square=check, | |
name="play_board", | |
) | |
return filepath | |
def gather_outputs(board: LczeroBoard, sampler: PolicySampler): | |
return sampler, board, board.fen(), get_pgn(board), render_board(board), "" | |
def get_init(model_name: str): | |
sampler = get_sampler(model_name) | |
is_ai_white = random.choice([True, False]) | |
init_board = LczeroBoard() | |
if is_ai_white: | |
play_ai_move(init_board, sampler) | |
return gather_outputs(init_board, sampler) | |
def play_user_move_then_ai_move( | |
uci_move: str, | |
board: LczeroBoard, | |
sampler: PolicySampler, | |
): | |
board.push_uci(uci_move) | |
play_ai_move(board, sampler) | |
return gather_outputs(board, sampler) | |
def play_ai_move( | |
board: LczeroBoard, | |
sampler: PolicySampler, | |
): | |
move, _ = next(iter(sampler.get_next_moves([board]))) | |
board.push(move) | |
with gr.Blocks() as interface: | |
with gr.Row(): | |
with gr.Column(): | |
current_fen = gr.Textbox( | |
label="Board FEN", | |
lines=1, | |
max_lines=1, | |
value=chess.STARTING_FEN, | |
) | |
current_pgn = gr.Textbox( | |
label="Action sequence", | |
lines=1, | |
value="", | |
) | |
with gr.Row(): | |
move_to_play = gr.Textbox( | |
label="Move to play (UCI)", | |
lines=1, | |
max_lines=1, | |
value="", | |
) | |
with gr.Column(): | |
model_name = gr.Dropdown( | |
label="Model", | |
choices=constants.ONNX_MODEL_NAMES, | |
) | |
play_button = gr.Button("Play") | |
reset_button = gr.Button("Reset") | |
with gr.Column(): | |
image_board = gr.Image(label="Board", interactive=False) | |
sampler = gr.State(value=None) | |
board = gr.State(value=None) | |
outputs = [sampler, board, current_fen, current_pgn, image_board, move_to_play] | |
play_button.click( | |
play_user_move_then_ai_move, | |
inputs=[move_to_play, board, sampler], | |
outputs=outputs, | |
) | |
move_to_play.submit( | |
play_user_move_then_ai_move, | |
inputs=[move_to_play, board, sampler], | |
outputs=outputs, | |
) | |
model_name.change( | |
get_sampler, | |
inputs=[model_name], | |
outputs=[sampler], | |
) | |
reset_button.click( | |
get_init, | |
inputs=[model_name], | |
outputs=outputs, | |
) | |
interface.load( | |
get_init, | |
inputs=[model_name], | |
outputs=outputs, | |
) | |