Spaces:
Build error
Build error
File size: 5,405 Bytes
343fa36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
"""
Gradio interface for visualizing the policy of a model.
"""
import gradio as gr
from demo import utils, visualisation
from lczerolens import GameDataset
from lczerolens.xai import ConceptDataset, HasThreatConcept
current_policy_statistics = None
current_lrp_statistics = None
current_probing_statistics = None
dataset = GameDataset("assets/test_stockfish_10.jsonl")
check_concept = HasThreatConcept("K", relative=True)
unique_check_dataset = ConceptDataset.from_game_dataset(dataset)
unique_check_dataset.set_concept(check_concept)
def list_models():
"""
List the models in the model directory.
"""
models_info = utils.get_models_info(leela=False)
return sorted([[model_info[0]] for model_info in models_info])
def on_select_model_df(
evt: gr.SelectData,
):
"""
When a model is selected, update the statement.
"""
return evt.value
def compute_policy_statistics(
model_name,
):
global current_policy_statistics
global dataset
if model_name == "":
gr.Warning(
"Please select a model.",
)
return None
wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "policy")
current_policy_statistics = lens.analyse_dataset(dataset, wrapper, 10)
return make_policy_plot()
def make_policy_plot():
global current_policy_statistics
if current_policy_statistics is None:
gr.Warning(
"Please compute policy statistics first.",
)
return None
else:
return visualisation.render_policy_statistics(current_policy_statistics)
def compute_lrp_statistics(
model_name,
):
global current_lrp_statistics
global dataset
if model_name == "":
gr.Warning(
"Please select a model.",
)
return None, None, None
wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "lrp")
current_lrp_statistics = lens.compute_statistics(dataset, wrapper, 10)
return make_lrp_plot()
def make_lrp_plot():
global current_lrp_statistics
if current_lrp_statistics is None:
gr.Warning(
"Please compute LRP statistics first.",
)
return None, None, None
else:
return visualisation.render_relevance_proportion(current_lrp_statistics)
def compute_probing_statistics(
model_name,
):
global current_probing_statistics
global check_concept
global unique_check_dataset
if model_name == "":
gr.Warning(
"Please select a model.",
)
return None
wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "probing", concept=check_concept)
current_probing_statistics = lens.compute_statistics(unique_check_dataset, wrapper, 10)
return make_probing_plot()
def make_probing_plot():
global current_probing_statistics
if current_probing_statistics is None:
gr.Warning(
"Please compute probing statistics first.",
)
return None
else:
return visualisation.render_probing_statistics(current_probing_statistics)
with gr.Blocks() as interface:
with gr.Row():
with gr.Column(scale=2):
model_df = gr.Dataframe(
headers=["Available models"],
datatype=["str"],
interactive=False,
type="array",
value=list_models,
)
with gr.Column(scale=1):
with gr.Row():
model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
model_df.select(
on_select_model_df,
None,
model_name,
)
with gr.Row():
with gr.Column():
policy_plot = gr.Plot(label="Policy statistics")
policy_compute_button = gr.Button(value="Compute policy statistics")
policy_plot_button = gr.Button(value="Plot policy statistics")
policy_compute_button.click(
compute_policy_statistics,
inputs=[model_name],
outputs=[policy_plot],
)
policy_plot_button.click(make_policy_plot, outputs=[policy_plot])
with gr.Column():
lrp_plot_hist = gr.Plot(label="LRP history statistics")
with gr.Row():
with gr.Column():
lrp_plot_planes = gr.Plot(label="LRP planes statistics")
with gr.Column():
lrp_plot_pieces = gr.Plot(label="LRP pieces statistics")
with gr.Row():
lrp_compute_button = gr.Button(value="Compute LRP statistics")
with gr.Row():
lrp_plot_button = gr.Button(value="Plot LRP statistics")
lrp_compute_button.click(
compute_lrp_statistics,
inputs=[model_name],
outputs=[lrp_plot_hist, lrp_plot_planes, lrp_plot_pieces],
)
lrp_plot_button.click(
make_lrp_plot,
outputs=[lrp_plot_hist, lrp_plot_planes, lrp_plot_pieces],
)
with gr.Column():
probing_plot = gr.Plot(label="Probing statistics")
probing_compute_button = gr.Button(value="Compute probing statistics")
probing_plot_button = gr.Button(value="Plot probing statistics")
probing_compute_button.click(
compute_probing_statistics,
inputs=[model_name],
outputs=[probing_plot],
)
probing_plot_button.click(make_probing_plot, outputs=[probing_plot])
|