Spaces:
Sleeping
Sleeping
fabio-deep
commited on
Commit
·
146a6ea
1
Parent(s):
8c4fe8b
added links
Browse files- .gitignore +1 -0
- app.py +343 -192
- app_utils.py +176 -114
- datasets.py +168 -123
- pgm/flow_pgm.py +310 -380
- pgm/layers.py +50 -42
- vae.py +136 -88
.gitignore
CHANGED
@@ -1,2 +1,3 @@
|
|
|
|
1 |
__pycache__
|
2 |
*.pyc
|
|
|
1 |
+
.vscode
|
2 |
__pycache__
|
3 |
*.pyc
|
app.py
CHANGED
@@ -8,48 +8,55 @@ from vae import HVAE
|
|
8 |
from datasets import morphomnist, ukbb, mimic, get_attr_max_min
|
9 |
from pgm.flow_pgm import MorphoMNISTPGM, FlowPGM, ChestPGM
|
10 |
from app_utils import (
|
11 |
-
mnist_graph,
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
)
|
14 |
|
15 |
DATA, MODELS = {}, {}
|
16 |
-
for k in [
|
17 |
DATA[k], MODELS[k] = {}, {}
|
18 |
|
19 |
# mnist
|
20 |
DIGITS = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
|
21 |
# brain
|
22 |
-
MRISEQ_CAT = [
|
23 |
-
SEX_CAT = [
|
24 |
HEIGHT, WIDTH = 270, 270
|
25 |
# chest
|
26 |
-
SEX_CAT_CHEST = [
|
27 |
-
RACE_CAT = [
|
28 |
-
FIND_CAT = [
|
29 |
-
DEVICE = torch.device(
|
30 |
|
31 |
|
32 |
class Hparams:
|
33 |
def update(self, dict):
|
34 |
for k, v in dict.items():
|
35 |
-
setattr(self, k, v)
|
36 |
|
37 |
|
38 |
def get_paths(dataset_id):
|
39 |
-
if
|
40 |
-
data_path =
|
41 |
-
pgm_path =
|
42 |
-
vae_path =
|
43 |
-
elif
|
44 |
-
data_path =
|
45 |
-
pgm_path =
|
46 |
-
vae_path =
|
47 |
-
elif
|
48 |
-
data_path =
|
49 |
-
pgm_path =
|
50 |
vae_path = [
|
51 |
-
|
52 |
-
|
53 |
]
|
54 |
return data_path, vae_path, pgm_path
|
55 |
|
@@ -57,64 +64,71 @@ def get_paths(dataset_id):
|
|
57 |
def load_pgm(dataset_id, pgm_path):
|
58 |
checkpoint = torch.load(pgm_path, map_location=DEVICE)
|
59 |
args = Hparams()
|
60 |
-
args.update(checkpoint[
|
61 |
args.device = DEVICE
|
62 |
-
if
|
63 |
pgm = MorphoMNISTPGM(args).to(args.device)
|
64 |
-
elif
|
65 |
pgm = FlowPGM(args).to(args.device)
|
66 |
-
elif
|
67 |
pgm = ChestPGM(args).to(args.device)
|
68 |
-
pgm.load_state_dict(checkpoint[
|
69 |
-
MODELS[dataset_id][
|
70 |
-
MODELS[dataset_id][
|
71 |
|
72 |
|
73 |
def load_vae(dataset_id, vae_path):
|
74 |
-
if
|
75 |
vae_path, dscm_path = vae_path[0], vae_path[1]
|
76 |
checkpoint = torch.load(vae_path, map_location=DEVICE)
|
77 |
args = Hparams()
|
78 |
-
args.update(checkpoint[
|
79 |
# backwards compatibility hack
|
80 |
-
if not hasattr(args,
|
81 |
-
args.vae =
|
82 |
-
if not hasattr(args,
|
83 |
args.cond_prior = False
|
84 |
-
if hasattr(args,
|
85 |
args.kl_free_bits = args.free_bits
|
86 |
args.device = DEVICE
|
87 |
vae = HVAE(args).to(args.device)
|
88 |
|
89 |
-
if
|
90 |
dscm_ckpt = torch.load(dscm_path, map_location=DEVICE)
|
91 |
-
vae.load_state_dict(
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
else:
|
93 |
-
vae.load_state_dict(checkpoint[
|
94 |
-
MODELS[dataset_id][
|
95 |
-
MODELS[dataset_id][
|
96 |
|
97 |
|
98 |
def get_dataloader(dataset_id, data_path):
|
99 |
-
MODELS[dataset_id][
|
100 |
-
args = MODELS[dataset_id][
|
101 |
-
if
|
102 |
datasets = morphomnist(args)
|
103 |
-
elif
|
104 |
datasets = ukbb(args)
|
105 |
-
elif
|
106 |
datasets = mimic(args)
|
107 |
-
DATA[dataset_id][
|
108 |
-
datasets[
|
|
|
109 |
|
110 |
|
111 |
def load_dataset(dataset_id):
|
112 |
data_path, _, pgm_path = get_paths(dataset_id)
|
113 |
checkpoint = torch.load(pgm_path, map_location=DEVICE)
|
114 |
args = Hparams()
|
115 |
-
args.update(checkpoint[
|
116 |
args.device = DEVICE
|
117 |
-
MODELS[dataset_id][
|
118 |
get_dataloader(dataset_id, data_path)
|
119 |
|
120 |
|
@@ -122,167 +136,179 @@ def load_model(dataset_id):
|
|
122 |
_, vae_path, pgm_path = get_paths(dataset_id)
|
123 |
load_pgm(dataset_id, pgm_path)
|
124 |
load_vae(dataset_id, vae_path)
|
125 |
-
|
126 |
|
127 |
@torch.no_grad()
|
128 |
def counterfactual_inference(dataset_id, obs, do_pa):
|
129 |
-
pa = {k: v.clone() for k, v in obs.items() if k !=
|
130 |
-
cf_pa = MODELS[dataset_id][
|
131 |
-
|
|
|
|
|
132 |
_pa = vae_preprocess(args, {k: v.clone() for k, v in pa.items()})
|
133 |
-
_cf_pa = vae_preprocess(args
|
134 |
-
z_t = 0.1 if
|
135 |
-
z = vae.abduct(x=obs[
|
136 |
if vae.cond_prior:
|
137 |
-
z = [z[j][
|
138 |
px_loc, px_scale = vae.forward_latents(latents=z, parents=_pa)
|
139 |
cf_loc, cf_scale = vae.forward_latents(latents=z, parents=_cf_pa)
|
140 |
-
u = (obs[
|
141 |
-
u_t = 0.1 if
|
142 |
cf_scale = cf_scale * u_t
|
143 |
cf_x = torch.clamp(cf_loc + cf_scale * u, min=-1, max=1)
|
144 |
-
return {
|
145 |
|
146 |
|
147 |
def get_obs_item(dataset_id, idx=None):
|
148 |
if idx is None:
|
149 |
-
n_test = len(DATA[dataset_id][
|
150 |
idx = torch.randperm(n_test)[0]
|
151 |
idx = int(idx)
|
152 |
-
return idx, DATA[dataset_id][
|
153 |
|
154 |
|
155 |
def get_mnist_obs(idx=None):
|
156 |
-
dataset_id =
|
157 |
if not DATA[dataset_id]:
|
158 |
load_dataset(dataset_id)
|
159 |
idx, obs = get_obs_item(dataset_id, idx)
|
160 |
-
x = get_fig_arr(obs[
|
161 |
-
t = (obs[
|
162 |
-
i = (obs[
|
163 |
-
y = DIGITS[obs[
|
164 |
return (idx, x, float(np.round(t, 2)), float(np.round(i, 2)), y)
|
165 |
|
166 |
|
167 |
def get_brain_obs(idx=None):
|
168 |
-
dataset_id =
|
169 |
if not DATA[dataset_id]:
|
170 |
load_dataset(dataset_id)
|
171 |
idx, obs = get_obs_item(dataset_id, idx)
|
172 |
-
x = get_fig_arr(obs[
|
173 |
-
m = MRISEQ_CAT[int(obs[
|
174 |
-
s = SEX_CAT[int(obs[
|
175 |
-
a = obs[
|
176 |
-
b = obs[
|
177 |
-
v = obs[
|
178 |
return (idx, x, m, s, a, float(np.round(b, 2)), float(np.round(v, 2)))
|
179 |
|
180 |
|
181 |
def get_chest_obs(idx=None):
|
182 |
-
dataset_id =
|
183 |
if not DATA[dataset_id]:
|
184 |
-
load_dataset(dataset_id)
|
185 |
idx, obs = get_obs_item(dataset_id, idx)
|
186 |
-
x = get_fig_arr(postprocess(obs[
|
187 |
-
s = SEX_CAT_CHEST[int(obs[
|
188 |
-
f = FIND_CAT[int(obs[
|
189 |
-
r = RACE_CAT[obs[
|
190 |
-
a = (obs[
|
191 |
return (idx, x, r, s, f, float(np.round(a, 1)))
|
192 |
|
193 |
|
194 |
def infer_mnist_cf(*args):
|
195 |
-
dataset_id =
|
196 |
idx, _, t, i, y, do_t, do_i, do_y = args
|
197 |
n_particles = 32
|
198 |
# preprocess
|
199 |
-
obs = DATA[dataset_id][
|
200 |
-
obs[
|
201 |
for k, v in obs.items():
|
202 |
obs[k] = v.view(1, 1) if len(v.shape) < 1 else v.unsqueeze(0)
|
203 |
-
obs[k] = obs[k].to(MODELS[dataset_id][
|
204 |
if n_particles > 1:
|
205 |
-
ndims = (1,)*3 if k ==
|
206 |
obs[k] = obs[k].repeat(n_particles, *ndims)
|
207 |
# intervention(s)
|
208 |
do_pa = {}
|
209 |
if do_t:
|
210 |
-
do_pa[
|
|
|
|
|
211 |
if do_i:
|
212 |
-
do_pa[
|
|
|
|
|
213 |
if do_y:
|
214 |
-
do_pa[
|
215 |
-
|
|
|
|
|
216 |
for k, v in do_pa.items():
|
217 |
-
do_pa[k] =
|
|
|
|
|
218 |
# infer counterfactual
|
219 |
out = counterfactual_inference(dataset_id, obs, do_pa)
|
220 |
# avg cf particles
|
221 |
-
cf_x = out[
|
222 |
-
cf_x_std = out[
|
223 |
-
rec_x = out[
|
224 |
-
cf_t = out[
|
225 |
-
cf_i = out[
|
226 |
-
cf_y = out[
|
227 |
# post process
|
228 |
cf_x = postprocess(cf_x)
|
229 |
cf_x_std = cf_x_std.squeeze().detach().cpu().numpy()
|
230 |
rec_x = postprocess(rec_x)
|
231 |
cf_t = np.round((cf_t.item() + 1) / 2 * (6.255515 - 0.87598526) + 0.87598526, 2)
|
232 |
-
cf_i = np.round((cf_i.item() + 1) / 2 * (254.90317 - 66.601204) + 66.601204, 2)
|
233 |
cf_y = DIGITS[cf_y.argmax(-1)]
|
234 |
-
# plots
|
235 |
# plt.close('all')
|
236 |
effect = cf_x - rec_x
|
237 |
-
effect = get_fig_arr(
|
238 |
-
norm=MidpointNormalize(vmin=-255, midpoint=0, vmax=255)
|
|
|
239 |
cf_x = get_fig_arr(cf_x)
|
240 |
-
cf_x_std = get_fig_arr(cf_x_std, cmap=
|
241 |
return (cf_x, cf_x_std, effect, cf_t, cf_i, cf_y)
|
242 |
|
243 |
|
244 |
def infer_brain_cf(*args):
|
245 |
-
dataset_id =
|
246 |
idx, _, m, s, a, b, v = args[:7]
|
247 |
do_m, do_s, do_a, do_b, do_v = args[7:]
|
248 |
n_particles = 16
|
249 |
# preprocessing
|
250 |
-
obs = DATA[dataset_id][
|
251 |
-
obs
|
252 |
-
obs = preprocess_brain(MODELS[dataset_id]['vae_args'], obs)
|
253 |
for k, _v in obs.items():
|
254 |
if n_particles > 1:
|
255 |
-
ndims = (1,)*3 if k ==
|
256 |
obs[k] = _v.repeat(n_particles, *ndims)
|
257 |
# interventions(s)
|
258 |
do_pa = {}
|
259 |
if do_m:
|
260 |
-
do_pa[
|
261 |
if do_s:
|
262 |
-
do_pa[
|
263 |
if do_a:
|
264 |
-
do_pa[
|
265 |
if do_b:
|
266 |
-
do_pa[
|
267 |
if do_v:
|
268 |
-
do_pa[
|
269 |
# normalize continuous attributes
|
270 |
-
for k in [
|
271 |
if k in do_pa.keys():
|
272 |
k_max, k_min = get_attr_max_min(k)
|
273 |
do_pa[k] = (do_pa[k] - k_min) / (k_max - k_min) # [0,1]
|
274 |
do_pa[k] = 2 * do_pa[k] - 1 # [-1,1]
|
275 |
|
276 |
for k, _v in do_pa.items():
|
277 |
-
do_pa[k] =
|
|
|
|
|
278 |
# infer counterfactual
|
279 |
out = counterfactual_inference(dataset_id, obs, do_pa)
|
280 |
# avg cf particles
|
281 |
-
cf_x = out[
|
282 |
-
cf_x_std = out[
|
283 |
-
rec_x = out[
|
284 |
-
cf_m = out[
|
285 |
-
cf_s = out[
|
286 |
# post process
|
287 |
cf_x = postprocess(cf_x)
|
288 |
cf_x_std = cf_x_std.squeeze().detach().cpu().numpy()
|
@@ -290,54 +316,70 @@ def infer_brain_cf(*args):
|
|
290 |
cf_m = MRISEQ_CAT[int(cf_m.item())]
|
291 |
cf_s = SEX_CAT[int(cf_s.item())]
|
292 |
cf_ = {}
|
293 |
-
for k in [
|
294 |
k_max, k_min = get_attr_max_min(k)
|
295 |
-
cf_[k] = (out[
|
296 |
# plots
|
297 |
-
# plt.close('all')
|
298 |
effect = cf_x - rec_x
|
299 |
-
effect = get_fig_arr(
|
300 |
-
|
|
|
|
|
|
|
301 |
cf_x = get_fig_arr(cf_x)
|
302 |
-
cf_x_std = get_fig_arr(cf_x_std, cmap=
|
303 |
-
return (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
|
305 |
|
306 |
def infer_chest_cf(*args):
|
307 |
-
dataset_id =
|
308 |
idx, _, r, s, f, a = args[:6]
|
309 |
do_r, do_s, do_f, do_a = args[6:]
|
310 |
n_particles = 16
|
311 |
# preprocessing
|
312 |
-
obs = DATA[dataset_id][
|
313 |
for k, v in obs.items():
|
314 |
-
obs[k] = v.to(MODELS[dataset_id][
|
315 |
if n_particles > 1:
|
316 |
-
ndims = (1,)*3 if k ==
|
317 |
obs[k] = obs[k].repeat(n_particles, *ndims)
|
318 |
# intervention(s)
|
319 |
do_pa = {}
|
320 |
with torch.no_grad():
|
321 |
if do_s:
|
322 |
-
do_pa[
|
323 |
if do_f:
|
324 |
-
do_pa[
|
325 |
if do_r:
|
326 |
-
do_pa[
|
|
|
|
|
327 |
if do_a:
|
328 |
-
do_pa[
|
329 |
for k, v in do_pa.items():
|
330 |
-
do_pa[k] =
|
|
|
|
|
331 |
# infer counterfactual
|
332 |
out = counterfactual_inference(dataset_id, obs, do_pa)
|
333 |
# avg cf particles
|
334 |
-
cf_x = out[
|
335 |
-
cf_x_std = out[
|
336 |
-
rec_x = out[
|
337 |
-
cf_r = out[
|
338 |
-
cf_s = out[
|
339 |
-
cf_f = out[
|
340 |
-
cf_a = out[
|
341 |
# post process
|
342 |
cf_x = postprocess(cf_x)
|
343 |
cf_x_std = cf_x_std.squeeze().detach().cpu().numpy()
|
@@ -349,10 +391,13 @@ def infer_chest_cf(*args):
|
|
349 |
# plots
|
350 |
# plt.close('all')
|
351 |
effect = cf_x - rec_x
|
352 |
-
effect = get_fig_arr(
|
353 |
-
|
|
|
|
|
|
|
354 |
cf_x = get_fig_arr(cf_x)
|
355 |
-
cf_x_std = get_fig_arr(cf_x_std, cmap=
|
356 |
return (cf_x, cf_x_std, effect, cf_r, cf_s, cf_f, np.round(cf_a, 1))
|
357 |
|
358 |
|
@@ -364,33 +409,59 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|
364 |
with gr.Row().style(equal_height=True):
|
365 |
idx = gr.Number(value=0, visible=False)
|
366 |
with gr.Column(scale=1, min_width=200):
|
367 |
-
x = gr.Image(label=
|
|
|
|
|
368 |
with gr.Column(scale=1, min_width=200):
|
369 |
-
cf_x = gr.Image(label=
|
|
|
|
|
370 |
with gr.Column(scale=1, min_width=200):
|
371 |
-
cf_x_std = gr.Image(
|
|
|
|
|
372 |
with gr.Column(scale=1, min_width=200):
|
373 |
-
effect = gr.Image(
|
|
|
|
|
374 |
with gr.Row().style(equal_height=True):
|
375 |
with gr.Column(scale=1.75):
|
376 |
-
gr.Markdown(
|
|
|
|
|
|
|
|
|
377 |
with gr.Column():
|
378 |
do_y = gr.Checkbox(label="do(digit)", value=False)
|
379 |
y = gr.Radio(DIGITS, label="", interactive=False)
|
380 |
with gr.Row():
|
381 |
with gr.Column(min_width=100):
|
382 |
do_t = gr.Checkbox(label="do(thickness)", value=False)
|
383 |
-
t = gr.Slider(
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
with gr.Column(min_width=100):
|
385 |
do_i = gr.Checkbox(label="do(intensity)", value=False)
|
386 |
-
i = gr.Slider(
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
with gr.Row():
|
388 |
new = gr.Button("New Observation")
|
389 |
reset = gr.Button("Reset", variant="stop")
|
390 |
submit = gr.Button("Submit", variant="primary")
|
391 |
with gr.Column(scale=1):
|
392 |
gr.Markdown("### ")
|
393 |
-
causal_graph = gr.Image(
|
|
|
|
|
394 |
|
395 |
with gr.TabItem("Brain MRI") as brain_tab:
|
396 |
brain_id = gr.Textbox(value=brain_tab.label, visible=False)
|
@@ -398,40 +469,81 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|
398 |
with gr.Row().style(equal_height=True):
|
399 |
idx_brain = gr.Number(value=0, visible=False)
|
400 |
with gr.Column(scale=1, min_width=200):
|
401 |
-
x_brain = gr.Image(label=
|
|
|
|
|
402 |
with gr.Column(scale=1, min_width=200):
|
403 |
-
cf_x_brain = gr.Image(
|
|
|
|
|
404 |
with gr.Column(scale=1, min_width=200):
|
405 |
-
cf_x_std_brain = gr.Image(
|
|
|
|
|
406 |
with gr.Column(scale=1, min_width=200):
|
407 |
-
effect_brain = gr.Image(
|
|
|
|
|
408 |
with gr.Row():
|
409 |
with gr.Column(scale=2.55):
|
410 |
-
gr.Markdown(
|
|
|
|
|
|
|
|
|
411 |
with gr.Row():
|
412 |
with gr.Column(min_width=200):
|
413 |
do_m = gr.Checkbox(label="do(MRI sequence)", value=False)
|
414 |
-
m = gr.Radio(
|
|
|
|
|
415 |
with gr.Column(min_width=200):
|
416 |
do_s = gr.Checkbox(label="do(sex)", value=False)
|
417 |
-
s = gr.Radio(
|
|
|
|
|
418 |
with gr.Row():
|
419 |
with gr.Column(min_width=100):
|
420 |
do_a = gr.Checkbox(label="do(age)", value=False)
|
421 |
-
a = gr.Slider(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
422 |
with gr.Column(min_width=100):
|
423 |
do_b = gr.Checkbox(label="do(brain volume)", value=False)
|
424 |
-
b = gr.Slider(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
425 |
with gr.Column(min_width=100):
|
426 |
-
do_v = gr.Checkbox(
|
427 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
with gr.Row():
|
429 |
new_brain = gr.Button("New Observation")
|
430 |
-
reset_brain = gr.Button("Reset", variant=
|
431 |
-
submit_brain = gr.Button("Submit", variant=
|
432 |
with gr.Column(scale=1):
|
433 |
# gr.Markdown("### ")
|
434 |
-
causal_graph_brain = gr.Image(
|
|
|
|
|
435 |
|
436 |
with gr.TabItem("Chest X-ray") as chest_tab:
|
437 |
chest_id = gr.Textbox(value=chest_tab.label, visible=False)
|
@@ -439,40 +551,58 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|
439 |
with gr.Row().style(equal_height=True):
|
440 |
idx_chest = gr.Number(value=0, visible=False)
|
441 |
with gr.Column(scale=1, min_width=200):
|
442 |
-
x_chest = gr.Image(label=
|
|
|
|
|
443 |
with gr.Column(scale=1, min_width=200):
|
444 |
-
cf_x_chest = gr.Image(
|
|
|
|
|
445 |
with gr.Column(scale=1, min_width=200):
|
446 |
-
cf_x_std_chest = gr.Image(
|
|
|
|
|
447 |
with gr.Column(scale=1, min_width=200):
|
448 |
-
effect_chest = gr.Image(
|
|
|
|
|
449 |
|
450 |
with gr.Row():
|
451 |
with gr.Column(scale=2.55):
|
452 |
-
gr.Markdown(
|
|
|
|
|
|
|
|
|
453 |
with gr.Row().style(equal_height=True):
|
454 |
with gr.Column(min_width=200):
|
455 |
do_f_chest = gr.Checkbox(label="do(disease)", value=False)
|
456 |
f_chest = gr.Radio(FIND_CAT, label="", interactive=False)
|
457 |
with gr.Column(min_width=200):
|
458 |
do_s_chest = gr.Checkbox(label="do(sex)", value=False)
|
459 |
-
s_chest = gr.Radio(
|
|
|
|
|
460 |
|
461 |
with gr.Row():
|
462 |
with gr.Column(min_width=200):
|
463 |
do_r_chest = gr.Checkbox(label="do(race)", value=False)
|
464 |
-
r_chest =
|
465 |
with gr.Column(min_width=200):
|
466 |
do_a_chest = gr.Checkbox(label="do(age)", value=False)
|
467 |
-
a_chest = gr.Slider(
|
468 |
-
|
|
|
|
|
469 |
with gr.Row():
|
470 |
new_chest = gr.Button("New Observation")
|
471 |
reset_chest = gr.Button("Reset", variant="stop")
|
472 |
submit_chest = gr.Button("Submit", variant="primary")
|
473 |
with gr.Column(scale=1):
|
474 |
# gr.Markdown("### ")
|
475 |
-
causal_graph_chest = gr.Image(
|
|
|
|
|
476 |
|
477 |
# morphomnist
|
478 |
do = [do_t, do_i, do_y]
|
@@ -514,29 +644,41 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|
514 |
new_chest.click(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest)
|
515 |
|
516 |
# "new" button: reset cf output panels
|
517 |
-
for _k, _v in zip(
|
518 |
-
|
|
|
|
|
519 |
|
520 |
# "reset" button: reload current observations
|
521 |
reset.click(fn=get_mnist_obs, inputs=idx, outputs=obs)
|
522 |
reset_brain.click(fn=get_brain_obs, inputs=idx_brain, outputs=obs_brain)
|
523 |
reset_chest.click(fn=get_chest_obs, inputs=idx_chest, outputs=obs_chest)
|
524 |
-
|
525 |
# "reset" button: deselect intervention checkboxes
|
526 |
-
reset.click(fn=lambda: (gr.update(value=False),)*len(do), inputs=None, outputs=do)
|
527 |
-
reset_brain.click(
|
528 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
529 |
|
530 |
# "reset" button: reset cf output panels
|
531 |
-
for _k, _v in zip(
|
532 |
-
|
533 |
-
|
|
|
|
|
534 |
|
535 |
# enable mnist interventions when checkbox is selected & update graph
|
536 |
for _k, _v in zip(do, [t, i, y]):
|
537 |
_k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v)
|
538 |
_k.change(mnist_graph, inputs=do, outputs=causal_graph)
|
539 |
-
|
540 |
# enable brain interventions when checkbox is selected & update graph
|
541 |
for _k, _v in zip(do_brain, [m, s, a, b, v]):
|
542 |
_k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v)
|
@@ -546,11 +688,20 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|
546 |
for _k, _v in zip(do_chest, [r_chest, s_chest, f_chest, a_chest]):
|
547 |
_k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v)
|
548 |
_k.change(chest_graph, inputs=do_chest, outputs=causal_graph_chest)
|
549 |
-
|
550 |
# "submit" button: infer countefactuals
|
551 |
submit.click(fn=infer_mnist_cf, inputs=obs + do, outputs=cf_out + [t, i, y])
|
552 |
-
submit_brain.click(
|
553 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
554 |
|
555 |
if __name__ == "__main__":
|
556 |
-
demo.
|
|
|
|
8 |
from datasets import morphomnist, ukbb, mimic, get_attr_max_min
|
9 |
from pgm.flow_pgm import MorphoMNISTPGM, FlowPGM, ChestPGM
|
10 |
from app_utils import (
|
11 |
+
mnist_graph,
|
12 |
+
brain_graph,
|
13 |
+
chest_graph,
|
14 |
+
vae_preprocess,
|
15 |
+
normalize,
|
16 |
+
preprocess_brain,
|
17 |
+
get_fig_arr,
|
18 |
+
postprocess,
|
19 |
+
MidpointNormalize,
|
20 |
)
|
21 |
|
22 |
DATA, MODELS = {}, {}
|
23 |
+
for k in ["Morpho-MNIST", "Brain MRI", "Chest X-ray"]:
|
24 |
DATA[k], MODELS[k] = {}, {}
|
25 |
|
26 |
# mnist
|
27 |
DIGITS = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
|
28 |
# brain
|
29 |
+
MRISEQ_CAT = ["T1", "T2-FLAIR"] # 0,1
|
30 |
+
SEX_CAT = ["female", "male"] # 0,1
|
31 |
HEIGHT, WIDTH = 270, 270
|
32 |
# chest
|
33 |
+
SEX_CAT_CHEST = ["male", "female"] # 0,1
|
34 |
+
RACE_CAT = ["white", "asian", "black"] # 0,1,2
|
35 |
+
FIND_CAT = ["no disease", "pleural effusion"]
|
36 |
+
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
37 |
|
38 |
|
39 |
class Hparams:
|
40 |
def update(self, dict):
|
41 |
for k, v in dict.items():
|
42 |
+
setattr(self, k, v)
|
43 |
|
44 |
|
45 |
def get_paths(dataset_id):
|
46 |
+
if "MNIST" in dataset_id:
|
47 |
+
data_path = "./data/morphomnist"
|
48 |
+
pgm_path = "./checkpoints/t_i_d/sup_pgm/checkpoint.pt"
|
49 |
+
vae_path = "./checkpoints/t_i_d/dgauss_cond_big_beta1_dropexo/checkpoint.pt"
|
50 |
+
elif "Brain" in dataset_id:
|
51 |
+
data_path = "./data/ukbb_subset"
|
52 |
+
pgm_path = "./checkpoints/m_b_v_s/sup_pgm/checkpoint.pt"
|
53 |
+
vae_path = "./checkpoints/m_b_v_s/ukbb192_beta5_dgauss_b33/checkpoint.pt"
|
54 |
+
elif "Chest" in dataset_id:
|
55 |
+
data_path = "./data/mimic_subset"
|
56 |
+
pgm_path = "./checkpoints/a_r_s_f/sup_pgm_mimic/checkpoint.pt"
|
57 |
vae_path = [
|
58 |
+
"./checkpoints/a_r_s_f/mimic_beta9_gelu_dgauss_1_lr3/checkpoint.pt", # base vae
|
59 |
+
"./checkpoints/a_r_s_f/mimic_dscm_lr_1e5_lagrange_lr_1_damping_10/6500_checkpoint.pt", # cf trained DSCM
|
60 |
]
|
61 |
return data_path, vae_path, pgm_path
|
62 |
|
|
|
64 |
def load_pgm(dataset_id, pgm_path):
|
65 |
checkpoint = torch.load(pgm_path, map_location=DEVICE)
|
66 |
args = Hparams()
|
67 |
+
args.update(checkpoint["hparams"])
|
68 |
args.device = DEVICE
|
69 |
+
if "MNIST" in dataset_id:
|
70 |
pgm = MorphoMNISTPGM(args).to(args.device)
|
71 |
+
elif "Brain" in dataset_id:
|
72 |
pgm = FlowPGM(args).to(args.device)
|
73 |
+
elif "Chest" in dataset_id:
|
74 |
pgm = ChestPGM(args).to(args.device)
|
75 |
+
pgm.load_state_dict(checkpoint["ema_model_state_dict"])
|
76 |
+
MODELS[dataset_id]["pgm"] = pgm
|
77 |
+
MODELS[dataset_id]["pgm_args"] = args
|
78 |
|
79 |
|
80 |
def load_vae(dataset_id, vae_path):
|
81 |
+
if "Chest" in dataset_id:
|
82 |
vae_path, dscm_path = vae_path[0], vae_path[1]
|
83 |
checkpoint = torch.load(vae_path, map_location=DEVICE)
|
84 |
args = Hparams()
|
85 |
+
args.update(checkpoint["hparams"])
|
86 |
# backwards compatibility hack
|
87 |
+
if not hasattr(args, "vae"):
|
88 |
+
args.vae = "hierarchical"
|
89 |
+
if not hasattr(args, "cond_prior"):
|
90 |
args.cond_prior = False
|
91 |
+
if hasattr(args, "free_bits"):
|
92 |
args.kl_free_bits = args.free_bits
|
93 |
args.device = DEVICE
|
94 |
vae = HVAE(args).to(args.device)
|
95 |
|
96 |
+
if "Chest" in dataset_id:
|
97 |
dscm_ckpt = torch.load(dscm_path, map_location=DEVICE)
|
98 |
+
vae.load_state_dict(
|
99 |
+
{
|
100 |
+
k[4:]: v
|
101 |
+
for k, v in dscm_ckpt["ema_model_state_dict"].items()
|
102 |
+
if "vae." in k
|
103 |
+
}
|
104 |
+
)
|
105 |
else:
|
106 |
+
vae.load_state_dict(checkpoint["ema_model_state_dict"])
|
107 |
+
MODELS[dataset_id]["vae"] = vae
|
108 |
+
MODELS[dataset_id]["vae_args"] = args
|
109 |
|
110 |
|
111 |
def get_dataloader(dataset_id, data_path):
|
112 |
+
MODELS[dataset_id]["pgm_args"].data_dir = data_path
|
113 |
+
args = MODELS[dataset_id]["pgm_args"]
|
114 |
+
if "MNIST" in dataset_id:
|
115 |
datasets = morphomnist(args)
|
116 |
+
elif "Brain" in dataset_id:
|
117 |
datasets = ukbb(args)
|
118 |
+
elif "Chest" in dataset_id:
|
119 |
datasets = mimic(args)
|
120 |
+
DATA[dataset_id]["test"] = torch.utils.data.DataLoader(
|
121 |
+
datasets["test"], shuffle=False, batch_size=args.bs, num_workers=4
|
122 |
+
)
|
123 |
|
124 |
|
125 |
def load_dataset(dataset_id):
|
126 |
data_path, _, pgm_path = get_paths(dataset_id)
|
127 |
checkpoint = torch.load(pgm_path, map_location=DEVICE)
|
128 |
args = Hparams()
|
129 |
+
args.update(checkpoint["hparams"])
|
130 |
args.device = DEVICE
|
131 |
+
MODELS[dataset_id]["pgm_args"] = args
|
132 |
get_dataloader(dataset_id, data_path)
|
133 |
|
134 |
|
|
|
136 |
_, vae_path, pgm_path = get_paths(dataset_id)
|
137 |
load_pgm(dataset_id, pgm_path)
|
138 |
load_vae(dataset_id, vae_path)
|
139 |
+
|
140 |
|
141 |
@torch.no_grad()
|
142 |
def counterfactual_inference(dataset_id, obs, do_pa):
|
143 |
+
pa = {k: v.clone() for k, v in obs.items() if k != "x"}
|
144 |
+
cf_pa = MODELS[dataset_id]["pgm"].counterfactual(
|
145 |
+
obs=pa, intervention=do_pa, num_particles=1
|
146 |
+
)
|
147 |
+
args, vae = MODELS[dataset_id]["vae_args"], MODELS[dataset_id]["vae"]
|
148 |
_pa = vae_preprocess(args, {k: v.clone() for k, v in pa.items()})
|
149 |
+
_cf_pa = vae_preprocess(args, {k: v.clone() for k, v in cf_pa.items()})
|
150 |
+
z_t = 0.1 if "mnist" in args.hps else 1.0
|
151 |
+
z = vae.abduct(x=obs["x"], parents=_pa, t=z_t)
|
152 |
if vae.cond_prior:
|
153 |
+
z = [z[j]["z"] for j in range(len(z))]
|
154 |
px_loc, px_scale = vae.forward_latents(latents=z, parents=_pa)
|
155 |
cf_loc, cf_scale = vae.forward_latents(latents=z, parents=_cf_pa)
|
156 |
+
u = (obs["x"] - px_loc) / px_scale.clamp(min=1e-12)
|
157 |
+
u_t = 0.1 if "mnist" in args.hps else 1.0 # cf sampling temp
|
158 |
cf_scale = cf_scale * u_t
|
159 |
cf_x = torch.clamp(cf_loc + cf_scale * u, min=-1, max=1)
|
160 |
+
return {"cf_x": cf_x, "rec_x": px_loc, "cf_pa": cf_pa}
|
161 |
|
162 |
|
163 |
def get_obs_item(dataset_id, idx=None):
|
164 |
if idx is None:
|
165 |
+
n_test = len(DATA[dataset_id]["test"].dataset)
|
166 |
idx = torch.randperm(n_test)[0]
|
167 |
idx = int(idx)
|
168 |
+
return idx, DATA[dataset_id]["test"].dataset.__getitem__(idx)
|
169 |
|
170 |
|
171 |
def get_mnist_obs(idx=None):
|
172 |
+
dataset_id = "Morpho-MNIST"
|
173 |
if not DATA[dataset_id]:
|
174 |
load_dataset(dataset_id)
|
175 |
idx, obs = get_obs_item(dataset_id, idx)
|
176 |
+
x = get_fig_arr(obs["x"].clone().squeeze().numpy())
|
177 |
+
t = (obs["thickness"].clone() + 1) / 2 * (6.255515 - 0.87598526) + 0.87598526
|
178 |
+
i = (obs["intensity"].clone() + 1) / 2 * (254.90317 - 66.601204) + 66.601204
|
179 |
+
y = DIGITS[obs["digit"].clone().argmax(-1)]
|
180 |
return (idx, x, float(np.round(t, 2)), float(np.round(i, 2)), y)
|
181 |
|
182 |
|
183 |
def get_brain_obs(idx=None):
|
184 |
+
dataset_id = "Brain MRI"
|
185 |
if not DATA[dataset_id]:
|
186 |
load_dataset(dataset_id)
|
187 |
idx, obs = get_obs_item(dataset_id, idx)
|
188 |
+
x = get_fig_arr(obs["x"].clone().squeeze().numpy())
|
189 |
+
m = MRISEQ_CAT[int(obs["mri_seq"].clone().item())]
|
190 |
+
s = SEX_CAT[int(obs["sex"].clone().item())]
|
191 |
+
a = obs["age"].clone().item()
|
192 |
+
b = obs["brain_volume"].clone().item() / 1000 # in ml
|
193 |
+
v = obs["ventricle_volume"].clone().item() / 1000 # in ml
|
194 |
return (idx, x, m, s, a, float(np.round(b, 2)), float(np.round(v, 2)))
|
195 |
|
196 |
|
197 |
def get_chest_obs(idx=None):
|
198 |
+
dataset_id = "Chest X-ray"
|
199 |
if not DATA[dataset_id]:
|
200 |
+
load_dataset(dataset_id)
|
201 |
idx, obs = get_obs_item(dataset_id, idx)
|
202 |
+
x = get_fig_arr(postprocess(obs["x"].clone()))
|
203 |
+
s = SEX_CAT_CHEST[int(obs["sex"].clone().squeeze().numpy())]
|
204 |
+
f = FIND_CAT[int(obs["finding"].clone().squeeze().numpy())]
|
205 |
+
r = RACE_CAT[obs["race"].clone().squeeze().numpy().argmax(-1)]
|
206 |
+
a = (obs["age"].clone().squeeze().numpy() + 1) * 50
|
207 |
return (idx, x, r, s, f, float(np.round(a, 1)))
|
208 |
|
209 |
|
210 |
def infer_mnist_cf(*args):
|
211 |
+
dataset_id = "Morpho-MNIST"
|
212 |
idx, _, t, i, y, do_t, do_i, do_y = args
|
213 |
n_particles = 32
|
214 |
# preprocess
|
215 |
+
obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx))
|
216 |
+
obs["x"] = (obs["x"] - 127.5) / 127.5
|
217 |
for k, v in obs.items():
|
218 |
obs[k] = v.view(1, 1) if len(v.shape) < 1 else v.unsqueeze(0)
|
219 |
+
obs[k] = obs[k].to(MODELS[dataset_id]["vae_args"].device).float()
|
220 |
if n_particles > 1:
|
221 |
+
ndims = (1,) * 3 if k == "x" else (1,)
|
222 |
obs[k] = obs[k].repeat(n_particles, *ndims)
|
223 |
# intervention(s)
|
224 |
do_pa = {}
|
225 |
if do_t:
|
226 |
+
do_pa["thickness"] = torch.tensor(
|
227 |
+
normalize(t, x_max=6.255515, x_min=0.87598526)
|
228 |
+
).view(1, 1)
|
229 |
if do_i:
|
230 |
+
do_pa["intensity"] = torch.tensor(
|
231 |
+
normalize(i, x_max=254.90317, x_min=66.601204)
|
232 |
+
).view(1, 1)
|
233 |
if do_y:
|
234 |
+
do_pa["digit"] = F.one_hot(torch.tensor(DIGITS.index(y)), num_classes=10).view(
|
235 |
+
1, 10
|
236 |
+
)
|
237 |
+
|
238 |
for k, v in do_pa.items():
|
239 |
+
do_pa[k] = (
|
240 |
+
v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1)
|
241 |
+
)
|
242 |
# infer counterfactual
|
243 |
out = counterfactual_inference(dataset_id, obs, do_pa)
|
244 |
# avg cf particles
|
245 |
+
cf_x = out["cf_x"].mean(0)
|
246 |
+
cf_x_std = out["cf_x"].std(0)
|
247 |
+
rec_x = out["rec_x"].mean(0)
|
248 |
+
cf_t = out["cf_pa"]["thickness"].mean(0)
|
249 |
+
cf_i = out["cf_pa"]["intensity"].mean(0)
|
250 |
+
cf_y = out["cf_pa"]["digit"].mean(0)
|
251 |
# post process
|
252 |
cf_x = postprocess(cf_x)
|
253 |
cf_x_std = cf_x_std.squeeze().detach().cpu().numpy()
|
254 |
rec_x = postprocess(rec_x)
|
255 |
cf_t = np.round((cf_t.item() + 1) / 2 * (6.255515 - 0.87598526) + 0.87598526, 2)
|
256 |
+
cf_i = np.round((cf_i.item() + 1) / 2 * (254.90317 - 66.601204) + 66.601204, 2)
|
257 |
cf_y = DIGITS[cf_y.argmax(-1)]
|
258 |
+
# plots
|
259 |
# plt.close('all')
|
260 |
effect = cf_x - rec_x
|
261 |
+
effect = get_fig_arr(
|
262 |
+
effect, cmap="RdBu_r", norm=MidpointNormalize(vmin=-255, midpoint=0, vmax=255)
|
263 |
+
)
|
264 |
cf_x = get_fig_arr(cf_x)
|
265 |
+
cf_x_std = get_fig_arr(cf_x_std, cmap="jet")
|
266 |
return (cf_x, cf_x_std, effect, cf_t, cf_i, cf_y)
|
267 |
|
268 |
|
269 |
def infer_brain_cf(*args):
|
270 |
+
dataset_id = "Brain MRI"
|
271 |
idx, _, m, s, a, b, v = args[:7]
|
272 |
do_m, do_s, do_a, do_b, do_v = args[7:]
|
273 |
n_particles = 16
|
274 |
# preprocessing
|
275 |
+
obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx))
|
276 |
+
obs = preprocess_brain(MODELS[dataset_id]["vae_args"], obs)
|
|
|
277 |
for k, _v in obs.items():
|
278 |
if n_particles > 1:
|
279 |
+
ndims = (1,) * 3 if k == "x" else (1,)
|
280 |
obs[k] = _v.repeat(n_particles, *ndims)
|
281 |
# interventions(s)
|
282 |
do_pa = {}
|
283 |
if do_m:
|
284 |
+
do_pa["mri_seq"] = torch.tensor(MRISEQ_CAT.index(m)).view(1, 1)
|
285 |
if do_s:
|
286 |
+
do_pa["sex"] = torch.tensor(SEX_CAT.index(s)).view(1, 1)
|
287 |
if do_a:
|
288 |
+
do_pa["age"] = torch.tensor(a).view(1, 1)
|
289 |
if do_b:
|
290 |
+
do_pa["brain_volume"] = torch.tensor(b * 1000).view(1, 1)
|
291 |
if do_v:
|
292 |
+
do_pa["ventricle_volume"] = torch.tensor(v * 1000).view(1, 1)
|
293 |
# normalize continuous attributes
|
294 |
+
for k in ["age", "brain_volume", "ventricle_volume"]:
|
295 |
if k in do_pa.keys():
|
296 |
k_max, k_min = get_attr_max_min(k)
|
297 |
do_pa[k] = (do_pa[k] - k_min) / (k_max - k_min) # [0,1]
|
298 |
do_pa[k] = 2 * do_pa[k] - 1 # [-1,1]
|
299 |
|
300 |
for k, _v in do_pa.items():
|
301 |
+
do_pa[k] = (
|
302 |
+
_v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1)
|
303 |
+
)
|
304 |
# infer counterfactual
|
305 |
out = counterfactual_inference(dataset_id, obs, do_pa)
|
306 |
# avg cf particles
|
307 |
+
cf_x = out["cf_x"].mean(0)
|
308 |
+
cf_x_std = out["cf_x"].std(0)
|
309 |
+
rec_x = out["rec_x"].mean(0)
|
310 |
+
cf_m = out["cf_pa"]["mri_seq"].mean(0)
|
311 |
+
cf_s = out["cf_pa"]["sex"].mean(0)
|
312 |
# post process
|
313 |
cf_x = postprocess(cf_x)
|
314 |
cf_x_std = cf_x_std.squeeze().detach().cpu().numpy()
|
|
|
316 |
cf_m = MRISEQ_CAT[int(cf_m.item())]
|
317 |
cf_s = SEX_CAT[int(cf_s.item())]
|
318 |
cf_ = {}
|
319 |
+
for k in ["age", "brain_volume", "ventricle_volume"]: # unnormalize
|
320 |
k_max, k_min = get_attr_max_min(k)
|
321 |
+
cf_[k] = (out["cf_pa"][k].mean(0).item() + 1) / 2 * (k_max - k_min) + k_min
|
322 |
# plots
|
323 |
+
# plt.close('all')
|
324 |
effect = cf_x - rec_x
|
325 |
+
effect = get_fig_arr(
|
326 |
+
effect,
|
327 |
+
cmap="RdBu_r",
|
328 |
+
norm=MidpointNormalize(vmin=effect.min(), midpoint=0, vmax=effect.max()),
|
329 |
+
)
|
330 |
cf_x = get_fig_arr(cf_x)
|
331 |
+
cf_x_std = get_fig_arr(cf_x_std, cmap="jet")
|
332 |
+
return (
|
333 |
+
cf_x,
|
334 |
+
cf_x_std,
|
335 |
+
effect,
|
336 |
+
cf_m,
|
337 |
+
cf_s,
|
338 |
+
np.round(cf_["age"], 1),
|
339 |
+
np.round(cf_["brain_volume"] / 1000, 2),
|
340 |
+
np.round(cf_["ventricle_volume"] / 1000, 2),
|
341 |
+
)
|
342 |
|
343 |
|
344 |
def infer_chest_cf(*args):
|
345 |
+
dataset_id = "Chest X-ray"
|
346 |
idx, _, r, s, f, a = args[:6]
|
347 |
do_r, do_s, do_f, do_a = args[6:]
|
348 |
n_particles = 16
|
349 |
# preprocessing
|
350 |
+
obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx))
|
351 |
for k, v in obs.items():
|
352 |
+
obs[k] = v.to(MODELS[dataset_id]["vae_args"].device).float()
|
353 |
if n_particles > 1:
|
354 |
+
ndims = (1,) * 3 if k == "x" else (1,)
|
355 |
obs[k] = obs[k].repeat(n_particles, *ndims)
|
356 |
# intervention(s)
|
357 |
do_pa = {}
|
358 |
with torch.no_grad():
|
359 |
if do_s:
|
360 |
+
do_pa["sex"] = torch.tensor(SEX_CAT_CHEST.index(s)).view(1, 1)
|
361 |
if do_f:
|
362 |
+
do_pa["finding"] = torch.tensor(FIND_CAT.index(f)).view(1, 1)
|
363 |
if do_r:
|
364 |
+
do_pa["race"] = F.one_hot(
|
365 |
+
torch.tensor(RACE_CAT.index(r)), num_classes=3
|
366 |
+
).view(1, 3)
|
367 |
if do_a:
|
368 |
+
do_pa["age"] = torch.tensor(a / 100 * 2 - 1).view(1, 1)
|
369 |
for k, v in do_pa.items():
|
370 |
+
do_pa[k] = (
|
371 |
+
v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1)
|
372 |
+
)
|
373 |
# infer counterfactual
|
374 |
out = counterfactual_inference(dataset_id, obs, do_pa)
|
375 |
# avg cf particles
|
376 |
+
cf_x = out["cf_x"].mean(0)
|
377 |
+
cf_x_std = out["cf_x"].std(0)
|
378 |
+
rec_x = out["rec_x"].mean(0)
|
379 |
+
cf_r = out["cf_pa"]["race"].mean(0)
|
380 |
+
cf_s = out["cf_pa"]["sex"].mean(0)
|
381 |
+
cf_f = out["cf_pa"]["finding"].mean(0)
|
382 |
+
cf_a = out["cf_pa"]["age"].mean(0)
|
383 |
# post process
|
384 |
cf_x = postprocess(cf_x)
|
385 |
cf_x_std = cf_x_std.squeeze().detach().cpu().numpy()
|
|
|
391 |
# plots
|
392 |
# plt.close('all')
|
393 |
effect = cf_x - rec_x
|
394 |
+
effect = get_fig_arr(
|
395 |
+
effect,
|
396 |
+
cmap="RdBu_r",
|
397 |
+
norm=MidpointNormalize(vmin=effect.min(), midpoint=0, vmax=effect.max()),
|
398 |
+
)
|
399 |
cf_x = get_fig_arr(cf_x)
|
400 |
+
cf_x_std = get_fig_arr(cf_x_std, cmap="jet")
|
401 |
return (cf_x, cf_x_std, effect, cf_r, cf_s, cf_f, np.round(cf_a, 1))
|
402 |
|
403 |
|
|
|
409 |
with gr.Row().style(equal_height=True):
|
410 |
idx = gr.Number(value=0, visible=False)
|
411 |
with gr.Column(scale=1, min_width=200):
|
412 |
+
x = gr.Image(label="Observation", interactive=False).style(
|
413 |
+
height=HEIGHT
|
414 |
+
)
|
415 |
with gr.Column(scale=1, min_width=200):
|
416 |
+
cf_x = gr.Image(label="Counterfactual", interactive=False).style(
|
417 |
+
height=HEIGHT
|
418 |
+
)
|
419 |
with gr.Column(scale=1, min_width=200):
|
420 |
+
cf_x_std = gr.Image(
|
421 |
+
label="Counterfactual Uncertainty", interactive=False
|
422 |
+
).style(height=HEIGHT)
|
423 |
with gr.Column(scale=1, min_width=200):
|
424 |
+
effect = gr.Image(
|
425 |
+
label="Direct Causal Effect", interactive=False
|
426 |
+
).style(height=HEIGHT)
|
427 |
with gr.Row().style(equal_height=True):
|
428 |
with gr.Column(scale=1.75):
|
429 |
+
gr.Markdown(
|
430 |
+
"#### Intervention"
|
431 |
+
+ 28 * " "
|
432 |
+
+ "[arXiv paper](https://arxiv.org/abs/2306.15764)   |   [GitHub code](https://github.com/biomedia-mira/causal-gen)"
|
433 |
+
)
|
434 |
with gr.Column():
|
435 |
do_y = gr.Checkbox(label="do(digit)", value=False)
|
436 |
y = gr.Radio(DIGITS, label="", interactive=False)
|
437 |
with gr.Row():
|
438 |
with gr.Column(min_width=100):
|
439 |
do_t = gr.Checkbox(label="do(thickness)", value=False)
|
440 |
+
t = gr.Slider(
|
441 |
+
label="\u00A0",
|
442 |
+
minimum=0.9,
|
443 |
+
maximum=5.5,
|
444 |
+
step=0.01,
|
445 |
+
interactive=False,
|
446 |
+
)
|
447 |
with gr.Column(min_width=100):
|
448 |
do_i = gr.Checkbox(label="do(intensity)", value=False)
|
449 |
+
i = gr.Slider(
|
450 |
+
label="\u00A0",
|
451 |
+
minimum=50,
|
452 |
+
maximum=255,
|
453 |
+
step=0.01,
|
454 |
+
interactive=False,
|
455 |
+
)
|
456 |
with gr.Row():
|
457 |
new = gr.Button("New Observation")
|
458 |
reset = gr.Button("Reset", variant="stop")
|
459 |
submit = gr.Button("Submit", variant="primary")
|
460 |
with gr.Column(scale=1):
|
461 |
gr.Markdown("### ")
|
462 |
+
causal_graph = gr.Image(
|
463 |
+
label="Causal Graph", interactive=False
|
464 |
+
).style(height=300)
|
465 |
|
466 |
with gr.TabItem("Brain MRI") as brain_tab:
|
467 |
brain_id = gr.Textbox(value=brain_tab.label, visible=False)
|
|
|
469 |
with gr.Row().style(equal_height=True):
|
470 |
idx_brain = gr.Number(value=0, visible=False)
|
471 |
with gr.Column(scale=1, min_width=200):
|
472 |
+
x_brain = gr.Image(label="Observation", interactive=False).style(
|
473 |
+
height=HEIGHT
|
474 |
+
)
|
475 |
with gr.Column(scale=1, min_width=200):
|
476 |
+
cf_x_brain = gr.Image(
|
477 |
+
label="Counterfactual", interactive=False
|
478 |
+
).style(height=HEIGHT)
|
479 |
with gr.Column(scale=1, min_width=200):
|
480 |
+
cf_x_std_brain = gr.Image(
|
481 |
+
label="Counterfactual Uncertainty", interactive=False
|
482 |
+
).style(height=HEIGHT)
|
483 |
with gr.Column(scale=1, min_width=200):
|
484 |
+
effect_brain = gr.Image(
|
485 |
+
label="Direct Causal Effect", interactive=False
|
486 |
+
).style(height=HEIGHT)
|
487 |
with gr.Row():
|
488 |
with gr.Column(scale=2.55):
|
489 |
+
gr.Markdown(
|
490 |
+
"#### Intervention"
|
491 |
+
+ 28 * " "
|
492 |
+
+ "[arXiv paper](https://arxiv.org/abs/2306.15764)   |   [GitHub code](https://github.com/biomedia-mira/causal-gen)"
|
493 |
+
)
|
494 |
with gr.Row():
|
495 |
with gr.Column(min_width=200):
|
496 |
do_m = gr.Checkbox(label="do(MRI sequence)", value=False)
|
497 |
+
m = gr.Radio(
|
498 |
+
["T1", "T2-FLAIR"], label="", interactive=False
|
499 |
+
)
|
500 |
with gr.Column(min_width=200):
|
501 |
do_s = gr.Checkbox(label="do(sex)", value=False)
|
502 |
+
s = gr.Radio(
|
503 |
+
["female", "male"], label="", interactive=False
|
504 |
+
)
|
505 |
with gr.Row():
|
506 |
with gr.Column(min_width=100):
|
507 |
do_a = gr.Checkbox(label="do(age)", value=False)
|
508 |
+
a = gr.Slider(
|
509 |
+
label="\u00A0",
|
510 |
+
value=50,
|
511 |
+
minimum=44,
|
512 |
+
maximum=73,
|
513 |
+
step=1,
|
514 |
+
interactive=False,
|
515 |
+
)
|
516 |
with gr.Column(min_width=100):
|
517 |
do_b = gr.Checkbox(label="do(brain volume)", value=False)
|
518 |
+
b = gr.Slider(
|
519 |
+
label="\u00A0",
|
520 |
+
value=1000,
|
521 |
+
minimum=850,
|
522 |
+
maximum=1550,
|
523 |
+
step=20,
|
524 |
+
interactive=False,
|
525 |
+
)
|
526 |
with gr.Column(min_width=100):
|
527 |
+
do_v = gr.Checkbox(
|
528 |
+
label="do(ventricle volume)", value=False
|
529 |
+
)
|
530 |
+
v = gr.Slider(
|
531 |
+
label="\u00A0",
|
532 |
+
value=40,
|
533 |
+
minimum=10,
|
534 |
+
maximum=125,
|
535 |
+
step=2,
|
536 |
+
interactive=False,
|
537 |
+
)
|
538 |
with gr.Row():
|
539 |
new_brain = gr.Button("New Observation")
|
540 |
+
reset_brain = gr.Button("Reset", variant="stop")
|
541 |
+
submit_brain = gr.Button("Submit", variant="primary")
|
542 |
with gr.Column(scale=1):
|
543 |
# gr.Markdown("### ")
|
544 |
+
causal_graph_brain = gr.Image(
|
545 |
+
label="Causal Graph", interactive=False
|
546 |
+
).style(height=340)
|
547 |
|
548 |
with gr.TabItem("Chest X-ray") as chest_tab:
|
549 |
chest_id = gr.Textbox(value=chest_tab.label, visible=False)
|
|
|
551 |
with gr.Row().style(equal_height=True):
|
552 |
idx_chest = gr.Number(value=0, visible=False)
|
553 |
with gr.Column(scale=1, min_width=200):
|
554 |
+
x_chest = gr.Image(label="Observation", interactive=False).style(
|
555 |
+
height=HEIGHT
|
556 |
+
)
|
557 |
with gr.Column(scale=1, min_width=200):
|
558 |
+
cf_x_chest = gr.Image(
|
559 |
+
label="Counterfactual", interactive=False
|
560 |
+
).style(height=HEIGHT)
|
561 |
with gr.Column(scale=1, min_width=200):
|
562 |
+
cf_x_std_chest = gr.Image(
|
563 |
+
label="Counterfactual Uncertainty", interactive=False
|
564 |
+
).style(height=HEIGHT)
|
565 |
with gr.Column(scale=1, min_width=200):
|
566 |
+
effect_chest = gr.Image(
|
567 |
+
label="Direct Causal Effect", interactive=False
|
568 |
+
).style(height=HEIGHT)
|
569 |
|
570 |
with gr.Row():
|
571 |
with gr.Column(scale=2.55):
|
572 |
+
gr.Markdown(
|
573 |
+
"#### Intervention"
|
574 |
+
+ 28 * " "
|
575 |
+
+ "[arXiv paper](https://arxiv.org/abs/2306.15764)   |   [GitHub code](https://github.com/biomedia-mira/causal-gen)"
|
576 |
+
)
|
577 |
with gr.Row().style(equal_height=True):
|
578 |
with gr.Column(min_width=200):
|
579 |
do_f_chest = gr.Checkbox(label="do(disease)", value=False)
|
580 |
f_chest = gr.Radio(FIND_CAT, label="", interactive=False)
|
581 |
with gr.Column(min_width=200):
|
582 |
do_s_chest = gr.Checkbox(label="do(sex)", value=False)
|
583 |
+
s_chest = gr.Radio(
|
584 |
+
SEX_CAT_CHEST, label="", interactive=False
|
585 |
+
)
|
586 |
|
587 |
with gr.Row():
|
588 |
with gr.Column(min_width=200):
|
589 |
do_r_chest = gr.Checkbox(label="do(race)", value=False)
|
590 |
+
r_chest = gr.Radio(RACE_CAT, label="", interactive=False)
|
591 |
with gr.Column(min_width=200):
|
592 |
do_a_chest = gr.Checkbox(label="do(age)", value=False)
|
593 |
+
a_chest = gr.Slider(
|
594 |
+
label="\u00A0", minimum=18, maximum=98, step=1
|
595 |
+
)
|
596 |
+
|
597 |
with gr.Row():
|
598 |
new_chest = gr.Button("New Observation")
|
599 |
reset_chest = gr.Button("Reset", variant="stop")
|
600 |
submit_chest = gr.Button("Submit", variant="primary")
|
601 |
with gr.Column(scale=1):
|
602 |
# gr.Markdown("### ")
|
603 |
+
causal_graph_chest = gr.Image(
|
604 |
+
label="Causal Graph", interactive=False
|
605 |
+
).style(height=345)
|
606 |
|
607 |
# morphomnist
|
608 |
do = [do_t, do_i, do_y]
|
|
|
644 |
new_chest.click(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest)
|
645 |
|
646 |
# "new" button: reset cf output panels
|
647 |
+
for _k, _v in zip(
|
648 |
+
[new, new_brain, new_chest], [cf_out, cf_out_brain, cf_out_chest]
|
649 |
+
):
|
650 |
+
_k.click(fn=lambda: (gr.update(value=None),) * 3, inputs=None, outputs=_v)
|
651 |
|
652 |
# "reset" button: reload current observations
|
653 |
reset.click(fn=get_mnist_obs, inputs=idx, outputs=obs)
|
654 |
reset_brain.click(fn=get_brain_obs, inputs=idx_brain, outputs=obs_brain)
|
655 |
reset_chest.click(fn=get_chest_obs, inputs=idx_chest, outputs=obs_chest)
|
656 |
+
|
657 |
# "reset" button: deselect intervention checkboxes
|
658 |
+
reset.click(fn=lambda: (gr.update(value=False),) * len(do), inputs=None, outputs=do)
|
659 |
+
reset_brain.click(
|
660 |
+
fn=lambda: (gr.update(value=False),) * len(do_brain),
|
661 |
+
inputs=None,
|
662 |
+
outputs=do_brain,
|
663 |
+
)
|
664 |
+
reset_chest.click(
|
665 |
+
fn=lambda: (gr.update(value=False),) * len(do_chest),
|
666 |
+
inputs=None,
|
667 |
+
outputs=do_chest,
|
668 |
+
)
|
669 |
|
670 |
# "reset" button: reset cf output panels
|
671 |
+
for _k, _v in zip(
|
672 |
+
[reset, reset_brain, reset_chest], [cf_out, cf_out_brain, cf_out_chest]
|
673 |
+
):
|
674 |
+
_k.click(fn=lambda: plt.close("all"), inputs=None, outputs=None)
|
675 |
+
_k.click(fn=lambda: (gr.update(value=None),) * 3, inputs=None, outputs=_v)
|
676 |
|
677 |
# enable mnist interventions when checkbox is selected & update graph
|
678 |
for _k, _v in zip(do, [t, i, y]):
|
679 |
_k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v)
|
680 |
_k.change(mnist_graph, inputs=do, outputs=causal_graph)
|
681 |
+
|
682 |
# enable brain interventions when checkbox is selected & update graph
|
683 |
for _k, _v in zip(do_brain, [m, s, a, b, v]):
|
684 |
_k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v)
|
|
|
688 |
for _k, _v in zip(do_chest, [r_chest, s_chest, f_chest, a_chest]):
|
689 |
_k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v)
|
690 |
_k.change(chest_graph, inputs=do_chest, outputs=causal_graph_chest)
|
691 |
+
|
692 |
# "submit" button: infer countefactuals
|
693 |
submit.click(fn=infer_mnist_cf, inputs=obs + do, outputs=cf_out + [t, i, y])
|
694 |
+
submit_brain.click(
|
695 |
+
fn=infer_brain_cf,
|
696 |
+
inputs=obs_brain + do_brain,
|
697 |
+
outputs=cf_out_brain + [m, s, a, b, v],
|
698 |
+
)
|
699 |
+
submit_chest.click(
|
700 |
+
fn=infer_chest_cf,
|
701 |
+
inputs=obs_chest + do_chest,
|
702 |
+
outputs=cf_out_chest + [r_chest, s_chest, f_chest, a_chest],
|
703 |
+
)
|
704 |
|
705 |
if __name__ == "__main__":
|
706 |
+
demo.queue()
|
707 |
+
demo.launch()
|
app_utils.py
CHANGED
@@ -3,16 +3,18 @@ import numpy as np
|
|
3 |
import networkx as nx
|
4 |
import matplotlib.pyplot as plt
|
5 |
|
|
|
|
|
6 |
from matplotlib import rc, patches, colors
|
7 |
-
|
8 |
-
rc(
|
9 |
-
rc(
|
10 |
-
rc(
|
|
|
11 |
|
12 |
from datasets import get_attr_max_min
|
13 |
|
14 |
-
|
15 |
-
HAMMER = np.array(Image.open('./hammer.png').resize((35, 35))) / 255
|
16 |
|
17 |
|
18 |
class MidpointNormalize(colors.Normalize):
|
@@ -21,7 +23,7 @@ class MidpointNormalize(colors.Normalize):
|
|
21 |
colors.Normalize.__init__(self, vmin, vmax, clip)
|
22 |
|
23 |
def __call__(self, value, clip=None):
|
24 |
-
v_ext = np.max(
|
25 |
x, y = [-v_ext, self.midpoint, v_ext], [0, 0.5, 1]
|
26 |
return np.ma.masked_array(np.interp(value, x, y))
|
27 |
|
@@ -31,10 +33,10 @@ def postprocess(x):
|
|
31 |
|
32 |
|
33 |
def mnist_graph(*args):
|
34 |
-
x, t, i, y = r
|
35 |
-
ut, ui, uy = r
|
36 |
-
zx, ex = r
|
37 |
-
|
38 |
G = nx.DiGraph()
|
39 |
G.add_edge(t, x)
|
40 |
G.add_edge(i, x)
|
@@ -47,31 +49,36 @@ def mnist_graph(*args):
|
|
47 |
G.add_edge(ex, x)
|
48 |
|
49 |
pos = {
|
50 |
-
y: (0, 0),
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
54 |
}
|
55 |
|
56 |
node_c = {}
|
57 |
for node in G:
|
58 |
-
node_c[node] =
|
59 |
-
node_line_c = {k:
|
60 |
-
edge_c = {e:
|
61 |
|
62 |
if args[0]: # do_t
|
63 |
-
edge_c[(ut, t)] =
|
64 |
# G.remove_edge(ut, t)
|
65 |
-
node_line_c[t] =
|
66 |
if args[1]: # do_i
|
67 |
-
edge_c[(ui, i)] =
|
68 |
-
edge_c[(t, i)] =
|
69 |
# G.remove_edges_from([(ui, i), (t, i)])
|
70 |
-
node_line_c[i] =
|
71 |
if args[2]: # do_y
|
72 |
-
edge_c[(uy, y)] =
|
73 |
# G.remove_edge(uy, y)
|
74 |
-
node_line_c[y] =
|
75 |
|
76 |
fs = 30
|
77 |
options = {
|
@@ -83,27 +90,36 @@ def mnist_graph(*args):
|
|
83 |
"linewidths": 2,
|
84 |
"width": 2,
|
85 |
}
|
86 |
-
plt.close(
|
87 |
-
fig, ax = plt.subplots(1, 1, figsize=(6,4.1))
|
88 |
# fig.patch.set_visible(False)
|
89 |
ax.margins(x=0.06, y=0.15, tight=False)
|
90 |
ax.axis("off")
|
91 |
-
nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle=
|
92 |
# need to reuse x, y limits so that the graphs plot the same way before and after removing edges
|
93 |
x_lim = (-1.348, 2.348)
|
94 |
y_lim = (-0.215, 1.215)
|
95 |
ax.set_xlim(x_lim)
|
96 |
ax.set_ylim(y_lim)
|
97 |
-
rect = patches.FancyBboxPatch(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
ax.add_patch(rect)
|
99 |
ax.text(1.85, 0.65, r"$\mathbf{U}_{\mathbf{x}}$", fontsize=fs)
|
100 |
|
101 |
if args[0]: # do_t
|
102 |
-
fig.figimage(HAMMER, 0.26*fig.bbox.xmax, 0.525*fig.bbox.ymax, zorder=10)
|
103 |
if args[1]: # do_i
|
104 |
-
fig.figimage(HAMMER, 0.5175*fig.bbox.xmax, 0.525*fig.bbox.ymax, zorder=11)
|
105 |
if args[2]: # do_y
|
106 |
-
fig.figimage(HAMMER, 0.26*fig.bbox.xmax, 0.2*fig.bbox.ymax, zorder=12)
|
107 |
|
108 |
fig.tight_layout()
|
109 |
fig.canvas.draw()
|
@@ -111,10 +127,16 @@ def mnist_graph(*args):
|
|
111 |
|
112 |
|
113 |
def brain_graph(*args):
|
114 |
-
x, m, s, a, b, v = r
|
115 |
-
um, us, ua, ub, uv =
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
G = nx.DiGraph()
|
119 |
G.add_edge(m, x)
|
120 |
G.add_edge(s, x)
|
@@ -132,44 +154,51 @@ def brain_graph(*args):
|
|
132 |
G.add_edge(uv, v)
|
133 |
|
134 |
pos = {
|
135 |
-
x: (0, 0),
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
}
|
142 |
|
143 |
node_c = {}
|
144 |
for node in G:
|
145 |
-
node_c[node] =
|
146 |
-
node_line_c = {k:
|
147 |
-
edge_c = {e:
|
148 |
|
149 |
if args[0]: # do_m
|
150 |
# G.remove_edge(um, m)
|
151 |
-
edge_c[(um, m)] =
|
152 |
-
node_line_c[m] =
|
153 |
if args[1]: # do_s
|
154 |
# G.remove_edge(us, s)
|
155 |
-
edge_c[(us, s)] =
|
156 |
-
node_line_c[s] =
|
157 |
if args[2]: # do_a
|
158 |
# G.remove_edge(ua, a)
|
159 |
-
edge_c[(ua, a)] =
|
160 |
-
node_line_c[a] =
|
161 |
if args[3]: # do_b
|
162 |
# G.remove_edges_from([(ub, b), (s, b), (a, b)])
|
163 |
-
edge_c[(ub, b)] =
|
164 |
-
edge_c[(s, b)] =
|
165 |
-
edge_c[(a, b)] =
|
166 |
-
node_line_c[b] =
|
167 |
if args[4]: # do_v
|
168 |
# G.remove_edges_from([(uv, v), (a, v), (b, v)])
|
169 |
-
edge_c[(uv, v)] =
|
170 |
-
edge_c[(a, v)] =
|
171 |
-
edge_c[(b, v)] =
|
172 |
-
node_line_c[v] =
|
173 |
|
174 |
fs = 30
|
175 |
options = {
|
@@ -182,33 +211,49 @@ def brain_graph(*args):
|
|
182 |
"width": 2,
|
183 |
}
|
184 |
|
185 |
-
plt.close(
|
186 |
-
fig, ax = plt.subplots(1, 1, figsize=(5,5))
|
187 |
# fig.patch.set_visible(False)
|
188 |
ax.margins(x=0.1, y=0.08, tight=False)
|
189 |
ax.axis("off")
|
190 |
-
nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle=
|
191 |
# need to reuse x, y limits so that the graphs plot the same way before and after removing edges
|
192 |
x_lim = (-1.32, 1.32)
|
193 |
y_lim = (-1.414, 2.414)
|
194 |
ax.set_xlim(x_lim)
|
195 |
ax.set_ylim(y_lim)
|
196 |
-
rect = patches.FancyBboxPatch(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
ax.add_patch(rect)
|
198 |
# ax.text(1.85, 0.65, r"$\mathbf{U}_{\mathbf{x}}$", fontsize=fs)
|
199 |
-
|
200 |
if args[0]: # do_m
|
201 |
-
fig.figimage(HAMMER, 0.0075*fig.bbox.xmax, 0.395*fig.bbox.ymax, zorder=10)
|
202 |
if args[1]: # do_s
|
203 |
-
fig.figimage(HAMMER, 0.72*fig.bbox.xmax, 0.395*fig.bbox.ymax, zorder=11)
|
204 |
if args[2]: # do_a
|
205 |
-
fig.figimage(HAMMER, 0.363*fig.bbox.xmax, 0.64*fig.bbox.ymax, zorder=12)
|
206 |
if args[3]: # do_b
|
207 |
-
fig.figimage(HAMMER, 0.72*fig.bbox.xmax, 0.64*fig.bbox.ymax, zorder=13)
|
208 |
if args[4]: # do_v
|
209 |
-
fig.figimage(HAMMER, 0.0075*fig.bbox.xmax, 0.64*fig.bbox.ymax, zorder=14)
|
210 |
else: # b -> v
|
211 |
-
a3 = patches.FancyArrowPatch(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
ax.add_patch(a3)
|
213 |
# print(ax.get_xlim())
|
214 |
# print(ax.get_ylim())
|
@@ -217,12 +262,16 @@ def brain_graph(*args):
|
|
217 |
return np.array(fig.canvas.renderer.buffer_rgba())
|
218 |
|
219 |
|
220 |
-
|
221 |
def chest_graph(*args):
|
222 |
-
x, a, d, r, s= r
|
223 |
-
ua, ud, ur, us =
|
224 |
-
|
225 |
-
|
|
|
|
|
|
|
|
|
|
|
226 |
G = nx.DiGraph()
|
227 |
G.add_edge(ua, a)
|
228 |
G.add_edge(ud, d)
|
@@ -237,7 +286,7 @@ def chest_graph(*args):
|
|
237 |
G.add_edge(a, x)
|
238 |
|
239 |
pos = {
|
240 |
-
x: (0, 0),
|
241 |
a: (-1, 1),
|
242 |
d: (0, 1),
|
243 |
r: (1, 1),
|
@@ -246,34 +295,34 @@ def chest_graph(*args):
|
|
246 |
ud: (0, 2),
|
247 |
ur: (1, 2),
|
248 |
us: (1, -1),
|
249 |
-
zx: (-0.25, -1),
|
250 |
ex: (0.25, -1),
|
251 |
}
|
252 |
|
253 |
node_c = {}
|
254 |
for node in G:
|
255 |
-
node_c[node] =
|
256 |
|
257 |
-
edge_c = {e:
|
258 |
-
node_line_c = {k:
|
259 |
|
260 |
if args[0]: # do_r
|
261 |
# G.remove_edge(ur, r)
|
262 |
-
edge_c[(ur, r)] =
|
263 |
-
node_line_c[r] =
|
264 |
if args[1]: # do_s
|
265 |
# G.remove_edges_from([(us, s)])
|
266 |
-
edge_c[(us, s)] =
|
267 |
-
node_line_c[s] =
|
268 |
if args[2]: # do_f (do_d)
|
269 |
# G.remove_edges_from([(ud, d), (a, d)])
|
270 |
-
edge_c[(ud, d)] =
|
271 |
-
edge_c[(a, d)] =
|
272 |
-
node_line_c[d] =
|
273 |
if args[3]: # do_a
|
274 |
# G.remove_edge(ua, a)
|
275 |
-
edge_c[(ua, a)] =
|
276 |
-
node_line_c[a] =
|
277 |
|
278 |
fs = 30
|
279 |
options = {
|
@@ -285,29 +334,38 @@ def chest_graph(*args):
|
|
285 |
"linewidths": 2,
|
286 |
"width": 2,
|
287 |
}
|
288 |
-
plt.close(
|
289 |
-
fig, ax = plt.subplots(1, 1, figsize=(5,5))
|
290 |
# fig.patch.set_visible(False)
|
291 |
ax.margins(x=0.1, y=0.08, tight=False)
|
292 |
ax.axis("off")
|
293 |
-
nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle=
|
294 |
# need to reuse x, y limits so that the graphs plot the same way before and after removing edges
|
295 |
x_lim = (-1.32, 1.32)
|
296 |
y_lim = (-1.414, 2.414)
|
297 |
ax.set_xlim(x_lim)
|
298 |
ax.set_ylim(y_lim)
|
299 |
-
rect = patches.FancyBboxPatch(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
ax.add_patch(rect)
|
301 |
ax.text(-0.9, -1.075, r"$\mathbf{U}_{\mathbf{x}}$", fontsize=fs)
|
302 |
-
|
303 |
if args[0]: # do_r
|
304 |
-
fig.figimage(HAMMER, 0.72*fig.bbox.xmax, 0.64*fig.bbox.ymax, zorder=10)
|
305 |
if args[1]: # do_s
|
306 |
-
fig.figimage(HAMMER, 0.72*fig.bbox.xmax, 0.395*fig.bbox.ymax, zorder=11)
|
307 |
if args[2]: # do_f
|
308 |
-
fig.figimage(HAMMER, 0.363*fig.bbox.xmax, 0.64*fig.bbox.ymax, zorder=12)
|
309 |
if args[3]: # do_a
|
310 |
-
fig.figimage(HAMMER, 0.0075*fig.bbox.xmax, 0.64*fig.bbox.ymax, zorder=13)
|
311 |
|
312 |
fig.tight_layout()
|
313 |
fig.canvas.draw()
|
@@ -315,51 +373,55 @@ def chest_graph(*args):
|
|
315 |
|
316 |
|
317 |
def vae_preprocess(args, pa):
|
318 |
-
if
|
319 |
# preprocessing ukbb parents for the vae which was originally trained using
|
320 |
# log standardized parents. The pgm was trained using [-1,1] normalization
|
321 |
# first undo [-1,1] parent preprocessing back to original range
|
322 |
for k, v in pa.items():
|
323 |
-
if k !=
|
324 |
pa[k] = (v + 1) / 2 # [-1,1] -> [0,1]
|
325 |
_max, _min = get_attr_max_min(k)
|
326 |
pa[k] = pa[k] * (_max - _min) + _min
|
327 |
# log_standardize parents for vae input
|
328 |
for k, v in pa.items():
|
329 |
logpa_k = torch.log(v.clamp(min=1e-12))
|
330 |
-
if k ==
|
331 |
pa[k] = (logpa_k - 4.112339973449707) / 0.11769197136163712
|
332 |
-
elif k ==
|
333 |
pa[k] = (logpa_k - 13.965583801269531) / 0.09537758678197861
|
334 |
-
elif k ==
|
335 |
pa[k] = (logpa_k - 10.345998764038086) / 0.43127763271331787
|
336 |
# concatenate parents expand to input res for conditioning the vae
|
337 |
-
pa = torch.cat(
|
338 |
-
|
339 |
-
|
|
|
|
|
|
|
|
|
340 |
return pa
|
341 |
|
342 |
|
343 |
def preprocess_brain(args, obs):
|
344 |
-
obs[
|
345 |
# for all other variables except x
|
346 |
-
for k in [k for k in obs.keys() if k !=
|
347 |
obs[k] = obs[k].float().to(args.device).view(1, 1)
|
348 |
-
if k in [
|
349 |
k_max, k_min = get_attr_max_min(k)
|
350 |
obs[k] = (obs[k] - k_min) / (k_max - k_min) # [0,1]
|
351 |
obs[k] = 2 * obs[k] - 1 # [-1,1]
|
352 |
return obs
|
353 |
|
354 |
|
355 |
-
def get_fig_arr(x, width=4, height=4, dpi=144, cmap=
|
356 |
fig = plt.figure(figsize=(width, height), dpi=dpi)
|
357 |
-
ax = plt.axes([0,0,1,1], frameon=False)
|
358 |
-
if cmap ==
|
359 |
ax.imshow(x, cmap=cmap, vmin=0, vmax=255)
|
360 |
else:
|
361 |
ax.imshow(x, cmap=cmap, norm=norm)
|
362 |
-
ax.axis(
|
363 |
fig.canvas.draw()
|
364 |
return np.array(fig.canvas.renderer.buffer_rgba())
|
365 |
|
@@ -370,4 +432,4 @@ def normalize(x, x_min=None, x_max=None, zero_one=False):
|
|
370 |
if x_max is None:
|
371 |
x_max = x.max()
|
372 |
x = (x - x_min) / (x_max - x_min) # [0,1]
|
373 |
-
return x if zero_one else 2 * x - 1 # else [-1,1]
|
|
|
3 |
import networkx as nx
|
4 |
import matplotlib.pyplot as plt
|
5 |
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
from matplotlib import rc, patches, colors
|
9 |
+
|
10 |
+
rc("font", **{"family": "serif", "serif": ["Roman"]})
|
11 |
+
rc("text", usetex=True)
|
12 |
+
rc("image", interpolation="none")
|
13 |
+
rc("text.latex", preamble=r"\usepackage{amsmath} \usepackage{amssymb}")
|
14 |
|
15 |
from datasets import get_attr_max_min
|
16 |
|
17 |
+
HAMMER = np.array(Image.open("./hammer.png").resize((35, 35))) / 255
|
|
|
18 |
|
19 |
|
20 |
class MidpointNormalize(colors.Normalize):
|
|
|
23 |
colors.Normalize.__init__(self, vmin, vmax, clip)
|
24 |
|
25 |
def __call__(self, value, clip=None):
|
26 |
+
v_ext = np.max([np.abs(self.vmin), np.abs(self.vmax)])
|
27 |
x, y = [-v_ext, self.midpoint, v_ext], [0, 0.5, 1]
|
28 |
return np.ma.masked_array(np.interp(value, x, y))
|
29 |
|
|
|
33 |
|
34 |
|
35 |
def mnist_graph(*args):
|
36 |
+
x, t, i, y = r"$\mathbf{x}$", r"$t$", r"$i$", r"$y$"
|
37 |
+
ut, ui, uy = r"$\mathbf{U}_t$", r"$\mathbf{U}_i$", r"$\mathbf{U}_y$"
|
38 |
+
zx, ex = r"$\mathbf{z}_{1:L}$", r"$\boldsymbol{\epsilon}$"
|
39 |
+
|
40 |
G = nx.DiGraph()
|
41 |
G.add_edge(t, x)
|
42 |
G.add_edge(i, x)
|
|
|
49 |
G.add_edge(ex, x)
|
50 |
|
51 |
pos = {
|
52 |
+
y: (0, 0),
|
53 |
+
uy: (-1, 0),
|
54 |
+
t: (0, 0.5),
|
55 |
+
ut: (0, 1),
|
56 |
+
x: (1, 0),
|
57 |
+
zx: (2, 0.375),
|
58 |
+
ex: (2, 0),
|
59 |
+
i: (1, 0.5),
|
60 |
+
ui: (1, 1),
|
61 |
}
|
62 |
|
63 |
node_c = {}
|
64 |
for node in G:
|
65 |
+
node_c[node] = "lightgrey" if node in [x, t, i, y] else "white"
|
66 |
+
node_line_c = {k: "black" for k, _ in node_c.items()}
|
67 |
+
edge_c = {e: "black" for e in G.edges}
|
68 |
|
69 |
if args[0]: # do_t
|
70 |
+
edge_c[(ut, t)] = "lightgrey"
|
71 |
# G.remove_edge(ut, t)
|
72 |
+
node_line_c[t] = "red"
|
73 |
if args[1]: # do_i
|
74 |
+
edge_c[(ui, i)] = "lightgrey"
|
75 |
+
edge_c[(t, i)] = "lightgrey"
|
76 |
# G.remove_edges_from([(ui, i), (t, i)])
|
77 |
+
node_line_c[i] = "red"
|
78 |
if args[2]: # do_y
|
79 |
+
edge_c[(uy, y)] = "lightgrey"
|
80 |
# G.remove_edge(uy, y)
|
81 |
+
node_line_c[y] = "red"
|
82 |
|
83 |
fs = 30
|
84 |
options = {
|
|
|
90 |
"linewidths": 2,
|
91 |
"width": 2,
|
92 |
}
|
93 |
+
plt.close("all")
|
94 |
+
fig, ax = plt.subplots(1, 1, figsize=(6, 4.1)) # , constrained_layout=True)
|
95 |
# fig.patch.set_visible(False)
|
96 |
ax.margins(x=0.06, y=0.15, tight=False)
|
97 |
ax.axis("off")
|
98 |
+
nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle="-|>", ax=ax)
|
99 |
# need to reuse x, y limits so that the graphs plot the same way before and after removing edges
|
100 |
x_lim = (-1.348, 2.348)
|
101 |
y_lim = (-0.215, 1.215)
|
102 |
ax.set_xlim(x_lim)
|
103 |
ax.set_ylim(y_lim)
|
104 |
+
rect = patches.FancyBboxPatch(
|
105 |
+
(1.75, -0.16),
|
106 |
+
0.5,
|
107 |
+
0.7,
|
108 |
+
boxstyle="round, pad=0.05, rounding_size=0",
|
109 |
+
linewidth=2,
|
110 |
+
edgecolor="black",
|
111 |
+
facecolor="none",
|
112 |
+
linestyle="-",
|
113 |
+
)
|
114 |
ax.add_patch(rect)
|
115 |
ax.text(1.85, 0.65, r"$\mathbf{U}_{\mathbf{x}}$", fontsize=fs)
|
116 |
|
117 |
if args[0]: # do_t
|
118 |
+
fig.figimage(HAMMER, 0.26 * fig.bbox.xmax, 0.525 * fig.bbox.ymax, zorder=10)
|
119 |
if args[1]: # do_i
|
120 |
+
fig.figimage(HAMMER, 0.5175 * fig.bbox.xmax, 0.525 * fig.bbox.ymax, zorder=11)
|
121 |
if args[2]: # do_y
|
122 |
+
fig.figimage(HAMMER, 0.26 * fig.bbox.xmax, 0.2 * fig.bbox.ymax, zorder=12)
|
123 |
|
124 |
fig.tight_layout()
|
125 |
fig.canvas.draw()
|
|
|
127 |
|
128 |
|
129 |
def brain_graph(*args):
|
130 |
+
x, m, s, a, b, v = r"$\mathbf{x}$", r"$m$", r"$s$", r"$a$", r"$b$", r"$v$"
|
131 |
+
um, us, ua, ub, uv = (
|
132 |
+
r"$\mathbf{U}_m$",
|
133 |
+
r"$\mathbf{U}_s$",
|
134 |
+
r"$\mathbf{U}_a$",
|
135 |
+
r"$\mathbf{U}_b$",
|
136 |
+
r"$\mathbf{U}_v$",
|
137 |
+
)
|
138 |
+
zx, ex = r"$\mathbf{z}_{1:L}$", r"$\boldsymbol{\epsilon}$"
|
139 |
+
|
140 |
G = nx.DiGraph()
|
141 |
G.add_edge(m, x)
|
142 |
G.add_edge(s, x)
|
|
|
154 |
G.add_edge(uv, v)
|
155 |
|
156 |
pos = {
|
157 |
+
x: (0, 0),
|
158 |
+
zx: (-0.25, -1),
|
159 |
+
ex: (0.25, -1),
|
160 |
+
a: (0, 1),
|
161 |
+
ua: (0, 2),
|
162 |
+
s: (1, 0),
|
163 |
+
us: (1, -1),
|
164 |
+
b: (1, 1),
|
165 |
+
ub: (1, 2),
|
166 |
+
m: (-1, 0),
|
167 |
+
um: (-1, -1),
|
168 |
+
v: (-1, 1),
|
169 |
+
uv: (-1, 2),
|
170 |
}
|
171 |
|
172 |
node_c = {}
|
173 |
for node in G:
|
174 |
+
node_c[node] = "lightgrey" if node in [x, m, s, a, b, v] else "white"
|
175 |
+
node_line_c = {k: "black" for k, _ in node_c.items()}
|
176 |
+
edge_c = {e: "black" for e in G.edges}
|
177 |
|
178 |
if args[0]: # do_m
|
179 |
# G.remove_edge(um, m)
|
180 |
+
edge_c[(um, m)] = "lightgrey"
|
181 |
+
node_line_c[m] = "red"
|
182 |
if args[1]: # do_s
|
183 |
# G.remove_edge(us, s)
|
184 |
+
edge_c[(us, s)] = "lightgrey"
|
185 |
+
node_line_c[s] = "red"
|
186 |
if args[2]: # do_a
|
187 |
# G.remove_edge(ua, a)
|
188 |
+
edge_c[(ua, a)] = "lightgrey"
|
189 |
+
node_line_c[a] = "red"
|
190 |
if args[3]: # do_b
|
191 |
# G.remove_edges_from([(ub, b), (s, b), (a, b)])
|
192 |
+
edge_c[(ub, b)] = "lightgrey"
|
193 |
+
edge_c[(s, b)] = "lightgrey"
|
194 |
+
edge_c[(a, b)] = "lightgrey"
|
195 |
+
node_line_c[b] = "red"
|
196 |
if args[4]: # do_v
|
197 |
# G.remove_edges_from([(uv, v), (a, v), (b, v)])
|
198 |
+
edge_c[(uv, v)] = "lightgrey"
|
199 |
+
edge_c[(a, v)] = "lightgrey"
|
200 |
+
edge_c[(b, v)] = "lightgrey"
|
201 |
+
node_line_c[v] = "red"
|
202 |
|
203 |
fs = 30
|
204 |
options = {
|
|
|
211 |
"width": 2,
|
212 |
}
|
213 |
|
214 |
+
plt.close("all")
|
215 |
+
fig, ax = plt.subplots(1, 1, figsize=(5, 5)) # , constrained_layout=True)
|
216 |
# fig.patch.set_visible(False)
|
217 |
ax.margins(x=0.1, y=0.08, tight=False)
|
218 |
ax.axis("off")
|
219 |
+
nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle="-|>", ax=ax)
|
220 |
# need to reuse x, y limits so that the graphs plot the same way before and after removing edges
|
221 |
x_lim = (-1.32, 1.32)
|
222 |
y_lim = (-1.414, 2.414)
|
223 |
ax.set_xlim(x_lim)
|
224 |
ax.set_ylim(y_lim)
|
225 |
+
rect = patches.FancyBboxPatch(
|
226 |
+
(-0.5, -1.325),
|
227 |
+
1,
|
228 |
+
0.65,
|
229 |
+
boxstyle="round, pad=0.05, rounding_size=0",
|
230 |
+
linewidth=2,
|
231 |
+
edgecolor="black",
|
232 |
+
facecolor="none",
|
233 |
+
linestyle="-",
|
234 |
+
)
|
235 |
ax.add_patch(rect)
|
236 |
# ax.text(1.85, 0.65, r"$\mathbf{U}_{\mathbf{x}}$", fontsize=fs)
|
237 |
+
|
238 |
if args[0]: # do_m
|
239 |
+
fig.figimage(HAMMER, 0.0075 * fig.bbox.xmax, 0.395 * fig.bbox.ymax, zorder=10)
|
240 |
if args[1]: # do_s
|
241 |
+
fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.395 * fig.bbox.ymax, zorder=11)
|
242 |
if args[2]: # do_a
|
243 |
+
fig.figimage(HAMMER, 0.363 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=12)
|
244 |
if args[3]: # do_b
|
245 |
+
fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=13)
|
246 |
if args[4]: # do_v
|
247 |
+
fig.figimage(HAMMER, 0.0075 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=14)
|
248 |
else: # b -> v
|
249 |
+
a3 = patches.FancyArrowPatch(
|
250 |
+
(0.86, 1.21),
|
251 |
+
(-0.86, 1.21),
|
252 |
+
connectionstyle="arc3,rad=.3",
|
253 |
+
linewidth=2,
|
254 |
+
arrowstyle="simple, head_width=10, head_length=10",
|
255 |
+
color="k",
|
256 |
+
)
|
257 |
ax.add_patch(a3)
|
258 |
# print(ax.get_xlim())
|
259 |
# print(ax.get_ylim())
|
|
|
262 |
return np.array(fig.canvas.renderer.buffer_rgba())
|
263 |
|
264 |
|
|
|
265 |
def chest_graph(*args):
|
266 |
+
x, a, d, r, s = r"$\mathbf{x}$", r"$a$", r"$d$", r"$r$", r"$s$"
|
267 |
+
ua, ud, ur, us = (
|
268 |
+
r"$\mathbf{U}_a$",
|
269 |
+
r"$\mathbf{U}_d$",
|
270 |
+
r"$\mathbf{U}_r$",
|
271 |
+
r"$\mathbf{U}_s$",
|
272 |
+
)
|
273 |
+
zx, ex = r"$\mathbf{z}_{1:L}$", r"$\boldsymbol{\epsilon}$"
|
274 |
+
|
275 |
G = nx.DiGraph()
|
276 |
G.add_edge(ua, a)
|
277 |
G.add_edge(ud, d)
|
|
|
286 |
G.add_edge(a, x)
|
287 |
|
288 |
pos = {
|
289 |
+
x: (0, 0),
|
290 |
a: (-1, 1),
|
291 |
d: (0, 1),
|
292 |
r: (1, 1),
|
|
|
295 |
ud: (0, 2),
|
296 |
ur: (1, 2),
|
297 |
us: (1, -1),
|
298 |
+
zx: (-0.25, -1),
|
299 |
ex: (0.25, -1),
|
300 |
}
|
301 |
|
302 |
node_c = {}
|
303 |
for node in G:
|
304 |
+
node_c[node] = "lightgrey" if node in [x, a, d, r, s] else "white"
|
305 |
|
306 |
+
edge_c = {e: "black" for e in G.edges}
|
307 |
+
node_line_c = {k: "black" for k, _ in node_c.items()}
|
308 |
|
309 |
if args[0]: # do_r
|
310 |
# G.remove_edge(ur, r)
|
311 |
+
edge_c[(ur, r)] = "lightgrey"
|
312 |
+
node_line_c[r] = "red"
|
313 |
if args[1]: # do_s
|
314 |
# G.remove_edges_from([(us, s)])
|
315 |
+
edge_c[(us, s)] = "lightgrey"
|
316 |
+
node_line_c[s] = "red"
|
317 |
if args[2]: # do_f (do_d)
|
318 |
# G.remove_edges_from([(ud, d), (a, d)])
|
319 |
+
edge_c[(ud, d)] = "lightgrey"
|
320 |
+
edge_c[(a, d)] = "lightgrey"
|
321 |
+
node_line_c[d] = "red"
|
322 |
if args[3]: # do_a
|
323 |
# G.remove_edge(ua, a)
|
324 |
+
edge_c[(ua, a)] = "lightgrey"
|
325 |
+
node_line_c[a] = "red"
|
326 |
|
327 |
fs = 30
|
328 |
options = {
|
|
|
334 |
"linewidths": 2,
|
335 |
"width": 2,
|
336 |
}
|
337 |
+
plt.close("all")
|
338 |
+
fig, ax = plt.subplots(1, 1, figsize=(5, 5)) # , constrained_layout=True)
|
339 |
# fig.patch.set_visible(False)
|
340 |
ax.margins(x=0.1, y=0.08, tight=False)
|
341 |
ax.axis("off")
|
342 |
+
nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle="-|>", ax=ax)
|
343 |
# need to reuse x, y limits so that the graphs plot the same way before and after removing edges
|
344 |
x_lim = (-1.32, 1.32)
|
345 |
y_lim = (-1.414, 2.414)
|
346 |
ax.set_xlim(x_lim)
|
347 |
ax.set_ylim(y_lim)
|
348 |
+
rect = patches.FancyBboxPatch(
|
349 |
+
(-0.5, -1.325),
|
350 |
+
1,
|
351 |
+
0.65,
|
352 |
+
boxstyle="round, pad=0.05, rounding_size=0",
|
353 |
+
linewidth=2,
|
354 |
+
edgecolor="black",
|
355 |
+
facecolor="none",
|
356 |
+
linestyle="-",
|
357 |
+
)
|
358 |
ax.add_patch(rect)
|
359 |
ax.text(-0.9, -1.075, r"$\mathbf{U}_{\mathbf{x}}$", fontsize=fs)
|
360 |
+
|
361 |
if args[0]: # do_r
|
362 |
+
fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=10)
|
363 |
if args[1]: # do_s
|
364 |
+
fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.395 * fig.bbox.ymax, zorder=11)
|
365 |
if args[2]: # do_f
|
366 |
+
fig.figimage(HAMMER, 0.363 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=12)
|
367 |
if args[3]: # do_a
|
368 |
+
fig.figimage(HAMMER, 0.0075 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=13)
|
369 |
|
370 |
fig.tight_layout()
|
371 |
fig.canvas.draw()
|
|
|
373 |
|
374 |
|
375 |
def vae_preprocess(args, pa):
|
376 |
+
if "ukbb" in args.hps:
|
377 |
# preprocessing ukbb parents for the vae which was originally trained using
|
378 |
# log standardized parents. The pgm was trained using [-1,1] normalization
|
379 |
# first undo [-1,1] parent preprocessing back to original range
|
380 |
for k, v in pa.items():
|
381 |
+
if k != "mri_seq" and k != "sex":
|
382 |
pa[k] = (v + 1) / 2 # [-1,1] -> [0,1]
|
383 |
_max, _min = get_attr_max_min(k)
|
384 |
pa[k] = pa[k] * (_max - _min) + _min
|
385 |
# log_standardize parents for vae input
|
386 |
for k, v in pa.items():
|
387 |
logpa_k = torch.log(v.clamp(min=1e-12))
|
388 |
+
if k == "age":
|
389 |
pa[k] = (logpa_k - 4.112339973449707) / 0.11769197136163712
|
390 |
+
elif k == "brain_volume":
|
391 |
pa[k] = (logpa_k - 13.965583801269531) / 0.09537758678197861
|
392 |
+
elif k == "ventricle_volume":
|
393 |
pa[k] = (logpa_k - 10.345998764038086) / 0.43127763271331787
|
394 |
# concatenate parents expand to input res for conditioning the vae
|
395 |
+
pa = torch.cat(
|
396 |
+
[pa[k] if len(pa[k].shape) > 1 else pa[k][..., None] for k in args.parents_x],
|
397 |
+
dim=1,
|
398 |
+
)
|
399 |
+
pa = (
|
400 |
+
pa[..., None, None].repeat(1, 1, *(args.input_res,) * 2).to(args.device).float()
|
401 |
+
)
|
402 |
return pa
|
403 |
|
404 |
|
405 |
def preprocess_brain(args, obs):
|
406 |
+
obs["x"] = (obs["x"][None, ...].float().to(args.device) - 127.5) / 127.5 # [-1,1]
|
407 |
# for all other variables except x
|
408 |
+
for k in [k for k in obs.keys() if k != "x"]:
|
409 |
obs[k] = obs[k].float().to(args.device).view(1, 1)
|
410 |
+
if k in ["age", "brain_volume", "ventricle_volume"]:
|
411 |
k_max, k_min = get_attr_max_min(k)
|
412 |
obs[k] = (obs[k] - k_min) / (k_max - k_min) # [0,1]
|
413 |
obs[k] = 2 * obs[k] - 1 # [-1,1]
|
414 |
return obs
|
415 |
|
416 |
|
417 |
+
def get_fig_arr(x, width=4, height=4, dpi=144, cmap="Greys_r", norm=None):
|
418 |
fig = plt.figure(figsize=(width, height), dpi=dpi)
|
419 |
+
ax = plt.axes([0, 0, 1, 1], frameon=False)
|
420 |
+
if cmap == "Greys_r":
|
421 |
ax.imshow(x, cmap=cmap, vmin=0, vmax=255)
|
422 |
else:
|
423 |
ax.imshow(x, cmap=cmap, norm=norm)
|
424 |
+
ax.axis("off")
|
425 |
fig.canvas.draw()
|
426 |
return np.array(fig.canvas.renderer.buffer_rgba())
|
427 |
|
|
|
432 |
if x_max is None:
|
433 |
x_max = x.max()
|
434 |
x = (x - x_min) / (x_max - x_min) # [0,1]
|
435 |
+
return x if zero_one else 2 * x - 1 # else [-1,1]
|
datasets.py
CHANGED
@@ -23,44 +23,45 @@ def normalize(x, x_min=None, x_max=None, zero_one=False):
|
|
23 |
x_min = x.min()
|
24 |
if x_max is None:
|
25 |
x_max = x.max()
|
26 |
-
print(f
|
27 |
x = (x - x_min) / (x_max - x_min) # [0,1]
|
28 |
return x if zero_one else 2 * x - 1 # else [-1,1]
|
29 |
|
30 |
|
31 |
class UKBBDataset(Dataset):
|
32 |
-
def __init__(
|
|
|
|
|
33 |
super().__init__()
|
34 |
self.root = root
|
35 |
self.transform = transform
|
36 |
self.concat_pa = concat_pa # return concatenated parents
|
37 |
|
38 |
-
print(f
|
39 |
self.df = pd.read_csv(csv_file)
|
40 |
self.columns = columns
|
41 |
if self.columns is None:
|
42 |
# ['eid', 'sex', 'age', 'brain_volume', 'ventricle_volume', 'mri_seq']
|
43 |
self.columns = list(self.df.columns) # return all
|
44 |
self.columns.pop(0) # remove redundant 'index' column
|
45 |
-
print(f
|
46 |
-
self.samples = {i: torch.as_tensor(
|
47 |
-
self.df[i]).float() for i in self.columns}
|
48 |
|
49 |
-
for k in [
|
50 |
-
print(f
|
51 |
if k in self.columns:
|
52 |
-
if norm ==
|
53 |
self.samples[k] = normalize(self.samples[k])
|
54 |
-
elif norm ==
|
55 |
self.samples[k] = normalize(self.samples[k], zero_one=True)
|
56 |
-
elif norm ==
|
57 |
self.samples[k] = log_standardize(self.samples[k])
|
58 |
elif norm == None:
|
59 |
pass
|
60 |
else:
|
61 |
-
NotImplementedError(f
|
62 |
-
print(f
|
63 |
-
self.return_x = True if
|
64 |
|
65 |
def __len__(self):
|
66 |
return len(self.df)
|
@@ -69,31 +70,32 @@ class UKBBDataset(Dataset):
|
|
69 |
sample = {k: v[idx] for k, v in self.samples.items()}
|
70 |
|
71 |
if self.return_x:
|
72 |
-
mri_seq =
|
73 |
# Load scan
|
74 |
-
filename =
|
75 |
-
mri_seq+
|
76 |
-
|
|
|
77 |
|
78 |
if self.transform is not None:
|
79 |
-
sample[
|
80 |
-
sample.pop(
|
81 |
|
82 |
if self.concat_pa:
|
83 |
-
sample[
|
84 |
-
torch.tensor([sample[k]]) for k in self.columns if k !=
|
85 |
-
|
86 |
|
87 |
return sample
|
88 |
|
89 |
|
90 |
def get_attr_max_min(attr):
|
91 |
# some ukbb dataset (max, min) stats
|
92 |
-
if attr ==
|
93 |
return 73, 44
|
94 |
-
elif attr ==
|
95 |
return 1629520, 841919
|
96 |
-
elif attr ==
|
97 |
return 157075, 7613.27001953125
|
98 |
else:
|
99 |
NotImplementedError
|
@@ -102,37 +104,43 @@ def get_attr_max_min(attr):
|
|
102 |
def ukbb(args):
|
103 |
csv_dir = args.data_dir
|
104 |
augmentation = {
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
}
|
117 |
|
118 |
datasets = {}
|
119 |
# for split in ['train', 'valid', 'test']:
|
120 |
-
for split in [
|
121 |
datasets[split] = UKBBDataset(
|
122 |
root=args.data_dir,
|
123 |
-
csv_file=os.path.join(csv_dir, split+
|
124 |
-
transform=augmentation[(
|
125 |
-
columns=(None if not args.parents_x else [
|
126 |
-
norm=(None if not hasattr(args,
|
127 |
-
|
128 |
-
|
129 |
|
130 |
return datasets
|
131 |
|
132 |
|
133 |
def _load_uint8(f):
|
134 |
-
idx_dtype, ndim = struct.unpack(
|
135 |
-
shape = struct.unpack(
|
136 |
buffer_length = int(np.prod(shape))
|
137 |
data = np.frombuffer(f.read(buffer_length), dtype=np.uint8).reshape(shape)
|
138 |
return data
|
@@ -152,8 +160,8 @@ def load_idx(path: str) -> np.ndarray:
|
|
152 |
----------
|
153 |
http://yann.lecun.com/exdb/mnist/
|
154 |
"""
|
155 |
-
open_fcn = gzip.open if path.endswith(
|
156 |
-
with open_fcn(path,
|
157 |
return _load_uint8(f)
|
158 |
|
159 |
|
@@ -168,8 +176,9 @@ def _get_paths(root_dir, train):
|
|
168 |
return images_path, labels_path, metrics_path
|
169 |
|
170 |
|
171 |
-
def load_morphomnist_like(
|
172 |
-
|
|
|
173 |
"""
|
174 |
Args:
|
175 |
root_dir: path to data directory
|
@@ -184,68 +193,84 @@ def load_morphomnist_like(root_dir, train: bool = True, columns=None) \
|
|
184 |
images = load_idx(images_path)
|
185 |
labels = load_idx(labels_path)
|
186 |
|
187 |
-
if columns is not None and
|
188 |
-
usecols = [
|
189 |
else:
|
190 |
usecols = columns
|
191 |
-
metrics = pd.read_csv(metrics_path, usecols=usecols, index_col=
|
192 |
return images, labels, metrics
|
193 |
|
194 |
|
195 |
class MorphoMNIST(Dataset):
|
196 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
self.train = train
|
198 |
self.transform = transform
|
199 |
self.columns = columns
|
200 |
self.concat_pa = concat_pa
|
201 |
self.norm = norm
|
202 |
|
203 |
-
cols_not_digit = [c for c in self.columns if c !=
|
204 |
images, labels, metrics_df = load_morphomnist_like(
|
205 |
-
root_dir, train, cols_not_digit
|
|
|
206 |
self.images = torch.from_numpy(np.array(images)).unsqueeze(1)
|
207 |
self.labels = F.one_hot(
|
208 |
-
torch.from_numpy(np.array(labels)).long(), num_classes=10
|
|
|
209 |
|
210 |
if self.columns is None:
|
211 |
self.columns = metrics_df.columns
|
212 |
self.samples = {k: torch.tensor(metrics_df[k]) for k in cols_not_digit}
|
213 |
|
214 |
self.min_max = {
|
215 |
-
|
216 |
-
|
217 |
}
|
218 |
|
219 |
for k, v in self.samples.items(): # optional preprocessing
|
220 |
-
print(f
|
221 |
-
if norm ==
|
222 |
-
self.samples[k] = normalize(
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
|
|
|
|
227 |
elif norm == None:
|
228 |
pass
|
229 |
else:
|
230 |
-
NotImplementedError(f
|
231 |
-
print(f
|
232 |
|
233 |
-
self.samples.update({
|
234 |
|
235 |
def __len__(self):
|
236 |
return len(self.images)
|
237 |
|
238 |
def __getitem__(self, idx):
|
239 |
sample = {}
|
240 |
-
sample[
|
241 |
|
242 |
if self.transform is not None:
|
243 |
-
sample[
|
244 |
|
245 |
if self.concat_pa:
|
246 |
-
sample[
|
247 |
-
|
248 |
-
|
|
|
|
|
|
|
|
|
249 |
else:
|
250 |
sample.update({k: v[idx] for k, v in self.samples.items()})
|
251 |
return sample
|
@@ -254,39 +279,43 @@ class MorphoMNIST(Dataset):
|
|
254 |
def morphomnist(args):
|
255 |
# Load data
|
256 |
augmentation = {
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
|
|
|
|
|
|
|
|
263 |
}
|
264 |
|
265 |
datasets = {}
|
266 |
# for split in ['train', 'valid', 'test']:
|
267 |
-
for split in [
|
268 |
datasets[split] = MorphoMNIST(
|
269 |
root_dir=args.data_dir,
|
270 |
-
train=(split ==
|
271 |
-
transform=augmentation[(
|
272 |
columns=args.parents_x,
|
273 |
norm=args.context_norm,
|
274 |
-
concat_pa=
|
275 |
)
|
276 |
return datasets
|
277 |
|
278 |
|
279 |
def preproc_mimic(batch):
|
280 |
for k, v in batch.items():
|
281 |
-
if k ==
|
282 |
-
batch[
|
283 |
-
elif k in [
|
284 |
batch[k] = batch[k].float().unsqueeze(-1)
|
285 |
-
batch[k] = batch[k] / 100.
|
286 |
batch[k] = batch[k] * 2 - 1 # [-1,1]
|
287 |
-
elif k in [
|
288 |
batch[k] = F.one_hot(batch[k], num_classes=3).squeeze().float()
|
289 |
-
elif k in [
|
290 |
batch[k] = batch[k].unsqueeze(-1).float()
|
291 |
else:
|
292 |
batch[k] = batch[k].float().unsqueeze(-1)
|
@@ -294,39 +323,52 @@ def preproc_mimic(batch):
|
|
294 |
|
295 |
|
296 |
class MIMICDataset(Dataset):
|
297 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
self.data = pd.read_csv(csv_file)
|
299 |
self.transform = transform
|
300 |
-
self.disease_labels = [
|
|
|
|
|
|
|
|
|
|
|
301 |
self.samples = {
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
}
|
310 |
|
311 |
-
for idx, _ in enumerate(tqdm(range(len(self.data)), desc=
|
312 |
-
if only_pleural_eff and self.data.loc[idx,
|
313 |
continue
|
314 |
-
img_path = os.path.join(root, self.data.loc[idx,
|
315 |
|
316 |
-
lung_opacity = self.data.loc[idx,
|
317 |
-
self.samples[
|
318 |
|
319 |
-
pleural_effusion = self.data.loc[idx,
|
320 |
-
self.samples[
|
321 |
|
322 |
-
disease = self.data.loc[idx,
|
323 |
-
finding = 0 if disease ==
|
324 |
|
325 |
-
self.samples[
|
326 |
-
self.samples[
|
327 |
-
self.samples[
|
328 |
-
self.samples[
|
329 |
-
self.samples[
|
330 |
|
331 |
self.columns = columns
|
332 |
if self.columns is None:
|
@@ -336,33 +378,36 @@ class MIMICDataset(Dataset):
|
|
336 |
self.concat_pa = concat_pa
|
337 |
|
338 |
def __len__(self):
|
339 |
-
return len(self.samples[
|
340 |
|
341 |
def __getitem__(self, idx):
|
342 |
sample = {k: v[idx] for k, v in self.samples.items()}
|
343 |
-
sample[
|
344 |
|
345 |
for k, v in sample.items():
|
346 |
sample[k] = torch.tensor(v)
|
347 |
|
348 |
if self.transform:
|
349 |
-
sample[
|
350 |
|
351 |
sample = preproc_mimic(sample)
|
352 |
if self.concat_pa:
|
353 |
-
sample[
|
354 |
return sample
|
355 |
|
356 |
|
357 |
def mimic(args):
|
358 |
args.csv_dir = args.data_dir
|
359 |
datasets = {}
|
360 |
-
datasets[
|
361 |
root=args.data_dir,
|
362 |
-
csv_file=os.path.join(args.csv_dir,
|
363 |
columns=args.parents_x,
|
364 |
-
transform=TF.Compose(
|
365 |
-
|
366 |
-
|
|
|
|
|
|
|
367 |
)
|
368 |
-
return datasets
|
|
|
23 |
x_min = x.min()
|
24 |
if x_max is None:
|
25 |
x_max = x.max()
|
26 |
+
print(f"max: {x_max}, min: {x_min}")
|
27 |
x = (x - x_min) / (x_max - x_min) # [0,1]
|
28 |
return x if zero_one else 2 * x - 1 # else [-1,1]
|
29 |
|
30 |
|
31 |
class UKBBDataset(Dataset):
|
32 |
+
def __init__(
|
33 |
+
self, root, csv_file, transform=None, columns=None, norm=None, concat_pa=True
|
34 |
+
):
|
35 |
super().__init__()
|
36 |
self.root = root
|
37 |
self.transform = transform
|
38 |
self.concat_pa = concat_pa # return concatenated parents
|
39 |
|
40 |
+
print(f"\nLoading csv data: {csv_file}")
|
41 |
self.df = pd.read_csv(csv_file)
|
42 |
self.columns = columns
|
43 |
if self.columns is None:
|
44 |
# ['eid', 'sex', 'age', 'brain_volume', 'ventricle_volume', 'mri_seq']
|
45 |
self.columns = list(self.df.columns) # return all
|
46 |
self.columns.pop(0) # remove redundant 'index' column
|
47 |
+
print(f"columns: {self.columns}")
|
48 |
+
self.samples = {i: torch.as_tensor(self.df[i]).float() for i in self.columns}
|
|
|
49 |
|
50 |
+
for k in ["age", "brain_volume", "ventricle_volume"]:
|
51 |
+
print(f"{k} normalization: {norm}")
|
52 |
if k in self.columns:
|
53 |
+
if norm == "[-1,1]":
|
54 |
self.samples[k] = normalize(self.samples[k])
|
55 |
+
elif norm == "[0,1]":
|
56 |
self.samples[k] = normalize(self.samples[k], zero_one=True)
|
57 |
+
elif norm == "log_standard":
|
58 |
self.samples[k] = log_standardize(self.samples[k])
|
59 |
elif norm == None:
|
60 |
pass
|
61 |
else:
|
62 |
+
NotImplementedError(f"{norm} not implemented.")
|
63 |
+
print(f"#samples: {len(self.df)}")
|
64 |
+
self.return_x = True if "eid" in self.columns else False
|
65 |
|
66 |
def __len__(self):
|
67 |
return len(self.df)
|
|
|
70 |
sample = {k: v[idx] for k, v in self.samples.items()}
|
71 |
|
72 |
if self.return_x:
|
73 |
+
mri_seq = "T1" if sample["mri_seq"] == 0.0 else "T2_FLAIR"
|
74 |
# Load scan
|
75 |
+
filename = (
|
76 |
+
f'{int(sample["eid"])}_' + mri_seq + "_unbiased_brain_rigid_to_mni.png"
|
77 |
+
)
|
78 |
+
x = Image.open(os.path.join(self.root, "thumbs_192x192", filename))
|
79 |
|
80 |
if self.transform is not None:
|
81 |
+
sample["x"] = self.transform(x)
|
82 |
+
sample.pop("eid", None)
|
83 |
|
84 |
if self.concat_pa:
|
85 |
+
sample["pa"] = torch.cat(
|
86 |
+
[torch.tensor([sample[k]]) for k in self.columns if k != "eid"], dim=0
|
87 |
+
)
|
88 |
|
89 |
return sample
|
90 |
|
91 |
|
92 |
def get_attr_max_min(attr):
|
93 |
# some ukbb dataset (max, min) stats
|
94 |
+
if attr == "age":
|
95 |
return 73, 44
|
96 |
+
elif attr == "brain_volume":
|
97 |
return 1629520, 841919
|
98 |
+
elif attr == "ventricle_volume":
|
99 |
return 157075, 7613.27001953125
|
100 |
else:
|
101 |
NotImplementedError
|
|
|
104 |
def ukbb(args):
|
105 |
csv_dir = args.data_dir
|
106 |
augmentation = {
|
107 |
+
"train": TF.Compose(
|
108 |
+
[
|
109 |
+
TF.Resize((args.input_res, args.input_res), antialias=None),
|
110 |
+
TF.RandomCrop(
|
111 |
+
size=(args.input_res, args.input_res),
|
112 |
+
padding=[2 * args.pad, args.pad],
|
113 |
+
),
|
114 |
+
TF.RandomHorizontalFlip(p=args.hflip),
|
115 |
+
TF.PILToTensor(),
|
116 |
+
]
|
117 |
+
),
|
118 |
+
"eval": TF.Compose(
|
119 |
+
[
|
120 |
+
TF.Resize((args.input_res, args.input_res), antialias=None),
|
121 |
+
TF.PILToTensor(),
|
122 |
+
]
|
123 |
+
),
|
124 |
}
|
125 |
|
126 |
datasets = {}
|
127 |
# for split in ['train', 'valid', 'test']:
|
128 |
+
for split in ["test"]:
|
129 |
datasets[split] = UKBBDataset(
|
130 |
root=args.data_dir,
|
131 |
+
csv_file=os.path.join(csv_dir, split + ".csv"),
|
132 |
+
transform=augmentation[("eval" if split != "train" else split)],
|
133 |
+
columns=(None if not args.parents_x else ["eid"] + args.parents_x),
|
134 |
+
norm=(None if not hasattr(args, "context_norm") else args.context_norm),
|
135 |
+
concat_pa=False,
|
136 |
+
)
|
137 |
|
138 |
return datasets
|
139 |
|
140 |
|
141 |
def _load_uint8(f):
|
142 |
+
idx_dtype, ndim = struct.unpack("BBBB", f.read(4))[2:]
|
143 |
+
shape = struct.unpack(">" + "I" * ndim, f.read(4 * ndim))
|
144 |
buffer_length = int(np.prod(shape))
|
145 |
data = np.frombuffer(f.read(buffer_length), dtype=np.uint8).reshape(shape)
|
146 |
return data
|
|
|
160 |
----------
|
161 |
http://yann.lecun.com/exdb/mnist/
|
162 |
"""
|
163 |
+
open_fcn = gzip.open if path.endswith(".gz") else open
|
164 |
+
with open_fcn(path, "rb") as f:
|
165 |
return _load_uint8(f)
|
166 |
|
167 |
|
|
|
176 |
return images_path, labels_path, metrics_path
|
177 |
|
178 |
|
179 |
+
def load_morphomnist_like(
|
180 |
+
root_dir, train: bool = True, columns=None
|
181 |
+
) -> Tuple[np.ndarray, np.ndarray, pd.DataFrame]:
|
182 |
"""
|
183 |
Args:
|
184 |
root_dir: path to data directory
|
|
|
193 |
images = load_idx(images_path)
|
194 |
labels = load_idx(labels_path)
|
195 |
|
196 |
+
if columns is not None and "index" not in columns:
|
197 |
+
usecols = ["index"] + list(columns)
|
198 |
else:
|
199 |
usecols = columns
|
200 |
+
metrics = pd.read_csv(metrics_path, usecols=usecols, index_col="index")
|
201 |
return images, labels, metrics
|
202 |
|
203 |
|
204 |
class MorphoMNIST(Dataset):
|
205 |
+
def __init__(
|
206 |
+
self,
|
207 |
+
root_dir,
|
208 |
+
train=True,
|
209 |
+
transform=None,
|
210 |
+
columns=None,
|
211 |
+
norm=None,
|
212 |
+
concat_pa=True,
|
213 |
+
):
|
214 |
self.train = train
|
215 |
self.transform = transform
|
216 |
self.columns = columns
|
217 |
self.concat_pa = concat_pa
|
218 |
self.norm = norm
|
219 |
|
220 |
+
cols_not_digit = [c for c in self.columns if c != "digit"]
|
221 |
images, labels, metrics_df = load_morphomnist_like(
|
222 |
+
root_dir, train, cols_not_digit
|
223 |
+
)
|
224 |
self.images = torch.from_numpy(np.array(images)).unsqueeze(1)
|
225 |
self.labels = F.one_hot(
|
226 |
+
torch.from_numpy(np.array(labels)).long(), num_classes=10
|
227 |
+
)
|
228 |
|
229 |
if self.columns is None:
|
230 |
self.columns = metrics_df.columns
|
231 |
self.samples = {k: torch.tensor(metrics_df[k]) for k in cols_not_digit}
|
232 |
|
233 |
self.min_max = {
|
234 |
+
"thickness": [0.87598526, 6.255515],
|
235 |
+
"intensity": [66.601204, 254.90317],
|
236 |
}
|
237 |
|
238 |
for k, v in self.samples.items(): # optional preprocessing
|
239 |
+
print(f"{k} normalization: {norm}")
|
240 |
+
if norm == "[-1,1]":
|
241 |
+
self.samples[k] = normalize(
|
242 |
+
v, x_min=self.min_max[k][0], x_max=self.min_max[k][1]
|
243 |
+
)
|
244 |
+
elif norm == "[0,1]":
|
245 |
+
self.samples[k] = normalize(
|
246 |
+
v, x_min=self.min_max[k][0], x_max=self.min_max[k][1], zero_one=True
|
247 |
+
)
|
248 |
elif norm == None:
|
249 |
pass
|
250 |
else:
|
251 |
+
NotImplementedError(f"{norm} not implemented.")
|
252 |
+
print(f"#samples: {len(metrics_df)}\n")
|
253 |
|
254 |
+
self.samples.update({"digit": self.labels})
|
255 |
|
256 |
def __len__(self):
|
257 |
return len(self.images)
|
258 |
|
259 |
def __getitem__(self, idx):
|
260 |
sample = {}
|
261 |
+
sample["x"] = self.images[idx]
|
262 |
|
263 |
if self.transform is not None:
|
264 |
+
sample["x"] = self.transform(sample["x"])
|
265 |
|
266 |
if self.concat_pa:
|
267 |
+
sample["pa"] = torch.cat(
|
268 |
+
[
|
269 |
+
v[idx] if k == "digit" else torch.tensor([v[idx]])
|
270 |
+
for k, v in self.samples.items()
|
271 |
+
],
|
272 |
+
dim=0,
|
273 |
+
)
|
274 |
else:
|
275 |
sample.update({k: v[idx] for k, v in self.samples.items()})
|
276 |
return sample
|
|
|
279 |
def morphomnist(args):
|
280 |
# Load data
|
281 |
augmentation = {
|
282 |
+
"train": TF.Compose(
|
283 |
+
[
|
284 |
+
TF.RandomCrop((args.input_res, args.input_res), padding=args.pad),
|
285 |
+
]
|
286 |
+
),
|
287 |
+
"eval": TF.Compose(
|
288 |
+
[
|
289 |
+
TF.Pad(padding=2), # (32, 32)
|
290 |
+
]
|
291 |
+
),
|
292 |
}
|
293 |
|
294 |
datasets = {}
|
295 |
# for split in ['train', 'valid', 'test']:
|
296 |
+
for split in ["test"]:
|
297 |
datasets[split] = MorphoMNIST(
|
298 |
root_dir=args.data_dir,
|
299 |
+
train=(split == "train"), # test set is valid set
|
300 |
+
transform=augmentation[("eval" if split != "train" else split)],
|
301 |
columns=args.parents_x,
|
302 |
norm=args.context_norm,
|
303 |
+
concat_pa=False,
|
304 |
)
|
305 |
return datasets
|
306 |
|
307 |
|
308 |
def preproc_mimic(batch):
|
309 |
for k, v in batch.items():
|
310 |
+
if k == "x":
|
311 |
+
batch["x"] = (batch["x"].float() - 127.5) / 127.5 # [-1,1]
|
312 |
+
elif k in ["age"]:
|
313 |
batch[k] = batch[k].float().unsqueeze(-1)
|
314 |
+
batch[k] = batch[k] / 100.0
|
315 |
batch[k] = batch[k] * 2 - 1 # [-1,1]
|
316 |
+
elif k in ["race"]:
|
317 |
batch[k] = F.one_hot(batch[k], num_classes=3).squeeze().float()
|
318 |
+
elif k in ["finding"]:
|
319 |
batch[k] = batch[k].unsqueeze(-1).float()
|
320 |
else:
|
321 |
batch[k] = batch[k].float().unsqueeze(-1)
|
|
|
323 |
|
324 |
|
325 |
class MIMICDataset(Dataset):
|
326 |
+
def __init__(
|
327 |
+
self,
|
328 |
+
root,
|
329 |
+
csv_file,
|
330 |
+
transform=None,
|
331 |
+
columns=None,
|
332 |
+
concat_pa=True,
|
333 |
+
only_pleural_eff=True,
|
334 |
+
):
|
335 |
self.data = pd.read_csv(csv_file)
|
336 |
self.transform = transform
|
337 |
+
self.disease_labels = [
|
338 |
+
"No Finding",
|
339 |
+
"Other",
|
340 |
+
"Pleural Effusion",
|
341 |
+
# "Lung Opacity",
|
342 |
+
]
|
343 |
self.samples = {
|
344 |
+
"age": [],
|
345 |
+
"sex": [],
|
346 |
+
"finding": [],
|
347 |
+
"x": [],
|
348 |
+
"race": [],
|
349 |
+
# "lung_opacity": [],
|
350 |
+
# "pleural_effusion": [],
|
351 |
}
|
352 |
|
353 |
+
for idx, _ in enumerate(tqdm(range(len(self.data)), desc="Loading MIMIC Data")):
|
354 |
+
if only_pleural_eff and self.data.loc[idx, "disease"] == "Other":
|
355 |
continue
|
356 |
+
img_path = os.path.join(root, self.data.loc[idx, "path_preproc"])
|
357 |
|
358 |
+
# lung_opacity = self.data.loc[idx, "Lung Opacity"]
|
359 |
+
# self.samples["lung_opacity"].append(lung_opacity)
|
360 |
|
361 |
+
# pleural_effusion = self.data.loc[idx, "Pleural Effusion"]
|
362 |
+
# self.samples["pleural_effusion"].append(pleural_effusion)
|
363 |
|
364 |
+
disease = self.data.loc[idx, "disease"]
|
365 |
+
finding = 0 if disease == "No Finding" else 1
|
366 |
|
367 |
+
self.samples["x"].append(img_path)
|
368 |
+
self.samples["finding"].append(finding)
|
369 |
+
self.samples["age"].append(self.data.loc[idx, "age"])
|
370 |
+
self.samples["race"].append(self.data.loc[idx, "race_label"])
|
371 |
+
self.samples["sex"].append(self.data.loc[idx, "sex_label"])
|
372 |
|
373 |
self.columns = columns
|
374 |
if self.columns is None:
|
|
|
378 |
self.concat_pa = concat_pa
|
379 |
|
380 |
def __len__(self):
|
381 |
+
return len(self.samples["x"])
|
382 |
|
383 |
def __getitem__(self, idx):
|
384 |
sample = {k: v[idx] for k, v in self.samples.items()}
|
385 |
+
sample["x"] = imread(sample["x"]).astype(np.float32)[None, ...]
|
386 |
|
387 |
for k, v in sample.items():
|
388 |
sample[k] = torch.tensor(v)
|
389 |
|
390 |
if self.transform:
|
391 |
+
sample["x"] = self.transform(sample["x"])
|
392 |
|
393 |
sample = preproc_mimic(sample)
|
394 |
if self.concat_pa:
|
395 |
+
sample["pa"] = torch.cat([sample[k] for k in self.columns], dim=0)
|
396 |
return sample
|
397 |
|
398 |
|
399 |
def mimic(args):
|
400 |
args.csv_dir = args.data_dir
|
401 |
datasets = {}
|
402 |
+
datasets["test"] = MIMICDataset(
|
403 |
root=args.data_dir,
|
404 |
+
csv_file=os.path.join(args.csv_dir, "mimic.sample.test.csv"),
|
405 |
columns=args.parents_x,
|
406 |
+
transform=TF.Compose(
|
407 |
+
[
|
408 |
+
TF.Resize((args.input_res, args.input_res), antialias=None),
|
409 |
+
]
|
410 |
+
),
|
411 |
+
concat_pa=False,
|
412 |
)
|
413 |
+
return datasets
|
pgm/flow_pgm.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
-
|
2 |
|
|
|
3 |
import torch
|
4 |
-
|
5 |
import torch.nn.functional as F
|
6 |
|
7 |
import pyro
|
@@ -15,53 +16,69 @@ from pyro.distributions.conditional import ConditionalTransformedDistribution
|
|
15 |
from .layers import (
|
16 |
ConditionalTransformedDistributionGumbelMax,
|
17 |
ConditionalGumbelMax,
|
18 |
-
ConditionalAffineTransform,
|
|
|
|
|
19 |
)
|
20 |
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
class BasePGM(nn.Module):
|
23 |
def __init__(self):
|
24 |
super().__init__()
|
25 |
|
26 |
def scm(self, *args, **kwargs):
|
27 |
def config(msg):
|
28 |
-
if isinstance(msg[
|
29 |
return TransformReparam()
|
30 |
else:
|
31 |
return None
|
|
|
32 |
return pyro.poutine.reparam(self.model, config=config)(*args, **kwargs)
|
33 |
|
34 |
-
def sample_scm(self, n_samples=1
|
35 |
-
with pyro.plate(
|
36 |
-
samples = self.scm(
|
37 |
return samples
|
38 |
|
39 |
-
def sample(self, n_samples=1
|
40 |
-
with pyro.plate(
|
41 |
-
samples = self.model(
|
42 |
return samples
|
43 |
|
44 |
-
def infer_exogeneous(self, obs):
|
45 |
batch_size = list(obs.values())[0].shape[0]
|
46 |
# assuming that we use transformed distributions for everything:
|
47 |
cond_model = pyro.condition(self.sample, data=obs)
|
48 |
-
cond_trace = pyro.poutine.trace(
|
49 |
-
cond_model).get_trace(batch_size)
|
50 |
|
51 |
output = {}
|
52 |
for name, node in cond_trace.nodes.items():
|
53 |
-
if
|
54 |
continue
|
55 |
-
fn = node[
|
56 |
if isinstance(fn, dist.Independent):
|
57 |
fn = fn.base_dist
|
58 |
if isinstance(fn, dist.TransformedDistribution):
|
59 |
# compute exogenous base dist (created with TransformReparam) at all sites
|
60 |
-
output[name +
|
61 |
-
|
|
|
62 |
return output
|
63 |
|
64 |
-
def counterfactual(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
dag_variables = self.variables.keys()
|
66 |
assert set(obs.keys()) == set(dag_variables)
|
67 |
avg_cfs = {k: torch.zeros_like(obs[k]) for k in obs.keys()}
|
@@ -70,43 +87,50 @@ class BasePGM(nn.Module):
|
|
70 |
for _ in range(num_particles):
|
71 |
# Abduction
|
72 |
exo_noise = self.infer_exogeneous(obs)
|
73 |
-
exo_noise = {k: v.detach() if detach else v for k,
|
74 |
-
v in exo_noise.items()}
|
75 |
# condition on root node variables (no exogeneous noise available)
|
76 |
for k in dag_variables:
|
77 |
if k not in intervention.keys():
|
78 |
-
if k not in [i.split(
|
79 |
exo_noise[k] = obs[k]
|
80 |
# Abducted SCM
|
81 |
-
abducted_scm = pyro.poutine.condition(
|
82 |
-
self.sample_scm, data=exo_noise)
|
83 |
# Action
|
84 |
-
counterfactual_scm = pyro.poutine.do(
|
85 |
-
abducted_scm, data=intervention)
|
86 |
# Prediction
|
87 |
-
counterfactuals = counterfactual_scm(batch_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
for k, v in counterfactuals.items():
|
90 |
-
avg_cfs[k] +=
|
91 |
return avg_cfs
|
92 |
|
93 |
|
94 |
class FlowPGM(BasePGM):
|
95 |
-
def __init__(self, args):
|
96 |
super().__init__()
|
97 |
self.variables = {
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
}
|
104 |
# priors: s, m, a, b and v
|
105 |
self.s_logit = nn.Parameter(torch.zeros(1))
|
106 |
self.m_logit = nn.Parameter(torch.zeros(1))
|
107 |
-
for k in [
|
108 |
-
self.register_buffer(f
|
109 |
-
self.register_buffer(f
|
110 |
|
111 |
# constraint, assumes data is [-1,1] normalized
|
112 |
# normalize_transform = T.ComposeTransform([
|
@@ -116,23 +140,19 @@ class FlowPGM(BasePGM):
|
|
116 |
|
117 |
# age flow
|
118 |
self.age_module = T.ComposeTransformModule(
|
119 |
-
[T.Spline(1, count_bins=4, order=
|
120 |
-
|
121 |
-
|
122 |
# self.age_module, normalize_transform])
|
123 |
|
124 |
# brain volume (conditional) flow: (sex, age) -> brain_vol
|
125 |
-
bvol_net = DenseNN(
|
126 |
-
|
127 |
-
self.bvol_flow = ConditionalAffineTransform(
|
128 |
-
context_nn=bvol_net, event_dim=0)
|
129 |
# self.bvol_flow = [self.bvol_flow, normalize_transform]
|
130 |
|
131 |
# ventricle volume (conditional) flow: (brain_vol, age) -> ventricle_vol
|
132 |
-
vvol_net = DenseNN(
|
133 |
-
|
134 |
-
self.vvol_flow = ConditionalAffineTransform(
|
135 |
-
context_nn=vvol_net, event_dim=0)
|
136 |
# self.vvol_flow = [self.vvol_transf, normalize_transform]
|
137 |
|
138 |
# if args.setup != 'sup_pgm':
|
@@ -148,148 +168,152 @@ class FlowPGM(BasePGM):
|
|
148 |
self.encoder_b = CNN(input_shape, num_outputs=2, context_dim=1)
|
149 |
# q(v | x) = Normal(mu(x), sigma(x))
|
150 |
self.encoder_v = CNN(input_shape, num_outputs=2)
|
151 |
-
self.f =
|
152 |
-
|
|
|
|
|
|
|
153 |
|
154 |
-
def model(self,
|
155 |
# p(s), sex dist
|
156 |
ps = dist.Bernoulli(logits=self.s_logit).to_event(1)
|
157 |
-
sex = pyro.sample(
|
158 |
|
159 |
# p(m), mri_seq dist
|
160 |
pm = dist.Bernoulli(logits=self.m_logit).to_event(1)
|
161 |
-
mri_seq = pyro.sample(
|
162 |
|
163 |
# p(a), age flow
|
164 |
pa_base = dist.Normal(self.a_base_loc, self.a_base_scale).to_event(1)
|
165 |
pa = dist.TransformedDistribution(pa_base, self.age_flow)
|
166 |
-
age = pyro.sample(
|
167 |
|
168 |
# p(b | s, a), brain volume flow
|
169 |
-
pb_sa_base = dist.Normal(
|
170 |
-
self.b_base_loc, self.b_base_scale).to_event(1)
|
171 |
pb_sa = ConditionalTransformedDistribution(
|
172 |
-
pb_sa_base, [self.bvol_flow]
|
173 |
-
|
|
|
174 |
# _ = self.bvol_transf # register with pyro
|
175 |
|
176 |
# p(v | b, a), ventricle volume flow
|
177 |
-
pv_ba_base = dist.Normal(
|
178 |
-
self.v_base_loc, self.v_base_scale).to_event(1)
|
179 |
pv_ba = ConditionalTransformedDistribution(
|
180 |
-
pv_ba_base, [self.vvol_flow]
|
181 |
-
|
|
|
182 |
# _ = self.vvol_transf # register with pyro
|
183 |
|
184 |
return {
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
}
|
191 |
|
192 |
-
def guide(self, **obs):
|
193 |
# guide for (optional) semi-supervised learning
|
194 |
-
pyro.module(
|
195 |
-
with pyro.plate(
|
196 |
# q(m | x)
|
197 |
-
if obs[
|
198 |
-
m_prob = torch.sigmoid(self.encoder_m(obs[
|
199 |
-
m = pyro.sample(
|
200 |
-
probs=m_prob).to_event(1))
|
201 |
|
202 |
# q(v | x)
|
203 |
-
if obs[
|
204 |
-
v_loc, v_logscale = self.encoder_v(obs[
|
205 |
qv_x = dist.Normal(v_loc, self.f(v_logscale)).to_event(1)
|
206 |
-
obs[
|
207 |
|
208 |
# q(b | x, v)
|
209 |
-
if obs[
|
210 |
b_loc, b_logscale = self.encoder_b(
|
211 |
-
obs[
|
|
|
212 |
qb_xv = dist.Normal(b_loc, self.f(b_logscale)).to_event(1)
|
213 |
-
obs[
|
214 |
|
215 |
# q(s | x, b)
|
216 |
-
if obs[
|
217 |
-
s_prob = torch.sigmoid(
|
218 |
-
obs[
|
219 |
-
|
|
|
220 |
|
221 |
# q(a | b, v)
|
222 |
-
if obs[
|
223 |
-
ctx = torch.cat(
|
224 |
-
[obs['brain_volume'], obs['ventricle_volume']], dim=-1)
|
225 |
a_loc, a_logscale = self.encoder_a(ctx).chunk(2, dim=-1)
|
226 |
-
pyro.sample(
|
227 |
-
a_loc, self.f(a_logscale)).to_event(1))
|
228 |
|
229 |
-
def model_anticausal(self, **obs):
|
230 |
# assumes all variables are observed
|
231 |
-
pyro.module(
|
232 |
-
with pyro.plate(
|
233 |
# q(v | x)
|
234 |
-
v_loc, v_logscale = self.encoder_v(obs[
|
235 |
qv_x = dist.Normal(v_loc, self.f(v_logscale)).to_event(1)
|
236 |
-
pyro.sample(
|
237 |
-
obs=obs['ventricle_volume'])
|
238 |
|
239 |
# q(b | x, v)
|
240 |
b_loc, b_logscale = self.encoder_b(
|
241 |
-
obs[
|
|
|
242 |
qb_xv = dist.Normal(b_loc, self.f(b_logscale)).to_event(1)
|
243 |
-
pyro.sample(
|
244 |
|
245 |
# q(a | b, v)
|
246 |
-
ctx = torch.cat(
|
247 |
-
[obs['brain_volume'], obs['ventricle_volume']], dim=-1)
|
248 |
a_loc, a_logscale = self.encoder_a(ctx).chunk(2, dim=-1)
|
249 |
-
pyro.sample(
|
250 |
-
|
|
|
|
|
|
|
251 |
|
252 |
# q(s | x, b)
|
253 |
-
s_prob = torch.sigmoid(self.encoder_s(
|
254 |
-
obs['x'], y=obs['brain_volume']))
|
255 |
qs_xb = dist.Bernoulli(probs=s_prob).to_event(1)
|
256 |
-
pyro.sample(
|
257 |
|
258 |
# q(m | x)
|
259 |
-
m_prob = torch.sigmoid(self.encoder_m(obs[
|
260 |
qm_x = dist.Bernoulli(probs=m_prob).to_event(1)
|
261 |
-
pyro.sample(
|
262 |
|
263 |
-
def predict(self, **obs):
|
264 |
# q(v | x)
|
265 |
-
v_loc, v_logscale = self.encoder_v(obs[
|
266 |
# v_loc = torch.tanh(v_loc)
|
267 |
# q(b | x, v)
|
268 |
-
b_loc, b_logscale = self.encoder_b(
|
269 |
-
|
|
|
270 |
# b_loc = torch.tanh(b_loc)
|
271 |
# q(a | b, v)
|
272 |
-
ctx = torch.cat([obs[
|
273 |
a_loc, a_logscale = self.encoder_a(ctx).chunk(2, dim=-1)
|
274 |
# a_loc = torch.tanh(b_loc)
|
275 |
# q(s | x, b)
|
276 |
-
s_prob = torch.sigmoid(self.encoder_s(obs[
|
277 |
# q(m | x)
|
278 |
-
m_prob = torch.sigmoid(self.encoder_m(obs[
|
279 |
|
280 |
return {
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
}
|
287 |
|
288 |
-
def svi_model(self, **obs):
|
289 |
-
with pyro.plate(
|
290 |
pyro.condition(self.model, data=obs)()
|
291 |
|
292 |
-
def guide_pass(self, **obs):
|
293 |
pass
|
294 |
|
295 |
|
@@ -297,173 +321,174 @@ class MorphoMNISTPGM(BasePGM):
|
|
297 |
def __init__(self, args):
|
298 |
super().__init__()
|
299 |
self.variables = {
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
}
|
304 |
# priors
|
305 |
self.digit_logits = nn.Parameter(torch.zeros(1, 10)) # uniform prior
|
306 |
-
for k in [
|
307 |
-
self.register_buffer(f
|
308 |
-
self.register_buffer(f
|
309 |
|
310 |
# constraint, assumes data is [-1,1] normalized
|
311 |
-
normalize_transform = T.ComposeTransform(
|
312 |
-
T.SigmoidTransform(), T.AffineTransform(loc=-1, scale=2)]
|
|
|
313 |
|
314 |
# thickness flow
|
315 |
self.thickness_module = T.ComposeTransformModule(
|
316 |
-
[T.Spline(1, count_bins=4, order=
|
317 |
-
|
318 |
-
|
|
|
|
|
319 |
|
320 |
# intensity (conditional) flow: thickness -> intensity
|
321 |
-
intensity_net = DenseNN(
|
322 |
-
1, args.widths, [1, 1], nonlinearity=nn.GELU())
|
323 |
self.context_nn = ConditionalAffineTransform(
|
324 |
-
context_nn=intensity_net, event_dim=0
|
|
|
325 |
self.intensity_flow = [self.context_nn, normalize_transform]
|
326 |
|
327 |
-
if args.setup !=
|
328 |
# anticausal predictors
|
329 |
input_shape = (args.input_channels, args.input_res, args.input_res)
|
330 |
# q(t | x, i) = Normal(mu(x, i), sigma(x, i)), 2 outputs: loc & scale
|
331 |
-
self.encoder_t = CNN(input_shape, num_outputs=2,
|
332 |
-
context_dim=1, width=8)
|
333 |
# q(i | x) = Normal(mu(x), sigma(x))
|
334 |
self.encoder_i = CNN(input_shape, num_outputs=2, width=8)
|
335 |
# q(y | x) = Categorical(pi(x))
|
336 |
self.encoder_y = CNN(input_shape, num_outputs=10, width=8)
|
337 |
-
self.f =
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
|
|
|
|
|
|
342 |
# p(y), digit label prior dist
|
343 |
py = dist.OneHotCategorical(
|
344 |
-
probs=F.softmax(self.digit_logits, dim=-1)
|
|
|
345 |
# with pyro.poutine.scale(scale=0.05):
|
346 |
-
digit = pyro.sample(
|
347 |
|
348 |
# p(t), thickness flow
|
349 |
pt_base = dist.Normal(self.t_base_loc, self.t_base_scale).to_event(1)
|
350 |
pt = dist.TransformedDistribution(pt_base, self.thickness_flow)
|
351 |
-
thickness = pyro.sample(
|
352 |
|
353 |
# p(i | t), intensity conditional flow
|
354 |
pi_t_base = dist.Normal(self.i_base_loc, self.i_base_scale).to_event(1)
|
355 |
pi_t = ConditionalTransformedDistribution(
|
356 |
-
pi_t_base, self.intensity_flow
|
357 |
-
|
|
|
358 |
_ = self.context_nn
|
359 |
|
360 |
-
return {
|
361 |
|
362 |
-
def guide(self, **obs):
|
363 |
# guide for (optional) semi-supervised learning
|
364 |
-
with pyro.plate(
|
365 |
# q(i | x)
|
366 |
-
if obs[
|
367 |
-
i_loc, i_logscale = self.encoder_i(obs[
|
368 |
-
qi_t = dist.Normal(torch.tanh(
|
369 |
-
|
370 |
-
obs['intensity'] = pyro.sample('intensity', qi_t)
|
371 |
|
372 |
# q(t | x, i)
|
373 |
-
if obs[
|
374 |
-
t_loc, t_logscale = self.encoder_t(
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
obs[
|
379 |
|
380 |
# q(y | x)
|
381 |
-
if obs[
|
382 |
-
y_prob = F.softmax(self.encoder_y(obs[
|
383 |
-
qy_x = dist.OneHotCategorical(probs=y_prob).to_event(1)
|
384 |
-
pyro.sample(
|
385 |
|
386 |
-
def model_anticausal(self, **obs):
|
387 |
# assumes all variables are observed & continuous ones are in [-1,1]
|
388 |
-
pyro.module(
|
389 |
-
with pyro.plate(
|
390 |
# q(t | x, i)
|
391 |
-
t_loc, t_logscale = self.encoder_t(
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
pyro.sample(
|
396 |
|
397 |
# q(i | x)
|
398 |
-
i_loc, i_logscale = self.encoder_i(obs[
|
399 |
-
qi_t = dist.Normal(torch.tanh(
|
400 |
-
|
401 |
-
pyro.sample('intensity_aux', qi_t, obs=obs['intensity'])
|
402 |
|
403 |
# q(y | x)
|
404 |
-
y_prob = F.softmax(self.encoder_y(obs[
|
405 |
-
qy_x = dist.OneHotCategorical(probs=y_prob).to_event(1)
|
406 |
-
pyro.sample(
|
407 |
|
408 |
-
def predict(self, **obs):
|
409 |
# q(t | x, i)
|
410 |
-
t_loc, t_logscale = self.encoder_t(
|
411 |
-
|
|
|
412 |
t_loc = torch.tanh(t_loc)
|
413 |
# q(i | x)
|
414 |
-
i_loc, i_logscale = self.encoder_i(obs[
|
415 |
i_loc = torch.tanh(i_loc)
|
416 |
# q(y | x)
|
417 |
-
y_prob = F.softmax(self.encoder_y(obs[
|
418 |
-
return {
|
419 |
|
420 |
-
def svi_model(self, **obs):
|
421 |
-
with pyro.plate(
|
422 |
pyro.condition(self.model, data=obs)()
|
423 |
|
424 |
-
def guide_pass(self, **obs):
|
425 |
pass
|
426 |
|
427 |
|
428 |
-
class ChestPGM(
|
429 |
-
def __init__(self, args):
|
430 |
super().__init__()
|
431 |
self.variables = {
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
}
|
437 |
# Discrete variables that are not root nodes
|
438 |
-
self.discrete_variables = {
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
self.register_buffer(f'{k}_base_loc', torch.zeros(1))
|
445 |
-
self.register_buffer(f'{k}_base_scale', torch.ones(1))
|
446 |
-
|
447 |
-
# age flow
|
448 |
self.age_flow_components = T.ComposeTransformModule([T.Spline(1)])
|
449 |
# self.age_constraints = T.ComposeTransform([
|
450 |
# T.AffineTransform(loc=4.09541458484, scale=0.32548387126),
|
451 |
# T.ExpTransform()])
|
452 |
-
self.age_flow = T.ComposeTransform(
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
|
|
457 |
# Finding (conditional) via MLP, a -> f
|
458 |
-
finding_net = DenseNN(
|
459 |
-
1, [8, 16], param_dims=[2], nonlinearity=nn.Sigmoid())#.cuda()
|
460 |
self.finding_transform_GumbelMax = ConditionalGumbelMax(
|
461 |
-
context_nn=finding_net,
|
462 |
-
|
463 |
# log space for sex and race
|
464 |
-
self.sex_logit = nn.Parameter(torch.
|
465 |
-
|
466 |
-
self.race_logits = nn.Parameter(np.log(1/3)*torch.ones(1, 3))
|
467 |
|
468 |
input_shape = (args.input_channels, args.input_res, args.input_res)
|
469 |
|
@@ -477,207 +502,112 @@ class ChestPGM(nn.Module):
|
|
477 |
# q(a | x, f) ~ Normal(mu(x), sigma(x))
|
478 |
self.encoder_a = CNN(input_shape, num_outputs=1, context_dim=1)
|
479 |
|
480 |
-
def model(self,
|
|
|
481 |
# p(s), sex dist
|
482 |
ps = dist.Bernoulli(logits=self.sex_logit).to_event(1)
|
483 |
-
sex = pyro.sample(
|
484 |
|
485 |
# p(a), age flow
|
486 |
pa_base = dist.Normal(self.a_base_loc, self.a_base_scale).to_event(1)
|
487 |
pa = dist.TransformedDistribution(pa_base, self.age_flow)
|
488 |
-
age = pyro.sample(
|
489 |
# age_ = self.age_constraints.inv(age)
|
490 |
_ = self.age_flow_components # register with pyro
|
491 |
|
492 |
# p(r), race dist
|
493 |
-
|
494 |
-
race = pyro.sample(
|
495 |
|
496 |
# p(f | a), finding as OneHotCategorical conditioned on age
|
497 |
-
finding_dist_base = dist.Gumbel(
|
498 |
-
|
|
|
499 |
finding_dist = ConditionalTransformedDistributionGumbelMax(
|
500 |
-
finding_dist_base,
|
501 |
-
|
502 |
finding = pyro.sample("finding", finding_dist)
|
503 |
|
504 |
return {
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
}
|
510 |
|
511 |
-
def guide(self, **obs):
|
512 |
-
|
513 |
-
pyro.module('ChestPGM', self)
|
514 |
-
with pyro.plate('observations', obs['x'].shape[0]):
|
515 |
# q(s | x)
|
516 |
-
if obs[
|
517 |
-
s_prob = torch.sigmoid(self.encoder_s(obs[
|
518 |
-
|
519 |
-
probs=s_prob).to_event(1))
|
520 |
# q(r | x)
|
521 |
-
if obs[
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
logits=r_logits).to_event(1))
|
526 |
# q(f | x)
|
527 |
-
if obs[
|
528 |
-
f_prob = torch.sigmoid(self.
|
529 |
-
|
530 |
-
|
531 |
# q(a | x, f)
|
532 |
-
if obs[
|
533 |
-
a_loc = self.encoder_a(
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
|
|
539 |
# assumes all variables are observed, train classfiers
|
540 |
-
pyro.module(
|
541 |
-
with pyro.plate(
|
542 |
-
|
543 |
-
s_prob = torch.sigmoid(self.encoder_s(obs[
|
544 |
-
|
545 |
-
|
|
|
546 |
|
547 |
# q(r | x)
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
|
553 |
# q(f | x)
|
554 |
-
f_prob = torch.sigmoid(self.encoder_f(obs[
|
555 |
qf_x = dist.Bernoulli(probs=f_prob).to_event(1)
|
556 |
-
|
557 |
|
558 |
# q(a | x, f)
|
559 |
-
a_loc = self.encoder_a(
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
# q(r | x)
|
568 |
-
r_logits = F.softmax(self.encoder_r(obs['x']), dim=-1) # .squeeze()
|
569 |
-
# q(f | x)
|
570 |
-
f_prob = torch.sigmoid(self.encoder_f(obs['x']))
|
571 |
-
# q(a | x, f)
|
572 |
-
a_loc = self.encoder_a(
|
573 |
-
obs['x'], y=obs['finding'])
|
574 |
-
|
575 |
-
return {
|
576 |
-
'sex': s_prob,
|
577 |
-
'race': r_logits,
|
578 |
-
'age': a_loc,
|
579 |
-
'finding': f_prob,
|
580 |
-
}
|
581 |
-
|
582 |
-
def predict_unnorm(self, **obs):
|
583 |
# q(s | x)
|
584 |
-
s_prob = self.encoder_s(obs[
|
585 |
# q(r | x)
|
586 |
-
|
587 |
# q(f | x)
|
588 |
-
f_prob = self.encoder_f(obs[
|
589 |
-
qf_x = dist.Bernoulli(probs=torch.sigmoid(f_prob)).to_event(1)
|
590 |
-
obs_finding = pyro.sample('finding', qf_x)
|
591 |
# q(a | x, f)
|
592 |
-
a_loc = self.encoder_a(
|
593 |
-
obs['x'],
|
594 |
-
# y=obs['finding'],
|
595 |
-
y=obs_finding,
|
596 |
-
)
|
597 |
|
598 |
return {
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
}
|
604 |
|
605 |
-
def svi_model(self, **obs):
|
606 |
-
with pyro.plate(
|
607 |
pyro.condition(self.model, data=obs)()
|
608 |
|
609 |
-
def guide_pass(self, **obs):
|
610 |
pass
|
611 |
-
|
612 |
-
def infer_exogeneous(self, obs):
|
613 |
-
batch_size = list(obs.values())[0].shape[0]
|
614 |
-
# assuming that we use transformed distributions for everything:
|
615 |
-
cond_model = pyro.condition(self.sample, data=obs)
|
616 |
-
cond_trace = pyro.poutine.trace(
|
617 |
-
cond_model).get_trace(batch_size)
|
618 |
-
|
619 |
-
output = {}
|
620 |
-
for name, node in cond_trace.nodes.items():
|
621 |
-
if 'z' in name or 'fn' not in node.keys():
|
622 |
-
continue
|
623 |
-
fn = node['fn']
|
624 |
-
if isinstance(fn, dist.Independent):
|
625 |
-
fn = fn.base_dist
|
626 |
-
if isinstance(fn, dist.TransformedDistribution):
|
627 |
-
# compute exogenous base dist (created with TransformReparam) at all sites
|
628 |
-
output[name + '_base'] = T.ComposeTransform(
|
629 |
-
fn.transforms).inv(node['value'])
|
630 |
-
return output
|
631 |
-
|
632 |
-
def scm(self, *args, **kwargs):
|
633 |
-
def config(msg):
|
634 |
-
if isinstance(msg['fn'], dist.TransformedDistribution):
|
635 |
-
return TransformReparam()
|
636 |
-
else:
|
637 |
-
return None
|
638 |
-
return pyro.poutine.reparam(self.model, config=config)(*args, **kwargs)
|
639 |
-
|
640 |
-
def sample_scm(self, n_samples=1, t=None):
|
641 |
-
with pyro.plate('obs', n_samples):
|
642 |
-
samples = self.scm(t)
|
643 |
-
return samples
|
644 |
-
|
645 |
-
def sample(self, n_samples=1, t=None):
|
646 |
-
with pyro.plate('obs', n_samples):
|
647 |
-
samples = self.model(t)
|
648 |
-
return samples
|
649 |
-
|
650 |
-
def counterfactual(self, obs, intervention, num_particles=1, detach=True, t=None):
|
651 |
-
dag_variables = self.variables.keys()
|
652 |
-
obs_ = {k: v for k, v in obs.items() if k in dag_variables}
|
653 |
-
assert set(obs_.keys()) == set(dag_variables)
|
654 |
-
# For continuos variables
|
655 |
-
avg_cfs = {k: torch.zeros_like(obs_[k]) for k in obs_.keys()}
|
656 |
-
batch_size = list(obs_.values())[0].shape[0]
|
657 |
-
|
658 |
-
for _ in range(num_particles):
|
659 |
-
# Abduction
|
660 |
-
exo_noise = self.infer_exogeneous(obs_)
|
661 |
-
exo_noise = {k: v.detach() if detach else v for k,
|
662 |
-
v in exo_noise.items()}
|
663 |
-
# condition on root node variables (no exogeneous noise available)
|
664 |
-
for k in dag_variables:
|
665 |
-
if k not in intervention.keys():
|
666 |
-
if k not in [i.split('_base')[0] for i in exo_noise.keys()]:
|
667 |
-
exo_noise[k] = obs_[k]
|
668 |
-
# Abducted SCM
|
669 |
-
abducted_scm = pyro.poutine.condition(
|
670 |
-
self.sample_scm, data=exo_noise)
|
671 |
-
# Action
|
672 |
-
counterfactual_scm = pyro.poutine.do(
|
673 |
-
abducted_scm, data=intervention)
|
674 |
-
# Prediction
|
675 |
-
counterfactuals = counterfactual_scm(batch_size, t)
|
676 |
-
# Check if we should change "finding", i.e. if its parents and itself are not intervened,
|
677 |
-
# then we use its observed value. This is needed due to stochastic abduction of discrete variables.
|
678 |
-
if 'age' not in intervention.keys() and 'finding' not in intervention.keys():
|
679 |
-
counterfactuals['finding'] = obs_['finding']
|
680 |
-
|
681 |
-
for k, v in counterfactuals.items():
|
682 |
-
avg_cfs[k] += (v / num_particles)
|
683 |
-
return avg_cfs
|
|
|
1 |
+
from typing import Dict
|
2 |
|
3 |
+
import numpy as np
|
4 |
import torch
|
5 |
+
from torch import nn, Tensor
|
6 |
import torch.nn.functional as F
|
7 |
|
8 |
import pyro
|
|
|
16 |
from .layers import (
|
17 |
ConditionalTransformedDistributionGumbelMax,
|
18 |
ConditionalGumbelMax,
|
19 |
+
ConditionalAffineTransform,
|
20 |
+
MLP,
|
21 |
+
CNN,
|
22 |
)
|
23 |
|
24 |
|
25 |
+
class Hparams:
|
26 |
+
def update(self, dict):
|
27 |
+
for k, v in dict.items():
|
28 |
+
setattr(self, k, v)
|
29 |
+
|
30 |
+
|
31 |
class BasePGM(nn.Module):
|
32 |
def __init__(self):
|
33 |
super().__init__()
|
34 |
|
35 |
def scm(self, *args, **kwargs):
|
36 |
def config(msg):
|
37 |
+
if isinstance(msg["fn"], dist.TransformedDistribution):
|
38 |
return TransformReparam()
|
39 |
else:
|
40 |
return None
|
41 |
+
|
42 |
return pyro.poutine.reparam(self.model, config=config)(*args, **kwargs)
|
43 |
|
44 |
+
def sample_scm(self, n_samples: int = 1):
|
45 |
+
with pyro.plate("obs", n_samples):
|
46 |
+
samples = self.scm()
|
47 |
return samples
|
48 |
|
49 |
+
def sample(self, n_samples: int = 1):
|
50 |
+
with pyro.plate("obs", n_samples):
|
51 |
+
samples = self.model() # NOTE: not ideal as model is defined in child class
|
52 |
return samples
|
53 |
|
54 |
+
def infer_exogeneous(self, obs: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
55 |
batch_size = list(obs.values())[0].shape[0]
|
56 |
# assuming that we use transformed distributions for everything:
|
57 |
cond_model = pyro.condition(self.sample, data=obs)
|
58 |
+
cond_trace = pyro.poutine.trace(cond_model).get_trace(batch_size)
|
|
|
59 |
|
60 |
output = {}
|
61 |
for name, node in cond_trace.nodes.items():
|
62 |
+
if "z" in name or "fn" not in node.keys():
|
63 |
continue
|
64 |
+
fn = node["fn"]
|
65 |
if isinstance(fn, dist.Independent):
|
66 |
fn = fn.base_dist
|
67 |
if isinstance(fn, dist.TransformedDistribution):
|
68 |
# compute exogenous base dist (created with TransformReparam) at all sites
|
69 |
+
output[name + "_base"] = T.ComposeTransform(fn.transforms).inv(
|
70 |
+
node["value"]
|
71 |
+
)
|
72 |
return output
|
73 |
|
74 |
+
def counterfactual(
|
75 |
+
self,
|
76 |
+
obs: Dict[str, Tensor],
|
77 |
+
intervention: Dict[str, Tensor],
|
78 |
+
num_particles: int = 1,
|
79 |
+
detach: bool = True,
|
80 |
+
) -> Dict[str, Tensor]:
|
81 |
+
# NOTE: not ideal as "variables" is defined in child class
|
82 |
dag_variables = self.variables.keys()
|
83 |
assert set(obs.keys()) == set(dag_variables)
|
84 |
avg_cfs = {k: torch.zeros_like(obs[k]) for k in obs.keys()}
|
|
|
87 |
for _ in range(num_particles):
|
88 |
# Abduction
|
89 |
exo_noise = self.infer_exogeneous(obs)
|
90 |
+
exo_noise = {k: v.detach() if detach else v for k, v in exo_noise.items()}
|
|
|
91 |
# condition on root node variables (no exogeneous noise available)
|
92 |
for k in dag_variables:
|
93 |
if k not in intervention.keys():
|
94 |
+
if k not in [i.split("_base")[0] for i in exo_noise.keys()]:
|
95 |
exo_noise[k] = obs[k]
|
96 |
# Abducted SCM
|
97 |
+
abducted_scm = pyro.poutine.condition(self.sample_scm, data=exo_noise)
|
|
|
98 |
# Action
|
99 |
+
counterfactual_scm = pyro.poutine.do(abducted_scm, data=intervention)
|
|
|
100 |
# Prediction
|
101 |
+
counterfactuals = counterfactual_scm(batch_size)
|
102 |
+
|
103 |
+
if hasattr(self, "discrete_variables"): # hack for MIMIC
|
104 |
+
# Check if we should change "finding", i.e. if its parents and/or
|
105 |
+
# itself are not intervened on, then we use its observed value.
|
106 |
+
# This is used due to stochastic abduction of discrete variables
|
107 |
+
if (
|
108 |
+
"age" not in intervention.keys()
|
109 |
+
and "finding" not in intervention.keys()
|
110 |
+
):
|
111 |
+
counterfactuals["finding"] = obs["finding"]
|
112 |
|
113 |
for k, v in counterfactuals.items():
|
114 |
+
avg_cfs[k] += v / num_particles
|
115 |
return avg_cfs
|
116 |
|
117 |
|
118 |
class FlowPGM(BasePGM):
|
119 |
+
def __init__(self, args: Hparams):
|
120 |
super().__init__()
|
121 |
self.variables = {
|
122 |
+
"sex": "binary",
|
123 |
+
"mri_seq": "binary",
|
124 |
+
"age": "continuous",
|
125 |
+
"brain_volume": "continuous",
|
126 |
+
"ventricle_volume": "continuous",
|
127 |
}
|
128 |
# priors: s, m, a, b and v
|
129 |
self.s_logit = nn.Parameter(torch.zeros(1))
|
130 |
self.m_logit = nn.Parameter(torch.zeros(1))
|
131 |
+
for k in ["a", "b", "v"]:
|
132 |
+
self.register_buffer(f"{k}_base_loc", torch.zeros(1))
|
133 |
+
self.register_buffer(f"{k}_base_scale", torch.ones(1))
|
134 |
|
135 |
# constraint, assumes data is [-1,1] normalized
|
136 |
# normalize_transform = T.ComposeTransform([
|
|
|
140 |
|
141 |
# age flow
|
142 |
self.age_module = T.ComposeTransformModule(
|
143 |
+
[T.Spline(1, count_bins=4, order="linear")]
|
144 |
+
)
|
145 |
+
self.age_flow = T.ComposeTransform([self.age_module])
|
146 |
# self.age_module, normalize_transform])
|
147 |
|
148 |
# brain volume (conditional) flow: (sex, age) -> brain_vol
|
149 |
+
bvol_net = DenseNN(2, args.widths, [1, 1], nonlinearity=nn.LeakyReLU(0.1))
|
150 |
+
self.bvol_flow = ConditionalAffineTransform(context_nn=bvol_net, event_dim=0)
|
|
|
|
|
151 |
# self.bvol_flow = [self.bvol_flow, normalize_transform]
|
152 |
|
153 |
# ventricle volume (conditional) flow: (brain_vol, age) -> ventricle_vol
|
154 |
+
vvol_net = DenseNN(2, args.widths, [1, 1], nonlinearity=nn.LeakyReLU(0.1))
|
155 |
+
self.vvol_flow = ConditionalAffineTransform(context_nn=vvol_net, event_dim=0)
|
|
|
|
|
156 |
# self.vvol_flow = [self.vvol_transf, normalize_transform]
|
157 |
|
158 |
# if args.setup != 'sup_pgm':
|
|
|
168 |
self.encoder_b = CNN(input_shape, num_outputs=2, context_dim=1)
|
169 |
# q(v | x) = Normal(mu(x), sigma(x))
|
170 |
self.encoder_v = CNN(input_shape, num_outputs=2)
|
171 |
+
self.f = (
|
172 |
+
lambda x: args.std_fixed * torch.ones_like(x)
|
173 |
+
if args.std_fixed > 0
|
174 |
+
else F.softplus(x)
|
175 |
+
)
|
176 |
|
177 |
+
def model(self) -> Dict[str, Tensor]:
|
178 |
# p(s), sex dist
|
179 |
ps = dist.Bernoulli(logits=self.s_logit).to_event(1)
|
180 |
+
sex = pyro.sample("sex", ps)
|
181 |
|
182 |
# p(m), mri_seq dist
|
183 |
pm = dist.Bernoulli(logits=self.m_logit).to_event(1)
|
184 |
+
mri_seq = pyro.sample("mri_seq", pm)
|
185 |
|
186 |
# p(a), age flow
|
187 |
pa_base = dist.Normal(self.a_base_loc, self.a_base_scale).to_event(1)
|
188 |
pa = dist.TransformedDistribution(pa_base, self.age_flow)
|
189 |
+
age = pyro.sample("age", pa)
|
190 |
|
191 |
# p(b | s, a), brain volume flow
|
192 |
+
pb_sa_base = dist.Normal(self.b_base_loc, self.b_base_scale).to_event(1)
|
|
|
193 |
pb_sa = ConditionalTransformedDistribution(
|
194 |
+
pb_sa_base, [self.bvol_flow]
|
195 |
+
).condition(torch.cat([sex, age], dim=1))
|
196 |
+
bvol = pyro.sample("brain_volume", pb_sa)
|
197 |
# _ = self.bvol_transf # register with pyro
|
198 |
|
199 |
# p(v | b, a), ventricle volume flow
|
200 |
+
pv_ba_base = dist.Normal(self.v_base_loc, self.v_base_scale).to_event(1)
|
|
|
201 |
pv_ba = ConditionalTransformedDistribution(
|
202 |
+
pv_ba_base, [self.vvol_flow]
|
203 |
+
).condition(torch.cat([bvol, age], dim=1))
|
204 |
+
vvol = pyro.sample("ventricle_volume", pv_ba)
|
205 |
# _ = self.vvol_transf # register with pyro
|
206 |
|
207 |
return {
|
208 |
+
"sex": sex,
|
209 |
+
"mri_seq": mri_seq,
|
210 |
+
"age": age,
|
211 |
+
"brain_volume": bvol,
|
212 |
+
"ventricle_volume": vvol,
|
213 |
}
|
214 |
|
215 |
+
def guide(self, **obs) -> None:
|
216 |
# guide for (optional) semi-supervised learning
|
217 |
+
pyro.module("FlowPGM", self)
|
218 |
+
with pyro.plate("observations", obs["x"].shape[0]):
|
219 |
# q(m | x)
|
220 |
+
if obs["mri_seq"] is None:
|
221 |
+
m_prob = torch.sigmoid(self.encoder_m(obs["x"]))
|
222 |
+
m = pyro.sample("mri_seq", dist.Bernoulli(probs=m_prob).to_event(1))
|
|
|
223 |
|
224 |
# q(v | x)
|
225 |
+
if obs["ventricle_volume"] is None:
|
226 |
+
v_loc, v_logscale = self.encoder_v(obs["x"]).chunk(2, dim=-1)
|
227 |
qv_x = dist.Normal(v_loc, self.f(v_logscale)).to_event(1)
|
228 |
+
obs["ventricle_volume"] = pyro.sample("ventricle_volume", qv_x)
|
229 |
|
230 |
# q(b | x, v)
|
231 |
+
if obs["brain_volume"] is None:
|
232 |
b_loc, b_logscale = self.encoder_b(
|
233 |
+
obs["x"], y=obs["ventricle_volume"]
|
234 |
+
).chunk(2, dim=-1)
|
235 |
qb_xv = dist.Normal(b_loc, self.f(b_logscale)).to_event(1)
|
236 |
+
obs["brain_volume"] = pyro.sample("brain_volume", qb_xv)
|
237 |
|
238 |
# q(s | x, b)
|
239 |
+
if obs["sex"] is None:
|
240 |
+
s_prob = torch.sigmoid(
|
241 |
+
self.encoder_s(obs["x"], y=obs["brain_volume"])
|
242 |
+
) # .squeeze()
|
243 |
+
pyro.sample("sex", dist.Bernoulli(probs=s_prob).to_event(1))
|
244 |
|
245 |
# q(a | b, v)
|
246 |
+
if obs["age"] is None:
|
247 |
+
ctx = torch.cat([obs["brain_volume"], obs["ventricle_volume"]], dim=-1)
|
|
|
248 |
a_loc, a_logscale = self.encoder_a(ctx).chunk(2, dim=-1)
|
249 |
+
pyro.sample("age", dist.Normal(a_loc, self.f(a_logscale)).to_event(1))
|
|
|
250 |
|
251 |
+
def model_anticausal(self, **obs) -> None:
|
252 |
# assumes all variables are observed
|
253 |
+
pyro.module("FlowPGM", self)
|
254 |
+
with pyro.plate("observations", obs["x"].shape[0]):
|
255 |
# q(v | x)
|
256 |
+
v_loc, v_logscale = self.encoder_v(obs["x"]).chunk(2, dim=-1)
|
257 |
qv_x = dist.Normal(v_loc, self.f(v_logscale)).to_event(1)
|
258 |
+
pyro.sample("ventricle_volume_aux", qv_x, obs=obs["ventricle_volume"])
|
|
|
259 |
|
260 |
# q(b | x, v)
|
261 |
b_loc, b_logscale = self.encoder_b(
|
262 |
+
obs["x"], y=obs["ventricle_volume"]
|
263 |
+
).chunk(2, dim=-1)
|
264 |
qb_xv = dist.Normal(b_loc, self.f(b_logscale)).to_event(1)
|
265 |
+
pyro.sample("brain_volume_aux", qb_xv, obs=obs["brain_volume"])
|
266 |
|
267 |
# q(a | b, v)
|
268 |
+
ctx = torch.cat([obs["brain_volume"], obs["ventricle_volume"]], dim=-1)
|
|
|
269 |
a_loc, a_logscale = self.encoder_a(ctx).chunk(2, dim=-1)
|
270 |
+
pyro.sample(
|
271 |
+
"age_aux",
|
272 |
+
dist.Normal(a_loc, self.f(a_logscale)).to_event(1),
|
273 |
+
obs=obs["age"],
|
274 |
+
)
|
275 |
|
276 |
# q(s | x, b)
|
277 |
+
s_prob = torch.sigmoid(self.encoder_s(obs["x"], y=obs["brain_volume"]))
|
|
|
278 |
qs_xb = dist.Bernoulli(probs=s_prob).to_event(1)
|
279 |
+
pyro.sample("sex_aux", qs_xb, obs=obs["sex"])
|
280 |
|
281 |
# q(m | x)
|
282 |
+
m_prob = torch.sigmoid(self.encoder_m(obs["x"]))
|
283 |
qm_x = dist.Bernoulli(probs=m_prob).to_event(1)
|
284 |
+
pyro.sample("mri_seq_aux", qm_x, obs=obs["mri_seq"])
|
285 |
|
286 |
+
def predict(self, **obs) -> Dict[str, Tensor]:
|
287 |
# q(v | x)
|
288 |
+
v_loc, v_logscale = self.encoder_v(obs["x"]).chunk(2, dim=-1)
|
289 |
# v_loc = torch.tanh(v_loc)
|
290 |
# q(b | x, v)
|
291 |
+
b_loc, b_logscale = self.encoder_b(obs["x"], y=obs["ventricle_volume"]).chunk(
|
292 |
+
2, dim=-1
|
293 |
+
)
|
294 |
# b_loc = torch.tanh(b_loc)
|
295 |
# q(a | b, v)
|
296 |
+
ctx = torch.cat([obs["brain_volume"], obs["ventricle_volume"]], dim=-1)
|
297 |
a_loc, a_logscale = self.encoder_a(ctx).chunk(2, dim=-1)
|
298 |
# a_loc = torch.tanh(b_loc)
|
299 |
# q(s | x, b)
|
300 |
+
s_prob = torch.sigmoid(self.encoder_s(obs["x"], y=obs["brain_volume"]))
|
301 |
# q(m | x)
|
302 |
+
m_prob = torch.sigmoid(self.encoder_m(obs["x"]))
|
303 |
|
304 |
return {
|
305 |
+
"sex": s_prob,
|
306 |
+
"mri_seq": m_prob,
|
307 |
+
"age": a_loc,
|
308 |
+
"brain_volume": b_loc,
|
309 |
+
"ventricle_volume": v_loc,
|
310 |
}
|
311 |
|
312 |
+
def svi_model(self, **obs) -> None:
|
313 |
+
with pyro.plate("observations", obs["x"].shape[0]):
|
314 |
pyro.condition(self.model, data=obs)()
|
315 |
|
316 |
+
def guide_pass(self, **obs) -> None:
|
317 |
pass
|
318 |
|
319 |
|
|
|
321 |
def __init__(self, args):
|
322 |
super().__init__()
|
323 |
self.variables = {
|
324 |
+
"thickness": "continuous",
|
325 |
+
"intensity": "continuous",
|
326 |
+
"digit": "categorical",
|
327 |
}
|
328 |
# priors
|
329 |
self.digit_logits = nn.Parameter(torch.zeros(1, 10)) # uniform prior
|
330 |
+
for k in ["t", "i"]: # thickness, intensity, standard Gaussian
|
331 |
+
self.register_buffer(f"{k}_base_loc", torch.zeros(1))
|
332 |
+
self.register_buffer(f"{k}_base_scale", torch.ones(1))
|
333 |
|
334 |
# constraint, assumes data is [-1,1] normalized
|
335 |
+
normalize_transform = T.ComposeTransform(
|
336 |
+
[T.SigmoidTransform(), T.AffineTransform(loc=-1, scale=2)]
|
337 |
+
)
|
338 |
|
339 |
# thickness flow
|
340 |
self.thickness_module = T.ComposeTransformModule(
|
341 |
+
[T.Spline(1, count_bins=4, order="linear")]
|
342 |
+
)
|
343 |
+
self.thickness_flow = T.ComposeTransform(
|
344 |
+
[self.thickness_module, normalize_transform]
|
345 |
+
)
|
346 |
|
347 |
# intensity (conditional) flow: thickness -> intensity
|
348 |
+
intensity_net = DenseNN(1, args.widths, [1, 1], nonlinearity=nn.GELU())
|
|
|
349 |
self.context_nn = ConditionalAffineTransform(
|
350 |
+
context_nn=intensity_net, event_dim=0
|
351 |
+
)
|
352 |
self.intensity_flow = [self.context_nn, normalize_transform]
|
353 |
|
354 |
+
if args.setup != "sup_pgm":
|
355 |
# anticausal predictors
|
356 |
input_shape = (args.input_channels, args.input_res, args.input_res)
|
357 |
# q(t | x, i) = Normal(mu(x, i), sigma(x, i)), 2 outputs: loc & scale
|
358 |
+
self.encoder_t = CNN(input_shape, num_outputs=2, context_dim=1, width=8)
|
|
|
359 |
# q(i | x) = Normal(mu(x), sigma(x))
|
360 |
self.encoder_i = CNN(input_shape, num_outputs=2, width=8)
|
361 |
# q(y | x) = Categorical(pi(x))
|
362 |
self.encoder_y = CNN(input_shape, num_outputs=10, width=8)
|
363 |
+
self.f = (
|
364 |
+
lambda x: args.std_fixed * torch.ones_like(x)
|
365 |
+
if args.std_fixed > 0
|
366 |
+
else F.softplus(x)
|
367 |
+
)
|
368 |
+
|
369 |
+
def model(self) -> Dict[str, Tensor]:
|
370 |
+
pyro.module("MorphoMNISTPGM", self)
|
371 |
# p(y), digit label prior dist
|
372 |
py = dist.OneHotCategorical(
|
373 |
+
probs=F.softmax(self.digit_logits, dim=-1)
|
374 |
+
) # .to_event(1)
|
375 |
# with pyro.poutine.scale(scale=0.05):
|
376 |
+
digit = pyro.sample("digit", py)
|
377 |
|
378 |
# p(t), thickness flow
|
379 |
pt_base = dist.Normal(self.t_base_loc, self.t_base_scale).to_event(1)
|
380 |
pt = dist.TransformedDistribution(pt_base, self.thickness_flow)
|
381 |
+
thickness = pyro.sample("thickness", pt)
|
382 |
|
383 |
# p(i | t), intensity conditional flow
|
384 |
pi_t_base = dist.Normal(self.i_base_loc, self.i_base_scale).to_event(1)
|
385 |
pi_t = ConditionalTransformedDistribution(
|
386 |
+
pi_t_base, self.intensity_flow
|
387 |
+
).condition(thickness)
|
388 |
+
intensity = pyro.sample("intensity", pi_t)
|
389 |
_ = self.context_nn
|
390 |
|
391 |
+
return {"thickness": thickness, "intensity": intensity, "digit": digit}
|
392 |
|
393 |
+
def guide(self, **obs) -> None:
|
394 |
# guide for (optional) semi-supervised learning
|
395 |
+
with pyro.plate("observations", obs["x"].shape[0]):
|
396 |
# q(i | x)
|
397 |
+
if obs["intensity"] is None:
|
398 |
+
i_loc, i_logscale = self.encoder_i(obs["x"]).chunk(2, dim=-1)
|
399 |
+
qi_t = dist.Normal(torch.tanh(i_loc), self.f(i_logscale)).to_event(1)
|
400 |
+
obs["intensity"] = pyro.sample("intensity", qi_t)
|
|
|
401 |
|
402 |
# q(t | x, i)
|
403 |
+
if obs["thickness"] is None:
|
404 |
+
t_loc, t_logscale = self.encoder_t(obs["x"], y=obs["intensity"]).chunk(
|
405 |
+
2, dim=-1
|
406 |
+
)
|
407 |
+
qt_x = dist.Normal(torch.tanh(t_loc), self.f(t_logscale)).to_event(1)
|
408 |
+
obs["thickness"] = pyro.sample("thickness", qt_x)
|
409 |
|
410 |
# q(y | x)
|
411 |
+
if obs["digit"] is None:
|
412 |
+
y_prob = F.softmax(self.encoder_y(obs["x"]), dim=-1)
|
413 |
+
qy_x = dist.OneHotCategorical(probs=y_prob) # .to_event(1)
|
414 |
+
pyro.sample("digit", qy_x)
|
415 |
|
416 |
+
def model_anticausal(self, **obs) -> None:
|
417 |
# assumes all variables are observed & continuous ones are in [-1,1]
|
418 |
+
pyro.module("MorphoMNISTPGM", self)
|
419 |
+
with pyro.plate("observations", obs["x"].shape[0]):
|
420 |
# q(t | x, i)
|
421 |
+
t_loc, t_logscale = self.encoder_t(obs["x"], y=obs["intensity"]).chunk(
|
422 |
+
2, dim=-1
|
423 |
+
)
|
424 |
+
qt_x = dist.Normal(torch.tanh(t_loc), self.f(t_logscale)).to_event(1)
|
425 |
+
pyro.sample("thickness_aux", qt_x, obs=obs["thickness"])
|
426 |
|
427 |
# q(i | x)
|
428 |
+
i_loc, i_logscale = self.encoder_i(obs["x"]).chunk(2, dim=-1)
|
429 |
+
qi_t = dist.Normal(torch.tanh(i_loc), self.f(i_logscale)).to_event(1)
|
430 |
+
pyro.sample("intensity_aux", qi_t, obs=obs["intensity"])
|
|
|
431 |
|
432 |
# q(y | x)
|
433 |
+
y_prob = F.softmax(self.encoder_y(obs["x"]), dim=-1)
|
434 |
+
qy_x = dist.OneHotCategorical(probs=y_prob) # .to_event(1)
|
435 |
+
pyro.sample("digit_aux", qy_x, obs=obs["digit"])
|
436 |
|
437 |
+
def predict(self, **obs) -> Dict[str, Tensor]:
|
438 |
# q(t | x, i)
|
439 |
+
t_loc, t_logscale = self.encoder_t(obs["x"], y=obs["intensity"]).chunk(
|
440 |
+
2, dim=-1
|
441 |
+
)
|
442 |
t_loc = torch.tanh(t_loc)
|
443 |
# q(i | x)
|
444 |
+
i_loc, i_logscale = self.encoder_i(obs["x"]).chunk(2, dim=-1)
|
445 |
i_loc = torch.tanh(i_loc)
|
446 |
# q(y | x)
|
447 |
+
y_prob = F.softmax(self.encoder_y(obs["x"]), dim=-1)
|
448 |
+
return {"thickness": t_loc, "intensity": i_loc, "digit": y_prob}
|
449 |
|
450 |
+
def svi_model(self, **obs) -> None:
|
451 |
+
with pyro.plate("observations", obs["x"].shape[0]):
|
452 |
pyro.condition(self.model, data=obs)()
|
453 |
|
454 |
+
def guide_pass(self, **obs) -> None:
|
455 |
pass
|
456 |
|
457 |
|
458 |
+
class ChestPGM(BasePGM):
|
459 |
+
def __init__(self, args: Hparams):
|
460 |
super().__init__()
|
461 |
self.variables = {
|
462 |
+
"race": "categorical",
|
463 |
+
"sex": "binary",
|
464 |
+
"finding": "binary",
|
465 |
+
"age": "continuous",
|
466 |
}
|
467 |
# Discrete variables that are not root nodes
|
468 |
+
self.discrete_variables = {"finding": "binary"}
|
469 |
+
# define base distributions
|
470 |
+
for k in ["a"]: # , "f"]:
|
471 |
+
self.register_buffer(f"{k}_base_loc", torch.zeros(1))
|
472 |
+
self.register_buffer(f"{k}_base_scale", torch.ones(1))
|
473 |
+
# age spline flow
|
|
|
|
|
|
|
|
|
474 |
self.age_flow_components = T.ComposeTransformModule([T.Spline(1)])
|
475 |
# self.age_constraints = T.ComposeTransform([
|
476 |
# T.AffineTransform(loc=4.09541458484, scale=0.32548387126),
|
477 |
# T.ExpTransform()])
|
478 |
+
self.age_flow = T.ComposeTransform(
|
479 |
+
[
|
480 |
+
self.age_flow_components,
|
481 |
+
# self.age_constraints,
|
482 |
+
]
|
483 |
+
)
|
484 |
# Finding (conditional) via MLP, a -> f
|
485 |
+
finding_net = DenseNN(1, [8, 16], param_dims=[2], nonlinearity=nn.Sigmoid())
|
|
|
486 |
self.finding_transform_GumbelMax = ConditionalGumbelMax(
|
487 |
+
context_nn=finding_net, event_dim=0
|
488 |
+
)
|
489 |
# log space for sex and race
|
490 |
+
self.sex_logit = nn.Parameter(np.log(1 / 2) * torch.ones(1))
|
491 |
+
self.race_logits = nn.Parameter(np.log(1 / 3) * torch.ones(1, 3))
|
|
|
492 |
|
493 |
input_shape = (args.input_channels, args.input_res, args.input_res)
|
494 |
|
|
|
502 |
# q(a | x, f) ~ Normal(mu(x), sigma(x))
|
503 |
self.encoder_a = CNN(input_shape, num_outputs=1, context_dim=1)
|
504 |
|
505 |
+
def model(self) -> Dict[str, Tensor]:
|
506 |
+
pyro.module("ChestPGM", self)
|
507 |
# p(s), sex dist
|
508 |
ps = dist.Bernoulli(logits=self.sex_logit).to_event(1)
|
509 |
+
sex = pyro.sample("sex", ps)
|
510 |
|
511 |
# p(a), age flow
|
512 |
pa_base = dist.Normal(self.a_base_loc, self.a_base_scale).to_event(1)
|
513 |
pa = dist.TransformedDistribution(pa_base, self.age_flow)
|
514 |
+
age = pyro.sample("age", pa)
|
515 |
# age_ = self.age_constraints.inv(age)
|
516 |
_ = self.age_flow_components # register with pyro
|
517 |
|
518 |
# p(r), race dist
|
519 |
+
pr = dist.OneHotCategorical(logits=self.race_logits) # .to_event(1)
|
520 |
+
race = pyro.sample("race", pr)
|
521 |
|
522 |
# p(f | a), finding as OneHotCategorical conditioned on age
|
523 |
+
# finding_dist_base = dist.Gumbel(self.f_base_loc, self.f_base_scale).to_event(1)
|
524 |
+
finding_dist_base = dist.Gumbel(torch.zeros(1), torch.ones(1)).to_event(1)
|
525 |
+
|
526 |
finding_dist = ConditionalTransformedDistributionGumbelMax(
|
527 |
+
finding_dist_base, [self.finding_transform_GumbelMax]
|
528 |
+
).condition(age)
|
529 |
finding = pyro.sample("finding", finding_dist)
|
530 |
|
531 |
return {
|
532 |
+
"sex": sex,
|
533 |
+
"race": race,
|
534 |
+
"age": age,
|
535 |
+
"finding": finding,
|
536 |
}
|
537 |
|
538 |
+
def guide(self, **obs) -> None:
|
539 |
+
with pyro.plate("observations", obs["x"].shape[0]):
|
|
|
|
|
540 |
# q(s | x)
|
541 |
+
if obs["sex"] is None:
|
542 |
+
s_prob = torch.sigmoid(self.encoder_s(obs["x"]))
|
543 |
+
pyro.sample("sex", dist.Bernoulli(probs=s_prob).to_event(1))
|
|
|
544 |
# q(r | x)
|
545 |
+
if obs["race"] is None:
|
546 |
+
r_probs = F.softmax(self.encoder_r(obs["x"]), dim=-1)
|
547 |
+
qr_x = dist.OneHotCategorical(probs=r_probs) # .to_event(1)
|
548 |
+
pyro.sample("race", qr_x)
|
|
|
549 |
# q(f | x)
|
550 |
+
if obs["finding"] is None:
|
551 |
+
f_prob = torch.sigmoid(self.encoder_f(obs["x"]))
|
552 |
+
qf_x = dist.Bernoulli(probs=f_prob).to_event(1)
|
553 |
+
obs["finding"] = pyro.sample("finding", qf_x)
|
554 |
# q(a | x, f)
|
555 |
+
if obs["age"] is None:
|
556 |
+
a_loc, a_logscale = self.encoder_a(obs["x"], y=obs["finding"]).chunk(
|
557 |
+
2, dim=-1
|
558 |
+
)
|
559 |
+
qa_xf = dist.Normal(a_loc, self.f(a_logscale)).to_event(1)
|
560 |
+
pyro.sample("age_aux", qa_xf)
|
561 |
+
|
562 |
+
def model_anticausal(self, **obs) -> None:
|
563 |
# assumes all variables are observed, train classfiers
|
564 |
+
pyro.module("ChestPGM", self)
|
565 |
+
with pyro.plate("observations", obs["x"].shape[0]):
|
566 |
+
# q(s | x)
|
567 |
+
s_prob = torch.sigmoid(self.encoder_s(obs["x"]))
|
568 |
+
qs_x = dist.Bernoulli(probs=s_prob).to_event(1)
|
569 |
+
# with pyro.poutine.scale(scale=0.8):
|
570 |
+
pyro.sample("sex_aux", qs_x, obs=obs["sex"])
|
571 |
|
572 |
# q(r | x)
|
573 |
+
r_probs = F.softmax(self.encoder_r(obs["x"]), dim=-1)
|
574 |
+
qr_x = dist.OneHotCategorical(probs=r_probs) # .to_event(1)
|
575 |
+
# with pyro.poutine.scale(scale=0.5):
|
576 |
+
pyro.sample("race_aux", qr_x, obs=obs["race"])
|
577 |
|
578 |
# q(f | x)
|
579 |
+
f_prob = torch.sigmoid(self.encoder_f(obs["x"]))
|
580 |
qf_x = dist.Bernoulli(probs=f_prob).to_event(1)
|
581 |
+
pyro.sample("finding_aux", qf_x, obs=obs["finding"])
|
582 |
|
583 |
# q(a | x, f)
|
584 |
+
a_loc, a_logscale = self.encoder_a(obs["x"], y=obs["finding"]).chunk(
|
585 |
+
2, dim=-1
|
586 |
+
)
|
587 |
+
qa_xf = dist.Normal(a_loc, self.f(a_logscale)).to_event(1)
|
588 |
+
# with pyro.poutine.scale(scale=2):
|
589 |
+
pyro.sample("age_aux", qa_xf, obs=obs["age"])
|
590 |
+
|
591 |
+
def predict(self, **obs) -> Dict[str, Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
592 |
# q(s | x)
|
593 |
+
s_prob = torch.sigmoid(self.encoder_s(obs["x"]))
|
594 |
# q(r | x)
|
595 |
+
r_probs = F.softmax(self.encoder_r(obs["x"]), dim=-1)
|
596 |
# q(f | x)
|
597 |
+
f_prob = torch.sigmoid(self.encoder_f(obs["x"]))
|
|
|
|
|
598 |
# q(a | x, f)
|
599 |
+
a_loc, _ = self.encoder_a(obs["x"], y=obs["finding"]).chunk(2, dim=-1)
|
|
|
|
|
|
|
|
|
600 |
|
601 |
return {
|
602 |
+
"sex": s_prob,
|
603 |
+
"race": r_probs,
|
604 |
+
"finding": f_prob,
|
605 |
+
"age": a_loc,
|
606 |
}
|
607 |
|
608 |
+
def svi_model(self, **obs) -> None:
|
609 |
+
with pyro.plate("observations", obs["x"].shape[0]):
|
610 |
pyro.condition(self.model, data=obs)()
|
611 |
|
612 |
+
def guide_pass(self, **obs) -> None:
|
613 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pgm/layers.py
CHANGED
@@ -7,7 +7,7 @@ from typing import Dict
|
|
7 |
from pyro.distributions.conditional import (
|
8 |
ConditionalTransformModule,
|
9 |
ConditionalTransformedDistribution,
|
10 |
-
TransformedDistribution
|
11 |
)
|
12 |
from pyro.distributions.torch_distribution import TorchDistributionMixin
|
13 |
|
@@ -25,7 +25,8 @@ class ConditionalAffineTransform(ConditionalTransformModule):
|
|
25 |
def condition(self, context):
|
26 |
loc, log_scale = self.context_nn(context)
|
27 |
return torch.distributions.transforms.AffineTransform(
|
28 |
-
loc, log_scale.exp(), event_dim=self.event_dim
|
|
|
29 |
|
30 |
|
31 |
class MLP(nn.Module):
|
@@ -58,27 +59,27 @@ class CNN(nn.Module):
|
|
58 |
nn.BatchNorm2d(width),
|
59 |
activation,
|
60 |
(nn.MaxPool2d(2, 2) if res > 32 else nn.Identity()),
|
61 |
-
nn.Conv2d(width, 2*width, 3, 2, 1, bias=False),
|
62 |
-
nn.BatchNorm2d(2*width),
|
|
|
|
|
|
|
63 |
activation,
|
64 |
-
nn.Conv2d(2*width,
|
65 |
-
nn.BatchNorm2d(
|
66 |
activation,
|
67 |
-
nn.Conv2d(
|
68 |
-
nn.BatchNorm2d(4*width),
|
69 |
activation,
|
70 |
-
nn.Conv2d(4*width,
|
71 |
-
nn.BatchNorm2d(
|
72 |
activation,
|
73 |
-
nn.Conv2d(4*width, 8*width, 3, 2, 1, bias=False),
|
74 |
-
nn.BatchNorm2d(8*width),
|
75 |
-
activation
|
76 |
)
|
77 |
self.fc = nn.Sequential(
|
78 |
-
nn.Linear(8*width + context_dim, 8*width, bias=False),
|
79 |
-
nn.BatchNorm1d(8*width),
|
80 |
activation,
|
81 |
-
nn.Linear(8*width, num_outputs)
|
82 |
)
|
83 |
|
84 |
def forward(self, x, y=None):
|
@@ -96,7 +97,9 @@ class ArgMaxGumbelMax(Transform):
|
|
96 |
super(ArgMaxGumbelMax, self).__init__(cache_size=cache_size)
|
97 |
self.logits = logits
|
98 |
self._event_dim = event_dim
|
99 |
-
self._categorical = pyro.distributions.torch.Categorical(
|
|
|
|
|
100 |
|
101 |
@property
|
102 |
def event_dim(self):
|
@@ -104,7 +107,7 @@ class ArgMaxGumbelMax(Transform):
|
|
104 |
|
105 |
def __call__(self, gumbels):
|
106 |
"""
|
107 |
-
Computes the forward transform
|
108 |
"""
|
109 |
assert self.logits != None, "Logits not defined."
|
110 |
|
@@ -126,8 +129,8 @@ class ArgMaxGumbelMax(Transform):
|
|
126 |
|
127 |
@property
|
128 |
def domain(self):
|
129 |
-
""""
|
130 |
-
Domain of input(gumbel variables), Real
|
131 |
"""
|
132 |
if self.event_dim == 0:
|
133 |
return constraints.real
|
@@ -135,8 +138,8 @@ class ArgMaxGumbelMax(Transform):
|
|
135 |
|
136 |
@property
|
137 |
def codomain(self):
|
138 |
-
""""
|
139 |
-
Domain of output(categorical variables), should be natural numbers, but set to Real for now
|
140 |
"""
|
141 |
if self.event_dim == 0:
|
142 |
return constraints.real
|
@@ -147,28 +150,32 @@ class ArgMaxGumbelMax(Transform):
|
|
147 |
assert self.logits != None, "Logits not defined."
|
148 |
|
149 |
uniforms = torch.rand(
|
150 |
-
self.logits.shape, dtype=self.logits.dtype, device=self.logits.device
|
|
|
151 |
gumbels = -((-(uniforms.log())).log())
|
152 |
# print(f'gumbels: {gumbels.size()}, {gumbels.dtype}')
|
153 |
# (batch_size, num_classes) mask to select kth class
|
154 |
# print(f'k : {k.size()}')
|
155 |
-
mask = F.one_hot(
|
156 |
-
|
|
|
157 |
# print(f'mask: {mask.size()}, {mask.dtype}')
|
158 |
# (batch_size, 1) select topgumbel for truncation of other classes
|
159 |
-
topgumbel = (mask * gumbels).sum(dim=-1, keepdim=True) -
|
160 |
-
|
|
|
161 |
mask = 1 - mask # invert mask to select other != k classes
|
162 |
g = gumbels + self.logits
|
163 |
# (batch_size, num_classes)
|
164 |
-
epsilons = -torch.log(mask * torch.exp(-g) +
|
165 |
-
|
|
|
166 |
return epsilons
|
167 |
|
168 |
def log_abs_det_jacobian(self, x, y):
|
169 |
"""We use the log_abs_det_jacobian to account for the categorical prob
|
170 |
-
|
171 |
-
|
172 |
"""
|
173 |
# print(f"logits: {torch.log(F.softmax(self.logits, dim=-1)).size()}")
|
174 |
# print(f'y: {y.size()} ')
|
@@ -188,7 +195,8 @@ class ConditionalGumbelMax(ConditionalTransformModule):
|
|
188 |
def condition(self, context):
|
189 |
"""Given context (age), output the Categorical results"""
|
190 |
logits = self.context_nn(
|
191 |
-
context
|
|
|
192 |
return ArgMaxGumbelMax(logits)
|
193 |
|
194 |
def _logits(self, context):
|
@@ -197,8 +205,8 @@ class ConditionalGumbelMax(ConditionalTransformModule):
|
|
197 |
|
198 |
@property
|
199 |
def domain(self):
|
200 |
-
""""
|
201 |
-
Domain of input(gumbel variables), Real
|
202 |
"""
|
203 |
if self.event_dim == 0:
|
204 |
return constraints.real
|
@@ -206,8 +214,8 @@ class ConditionalGumbelMax(ConditionalTransformModule):
|
|
206 |
|
207 |
@property
|
208 |
def codomain(self):
|
209 |
-
""""
|
210 |
-
Domain of output(categorical variables), should be natural numbers, but set to Real for now
|
211 |
"""
|
212 |
if self.event_dim == 0:
|
213 |
return constraints.real
|
@@ -215,8 +223,7 @@ class ConditionalGumbelMax(ConditionalTransformModule):
|
|
215 |
|
216 |
|
217 |
class TransformedDistributionGumbelMax(TransformedDistribution, TorchDistributionMixin):
|
218 |
-
r"""
|
219 |
-
"""
|
220 |
arg_constraints: Dict[str, constraints.Constraint] = {}
|
221 |
|
222 |
def log_prob(self, value):
|
@@ -233,15 +240,16 @@ class TransformedDistributionGumbelMax(TransformedDistribution, TorchDistributio
|
|
233 |
for transform in reversed(self.transforms):
|
234 |
x = transform.inv(y)
|
235 |
event_dim += transform.domain.event_dim - transform.codomain.event_dim
|
236 |
-
log_prob = log_prob - _sum_rightmost(
|
237 |
-
|
|
|
|
|
238 |
y = x
|
239 |
# print(f"log_prob: {log_prob.size()}")
|
240 |
return log_prob
|
241 |
|
242 |
|
243 |
class ConditionalTransformedDistributionGumbelMax(ConditionalTransformedDistribution):
|
244 |
-
|
245 |
def condition(self, context):
|
246 |
base_dist = self.base_dist.condition(context)
|
247 |
transforms = [t.condition(context) for t in self.transforms]
|
@@ -249,4 +257,4 @@ class ConditionalTransformedDistributionGumbelMax(ConditionalTransformedDistribu
|
|
249 |
return TransformedDistributionGumbelMax(base_dist, transforms)
|
250 |
|
251 |
def clear_cache(self):
|
252 |
-
pass
|
|
|
7 |
from pyro.distributions.conditional import (
|
8 |
ConditionalTransformModule,
|
9 |
ConditionalTransformedDistribution,
|
10 |
+
TransformedDistribution,
|
11 |
)
|
12 |
from pyro.distributions.torch_distribution import TorchDistributionMixin
|
13 |
|
|
|
25 |
def condition(self, context):
|
26 |
loc, log_scale = self.context_nn(context)
|
27 |
return torch.distributions.transforms.AffineTransform(
|
28 |
+
loc, log_scale.exp(), event_dim=self.event_dim
|
29 |
+
)
|
30 |
|
31 |
|
32 |
class MLP(nn.Module):
|
|
|
59 |
nn.BatchNorm2d(width),
|
60 |
activation,
|
61 |
(nn.MaxPool2d(2, 2) if res > 32 else nn.Identity()),
|
62 |
+
nn.Conv2d(width, 2 * width, 3, 2, 1, bias=False),
|
63 |
+
nn.BatchNorm2d(2 * width),
|
64 |
+
activation,
|
65 |
+
nn.Conv2d(2 * width, 2 * width, 3, 1, 1, bias=False),
|
66 |
+
nn.BatchNorm2d(2 * width),
|
67 |
activation,
|
68 |
+
nn.Conv2d(2 * width, 4 * width, 3, 2, 1, bias=False),
|
69 |
+
nn.BatchNorm2d(4 * width),
|
70 |
activation,
|
71 |
+
nn.Conv2d(4 * width, 4 * width, 3, 1, 1, bias=False),
|
72 |
+
nn.BatchNorm2d(4 * width),
|
73 |
activation,
|
74 |
+
nn.Conv2d(4 * width, 8 * width, 3, 2, 1, bias=False),
|
75 |
+
nn.BatchNorm2d(8 * width),
|
76 |
activation,
|
|
|
|
|
|
|
77 |
)
|
78 |
self.fc = nn.Sequential(
|
79 |
+
nn.Linear(8 * width + context_dim, 8 * width, bias=False),
|
80 |
+
nn.BatchNorm1d(8 * width),
|
81 |
activation,
|
82 |
+
nn.Linear(8 * width, num_outputs),
|
83 |
)
|
84 |
|
85 |
def forward(self, x, y=None):
|
|
|
97 |
super(ArgMaxGumbelMax, self).__init__(cache_size=cache_size)
|
98 |
self.logits = logits
|
99 |
self._event_dim = event_dim
|
100 |
+
self._categorical = pyro.distributions.torch.Categorical(
|
101 |
+
logits=self.logits
|
102 |
+
).to_event(0)
|
103 |
|
104 |
@property
|
105 |
def event_dim(self):
|
|
|
107 |
|
108 |
def __call__(self, gumbels):
|
109 |
"""
|
110 |
+
Computes the forward transform
|
111 |
"""
|
112 |
assert self.logits != None, "Logits not defined."
|
113 |
|
|
|
129 |
|
130 |
@property
|
131 |
def domain(self):
|
132 |
+
""" "
|
133 |
+
Domain of input(gumbel variables), Real
|
134 |
"""
|
135 |
if self.event_dim == 0:
|
136 |
return constraints.real
|
|
|
138 |
|
139 |
@property
|
140 |
def codomain(self):
|
141 |
+
""" "
|
142 |
+
Domain of output(categorical variables), should be natural numbers, but set to Real for now
|
143 |
"""
|
144 |
if self.event_dim == 0:
|
145 |
return constraints.real
|
|
|
150 |
assert self.logits != None, "Logits not defined."
|
151 |
|
152 |
uniforms = torch.rand(
|
153 |
+
self.logits.shape, dtype=self.logits.dtype, device=self.logits.device
|
154 |
+
)
|
155 |
gumbels = -((-(uniforms.log())).log())
|
156 |
# print(f'gumbels: {gumbels.size()}, {gumbels.dtype}')
|
157 |
# (batch_size, num_classes) mask to select kth class
|
158 |
# print(f'k : {k.size()}')
|
159 |
+
mask = F.one_hot(
|
160 |
+
k.squeeze(-1).to(torch.int64), num_classes=self.logits.shape[-1]
|
161 |
+
)
|
162 |
# print(f'mask: {mask.size()}, {mask.dtype}')
|
163 |
# (batch_size, 1) select topgumbel for truncation of other classes
|
164 |
+
topgumbel = (mask * gumbels).sum(dim=-1, keepdim=True) - (
|
165 |
+
mask * self.logits
|
166 |
+
).sum(dim=-1, keepdim=True)
|
167 |
mask = 1 - mask # invert mask to select other != k classes
|
168 |
g = gumbels + self.logits
|
169 |
# (batch_size, num_classes)
|
170 |
+
epsilons = -torch.log(mask * torch.exp(-g) + torch.exp(-topgumbel)) - (
|
171 |
+
mask * self.logits
|
172 |
+
)
|
173 |
return epsilons
|
174 |
|
175 |
def log_abs_det_jacobian(self, x, y):
|
176 |
"""We use the log_abs_det_jacobian to account for the categorical prob
|
177 |
+
x: Gumbels; y: argmax(x+logits)
|
178 |
+
P(y) = softmax
|
179 |
"""
|
180 |
# print(f"logits: {torch.log(F.softmax(self.logits, dim=-1)).size()}")
|
181 |
# print(f'y: {y.size()} ')
|
|
|
195 |
def condition(self, context):
|
196 |
"""Given context (age), output the Categorical results"""
|
197 |
logits = self.context_nn(
|
198 |
+
context
|
199 |
+
) # The logits for calculating argmax(Gumbel + logits)
|
200 |
return ArgMaxGumbelMax(logits)
|
201 |
|
202 |
def _logits(self, context):
|
|
|
205 |
|
206 |
@property
|
207 |
def domain(self):
|
208 |
+
""" "
|
209 |
+
Domain of input(gumbel variables), Real
|
210 |
"""
|
211 |
if self.event_dim == 0:
|
212 |
return constraints.real
|
|
|
214 |
|
215 |
@property
|
216 |
def codomain(self):
|
217 |
+
""" "
|
218 |
+
Domain of output(categorical variables), should be natural numbers, but set to Real for now
|
219 |
"""
|
220 |
if self.event_dim == 0:
|
221 |
return constraints.real
|
|
|
223 |
|
224 |
|
225 |
class TransformedDistributionGumbelMax(TransformedDistribution, TorchDistributionMixin):
|
226 |
+
r"""Define a TransformedDistribution class for Gumbel max"""
|
|
|
227 |
arg_constraints: Dict[str, constraints.Constraint] = {}
|
228 |
|
229 |
def log_prob(self, value):
|
|
|
240 |
for transform in reversed(self.transforms):
|
241 |
x = transform.inv(y)
|
242 |
event_dim += transform.domain.event_dim - transform.codomain.event_dim
|
243 |
+
log_prob = log_prob - _sum_rightmost(
|
244 |
+
transform.log_abs_det_jacobian(x, y),
|
245 |
+
event_dim - transform.domain.event_dim,
|
246 |
+
)
|
247 |
y = x
|
248 |
# print(f"log_prob: {log_prob.size()}")
|
249 |
return log_prob
|
250 |
|
251 |
|
252 |
class ConditionalTransformedDistributionGumbelMax(ConditionalTransformedDistribution):
|
|
|
253 |
def condition(self, context):
|
254 |
base_dist = self.base_dist.condition(context)
|
255 |
transforms = [t.condition(context) for t in self.transforms]
|
|
|
257 |
return TransformedDistributionGumbelMax(base_dist, transforms)
|
258 |
|
259 |
def clear_cache(self):
|
260 |
+
pass
|
vae.py
CHANGED
@@ -6,9 +6,17 @@ import torch.distributions as dist
|
|
6 |
|
7 |
EPS = -9 # minimum logscale
|
8 |
|
|
|
9 |
@torch.jit.script
|
10 |
def gaussian_kl(q_loc, q_logscale, p_loc, p_logscale):
|
11 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
@torch.jit.script
|
@@ -17,20 +25,28 @@ def sample_gaussian(loc, logscale):
|
|
17 |
|
18 |
|
19 |
class Block(nn.Module):
|
20 |
-
def __init__(
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
super().__init__()
|
23 |
self.d = down_rate
|
24 |
self.residual = residual
|
25 |
padding = 0 if kernel_size == 1 else 1
|
26 |
|
27 |
-
if version ==
|
28 |
activation = nn.ReLU()
|
29 |
self.conv = nn.Sequential(
|
30 |
activation,
|
31 |
nn.Conv2d(in_width, bottleneck, kernel_size, 1, padding),
|
32 |
activation,
|
33 |
-
nn.Conv2d(bottleneck, out_width, kernel_size, 1, padding)
|
34 |
)
|
35 |
else: # for morphomnist
|
36 |
activation = nn.GELU()
|
@@ -42,7 +58,7 @@ class Block(nn.Module):
|
|
42 |
activation,
|
43 |
nn.Conv2d(bottleneck, bottleneck, kernel_size, 1, padding),
|
44 |
activation,
|
45 |
-
nn.Conv2d(bottleneck, out_width, 1, 1)
|
46 |
)
|
47 |
|
48 |
if self.residual and (self.d or in_width > out_width):
|
@@ -67,30 +83,41 @@ class Encoder(nn.Module):
|
|
67 |
super().__init__()
|
68 |
# parse architecture
|
69 |
stages = []
|
70 |
-
for i, stage in enumerate(args.enc_arch.split(
|
71 |
-
start = stage.index(
|
72 |
-
end = stage.index(
|
73 |
n_blocks = int(stage[start:end])
|
74 |
|
75 |
if i == 0: # define network stem
|
76 |
-
if n_blocks == 0 and
|
77 |
-
print(
|
78 |
-
self.stem = nn.Conv2d(
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
80 |
continue
|
81 |
else:
|
82 |
-
self.stem = nn.Conv2d(
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
stages += [(args.widths[i], None) for _ in range(n_blocks)]
|
86 |
-
if
|
87 |
-
stages += [(args.widths[i+1], int(stage[stage.index(
|
88 |
blocks = []
|
89 |
for i, (width, d) in enumerate(stages):
|
90 |
-
prev_width = stages[max(0, i-1)][0]
|
91 |
bottleneck = int(prev_width / args.bottleneck)
|
92 |
-
blocks.append(
|
93 |
-
|
|
|
94 |
# scale weights of last conv layer in each block
|
95 |
for b in blocks:
|
96 |
b.conv[-1].weight.data *= np.sqrt(1 / len(blocks))
|
@@ -113,7 +140,7 @@ class DecoderBlock(nn.Module):
|
|
113 |
super().__init__()
|
114 |
bottleneck = int(in_width / args.bottleneck)
|
115 |
self.res = resolution
|
116 |
-
self.stochastic =
|
117 |
self.z_dim = args.z_dim
|
118 |
self.cond_prior = args.cond_prior
|
119 |
k = 3 if self.res > 2 else 1
|
@@ -125,21 +152,35 @@ class DecoderBlock(nn.Module):
|
|
125 |
# self.z_feat_proj = nn.Conv2d(self.z_dim + in_width, out_width, 1)
|
126 |
self.z_feat_proj = nn.Conv2d(self.z_dim + in_width, out_width, 1)
|
127 |
|
128 |
-
self.prior = Block(
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
if self.stochastic:
|
131 |
-
self.posterior = Block(
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
self.z_proj = nn.Conv2d(self.z_dim + args.context_dim, in_width, 1)
|
134 |
-
self.conv = Block(
|
|
|
|
|
135 |
|
136 |
def forward_prior(self, z, pa=None, t=None):
|
137 |
if self.cond_prior:
|
138 |
z = torch.cat([z, pa], dim=1)
|
139 |
z = self.prior(z)
|
140 |
-
p_loc = z[:, :self.z_dim, ...]
|
141 |
-
p_logscale = z[:, self.z_dim:2*self.z_dim, ...]
|
142 |
-
p_features = z[:, 2*self.z_dim:, ...]
|
143 |
if t is not None:
|
144 |
p_logscale = p_logscale + torch.tensor(t).to(z.device).log()
|
145 |
return p_loc, p_logscale, p_features
|
@@ -157,28 +198,27 @@ class Decoder(nn.Module):
|
|
157 |
super().__init__()
|
158 |
# parse architecture
|
159 |
stages = []
|
160 |
-
for i, stage in enumerate(args.dec_arch.split(
|
161 |
-
res = int(stage.split(
|
162 |
-
n_blocks = int(stage[stage.index(
|
163 |
stages += [(res, args.widths[::-1][i]) for _ in range(n_blocks)]
|
164 |
self.blocks = []
|
165 |
for i, (res, width) in enumerate(stages):
|
166 |
-
next_width = stages[min(len(stages)-1, i+1)][1]
|
167 |
self.blocks.append(DecoderBlock(args, width, next_width, res))
|
168 |
self._scale_weights()
|
169 |
self.blocks = nn.ModuleList(self.blocks)
|
170 |
# bias params
|
171 |
-
self.all_res = list(np.unique([stages[i][0]
|
172 |
-
for i in range(len(stages))]))
|
173 |
bias = []
|
174 |
for i, res in enumerate(self.all_res):
|
175 |
if res <= args.bias_max_res:
|
176 |
-
bias.append(
|
177 |
-
torch.zeros(1, args.widths[::-1][i], res, res)
|
178 |
-
)
|
179 |
self.bias = nn.ParameterList(bias)
|
180 |
self.cond_prior = args.cond_prior
|
181 |
-
self.is_drop_cond = True if
|
182 |
|
183 |
def _scale_weights(self):
|
184 |
scale = np.sqrt(1 / len(self.blocks))
|
@@ -200,28 +240,29 @@ class Decoder(nn.Module):
|
|
200 |
res = block.res # current block resolution, e.g. 64x64
|
201 |
pa = parents[..., :res, :res].clone() # select parents @ res
|
202 |
|
203 |
-
if
|
|
|
|
|
204 |
pa_drop1 = pa.clone()
|
205 |
-
pa_drop1[:,2
|
206 |
pa_drop2 = pa.clone()
|
207 |
-
pa_drop2[:,2
|
208 |
else: # for ukbb
|
209 |
pa_drop1 = pa_drop2 = pa
|
210 |
|
211 |
if h.size(-1) < res: # upsample previous layer output
|
212 |
b = bias[res] if res in bias.keys() else 0 # broadcasting
|
213 |
-
h = b + F.interpolate(h, scale_factor=res/h.shape[-1])
|
214 |
|
215 |
if block.cond_prior: # conditional prior: p(z_i | z_<i, pa_x)
|
216 |
# w/ posterior correction
|
217 |
# p_loc, p_logscale, p_feat = block.forward_prior(h, pa_drop1, t=t)
|
218 |
if z.size(-1) < res: # w/o posterior correction
|
219 |
-
z = b + F.interpolate(z, scale_factor=res/z.shape[-1])
|
220 |
-
p_loc, p_logscale, p_feat = block.forward_prior(
|
221 |
-
z, pa_drop1, t=t)
|
222 |
else: # exogenous prior: p(z_i | z_<i)
|
223 |
if z.size(-1) < res:
|
224 |
-
z = b + F.interpolate(z, scale_factor=res/z.shape[-1])
|
225 |
p_loc, p_logscale, p_feat = block.forward_prior(z, t=t)
|
226 |
|
227 |
# computation tree:
|
@@ -239,16 +280,17 @@ class Decoder(nn.Module):
|
|
239 |
|
240 |
if block.stochastic:
|
241 |
if x is not None: # z_i ~ q(z_i | z_<i, pa_x, x)
|
242 |
-
q_loc, q_logscale = block.forward_posterior(
|
243 |
-
h, pa, x[res], t=t)
|
244 |
z = sample_gaussian(q_loc, q_logscale)
|
245 |
-
stat = dict(kl=gaussian_kl(
|
246 |
-
q_loc, q_logscale, p_loc, p_logscale))
|
247 |
# abduct exogenous noise
|
248 |
if abduct:
|
249 |
if block.cond_prior: # z* if conditional prior
|
250 |
-
stat.update(
|
251 |
-
|
|
|
|
|
|
|
252 |
else: # z if exogenous prior
|
253 |
# stat.update(dict(z=z.detach()))
|
254 |
stat.update(dict(z=z)) # if cf training
|
@@ -258,8 +300,9 @@ class Decoder(nn.Module):
|
|
258 |
z = sample_gaussian(p_loc, p_logscale)
|
259 |
|
260 |
if abduct and block.cond_prior: # for abducting z*
|
261 |
-
stats.append(
|
262 |
-
|
|
|
263 |
else:
|
264 |
try: # forward fixed latents z or z*
|
265 |
z = latents[i]
|
@@ -267,8 +310,9 @@ class Decoder(nn.Module):
|
|
267 |
z = sample_gaussian(p_loc, p_logscale)
|
268 |
|
269 |
if abduct and block.cond_prior: # for abducting z*
|
270 |
-
stats.append(
|
271 |
-
|
|
|
272 |
else:
|
273 |
z = p_loc # deterministic path
|
274 |
|
@@ -276,7 +320,7 @@ class Decoder(nn.Module):
|
|
276 |
h = self.forward_merge(block, h, z, pa_drop2)
|
277 |
|
278 |
# if not block.cond_prior:
|
279 |
-
if (i+1) < len(self.blocks):
|
280 |
# z independent of pa_x for next layer prior
|
281 |
z = block.z_feat_proj(torch.cat([z, p_feat], dim=1))
|
282 |
return h, stats
|
@@ -287,7 +331,7 @@ class Decoder(nn.Module):
|
|
287 |
return block.conv(h)
|
288 |
|
289 |
def drop_cond(self):
|
290 |
-
opt = dist.Categorical(1/3*torch.ones(3)).sample()
|
291 |
if opt == 0: # drop stochastic path
|
292 |
p1, p2 = 0, 1
|
293 |
elif opt == 1: # drop deterministic path
|
@@ -301,30 +345,31 @@ class DGaussNet(nn.Module):
|
|
301 |
def __init__(self, args):
|
302 |
super(DGaussNet, self).__init__()
|
303 |
self.x_loc = nn.Conv2d(
|
304 |
-
args.widths[0], args.input_channels, kernel_size=1, stride=1
|
|
|
305 |
self.x_logscale = nn.Conv2d(
|
306 |
-
args.widths[0], args.input_channels, kernel_size=1, stride=1
|
|
|
307 |
|
308 |
if args.input_channels == 3:
|
309 |
-
self.channel_coeffs = nn.Conv2d(
|
310 |
-
args.widths[0], 3, kernel_size=1, stride=1)
|
311 |
|
312 |
if args.std_init > 0: # if std_init=0, random init weights for diag cov
|
313 |
nn.init.zeros_(self.x_logscale.weight)
|
314 |
nn.init.constant_(self.x_logscale.bias, np.log(args.std_init))
|
315 |
|
316 |
-
covariance = args.x_like.split(
|
317 |
-
if covariance ==
|
318 |
self.x_logscale.weight.requires_grad = False
|
319 |
self.x_logscale.bias.requires_grad = False
|
320 |
-
elif covariance ==
|
321 |
self.x_logscale.weight.requires_grad = False
|
322 |
self.x_logscale.bias.requires_grad = True
|
323 |
-
elif covariance ==
|
324 |
self.x_logscale.weight.requires_grad = True
|
325 |
self.x_logscale.bias.requires_grad = True
|
326 |
else:
|
327 |
-
NotImplementedError(f
|
328 |
|
329 |
def forward(self, h, x=None, t=None):
|
330 |
loc, logscale = self.x_loc(h), self.x_logscale(h).clamp(min=EPS)
|
@@ -351,7 +396,9 @@ class DGaussNet(nn.Module):
|
|
351 |
return loc, logscale
|
352 |
|
353 |
def approx_cdf(self, x):
|
354 |
-
return 0.5 * (
|
|
|
|
|
355 |
|
356 |
def nll(self, h, x):
|
357 |
loc, logscale = self.forward(h, x)
|
@@ -367,10 +414,11 @@ class DGaussNet(nn.Module):
|
|
367 |
log_probs = torch.where(
|
368 |
x < -0.999,
|
369 |
log_cdf_plus,
|
370 |
-
torch.where(
|
371 |
-
|
|
|
372 |
)
|
373 |
-
return -1. * log_probs.mean(dim=(1, 2, 3))
|
374 |
|
375 |
def sample(self, h, return_loc=True, t=None):
|
376 |
if return_loc:
|
@@ -378,20 +426,20 @@ class DGaussNet(nn.Module):
|
|
378 |
else:
|
379 |
loc, logscale = self.forward(h, t)
|
380 |
x = loc + torch.exp(logscale) * torch.randn_like(loc)
|
381 |
-
x = torch.clamp(x, min=-1
|
382 |
return x, logscale.exp()
|
383 |
|
384 |
|
385 |
class HVAE(nn.Module):
|
386 |
def __init__(self, args):
|
387 |
super().__init__()
|
388 |
-
args.vr =
|
389 |
self.encoder = Encoder(args)
|
390 |
self.decoder = Decoder(args)
|
391 |
-
if args.x_like.split(
|
392 |
self.likelihood = DGaussNet(args)
|
393 |
else:
|
394 |
-
NotImplementedError(f
|
395 |
self.cond_prior = args.cond_prior
|
396 |
self.free_bits = args.kl_free_bits
|
397 |
|
@@ -404,12 +452,12 @@ class HVAE(nn.Module):
|
|
404 |
kl_pp = 0.0
|
405 |
for stat in stats:
|
406 |
kl_pp += torch.maximum(
|
407 |
-
free_bits, stat[
|
408 |
).sum()
|
409 |
else:
|
410 |
kl_pp = torch.zeros_like(nll_pp)
|
411 |
for i, stat in enumerate(stats):
|
412 |
-
kl_pp += stat[
|
413 |
kl_pp = kl_pp / np.prod(x.shape[1:]) # per pixel
|
414 |
elbo = nll_pp.mean() + beta * kl_pp.mean() # negative elbo (free energy)
|
415 |
return dict(elbo=elbo, nll=nll_pp.mean(), kl=kl_pp.mean())
|
@@ -421,26 +469,26 @@ class HVAE(nn.Module):
|
|
421 |
def abduct(self, x, parents, cf_parents=None, alpha=0.5, t=None):
|
422 |
acts = self.encoder(x)
|
423 |
_, q_stats = self.decoder(
|
424 |
-
x=acts, parents=parents, abduct=True, t=t
|
425 |
-
|
|
|
426 |
|
427 |
if self.cond_prior and cf_parents is not None:
|
428 |
-
_, p_stats = self.decoder(
|
429 |
-
|
430 |
-
p_stats = [s['z'] for s in p_stats]
|
431 |
|
432 |
cf_zs = []
|
433 |
t = torch.tensor(t).to(x.device) # z* sampling temperature
|
434 |
|
435 |
for i in range(len(q_stats)):
|
436 |
# from z_i ~ q(z_i | z_{<i}, x, pa)
|
437 |
-
q_loc = q_stats[i][
|
438 |
-
q_scale = q_stats[i][
|
439 |
# abduct exogenouse noise u ~ N(0,I)
|
440 |
-
u = (q_stats[i][
|
441 |
# p(z_i | z_{<i}, pa*)
|
442 |
-
p_loc = p_stats[i][
|
443 |
-
p_var = p_stats[i][
|
444 |
|
445 |
# Option1: mixture distribution: r(z_i | z_{<i}, x, pa, pa*)
|
446 |
# = a*q(z_i | z_{<i}, x, pa) + (1-a)*p(z_i | z_{<i}, pa*)
|
|
|
6 |
|
7 |
EPS = -9 # minimum logscale
|
8 |
|
9 |
+
|
10 |
@torch.jit.script
|
11 |
def gaussian_kl(q_loc, q_logscale, p_loc, p_logscale):
|
12 |
+
return (
|
13 |
+
-0.5
|
14 |
+
+ p_logscale
|
15 |
+
- q_logscale
|
16 |
+
+ 0.5
|
17 |
+
* (q_logscale.exp().pow(2) + (q_loc - p_loc).pow(2))
|
18 |
+
/ p_logscale.exp().pow(2)
|
19 |
+
)
|
20 |
|
21 |
|
22 |
@torch.jit.script
|
|
|
25 |
|
26 |
|
27 |
class Block(nn.Module):
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
in_width,
|
31 |
+
bottleneck,
|
32 |
+
out_width,
|
33 |
+
kernel_size=3,
|
34 |
+
residual=True,
|
35 |
+
down_rate=None,
|
36 |
+
version=None,
|
37 |
+
):
|
38 |
super().__init__()
|
39 |
self.d = down_rate
|
40 |
self.residual = residual
|
41 |
padding = 0 if kernel_size == 1 else 1
|
42 |
|
43 |
+
if version == "light": # for ukbb
|
44 |
activation = nn.ReLU()
|
45 |
self.conv = nn.Sequential(
|
46 |
activation,
|
47 |
nn.Conv2d(in_width, bottleneck, kernel_size, 1, padding),
|
48 |
activation,
|
49 |
+
nn.Conv2d(bottleneck, out_width, kernel_size, 1, padding),
|
50 |
)
|
51 |
else: # for morphomnist
|
52 |
activation = nn.GELU()
|
|
|
58 |
activation,
|
59 |
nn.Conv2d(bottleneck, bottleneck, kernel_size, 1, padding),
|
60 |
activation,
|
61 |
+
nn.Conv2d(bottleneck, out_width, 1, 1),
|
62 |
)
|
63 |
|
64 |
if self.residual and (self.d or in_width > out_width):
|
|
|
83 |
super().__init__()
|
84 |
# parse architecture
|
85 |
stages = []
|
86 |
+
for i, stage in enumerate(args.enc_arch.split(",")):
|
87 |
+
start = stage.index("b") + 1
|
88 |
+
end = stage.index("d") if "d" in stage else None
|
89 |
n_blocks = int(stage[start:end])
|
90 |
|
91 |
if i == 0: # define network stem
|
92 |
+
if n_blocks == 0 and "d" not in stage:
|
93 |
+
print("Using stride=2 conv encoder stem.")
|
94 |
+
self.stem = nn.Conv2d(
|
95 |
+
args.input_channels,
|
96 |
+
args.widths[1],
|
97 |
+
kernel_size=7,
|
98 |
+
stride=2,
|
99 |
+
padding=3,
|
100 |
+
)
|
101 |
continue
|
102 |
else:
|
103 |
+
self.stem = nn.Conv2d(
|
104 |
+
args.input_channels,
|
105 |
+
args.widths[0],
|
106 |
+
kernel_size=7,
|
107 |
+
stride=1,
|
108 |
+
padding=3,
|
109 |
+
)
|
110 |
|
111 |
stages += [(args.widths[i], None) for _ in range(n_blocks)]
|
112 |
+
if "d" in stage: # downsampling block
|
113 |
+
stages += [(args.widths[i + 1], int(stage[stage.index("d") + 1]))]
|
114 |
blocks = []
|
115 |
for i, (width, d) in enumerate(stages):
|
116 |
+
prev_width = stages[max(0, i - 1)][0]
|
117 |
bottleneck = int(prev_width / args.bottleneck)
|
118 |
+
blocks.append(
|
119 |
+
Block(prev_width, bottleneck, width, down_rate=d, version=args.vr)
|
120 |
+
)
|
121 |
# scale weights of last conv layer in each block
|
122 |
for b in blocks:
|
123 |
b.conv[-1].weight.data *= np.sqrt(1 / len(blocks))
|
|
|
140 |
super().__init__()
|
141 |
bottleneck = int(in_width / args.bottleneck)
|
142 |
self.res = resolution
|
143 |
+
self.stochastic = self.res <= args.z_max_res
|
144 |
self.z_dim = args.z_dim
|
145 |
self.cond_prior = args.cond_prior
|
146 |
k = 3 if self.res > 2 else 1
|
|
|
152 |
# self.z_feat_proj = nn.Conv2d(self.z_dim + in_width, out_width, 1)
|
153 |
self.z_feat_proj = nn.Conv2d(self.z_dim + in_width, out_width, 1)
|
154 |
|
155 |
+
self.prior = Block(
|
156 |
+
p_in_width,
|
157 |
+
bottleneck,
|
158 |
+
2 * self.z_dim + in_width,
|
159 |
+
kernel_size=k,
|
160 |
+
residual=False,
|
161 |
+
version=args.vr,
|
162 |
+
)
|
163 |
if self.stochastic:
|
164 |
+
self.posterior = Block(
|
165 |
+
2 * in_width + args.context_dim,
|
166 |
+
bottleneck,
|
167 |
+
2 * self.z_dim,
|
168 |
+
kernel_size=k,
|
169 |
+
residual=False,
|
170 |
+
version=args.vr,
|
171 |
+
)
|
172 |
self.z_proj = nn.Conv2d(self.z_dim + args.context_dim, in_width, 1)
|
173 |
+
self.conv = Block(
|
174 |
+
in_width, bottleneck, out_width, kernel_size=k, version=args.vr
|
175 |
+
)
|
176 |
|
177 |
def forward_prior(self, z, pa=None, t=None):
|
178 |
if self.cond_prior:
|
179 |
z = torch.cat([z, pa], dim=1)
|
180 |
z = self.prior(z)
|
181 |
+
p_loc = z[:, : self.z_dim, ...]
|
182 |
+
p_logscale = z[:, self.z_dim : 2 * self.z_dim, ...]
|
183 |
+
p_features = z[:, 2 * self.z_dim :, ...]
|
184 |
if t is not None:
|
185 |
p_logscale = p_logscale + torch.tensor(t).to(z.device).log()
|
186 |
return p_loc, p_logscale, p_features
|
|
|
198 |
super().__init__()
|
199 |
# parse architecture
|
200 |
stages = []
|
201 |
+
for i, stage in enumerate(args.dec_arch.split(",")):
|
202 |
+
res = int(stage.split("b")[0])
|
203 |
+
n_blocks = int(stage[stage.index("b") + 1 :])
|
204 |
stages += [(res, args.widths[::-1][i]) for _ in range(n_blocks)]
|
205 |
self.blocks = []
|
206 |
for i, (res, width) in enumerate(stages):
|
207 |
+
next_width = stages[min(len(stages) - 1, i + 1)][1]
|
208 |
self.blocks.append(DecoderBlock(args, width, next_width, res))
|
209 |
self._scale_weights()
|
210 |
self.blocks = nn.ModuleList(self.blocks)
|
211 |
# bias params
|
212 |
+
self.all_res = list(np.unique([stages[i][0] for i in range(len(stages))]))
|
|
|
213 |
bias = []
|
214 |
for i, res in enumerate(self.all_res):
|
215 |
if res <= args.bias_max_res:
|
216 |
+
bias.append(
|
217 |
+
nn.Parameter(torch.zeros(1, args.widths[::-1][i], res, res))
|
218 |
+
)
|
219 |
self.bias = nn.ParameterList(bias)
|
220 |
self.cond_prior = args.cond_prior
|
221 |
+
self.is_drop_cond = True if "mnist" in args.hps else False # hacky
|
222 |
|
223 |
def _scale_weights(self):
|
224 |
scale = np.sqrt(1 / len(self.blocks))
|
|
|
240 |
res = block.res # current block resolution, e.g. 64x64
|
241 |
pa = parents[..., :res, :res].clone() # select parents @ res
|
242 |
|
243 |
+
if (
|
244 |
+
self.is_drop_cond
|
245 |
+
): # for morphomnist w/ conditioning dropout. Hacky, clean up later
|
246 |
pa_drop1 = pa.clone()
|
247 |
+
pa_drop1[:, 2:, ...] = pa_drop1[:, 2:, ...] * p1
|
248 |
pa_drop2 = pa.clone()
|
249 |
+
pa_drop2[:, 2:, ...] = pa_drop2[:, 2:, ...] * p2
|
250 |
else: # for ukbb
|
251 |
pa_drop1 = pa_drop2 = pa
|
252 |
|
253 |
if h.size(-1) < res: # upsample previous layer output
|
254 |
b = bias[res] if res in bias.keys() else 0 # broadcasting
|
255 |
+
h = b + F.interpolate(h, scale_factor=res / h.shape[-1])
|
256 |
|
257 |
if block.cond_prior: # conditional prior: p(z_i | z_<i, pa_x)
|
258 |
# w/ posterior correction
|
259 |
# p_loc, p_logscale, p_feat = block.forward_prior(h, pa_drop1, t=t)
|
260 |
if z.size(-1) < res: # w/o posterior correction
|
261 |
+
z = b + F.interpolate(z, scale_factor=res / z.shape[-1])
|
262 |
+
p_loc, p_logscale, p_feat = block.forward_prior(z, pa_drop1, t=t)
|
|
|
263 |
else: # exogenous prior: p(z_i | z_<i)
|
264 |
if z.size(-1) < res:
|
265 |
+
z = b + F.interpolate(z, scale_factor=res / z.shape[-1])
|
266 |
p_loc, p_logscale, p_feat = block.forward_prior(z, t=t)
|
267 |
|
268 |
# computation tree:
|
|
|
280 |
|
281 |
if block.stochastic:
|
282 |
if x is not None: # z_i ~ q(z_i | z_<i, pa_x, x)
|
283 |
+
q_loc, q_logscale = block.forward_posterior(h, pa, x[res], t=t)
|
|
|
284 |
z = sample_gaussian(q_loc, q_logscale)
|
285 |
+
stat = dict(kl=gaussian_kl(q_loc, q_logscale, p_loc, p_logscale))
|
|
|
286 |
# abduct exogenous noise
|
287 |
if abduct:
|
288 |
if block.cond_prior: # z* if conditional prior
|
289 |
+
stat.update(
|
290 |
+
dict(
|
291 |
+
z={"z": z, "q_loc": q_loc, "q_logscale": q_logscale}
|
292 |
+
)
|
293 |
+
)
|
294 |
else: # z if exogenous prior
|
295 |
# stat.update(dict(z=z.detach()))
|
296 |
stat.update(dict(z=z)) # if cf training
|
|
|
300 |
z = sample_gaussian(p_loc, p_logscale)
|
301 |
|
302 |
if abduct and block.cond_prior: # for abducting z*
|
303 |
+
stats.append(
|
304 |
+
dict(z={"p_loc": p_loc, "p_logscale": p_logscale})
|
305 |
+
)
|
306 |
else:
|
307 |
try: # forward fixed latents z or z*
|
308 |
z = latents[i]
|
|
|
310 |
z = sample_gaussian(p_loc, p_logscale)
|
311 |
|
312 |
if abduct and block.cond_prior: # for abducting z*
|
313 |
+
stats.append(
|
314 |
+
dict(z={"p_loc": p_loc, "p_logscale": p_logscale})
|
315 |
+
)
|
316 |
else:
|
317 |
z = p_loc # deterministic path
|
318 |
|
|
|
320 |
h = self.forward_merge(block, h, z, pa_drop2)
|
321 |
|
322 |
# if not block.cond_prior:
|
323 |
+
if (i + 1) < len(self.blocks):
|
324 |
# z independent of pa_x for next layer prior
|
325 |
z = block.z_feat_proj(torch.cat([z, p_feat], dim=1))
|
326 |
return h, stats
|
|
|
331 |
return block.conv(h)
|
332 |
|
333 |
def drop_cond(self):
|
334 |
+
opt = dist.Categorical(1 / 3 * torch.ones(3)).sample()
|
335 |
if opt == 0: # drop stochastic path
|
336 |
p1, p2 = 0, 1
|
337 |
elif opt == 1: # drop deterministic path
|
|
|
345 |
def __init__(self, args):
|
346 |
super(DGaussNet, self).__init__()
|
347 |
self.x_loc = nn.Conv2d(
|
348 |
+
args.widths[0], args.input_channels, kernel_size=1, stride=1
|
349 |
+
)
|
350 |
self.x_logscale = nn.Conv2d(
|
351 |
+
args.widths[0], args.input_channels, kernel_size=1, stride=1
|
352 |
+
)
|
353 |
|
354 |
if args.input_channels == 3:
|
355 |
+
self.channel_coeffs = nn.Conv2d(args.widths[0], 3, kernel_size=1, stride=1)
|
|
|
356 |
|
357 |
if args.std_init > 0: # if std_init=0, random init weights for diag cov
|
358 |
nn.init.zeros_(self.x_logscale.weight)
|
359 |
nn.init.constant_(self.x_logscale.bias, np.log(args.std_init))
|
360 |
|
361 |
+
covariance = args.x_like.split("_")[0]
|
362 |
+
if covariance == "fixed":
|
363 |
self.x_logscale.weight.requires_grad = False
|
364 |
self.x_logscale.bias.requires_grad = False
|
365 |
+
elif covariance == "shared":
|
366 |
self.x_logscale.weight.requires_grad = False
|
367 |
self.x_logscale.bias.requires_grad = True
|
368 |
+
elif covariance == "diag":
|
369 |
self.x_logscale.weight.requires_grad = True
|
370 |
self.x_logscale.bias.requires_grad = True
|
371 |
else:
|
372 |
+
NotImplementedError(f"{args.x_like} not implemented.")
|
373 |
|
374 |
def forward(self, h, x=None, t=None):
|
375 |
loc, logscale = self.x_loc(h), self.x_logscale(h).clamp(min=EPS)
|
|
|
396 |
return loc, logscale
|
397 |
|
398 |
def approx_cdf(self, x):
|
399 |
+
return 0.5 * (
|
400 |
+
1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3)))
|
401 |
+
)
|
402 |
|
403 |
def nll(self, h, x):
|
404 |
loc, logscale = self.forward(h, x)
|
|
|
414 |
log_probs = torch.where(
|
415 |
x < -0.999,
|
416 |
log_cdf_plus,
|
417 |
+
torch.where(
|
418 |
+
x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))
|
419 |
+
),
|
420 |
)
|
421 |
+
return -1.0 * log_probs.mean(dim=(1, 2, 3))
|
422 |
|
423 |
def sample(self, h, return_loc=True, t=None):
|
424 |
if return_loc:
|
|
|
426 |
else:
|
427 |
loc, logscale = self.forward(h, t)
|
428 |
x = loc + torch.exp(logscale) * torch.randn_like(loc)
|
429 |
+
x = torch.clamp(x, min=-1.0, max=1.0)
|
430 |
return x, logscale.exp()
|
431 |
|
432 |
|
433 |
class HVAE(nn.Module):
|
434 |
def __init__(self, args):
|
435 |
super().__init__()
|
436 |
+
args.vr = "light" if "ukbb" in args.hps else None # hacky
|
437 |
self.encoder = Encoder(args)
|
438 |
self.decoder = Decoder(args)
|
439 |
+
if args.x_like.split("_")[1] == "dgauss":
|
440 |
self.likelihood = DGaussNet(args)
|
441 |
else:
|
442 |
+
NotImplementedError(f"{args.x_like} not implemented.")
|
443 |
self.cond_prior = args.cond_prior
|
444 |
self.free_bits = args.kl_free_bits
|
445 |
|
|
|
452 |
kl_pp = 0.0
|
453 |
for stat in stats:
|
454 |
kl_pp += torch.maximum(
|
455 |
+
free_bits, stat["kl"].sum(dim=(2, 3)).mean(dim=0)
|
456 |
).sum()
|
457 |
else:
|
458 |
kl_pp = torch.zeros_like(nll_pp)
|
459 |
for i, stat in enumerate(stats):
|
460 |
+
kl_pp += stat["kl"].sum(dim=(1, 2, 3))
|
461 |
kl_pp = kl_pp / np.prod(x.shape[1:]) # per pixel
|
462 |
elbo = nll_pp.mean() + beta * kl_pp.mean() # negative elbo (free energy)
|
463 |
return dict(elbo=elbo, nll=nll_pp.mean(), kl=kl_pp.mean())
|
|
|
469 |
def abduct(self, x, parents, cf_parents=None, alpha=0.5, t=None):
|
470 |
acts = self.encoder(x)
|
471 |
_, q_stats = self.decoder(
|
472 |
+
x=acts, parents=parents, abduct=True, t=t
|
473 |
+
) # q(z|x,pa)
|
474 |
+
q_stats = [s["z"] for s in q_stats]
|
475 |
|
476 |
if self.cond_prior and cf_parents is not None:
|
477 |
+
_, p_stats = self.decoder(parents=cf_parents, abduct=True, t=t) # p(z|pa*)
|
478 |
+
p_stats = [s["z"] for s in p_stats]
|
|
|
479 |
|
480 |
cf_zs = []
|
481 |
t = torch.tensor(t).to(x.device) # z* sampling temperature
|
482 |
|
483 |
for i in range(len(q_stats)):
|
484 |
# from z_i ~ q(z_i | z_{<i}, x, pa)
|
485 |
+
q_loc = q_stats[i]["q_loc"]
|
486 |
+
q_scale = q_stats[i]["q_logscale"].exp()
|
487 |
# abduct exogenouse noise u ~ N(0,I)
|
488 |
+
u = (q_stats[i]["z"] - q_loc) / q_scale
|
489 |
# p(z_i | z_{<i}, pa*)
|
490 |
+
p_loc = p_stats[i]["p_loc"]
|
491 |
+
p_var = p_stats[i]["p_logscale"].exp().pow(2)
|
492 |
|
493 |
# Option1: mixture distribution: r(z_i | z_{<i}, x, pa, pa*)
|
494 |
# = a*q(z_i | z_{<i}, x, pa) + (1-a)*p(z_i | z_{<i}, pa*)
|