File size: 2,113 Bytes
0d998a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340463d
0d998a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340463d
 
0d998a6
 
340463d
 
 
0d998a6
 
340463d
0d998a6
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""Script to generate features for a given board state.
"""

from typing import Optional

from lczerolens import ModelWrapper
from lczerolens.xai import ActivationLens
from lczerolens.encodings import InputEncoding
import chess
import einops
import torch

from .sae import SparseAutoEncoder


class OutputGenerator:
    
    def __init__(self, sae: SparseAutoEncoder, wrapper: ModelWrapper, module_exp: Optional[str] = None):
        self.sae = sae
        self.wrapper = wrapper
        self.lens = ActivationLens(module_exp=module_exp)

    @torch.no_grad
    def generate(
        self,
        root_fen: Optional[str] = None,
        traj_fen: Optional[str] = None,
        root_board: Optional[chess.Board] = None,
        traj_board: Optional[chess.Board] = None,
    ):
        if root_board is not None and traj_board is not None:
            input_encoding = InputEncoding.INPUT_CLASSICAL_112_PLANE
        elif root_fen is not None and traj_fen is not None:
            root_board = chess.Board(root_fen)
            traj_board = chess.Board(traj_fen)
            input_encoding = InputEncoding.INPUT_CLASSICAL_112_PLANE_REPEATED
        else:
            raise ValueError
        iter_boards = iter([([root_board, traj_board],)])
        result_iter = self.lens.analyse_batched_boards(
            iter_boards,
            self.wrapper,
            return_output=True,
            wrapper_kwargs={
                "input_encoding": input_encoding,
            }
        )
        act_dict, (model_output,) = next(result_iter)
        if len(act_dict) == 0:
            raise ValueError("No module matced the given expression.")
        elif len(act_dict) > 1:
            raise ValueError("Multiple modules matched the given expression.")
        acts = next(iter(act_dict.values()))
        root_acts = einops.rearrange(acts[0], "c h w -> (h w) c")
        traj_acts = einops.rearrange(acts[1], "c h w -> (h w) c")
        pixel_acts = torch.cat([root_acts, traj_acts], dim=1)
        sae_output = self.sae(pixel_acts, output_features=True)
        return model_output, pixel_acts, sae_output