File size: 5,396 Bytes
0d998a6
 
 
 
 
 
 
340463d
0d998a6
 
 
 
 
 
 
 
 
 
 
 
 
 
340463d
0d998a6
 
 
340463d
 
0d998a6
 
 
 
340463d
 
0d998a6
340463d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d998a6
 
 
 
 
 
340463d
 
0d998a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340463d
0d998a6
 
 
 
 
 
340463d
0d998a6
 
 
 
 
 
 
 
340463d
0d998a6
 
 
340463d
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
Gradio interface for plotting policy.
"""

import chess
import gradio as gr
import uuid
import torch

from lczerolens.encodings import encode_move

from src import constants, global_variables, visualisation


def compute_features_fn(
    features,
    model_output, 
    file_id,
    root_fen, 
    traj_fen, 
    feature_index
):
    model_output, pixel_acts, sae_output = global_variables.generator.generate(
        root_fen=root_fen,
        traj_fen=traj_fen
    )
    features = sae_output["features"]
    x_hat = sae_output["x_hat"]
    first_output = render_feature_index(
        features,
        model_output,
        file_id,
        traj_fen, 
        feature_index
    )

    half_a_dim = constants.ACTIVATION_DIM // 2
    half_f_dim = constants.DICTIONARY_SIZE // 2
    pixel_f_avg = features.mean(dim=0)
    pixel_f_active = (features > 0).float().mean(dim=0)
    pixel_p_avg = features.mean(dim=1)
    pixel_p_active = (features > 0).float().mean(dim=1)

    board = chess.Board(traj_fen)
    if board.turn:
        most_avg_pixels = pixel_p_avg.topk(5).indices.tolist()
        most_active_pixels = pixel_p_active.topk(5).indices.tolist()
    else:
        most_avg_pixels = pixel_p_avg.view(8,8).flip(0).view(64).topk(5).indices.tolist()
        most_active_pixels = pixel_p_active.view(8,8).flip(0).view(64).topk(5).indices.tolist()

    info = f"Root WDL: {model_output['wdl'][0]}\n"
    info += f"Traj WDL: {model_output['wdl'][1]}\n"
    info += f"MSE loss: {torch.nn.functional.mse_loss(x_hat, pixel_acts, reduction='none').sum(dim=1).mean()}\n"
    info += f"MSE loss (root): {torch.nn.functional.mse_loss(x_hat[:,:half_a_dim], pixel_acts[:,:half_a_dim], reduction='none').sum(dim=1).mean()}\n"
    info += f"MSE loss (traj): {torch.nn.functional.mse_loss(x_hat[:,half_a_dim:], pixel_acts[:,half_a_dim:], reduction='none').sum(dim=1).mean()}\n"
    info += f"L0 loss: {(features>0).sum(dim=1).float().mean()}\n"
    info += f"L0 loss (c): {(features[:,:half_f_dim]>0).sum(dim=1).float().mean()}\n"
    info += f"L0 loss (d): {(features[:,half_f_dim:]>0).sum(dim=1).float().mean()}\n"
    info += f"Most active features (avg): {pixel_f_avg.topk(5).indices.tolist()}\n"
    info += f"Most active features (active): {pixel_f_active.topk(5).indices.tolist()}\n"
    info += f"Most active pixels (avg): {[chess.SQUARE_NAMES[p] for p in most_avg_pixels]}\n"
    info += f"Most active pixels (active): {[chess.SQUARE_NAMES[p] for p in most_active_pixels]}"
    
    return *first_output, info


def render_feature_index(
    features,
    model_output,
    file_id,
    traj_fen, 
    feature_index
):
    if file_id is None:
        file_id = str(uuid.uuid4())
    board = chess.Board(traj_fen)
    pixel_features = features[:,feature_index]
    if board.turn:
        heatmap = pixel_features.view(64)
    else:
        heatmap = pixel_features.view(8,8).flip(0).view(64)

    best_legal_logit = None
    best_legal_move = None
    for move in board.legal_moves:
        move_index = encode_move(move, (board.turn, not board.turn))
        logit = model_output["policy"][1,move_index].item()
        if best_legal_logit is None:
            best_legal_logit = logit
        else:
            best_legal_move = move

    svg_board, fig = visualisation.render_heatmap(
        board,
        heatmap,
        arrows=[(best_legal_move.from_square, best_legal_move.to_square)],
    )
    with open(f"{constants.FIGURES_FOLER}/{file_id}.svg", "w") as f:
        f.write(svg_board)
    return (
        features,
        model_output,
        file_id,
        f"{constants.FIGURES_FOLER}/{file_id}.svg",
        fig
    ) 

with gr.Blocks() as interface:
    with gr.Row():
        with gr.Column():
            root_fen = gr.Textbox(
                label="Root FEN",
                lines=1,
                max_lines=1,
                value=chess.STARTING_FEN,
            )
            traj_fen = gr.Textbox(
                label="Trajectory FEN",
                lines=1,
                max_lines=1,
                value="rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1",
            )
            compute_features = gr.Button("Compute features")

            with gr.Group():
                with gr.Row():
                    feature_index = gr.Slider(
                        label="Feature index",
                        minimum=0,
                        maximum=constants.DICTIONARY_SIZE-1,
                        step=1,
                        value=0,
                    )
                    
            with gr.Group():
                with gr.Row():    
                    info = gr.Textbox(label="Info", lines=1, max_lines=20, value="")
                with gr.Row():
                    colorbar = gr.Plot(label="Colorbar")
        with gr.Column():
            board_image = gr.Image(label="Board")

    features = gr.State(None)
    model_output = gr.State(None)
    file_id = gr.State(None)
    
    compute_features.click(
        compute_features_fn,
        inputs=[features, model_output, file_id, root_fen, traj_fen, feature_index],
        outputs=[features, model_output, file_id, board_image, colorbar, info],
    )
    feature_index.change(
        render_feature_index,
        inputs=[features, model_output, file_id, traj_fen, feature_index],
        outputs=[features, model_output, file_id, board_image, colorbar],
    )