Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
import gradio as gr | |
import matplotlib.pylab as plt | |
import torch.nn.functional as F | |
from vae import HVAE | |
from datasets import morphomnist, ukbb, mimic, get_attr_max_min | |
from pgm.flow_pgm import MorphoMNISTPGM, FlowPGM, ChestPGM | |
from app_utils import ( | |
mnist_graph, | |
brain_graph, | |
chest_graph, | |
vae_preprocess, | |
normalize, | |
preprocess_brain, | |
get_fig_arr, | |
postprocess, | |
MidpointNormalize, | |
) | |
DATA, MODELS = {}, {} | |
for k in ["Morpho-MNIST", "Brain MRI", "Chest X-ray"]: | |
DATA[k], MODELS[k] = {}, {} | |
# mnist | |
DIGITS = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] | |
# brain | |
MRISEQ_CAT = ["T1", "T2-FLAIR"] # 0,1 | |
SEX_CAT = ["female", "male"] # 0,1 | |
HEIGHT, WIDTH = 270, 270 | |
# chest | |
SEX_CAT_CHEST = ["male", "female"] # 0,1 | |
RACE_CAT = ["white", "asian", "black"] # 0,1,2 | |
FIND_CAT = ["no disease", "pleural effusion"] | |
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
class Hparams: | |
def update(self, dict): | |
for k, v in dict.items(): | |
setattr(self, k, v) | |
def get_paths(dataset_id): | |
if "MNIST" in dataset_id: | |
data_path = "./data/morphomnist" | |
pgm_path = "./checkpoints/t_i_d/sup_pgm/checkpoint.pt" | |
vae_path = "./checkpoints/t_i_d/dgauss_cond_big_beta1_dropexo/checkpoint.pt" | |
elif "Brain" in dataset_id: | |
data_path = "./data/ukbb_subset" | |
pgm_path = "./checkpoints/m_b_v_s/sup_pgm/checkpoint.pt" | |
vae_path = "./checkpoints/m_b_v_s/ukbb192_beta5_dgauss_b33/checkpoint.pt" | |
elif "Chest" in dataset_id: | |
data_path = "./data/mimic_subset" | |
pgm_path = "./checkpoints/a_r_s_f/sup_pgm_mimic/checkpoint.pt" | |
vae_path = [ | |
"./checkpoints/a_r_s_f/mimic_beta9_gelu_dgauss_1_lr3/checkpoint.pt", # base vae | |
"./checkpoints/a_r_s_f/mimic_dscm_lr_1e5_lagrange_lr_1_damping_10/6500_checkpoint.pt", # cf trained DSCM | |
] | |
return data_path, vae_path, pgm_path | |
def load_pgm(dataset_id, pgm_path): | |
checkpoint = torch.load(pgm_path, map_location=DEVICE) | |
args = Hparams() | |
args.update(checkpoint["hparams"]) | |
args.device = DEVICE | |
if "MNIST" in dataset_id: | |
pgm = MorphoMNISTPGM(args).to(args.device) | |
elif "Brain" in dataset_id: | |
pgm = FlowPGM(args).to(args.device) | |
elif "Chest" in dataset_id: | |
pgm = ChestPGM(args).to(args.device) | |
pgm.load_state_dict(checkpoint["ema_model_state_dict"]) | |
MODELS[dataset_id]["pgm"] = pgm | |
MODELS[dataset_id]["pgm_args"] = args | |
def load_vae(dataset_id, vae_path): | |
if "Chest" in dataset_id: | |
vae_path, dscm_path = vae_path[0], vae_path[1] | |
checkpoint = torch.load(vae_path, map_location=DEVICE) | |
args = Hparams() | |
args.update(checkpoint["hparams"]) | |
# backwards compatibility hack | |
if not hasattr(args, "vae"): | |
args.vae = "hierarchical" | |
if not hasattr(args, "cond_prior"): | |
args.cond_prior = False | |
if hasattr(args, "free_bits"): | |
args.kl_free_bits = args.free_bits | |
args.device = DEVICE | |
vae = HVAE(args).to(args.device) | |
if "Chest" in dataset_id: | |
dscm_ckpt = torch.load(dscm_path, map_location=DEVICE) | |
vae.load_state_dict( | |
{ | |
k[4:]: v | |
for k, v in dscm_ckpt["ema_model_state_dict"].items() | |
if "vae." in k | |
} | |
) | |
else: | |
vae.load_state_dict(checkpoint["ema_model_state_dict"]) | |
MODELS[dataset_id]["vae"] = vae | |
MODELS[dataset_id]["vae_args"] = args | |
def get_dataloader(dataset_id, data_path): | |
MODELS[dataset_id]["pgm_args"].data_dir = data_path | |
args = MODELS[dataset_id]["pgm_args"] | |
if "MNIST" in dataset_id: | |
datasets = morphomnist(args) | |
elif "Brain" in dataset_id: | |
datasets = ukbb(args) | |
elif "Chest" in dataset_id: | |
datasets = mimic(args) | |
DATA[dataset_id]["test"] = torch.utils.data.DataLoader( | |
datasets["test"], shuffle=False, batch_size=args.bs, num_workers=4 | |
) | |
def load_dataset(dataset_id): | |
data_path, _, pgm_path = get_paths(dataset_id) | |
checkpoint = torch.load(pgm_path, map_location=DEVICE) | |
args = Hparams() | |
args.update(checkpoint["hparams"]) | |
args.device = DEVICE | |
MODELS[dataset_id]["pgm_args"] = args | |
get_dataloader(dataset_id, data_path) | |
def load_model(dataset_id): | |
_, vae_path, pgm_path = get_paths(dataset_id) | |
load_pgm(dataset_id, pgm_path) | |
load_vae(dataset_id, vae_path) | |
def counterfactual_inference(dataset_id, obs, do_pa): | |
pa = {k: v.clone() for k, v in obs.items() if k != "x"} | |
cf_pa = MODELS[dataset_id]["pgm"].counterfactual( | |
obs=pa, intervention=do_pa, num_particles=1 | |
) | |
args, vae = MODELS[dataset_id]["vae_args"], MODELS[dataset_id]["vae"] | |
_pa = vae_preprocess(args, {k: v.clone() for k, v in pa.items()}) | |
_cf_pa = vae_preprocess(args, {k: v.clone() for k, v in cf_pa.items()}) | |
z_t = 0.1 if "mnist" in args.hps else 1.0 | |
z = vae.abduct(x=obs["x"], parents=_pa, t=z_t) | |
if vae.cond_prior: | |
z = [z[j]["z"] for j in range(len(z))] | |
px_loc, px_scale = vae.forward_latents(latents=z, parents=_pa) | |
cf_loc, cf_scale = vae.forward_latents(latents=z, parents=_cf_pa) | |
u = (obs["x"] - px_loc) / px_scale.clamp(min=1e-12) | |
u_t = 0.1 if "mnist" in args.hps else 1.0 # cf sampling temp | |
cf_scale = cf_scale * u_t | |
cf_x = torch.clamp(cf_loc + cf_scale * u, min=-1, max=1) | |
return {"cf_x": cf_x, "rec_x": px_loc, "cf_pa": cf_pa} | |
def get_obs_item(dataset_id, idx=None): | |
if idx is None: | |
n_test = len(DATA[dataset_id]["test"].dataset) | |
idx = torch.randperm(n_test)[0] | |
idx = int(idx) | |
return idx, DATA[dataset_id]["test"].dataset.__getitem__(idx) | |
def get_mnist_obs(idx=None): | |
dataset_id = "Morpho-MNIST" | |
if not DATA[dataset_id]: | |
load_dataset(dataset_id) | |
idx, obs = get_obs_item(dataset_id, idx) | |
x = get_fig_arr(obs["x"].clone().squeeze().numpy()) | |
t = (obs["thickness"].clone() + 1) / 2 * (6.255515 - 0.87598526) + 0.87598526 | |
i = (obs["intensity"].clone() + 1) / 2 * (254.90317 - 66.601204) + 66.601204 | |
y = DIGITS[obs["digit"].clone().argmax(-1)] | |
return (idx, x, float(np.round(t, 2)), float(np.round(i, 2)), y) | |
def get_brain_obs(idx=None): | |
dataset_id = "Brain MRI" | |
if not DATA[dataset_id]: | |
load_dataset(dataset_id) | |
idx, obs = get_obs_item(dataset_id, idx) | |
x = get_fig_arr(obs["x"].clone().squeeze().numpy()) | |
m = MRISEQ_CAT[int(obs["mri_seq"].clone().item())] | |
s = SEX_CAT[int(obs["sex"].clone().item())] | |
a = obs["age"].clone().item() | |
b = obs["brain_volume"].clone().item() / 1000 # in ml | |
v = obs["ventricle_volume"].clone().item() / 1000 # in ml | |
return (idx, x, m, s, a, float(np.round(b, 2)), float(np.round(v, 2))) | |
def get_chest_obs(idx=None): | |
dataset_id = "Chest X-ray" | |
if not DATA[dataset_id]: | |
load_dataset(dataset_id) | |
idx, obs = get_obs_item(dataset_id, idx) | |
x = get_fig_arr(postprocess(obs["x"].clone())) | |
s = SEX_CAT_CHEST[int(obs["sex"].clone().squeeze().numpy())] | |
f = FIND_CAT[int(obs["finding"].clone().squeeze().numpy())] | |
r = RACE_CAT[obs["race"].clone().squeeze().numpy().argmax(-1)] | |
a = (obs["age"].clone().squeeze().numpy() + 1) * 50 | |
return (idx, x, r, s, f, float(np.round(a, 1))) | |
def infer_mnist_cf(*args): | |
dataset_id = "Morpho-MNIST" | |
idx, _, t, i, y, do_t, do_i, do_y = args | |
n_particles = 32 | |
# preprocess | |
obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx)) | |
obs["x"] = (obs["x"] - 127.5) / 127.5 | |
for k, v in obs.items(): | |
obs[k] = v.view(1, 1) if len(v.shape) < 1 else v.unsqueeze(0) | |
obs[k] = obs[k].to(MODELS[dataset_id]["vae_args"].device).float() | |
if n_particles > 1: | |
ndims = (1,) * 3 if k == "x" else (1,) | |
obs[k] = obs[k].repeat(n_particles, *ndims) | |
# intervention(s) | |
do_pa = {} | |
if do_t: | |
do_pa["thickness"] = torch.tensor( | |
normalize(t, x_max=6.255515, x_min=0.87598526) | |
).view(1, 1) | |
if do_i: | |
do_pa["intensity"] = torch.tensor( | |
normalize(i, x_max=254.90317, x_min=66.601204) | |
).view(1, 1) | |
if do_y: | |
do_pa["digit"] = F.one_hot(torch.tensor(DIGITS.index(y)), num_classes=10).view( | |
1, 10 | |
) | |
for k, v in do_pa.items(): | |
do_pa[k] = ( | |
v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1) | |
) | |
# infer counterfactual | |
out = counterfactual_inference(dataset_id, obs, do_pa) | |
# avg cf particles | |
cf_x = out["cf_x"].mean(0) | |
cf_x_std = out["cf_x"].std(0) | |
rec_x = out["rec_x"].mean(0) | |
cf_t = out["cf_pa"]["thickness"].mean(0) | |
cf_i = out["cf_pa"]["intensity"].mean(0) | |
cf_y = out["cf_pa"]["digit"].mean(0) | |
# post process | |
cf_x = postprocess(cf_x) | |
cf_x_std = cf_x_std.squeeze().detach().cpu().numpy() | |
rec_x = postprocess(rec_x) | |
cf_t = np.round((cf_t.item() + 1) / 2 * (6.255515 - 0.87598526) + 0.87598526, 2) | |
cf_i = np.round((cf_i.item() + 1) / 2 * (254.90317 - 66.601204) + 66.601204, 2) | |
cf_y = DIGITS[cf_y.argmax(-1)] | |
# plots | |
# plt.close('all') | |
effect = cf_x - rec_x | |
effect = get_fig_arr( | |
effect, cmap="RdBu_r", norm=MidpointNormalize(vmin=-255, midpoint=0, vmax=255) | |
) | |
cf_x = get_fig_arr(cf_x) | |
cf_x_std = get_fig_arr(cf_x_std, cmap="jet") | |
return (cf_x, cf_x_std, effect, cf_t, cf_i, cf_y) | |
def infer_brain_cf(*args): | |
dataset_id = "Brain MRI" | |
idx, _, m, s, a, b, v = args[:7] | |
do_m, do_s, do_a, do_b, do_v = args[7:] | |
n_particles = 16 | |
# preprocessing | |
obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx)) | |
obs = preprocess_brain(MODELS[dataset_id]["vae_args"], obs) | |
for k, _v in obs.items(): | |
if n_particles > 1: | |
ndims = (1,) * 3 if k == "x" else (1,) | |
obs[k] = _v.repeat(n_particles, *ndims) | |
# interventions(s) | |
do_pa = {} | |
if do_m: | |
do_pa["mri_seq"] = torch.tensor(MRISEQ_CAT.index(m)).view(1, 1) | |
if do_s: | |
do_pa["sex"] = torch.tensor(SEX_CAT.index(s)).view(1, 1) | |
if do_a: | |
do_pa["age"] = torch.tensor(a).view(1, 1) | |
if do_b: | |
do_pa["brain_volume"] = torch.tensor(b * 1000).view(1, 1) | |
if do_v: | |
do_pa["ventricle_volume"] = torch.tensor(v * 1000).view(1, 1) | |
# normalize continuous attributes | |
for k in ["age", "brain_volume", "ventricle_volume"]: | |
if k in do_pa.keys(): | |
k_max, k_min = get_attr_max_min(k) | |
do_pa[k] = (do_pa[k] - k_min) / (k_max - k_min) # [0,1] | |
do_pa[k] = 2 * do_pa[k] - 1 # [-1,1] | |
for k, _v in do_pa.items(): | |
do_pa[k] = ( | |
_v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1) | |
) | |
# infer counterfactual | |
out = counterfactual_inference(dataset_id, obs, do_pa) | |
# avg cf particles | |
cf_x = out["cf_x"].mean(0) | |
cf_x_std = out["cf_x"].std(0) | |
rec_x = out["rec_x"].mean(0) | |
cf_m = out["cf_pa"]["mri_seq"].mean(0) | |
cf_s = out["cf_pa"]["sex"].mean(0) | |
# post process | |
cf_x = postprocess(cf_x) | |
cf_x_std = cf_x_std.squeeze().detach().cpu().numpy() | |
rec_x = postprocess(rec_x) | |
cf_m = MRISEQ_CAT[int(cf_m.item())] | |
cf_s = SEX_CAT[int(cf_s.item())] | |
cf_ = {} | |
for k in ["age", "brain_volume", "ventricle_volume"]: # unnormalize | |
k_max, k_min = get_attr_max_min(k) | |
cf_[k] = (out["cf_pa"][k].mean(0).item() + 1) / 2 * (k_max - k_min) + k_min | |
# plots | |
# plt.close('all') | |
effect = cf_x - rec_x | |
effect = get_fig_arr( | |
effect, | |
cmap="RdBu_r", | |
norm=MidpointNormalize(vmin=effect.min(), midpoint=0, vmax=effect.max()), | |
) | |
cf_x = get_fig_arr(cf_x) | |
cf_x_std = get_fig_arr(cf_x_std, cmap="jet") | |
return ( | |
cf_x, | |
cf_x_std, | |
effect, | |
cf_m, | |
cf_s, | |
np.round(cf_["age"], 1), | |
np.round(cf_["brain_volume"] / 1000, 2), | |
np.round(cf_["ventricle_volume"] / 1000, 2), | |
) | |
def infer_chest_cf(*args): | |
dataset_id = "Chest X-ray" | |
idx, _, r, s, f, a = args[:6] | |
do_r, do_s, do_f, do_a = args[6:] | |
n_particles = 16 | |
# preprocessing | |
obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx)) | |
for k, v in obs.items(): | |
obs[k] = v.to(MODELS[dataset_id]["vae_args"].device).float() | |
if n_particles > 1: | |
ndims = (1,) * 3 if k == "x" else (1,) | |
obs[k] = obs[k].repeat(n_particles, *ndims) | |
# intervention(s) | |
do_pa = {} | |
with torch.no_grad(): | |
if do_s: | |
do_pa["sex"] = torch.tensor(SEX_CAT_CHEST.index(s)).view(1, 1) | |
if do_f: | |
do_pa["finding"] = torch.tensor(FIND_CAT.index(f)).view(1, 1) | |
if do_r: | |
do_pa["race"] = F.one_hot( | |
torch.tensor(RACE_CAT.index(r)), num_classes=3 | |
).view(1, 3) | |
if do_a: | |
do_pa["age"] = torch.tensor(a / 100 * 2 - 1).view(1, 1) | |
for k, v in do_pa.items(): | |
do_pa[k] = ( | |
v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1) | |
) | |
# infer counterfactual | |
out = counterfactual_inference(dataset_id, obs, do_pa) | |
# avg cf particles | |
cf_x = out["cf_x"].mean(0) | |
cf_x_std = out["cf_x"].std(0) | |
rec_x = out["rec_x"].mean(0) | |
cf_r = out["cf_pa"]["race"].mean(0) | |
cf_s = out["cf_pa"]["sex"].mean(0) | |
cf_f = out["cf_pa"]["finding"].mean(0) | |
cf_a = out["cf_pa"]["age"].mean(0) | |
# post process | |
cf_x = postprocess(cf_x) | |
cf_x_std = cf_x_std.squeeze().detach().cpu().numpy() | |
rec_x = postprocess(rec_x) | |
cf_r = RACE_CAT[cf_r.argmax(-1)] | |
cf_s = SEX_CAT_CHEST[int(cf_s.item())] | |
cf_f = FIND_CAT[int(cf_f.item())] | |
cf_a = (cf_a.item() + 1) * 50 | |
# plots | |
# plt.close('all') | |
effect = cf_x - rec_x | |
effect = get_fig_arr( | |
effect, | |
cmap="RdBu_r", | |
norm=MidpointNormalize(vmin=effect.min(), midpoint=0, vmax=effect.max()), | |
) | |
cf_x = get_fig_arr(cf_x) | |
cf_x_std = get_fig_arr(cf_x_std, cmap="jet") | |
return (cf_x, cf_x_std, effect, cf_r, cf_s, cf_f, np.round(cf_a, 1)) | |
with gr.Blocks(theme=gr.themes.Default()) as demo: | |
with gr.Tabs(): | |
with gr.TabItem("Morpho-MNIST") as mnist_tab: | |
mnist_id = gr.Textbox(value=mnist_tab.label, visible=False) | |
with gr.Row().style(equal_height=True): | |
idx = gr.Number(value=0, visible=False) | |
with gr.Column(scale=1, min_width=200): | |
x = gr.Image(label="Observation", interactive=False).style( | |
height=HEIGHT | |
) | |
with gr.Column(scale=1, min_width=200): | |
cf_x = gr.Image(label="Counterfactual", interactive=False).style( | |
height=HEIGHT | |
) | |
with gr.Column(scale=1, min_width=200): | |
cf_x_std = gr.Image( | |
label="Counterfactual Uncertainty", interactive=False | |
).style(height=HEIGHT) | |
with gr.Column(scale=1, min_width=200): | |
effect = gr.Image( | |
label="Direct Causal Effect", interactive=False | |
).style(height=HEIGHT) | |
with gr.Row().style(equal_height=True): | |
with gr.Column(scale=1.75): | |
gr.Markdown( | |
"#### Intervention" | |
+ 28 * " " | |
+ "[arXiv paper](https://arxiv.org/abs/2306.15764)   |   [GitHub code](https://github.com/biomedia-mira/causal-gen)" | |
) | |
with gr.Column(): | |
do_y = gr.Checkbox(label="do(digit)", value=False) | |
y = gr.Radio(DIGITS, label="", interactive=False) | |
with gr.Row(): | |
with gr.Column(min_width=100): | |
do_t = gr.Checkbox(label="do(thickness)", value=False) | |
t = gr.Slider( | |
label="\u00A0", | |
minimum=0.9, | |
maximum=5.5, | |
step=0.01, | |
interactive=False, | |
) | |
with gr.Column(min_width=100): | |
do_i = gr.Checkbox(label="do(intensity)", value=False) | |
i = gr.Slider( | |
label="\u00A0", | |
minimum=50, | |
maximum=255, | |
step=0.01, | |
interactive=False, | |
) | |
with gr.Row(): | |
new = gr.Button("New Observation") | |
reset = gr.Button("Reset", variant="stop") | |
submit = gr.Button("Submit", variant="primary") | |
with gr.Column(scale=1): | |
gr.Markdown("### ") | |
causal_graph = gr.Image( | |
label="Causal Graph", interactive=False | |
).style(height=300) | |
with gr.TabItem("Brain MRI") as brain_tab: | |
brain_id = gr.Textbox(value=brain_tab.label, visible=False) | |
with gr.Row().style(equal_height=True): | |
idx_brain = gr.Number(value=0, visible=False) | |
with gr.Column(scale=1, min_width=200): | |
x_brain = gr.Image(label="Observation", interactive=False).style( | |
height=HEIGHT | |
) | |
with gr.Column(scale=1, min_width=200): | |
cf_x_brain = gr.Image( | |
label="Counterfactual", interactive=False | |
).style(height=HEIGHT) | |
with gr.Column(scale=1, min_width=200): | |
cf_x_std_brain = gr.Image( | |
label="Counterfactual Uncertainty", interactive=False | |
).style(height=HEIGHT) | |
with gr.Column(scale=1, min_width=200): | |
effect_brain = gr.Image( | |
label="Direct Causal Effect", interactive=False | |
).style(height=HEIGHT) | |
with gr.Row(): | |
with gr.Column(scale=2.55): | |
gr.Markdown( | |
"#### Intervention" | |
+ 28 * " " | |
+ "[arXiv paper](https://arxiv.org/abs/2306.15764)   |   [GitHub code](https://github.com/biomedia-mira/causal-gen)" | |
) | |
with gr.Row(): | |
with gr.Column(min_width=200): | |
do_m = gr.Checkbox(label="do(MRI sequence)", value=False) | |
m = gr.Radio( | |
["T1", "T2-FLAIR"], label="", interactive=False | |
) | |
with gr.Column(min_width=200): | |
do_s = gr.Checkbox(label="do(sex)", value=False) | |
s = gr.Radio( | |
["female", "male"], label="", interactive=False | |
) | |
with gr.Row(): | |
with gr.Column(min_width=100): | |
do_a = gr.Checkbox(label="do(age)", value=False) | |
a = gr.Slider( | |
label="\u00A0", | |
value=50, | |
minimum=44, | |
maximum=73, | |
step=1, | |
interactive=False, | |
) | |
with gr.Column(min_width=100): | |
do_b = gr.Checkbox(label="do(brain volume)", value=False) | |
b = gr.Slider( | |
label="\u00A0", | |
value=1000, | |
minimum=850, | |
maximum=1550, | |
step=20, | |
interactive=False, | |
) | |
with gr.Column(min_width=100): | |
do_v = gr.Checkbox( | |
label="do(ventricle volume)", value=False | |
) | |
v = gr.Slider( | |
label="\u00A0", | |
value=40, | |
minimum=10, | |
maximum=125, | |
step=2, | |
interactive=False, | |
) | |
with gr.Row(): | |
new_brain = gr.Button("New Observation") | |
reset_brain = gr.Button("Reset", variant="stop") | |
submit_brain = gr.Button("Submit", variant="primary") | |
with gr.Column(scale=1): | |
# gr.Markdown("### ") | |
causal_graph_brain = gr.Image( | |
label="Causal Graph", interactive=False | |
).style(height=340) | |
with gr.TabItem("Chest X-ray") as chest_tab: | |
chest_id = gr.Textbox(value=chest_tab.label, visible=False) | |
with gr.Row().style(equal_height=True): | |
idx_chest = gr.Number(value=0, visible=False) | |
with gr.Column(scale=1, min_width=200): | |
x_chest = gr.Image(label="Observation", interactive=False).style( | |
height=HEIGHT | |
) | |
with gr.Column(scale=1, min_width=200): | |
cf_x_chest = gr.Image( | |
label="Counterfactual", interactive=False | |
).style(height=HEIGHT) | |
with gr.Column(scale=1, min_width=200): | |
cf_x_std_chest = gr.Image( | |
label="Counterfactual Uncertainty", interactive=False | |
).style(height=HEIGHT) | |
with gr.Column(scale=1, min_width=200): | |
effect_chest = gr.Image( | |
label="Direct Causal Effect", interactive=False | |
).style(height=HEIGHT) | |
with gr.Row(): | |
with gr.Column(scale=2.55): | |
gr.Markdown( | |
"#### Intervention" | |
+ 28 * " " | |
+ "[arXiv paper](https://arxiv.org/abs/2306.15764)   |   [GitHub code](https://github.com/biomedia-mira/causal-gen)" | |
) | |
with gr.Row().style(equal_height=True): | |
with gr.Column(min_width=200): | |
do_f_chest = gr.Checkbox(label="do(disease)", value=False) | |
f_chest = gr.Radio(FIND_CAT, label="", interactive=False) | |
with gr.Column(min_width=200): | |
do_s_chest = gr.Checkbox(label="do(sex)", value=False) | |
s_chest = gr.Radio( | |
SEX_CAT_CHEST, label="", interactive=False | |
) | |
with gr.Row(): | |
with gr.Column(min_width=200): | |
do_r_chest = gr.Checkbox(label="do(race)", value=False) | |
r_chest = gr.Radio(RACE_CAT, label="", interactive=False) | |
with gr.Column(min_width=200): | |
do_a_chest = gr.Checkbox(label="do(age)", value=False) | |
a_chest = gr.Slider( | |
label="\u00A0", minimum=18, maximum=98, step=1 | |
) | |
with gr.Row(): | |
new_chest = gr.Button("New Observation") | |
reset_chest = gr.Button("Reset", variant="stop") | |
submit_chest = gr.Button("Submit", variant="primary") | |
with gr.Column(scale=1): | |
# gr.Markdown("### ") | |
causal_graph_chest = gr.Image( | |
label="Causal Graph", interactive=False | |
).style(height=345) | |
# morphomnist | |
do = [do_t, do_i, do_y] | |
obs = [idx, x, t, i, y] | |
cf_out = [cf_x, cf_x_std, effect] | |
# brain | |
do_brain = [do_m, do_s, do_a, do_b, do_v] # intervention checkboxes | |
obs_brain = [idx_brain, x_brain, m, s, a, b, v] # observed image/attributes | |
cf_out_brain = [cf_x_brain, cf_x_std_brain, effect_brain] # counterfactual outputs | |
# chest | |
do_chest = [do_r_chest, do_s_chest, do_f_chest, do_a_chest] | |
obs_chest = [idx_chest, x_chest, r_chest, s_chest, f_chest, a_chest] | |
cf_out_chest = [cf_x_chest, cf_x_std_chest, effect_chest] | |
# on start: load new observations & causal graph | |
demo.load(fn=get_mnist_obs, inputs=None, outputs=obs) | |
demo.load(fn=mnist_graph, inputs=do, outputs=causal_graph) | |
demo.load(fn=load_model, inputs=mnist_id, outputs=None) | |
demo.load(fn=get_brain_obs, inputs=None, outputs=obs_brain) | |
demo.load(fn=get_chest_obs, inputs=None, outputs=obs_chest) | |
demo.load(fn=brain_graph, inputs=do_brain, outputs=causal_graph_brain) | |
demo.load(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest) | |
# on tab select: load models | |
brain_tab.select(fn=load_model, inputs=brain_id, outputs=None) | |
chest_tab.select(fn=load_model, inputs=chest_id, outputs=None) | |
# "new" button: load new observations | |
new.click(fn=get_mnist_obs, inputs=None, outputs=obs) | |
new_chest.click(fn=get_chest_obs, inputs=None, outputs=obs_chest) | |
new_brain.click(fn=get_brain_obs, inputs=None, outputs=obs_brain) | |
# "new" button: reset causal graphs | |
new.click(fn=mnist_graph, inputs=do, outputs=causal_graph) | |
new_brain.click(fn=brain_graph, inputs=do_brain, outputs=causal_graph_brain) | |
new_chest.click(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest) | |
# "new" button: reset cf output panels | |
for _k, _v in zip( | |
[new, new_brain, new_chest], [cf_out, cf_out_brain, cf_out_chest] | |
): | |
_k.click(fn=lambda: (gr.update(value=None),) * 3, inputs=None, outputs=_v) | |
# "reset" button: reload current observations | |
reset.click(fn=get_mnist_obs, inputs=idx, outputs=obs) | |
reset_brain.click(fn=get_brain_obs, inputs=idx_brain, outputs=obs_brain) | |
reset_chest.click(fn=get_chest_obs, inputs=idx_chest, outputs=obs_chest) | |
# "reset" button: deselect intervention checkboxes | |
reset.click(fn=lambda: (gr.update(value=False),) * len(do), inputs=None, outputs=do) | |
reset_brain.click( | |
fn=lambda: (gr.update(value=False),) * len(do_brain), | |
inputs=None, | |
outputs=do_brain, | |
) | |
reset_chest.click( | |
fn=lambda: (gr.update(value=False),) * len(do_chest), | |
inputs=None, | |
outputs=do_chest, | |
) | |
# "reset" button: reset cf output panels | |
for _k, _v in zip( | |
[reset, reset_brain, reset_chest], [cf_out, cf_out_brain, cf_out_chest] | |
): | |
_k.click(fn=lambda: plt.close("all"), inputs=None, outputs=None) | |
_k.click(fn=lambda: (gr.update(value=None),) * 3, inputs=None, outputs=_v) | |
# enable mnist interventions when checkbox is selected & update graph | |
for _k, _v in zip(do, [t, i, y]): | |
_k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v) | |
_k.change(mnist_graph, inputs=do, outputs=causal_graph) | |
# enable brain interventions when checkbox is selected & update graph | |
for _k, _v in zip(do_brain, [m, s, a, b, v]): | |
_k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v) | |
_k.change(brain_graph, inputs=do_brain, outputs=causal_graph_brain) | |
# enable chest interventions when checkbox is selected & update graph | |
for _k, _v in zip(do_chest, [r_chest, s_chest, f_chest, a_chest]): | |
_k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v) | |
_k.change(chest_graph, inputs=do_chest, outputs=causal_graph_chest) | |
# "submit" button: infer countefactuals | |
submit.click(fn=infer_mnist_cf, inputs=obs + do, outputs=cf_out + [t, i, y]) | |
submit_brain.click( | |
fn=infer_brain_cf, | |
inputs=obs_brain + do_brain, | |
outputs=cf_out_brain + [m, s, a, b, v], | |
) | |
submit_chest.click( | |
fn=infer_chest_cf, | |
inputs=obs_chest + do_chest, | |
outputs=cf_out_chest + [r_chest, s_chest, f_chest, a_chest], | |
) | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch() | |