Spaces:
Running
Running
import gradio as gr | |
import chess.svg | |
from lczerolens import LczeroBoard, LczeroModel, Lens | |
from . import constants | |
def create_board_figure( | |
board: LczeroBoard, | |
*, | |
orientation: bool = chess.WHITE, | |
arrows: str = "", | |
square: str = "", | |
name: str = "board", | |
): | |
try: | |
if arrows: | |
arrows_list = arrows.split(" ") | |
chess_arrows = [] | |
for arrow in arrows_list: | |
from_square, to_square = arrow[:2], arrow[2:] | |
chess_arrows.append( | |
( | |
chess.parse_square(from_square), | |
chess.parse_square(to_square), | |
) | |
) | |
else: | |
chess_arrows = [] | |
except ValueError: | |
chess_arrows = [] | |
gr.Warning("Invalid arrows, using none.") | |
try: | |
color_dict = {chess.parse_square(square): "#FF0000"} if square else {} | |
except ValueError: | |
color_dict = {} | |
gr.Warning("Invalid square, using none.") | |
svg_board = chess.svg.board( | |
board, | |
size=350, | |
orientation=orientation, | |
arrows=chess_arrows, | |
fill=color_dict, | |
) | |
with open(f"{constants.FIGURE_DIRECTORY}/{name}.svg", "w") as f: | |
f.write(svg_board) | |
return f"{constants.FIGURE_DIRECTORY}/{name}.svg" | |
class OutputLens(Lens): | |
def _intervene(self, model: LczeroModel, **kwargs) -> dict: | |
return model.output.save() | |
def get_info(model: LczeroModel, board: LczeroBoard): | |
lens = OutputLens() | |
output = lens.analyse(model, board) | |
w = output["wdl"][0,0] | |
d = output["wdl"][0,1] | |
l = output["wdl"][0,2] | |
legal_indices = board.get_legal_indices() | |
best_move_idx = output["policy"].gather(dim=1, index=legal_indices.unsqueeze(0)).argmax(dim=1).item() | |
best_move = board.decode_move(legal_indices[best_move_idx]) | |
info = f"w: {w:.2f}, d: {d:.2f}, l: {l:.2f}, best: {best_move}" | |
return info | |