fabio-deep
added links
146a6ea
raw
history blame
28.8 kB
import torch
import numpy as np
import gradio as gr
import matplotlib.pylab as plt
import torch.nn.functional as F
from vae import HVAE
from datasets import morphomnist, ukbb, mimic, get_attr_max_min
from pgm.flow_pgm import MorphoMNISTPGM, FlowPGM, ChestPGM
from app_utils import (
mnist_graph,
brain_graph,
chest_graph,
vae_preprocess,
normalize,
preprocess_brain,
get_fig_arr,
postprocess,
MidpointNormalize,
)
DATA, MODELS = {}, {}
for k in ["Morpho-MNIST", "Brain MRI", "Chest X-ray"]:
DATA[k], MODELS[k] = {}, {}
# mnist
DIGITS = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
# brain
MRISEQ_CAT = ["T1", "T2-FLAIR"] # 0,1
SEX_CAT = ["female", "male"] # 0,1
HEIGHT, WIDTH = 270, 270
# chest
SEX_CAT_CHEST = ["male", "female"] # 0,1
RACE_CAT = ["white", "asian", "black"] # 0,1,2
FIND_CAT = ["no disease", "pleural effusion"]
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class Hparams:
def update(self, dict):
for k, v in dict.items():
setattr(self, k, v)
def get_paths(dataset_id):
if "MNIST" in dataset_id:
data_path = "./data/morphomnist"
pgm_path = "./checkpoints/t_i_d/sup_pgm/checkpoint.pt"
vae_path = "./checkpoints/t_i_d/dgauss_cond_big_beta1_dropexo/checkpoint.pt"
elif "Brain" in dataset_id:
data_path = "./data/ukbb_subset"
pgm_path = "./checkpoints/m_b_v_s/sup_pgm/checkpoint.pt"
vae_path = "./checkpoints/m_b_v_s/ukbb192_beta5_dgauss_b33/checkpoint.pt"
elif "Chest" in dataset_id:
data_path = "./data/mimic_subset"
pgm_path = "./checkpoints/a_r_s_f/sup_pgm_mimic/checkpoint.pt"
vae_path = [
"./checkpoints/a_r_s_f/mimic_beta9_gelu_dgauss_1_lr3/checkpoint.pt", # base vae
"./checkpoints/a_r_s_f/mimic_dscm_lr_1e5_lagrange_lr_1_damping_10/6500_checkpoint.pt", # cf trained DSCM
]
return data_path, vae_path, pgm_path
def load_pgm(dataset_id, pgm_path):
checkpoint = torch.load(pgm_path, map_location=DEVICE)
args = Hparams()
args.update(checkpoint["hparams"])
args.device = DEVICE
if "MNIST" in dataset_id:
pgm = MorphoMNISTPGM(args).to(args.device)
elif "Brain" in dataset_id:
pgm = FlowPGM(args).to(args.device)
elif "Chest" in dataset_id:
pgm = ChestPGM(args).to(args.device)
pgm.load_state_dict(checkpoint["ema_model_state_dict"])
MODELS[dataset_id]["pgm"] = pgm
MODELS[dataset_id]["pgm_args"] = args
def load_vae(dataset_id, vae_path):
if "Chest" in dataset_id:
vae_path, dscm_path = vae_path[0], vae_path[1]
checkpoint = torch.load(vae_path, map_location=DEVICE)
args = Hparams()
args.update(checkpoint["hparams"])
# backwards compatibility hack
if not hasattr(args, "vae"):
args.vae = "hierarchical"
if not hasattr(args, "cond_prior"):
args.cond_prior = False
if hasattr(args, "free_bits"):
args.kl_free_bits = args.free_bits
args.device = DEVICE
vae = HVAE(args).to(args.device)
if "Chest" in dataset_id:
dscm_ckpt = torch.load(dscm_path, map_location=DEVICE)
vae.load_state_dict(
{
k[4:]: v
for k, v in dscm_ckpt["ema_model_state_dict"].items()
if "vae." in k
}
)
else:
vae.load_state_dict(checkpoint["ema_model_state_dict"])
MODELS[dataset_id]["vae"] = vae
MODELS[dataset_id]["vae_args"] = args
def get_dataloader(dataset_id, data_path):
MODELS[dataset_id]["pgm_args"].data_dir = data_path
args = MODELS[dataset_id]["pgm_args"]
if "MNIST" in dataset_id:
datasets = morphomnist(args)
elif "Brain" in dataset_id:
datasets = ukbb(args)
elif "Chest" in dataset_id:
datasets = mimic(args)
DATA[dataset_id]["test"] = torch.utils.data.DataLoader(
datasets["test"], shuffle=False, batch_size=args.bs, num_workers=4
)
def load_dataset(dataset_id):
data_path, _, pgm_path = get_paths(dataset_id)
checkpoint = torch.load(pgm_path, map_location=DEVICE)
args = Hparams()
args.update(checkpoint["hparams"])
args.device = DEVICE
MODELS[dataset_id]["pgm_args"] = args
get_dataloader(dataset_id, data_path)
def load_model(dataset_id):
_, vae_path, pgm_path = get_paths(dataset_id)
load_pgm(dataset_id, pgm_path)
load_vae(dataset_id, vae_path)
@torch.no_grad()
def counterfactual_inference(dataset_id, obs, do_pa):
pa = {k: v.clone() for k, v in obs.items() if k != "x"}
cf_pa = MODELS[dataset_id]["pgm"].counterfactual(
obs=pa, intervention=do_pa, num_particles=1
)
args, vae = MODELS[dataset_id]["vae_args"], MODELS[dataset_id]["vae"]
_pa = vae_preprocess(args, {k: v.clone() for k, v in pa.items()})
_cf_pa = vae_preprocess(args, {k: v.clone() for k, v in cf_pa.items()})
z_t = 0.1 if "mnist" in args.hps else 1.0
z = vae.abduct(x=obs["x"], parents=_pa, t=z_t)
if vae.cond_prior:
z = [z[j]["z"] for j in range(len(z))]
px_loc, px_scale = vae.forward_latents(latents=z, parents=_pa)
cf_loc, cf_scale = vae.forward_latents(latents=z, parents=_cf_pa)
u = (obs["x"] - px_loc) / px_scale.clamp(min=1e-12)
u_t = 0.1 if "mnist" in args.hps else 1.0 # cf sampling temp
cf_scale = cf_scale * u_t
cf_x = torch.clamp(cf_loc + cf_scale * u, min=-1, max=1)
return {"cf_x": cf_x, "rec_x": px_loc, "cf_pa": cf_pa}
def get_obs_item(dataset_id, idx=None):
if idx is None:
n_test = len(DATA[dataset_id]["test"].dataset)
idx = torch.randperm(n_test)[0]
idx = int(idx)
return idx, DATA[dataset_id]["test"].dataset.__getitem__(idx)
def get_mnist_obs(idx=None):
dataset_id = "Morpho-MNIST"
if not DATA[dataset_id]:
load_dataset(dataset_id)
idx, obs = get_obs_item(dataset_id, idx)
x = get_fig_arr(obs["x"].clone().squeeze().numpy())
t = (obs["thickness"].clone() + 1) / 2 * (6.255515 - 0.87598526) + 0.87598526
i = (obs["intensity"].clone() + 1) / 2 * (254.90317 - 66.601204) + 66.601204
y = DIGITS[obs["digit"].clone().argmax(-1)]
return (idx, x, float(np.round(t, 2)), float(np.round(i, 2)), y)
def get_brain_obs(idx=None):
dataset_id = "Brain MRI"
if not DATA[dataset_id]:
load_dataset(dataset_id)
idx, obs = get_obs_item(dataset_id, idx)
x = get_fig_arr(obs["x"].clone().squeeze().numpy())
m = MRISEQ_CAT[int(obs["mri_seq"].clone().item())]
s = SEX_CAT[int(obs["sex"].clone().item())]
a = obs["age"].clone().item()
b = obs["brain_volume"].clone().item() / 1000 # in ml
v = obs["ventricle_volume"].clone().item() / 1000 # in ml
return (idx, x, m, s, a, float(np.round(b, 2)), float(np.round(v, 2)))
def get_chest_obs(idx=None):
dataset_id = "Chest X-ray"
if not DATA[dataset_id]:
load_dataset(dataset_id)
idx, obs = get_obs_item(dataset_id, idx)
x = get_fig_arr(postprocess(obs["x"].clone()))
s = SEX_CAT_CHEST[int(obs["sex"].clone().squeeze().numpy())]
f = FIND_CAT[int(obs["finding"].clone().squeeze().numpy())]
r = RACE_CAT[obs["race"].clone().squeeze().numpy().argmax(-1)]
a = (obs["age"].clone().squeeze().numpy() + 1) * 50
return (idx, x, r, s, f, float(np.round(a, 1)))
def infer_mnist_cf(*args):
dataset_id = "Morpho-MNIST"
idx, _, t, i, y, do_t, do_i, do_y = args
n_particles = 32
# preprocess
obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx))
obs["x"] = (obs["x"] - 127.5) / 127.5
for k, v in obs.items():
obs[k] = v.view(1, 1) if len(v.shape) < 1 else v.unsqueeze(0)
obs[k] = obs[k].to(MODELS[dataset_id]["vae_args"].device).float()
if n_particles > 1:
ndims = (1,) * 3 if k == "x" else (1,)
obs[k] = obs[k].repeat(n_particles, *ndims)
# intervention(s)
do_pa = {}
if do_t:
do_pa["thickness"] = torch.tensor(
normalize(t, x_max=6.255515, x_min=0.87598526)
).view(1, 1)
if do_i:
do_pa["intensity"] = torch.tensor(
normalize(i, x_max=254.90317, x_min=66.601204)
).view(1, 1)
if do_y:
do_pa["digit"] = F.one_hot(torch.tensor(DIGITS.index(y)), num_classes=10).view(
1, 10
)
for k, v in do_pa.items():
do_pa[k] = (
v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1)
)
# infer counterfactual
out = counterfactual_inference(dataset_id, obs, do_pa)
# avg cf particles
cf_x = out["cf_x"].mean(0)
cf_x_std = out["cf_x"].std(0)
rec_x = out["rec_x"].mean(0)
cf_t = out["cf_pa"]["thickness"].mean(0)
cf_i = out["cf_pa"]["intensity"].mean(0)
cf_y = out["cf_pa"]["digit"].mean(0)
# post process
cf_x = postprocess(cf_x)
cf_x_std = cf_x_std.squeeze().detach().cpu().numpy()
rec_x = postprocess(rec_x)
cf_t = np.round((cf_t.item() + 1) / 2 * (6.255515 - 0.87598526) + 0.87598526, 2)
cf_i = np.round((cf_i.item() + 1) / 2 * (254.90317 - 66.601204) + 66.601204, 2)
cf_y = DIGITS[cf_y.argmax(-1)]
# plots
# plt.close('all')
effect = cf_x - rec_x
effect = get_fig_arr(
effect, cmap="RdBu_r", norm=MidpointNormalize(vmin=-255, midpoint=0, vmax=255)
)
cf_x = get_fig_arr(cf_x)
cf_x_std = get_fig_arr(cf_x_std, cmap="jet")
return (cf_x, cf_x_std, effect, cf_t, cf_i, cf_y)
def infer_brain_cf(*args):
dataset_id = "Brain MRI"
idx, _, m, s, a, b, v = args[:7]
do_m, do_s, do_a, do_b, do_v = args[7:]
n_particles = 16
# preprocessing
obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx))
obs = preprocess_brain(MODELS[dataset_id]["vae_args"], obs)
for k, _v in obs.items():
if n_particles > 1:
ndims = (1,) * 3 if k == "x" else (1,)
obs[k] = _v.repeat(n_particles, *ndims)
# interventions(s)
do_pa = {}
if do_m:
do_pa["mri_seq"] = torch.tensor(MRISEQ_CAT.index(m)).view(1, 1)
if do_s:
do_pa["sex"] = torch.tensor(SEX_CAT.index(s)).view(1, 1)
if do_a:
do_pa["age"] = torch.tensor(a).view(1, 1)
if do_b:
do_pa["brain_volume"] = torch.tensor(b * 1000).view(1, 1)
if do_v:
do_pa["ventricle_volume"] = torch.tensor(v * 1000).view(1, 1)
# normalize continuous attributes
for k in ["age", "brain_volume", "ventricle_volume"]:
if k in do_pa.keys():
k_max, k_min = get_attr_max_min(k)
do_pa[k] = (do_pa[k] - k_min) / (k_max - k_min) # [0,1]
do_pa[k] = 2 * do_pa[k] - 1 # [-1,1]
for k, _v in do_pa.items():
do_pa[k] = (
_v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1)
)
# infer counterfactual
out = counterfactual_inference(dataset_id, obs, do_pa)
# avg cf particles
cf_x = out["cf_x"].mean(0)
cf_x_std = out["cf_x"].std(0)
rec_x = out["rec_x"].mean(0)
cf_m = out["cf_pa"]["mri_seq"].mean(0)
cf_s = out["cf_pa"]["sex"].mean(0)
# post process
cf_x = postprocess(cf_x)
cf_x_std = cf_x_std.squeeze().detach().cpu().numpy()
rec_x = postprocess(rec_x)
cf_m = MRISEQ_CAT[int(cf_m.item())]
cf_s = SEX_CAT[int(cf_s.item())]
cf_ = {}
for k in ["age", "brain_volume", "ventricle_volume"]: # unnormalize
k_max, k_min = get_attr_max_min(k)
cf_[k] = (out["cf_pa"][k].mean(0).item() + 1) / 2 * (k_max - k_min) + k_min
# plots
# plt.close('all')
effect = cf_x - rec_x
effect = get_fig_arr(
effect,
cmap="RdBu_r",
norm=MidpointNormalize(vmin=effect.min(), midpoint=0, vmax=effect.max()),
)
cf_x = get_fig_arr(cf_x)
cf_x_std = get_fig_arr(cf_x_std, cmap="jet")
return (
cf_x,
cf_x_std,
effect,
cf_m,
cf_s,
np.round(cf_["age"], 1),
np.round(cf_["brain_volume"] / 1000, 2),
np.round(cf_["ventricle_volume"] / 1000, 2),
)
def infer_chest_cf(*args):
dataset_id = "Chest X-ray"
idx, _, r, s, f, a = args[:6]
do_r, do_s, do_f, do_a = args[6:]
n_particles = 16
# preprocessing
obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx))
for k, v in obs.items():
obs[k] = v.to(MODELS[dataset_id]["vae_args"].device).float()
if n_particles > 1:
ndims = (1,) * 3 if k == "x" else (1,)
obs[k] = obs[k].repeat(n_particles, *ndims)
# intervention(s)
do_pa = {}
with torch.no_grad():
if do_s:
do_pa["sex"] = torch.tensor(SEX_CAT_CHEST.index(s)).view(1, 1)
if do_f:
do_pa["finding"] = torch.tensor(FIND_CAT.index(f)).view(1, 1)
if do_r:
do_pa["race"] = F.one_hot(
torch.tensor(RACE_CAT.index(r)), num_classes=3
).view(1, 3)
if do_a:
do_pa["age"] = torch.tensor(a / 100 * 2 - 1).view(1, 1)
for k, v in do_pa.items():
do_pa[k] = (
v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1)
)
# infer counterfactual
out = counterfactual_inference(dataset_id, obs, do_pa)
# avg cf particles
cf_x = out["cf_x"].mean(0)
cf_x_std = out["cf_x"].std(0)
rec_x = out["rec_x"].mean(0)
cf_r = out["cf_pa"]["race"].mean(0)
cf_s = out["cf_pa"]["sex"].mean(0)
cf_f = out["cf_pa"]["finding"].mean(0)
cf_a = out["cf_pa"]["age"].mean(0)
# post process
cf_x = postprocess(cf_x)
cf_x_std = cf_x_std.squeeze().detach().cpu().numpy()
rec_x = postprocess(rec_x)
cf_r = RACE_CAT[cf_r.argmax(-1)]
cf_s = SEX_CAT_CHEST[int(cf_s.item())]
cf_f = FIND_CAT[int(cf_f.item())]
cf_a = (cf_a.item() + 1) * 50
# plots
# plt.close('all')
effect = cf_x - rec_x
effect = get_fig_arr(
effect,
cmap="RdBu_r",
norm=MidpointNormalize(vmin=effect.min(), midpoint=0, vmax=effect.max()),
)
cf_x = get_fig_arr(cf_x)
cf_x_std = get_fig_arr(cf_x_std, cmap="jet")
return (cf_x, cf_x_std, effect, cf_r, cf_s, cf_f, np.round(cf_a, 1))
with gr.Blocks(theme=gr.themes.Default()) as demo:
with gr.Tabs():
with gr.TabItem("Morpho-MNIST") as mnist_tab:
mnist_id = gr.Textbox(value=mnist_tab.label, visible=False)
with gr.Row().style(equal_height=True):
idx = gr.Number(value=0, visible=False)
with gr.Column(scale=1, min_width=200):
x = gr.Image(label="Observation", interactive=False).style(
height=HEIGHT
)
with gr.Column(scale=1, min_width=200):
cf_x = gr.Image(label="Counterfactual", interactive=False).style(
height=HEIGHT
)
with gr.Column(scale=1, min_width=200):
cf_x_std = gr.Image(
label="Counterfactual Uncertainty", interactive=False
).style(height=HEIGHT)
with gr.Column(scale=1, min_width=200):
effect = gr.Image(
label="Direct Causal Effect", interactive=False
).style(height=HEIGHT)
with gr.Row().style(equal_height=True):
with gr.Column(scale=1.75):
gr.Markdown(
"#### Intervention"
+ 28 * "&emsp;"
+ "[arXiv paper](https://arxiv.org/abs/2306.15764) &ensp; | &ensp; [GitHub code](https://github.com/biomedia-mira/causal-gen)"
)
with gr.Column():
do_y = gr.Checkbox(label="do(digit)", value=False)
y = gr.Radio(DIGITS, label="", interactive=False)
with gr.Row():
with gr.Column(min_width=100):
do_t = gr.Checkbox(label="do(thickness)", value=False)
t = gr.Slider(
label="\u00A0",
minimum=0.9,
maximum=5.5,
step=0.01,
interactive=False,
)
with gr.Column(min_width=100):
do_i = gr.Checkbox(label="do(intensity)", value=False)
i = gr.Slider(
label="\u00A0",
minimum=50,
maximum=255,
step=0.01,
interactive=False,
)
with gr.Row():
new = gr.Button("New Observation")
reset = gr.Button("Reset", variant="stop")
submit = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
gr.Markdown("### &nbsp;")
causal_graph = gr.Image(
label="Causal Graph", interactive=False
).style(height=300)
with gr.TabItem("Brain MRI") as brain_tab:
brain_id = gr.Textbox(value=brain_tab.label, visible=False)
with gr.Row().style(equal_height=True):
idx_brain = gr.Number(value=0, visible=False)
with gr.Column(scale=1, min_width=200):
x_brain = gr.Image(label="Observation", interactive=False).style(
height=HEIGHT
)
with gr.Column(scale=1, min_width=200):
cf_x_brain = gr.Image(
label="Counterfactual", interactive=False
).style(height=HEIGHT)
with gr.Column(scale=1, min_width=200):
cf_x_std_brain = gr.Image(
label="Counterfactual Uncertainty", interactive=False
).style(height=HEIGHT)
with gr.Column(scale=1, min_width=200):
effect_brain = gr.Image(
label="Direct Causal Effect", interactive=False
).style(height=HEIGHT)
with gr.Row():
with gr.Column(scale=2.55):
gr.Markdown(
"#### Intervention"
+ 28 * "&emsp;"
+ "[arXiv paper](https://arxiv.org/abs/2306.15764) &ensp; | &ensp; [GitHub code](https://github.com/biomedia-mira/causal-gen)"
)
with gr.Row():
with gr.Column(min_width=200):
do_m = gr.Checkbox(label="do(MRI sequence)", value=False)
m = gr.Radio(
["T1", "T2-FLAIR"], label="", interactive=False
)
with gr.Column(min_width=200):
do_s = gr.Checkbox(label="do(sex)", value=False)
s = gr.Radio(
["female", "male"], label="", interactive=False
)
with gr.Row():
with gr.Column(min_width=100):
do_a = gr.Checkbox(label="do(age)", value=False)
a = gr.Slider(
label="\u00A0",
value=50,
minimum=44,
maximum=73,
step=1,
interactive=False,
)
with gr.Column(min_width=100):
do_b = gr.Checkbox(label="do(brain volume)", value=False)
b = gr.Slider(
label="\u00A0",
value=1000,
minimum=850,
maximum=1550,
step=20,
interactive=False,
)
with gr.Column(min_width=100):
do_v = gr.Checkbox(
label="do(ventricle volume)", value=False
)
v = gr.Slider(
label="\u00A0",
value=40,
minimum=10,
maximum=125,
step=2,
interactive=False,
)
with gr.Row():
new_brain = gr.Button("New Observation")
reset_brain = gr.Button("Reset", variant="stop")
submit_brain = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
# gr.Markdown("### &nbsp;")
causal_graph_brain = gr.Image(
label="Causal Graph", interactive=False
).style(height=340)
with gr.TabItem("Chest X-ray") as chest_tab:
chest_id = gr.Textbox(value=chest_tab.label, visible=False)
with gr.Row().style(equal_height=True):
idx_chest = gr.Number(value=0, visible=False)
with gr.Column(scale=1, min_width=200):
x_chest = gr.Image(label="Observation", interactive=False).style(
height=HEIGHT
)
with gr.Column(scale=1, min_width=200):
cf_x_chest = gr.Image(
label="Counterfactual", interactive=False
).style(height=HEIGHT)
with gr.Column(scale=1, min_width=200):
cf_x_std_chest = gr.Image(
label="Counterfactual Uncertainty", interactive=False
).style(height=HEIGHT)
with gr.Column(scale=1, min_width=200):
effect_chest = gr.Image(
label="Direct Causal Effect", interactive=False
).style(height=HEIGHT)
with gr.Row():
with gr.Column(scale=2.55):
gr.Markdown(
"#### Intervention"
+ 28 * "&emsp;"
+ "[arXiv paper](https://arxiv.org/abs/2306.15764) &ensp; | &ensp; [GitHub code](https://github.com/biomedia-mira/causal-gen)"
)
with gr.Row().style(equal_height=True):
with gr.Column(min_width=200):
do_f_chest = gr.Checkbox(label="do(disease)", value=False)
f_chest = gr.Radio(FIND_CAT, label="", interactive=False)
with gr.Column(min_width=200):
do_s_chest = gr.Checkbox(label="do(sex)", value=False)
s_chest = gr.Radio(
SEX_CAT_CHEST, label="", interactive=False
)
with gr.Row():
with gr.Column(min_width=200):
do_r_chest = gr.Checkbox(label="do(race)", value=False)
r_chest = gr.Radio(RACE_CAT, label="", interactive=False)
with gr.Column(min_width=200):
do_a_chest = gr.Checkbox(label="do(age)", value=False)
a_chest = gr.Slider(
label="\u00A0", minimum=18, maximum=98, step=1
)
with gr.Row():
new_chest = gr.Button("New Observation")
reset_chest = gr.Button("Reset", variant="stop")
submit_chest = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
# gr.Markdown("### &nbsp;")
causal_graph_chest = gr.Image(
label="Causal Graph", interactive=False
).style(height=345)
# morphomnist
do = [do_t, do_i, do_y]
obs = [idx, x, t, i, y]
cf_out = [cf_x, cf_x_std, effect]
# brain
do_brain = [do_m, do_s, do_a, do_b, do_v] # intervention checkboxes
obs_brain = [idx_brain, x_brain, m, s, a, b, v] # observed image/attributes
cf_out_brain = [cf_x_brain, cf_x_std_brain, effect_brain] # counterfactual outputs
# chest
do_chest = [do_r_chest, do_s_chest, do_f_chest, do_a_chest]
obs_chest = [idx_chest, x_chest, r_chest, s_chest, f_chest, a_chest]
cf_out_chest = [cf_x_chest, cf_x_std_chest, effect_chest]
# on start: load new observations & causal graph
demo.load(fn=get_mnist_obs, inputs=None, outputs=obs)
demo.load(fn=mnist_graph, inputs=do, outputs=causal_graph)
demo.load(fn=load_model, inputs=mnist_id, outputs=None)
demo.load(fn=get_brain_obs, inputs=None, outputs=obs_brain)
demo.load(fn=get_chest_obs, inputs=None, outputs=obs_chest)
demo.load(fn=brain_graph, inputs=do_brain, outputs=causal_graph_brain)
demo.load(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest)
# on tab select: load models
brain_tab.select(fn=load_model, inputs=brain_id, outputs=None)
chest_tab.select(fn=load_model, inputs=chest_id, outputs=None)
# "new" button: load new observations
new.click(fn=get_mnist_obs, inputs=None, outputs=obs)
new_chest.click(fn=get_chest_obs, inputs=None, outputs=obs_chest)
new_brain.click(fn=get_brain_obs, inputs=None, outputs=obs_brain)
# "new" button: reset causal graphs
new.click(fn=mnist_graph, inputs=do, outputs=causal_graph)
new_brain.click(fn=brain_graph, inputs=do_brain, outputs=causal_graph_brain)
new_chest.click(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest)
# "new" button: reset cf output panels
for _k, _v in zip(
[new, new_brain, new_chest], [cf_out, cf_out_brain, cf_out_chest]
):
_k.click(fn=lambda: (gr.update(value=None),) * 3, inputs=None, outputs=_v)
# "reset" button: reload current observations
reset.click(fn=get_mnist_obs, inputs=idx, outputs=obs)
reset_brain.click(fn=get_brain_obs, inputs=idx_brain, outputs=obs_brain)
reset_chest.click(fn=get_chest_obs, inputs=idx_chest, outputs=obs_chest)
# "reset" button: deselect intervention checkboxes
reset.click(fn=lambda: (gr.update(value=False),) * len(do), inputs=None, outputs=do)
reset_brain.click(
fn=lambda: (gr.update(value=False),) * len(do_brain),
inputs=None,
outputs=do_brain,
)
reset_chest.click(
fn=lambda: (gr.update(value=False),) * len(do_chest),
inputs=None,
outputs=do_chest,
)
# "reset" button: reset cf output panels
for _k, _v in zip(
[reset, reset_brain, reset_chest], [cf_out, cf_out_brain, cf_out_chest]
):
_k.click(fn=lambda: plt.close("all"), inputs=None, outputs=None)
_k.click(fn=lambda: (gr.update(value=None),) * 3, inputs=None, outputs=_v)
# enable mnist interventions when checkbox is selected & update graph
for _k, _v in zip(do, [t, i, y]):
_k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v)
_k.change(mnist_graph, inputs=do, outputs=causal_graph)
# enable brain interventions when checkbox is selected & update graph
for _k, _v in zip(do_brain, [m, s, a, b, v]):
_k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v)
_k.change(brain_graph, inputs=do_brain, outputs=causal_graph_brain)
# enable chest interventions when checkbox is selected & update graph
for _k, _v in zip(do_chest, [r_chest, s_chest, f_chest, a_chest]):
_k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v)
_k.change(chest_graph, inputs=do_chest, outputs=causal_graph_chest)
# "submit" button: infer countefactuals
submit.click(fn=infer_mnist_cf, inputs=obs + do, outputs=cf_out + [t, i, y])
submit_brain.click(
fn=infer_brain_cf,
inputs=obs_brain + do_brain,
outputs=cf_out_brain + [m, s, a, b, v],
)
submit_chest.click(
fn=infer_chest_cf,
inputs=obs_chest + do_chest,
outputs=cf_out_chest + [r_chest, s_chest, f_chest, a_chest],
)
if __name__ == "__main__":
demo.queue()
demo.launch()