"""Script to generate features for a given board state. """ from typing import Optional from lczerolens import ModelWrapper from lczerolens.xai import ActivationLens from lczerolens.encodings import InputEncoding import chess import einops import torch from .sae import SparseAutoEncoder class OutputGenerator: def __init__(self, sae: SparseAutoEncoder, wrapper: ModelWrapper, module_exp: Optional[str] = None): self.sae = sae self.wrapper = wrapper self.lens = ActivationLens(module_exp=module_exp) @torch.no_grad def generate( self, root_fen: Optional[str] = None, traj_fen: Optional[str] = None, root_board: Optional[chess.Board] = None, traj_board: Optional[chess.Board] = None, ): if root_board is not None and traj_board is not None: input_encoding = InputEncoding.INPUT_CLASSICAL_112_PLANE elif root_fen is not None and traj_fen is not None: root_board = chess.Board(root_fen) traj_board = chess.Board(traj_fen) input_encoding = InputEncoding.INPUT_CLASSICAL_112_PLANE_REPEATED else: raise ValueError iter_boards = iter([([root_board, traj_board],)]) result_iter = self.lens.analyse_batched_boards( iter_boards, self.wrapper, return_output=True, wrapper_kwargs={ "input_encoding": input_encoding, } ) act_dict, (model_output,) = next(result_iter) if len(act_dict) == 0: raise ValueError("No module matced the given expression.") elif len(act_dict) > 1: raise ValueError("Multiple modules matched the given expression.") acts = next(iter(act_dict.values())) root_acts = einops.rearrange(acts[0], "c h w -> (h w) c") traj_acts = einops.rearrange(acts[1], "c h w -> (h w) c") pixel_acts = torch.cat([root_acts, traj_acts], dim=1) sae_output = self.sae(pixel_acts, output_features=True) return model_output, pixel_acts, sae_output