File size: 5,405 Bytes
343fa36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""
Gradio interface for visualizing the policy of a model.
"""

import gradio as gr

from demo import utils, visualisation
from lczerolens import GameDataset
from lczerolens.xai import ConceptDataset, HasThreatConcept

current_policy_statistics = None
current_lrp_statistics = None
current_probing_statistics = None
dataset = GameDataset("assets/test_stockfish_10.jsonl")
check_concept = HasThreatConcept("K", relative=True)
unique_check_dataset = ConceptDataset.from_game_dataset(dataset)
unique_check_dataset.set_concept(check_concept)


def list_models():
    """
    List the models in the model directory.
    """
    models_info = utils.get_models_info(leela=False)
    return sorted([[model_info[0]] for model_info in models_info])


def on_select_model_df(
    evt: gr.SelectData,
):
    """
    When a model is selected, update the statement.
    """
    return evt.value


def compute_policy_statistics(
    model_name,
):
    global current_policy_statistics
    global dataset

    if model_name == "":
        gr.Warning(
            "Please select a model.",
        )
        return None
    wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "policy")
    current_policy_statistics = lens.analyse_dataset(dataset, wrapper, 10)
    return make_policy_plot()


def make_policy_plot():
    global current_policy_statistics

    if current_policy_statistics is None:
        gr.Warning(
            "Please compute policy statistics first.",
        )
        return None
    else:
        return visualisation.render_policy_statistics(current_policy_statistics)


def compute_lrp_statistics(
    model_name,
):
    global current_lrp_statistics
    global dataset

    if model_name == "":
        gr.Warning(
            "Please select a model.",
        )
        return None, None, None
    wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "lrp")
    current_lrp_statistics = lens.compute_statistics(dataset, wrapper, 10)
    return make_lrp_plot()


def make_lrp_plot():
    global current_lrp_statistics

    if current_lrp_statistics is None:
        gr.Warning(
            "Please compute LRP statistics first.",
        )
        return None, None, None
    else:
        return visualisation.render_relevance_proportion(current_lrp_statistics)


def compute_probing_statistics(
    model_name,
):
    global current_probing_statistics
    global check_concept
    global unique_check_dataset

    if model_name == "":
        gr.Warning(
            "Please select a model.",
        )
        return None
    wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "probing", concept=check_concept)
    current_probing_statistics = lens.compute_statistics(unique_check_dataset, wrapper, 10)
    return make_probing_plot()


def make_probing_plot():
    global current_probing_statistics

    if current_probing_statistics is None:
        gr.Warning(
            "Please compute probing statistics first.",
        )
        return None
    else:
        return visualisation.render_probing_statistics(current_probing_statistics)


with gr.Blocks() as interface:
    with gr.Row():
        with gr.Column(scale=2):
            model_df = gr.Dataframe(
                headers=["Available models"],
                datatype=["str"],
                interactive=False,
                type="array",
                value=list_models,
            )
        with gr.Column(scale=1):
            with gr.Row():
                model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
    model_df.select(
        on_select_model_df,
        None,
        model_name,
    )

    with gr.Row():
        with gr.Column():
            policy_plot = gr.Plot(label="Policy statistics")
            policy_compute_button = gr.Button(value="Compute policy statistics")
            policy_plot_button = gr.Button(value="Plot policy statistics")

            policy_compute_button.click(
                compute_policy_statistics,
                inputs=[model_name],
                outputs=[policy_plot],
            )
            policy_plot_button.click(make_policy_plot, outputs=[policy_plot])

        with gr.Column():
            lrp_plot_hist = gr.Plot(label="LRP history statistics")

    with gr.Row():
        with gr.Column():
            lrp_plot_planes = gr.Plot(label="LRP planes statistics")

        with gr.Column():
            lrp_plot_pieces = gr.Plot(label="LRP pieces statistics")

    with gr.Row():
        lrp_compute_button = gr.Button(value="Compute LRP statistics")
    with gr.Row():
        lrp_plot_button = gr.Button(value="Plot LRP statistics")

    lrp_compute_button.click(
        compute_lrp_statistics,
        inputs=[model_name],
        outputs=[lrp_plot_hist, lrp_plot_planes, lrp_plot_pieces],
    )
    lrp_plot_button.click(
        make_lrp_plot,
        outputs=[lrp_plot_hist, lrp_plot_planes, lrp_plot_pieces],
    )

    with gr.Column():
        probing_plot = gr.Plot(label="Probing statistics")
        probing_compute_button = gr.Button(value="Compute probing statistics")
        probing_plot_button = gr.Button(value="Plot probing statistics")

        probing_compute_button.click(
            compute_probing_statistics,
            inputs=[model_name],
            outputs=[probing_plot],
        )
        probing_plot_button.click(make_probing_plot, outputs=[probing_plot])