Spaces:
Build error
Build error
File size: 5,405 Bytes
343fa36 |
|
"""
Gradio interface for visualizing the policy of a model.
"""
import gradio as gr
from demo import utils, visualisation
from lczerolens import GameDataset
from lczerolens.xai import ConceptDataset, HasThreatConcept
current_policy_statistics = None
current_lrp_statistics = None
current_probing_statistics = None
dataset = GameDataset("assets/test_stockfish_10.jsonl")
check_concept = HasThreatConcept("K", relative=True)
unique_check_dataset = ConceptDataset.from_game_dataset(dataset)
unique_check_dataset.set_concept(check_concept)
def list_models():
"""
List the models in the model directory.
"""
models_info = utils.get_models_info(leela=False)
return sorted([[model_info[0]] for model_info in models_info])
def on_select_model_df(
evt: gr.SelectData,
):
"""
When a model is selected, update the statement.
"""
return evt.value
def compute_policy_statistics(
model_name,
):
global current_policy_statistics
global dataset
if model_name == "":
gr.Warning(
"Please select a model.",
)
return None
wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "policy")
current_policy_statistics = lens.analyse_dataset(dataset, wrapper, 10)
return make_policy_plot()
def make_policy_plot():
global current_policy_statistics
if current_policy_statistics is None:
gr.Warning(
"Please compute policy statistics first.",
)
return None
else:
return visualisation.render_policy_statistics(current_policy_statistics)
def compute_lrp_statistics(
model_name,
):
global current_lrp_statistics
global dataset
if model_name == "":
gr.Warning(
"Please select a model.",
)
return None, None, None
wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "lrp")
current_lrp_statistics = lens.compute_statistics(dataset, wrapper, 10)
return make_lrp_plot()
def make_lrp_plot():
global current_lrp_statistics
if current_lrp_statistics is None:
gr.Warning(
"Please compute LRP statistics first.",
)
return None, None, None
else:
return visualisation.render_relevance_proportion(current_lrp_statistics)
def compute_probing_statistics(
model_name,
):
global current_probing_statistics
global check_concept
global unique_check_dataset
if model_name == "":
gr.Warning(
"Please select a model.",
)
return None
wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "probing", concept=check_concept)
current_probing_statistics = lens.compute_statistics(unique_check_dataset, wrapper, 10)
return make_probing_plot()
def make_probing_plot():
global current_probing_statistics
if current_probing_statistics is None:
gr.Warning(
"Please compute probing statistics first.",
)
return None
else:
return visualisation.render_probing_statistics(current_probing_statistics)
with gr.Blocks() as interface:
with gr.Row():
with gr.Column(scale=2):
model_df = gr.Dataframe(
headers=["Available models"],
datatype=["str"],
interactive=False,
type="array",
value=list_models,
)
with gr.Column(scale=1):
with gr.Row():
model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
model_df.select(
on_select_model_df,
None,
model_name,
)
with gr.Row():
with gr.Column():
policy_plot = gr.Plot(label="Policy statistics")
policy_compute_button = gr.Button(value="Compute policy statistics")
policy_plot_button = gr.Button(value="Plot policy statistics")
policy_compute_button.click(
compute_policy_statistics,
inputs=[model_name],
outputs=[policy_plot],
)
policy_plot_button.click(make_policy_plot, outputs=[policy_plot])
with gr.Column():
lrp_plot_hist = gr.Plot(label="LRP history statistics")
with gr.Row():
with gr.Column():
lrp_plot_planes = gr.Plot(label="LRP planes statistics")
with gr.Column():
lrp_plot_pieces = gr.Plot(label="LRP pieces statistics")
with gr.Row():
lrp_compute_button = gr.Button(value="Compute LRP statistics")
with gr.Row():
lrp_plot_button = gr.Button(value="Plot LRP statistics")
lrp_compute_button.click(
compute_lrp_statistics,
inputs=[model_name],
outputs=[lrp_plot_hist, lrp_plot_planes, lrp_plot_pieces],
)
lrp_plot_button.click(
make_lrp_plot,
outputs=[lrp_plot_hist, lrp_plot_planes, lrp_plot_pieces],
)
with gr.Column():
probing_plot = gr.Plot(label="Probing statistics")
probing_compute_button = gr.Button(value="Compute probing statistics")
probing_plot_button = gr.Button(value="Plot probing statistics")
probing_compute_button.click(
compute_probing_statistics,
inputs=[model_name],
outputs=[probing_plot],
)
probing_plot_button.click(make_probing_plot, outputs=[probing_plot])
|