import gradio as gr import chess.svg from lczerolens import LczeroBoard, LczeroModel, Lens from . import constants def create_board_figure( board: LczeroBoard, *, orientation: bool = chess.WHITE, arrows: str = "", square: str = "", name: str = "board", ): try: if arrows: arrows_list = arrows.split(" ") chess_arrows = [] for arrow in arrows_list: from_square, to_square = arrow[:2], arrow[2:] chess_arrows.append( ( chess.parse_square(from_square), chess.parse_square(to_square), ) ) else: chess_arrows = [] except ValueError: chess_arrows = [] gr.Warning("Invalid arrows, using none.") try: color_dict = {chess.parse_square(square): "#FF0000"} if square else {} except ValueError: color_dict = {} gr.Warning("Invalid square, using none.") svg_board = chess.svg.board( board, size=350, orientation=orientation, arrows=chess_arrows, fill=color_dict, ) with open(f"{constants.FIGURE_DIRECTORY}/{name}.svg", "w") as f: f.write(svg_board) return f"{constants.FIGURE_DIRECTORY}/{name}.svg" class OutputLens(Lens): def _intervene(self, model: LczeroModel, **kwargs) -> dict: return model.output.save() def get_info(model: LczeroModel, board: LczeroBoard): lens = OutputLens() output = lens.analyse(model, board) w = output["wdl"][0,0] d = output["wdl"][0,1] l = output["wdl"][0,2] legal_indices = board.get_legal_indices() best_move_idx = output["policy"].gather(dim=1, index=legal_indices.unsqueeze(0)).argmax(dim=1).item() best_move = board.decode_move(legal_indices[best_move_idx]) info = f"w: {w:.2f}, d: {d:.2f}, l: {l:.2f}, best: {best_move}" return info