Spaces:
Runtime error
Runtime error
"""Script to generate features for a given board state. | |
""" | |
from typing import Optional | |
from lczerolens import ModelWrapper | |
from lczerolens.xai import ActivationLens | |
from lczerolens.encodings import InputEncoding | |
import chess | |
import einops | |
import torch | |
from .sae import SparseAutoEncoder | |
class OutputGenerator: | |
def __init__(self, sae: SparseAutoEncoder, wrapper: ModelWrapper, module_exp: Optional[str] = None): | |
self.sae = sae | |
self.wrapper = wrapper | |
self.lens = ActivationLens(module_exp=module_exp) | |
def generate( | |
self, | |
root_fen: Optional[str] = None, | |
traj_fen: Optional[str] = None, | |
root_board: Optional[chess.Board] = None, | |
traj_board: Optional[chess.Board] = None, | |
): | |
if root_board is not None and traj_board is not None: | |
input_encoding = InputEncoding.INPUT_CLASSICAL_112_PLANE | |
elif root_fen is not None and traj_fen is not None: | |
root_board = chess.Board(root_fen) | |
traj_board = chess.Board(traj_fen) | |
input_encoding = InputEncoding.INPUT_CLASSICAL_112_PLANE_REPEATED | |
else: | |
raise ValueError | |
iter_boards = iter([([root_board, traj_board],)]) | |
result_iter = self.lens.analyse_batched_boards( | |
iter_boards, | |
self.wrapper, | |
return_output=True, | |
wrapper_kwargs={ | |
"input_encoding": input_encoding, | |
} | |
) | |
act_dict, (model_output,) = next(result_iter) | |
if len(act_dict) == 0: | |
raise ValueError("No module matced the given expression.") | |
elif len(act_dict) > 1: | |
raise ValueError("Multiple modules matched the given expression.") | |
acts = next(iter(act_dict.values())) | |
root_acts = einops.rearrange(acts[0], "c h w -> (h w) c") | |
traj_acts = einops.rearrange(acts[1], "c h w -> (h w) c") | |
pixel_acts = torch.cat([root_acts, traj_acts], dim=1) | |
sae_output = self.sae(pixel_acts, output_features=True) | |
return model_output, pixel_acts, sae_output | |