fabio-deep commited on
Commit
146a6ea
·
1 Parent(s): 8c4fe8b

added links

Browse files
Files changed (7) hide show
  1. .gitignore +1 -0
  2. app.py +343 -192
  3. app_utils.py +176 -114
  4. datasets.py +168 -123
  5. pgm/flow_pgm.py +310 -380
  6. pgm/layers.py +50 -42
  7. vae.py +136 -88
.gitignore CHANGED
@@ -1,2 +1,3 @@
 
1
  __pycache__
2
  *.pyc
 
1
+ .vscode
2
  __pycache__
3
  *.pyc
app.py CHANGED
@@ -8,48 +8,55 @@ from vae import HVAE
8
  from datasets import morphomnist, ukbb, mimic, get_attr_max_min
9
  from pgm.flow_pgm import MorphoMNISTPGM, FlowPGM, ChestPGM
10
  from app_utils import (
11
- mnist_graph, brain_graph, chest_graph, vae_preprocess, normalize, \
12
- preprocess_brain, get_fig_arr, postprocess, MidpointNormalize
 
 
 
 
 
 
 
13
  )
14
 
15
  DATA, MODELS = {}, {}
16
- for k in ['Morpho-MNIST', 'Brain MRI', 'Chest X-ray']:
17
  DATA[k], MODELS[k] = {}, {}
18
 
19
  # mnist
20
  DIGITS = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
21
  # brain
22
- MRISEQ_CAT = ['T1', 'T2-FLAIR'] # 0,1
23
- SEX_CAT = ['female', 'male'] # 0,1
24
  HEIGHT, WIDTH = 270, 270
25
  # chest
26
- SEX_CAT_CHEST = ['male', 'female'] # 0,1
27
- RACE_CAT = ['white', 'asian', 'black'] # 0,1,2
28
- FIND_CAT = ['no disease', 'pleural effusion']
29
- DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
30
 
31
 
32
  class Hparams:
33
  def update(self, dict):
34
  for k, v in dict.items():
35
- setattr(self, k, v)
36
 
37
 
38
  def get_paths(dataset_id):
39
- if 'MNIST' in dataset_id:
40
- data_path = './data/morphomnist'
41
- pgm_path = './checkpoints/t_i_d/sup_pgm/checkpoint.pt'
42
- vae_path = './checkpoints/t_i_d/dgauss_cond_big_beta1_dropexo/checkpoint.pt'
43
- elif 'Brain' in dataset_id:
44
- data_path = './data/ukbb_subset'
45
- pgm_path = './checkpoints/m_b_v_s/sup_pgm/checkpoint.pt'
46
- vae_path = './checkpoints/m_b_v_s/ukbb192_beta5_dgauss_b33/checkpoint.pt'
47
- elif 'Chest' in dataset_id:
48
- data_path = './data/mimic_subset'
49
- pgm_path = './checkpoints/a_r_s_f/sup_pgm_mimic/checkpoint.pt'
50
  vae_path = [
51
- './checkpoints/a_r_s_f/mimic_beta9_gelu_dgauss_1_lr3/checkpoint.pt', # base vae
52
- './checkpoints/a_r_s_f/mimic_dscm_lr_1e5_lagrange_lr_1_damping_10/6500_checkpoint.pt' # cf trained DSCM
53
  ]
54
  return data_path, vae_path, pgm_path
55
 
@@ -57,64 +64,71 @@ def get_paths(dataset_id):
57
  def load_pgm(dataset_id, pgm_path):
58
  checkpoint = torch.load(pgm_path, map_location=DEVICE)
59
  args = Hparams()
60
- args.update(checkpoint['hparams'])
61
  args.device = DEVICE
62
- if 'MNIST' in dataset_id:
63
  pgm = MorphoMNISTPGM(args).to(args.device)
64
- elif 'Brain' in dataset_id:
65
  pgm = FlowPGM(args).to(args.device)
66
- elif 'Chest' in dataset_id:
67
  pgm = ChestPGM(args).to(args.device)
68
- pgm.load_state_dict(checkpoint['ema_model_state_dict'])
69
- MODELS[dataset_id]['pgm'] = pgm
70
- MODELS[dataset_id]['pgm_args'] = args
71
 
72
 
73
  def load_vae(dataset_id, vae_path):
74
- if 'Chest' in dataset_id:
75
  vae_path, dscm_path = vae_path[0], vae_path[1]
76
  checkpoint = torch.load(vae_path, map_location=DEVICE)
77
  args = Hparams()
78
- args.update(checkpoint['hparams'])
79
  # backwards compatibility hack
80
- if not hasattr(args, 'vae'):
81
- args.vae = 'hierarchical'
82
- if not hasattr(args, 'cond_prior'):
83
  args.cond_prior = False
84
- if hasattr(args, 'free_bits'):
85
  args.kl_free_bits = args.free_bits
86
  args.device = DEVICE
87
  vae = HVAE(args).to(args.device)
88
 
89
- if 'Chest' in dataset_id:
90
  dscm_ckpt = torch.load(dscm_path, map_location=DEVICE)
91
- vae.load_state_dict({k[4:]: v for k, v in dscm_ckpt['ema_model_state_dict'].items() if 'vae.' in k})
 
 
 
 
 
 
92
  else:
93
- vae.load_state_dict(checkpoint['ema_model_state_dict'])
94
- MODELS[dataset_id]['vae'] = vae
95
- MODELS[dataset_id]['vae_args'] = args
96
 
97
 
98
  def get_dataloader(dataset_id, data_path):
99
- MODELS[dataset_id]['pgm_args'].data_dir = data_path
100
- args = MODELS[dataset_id]['pgm_args']
101
- if 'MNIST' in dataset_id:
102
  datasets = morphomnist(args)
103
- elif 'Brain' in dataset_id:
104
  datasets = ukbb(args)
105
- elif 'Chest' in dataset_id:
106
  datasets = mimic(args)
107
- DATA[dataset_id]['test'] = torch.utils.data.DataLoader(
108
- datasets['test'], shuffle=False, batch_size=args.bs, num_workers=4)
 
109
 
110
 
111
  def load_dataset(dataset_id):
112
  data_path, _, pgm_path = get_paths(dataset_id)
113
  checkpoint = torch.load(pgm_path, map_location=DEVICE)
114
  args = Hparams()
115
- args.update(checkpoint['hparams'])
116
  args.device = DEVICE
117
- MODELS[dataset_id]['pgm_args'] = args
118
  get_dataloader(dataset_id, data_path)
119
 
120
 
@@ -122,167 +136,179 @@ def load_model(dataset_id):
122
  _, vae_path, pgm_path = get_paths(dataset_id)
123
  load_pgm(dataset_id, pgm_path)
124
  load_vae(dataset_id, vae_path)
125
-
126
 
127
  @torch.no_grad()
128
  def counterfactual_inference(dataset_id, obs, do_pa):
129
- pa = {k: v.clone() for k, v in obs.items() if k != 'x'}
130
- cf_pa = MODELS[dataset_id]['pgm'].counterfactual(obs=pa, intervention=do_pa, num_particles=1)
131
- args, vae = MODELS[dataset_id]['vae_args'], MODELS[dataset_id]['vae']
 
 
132
  _pa = vae_preprocess(args, {k: v.clone() for k, v in pa.items()})
133
- _cf_pa = vae_preprocess(args , {k: v.clone() for k, v in cf_pa.items()})
134
- z_t = 0.1 if 'mnist' in args.hps else 1.0
135
- z = vae.abduct(x=obs['x'], parents=_pa, t=z_t)
136
  if vae.cond_prior:
137
- z = [z[j]['z'] for j in range(len(z))]
138
  px_loc, px_scale = vae.forward_latents(latents=z, parents=_pa)
139
  cf_loc, cf_scale = vae.forward_latents(latents=z, parents=_cf_pa)
140
- u = (obs['x'] - px_loc) / px_scale.clamp(min=1e-12)
141
- u_t = 0.1 if 'mnist' in args.hps else 1.0 # cf sampling temp
142
  cf_scale = cf_scale * u_t
143
  cf_x = torch.clamp(cf_loc + cf_scale * u, min=-1, max=1)
144
- return {'cf_x': cf_x, 'rec_x': px_loc, 'cf_pa': cf_pa}
145
 
146
 
147
  def get_obs_item(dataset_id, idx=None):
148
  if idx is None:
149
- n_test = len(DATA[dataset_id]['test'].dataset)
150
  idx = torch.randperm(n_test)[0]
151
  idx = int(idx)
152
- return idx, DATA[dataset_id]['test'].dataset.__getitem__(idx)
153
 
154
 
155
  def get_mnist_obs(idx=None):
156
- dataset_id = 'Morpho-MNIST'
157
  if not DATA[dataset_id]:
158
  load_dataset(dataset_id)
159
  idx, obs = get_obs_item(dataset_id, idx)
160
- x = get_fig_arr(obs['x'].clone().squeeze().numpy())
161
- t = (obs['thickness'].clone() + 1) / 2 * (6.255515 - 0.87598526) + 0.87598526
162
- i = (obs['intensity'].clone() + 1) / 2 * (254.90317 - 66.601204) + 66.601204
163
- y = DIGITS[obs['digit'].clone().argmax(-1)]
164
  return (idx, x, float(np.round(t, 2)), float(np.round(i, 2)), y)
165
 
166
 
167
  def get_brain_obs(idx=None):
168
- dataset_id = 'Brain MRI'
169
  if not DATA[dataset_id]:
170
  load_dataset(dataset_id)
171
  idx, obs = get_obs_item(dataset_id, idx)
172
- x = get_fig_arr(obs['x'].clone().squeeze().numpy())
173
- m = MRISEQ_CAT[int(obs['mri_seq'].clone().item())]
174
- s = SEX_CAT[int(obs['sex'].clone().item())]
175
- a = obs['age'].clone().item()
176
- b = obs['brain_volume'].clone().item() / 1000 # in ml
177
- v = obs['ventricle_volume'].clone().item() / 1000 # in ml
178
  return (idx, x, m, s, a, float(np.round(b, 2)), float(np.round(v, 2)))
179
 
180
 
181
  def get_chest_obs(idx=None):
182
- dataset_id = 'Chest X-ray'
183
  if not DATA[dataset_id]:
184
- load_dataset(dataset_id)
185
  idx, obs = get_obs_item(dataset_id, idx)
186
- x = get_fig_arr(postprocess(obs['x'].clone()))
187
- s = SEX_CAT_CHEST[int(obs['sex'].clone().squeeze().numpy())]
188
- f = FIND_CAT[int(obs['finding'].clone().squeeze().numpy())]
189
- r = RACE_CAT[obs['race'].clone().squeeze().numpy().argmax(-1)]
190
- a = (obs['age'].clone().squeeze().numpy()+1)*50
191
  return (idx, x, r, s, f, float(np.round(a, 1)))
192
 
193
 
194
  def infer_mnist_cf(*args):
195
- dataset_id = 'Morpho-MNIST'
196
  idx, _, t, i, y, do_t, do_i, do_y = args
197
  n_particles = 32
198
  # preprocess
199
- obs = DATA[dataset_id]['test'].dataset.__getitem__(int(idx))
200
- obs['x'] = (obs['x'] - 127.5) / 127.5
201
  for k, v in obs.items():
202
  obs[k] = v.view(1, 1) if len(v.shape) < 1 else v.unsqueeze(0)
203
- obs[k] = obs[k].to(MODELS[dataset_id]['vae_args'].device).float()
204
  if n_particles > 1:
205
- ndims = (1,)*3 if k == 'x' else (1,)
206
  obs[k] = obs[k].repeat(n_particles, *ndims)
207
  # intervention(s)
208
  do_pa = {}
209
  if do_t:
210
- do_pa['thickness'] = torch.tensor(normalize(t, x_max=6.255515, x_min=0.87598526)).view(1, 1)
 
 
211
  if do_i:
212
- do_pa['intensity'] = torch.tensor(normalize(i, x_max=254.90317, x_min=66.601204)).view(1, 1)
 
 
213
  if do_y:
214
- do_pa['digit'] = F.one_hot(torch.tensor(DIGITS.index(y)), num_classes=10).view(1, 10)
215
-
 
 
216
  for k, v in do_pa.items():
217
- do_pa[k] = v.to(MODELS[dataset_id]['vae_args'].device).float().repeat(n_particles, 1)
 
 
218
  # infer counterfactual
219
  out = counterfactual_inference(dataset_id, obs, do_pa)
220
  # avg cf particles
221
- cf_x = out['cf_x'].mean(0)
222
- cf_x_std = out['cf_x'].std(0)
223
- rec_x = out['rec_x'].mean(0)
224
- cf_t = out['cf_pa']['thickness'].mean(0)
225
- cf_i = out['cf_pa']['intensity'].mean(0)
226
- cf_y = out['cf_pa']['digit'].mean(0)
227
  # post process
228
  cf_x = postprocess(cf_x)
229
  cf_x_std = cf_x_std.squeeze().detach().cpu().numpy()
230
  rec_x = postprocess(rec_x)
231
  cf_t = np.round((cf_t.item() + 1) / 2 * (6.255515 - 0.87598526) + 0.87598526, 2)
232
- cf_i = np.round((cf_i.item() + 1) / 2 * (254.90317 - 66.601204) + 66.601204, 2)
233
  cf_y = DIGITS[cf_y.argmax(-1)]
234
- # plots
235
  # plt.close('all')
236
  effect = cf_x - rec_x
237
- effect = get_fig_arr(effect, cmap='RdBu_r',
238
- norm=MidpointNormalize(vmin=-255, midpoint=0, vmax=255))
 
239
  cf_x = get_fig_arr(cf_x)
240
- cf_x_std = get_fig_arr(cf_x_std, cmap='jet')
241
  return (cf_x, cf_x_std, effect, cf_t, cf_i, cf_y)
242
 
243
 
244
  def infer_brain_cf(*args):
245
- dataset_id = 'Brain MRI'
246
  idx, _, m, s, a, b, v = args[:7]
247
  do_m, do_s, do_a, do_b, do_v = args[7:]
248
  n_particles = 16
249
  # preprocessing
250
- obs = DATA[dataset_id]['test'].dataset.__getitem__(int(idx))
251
- obs.pop('pa')
252
- obs = preprocess_brain(MODELS[dataset_id]['vae_args'], obs)
253
  for k, _v in obs.items():
254
  if n_particles > 1:
255
- ndims = (1,)*3 if k == 'x' else (1,)
256
  obs[k] = _v.repeat(n_particles, *ndims)
257
  # interventions(s)
258
  do_pa = {}
259
  if do_m:
260
- do_pa['mri_seq'] = torch.tensor(MRISEQ_CAT.index(m)).view(1, 1)
261
  if do_s:
262
- do_pa['sex'] = torch.tensor(SEX_CAT.index(s)).view(1, 1)
263
  if do_a:
264
- do_pa['age'] = torch.tensor(a).view(1, 1)
265
  if do_b:
266
- do_pa['brain_volume'] = torch.tensor(b * 1000).view(1, 1)
267
  if do_v:
268
- do_pa['ventricle_volume'] = torch.tensor(v * 1000).view(1, 1)
269
  # normalize continuous attributes
270
- for k in ['age', 'brain_volume', 'ventricle_volume']:
271
  if k in do_pa.keys():
272
  k_max, k_min = get_attr_max_min(k)
273
  do_pa[k] = (do_pa[k] - k_min) / (k_max - k_min) # [0,1]
274
  do_pa[k] = 2 * do_pa[k] - 1 # [-1,1]
275
 
276
  for k, _v in do_pa.items():
277
- do_pa[k] = _v.to(MODELS[dataset_id]['vae_args'].device).float().repeat(n_particles, 1)
 
 
278
  # infer counterfactual
279
  out = counterfactual_inference(dataset_id, obs, do_pa)
280
  # avg cf particles
281
- cf_x = out['cf_x'].mean(0)
282
- cf_x_std = out['cf_x'].std(0)
283
- rec_x = out['rec_x'].mean(0)
284
- cf_m = out['cf_pa']['mri_seq'].mean(0)
285
- cf_s = out['cf_pa']['sex'].mean(0)
286
  # post process
287
  cf_x = postprocess(cf_x)
288
  cf_x_std = cf_x_std.squeeze().detach().cpu().numpy()
@@ -290,54 +316,70 @@ def infer_brain_cf(*args):
290
  cf_m = MRISEQ_CAT[int(cf_m.item())]
291
  cf_s = SEX_CAT[int(cf_s.item())]
292
  cf_ = {}
293
- for k in ['age', 'brain_volume', 'ventricle_volume']: # unnormalize
294
  k_max, k_min = get_attr_max_min(k)
295
- cf_[k] = (out['cf_pa'][k].mean(0).item() + 1) / 2 * (k_max - k_min) + k_min
296
  # plots
297
- # plt.close('all')
298
  effect = cf_x - rec_x
299
- effect = get_fig_arr(effect, cmap='RdBu_r',
300
- norm=MidpointNormalize(vmin=effect.min(), midpoint=0, vmax=effect.max()))
 
 
 
301
  cf_x = get_fig_arr(cf_x)
302
- cf_x_std = get_fig_arr(cf_x_std, cmap='jet')
303
- return (cf_x, cf_x_std, effect, cf_m, cf_s, np.round(cf_['age'], 1), np.round(cf_['brain_volume'] / 1000, 2), np.round(cf_['ventricle_volume'] / 1000, 2))
 
 
 
 
 
 
 
 
 
304
 
305
 
306
  def infer_chest_cf(*args):
307
- dataset_id = 'Chest X-ray'
308
  idx, _, r, s, f, a = args[:6]
309
  do_r, do_s, do_f, do_a = args[6:]
310
  n_particles = 16
311
  # preprocessing
312
- obs = DATA[dataset_id]['test'].dataset.__getitem__(int(idx))
313
  for k, v in obs.items():
314
- obs[k] = v.to(MODELS[dataset_id]['vae_args'].device).float()
315
  if n_particles > 1:
316
- ndims = (1,)*3 if k == 'x' else (1,)
317
  obs[k] = obs[k].repeat(n_particles, *ndims)
318
  # intervention(s)
319
  do_pa = {}
320
  with torch.no_grad():
321
  if do_s:
322
- do_pa['sex'] = torch.tensor(SEX_CAT_CHEST.index(s)).view(1, 1)
323
  if do_f:
324
- do_pa['finding'] = torch.tensor(FIND_CAT.index(f)).view(1, 1)
325
  if do_r:
326
- do_pa['race'] = F.one_hot(torch.tensor(RACE_CAT.index(r)), num_classes=3).view(1, 3)
 
 
327
  if do_a:
328
- do_pa['age'] = torch.tensor(a/100*2-1).view(1,1)
329
  for k, v in do_pa.items():
330
- do_pa[k] = v.to(MODELS[dataset_id]['vae_args'].device).float().repeat(n_particles, 1)
 
 
331
  # infer counterfactual
332
  out = counterfactual_inference(dataset_id, obs, do_pa)
333
  # avg cf particles
334
- cf_x = out['cf_x'].mean(0)
335
- cf_x_std = out['cf_x'].std(0)
336
- rec_x = out['rec_x'].mean(0)
337
- cf_r = out['cf_pa']['race'].mean(0)
338
- cf_s = out['cf_pa']['sex'].mean(0)
339
- cf_f = out['cf_pa']['finding'].mean(0)
340
- cf_a = out['cf_pa']['age'].mean(0)
341
  # post process
342
  cf_x = postprocess(cf_x)
343
  cf_x_std = cf_x_std.squeeze().detach().cpu().numpy()
@@ -349,10 +391,13 @@ def infer_chest_cf(*args):
349
  # plots
350
  # plt.close('all')
351
  effect = cf_x - rec_x
352
- effect = get_fig_arr(effect, cmap='RdBu_r',
353
- norm=MidpointNormalize(vmin=effect.min(), midpoint=0, vmax=effect.max()))
 
 
 
354
  cf_x = get_fig_arr(cf_x)
355
- cf_x_std = get_fig_arr(cf_x_std, cmap='jet')
356
  return (cf_x, cf_x_std, effect, cf_r, cf_s, cf_f, np.round(cf_a, 1))
357
 
358
 
@@ -364,33 +409,59 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
364
  with gr.Row().style(equal_height=True):
365
  idx = gr.Number(value=0, visible=False)
366
  with gr.Column(scale=1, min_width=200):
367
- x = gr.Image(label='Observation', interactive=False).style(height=HEIGHT)
 
 
368
  with gr.Column(scale=1, min_width=200):
369
- cf_x = gr.Image(label='Counterfactual', interactive=False).style(height=HEIGHT)
 
 
370
  with gr.Column(scale=1, min_width=200):
371
- cf_x_std = gr.Image(label='Counterfactual Uncertainty', interactive=False).style(height=HEIGHT)
 
 
372
  with gr.Column(scale=1, min_width=200):
373
- effect = gr.Image(label='Direct Causal Effect', interactive=False).style(height=HEIGHT)
 
 
374
  with gr.Row().style(equal_height=True):
375
  with gr.Column(scale=1.75):
376
- gr.Markdown("#### Intervention")
 
 
 
 
377
  with gr.Column():
378
  do_y = gr.Checkbox(label="do(digit)", value=False)
379
  y = gr.Radio(DIGITS, label="", interactive=False)
380
  with gr.Row():
381
  with gr.Column(min_width=100):
382
  do_t = gr.Checkbox(label="do(thickness)", value=False)
383
- t = gr.Slider(label="\u00A0", minimum=0.9, maximum=5.5, step=0.01, interactive=False)
 
 
 
 
 
 
384
  with gr.Column(min_width=100):
385
  do_i = gr.Checkbox(label="do(intensity)", value=False)
386
- i = gr.Slider(label="\u00A0", minimum=50, maximum=255, step=0.01, interactive=False)
 
 
 
 
 
 
387
  with gr.Row():
388
  new = gr.Button("New Observation")
389
  reset = gr.Button("Reset", variant="stop")
390
  submit = gr.Button("Submit", variant="primary")
391
  with gr.Column(scale=1):
392
  gr.Markdown("### &nbsp;")
393
- causal_graph = gr.Image(label='Causal Graph', interactive=False).style(height=300)
 
 
394
 
395
  with gr.TabItem("Brain MRI") as brain_tab:
396
  brain_id = gr.Textbox(value=brain_tab.label, visible=False)
@@ -398,40 +469,81 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
398
  with gr.Row().style(equal_height=True):
399
  idx_brain = gr.Number(value=0, visible=False)
400
  with gr.Column(scale=1, min_width=200):
401
- x_brain = gr.Image(label='Observation', interactive=False).style(height=HEIGHT)
 
 
402
  with gr.Column(scale=1, min_width=200):
403
- cf_x_brain = gr.Image(label='Counterfactual', interactive=False).style(height=HEIGHT)
 
 
404
  with gr.Column(scale=1, min_width=200):
405
- cf_x_std_brain = gr.Image(label='Counterfactual Uncertainty', interactive=False).style(height=HEIGHT)
 
 
406
  with gr.Column(scale=1, min_width=200):
407
- effect_brain = gr.Image(label='Direct Causal Effect', interactive=False).style(height=HEIGHT)
 
 
408
  with gr.Row():
409
  with gr.Column(scale=2.55):
410
- gr.Markdown("#### Intervention")
 
 
 
 
411
  with gr.Row():
412
  with gr.Column(min_width=200):
413
  do_m = gr.Checkbox(label="do(MRI sequence)", value=False)
414
- m = gr.Radio(["T1", "T2-FLAIR"], label="", interactive=False)
 
 
415
  with gr.Column(min_width=200):
416
  do_s = gr.Checkbox(label="do(sex)", value=False)
417
- s = gr.Radio(["female", "male"], label="", interactive=False)
 
 
418
  with gr.Row():
419
  with gr.Column(min_width=100):
420
  do_a = gr.Checkbox(label="do(age)", value=False)
421
- a = gr.Slider(label="\u00A0", value=50, minimum=44, maximum=73, step=1, interactive=False)
 
 
 
 
 
 
 
422
  with gr.Column(min_width=100):
423
  do_b = gr.Checkbox(label="do(brain volume)", value=False)
424
- b = gr.Slider(label="\u00A0", value=1000, minimum=850, maximum=1550, step=20, interactive=False)
 
 
 
 
 
 
 
425
  with gr.Column(min_width=100):
426
- do_v = gr.Checkbox(label="do(ventricle volume)", value=False)
427
- v = gr.Slider(label="\u00A0", value=40, minimum=10, maximum=125, step=2, interactive=False)
 
 
 
 
 
 
 
 
 
428
  with gr.Row():
429
  new_brain = gr.Button("New Observation")
430
- reset_brain = gr.Button("Reset", variant='stop')
431
- submit_brain = gr.Button("Submit", variant='primary')
432
  with gr.Column(scale=1):
433
  # gr.Markdown("### &nbsp;")
434
- causal_graph_brain = gr.Image(label='Causal Graph', interactive=False).style(height=340)
 
 
435
 
436
  with gr.TabItem("Chest X-ray") as chest_tab:
437
  chest_id = gr.Textbox(value=chest_tab.label, visible=False)
@@ -439,40 +551,58 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
439
  with gr.Row().style(equal_height=True):
440
  idx_chest = gr.Number(value=0, visible=False)
441
  with gr.Column(scale=1, min_width=200):
442
- x_chest = gr.Image(label='Observation', interactive=False).style(height=HEIGHT)
 
 
443
  with gr.Column(scale=1, min_width=200):
444
- cf_x_chest = gr.Image(label='Counterfactual', interactive=False).style(height=HEIGHT)
 
 
445
  with gr.Column(scale=1, min_width=200):
446
- cf_x_std_chest = gr.Image(label='Counterfactual Uncertainty', interactive=False).style(height=HEIGHT)
 
 
447
  with gr.Column(scale=1, min_width=200):
448
- effect_chest = gr.Image(label='Direct Causal Effect', interactive=False).style(height=HEIGHT)
 
 
449
 
450
  with gr.Row():
451
  with gr.Column(scale=2.55):
452
- gr.Markdown("#### Intervention")
 
 
 
 
453
  with gr.Row().style(equal_height=True):
454
  with gr.Column(min_width=200):
455
  do_f_chest = gr.Checkbox(label="do(disease)", value=False)
456
  f_chest = gr.Radio(FIND_CAT, label="", interactive=False)
457
  with gr.Column(min_width=200):
458
  do_s_chest = gr.Checkbox(label="do(sex)", value=False)
459
- s_chest = gr.Radio(SEX_CAT_CHEST, label="", interactive=False)
 
 
460
 
461
  with gr.Row():
462
  with gr.Column(min_width=200):
463
  do_r_chest = gr.Checkbox(label="do(race)", value=False)
464
- r_chest = gr.Radio(RACE_CAT, label="", interactive=False)
465
  with gr.Column(min_width=200):
466
  do_a_chest = gr.Checkbox(label="do(age)", value=False)
467
- a_chest = gr.Slider(label="\u00A0", minimum=18, maximum=98, step=1)
468
-
 
 
469
  with gr.Row():
470
  new_chest = gr.Button("New Observation")
471
  reset_chest = gr.Button("Reset", variant="stop")
472
  submit_chest = gr.Button("Submit", variant="primary")
473
  with gr.Column(scale=1):
474
  # gr.Markdown("### &nbsp;")
475
- causal_graph_chest = gr.Image(label='Causal Graph', interactive=False).style(height=345)
 
 
476
 
477
  # morphomnist
478
  do = [do_t, do_i, do_y]
@@ -514,29 +644,41 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
514
  new_chest.click(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest)
515
 
516
  # "new" button: reset cf output panels
517
- for _k, _v in zip([new, new_brain, new_chest], [cf_out, cf_out_brain, cf_out_chest]):
518
- _k.click(fn=lambda: (gr.update(value=None),)*3, inputs=None, outputs=_v)
 
 
519
 
520
  # "reset" button: reload current observations
521
  reset.click(fn=get_mnist_obs, inputs=idx, outputs=obs)
522
  reset_brain.click(fn=get_brain_obs, inputs=idx_brain, outputs=obs_brain)
523
  reset_chest.click(fn=get_chest_obs, inputs=idx_chest, outputs=obs_chest)
524
-
525
  # "reset" button: deselect intervention checkboxes
526
- reset.click(fn=lambda: (gr.update(value=False),)*len(do), inputs=None, outputs=do)
527
- reset_brain.click(fn=lambda: (gr.update(value=False),)*len(do_brain), inputs=None, outputs=do_brain)
528
- reset_chest.click(fn=lambda: (gr.update(value=False),)*len(do_chest), inputs=None, outputs=do_chest)
 
 
 
 
 
 
 
 
529
 
530
  # "reset" button: reset cf output panels
531
- for _k, _v in zip([reset, reset_brain, reset_chest], [cf_out, cf_out_brain, cf_out_chest]):
532
- _k.click(fn=lambda: plt.close('all'), inputs=None, outputs=None)
533
- _k.click(fn=lambda: (gr.update(value=None),)*3, inputs=None, outputs=_v)
 
 
534
 
535
  # enable mnist interventions when checkbox is selected & update graph
536
  for _k, _v in zip(do, [t, i, y]):
537
  _k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v)
538
  _k.change(mnist_graph, inputs=do, outputs=causal_graph)
539
-
540
  # enable brain interventions when checkbox is selected & update graph
541
  for _k, _v in zip(do_brain, [m, s, a, b, v]):
542
  _k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v)
@@ -546,11 +688,20 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
546
  for _k, _v in zip(do_chest, [r_chest, s_chest, f_chest, a_chest]):
547
  _k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v)
548
  _k.change(chest_graph, inputs=do_chest, outputs=causal_graph_chest)
549
-
550
  # "submit" button: infer countefactuals
551
  submit.click(fn=infer_mnist_cf, inputs=obs + do, outputs=cf_out + [t, i, y])
552
- submit_brain.click(fn=infer_brain_cf, inputs=obs_brain + do_brain, outputs=cf_out_brain + [m, s, a, b, v])
553
- submit_chest.click(fn=infer_chest_cf, inputs=obs_chest + do_chest, outputs=cf_out_chest + [r_chest, s_chest, f_chest, a_chest])
 
 
 
 
 
 
 
 
554
 
555
  if __name__ == "__main__":
556
- demo.launch()
 
 
8
  from datasets import morphomnist, ukbb, mimic, get_attr_max_min
9
  from pgm.flow_pgm import MorphoMNISTPGM, FlowPGM, ChestPGM
10
  from app_utils import (
11
+ mnist_graph,
12
+ brain_graph,
13
+ chest_graph,
14
+ vae_preprocess,
15
+ normalize,
16
+ preprocess_brain,
17
+ get_fig_arr,
18
+ postprocess,
19
+ MidpointNormalize,
20
  )
21
 
22
  DATA, MODELS = {}, {}
23
+ for k in ["Morpho-MNIST", "Brain MRI", "Chest X-ray"]:
24
  DATA[k], MODELS[k] = {}, {}
25
 
26
  # mnist
27
  DIGITS = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
28
  # brain
29
+ MRISEQ_CAT = ["T1", "T2-FLAIR"] # 0,1
30
+ SEX_CAT = ["female", "male"] # 0,1
31
  HEIGHT, WIDTH = 270, 270
32
  # chest
33
+ SEX_CAT_CHEST = ["male", "female"] # 0,1
34
+ RACE_CAT = ["white", "asian", "black"] # 0,1,2
35
+ FIND_CAT = ["no disease", "pleural effusion"]
36
+ DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
37
 
38
 
39
  class Hparams:
40
  def update(self, dict):
41
  for k, v in dict.items():
42
+ setattr(self, k, v)
43
 
44
 
45
  def get_paths(dataset_id):
46
+ if "MNIST" in dataset_id:
47
+ data_path = "./data/morphomnist"
48
+ pgm_path = "./checkpoints/t_i_d/sup_pgm/checkpoint.pt"
49
+ vae_path = "./checkpoints/t_i_d/dgauss_cond_big_beta1_dropexo/checkpoint.pt"
50
+ elif "Brain" in dataset_id:
51
+ data_path = "./data/ukbb_subset"
52
+ pgm_path = "./checkpoints/m_b_v_s/sup_pgm/checkpoint.pt"
53
+ vae_path = "./checkpoints/m_b_v_s/ukbb192_beta5_dgauss_b33/checkpoint.pt"
54
+ elif "Chest" in dataset_id:
55
+ data_path = "./data/mimic_subset"
56
+ pgm_path = "./checkpoints/a_r_s_f/sup_pgm_mimic/checkpoint.pt"
57
  vae_path = [
58
+ "./checkpoints/a_r_s_f/mimic_beta9_gelu_dgauss_1_lr3/checkpoint.pt", # base vae
59
+ "./checkpoints/a_r_s_f/mimic_dscm_lr_1e5_lagrange_lr_1_damping_10/6500_checkpoint.pt", # cf trained DSCM
60
  ]
61
  return data_path, vae_path, pgm_path
62
 
 
64
  def load_pgm(dataset_id, pgm_path):
65
  checkpoint = torch.load(pgm_path, map_location=DEVICE)
66
  args = Hparams()
67
+ args.update(checkpoint["hparams"])
68
  args.device = DEVICE
69
+ if "MNIST" in dataset_id:
70
  pgm = MorphoMNISTPGM(args).to(args.device)
71
+ elif "Brain" in dataset_id:
72
  pgm = FlowPGM(args).to(args.device)
73
+ elif "Chest" in dataset_id:
74
  pgm = ChestPGM(args).to(args.device)
75
+ pgm.load_state_dict(checkpoint["ema_model_state_dict"])
76
+ MODELS[dataset_id]["pgm"] = pgm
77
+ MODELS[dataset_id]["pgm_args"] = args
78
 
79
 
80
  def load_vae(dataset_id, vae_path):
81
+ if "Chest" in dataset_id:
82
  vae_path, dscm_path = vae_path[0], vae_path[1]
83
  checkpoint = torch.load(vae_path, map_location=DEVICE)
84
  args = Hparams()
85
+ args.update(checkpoint["hparams"])
86
  # backwards compatibility hack
87
+ if not hasattr(args, "vae"):
88
+ args.vae = "hierarchical"
89
+ if not hasattr(args, "cond_prior"):
90
  args.cond_prior = False
91
+ if hasattr(args, "free_bits"):
92
  args.kl_free_bits = args.free_bits
93
  args.device = DEVICE
94
  vae = HVAE(args).to(args.device)
95
 
96
+ if "Chest" in dataset_id:
97
  dscm_ckpt = torch.load(dscm_path, map_location=DEVICE)
98
+ vae.load_state_dict(
99
+ {
100
+ k[4:]: v
101
+ for k, v in dscm_ckpt["ema_model_state_dict"].items()
102
+ if "vae." in k
103
+ }
104
+ )
105
  else:
106
+ vae.load_state_dict(checkpoint["ema_model_state_dict"])
107
+ MODELS[dataset_id]["vae"] = vae
108
+ MODELS[dataset_id]["vae_args"] = args
109
 
110
 
111
  def get_dataloader(dataset_id, data_path):
112
+ MODELS[dataset_id]["pgm_args"].data_dir = data_path
113
+ args = MODELS[dataset_id]["pgm_args"]
114
+ if "MNIST" in dataset_id:
115
  datasets = morphomnist(args)
116
+ elif "Brain" in dataset_id:
117
  datasets = ukbb(args)
118
+ elif "Chest" in dataset_id:
119
  datasets = mimic(args)
120
+ DATA[dataset_id]["test"] = torch.utils.data.DataLoader(
121
+ datasets["test"], shuffle=False, batch_size=args.bs, num_workers=4
122
+ )
123
 
124
 
125
  def load_dataset(dataset_id):
126
  data_path, _, pgm_path = get_paths(dataset_id)
127
  checkpoint = torch.load(pgm_path, map_location=DEVICE)
128
  args = Hparams()
129
+ args.update(checkpoint["hparams"])
130
  args.device = DEVICE
131
+ MODELS[dataset_id]["pgm_args"] = args
132
  get_dataloader(dataset_id, data_path)
133
 
134
 
 
136
  _, vae_path, pgm_path = get_paths(dataset_id)
137
  load_pgm(dataset_id, pgm_path)
138
  load_vae(dataset_id, vae_path)
139
+
140
 
141
  @torch.no_grad()
142
  def counterfactual_inference(dataset_id, obs, do_pa):
143
+ pa = {k: v.clone() for k, v in obs.items() if k != "x"}
144
+ cf_pa = MODELS[dataset_id]["pgm"].counterfactual(
145
+ obs=pa, intervention=do_pa, num_particles=1
146
+ )
147
+ args, vae = MODELS[dataset_id]["vae_args"], MODELS[dataset_id]["vae"]
148
  _pa = vae_preprocess(args, {k: v.clone() for k, v in pa.items()})
149
+ _cf_pa = vae_preprocess(args, {k: v.clone() for k, v in cf_pa.items()})
150
+ z_t = 0.1 if "mnist" in args.hps else 1.0
151
+ z = vae.abduct(x=obs["x"], parents=_pa, t=z_t)
152
  if vae.cond_prior:
153
+ z = [z[j]["z"] for j in range(len(z))]
154
  px_loc, px_scale = vae.forward_latents(latents=z, parents=_pa)
155
  cf_loc, cf_scale = vae.forward_latents(latents=z, parents=_cf_pa)
156
+ u = (obs["x"] - px_loc) / px_scale.clamp(min=1e-12)
157
+ u_t = 0.1 if "mnist" in args.hps else 1.0 # cf sampling temp
158
  cf_scale = cf_scale * u_t
159
  cf_x = torch.clamp(cf_loc + cf_scale * u, min=-1, max=1)
160
+ return {"cf_x": cf_x, "rec_x": px_loc, "cf_pa": cf_pa}
161
 
162
 
163
  def get_obs_item(dataset_id, idx=None):
164
  if idx is None:
165
+ n_test = len(DATA[dataset_id]["test"].dataset)
166
  idx = torch.randperm(n_test)[0]
167
  idx = int(idx)
168
+ return idx, DATA[dataset_id]["test"].dataset.__getitem__(idx)
169
 
170
 
171
  def get_mnist_obs(idx=None):
172
+ dataset_id = "Morpho-MNIST"
173
  if not DATA[dataset_id]:
174
  load_dataset(dataset_id)
175
  idx, obs = get_obs_item(dataset_id, idx)
176
+ x = get_fig_arr(obs["x"].clone().squeeze().numpy())
177
+ t = (obs["thickness"].clone() + 1) / 2 * (6.255515 - 0.87598526) + 0.87598526
178
+ i = (obs["intensity"].clone() + 1) / 2 * (254.90317 - 66.601204) + 66.601204
179
+ y = DIGITS[obs["digit"].clone().argmax(-1)]
180
  return (idx, x, float(np.round(t, 2)), float(np.round(i, 2)), y)
181
 
182
 
183
  def get_brain_obs(idx=None):
184
+ dataset_id = "Brain MRI"
185
  if not DATA[dataset_id]:
186
  load_dataset(dataset_id)
187
  idx, obs = get_obs_item(dataset_id, idx)
188
+ x = get_fig_arr(obs["x"].clone().squeeze().numpy())
189
+ m = MRISEQ_CAT[int(obs["mri_seq"].clone().item())]
190
+ s = SEX_CAT[int(obs["sex"].clone().item())]
191
+ a = obs["age"].clone().item()
192
+ b = obs["brain_volume"].clone().item() / 1000 # in ml
193
+ v = obs["ventricle_volume"].clone().item() / 1000 # in ml
194
  return (idx, x, m, s, a, float(np.round(b, 2)), float(np.round(v, 2)))
195
 
196
 
197
  def get_chest_obs(idx=None):
198
+ dataset_id = "Chest X-ray"
199
  if not DATA[dataset_id]:
200
+ load_dataset(dataset_id)
201
  idx, obs = get_obs_item(dataset_id, idx)
202
+ x = get_fig_arr(postprocess(obs["x"].clone()))
203
+ s = SEX_CAT_CHEST[int(obs["sex"].clone().squeeze().numpy())]
204
+ f = FIND_CAT[int(obs["finding"].clone().squeeze().numpy())]
205
+ r = RACE_CAT[obs["race"].clone().squeeze().numpy().argmax(-1)]
206
+ a = (obs["age"].clone().squeeze().numpy() + 1) * 50
207
  return (idx, x, r, s, f, float(np.round(a, 1)))
208
 
209
 
210
  def infer_mnist_cf(*args):
211
+ dataset_id = "Morpho-MNIST"
212
  idx, _, t, i, y, do_t, do_i, do_y = args
213
  n_particles = 32
214
  # preprocess
215
+ obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx))
216
+ obs["x"] = (obs["x"] - 127.5) / 127.5
217
  for k, v in obs.items():
218
  obs[k] = v.view(1, 1) if len(v.shape) < 1 else v.unsqueeze(0)
219
+ obs[k] = obs[k].to(MODELS[dataset_id]["vae_args"].device).float()
220
  if n_particles > 1:
221
+ ndims = (1,) * 3 if k == "x" else (1,)
222
  obs[k] = obs[k].repeat(n_particles, *ndims)
223
  # intervention(s)
224
  do_pa = {}
225
  if do_t:
226
+ do_pa["thickness"] = torch.tensor(
227
+ normalize(t, x_max=6.255515, x_min=0.87598526)
228
+ ).view(1, 1)
229
  if do_i:
230
+ do_pa["intensity"] = torch.tensor(
231
+ normalize(i, x_max=254.90317, x_min=66.601204)
232
+ ).view(1, 1)
233
  if do_y:
234
+ do_pa["digit"] = F.one_hot(torch.tensor(DIGITS.index(y)), num_classes=10).view(
235
+ 1, 10
236
+ )
237
+
238
  for k, v in do_pa.items():
239
+ do_pa[k] = (
240
+ v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1)
241
+ )
242
  # infer counterfactual
243
  out = counterfactual_inference(dataset_id, obs, do_pa)
244
  # avg cf particles
245
+ cf_x = out["cf_x"].mean(0)
246
+ cf_x_std = out["cf_x"].std(0)
247
+ rec_x = out["rec_x"].mean(0)
248
+ cf_t = out["cf_pa"]["thickness"].mean(0)
249
+ cf_i = out["cf_pa"]["intensity"].mean(0)
250
+ cf_y = out["cf_pa"]["digit"].mean(0)
251
  # post process
252
  cf_x = postprocess(cf_x)
253
  cf_x_std = cf_x_std.squeeze().detach().cpu().numpy()
254
  rec_x = postprocess(rec_x)
255
  cf_t = np.round((cf_t.item() + 1) / 2 * (6.255515 - 0.87598526) + 0.87598526, 2)
256
+ cf_i = np.round((cf_i.item() + 1) / 2 * (254.90317 - 66.601204) + 66.601204, 2)
257
  cf_y = DIGITS[cf_y.argmax(-1)]
258
+ # plots
259
  # plt.close('all')
260
  effect = cf_x - rec_x
261
+ effect = get_fig_arr(
262
+ effect, cmap="RdBu_r", norm=MidpointNormalize(vmin=-255, midpoint=0, vmax=255)
263
+ )
264
  cf_x = get_fig_arr(cf_x)
265
+ cf_x_std = get_fig_arr(cf_x_std, cmap="jet")
266
  return (cf_x, cf_x_std, effect, cf_t, cf_i, cf_y)
267
 
268
 
269
  def infer_brain_cf(*args):
270
+ dataset_id = "Brain MRI"
271
  idx, _, m, s, a, b, v = args[:7]
272
  do_m, do_s, do_a, do_b, do_v = args[7:]
273
  n_particles = 16
274
  # preprocessing
275
+ obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx))
276
+ obs = preprocess_brain(MODELS[dataset_id]["vae_args"], obs)
 
277
  for k, _v in obs.items():
278
  if n_particles > 1:
279
+ ndims = (1,) * 3 if k == "x" else (1,)
280
  obs[k] = _v.repeat(n_particles, *ndims)
281
  # interventions(s)
282
  do_pa = {}
283
  if do_m:
284
+ do_pa["mri_seq"] = torch.tensor(MRISEQ_CAT.index(m)).view(1, 1)
285
  if do_s:
286
+ do_pa["sex"] = torch.tensor(SEX_CAT.index(s)).view(1, 1)
287
  if do_a:
288
+ do_pa["age"] = torch.tensor(a).view(1, 1)
289
  if do_b:
290
+ do_pa["brain_volume"] = torch.tensor(b * 1000).view(1, 1)
291
  if do_v:
292
+ do_pa["ventricle_volume"] = torch.tensor(v * 1000).view(1, 1)
293
  # normalize continuous attributes
294
+ for k in ["age", "brain_volume", "ventricle_volume"]:
295
  if k in do_pa.keys():
296
  k_max, k_min = get_attr_max_min(k)
297
  do_pa[k] = (do_pa[k] - k_min) / (k_max - k_min) # [0,1]
298
  do_pa[k] = 2 * do_pa[k] - 1 # [-1,1]
299
 
300
  for k, _v in do_pa.items():
301
+ do_pa[k] = (
302
+ _v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1)
303
+ )
304
  # infer counterfactual
305
  out = counterfactual_inference(dataset_id, obs, do_pa)
306
  # avg cf particles
307
+ cf_x = out["cf_x"].mean(0)
308
+ cf_x_std = out["cf_x"].std(0)
309
+ rec_x = out["rec_x"].mean(0)
310
+ cf_m = out["cf_pa"]["mri_seq"].mean(0)
311
+ cf_s = out["cf_pa"]["sex"].mean(0)
312
  # post process
313
  cf_x = postprocess(cf_x)
314
  cf_x_std = cf_x_std.squeeze().detach().cpu().numpy()
 
316
  cf_m = MRISEQ_CAT[int(cf_m.item())]
317
  cf_s = SEX_CAT[int(cf_s.item())]
318
  cf_ = {}
319
+ for k in ["age", "brain_volume", "ventricle_volume"]: # unnormalize
320
  k_max, k_min = get_attr_max_min(k)
321
+ cf_[k] = (out["cf_pa"][k].mean(0).item() + 1) / 2 * (k_max - k_min) + k_min
322
  # plots
323
+ # plt.close('all')
324
  effect = cf_x - rec_x
325
+ effect = get_fig_arr(
326
+ effect,
327
+ cmap="RdBu_r",
328
+ norm=MidpointNormalize(vmin=effect.min(), midpoint=0, vmax=effect.max()),
329
+ )
330
  cf_x = get_fig_arr(cf_x)
331
+ cf_x_std = get_fig_arr(cf_x_std, cmap="jet")
332
+ return (
333
+ cf_x,
334
+ cf_x_std,
335
+ effect,
336
+ cf_m,
337
+ cf_s,
338
+ np.round(cf_["age"], 1),
339
+ np.round(cf_["brain_volume"] / 1000, 2),
340
+ np.round(cf_["ventricle_volume"] / 1000, 2),
341
+ )
342
 
343
 
344
  def infer_chest_cf(*args):
345
+ dataset_id = "Chest X-ray"
346
  idx, _, r, s, f, a = args[:6]
347
  do_r, do_s, do_f, do_a = args[6:]
348
  n_particles = 16
349
  # preprocessing
350
+ obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx))
351
  for k, v in obs.items():
352
+ obs[k] = v.to(MODELS[dataset_id]["vae_args"].device).float()
353
  if n_particles > 1:
354
+ ndims = (1,) * 3 if k == "x" else (1,)
355
  obs[k] = obs[k].repeat(n_particles, *ndims)
356
  # intervention(s)
357
  do_pa = {}
358
  with torch.no_grad():
359
  if do_s:
360
+ do_pa["sex"] = torch.tensor(SEX_CAT_CHEST.index(s)).view(1, 1)
361
  if do_f:
362
+ do_pa["finding"] = torch.tensor(FIND_CAT.index(f)).view(1, 1)
363
  if do_r:
364
+ do_pa["race"] = F.one_hot(
365
+ torch.tensor(RACE_CAT.index(r)), num_classes=3
366
+ ).view(1, 3)
367
  if do_a:
368
+ do_pa["age"] = torch.tensor(a / 100 * 2 - 1).view(1, 1)
369
  for k, v in do_pa.items():
370
+ do_pa[k] = (
371
+ v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1)
372
+ )
373
  # infer counterfactual
374
  out = counterfactual_inference(dataset_id, obs, do_pa)
375
  # avg cf particles
376
+ cf_x = out["cf_x"].mean(0)
377
+ cf_x_std = out["cf_x"].std(0)
378
+ rec_x = out["rec_x"].mean(0)
379
+ cf_r = out["cf_pa"]["race"].mean(0)
380
+ cf_s = out["cf_pa"]["sex"].mean(0)
381
+ cf_f = out["cf_pa"]["finding"].mean(0)
382
+ cf_a = out["cf_pa"]["age"].mean(0)
383
  # post process
384
  cf_x = postprocess(cf_x)
385
  cf_x_std = cf_x_std.squeeze().detach().cpu().numpy()
 
391
  # plots
392
  # plt.close('all')
393
  effect = cf_x - rec_x
394
+ effect = get_fig_arr(
395
+ effect,
396
+ cmap="RdBu_r",
397
+ norm=MidpointNormalize(vmin=effect.min(), midpoint=0, vmax=effect.max()),
398
+ )
399
  cf_x = get_fig_arr(cf_x)
400
+ cf_x_std = get_fig_arr(cf_x_std, cmap="jet")
401
  return (cf_x, cf_x_std, effect, cf_r, cf_s, cf_f, np.round(cf_a, 1))
402
 
403
 
 
409
  with gr.Row().style(equal_height=True):
410
  idx = gr.Number(value=0, visible=False)
411
  with gr.Column(scale=1, min_width=200):
412
+ x = gr.Image(label="Observation", interactive=False).style(
413
+ height=HEIGHT
414
+ )
415
  with gr.Column(scale=1, min_width=200):
416
+ cf_x = gr.Image(label="Counterfactual", interactive=False).style(
417
+ height=HEIGHT
418
+ )
419
  with gr.Column(scale=1, min_width=200):
420
+ cf_x_std = gr.Image(
421
+ label="Counterfactual Uncertainty", interactive=False
422
+ ).style(height=HEIGHT)
423
  with gr.Column(scale=1, min_width=200):
424
+ effect = gr.Image(
425
+ label="Direct Causal Effect", interactive=False
426
+ ).style(height=HEIGHT)
427
  with gr.Row().style(equal_height=True):
428
  with gr.Column(scale=1.75):
429
+ gr.Markdown(
430
+ "#### Intervention"
431
+ + 28 * "&emsp;"
432
+ + "[arXiv paper](https://arxiv.org/abs/2306.15764) &ensp; | &ensp; [GitHub code](https://github.com/biomedia-mira/causal-gen)"
433
+ )
434
  with gr.Column():
435
  do_y = gr.Checkbox(label="do(digit)", value=False)
436
  y = gr.Radio(DIGITS, label="", interactive=False)
437
  with gr.Row():
438
  with gr.Column(min_width=100):
439
  do_t = gr.Checkbox(label="do(thickness)", value=False)
440
+ t = gr.Slider(
441
+ label="\u00A0",
442
+ minimum=0.9,
443
+ maximum=5.5,
444
+ step=0.01,
445
+ interactive=False,
446
+ )
447
  with gr.Column(min_width=100):
448
  do_i = gr.Checkbox(label="do(intensity)", value=False)
449
+ i = gr.Slider(
450
+ label="\u00A0",
451
+ minimum=50,
452
+ maximum=255,
453
+ step=0.01,
454
+ interactive=False,
455
+ )
456
  with gr.Row():
457
  new = gr.Button("New Observation")
458
  reset = gr.Button("Reset", variant="stop")
459
  submit = gr.Button("Submit", variant="primary")
460
  with gr.Column(scale=1):
461
  gr.Markdown("### &nbsp;")
462
+ causal_graph = gr.Image(
463
+ label="Causal Graph", interactive=False
464
+ ).style(height=300)
465
 
466
  with gr.TabItem("Brain MRI") as brain_tab:
467
  brain_id = gr.Textbox(value=brain_tab.label, visible=False)
 
469
  with gr.Row().style(equal_height=True):
470
  idx_brain = gr.Number(value=0, visible=False)
471
  with gr.Column(scale=1, min_width=200):
472
+ x_brain = gr.Image(label="Observation", interactive=False).style(
473
+ height=HEIGHT
474
+ )
475
  with gr.Column(scale=1, min_width=200):
476
+ cf_x_brain = gr.Image(
477
+ label="Counterfactual", interactive=False
478
+ ).style(height=HEIGHT)
479
  with gr.Column(scale=1, min_width=200):
480
+ cf_x_std_brain = gr.Image(
481
+ label="Counterfactual Uncertainty", interactive=False
482
+ ).style(height=HEIGHT)
483
  with gr.Column(scale=1, min_width=200):
484
+ effect_brain = gr.Image(
485
+ label="Direct Causal Effect", interactive=False
486
+ ).style(height=HEIGHT)
487
  with gr.Row():
488
  with gr.Column(scale=2.55):
489
+ gr.Markdown(
490
+ "#### Intervention"
491
+ + 28 * "&emsp;"
492
+ + "[arXiv paper](https://arxiv.org/abs/2306.15764) &ensp; | &ensp; [GitHub code](https://github.com/biomedia-mira/causal-gen)"
493
+ )
494
  with gr.Row():
495
  with gr.Column(min_width=200):
496
  do_m = gr.Checkbox(label="do(MRI sequence)", value=False)
497
+ m = gr.Radio(
498
+ ["T1", "T2-FLAIR"], label="", interactive=False
499
+ )
500
  with gr.Column(min_width=200):
501
  do_s = gr.Checkbox(label="do(sex)", value=False)
502
+ s = gr.Radio(
503
+ ["female", "male"], label="", interactive=False
504
+ )
505
  with gr.Row():
506
  with gr.Column(min_width=100):
507
  do_a = gr.Checkbox(label="do(age)", value=False)
508
+ a = gr.Slider(
509
+ label="\u00A0",
510
+ value=50,
511
+ minimum=44,
512
+ maximum=73,
513
+ step=1,
514
+ interactive=False,
515
+ )
516
  with gr.Column(min_width=100):
517
  do_b = gr.Checkbox(label="do(brain volume)", value=False)
518
+ b = gr.Slider(
519
+ label="\u00A0",
520
+ value=1000,
521
+ minimum=850,
522
+ maximum=1550,
523
+ step=20,
524
+ interactive=False,
525
+ )
526
  with gr.Column(min_width=100):
527
+ do_v = gr.Checkbox(
528
+ label="do(ventricle volume)", value=False
529
+ )
530
+ v = gr.Slider(
531
+ label="\u00A0",
532
+ value=40,
533
+ minimum=10,
534
+ maximum=125,
535
+ step=2,
536
+ interactive=False,
537
+ )
538
  with gr.Row():
539
  new_brain = gr.Button("New Observation")
540
+ reset_brain = gr.Button("Reset", variant="stop")
541
+ submit_brain = gr.Button("Submit", variant="primary")
542
  with gr.Column(scale=1):
543
  # gr.Markdown("### &nbsp;")
544
+ causal_graph_brain = gr.Image(
545
+ label="Causal Graph", interactive=False
546
+ ).style(height=340)
547
 
548
  with gr.TabItem("Chest X-ray") as chest_tab:
549
  chest_id = gr.Textbox(value=chest_tab.label, visible=False)
 
551
  with gr.Row().style(equal_height=True):
552
  idx_chest = gr.Number(value=0, visible=False)
553
  with gr.Column(scale=1, min_width=200):
554
+ x_chest = gr.Image(label="Observation", interactive=False).style(
555
+ height=HEIGHT
556
+ )
557
  with gr.Column(scale=1, min_width=200):
558
+ cf_x_chest = gr.Image(
559
+ label="Counterfactual", interactive=False
560
+ ).style(height=HEIGHT)
561
  with gr.Column(scale=1, min_width=200):
562
+ cf_x_std_chest = gr.Image(
563
+ label="Counterfactual Uncertainty", interactive=False
564
+ ).style(height=HEIGHT)
565
  with gr.Column(scale=1, min_width=200):
566
+ effect_chest = gr.Image(
567
+ label="Direct Causal Effect", interactive=False
568
+ ).style(height=HEIGHT)
569
 
570
  with gr.Row():
571
  with gr.Column(scale=2.55):
572
+ gr.Markdown(
573
+ "#### Intervention"
574
+ + 28 * "&emsp;"
575
+ + "[arXiv paper](https://arxiv.org/abs/2306.15764) &ensp; | &ensp; [GitHub code](https://github.com/biomedia-mira/causal-gen)"
576
+ )
577
  with gr.Row().style(equal_height=True):
578
  with gr.Column(min_width=200):
579
  do_f_chest = gr.Checkbox(label="do(disease)", value=False)
580
  f_chest = gr.Radio(FIND_CAT, label="", interactive=False)
581
  with gr.Column(min_width=200):
582
  do_s_chest = gr.Checkbox(label="do(sex)", value=False)
583
+ s_chest = gr.Radio(
584
+ SEX_CAT_CHEST, label="", interactive=False
585
+ )
586
 
587
  with gr.Row():
588
  with gr.Column(min_width=200):
589
  do_r_chest = gr.Checkbox(label="do(race)", value=False)
590
+ r_chest = gr.Radio(RACE_CAT, label="", interactive=False)
591
  with gr.Column(min_width=200):
592
  do_a_chest = gr.Checkbox(label="do(age)", value=False)
593
+ a_chest = gr.Slider(
594
+ label="\u00A0", minimum=18, maximum=98, step=1
595
+ )
596
+
597
  with gr.Row():
598
  new_chest = gr.Button("New Observation")
599
  reset_chest = gr.Button("Reset", variant="stop")
600
  submit_chest = gr.Button("Submit", variant="primary")
601
  with gr.Column(scale=1):
602
  # gr.Markdown("### &nbsp;")
603
+ causal_graph_chest = gr.Image(
604
+ label="Causal Graph", interactive=False
605
+ ).style(height=345)
606
 
607
  # morphomnist
608
  do = [do_t, do_i, do_y]
 
644
  new_chest.click(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest)
645
 
646
  # "new" button: reset cf output panels
647
+ for _k, _v in zip(
648
+ [new, new_brain, new_chest], [cf_out, cf_out_brain, cf_out_chest]
649
+ ):
650
+ _k.click(fn=lambda: (gr.update(value=None),) * 3, inputs=None, outputs=_v)
651
 
652
  # "reset" button: reload current observations
653
  reset.click(fn=get_mnist_obs, inputs=idx, outputs=obs)
654
  reset_brain.click(fn=get_brain_obs, inputs=idx_brain, outputs=obs_brain)
655
  reset_chest.click(fn=get_chest_obs, inputs=idx_chest, outputs=obs_chest)
656
+
657
  # "reset" button: deselect intervention checkboxes
658
+ reset.click(fn=lambda: (gr.update(value=False),) * len(do), inputs=None, outputs=do)
659
+ reset_brain.click(
660
+ fn=lambda: (gr.update(value=False),) * len(do_brain),
661
+ inputs=None,
662
+ outputs=do_brain,
663
+ )
664
+ reset_chest.click(
665
+ fn=lambda: (gr.update(value=False),) * len(do_chest),
666
+ inputs=None,
667
+ outputs=do_chest,
668
+ )
669
 
670
  # "reset" button: reset cf output panels
671
+ for _k, _v in zip(
672
+ [reset, reset_brain, reset_chest], [cf_out, cf_out_brain, cf_out_chest]
673
+ ):
674
+ _k.click(fn=lambda: plt.close("all"), inputs=None, outputs=None)
675
+ _k.click(fn=lambda: (gr.update(value=None),) * 3, inputs=None, outputs=_v)
676
 
677
  # enable mnist interventions when checkbox is selected & update graph
678
  for _k, _v in zip(do, [t, i, y]):
679
  _k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v)
680
  _k.change(mnist_graph, inputs=do, outputs=causal_graph)
681
+
682
  # enable brain interventions when checkbox is selected & update graph
683
  for _k, _v in zip(do_brain, [m, s, a, b, v]):
684
  _k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v)
 
688
  for _k, _v in zip(do_chest, [r_chest, s_chest, f_chest, a_chest]):
689
  _k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v)
690
  _k.change(chest_graph, inputs=do_chest, outputs=causal_graph_chest)
691
+
692
  # "submit" button: infer countefactuals
693
  submit.click(fn=infer_mnist_cf, inputs=obs + do, outputs=cf_out + [t, i, y])
694
+ submit_brain.click(
695
+ fn=infer_brain_cf,
696
+ inputs=obs_brain + do_brain,
697
+ outputs=cf_out_brain + [m, s, a, b, v],
698
+ )
699
+ submit_chest.click(
700
+ fn=infer_chest_cf,
701
+ inputs=obs_chest + do_chest,
702
+ outputs=cf_out_chest + [r_chest, s_chest, f_chest, a_chest],
703
+ )
704
 
705
  if __name__ == "__main__":
706
+ demo.queue()
707
+ demo.launch()
app_utils.py CHANGED
@@ -3,16 +3,18 @@ import numpy as np
3
  import networkx as nx
4
  import matplotlib.pyplot as plt
5
 
 
 
6
  from matplotlib import rc, patches, colors
7
- rc('font', **{'family': 'serif', 'serif': ['Roman']})
8
- rc('text', usetex=True)
9
- rc('image', interpolation='none')
10
- rc('text.latex', preamble=r'\usepackage{amsmath} \usepackage{amssymb}')
 
11
 
12
  from datasets import get_attr_max_min
13
 
14
- from PIL import Image
15
- HAMMER = np.array(Image.open('./hammer.png').resize((35, 35))) / 255
16
 
17
 
18
  class MidpointNormalize(colors.Normalize):
@@ -21,7 +23,7 @@ class MidpointNormalize(colors.Normalize):
21
  colors.Normalize.__init__(self, vmin, vmax, clip)
22
 
23
  def __call__(self, value, clip=None):
24
- v_ext = np.max( [ np.abs(self.vmin), np.abs(self.vmax) ] )
25
  x, y = [-v_ext, self.midpoint, v_ext], [0, 0.5, 1]
26
  return np.ma.masked_array(np.interp(value, x, y))
27
 
@@ -31,10 +33,10 @@ def postprocess(x):
31
 
32
 
33
  def mnist_graph(*args):
34
- x, t, i, y = r'$\mathbf{x}$', r'$t$', r'$i$', r'$y$'
35
- ut, ui, uy = r'$\mathbf{U}_t$', r'$\mathbf{U}_i$', r'$\mathbf{U}_y$'
36
- zx, ex = r'$\mathbf{z}_{1:L}$', r'$\boldsymbol{\epsilon}$'
37
-
38
  G = nx.DiGraph()
39
  G.add_edge(t, x)
40
  G.add_edge(i, x)
@@ -47,31 +49,36 @@ def mnist_graph(*args):
47
  G.add_edge(ex, x)
48
 
49
  pos = {
50
- y: (0, 0), uy: (-1, 0),
51
- t: (0, 0.5), ut: (0, 1),
52
- x: (1, 0), zx: (2, 0.375), ex: (2, 0),
53
- i: (1, 0.5), ui: (1, 1),
 
 
 
 
 
54
  }
55
 
56
  node_c = {}
57
  for node in G:
58
- node_c[node] = 'lightgrey' if node in [x, t, i, y] else 'white'
59
- node_line_c = {k: 'black' for k, _ in node_c.items()}
60
- edge_c = {e: 'black' for e in G.edges}
61
 
62
  if args[0]: # do_t
63
- edge_c[(ut, t)] = 'lightgrey'
64
  # G.remove_edge(ut, t)
65
- node_line_c[t] = 'red'
66
  if args[1]: # do_i
67
- edge_c[(ui, i)] = 'lightgrey'
68
- edge_c[(t, i)] = 'lightgrey'
69
  # G.remove_edges_from([(ui, i), (t, i)])
70
- node_line_c[i] = 'red'
71
  if args[2]: # do_y
72
- edge_c[(uy, y)] = 'lightgrey'
73
  # G.remove_edge(uy, y)
74
- node_line_c[y] = 'red'
75
 
76
  fs = 30
77
  options = {
@@ -83,27 +90,36 @@ def mnist_graph(*args):
83
  "linewidths": 2,
84
  "width": 2,
85
  }
86
- plt.close('all')
87
- fig, ax = plt.subplots(1, 1, figsize=(6,4.1))#, constrained_layout=True)
88
  # fig.patch.set_visible(False)
89
  ax.margins(x=0.06, y=0.15, tight=False)
90
  ax.axis("off")
91
- nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle='-|>', ax=ax)
92
  # need to reuse x, y limits so that the graphs plot the same way before and after removing edges
93
  x_lim = (-1.348, 2.348)
94
  y_lim = (-0.215, 1.215)
95
  ax.set_xlim(x_lim)
96
  ax.set_ylim(y_lim)
97
- rect = patches.FancyBboxPatch((1.75, -0.16), 0.5, 0.7, boxstyle="round, pad=0.05, rounding_size=0", linewidth=2, edgecolor='black', facecolor='none', linestyle='-')
 
 
 
 
 
 
 
 
 
98
  ax.add_patch(rect)
99
  ax.text(1.85, 0.65, r"$\mathbf{U}_{\mathbf{x}}$", fontsize=fs)
100
 
101
  if args[0]: # do_t
102
- fig.figimage(HAMMER, 0.26*fig.bbox.xmax, 0.525*fig.bbox.ymax, zorder=10)
103
  if args[1]: # do_i
104
- fig.figimage(HAMMER, 0.5175*fig.bbox.xmax, 0.525*fig.bbox.ymax, zorder=11)
105
  if args[2]: # do_y
106
- fig.figimage(HAMMER, 0.26*fig.bbox.xmax, 0.2*fig.bbox.ymax, zorder=12)
107
 
108
  fig.tight_layout()
109
  fig.canvas.draw()
@@ -111,10 +127,16 @@ def mnist_graph(*args):
111
 
112
 
113
  def brain_graph(*args):
114
- x, m, s, a, b, v = r'$\mathbf{x}$', r'$m$', r'$s$', r'$a$', r'$b$', r'$v$'
115
- um, us, ua, ub, uv = r'$\mathbf{U}_m$', r'$\mathbf{U}_s$', r'$\mathbf{U}_a$', r'$\mathbf{U}_b$', r'$\mathbf{U}_v$'
116
- zx, ex = r'$\mathbf{z}_{1:L}$', r'$\boldsymbol{\epsilon}$'
117
-
 
 
 
 
 
 
118
  G = nx.DiGraph()
119
  G.add_edge(m, x)
120
  G.add_edge(s, x)
@@ -132,44 +154,51 @@ def brain_graph(*args):
132
  G.add_edge(uv, v)
133
 
134
  pos = {
135
- x: (0, 0), zx: (-0.25, -1), ex: (0.25, -1),
136
- a: (0, 1), ua: (0, 2),
137
- s: (1, 0), us: (1, -1),
138
- b: (1, 1), ub: (1, 2),
139
- m: (-1, 0), um: (-1, -1),
140
- v: (-1, 1), uv: (-1, 2)
 
 
 
 
 
 
 
141
  }
142
 
143
  node_c = {}
144
  for node in G:
145
- node_c[node] = 'lightgrey' if node in [x, m, s, a, b, v] else 'white'
146
- node_line_c = {k: 'black' for k, _ in node_c.items()}
147
- edge_c = {e: 'black' for e in G.edges}
148
 
149
  if args[0]: # do_m
150
  # G.remove_edge(um, m)
151
- edge_c[(um, m)] = 'lightgrey'
152
- node_line_c[m] = 'red'
153
  if args[1]: # do_s
154
  # G.remove_edge(us, s)
155
- edge_c[(us, s)] = 'lightgrey'
156
- node_line_c[s] = 'red'
157
  if args[2]: # do_a
158
  # G.remove_edge(ua, a)
159
- edge_c[(ua, a)] = 'lightgrey'
160
- node_line_c[a] = 'red'
161
  if args[3]: # do_b
162
  # G.remove_edges_from([(ub, b), (s, b), (a, b)])
163
- edge_c[(ub, b)] = 'lightgrey'
164
- edge_c[(s, b)] = 'lightgrey'
165
- edge_c[(a, b)] = 'lightgrey'
166
- node_line_c[b] = 'red'
167
  if args[4]: # do_v
168
  # G.remove_edges_from([(uv, v), (a, v), (b, v)])
169
- edge_c[(uv, v)] = 'lightgrey'
170
- edge_c[(a, v)] = 'lightgrey'
171
- edge_c[(b, v)] = 'lightgrey'
172
- node_line_c[v] = 'red'
173
 
174
  fs = 30
175
  options = {
@@ -182,33 +211,49 @@ def brain_graph(*args):
182
  "width": 2,
183
  }
184
 
185
- plt.close('all')
186
- fig, ax = plt.subplots(1, 1, figsize=(5,5))#, constrained_layout=True)
187
  # fig.patch.set_visible(False)
188
  ax.margins(x=0.1, y=0.08, tight=False)
189
  ax.axis("off")
190
- nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle='-|>', ax=ax)
191
  # need to reuse x, y limits so that the graphs plot the same way before and after removing edges
192
  x_lim = (-1.32, 1.32)
193
  y_lim = (-1.414, 2.414)
194
  ax.set_xlim(x_lim)
195
  ax.set_ylim(y_lim)
196
- rect = patches.FancyBboxPatch((-0.5, -1.325), 1, 0.65, boxstyle="round, pad=0.05, rounding_size=0", linewidth=2, edgecolor='black', facecolor='none', linestyle='-')
 
 
 
 
 
 
 
 
 
197
  ax.add_patch(rect)
198
  # ax.text(1.85, 0.65, r"$\mathbf{U}_{\mathbf{x}}$", fontsize=fs)
199
-
200
  if args[0]: # do_m
201
- fig.figimage(HAMMER, 0.0075*fig.bbox.xmax, 0.395*fig.bbox.ymax, zorder=10)
202
  if args[1]: # do_s
203
- fig.figimage(HAMMER, 0.72*fig.bbox.xmax, 0.395*fig.bbox.ymax, zorder=11)
204
  if args[2]: # do_a
205
- fig.figimage(HAMMER, 0.363*fig.bbox.xmax, 0.64*fig.bbox.ymax, zorder=12)
206
  if args[3]: # do_b
207
- fig.figimage(HAMMER, 0.72*fig.bbox.xmax, 0.64*fig.bbox.ymax, zorder=13)
208
  if args[4]: # do_v
209
- fig.figimage(HAMMER, 0.0075*fig.bbox.xmax, 0.64*fig.bbox.ymax, zorder=14)
210
  else: # b -> v
211
- a3 = patches.FancyArrowPatch((.86, 1.21), (-.86, 1.21), connectionstyle="arc3,rad=.3", linewidth=2, arrowstyle='simple, head_width=10, head_length=10', color='k')
 
 
 
 
 
 
 
212
  ax.add_patch(a3)
213
  # print(ax.get_xlim())
214
  # print(ax.get_ylim())
@@ -217,12 +262,16 @@ def brain_graph(*args):
217
  return np.array(fig.canvas.renderer.buffer_rgba())
218
 
219
 
220
-
221
  def chest_graph(*args):
222
- x, a, d, r, s= r'$\mathbf{x}$', r'$a$', r'$d$', r'$r$', r'$s$'
223
- ua, ud, ur, us = r'$\mathbf{U}_a$', r'$\mathbf{U}_d$', r'$\mathbf{U}_r$', r'$\mathbf{U}_s$'
224
- zx, ex = r'$\mathbf{z}_{1:L}$', r'$\boldsymbol{\epsilon}$'
225
-
 
 
 
 
 
226
  G = nx.DiGraph()
227
  G.add_edge(ua, a)
228
  G.add_edge(ud, d)
@@ -237,7 +286,7 @@ def chest_graph(*args):
237
  G.add_edge(a, x)
238
 
239
  pos = {
240
- x: (0, 0),
241
  a: (-1, 1),
242
  d: (0, 1),
243
  r: (1, 1),
@@ -246,34 +295,34 @@ def chest_graph(*args):
246
  ud: (0, 2),
247
  ur: (1, 2),
248
  us: (1, -1),
249
- zx: (-0.25, -1),
250
  ex: (0.25, -1),
251
  }
252
 
253
  node_c = {}
254
  for node in G:
255
- node_c[node] = 'lightgrey' if node in [x, a, d, r, s] else 'white'
256
 
257
- edge_c = {e: 'black' for e in G.edges}
258
- node_line_c = {k: 'black' for k, _ in node_c.items()}
259
 
260
  if args[0]: # do_r
261
  # G.remove_edge(ur, r)
262
- edge_c[(ur, r)] = 'lightgrey'
263
- node_line_c[r] = 'red'
264
  if args[1]: # do_s
265
  # G.remove_edges_from([(us, s)])
266
- edge_c[(us, s)] = 'lightgrey'
267
- node_line_c[s] = 'red'
268
  if args[2]: # do_f (do_d)
269
  # G.remove_edges_from([(ud, d), (a, d)])
270
- edge_c[(ud, d)] = 'lightgrey'
271
- edge_c[(a, d)] = 'lightgrey'
272
- node_line_c[d] = 'red'
273
  if args[3]: # do_a
274
  # G.remove_edge(ua, a)
275
- edge_c[(ua, a)] = 'lightgrey'
276
- node_line_c[a] = 'red'
277
 
278
  fs = 30
279
  options = {
@@ -285,29 +334,38 @@ def chest_graph(*args):
285
  "linewidths": 2,
286
  "width": 2,
287
  }
288
- plt.close('all')
289
- fig, ax = plt.subplots(1, 1, figsize=(5,5))#, constrained_layout=True)
290
  # fig.patch.set_visible(False)
291
  ax.margins(x=0.1, y=0.08, tight=False)
292
  ax.axis("off")
293
- nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle='-|>', ax=ax)
294
  # need to reuse x, y limits so that the graphs plot the same way before and after removing edges
295
  x_lim = (-1.32, 1.32)
296
  y_lim = (-1.414, 2.414)
297
  ax.set_xlim(x_lim)
298
  ax.set_ylim(y_lim)
299
- rect = patches.FancyBboxPatch((-0.5, -1.325), 1, 0.65, boxstyle="round, pad=0.05, rounding_size=0", linewidth=2, edgecolor='black', facecolor='none', linestyle='-')
 
 
 
 
 
 
 
 
 
300
  ax.add_patch(rect)
301
  ax.text(-0.9, -1.075, r"$\mathbf{U}_{\mathbf{x}}$", fontsize=fs)
302
-
303
  if args[0]: # do_r
304
- fig.figimage(HAMMER, 0.72*fig.bbox.xmax, 0.64*fig.bbox.ymax, zorder=10)
305
  if args[1]: # do_s
306
- fig.figimage(HAMMER, 0.72*fig.bbox.xmax, 0.395*fig.bbox.ymax, zorder=11)
307
  if args[2]: # do_f
308
- fig.figimage(HAMMER, 0.363*fig.bbox.xmax, 0.64*fig.bbox.ymax, zorder=12)
309
  if args[3]: # do_a
310
- fig.figimage(HAMMER, 0.0075*fig.bbox.xmax, 0.64*fig.bbox.ymax, zorder=13)
311
 
312
  fig.tight_layout()
313
  fig.canvas.draw()
@@ -315,51 +373,55 @@ def chest_graph(*args):
315
 
316
 
317
  def vae_preprocess(args, pa):
318
- if 'ukbb' in args.hps:
319
  # preprocessing ukbb parents for the vae which was originally trained using
320
  # log standardized parents. The pgm was trained using [-1,1] normalization
321
  # first undo [-1,1] parent preprocessing back to original range
322
  for k, v in pa.items():
323
- if k != 'mri_seq' and k != 'sex':
324
  pa[k] = (v + 1) / 2 # [-1,1] -> [0,1]
325
  _max, _min = get_attr_max_min(k)
326
  pa[k] = pa[k] * (_max - _min) + _min
327
  # log_standardize parents for vae input
328
  for k, v in pa.items():
329
  logpa_k = torch.log(v.clamp(min=1e-12))
330
- if k == 'age':
331
  pa[k] = (logpa_k - 4.112339973449707) / 0.11769197136163712
332
- elif k == 'brain_volume':
333
  pa[k] = (logpa_k - 13.965583801269531) / 0.09537758678197861
334
- elif k == 'ventricle_volume':
335
  pa[k] = (logpa_k - 10.345998764038086) / 0.43127763271331787
336
  # concatenate parents expand to input res for conditioning the vae
337
- pa = torch.cat([pa[k] if len(pa[k].shape) > 1 else pa[k][..., None]
338
- for k in args.parents_x], dim=1)
339
- pa = pa[..., None, None].repeat(1, 1, *(args.input_res,)*2).to(args.device).float()
 
 
 
 
340
  return pa
341
 
342
 
343
  def preprocess_brain(args, obs):
344
- obs['x'] = (obs['x'][None,...].float().to(args.device) - 127.5) / 127.5 # [-1,1]
345
  # for all other variables except x
346
- for k in [k for k in obs.keys() if k != 'x']:
347
  obs[k] = obs[k].float().to(args.device).view(1, 1)
348
- if k in ['age', 'brain_volume', 'ventricle_volume']:
349
  k_max, k_min = get_attr_max_min(k)
350
  obs[k] = (obs[k] - k_min) / (k_max - k_min) # [0,1]
351
  obs[k] = 2 * obs[k] - 1 # [-1,1]
352
  return obs
353
 
354
 
355
- def get_fig_arr(x, width=4, height=4, dpi=144, cmap='Greys_r', norm=None):
356
  fig = plt.figure(figsize=(width, height), dpi=dpi)
357
- ax = plt.axes([0,0,1,1], frameon=False)
358
- if cmap == 'Greys_r':
359
  ax.imshow(x, cmap=cmap, vmin=0, vmax=255)
360
  else:
361
  ax.imshow(x, cmap=cmap, norm=norm)
362
- ax.axis('off')
363
  fig.canvas.draw()
364
  return np.array(fig.canvas.renderer.buffer_rgba())
365
 
@@ -370,4 +432,4 @@ def normalize(x, x_min=None, x_max=None, zero_one=False):
370
  if x_max is None:
371
  x_max = x.max()
372
  x = (x - x_min) / (x_max - x_min) # [0,1]
373
- return x if zero_one else 2 * x - 1 # else [-1,1]
 
3
  import networkx as nx
4
  import matplotlib.pyplot as plt
5
 
6
+ from PIL import Image
7
+
8
  from matplotlib import rc, patches, colors
9
+
10
+ rc("font", **{"family": "serif", "serif": ["Roman"]})
11
+ rc("text", usetex=True)
12
+ rc("image", interpolation="none")
13
+ rc("text.latex", preamble=r"\usepackage{amsmath} \usepackage{amssymb}")
14
 
15
  from datasets import get_attr_max_min
16
 
17
+ HAMMER = np.array(Image.open("./hammer.png").resize((35, 35))) / 255
 
18
 
19
 
20
  class MidpointNormalize(colors.Normalize):
 
23
  colors.Normalize.__init__(self, vmin, vmax, clip)
24
 
25
  def __call__(self, value, clip=None):
26
+ v_ext = np.max([np.abs(self.vmin), np.abs(self.vmax)])
27
  x, y = [-v_ext, self.midpoint, v_ext], [0, 0.5, 1]
28
  return np.ma.masked_array(np.interp(value, x, y))
29
 
 
33
 
34
 
35
  def mnist_graph(*args):
36
+ x, t, i, y = r"$\mathbf{x}$", r"$t$", r"$i$", r"$y$"
37
+ ut, ui, uy = r"$\mathbf{U}_t$", r"$\mathbf{U}_i$", r"$\mathbf{U}_y$"
38
+ zx, ex = r"$\mathbf{z}_{1:L}$", r"$\boldsymbol{\epsilon}$"
39
+
40
  G = nx.DiGraph()
41
  G.add_edge(t, x)
42
  G.add_edge(i, x)
 
49
  G.add_edge(ex, x)
50
 
51
  pos = {
52
+ y: (0, 0),
53
+ uy: (-1, 0),
54
+ t: (0, 0.5),
55
+ ut: (0, 1),
56
+ x: (1, 0),
57
+ zx: (2, 0.375),
58
+ ex: (2, 0),
59
+ i: (1, 0.5),
60
+ ui: (1, 1),
61
  }
62
 
63
  node_c = {}
64
  for node in G:
65
+ node_c[node] = "lightgrey" if node in [x, t, i, y] else "white"
66
+ node_line_c = {k: "black" for k, _ in node_c.items()}
67
+ edge_c = {e: "black" for e in G.edges}
68
 
69
  if args[0]: # do_t
70
+ edge_c[(ut, t)] = "lightgrey"
71
  # G.remove_edge(ut, t)
72
+ node_line_c[t] = "red"
73
  if args[1]: # do_i
74
+ edge_c[(ui, i)] = "lightgrey"
75
+ edge_c[(t, i)] = "lightgrey"
76
  # G.remove_edges_from([(ui, i), (t, i)])
77
+ node_line_c[i] = "red"
78
  if args[2]: # do_y
79
+ edge_c[(uy, y)] = "lightgrey"
80
  # G.remove_edge(uy, y)
81
+ node_line_c[y] = "red"
82
 
83
  fs = 30
84
  options = {
 
90
  "linewidths": 2,
91
  "width": 2,
92
  }
93
+ plt.close("all")
94
+ fig, ax = plt.subplots(1, 1, figsize=(6, 4.1)) # , constrained_layout=True)
95
  # fig.patch.set_visible(False)
96
  ax.margins(x=0.06, y=0.15, tight=False)
97
  ax.axis("off")
98
+ nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle="-|>", ax=ax)
99
  # need to reuse x, y limits so that the graphs plot the same way before and after removing edges
100
  x_lim = (-1.348, 2.348)
101
  y_lim = (-0.215, 1.215)
102
  ax.set_xlim(x_lim)
103
  ax.set_ylim(y_lim)
104
+ rect = patches.FancyBboxPatch(
105
+ (1.75, -0.16),
106
+ 0.5,
107
+ 0.7,
108
+ boxstyle="round, pad=0.05, rounding_size=0",
109
+ linewidth=2,
110
+ edgecolor="black",
111
+ facecolor="none",
112
+ linestyle="-",
113
+ )
114
  ax.add_patch(rect)
115
  ax.text(1.85, 0.65, r"$\mathbf{U}_{\mathbf{x}}$", fontsize=fs)
116
 
117
  if args[0]: # do_t
118
+ fig.figimage(HAMMER, 0.26 * fig.bbox.xmax, 0.525 * fig.bbox.ymax, zorder=10)
119
  if args[1]: # do_i
120
+ fig.figimage(HAMMER, 0.5175 * fig.bbox.xmax, 0.525 * fig.bbox.ymax, zorder=11)
121
  if args[2]: # do_y
122
+ fig.figimage(HAMMER, 0.26 * fig.bbox.xmax, 0.2 * fig.bbox.ymax, zorder=12)
123
 
124
  fig.tight_layout()
125
  fig.canvas.draw()
 
127
 
128
 
129
  def brain_graph(*args):
130
+ x, m, s, a, b, v = r"$\mathbf{x}$", r"$m$", r"$s$", r"$a$", r"$b$", r"$v$"
131
+ um, us, ua, ub, uv = (
132
+ r"$\mathbf{U}_m$",
133
+ r"$\mathbf{U}_s$",
134
+ r"$\mathbf{U}_a$",
135
+ r"$\mathbf{U}_b$",
136
+ r"$\mathbf{U}_v$",
137
+ )
138
+ zx, ex = r"$\mathbf{z}_{1:L}$", r"$\boldsymbol{\epsilon}$"
139
+
140
  G = nx.DiGraph()
141
  G.add_edge(m, x)
142
  G.add_edge(s, x)
 
154
  G.add_edge(uv, v)
155
 
156
  pos = {
157
+ x: (0, 0),
158
+ zx: (-0.25, -1),
159
+ ex: (0.25, -1),
160
+ a: (0, 1),
161
+ ua: (0, 2),
162
+ s: (1, 0),
163
+ us: (1, -1),
164
+ b: (1, 1),
165
+ ub: (1, 2),
166
+ m: (-1, 0),
167
+ um: (-1, -1),
168
+ v: (-1, 1),
169
+ uv: (-1, 2),
170
  }
171
 
172
  node_c = {}
173
  for node in G:
174
+ node_c[node] = "lightgrey" if node in [x, m, s, a, b, v] else "white"
175
+ node_line_c = {k: "black" for k, _ in node_c.items()}
176
+ edge_c = {e: "black" for e in G.edges}
177
 
178
  if args[0]: # do_m
179
  # G.remove_edge(um, m)
180
+ edge_c[(um, m)] = "lightgrey"
181
+ node_line_c[m] = "red"
182
  if args[1]: # do_s
183
  # G.remove_edge(us, s)
184
+ edge_c[(us, s)] = "lightgrey"
185
+ node_line_c[s] = "red"
186
  if args[2]: # do_a
187
  # G.remove_edge(ua, a)
188
+ edge_c[(ua, a)] = "lightgrey"
189
+ node_line_c[a] = "red"
190
  if args[3]: # do_b
191
  # G.remove_edges_from([(ub, b), (s, b), (a, b)])
192
+ edge_c[(ub, b)] = "lightgrey"
193
+ edge_c[(s, b)] = "lightgrey"
194
+ edge_c[(a, b)] = "lightgrey"
195
+ node_line_c[b] = "red"
196
  if args[4]: # do_v
197
  # G.remove_edges_from([(uv, v), (a, v), (b, v)])
198
+ edge_c[(uv, v)] = "lightgrey"
199
+ edge_c[(a, v)] = "lightgrey"
200
+ edge_c[(b, v)] = "lightgrey"
201
+ node_line_c[v] = "red"
202
 
203
  fs = 30
204
  options = {
 
211
  "width": 2,
212
  }
213
 
214
+ plt.close("all")
215
+ fig, ax = plt.subplots(1, 1, figsize=(5, 5)) # , constrained_layout=True)
216
  # fig.patch.set_visible(False)
217
  ax.margins(x=0.1, y=0.08, tight=False)
218
  ax.axis("off")
219
+ nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle="-|>", ax=ax)
220
  # need to reuse x, y limits so that the graphs plot the same way before and after removing edges
221
  x_lim = (-1.32, 1.32)
222
  y_lim = (-1.414, 2.414)
223
  ax.set_xlim(x_lim)
224
  ax.set_ylim(y_lim)
225
+ rect = patches.FancyBboxPatch(
226
+ (-0.5, -1.325),
227
+ 1,
228
+ 0.65,
229
+ boxstyle="round, pad=0.05, rounding_size=0",
230
+ linewidth=2,
231
+ edgecolor="black",
232
+ facecolor="none",
233
+ linestyle="-",
234
+ )
235
  ax.add_patch(rect)
236
  # ax.text(1.85, 0.65, r"$\mathbf{U}_{\mathbf{x}}$", fontsize=fs)
237
+
238
  if args[0]: # do_m
239
+ fig.figimage(HAMMER, 0.0075 * fig.bbox.xmax, 0.395 * fig.bbox.ymax, zorder=10)
240
  if args[1]: # do_s
241
+ fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.395 * fig.bbox.ymax, zorder=11)
242
  if args[2]: # do_a
243
+ fig.figimage(HAMMER, 0.363 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=12)
244
  if args[3]: # do_b
245
+ fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=13)
246
  if args[4]: # do_v
247
+ fig.figimage(HAMMER, 0.0075 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=14)
248
  else: # b -> v
249
+ a3 = patches.FancyArrowPatch(
250
+ (0.86, 1.21),
251
+ (-0.86, 1.21),
252
+ connectionstyle="arc3,rad=.3",
253
+ linewidth=2,
254
+ arrowstyle="simple, head_width=10, head_length=10",
255
+ color="k",
256
+ )
257
  ax.add_patch(a3)
258
  # print(ax.get_xlim())
259
  # print(ax.get_ylim())
 
262
  return np.array(fig.canvas.renderer.buffer_rgba())
263
 
264
 
 
265
  def chest_graph(*args):
266
+ x, a, d, r, s = r"$\mathbf{x}$", r"$a$", r"$d$", r"$r$", r"$s$"
267
+ ua, ud, ur, us = (
268
+ r"$\mathbf{U}_a$",
269
+ r"$\mathbf{U}_d$",
270
+ r"$\mathbf{U}_r$",
271
+ r"$\mathbf{U}_s$",
272
+ )
273
+ zx, ex = r"$\mathbf{z}_{1:L}$", r"$\boldsymbol{\epsilon}$"
274
+
275
  G = nx.DiGraph()
276
  G.add_edge(ua, a)
277
  G.add_edge(ud, d)
 
286
  G.add_edge(a, x)
287
 
288
  pos = {
289
+ x: (0, 0),
290
  a: (-1, 1),
291
  d: (0, 1),
292
  r: (1, 1),
 
295
  ud: (0, 2),
296
  ur: (1, 2),
297
  us: (1, -1),
298
+ zx: (-0.25, -1),
299
  ex: (0.25, -1),
300
  }
301
 
302
  node_c = {}
303
  for node in G:
304
+ node_c[node] = "lightgrey" if node in [x, a, d, r, s] else "white"
305
 
306
+ edge_c = {e: "black" for e in G.edges}
307
+ node_line_c = {k: "black" for k, _ in node_c.items()}
308
 
309
  if args[0]: # do_r
310
  # G.remove_edge(ur, r)
311
+ edge_c[(ur, r)] = "lightgrey"
312
+ node_line_c[r] = "red"
313
  if args[1]: # do_s
314
  # G.remove_edges_from([(us, s)])
315
+ edge_c[(us, s)] = "lightgrey"
316
+ node_line_c[s] = "red"
317
  if args[2]: # do_f (do_d)
318
  # G.remove_edges_from([(ud, d), (a, d)])
319
+ edge_c[(ud, d)] = "lightgrey"
320
+ edge_c[(a, d)] = "lightgrey"
321
+ node_line_c[d] = "red"
322
  if args[3]: # do_a
323
  # G.remove_edge(ua, a)
324
+ edge_c[(ua, a)] = "lightgrey"
325
+ node_line_c[a] = "red"
326
 
327
  fs = 30
328
  options = {
 
334
  "linewidths": 2,
335
  "width": 2,
336
  }
337
+ plt.close("all")
338
+ fig, ax = plt.subplots(1, 1, figsize=(5, 5)) # , constrained_layout=True)
339
  # fig.patch.set_visible(False)
340
  ax.margins(x=0.1, y=0.08, tight=False)
341
  ax.axis("off")
342
+ nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle="-|>", ax=ax)
343
  # need to reuse x, y limits so that the graphs plot the same way before and after removing edges
344
  x_lim = (-1.32, 1.32)
345
  y_lim = (-1.414, 2.414)
346
  ax.set_xlim(x_lim)
347
  ax.set_ylim(y_lim)
348
+ rect = patches.FancyBboxPatch(
349
+ (-0.5, -1.325),
350
+ 1,
351
+ 0.65,
352
+ boxstyle="round, pad=0.05, rounding_size=0",
353
+ linewidth=2,
354
+ edgecolor="black",
355
+ facecolor="none",
356
+ linestyle="-",
357
+ )
358
  ax.add_patch(rect)
359
  ax.text(-0.9, -1.075, r"$\mathbf{U}_{\mathbf{x}}$", fontsize=fs)
360
+
361
  if args[0]: # do_r
362
+ fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=10)
363
  if args[1]: # do_s
364
+ fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.395 * fig.bbox.ymax, zorder=11)
365
  if args[2]: # do_f
366
+ fig.figimage(HAMMER, 0.363 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=12)
367
  if args[3]: # do_a
368
+ fig.figimage(HAMMER, 0.0075 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=13)
369
 
370
  fig.tight_layout()
371
  fig.canvas.draw()
 
373
 
374
 
375
  def vae_preprocess(args, pa):
376
+ if "ukbb" in args.hps:
377
  # preprocessing ukbb parents for the vae which was originally trained using
378
  # log standardized parents. The pgm was trained using [-1,1] normalization
379
  # first undo [-1,1] parent preprocessing back to original range
380
  for k, v in pa.items():
381
+ if k != "mri_seq" and k != "sex":
382
  pa[k] = (v + 1) / 2 # [-1,1] -> [0,1]
383
  _max, _min = get_attr_max_min(k)
384
  pa[k] = pa[k] * (_max - _min) + _min
385
  # log_standardize parents for vae input
386
  for k, v in pa.items():
387
  logpa_k = torch.log(v.clamp(min=1e-12))
388
+ if k == "age":
389
  pa[k] = (logpa_k - 4.112339973449707) / 0.11769197136163712
390
+ elif k == "brain_volume":
391
  pa[k] = (logpa_k - 13.965583801269531) / 0.09537758678197861
392
+ elif k == "ventricle_volume":
393
  pa[k] = (logpa_k - 10.345998764038086) / 0.43127763271331787
394
  # concatenate parents expand to input res for conditioning the vae
395
+ pa = torch.cat(
396
+ [pa[k] if len(pa[k].shape) > 1 else pa[k][..., None] for k in args.parents_x],
397
+ dim=1,
398
+ )
399
+ pa = (
400
+ pa[..., None, None].repeat(1, 1, *(args.input_res,) * 2).to(args.device).float()
401
+ )
402
  return pa
403
 
404
 
405
  def preprocess_brain(args, obs):
406
+ obs["x"] = (obs["x"][None, ...].float().to(args.device) - 127.5) / 127.5 # [-1,1]
407
  # for all other variables except x
408
+ for k in [k for k in obs.keys() if k != "x"]:
409
  obs[k] = obs[k].float().to(args.device).view(1, 1)
410
+ if k in ["age", "brain_volume", "ventricle_volume"]:
411
  k_max, k_min = get_attr_max_min(k)
412
  obs[k] = (obs[k] - k_min) / (k_max - k_min) # [0,1]
413
  obs[k] = 2 * obs[k] - 1 # [-1,1]
414
  return obs
415
 
416
 
417
+ def get_fig_arr(x, width=4, height=4, dpi=144, cmap="Greys_r", norm=None):
418
  fig = plt.figure(figsize=(width, height), dpi=dpi)
419
+ ax = plt.axes([0, 0, 1, 1], frameon=False)
420
+ if cmap == "Greys_r":
421
  ax.imshow(x, cmap=cmap, vmin=0, vmax=255)
422
  else:
423
  ax.imshow(x, cmap=cmap, norm=norm)
424
+ ax.axis("off")
425
  fig.canvas.draw()
426
  return np.array(fig.canvas.renderer.buffer_rgba())
427
 
 
432
  if x_max is None:
433
  x_max = x.max()
434
  x = (x - x_min) / (x_max - x_min) # [0,1]
435
+ return x if zero_one else 2 * x - 1 # else [-1,1]
datasets.py CHANGED
@@ -23,44 +23,45 @@ def normalize(x, x_min=None, x_max=None, zero_one=False):
23
  x_min = x.min()
24
  if x_max is None:
25
  x_max = x.max()
26
- print(f'max: {x_max}, min: {x_min}')
27
  x = (x - x_min) / (x_max - x_min) # [0,1]
28
  return x if zero_one else 2 * x - 1 # else [-1,1]
29
 
30
 
31
  class UKBBDataset(Dataset):
32
- def __init__(self, root, csv_file, transform=None, columns=None, norm=None, concat_pa=True):
 
 
33
  super().__init__()
34
  self.root = root
35
  self.transform = transform
36
  self.concat_pa = concat_pa # return concatenated parents
37
 
38
- print(f'\nLoading csv data: {csv_file}')
39
  self.df = pd.read_csv(csv_file)
40
  self.columns = columns
41
  if self.columns is None:
42
  # ['eid', 'sex', 'age', 'brain_volume', 'ventricle_volume', 'mri_seq']
43
  self.columns = list(self.df.columns) # return all
44
  self.columns.pop(0) # remove redundant 'index' column
45
- print(f'columns: {self.columns}')
46
- self.samples = {i: torch.as_tensor(
47
- self.df[i]).float() for i in self.columns}
48
 
49
- for k in ['age', 'brain_volume', 'ventricle_volume']:
50
- print(f'{k} normalization: {norm}')
51
  if k in self.columns:
52
- if norm == '[-1,1]':
53
  self.samples[k] = normalize(self.samples[k])
54
- elif norm == '[0,1]':
55
  self.samples[k] = normalize(self.samples[k], zero_one=True)
56
- elif norm == 'log_standard':
57
  self.samples[k] = log_standardize(self.samples[k])
58
  elif norm == None:
59
  pass
60
  else:
61
- NotImplementedError(f'{norm} not implemented.')
62
- print(f'#samples: {len(self.df)}')
63
- self.return_x = True if 'eid' in self.columns else False
64
 
65
  def __len__(self):
66
  return len(self.df)
@@ -69,31 +70,32 @@ class UKBBDataset(Dataset):
69
  sample = {k: v[idx] for k, v in self.samples.items()}
70
 
71
  if self.return_x:
72
- mri_seq = 'T1' if sample['mri_seq'] == 0. else 'T2_FLAIR'
73
  # Load scan
74
- filename = f'{int(sample["eid"])}_' + \
75
- mri_seq+'_unbiased_brain_rigid_to_mni.png'
76
- x = Image.open(os.path.join(self.root, 'thumbs_192x192', filename))
 
77
 
78
  if self.transform is not None:
79
- sample['x'] = self.transform(x)
80
- sample.pop('eid', None)
81
 
82
  if self.concat_pa:
83
- sample['pa'] = torch.cat([
84
- torch.tensor([sample[k]]) for k in self.columns if k != 'eid'
85
- ], dim=0)
86
 
87
  return sample
88
 
89
 
90
  def get_attr_max_min(attr):
91
  # some ukbb dataset (max, min) stats
92
- if attr == 'age':
93
  return 73, 44
94
- elif attr == 'brain_volume':
95
  return 1629520, 841919
96
- elif attr == 'ventricle_volume':
97
  return 157075, 7613.27001953125
98
  else:
99
  NotImplementedError
@@ -102,37 +104,43 @@ def get_attr_max_min(attr):
102
  def ukbb(args):
103
  csv_dir = args.data_dir
104
  augmentation = {
105
- 'train': TF.Compose([
106
- TF.Resize((args.input_res, args.input_res), antialias=None),
107
- TF.RandomCrop(size=(args.input_res, args.input_res),
108
- padding=[2*args.pad, args.pad]),
109
- TF.RandomHorizontalFlip(p=args.hflip),
110
- TF.PILToTensor()
111
- ]),
112
- 'eval': TF.Compose([
113
- TF.Resize((args.input_res, args.input_res), antialias=None),
114
- TF.PILToTensor()
115
- ])
 
 
 
 
 
 
116
  }
117
 
118
  datasets = {}
119
  # for split in ['train', 'valid', 'test']:
120
- for split in ['test']:
121
  datasets[split] = UKBBDataset(
122
  root=args.data_dir,
123
- csv_file=os.path.join(csv_dir, split+'.csv'),
124
- transform=augmentation[('eval' if split != 'train' else split)],
125
- columns=(None if not args.parents_x else ['eid'] + args.parents_x),
126
- norm=(None if not hasattr(args, 'context_norm')
127
- else args.context_norm),
128
- concat_pa=(True if not hasattr(args, 'concat_pa') else args.concat_pa))
129
 
130
  return datasets
131
 
132
 
133
  def _load_uint8(f):
134
- idx_dtype, ndim = struct.unpack('BBBB', f.read(4))[2:]
135
- shape = struct.unpack('>' + 'I' * ndim, f.read(4 * ndim))
136
  buffer_length = int(np.prod(shape))
137
  data = np.frombuffer(f.read(buffer_length), dtype=np.uint8).reshape(shape)
138
  return data
@@ -152,8 +160,8 @@ def load_idx(path: str) -> np.ndarray:
152
  ----------
153
  http://yann.lecun.com/exdb/mnist/
154
  """
155
- open_fcn = gzip.open if path.endswith('.gz') else open
156
- with open_fcn(path, 'rb') as f:
157
  return _load_uint8(f)
158
 
159
 
@@ -168,8 +176,9 @@ def _get_paths(root_dir, train):
168
  return images_path, labels_path, metrics_path
169
 
170
 
171
- def load_morphomnist_like(root_dir, train: bool = True, columns=None) \
172
- -> Tuple[np.ndarray, np.ndarray, pd.DataFrame]:
 
173
  """
174
  Args:
175
  root_dir: path to data directory
@@ -184,68 +193,84 @@ def load_morphomnist_like(root_dir, train: bool = True, columns=None) \
184
  images = load_idx(images_path)
185
  labels = load_idx(labels_path)
186
 
187
- if columns is not None and 'index' not in columns:
188
- usecols = ['index'] + list(columns)
189
  else:
190
  usecols = columns
191
- metrics = pd.read_csv(metrics_path, usecols=usecols, index_col='index')
192
  return images, labels, metrics
193
 
194
 
195
  class MorphoMNIST(Dataset):
196
- def __init__(self, root_dir, train=True, transform=None, columns=None, norm=None, concat_pa=True):
 
 
 
 
 
 
 
 
197
  self.train = train
198
  self.transform = transform
199
  self.columns = columns
200
  self.concat_pa = concat_pa
201
  self.norm = norm
202
 
203
- cols_not_digit = [c for c in self.columns if c != 'digit']
204
  images, labels, metrics_df = load_morphomnist_like(
205
- root_dir, train, cols_not_digit)
 
206
  self.images = torch.from_numpy(np.array(images)).unsqueeze(1)
207
  self.labels = F.one_hot(
208
- torch.from_numpy(np.array(labels)).long(), num_classes=10)
 
209
 
210
  if self.columns is None:
211
  self.columns = metrics_df.columns
212
  self.samples = {k: torch.tensor(metrics_df[k]) for k in cols_not_digit}
213
 
214
  self.min_max = {
215
- 'thickness': [0.87598526, 6.255515],
216
- 'intensity': [66.601204, 254.90317]
217
  }
218
 
219
  for k, v in self.samples.items(): # optional preprocessing
220
- print(f'{k} normalization: {norm}')
221
- if norm == '[-1,1]':
222
- self.samples[k] = normalize(v,
223
- x_min=self.min_max[k][0], x_max=self.min_max[k][1])
224
- elif norm == '[0,1]':
225
- self.samples[k] = normalize(v,
226
- x_min=self.min_max[k][0], x_max=self.min_max[k][1], zero_one=True)
 
 
227
  elif norm == None:
228
  pass
229
  else:
230
- NotImplementedError(f'{norm} not implemented.')
231
- print(f'#samples: {len(metrics_df)}\n')
232
 
233
- self.samples.update({'digit': self.labels})
234
 
235
  def __len__(self):
236
  return len(self.images)
237
 
238
  def __getitem__(self, idx):
239
  sample = {}
240
- sample['x'] = self.images[idx]
241
 
242
  if self.transform is not None:
243
- sample['x'] = self.transform(sample['x'])
244
 
245
  if self.concat_pa:
246
- sample['pa'] = torch.cat([
247
- v[idx] if k == 'digit' else torch.tensor([v[idx]]) for k, v in self.samples.items()],
248
- dim=0)
 
 
 
 
249
  else:
250
  sample.update({k: v[idx] for k, v in self.samples.items()})
251
  return sample
@@ -254,39 +279,43 @@ class MorphoMNIST(Dataset):
254
  def morphomnist(args):
255
  # Load data
256
  augmentation = {
257
- 'train': TF.Compose([
258
- TF.RandomCrop((args.input_res, args.input_res), padding=args.pad),
259
- ]),
260
- 'eval': TF.Compose([
261
- TF.Pad(padding=2), # (32, 32)
262
- ])
 
 
 
 
263
  }
264
 
265
  datasets = {}
266
  # for split in ['train', 'valid', 'test']:
267
- for split in ['test']:
268
  datasets[split] = MorphoMNIST(
269
  root_dir=args.data_dir,
270
- train=(split == 'train'), # test set is valid set
271
- transform=augmentation[('eval' if split != 'train' else split)],
272
  columns=args.parents_x,
273
  norm=args.context_norm,
274
- concat_pa=args.concat_pa
275
  )
276
  return datasets
277
 
278
 
279
  def preproc_mimic(batch):
280
  for k, v in batch.items():
281
- if k == 'x':
282
- batch['x'] = (batch['x'].float() - 127.5) / 127.5 # [-1,1]
283
- elif k in ['age']:
284
  batch[k] = batch[k].float().unsqueeze(-1)
285
- batch[k] = batch[k] / 100.
286
  batch[k] = batch[k] * 2 - 1 # [-1,1]
287
- elif k in ['race']:
288
  batch[k] = F.one_hot(batch[k], num_classes=3).squeeze().float()
289
- elif k in ['finding']:
290
  batch[k] = batch[k].unsqueeze(-1).float()
291
  else:
292
  batch[k] = batch[k].float().unsqueeze(-1)
@@ -294,39 +323,52 @@ def preproc_mimic(batch):
294
 
295
 
296
  class MIMICDataset(Dataset):
297
- def __init__(self, root, csv_file, transform=None, columns=None, concat_pa=True, only_pleural_eff=True):
 
 
 
 
 
 
 
 
298
  self.data = pd.read_csv(csv_file)
299
  self.transform = transform
300
- self.disease_labels = ['No Finding', 'Other', 'Pleural Effusion', 'Lung Opacity']
 
 
 
 
 
301
  self.samples = {
302
- 'age': [],
303
- 'sex': [],
304
- 'finding': [],
305
- 'x': [],
306
- 'race': [],
307
- 'lung_opacity': [],
308
- 'pleural_effusion': [],
309
  }
310
 
311
- for idx, _ in enumerate(tqdm(range(len(self.data)), desc='Loading MIMIC Data')):
312
- if only_pleural_eff and self.data.loc[idx, 'disease'] == 'Other':
313
  continue
314
- img_path = os.path.join(root, self.data.loc[idx, 'path_preproc'])
315
 
316
- lung_opacity = self.data.loc[idx, 'Lung Opacity']
317
- self.samples['lung_opacity'].append(lung_opacity)
318
 
319
- pleural_effusion = self.data.loc[idx, 'Pleural Effusion']
320
- self.samples['pleural_effusion'].append(pleural_effusion)
321
 
322
- disease = self.data.loc[idx, 'disease']
323
- finding = 0 if disease == 'No Finding' else 1
324
 
325
- self.samples['x'].append(img_path)
326
- self.samples['finding'].append(finding)
327
- self.samples['age'].append(self.data.loc[idx, 'age'])
328
- self.samples['race'].append(self.data.loc[idx, 'race_label'])
329
- self.samples['sex'].append(self.data.loc[idx, 'sex_label'])
330
 
331
  self.columns = columns
332
  if self.columns is None:
@@ -336,33 +378,36 @@ class MIMICDataset(Dataset):
336
  self.concat_pa = concat_pa
337
 
338
  def __len__(self):
339
- return len(self.samples['x'])
340
 
341
  def __getitem__(self, idx):
342
  sample = {k: v[idx] for k, v in self.samples.items()}
343
- sample['x'] = imread(sample['x']).astype(np.float32)[None, ...]
344
 
345
  for k, v in sample.items():
346
  sample[k] = torch.tensor(v)
347
 
348
  if self.transform:
349
- sample['x'] = self.transform(sample['x'])
350
 
351
  sample = preproc_mimic(sample)
352
  if self.concat_pa:
353
- sample['pa'] = torch.cat([sample[k] for k in self.columns], dim=0)
354
  return sample
355
 
356
 
357
  def mimic(args):
358
  args.csv_dir = args.data_dir
359
  datasets = {}
360
- datasets['test'] = MIMICDataset(
361
  root=args.data_dir,
362
- csv_file=os.path.join(args.csv_dir, 'mimic.sample.test.csv'),
363
  columns=args.parents_x,
364
- transform=TF.Compose([
365
- TF.Resize((args.input_res, args.input_res), antialias=None),
366
- ])
 
 
 
367
  )
368
- return datasets
 
23
  x_min = x.min()
24
  if x_max is None:
25
  x_max = x.max()
26
+ print(f"max: {x_max}, min: {x_min}")
27
  x = (x - x_min) / (x_max - x_min) # [0,1]
28
  return x if zero_one else 2 * x - 1 # else [-1,1]
29
 
30
 
31
  class UKBBDataset(Dataset):
32
+ def __init__(
33
+ self, root, csv_file, transform=None, columns=None, norm=None, concat_pa=True
34
+ ):
35
  super().__init__()
36
  self.root = root
37
  self.transform = transform
38
  self.concat_pa = concat_pa # return concatenated parents
39
 
40
+ print(f"\nLoading csv data: {csv_file}")
41
  self.df = pd.read_csv(csv_file)
42
  self.columns = columns
43
  if self.columns is None:
44
  # ['eid', 'sex', 'age', 'brain_volume', 'ventricle_volume', 'mri_seq']
45
  self.columns = list(self.df.columns) # return all
46
  self.columns.pop(0) # remove redundant 'index' column
47
+ print(f"columns: {self.columns}")
48
+ self.samples = {i: torch.as_tensor(self.df[i]).float() for i in self.columns}
 
49
 
50
+ for k in ["age", "brain_volume", "ventricle_volume"]:
51
+ print(f"{k} normalization: {norm}")
52
  if k in self.columns:
53
+ if norm == "[-1,1]":
54
  self.samples[k] = normalize(self.samples[k])
55
+ elif norm == "[0,1]":
56
  self.samples[k] = normalize(self.samples[k], zero_one=True)
57
+ elif norm == "log_standard":
58
  self.samples[k] = log_standardize(self.samples[k])
59
  elif norm == None:
60
  pass
61
  else:
62
+ NotImplementedError(f"{norm} not implemented.")
63
+ print(f"#samples: {len(self.df)}")
64
+ self.return_x = True if "eid" in self.columns else False
65
 
66
  def __len__(self):
67
  return len(self.df)
 
70
  sample = {k: v[idx] for k, v in self.samples.items()}
71
 
72
  if self.return_x:
73
+ mri_seq = "T1" if sample["mri_seq"] == 0.0 else "T2_FLAIR"
74
  # Load scan
75
+ filename = (
76
+ f'{int(sample["eid"])}_' + mri_seq + "_unbiased_brain_rigid_to_mni.png"
77
+ )
78
+ x = Image.open(os.path.join(self.root, "thumbs_192x192", filename))
79
 
80
  if self.transform is not None:
81
+ sample["x"] = self.transform(x)
82
+ sample.pop("eid", None)
83
 
84
  if self.concat_pa:
85
+ sample["pa"] = torch.cat(
86
+ [torch.tensor([sample[k]]) for k in self.columns if k != "eid"], dim=0
87
+ )
88
 
89
  return sample
90
 
91
 
92
  def get_attr_max_min(attr):
93
  # some ukbb dataset (max, min) stats
94
+ if attr == "age":
95
  return 73, 44
96
+ elif attr == "brain_volume":
97
  return 1629520, 841919
98
+ elif attr == "ventricle_volume":
99
  return 157075, 7613.27001953125
100
  else:
101
  NotImplementedError
 
104
  def ukbb(args):
105
  csv_dir = args.data_dir
106
  augmentation = {
107
+ "train": TF.Compose(
108
+ [
109
+ TF.Resize((args.input_res, args.input_res), antialias=None),
110
+ TF.RandomCrop(
111
+ size=(args.input_res, args.input_res),
112
+ padding=[2 * args.pad, args.pad],
113
+ ),
114
+ TF.RandomHorizontalFlip(p=args.hflip),
115
+ TF.PILToTensor(),
116
+ ]
117
+ ),
118
+ "eval": TF.Compose(
119
+ [
120
+ TF.Resize((args.input_res, args.input_res), antialias=None),
121
+ TF.PILToTensor(),
122
+ ]
123
+ ),
124
  }
125
 
126
  datasets = {}
127
  # for split in ['train', 'valid', 'test']:
128
+ for split in ["test"]:
129
  datasets[split] = UKBBDataset(
130
  root=args.data_dir,
131
+ csv_file=os.path.join(csv_dir, split + ".csv"),
132
+ transform=augmentation[("eval" if split != "train" else split)],
133
+ columns=(None if not args.parents_x else ["eid"] + args.parents_x),
134
+ norm=(None if not hasattr(args, "context_norm") else args.context_norm),
135
+ concat_pa=False,
136
+ )
137
 
138
  return datasets
139
 
140
 
141
  def _load_uint8(f):
142
+ idx_dtype, ndim = struct.unpack("BBBB", f.read(4))[2:]
143
+ shape = struct.unpack(">" + "I" * ndim, f.read(4 * ndim))
144
  buffer_length = int(np.prod(shape))
145
  data = np.frombuffer(f.read(buffer_length), dtype=np.uint8).reshape(shape)
146
  return data
 
160
  ----------
161
  http://yann.lecun.com/exdb/mnist/
162
  """
163
+ open_fcn = gzip.open if path.endswith(".gz") else open
164
+ with open_fcn(path, "rb") as f:
165
  return _load_uint8(f)
166
 
167
 
 
176
  return images_path, labels_path, metrics_path
177
 
178
 
179
+ def load_morphomnist_like(
180
+ root_dir, train: bool = True, columns=None
181
+ ) -> Tuple[np.ndarray, np.ndarray, pd.DataFrame]:
182
  """
183
  Args:
184
  root_dir: path to data directory
 
193
  images = load_idx(images_path)
194
  labels = load_idx(labels_path)
195
 
196
+ if columns is not None and "index" not in columns:
197
+ usecols = ["index"] + list(columns)
198
  else:
199
  usecols = columns
200
+ metrics = pd.read_csv(metrics_path, usecols=usecols, index_col="index")
201
  return images, labels, metrics
202
 
203
 
204
  class MorphoMNIST(Dataset):
205
+ def __init__(
206
+ self,
207
+ root_dir,
208
+ train=True,
209
+ transform=None,
210
+ columns=None,
211
+ norm=None,
212
+ concat_pa=True,
213
+ ):
214
  self.train = train
215
  self.transform = transform
216
  self.columns = columns
217
  self.concat_pa = concat_pa
218
  self.norm = norm
219
 
220
+ cols_not_digit = [c for c in self.columns if c != "digit"]
221
  images, labels, metrics_df = load_morphomnist_like(
222
+ root_dir, train, cols_not_digit
223
+ )
224
  self.images = torch.from_numpy(np.array(images)).unsqueeze(1)
225
  self.labels = F.one_hot(
226
+ torch.from_numpy(np.array(labels)).long(), num_classes=10
227
+ )
228
 
229
  if self.columns is None:
230
  self.columns = metrics_df.columns
231
  self.samples = {k: torch.tensor(metrics_df[k]) for k in cols_not_digit}
232
 
233
  self.min_max = {
234
+ "thickness": [0.87598526, 6.255515],
235
+ "intensity": [66.601204, 254.90317],
236
  }
237
 
238
  for k, v in self.samples.items(): # optional preprocessing
239
+ print(f"{k} normalization: {norm}")
240
+ if norm == "[-1,1]":
241
+ self.samples[k] = normalize(
242
+ v, x_min=self.min_max[k][0], x_max=self.min_max[k][1]
243
+ )
244
+ elif norm == "[0,1]":
245
+ self.samples[k] = normalize(
246
+ v, x_min=self.min_max[k][0], x_max=self.min_max[k][1], zero_one=True
247
+ )
248
  elif norm == None:
249
  pass
250
  else:
251
+ NotImplementedError(f"{norm} not implemented.")
252
+ print(f"#samples: {len(metrics_df)}\n")
253
 
254
+ self.samples.update({"digit": self.labels})
255
 
256
  def __len__(self):
257
  return len(self.images)
258
 
259
  def __getitem__(self, idx):
260
  sample = {}
261
+ sample["x"] = self.images[idx]
262
 
263
  if self.transform is not None:
264
+ sample["x"] = self.transform(sample["x"])
265
 
266
  if self.concat_pa:
267
+ sample["pa"] = torch.cat(
268
+ [
269
+ v[idx] if k == "digit" else torch.tensor([v[idx]])
270
+ for k, v in self.samples.items()
271
+ ],
272
+ dim=0,
273
+ )
274
  else:
275
  sample.update({k: v[idx] for k, v in self.samples.items()})
276
  return sample
 
279
  def morphomnist(args):
280
  # Load data
281
  augmentation = {
282
+ "train": TF.Compose(
283
+ [
284
+ TF.RandomCrop((args.input_res, args.input_res), padding=args.pad),
285
+ ]
286
+ ),
287
+ "eval": TF.Compose(
288
+ [
289
+ TF.Pad(padding=2), # (32, 32)
290
+ ]
291
+ ),
292
  }
293
 
294
  datasets = {}
295
  # for split in ['train', 'valid', 'test']:
296
+ for split in ["test"]:
297
  datasets[split] = MorphoMNIST(
298
  root_dir=args.data_dir,
299
+ train=(split == "train"), # test set is valid set
300
+ transform=augmentation[("eval" if split != "train" else split)],
301
  columns=args.parents_x,
302
  norm=args.context_norm,
303
+ concat_pa=False,
304
  )
305
  return datasets
306
 
307
 
308
  def preproc_mimic(batch):
309
  for k, v in batch.items():
310
+ if k == "x":
311
+ batch["x"] = (batch["x"].float() - 127.5) / 127.5 # [-1,1]
312
+ elif k in ["age"]:
313
  batch[k] = batch[k].float().unsqueeze(-1)
314
+ batch[k] = batch[k] / 100.0
315
  batch[k] = batch[k] * 2 - 1 # [-1,1]
316
+ elif k in ["race"]:
317
  batch[k] = F.one_hot(batch[k], num_classes=3).squeeze().float()
318
+ elif k in ["finding"]:
319
  batch[k] = batch[k].unsqueeze(-1).float()
320
  else:
321
  batch[k] = batch[k].float().unsqueeze(-1)
 
323
 
324
 
325
  class MIMICDataset(Dataset):
326
+ def __init__(
327
+ self,
328
+ root,
329
+ csv_file,
330
+ transform=None,
331
+ columns=None,
332
+ concat_pa=True,
333
+ only_pleural_eff=True,
334
+ ):
335
  self.data = pd.read_csv(csv_file)
336
  self.transform = transform
337
+ self.disease_labels = [
338
+ "No Finding",
339
+ "Other",
340
+ "Pleural Effusion",
341
+ # "Lung Opacity",
342
+ ]
343
  self.samples = {
344
+ "age": [],
345
+ "sex": [],
346
+ "finding": [],
347
+ "x": [],
348
+ "race": [],
349
+ # "lung_opacity": [],
350
+ # "pleural_effusion": [],
351
  }
352
 
353
+ for idx, _ in enumerate(tqdm(range(len(self.data)), desc="Loading MIMIC Data")):
354
+ if only_pleural_eff and self.data.loc[idx, "disease"] == "Other":
355
  continue
356
+ img_path = os.path.join(root, self.data.loc[idx, "path_preproc"])
357
 
358
+ # lung_opacity = self.data.loc[idx, "Lung Opacity"]
359
+ # self.samples["lung_opacity"].append(lung_opacity)
360
 
361
+ # pleural_effusion = self.data.loc[idx, "Pleural Effusion"]
362
+ # self.samples["pleural_effusion"].append(pleural_effusion)
363
 
364
+ disease = self.data.loc[idx, "disease"]
365
+ finding = 0 if disease == "No Finding" else 1
366
 
367
+ self.samples["x"].append(img_path)
368
+ self.samples["finding"].append(finding)
369
+ self.samples["age"].append(self.data.loc[idx, "age"])
370
+ self.samples["race"].append(self.data.loc[idx, "race_label"])
371
+ self.samples["sex"].append(self.data.loc[idx, "sex_label"])
372
 
373
  self.columns = columns
374
  if self.columns is None:
 
378
  self.concat_pa = concat_pa
379
 
380
  def __len__(self):
381
+ return len(self.samples["x"])
382
 
383
  def __getitem__(self, idx):
384
  sample = {k: v[idx] for k, v in self.samples.items()}
385
+ sample["x"] = imread(sample["x"]).astype(np.float32)[None, ...]
386
 
387
  for k, v in sample.items():
388
  sample[k] = torch.tensor(v)
389
 
390
  if self.transform:
391
+ sample["x"] = self.transform(sample["x"])
392
 
393
  sample = preproc_mimic(sample)
394
  if self.concat_pa:
395
+ sample["pa"] = torch.cat([sample[k] for k in self.columns], dim=0)
396
  return sample
397
 
398
 
399
  def mimic(args):
400
  args.csv_dir = args.data_dir
401
  datasets = {}
402
+ datasets["test"] = MIMICDataset(
403
  root=args.data_dir,
404
+ csv_file=os.path.join(args.csv_dir, "mimic.sample.test.csv"),
405
  columns=args.parents_x,
406
+ transform=TF.Compose(
407
+ [
408
+ TF.Resize((args.input_res, args.input_res), antialias=None),
409
+ ]
410
+ ),
411
+ concat_pa=False,
412
  )
413
+ return datasets
pgm/flow_pgm.py CHANGED
@@ -1,7 +1,8 @@
1
- import numpy as np
2
 
 
3
  import torch
4
- import torch.nn as nn
5
  import torch.nn.functional as F
6
 
7
  import pyro
@@ -15,53 +16,69 @@ from pyro.distributions.conditional import ConditionalTransformedDistribution
15
  from .layers import (
16
  ConditionalTransformedDistributionGumbelMax,
17
  ConditionalGumbelMax,
18
- ConditionalAffineTransform, MLP, CNN,
 
 
19
  )
20
 
21
 
 
 
 
 
 
 
22
  class BasePGM(nn.Module):
23
  def __init__(self):
24
  super().__init__()
25
 
26
  def scm(self, *args, **kwargs):
27
  def config(msg):
28
- if isinstance(msg['fn'], dist.TransformedDistribution):
29
  return TransformReparam()
30
  else:
31
  return None
 
32
  return pyro.poutine.reparam(self.model, config=config)(*args, **kwargs)
33
 
34
- def sample_scm(self, n_samples=1, t=None):
35
- with pyro.plate('obs', n_samples):
36
- samples = self.scm(t)
37
  return samples
38
 
39
- def sample(self, n_samples=1, t=None):
40
- with pyro.plate('obs', n_samples):
41
- samples = self.model(t) # model defined in parent class
42
  return samples
43
 
44
- def infer_exogeneous(self, obs):
45
  batch_size = list(obs.values())[0].shape[0]
46
  # assuming that we use transformed distributions for everything:
47
  cond_model = pyro.condition(self.sample, data=obs)
48
- cond_trace = pyro.poutine.trace(
49
- cond_model).get_trace(batch_size)
50
 
51
  output = {}
52
  for name, node in cond_trace.nodes.items():
53
- if 'z' in name or 'fn' not in node.keys():
54
  continue
55
- fn = node['fn']
56
  if isinstance(fn, dist.Independent):
57
  fn = fn.base_dist
58
  if isinstance(fn, dist.TransformedDistribution):
59
  # compute exogenous base dist (created with TransformReparam) at all sites
60
- output[name + '_base'] = T.ComposeTransform(
61
- fn.transforms).inv(node['value'])
 
62
  return output
63
 
64
- def counterfactual(self, obs, intervention, num_particles=1, detach=True, t=None):
 
 
 
 
 
 
 
65
  dag_variables = self.variables.keys()
66
  assert set(obs.keys()) == set(dag_variables)
67
  avg_cfs = {k: torch.zeros_like(obs[k]) for k in obs.keys()}
@@ -70,43 +87,50 @@ class BasePGM(nn.Module):
70
  for _ in range(num_particles):
71
  # Abduction
72
  exo_noise = self.infer_exogeneous(obs)
73
- exo_noise = {k: v.detach() if detach else v for k,
74
- v in exo_noise.items()}
75
  # condition on root node variables (no exogeneous noise available)
76
  for k in dag_variables:
77
  if k not in intervention.keys():
78
- if k not in [i.split('_base')[0] for i in exo_noise.keys()]:
79
  exo_noise[k] = obs[k]
80
  # Abducted SCM
81
- abducted_scm = pyro.poutine.condition(
82
- self.sample_scm, data=exo_noise)
83
  # Action
84
- counterfactual_scm = pyro.poutine.do(
85
- abducted_scm, data=intervention)
86
  # Prediction
87
- counterfactuals = counterfactual_scm(batch_size, t)
 
 
 
 
 
 
 
 
 
 
88
 
89
  for k, v in counterfactuals.items():
90
- avg_cfs[k] += (v / num_particles)
91
  return avg_cfs
92
 
93
 
94
  class FlowPGM(BasePGM):
95
- def __init__(self, args):
96
  super().__init__()
97
  self.variables = {
98
- 'sex': 'binary',
99
- 'mri_seq': 'binary',
100
- 'age': 'continuous',
101
- 'brain_volume': 'continuous',
102
- 'ventricle_volume': 'continuous'
103
  }
104
  # priors: s, m, a, b and v
105
  self.s_logit = nn.Parameter(torch.zeros(1))
106
  self.m_logit = nn.Parameter(torch.zeros(1))
107
- for k in ['a', 'b', 'v']:
108
- self.register_buffer(f'{k}_base_loc', torch.zeros(1))
109
- self.register_buffer(f'{k}_base_scale', torch.ones(1))
110
 
111
  # constraint, assumes data is [-1,1] normalized
112
  # normalize_transform = T.ComposeTransform([
@@ -116,23 +140,19 @@ class FlowPGM(BasePGM):
116
 
117
  # age flow
118
  self.age_module = T.ComposeTransformModule(
119
- [T.Spline(1, count_bins=4, order='linear')])
120
- self.age_flow = T.ComposeTransform([
121
- self.age_module])
122
  # self.age_module, normalize_transform])
123
 
124
  # brain volume (conditional) flow: (sex, age) -> brain_vol
125
- bvol_net = DenseNN(
126
- 2, args.widths, [1, 1], nonlinearity=nn.LeakyReLU(0.1))
127
- self.bvol_flow = ConditionalAffineTransform(
128
- context_nn=bvol_net, event_dim=0)
129
  # self.bvol_flow = [self.bvol_flow, normalize_transform]
130
 
131
  # ventricle volume (conditional) flow: (brain_vol, age) -> ventricle_vol
132
- vvol_net = DenseNN(
133
- 2, args.widths, [1, 1], nonlinearity=nn.LeakyReLU(0.1))
134
- self.vvol_flow = ConditionalAffineTransform(
135
- context_nn=vvol_net, event_dim=0)
136
  # self.vvol_flow = [self.vvol_transf, normalize_transform]
137
 
138
  # if args.setup != 'sup_pgm':
@@ -148,148 +168,152 @@ class FlowPGM(BasePGM):
148
  self.encoder_b = CNN(input_shape, num_outputs=2, context_dim=1)
149
  # q(v | x) = Normal(mu(x), sigma(x))
150
  self.encoder_v = CNN(input_shape, num_outputs=2)
151
- self.f = lambda x: args.std_fixed * \
152
- torch.ones_like(x) if args.std_fixed > 0 else F.softplus(x)
 
 
 
153
 
154
- def model(self, t=None):
155
  # p(s), sex dist
156
  ps = dist.Bernoulli(logits=self.s_logit).to_event(1)
157
- sex = pyro.sample('sex', ps)
158
 
159
  # p(m), mri_seq dist
160
  pm = dist.Bernoulli(logits=self.m_logit).to_event(1)
161
- mri_seq = pyro.sample('mri_seq', pm)
162
 
163
  # p(a), age flow
164
  pa_base = dist.Normal(self.a_base_loc, self.a_base_scale).to_event(1)
165
  pa = dist.TransformedDistribution(pa_base, self.age_flow)
166
- age = pyro.sample('age', pa)
167
 
168
  # p(b | s, a), brain volume flow
169
- pb_sa_base = dist.Normal(
170
- self.b_base_loc, self.b_base_scale).to_event(1)
171
  pb_sa = ConditionalTransformedDistribution(
172
- pb_sa_base, [self.bvol_flow]).condition(torch.cat([sex, age], dim=1))
173
- bvol = pyro.sample('brain_volume', pb_sa)
 
174
  # _ = self.bvol_transf # register with pyro
175
 
176
  # p(v | b, a), ventricle volume flow
177
- pv_ba_base = dist.Normal(
178
- self.v_base_loc, self.v_base_scale).to_event(1)
179
  pv_ba = ConditionalTransformedDistribution(
180
- pv_ba_base, [self.vvol_flow]).condition(torch.cat([bvol, age], dim=1))
181
- vvol = pyro.sample('ventricle_volume', pv_ba)
 
182
  # _ = self.vvol_transf # register with pyro
183
 
184
  return {
185
- 'sex': sex,
186
- 'mri_seq': mri_seq,
187
- 'age': age,
188
- 'brain_volume': bvol,
189
- 'ventricle_volume': vvol,
190
  }
191
 
192
- def guide(self, **obs):
193
  # guide for (optional) semi-supervised learning
194
- pyro.module('FlowPGM', self)
195
- with pyro.plate('observations', obs['x'].shape[0]):
196
  # q(m | x)
197
- if obs['mri_seq'] is None:
198
- m_prob = torch.sigmoid(self.encoder_m(obs['x']))
199
- m = pyro.sample('mri_seq', dist.Bernoulli(
200
- probs=m_prob).to_event(1))
201
 
202
  # q(v | x)
203
- if obs['ventricle_volume'] is None:
204
- v_loc, v_logscale = self.encoder_v(obs['x']).chunk(2, dim=-1)
205
  qv_x = dist.Normal(v_loc, self.f(v_logscale)).to_event(1)
206
- obs['ventricle_volume'] = pyro.sample('ventricle_volume', qv_x)
207
 
208
  # q(b | x, v)
209
- if obs['brain_volume'] is None:
210
  b_loc, b_logscale = self.encoder_b(
211
- obs['x'], y=obs['ventricle_volume']).chunk(2, dim=-1)
 
212
  qb_xv = dist.Normal(b_loc, self.f(b_logscale)).to_event(1)
213
- obs['brain_volume'] = pyro.sample('brain_volume', qb_xv)
214
 
215
  # q(s | x, b)
216
- if obs['sex'] is None:
217
- s_prob = torch.sigmoid(self.encoder_s(
218
- obs['x'], y=obs['brain_volume'])) # .squeeze()
219
- pyro.sample('sex', dist.Bernoulli(probs=s_prob).to_event(1))
 
220
 
221
  # q(a | b, v)
222
- if obs['age'] is None:
223
- ctx = torch.cat(
224
- [obs['brain_volume'], obs['ventricle_volume']], dim=-1)
225
  a_loc, a_logscale = self.encoder_a(ctx).chunk(2, dim=-1)
226
- pyro.sample('age', dist.Normal(
227
- a_loc, self.f(a_logscale)).to_event(1))
228
 
229
- def model_anticausal(self, **obs):
230
  # assumes all variables are observed
231
- pyro.module('FlowPGM', self)
232
- with pyro.plate('observations', obs['x'].shape[0]):
233
  # q(v | x)
234
- v_loc, v_logscale = self.encoder_v(obs['x']).chunk(2, dim=-1)
235
  qv_x = dist.Normal(v_loc, self.f(v_logscale)).to_event(1)
236
- pyro.sample('ventricle_volume_aux', qv_x,
237
- obs=obs['ventricle_volume'])
238
 
239
  # q(b | x, v)
240
  b_loc, b_logscale = self.encoder_b(
241
- obs['x'], y=obs['ventricle_volume']).chunk(2, dim=-1)
 
242
  qb_xv = dist.Normal(b_loc, self.f(b_logscale)).to_event(1)
243
- pyro.sample('brain_volume_aux', qb_xv, obs=obs['brain_volume'])
244
 
245
  # q(a | b, v)
246
- ctx = torch.cat(
247
- [obs['brain_volume'], obs['ventricle_volume']], dim=-1)
248
  a_loc, a_logscale = self.encoder_a(ctx).chunk(2, dim=-1)
249
- pyro.sample('age_aux', dist.Normal(
250
- a_loc, self.f(a_logscale)).to_event(1), obs=obs['age'])
 
 
 
251
 
252
  # q(s | x, b)
253
- s_prob = torch.sigmoid(self.encoder_s(
254
- obs['x'], y=obs['brain_volume']))
255
  qs_xb = dist.Bernoulli(probs=s_prob).to_event(1)
256
- pyro.sample('sex_aux', qs_xb, obs=obs['sex'])
257
 
258
  # q(m | x)
259
- m_prob = torch.sigmoid(self.encoder_m(obs['x']))
260
  qm_x = dist.Bernoulli(probs=m_prob).to_event(1)
261
- pyro.sample('mri_seq_aux', qm_x, obs=obs['mri_seq'])
262
 
263
- def predict(self, **obs):
264
  # q(v | x)
265
- v_loc, v_logscale = self.encoder_v(obs['x']).chunk(2, dim=-1)
266
  # v_loc = torch.tanh(v_loc)
267
  # q(b | x, v)
268
- b_loc, b_logscale = self.encoder_b(
269
- obs['x'], y=obs['ventricle_volume']).chunk(2, dim=-1)
 
270
  # b_loc = torch.tanh(b_loc)
271
  # q(a | b, v)
272
- ctx = torch.cat([obs['brain_volume'], obs['ventricle_volume']], dim=-1)
273
  a_loc, a_logscale = self.encoder_a(ctx).chunk(2, dim=-1)
274
  # a_loc = torch.tanh(b_loc)
275
  # q(s | x, b)
276
- s_prob = torch.sigmoid(self.encoder_s(obs['x'], y=obs['brain_volume']))
277
  # q(m | x)
278
- m_prob = torch.sigmoid(self.encoder_m(obs['x']))
279
 
280
  return {
281
- 'sex': s_prob,
282
- 'mri_seq': m_prob,
283
- 'age': a_loc,
284
- 'brain_volume': b_loc,
285
- 'ventricle_volume': v_loc,
286
  }
287
 
288
- def svi_model(self, **obs):
289
- with pyro.plate('observations', obs['x'].shape[0]):
290
  pyro.condition(self.model, data=obs)()
291
 
292
- def guide_pass(self, **obs):
293
  pass
294
 
295
 
@@ -297,173 +321,174 @@ class MorphoMNISTPGM(BasePGM):
297
  def __init__(self, args):
298
  super().__init__()
299
  self.variables = {
300
- 'thickness': 'continuous',
301
- 'intensity': 'continuous',
302
- 'digit': 'categorical',
303
  }
304
  # priors
305
  self.digit_logits = nn.Parameter(torch.zeros(1, 10)) # uniform prior
306
- for k in ['t', 'i']: # thickness, intensity, standard Gaussian
307
- self.register_buffer(f'{k}_base_loc', torch.zeros(1))
308
- self.register_buffer(f'{k}_base_scale', torch.ones(1))
309
 
310
  # constraint, assumes data is [-1,1] normalized
311
- normalize_transform = T.ComposeTransform([
312
- T.SigmoidTransform(), T.AffineTransform(loc=-1, scale=2)])
 
313
 
314
  # thickness flow
315
  self.thickness_module = T.ComposeTransformModule(
316
- [T.Spline(1, count_bins=4, order='linear')])
317
- self.thickness_flow = T.ComposeTransform([
318
- self.thickness_module, normalize_transform])
 
 
319
 
320
  # intensity (conditional) flow: thickness -> intensity
321
- intensity_net = DenseNN(
322
- 1, args.widths, [1, 1], nonlinearity=nn.GELU())
323
  self.context_nn = ConditionalAffineTransform(
324
- context_nn=intensity_net, event_dim=0)
 
325
  self.intensity_flow = [self.context_nn, normalize_transform]
326
 
327
- if args.setup != 'sup_pgm':
328
  # anticausal predictors
329
  input_shape = (args.input_channels, args.input_res, args.input_res)
330
  # q(t | x, i) = Normal(mu(x, i), sigma(x, i)), 2 outputs: loc & scale
331
- self.encoder_t = CNN(input_shape, num_outputs=2,
332
- context_dim=1, width=8)
333
  # q(i | x) = Normal(mu(x), sigma(x))
334
  self.encoder_i = CNN(input_shape, num_outputs=2, width=8)
335
  # q(y | x) = Categorical(pi(x))
336
  self.encoder_y = CNN(input_shape, num_outputs=10, width=8)
337
- self.f = lambda x: args.std_fixed * \
338
- torch.ones_like(x) if args.std_fixed > 0 else F.softplus(x)
339
-
340
- def model(self, t=None):
341
- pyro.module('MorphoMNISTPGM', self)
 
 
 
342
  # p(y), digit label prior dist
343
  py = dist.OneHotCategorical(
344
- probs=F.softmax(self.digit_logits, dim=-1)).to_event(1)
 
345
  # with pyro.poutine.scale(scale=0.05):
346
- digit = pyro.sample('digit', py)
347
 
348
  # p(t), thickness flow
349
  pt_base = dist.Normal(self.t_base_loc, self.t_base_scale).to_event(1)
350
  pt = dist.TransformedDistribution(pt_base, self.thickness_flow)
351
- thickness = pyro.sample('thickness', pt)
352
 
353
  # p(i | t), intensity conditional flow
354
  pi_t_base = dist.Normal(self.i_base_loc, self.i_base_scale).to_event(1)
355
  pi_t = ConditionalTransformedDistribution(
356
- pi_t_base, self.intensity_flow).condition(thickness)
357
- intensity = pyro.sample('intensity', pi_t)
 
358
  _ = self.context_nn
359
 
360
- return {'thickness': thickness, 'intensity': intensity, 'digit': digit}
361
 
362
- def guide(self, **obs):
363
  # guide for (optional) semi-supervised learning
364
- with pyro.plate('observations', obs['x'].shape[0]):
365
  # q(i | x)
366
- if obs['intensity'] is None:
367
- i_loc, i_logscale = self.encoder_i(obs['x']).chunk(2, dim=-1)
368
- qi_t = dist.Normal(torch.tanh(
369
- i_loc), self.f(i_logscale)).to_event(1)
370
- obs['intensity'] = pyro.sample('intensity', qi_t)
371
 
372
  # q(t | x, i)
373
- if obs['thickness'] is None:
374
- t_loc, t_logscale = self.encoder_t(
375
- obs['x'], y=obs['intensity']).chunk(2, dim=-1)
376
- qt_x = dist.Normal(torch.tanh(
377
- t_loc), self.f(t_logscale)).to_event(1)
378
- obs['thickness'] = pyro.sample('thickness', qt_x)
379
 
380
  # q(y | x)
381
- if obs['digit'] is None:
382
- y_prob = F.softmax(self.encoder_y(obs['x']), dim=-1)
383
- qy_x = dist.OneHotCategorical(probs=y_prob).to_event(1)
384
- pyro.sample('digit', qy_x)
385
 
386
- def model_anticausal(self, **obs):
387
  # assumes all variables are observed & continuous ones are in [-1,1]
388
- pyro.module('MorphoMNISTPGM', self)
389
- with pyro.plate('observations', obs['x'].shape[0]):
390
  # q(t | x, i)
391
- t_loc, t_logscale = self.encoder_t(
392
- obs['x'], y=obs['intensity']).chunk(2, dim=-1)
393
- qt_x = dist.Normal(torch.tanh(
394
- t_loc), self.f(t_logscale)).to_event(1)
395
- pyro.sample('thickness_aux', qt_x, obs=obs['thickness'])
396
 
397
  # q(i | x)
398
- i_loc, i_logscale = self.encoder_i(obs['x']).chunk(2, dim=-1)
399
- qi_t = dist.Normal(torch.tanh(
400
- i_loc), self.f(i_logscale)).to_event(1)
401
- pyro.sample('intensity_aux', qi_t, obs=obs['intensity'])
402
 
403
  # q(y | x)
404
- y_prob = F.softmax(self.encoder_y(obs['x']), dim=-1)
405
- qy_x = dist.OneHotCategorical(probs=y_prob).to_event(1)
406
- pyro.sample('digit_aux', qy_x, obs=obs['digit'])
407
 
408
- def predict(self, **obs):
409
  # q(t | x, i)
410
- t_loc, t_logscale = self.encoder_t(
411
- obs['x'], y=obs['intensity']).chunk(2, dim=-1)
 
412
  t_loc = torch.tanh(t_loc)
413
  # q(i | x)
414
- i_loc, i_logscale = self.encoder_i(obs['x']).chunk(2, dim=-1)
415
  i_loc = torch.tanh(i_loc)
416
  # q(y | x)
417
- y_prob = F.softmax(self.encoder_y(obs['x']), dim=-1)
418
- return {'thickness': t_loc, 'intensity': i_loc, 'digit': y_prob}
419
 
420
- def svi_model(self, **obs):
421
- with pyro.plate('observations', obs['x'].shape[0]):
422
  pyro.condition(self.model, data=obs)()
423
 
424
- def guide_pass(self, **obs):
425
  pass
426
 
427
 
428
- class ChestPGM(nn.Module):
429
- def __init__(self, args):
430
  super().__init__()
431
  self.variables = {
432
- 'race': 'categorical',
433
- 'sex': 'binary',
434
- 'finding': 'binary',
435
- 'age': 'continuous',
436
  }
437
  # Discrete variables that are not root nodes
438
- self.discrete_variables = {
439
- 'finding': 'binary',
440
- }
441
-
442
- # prior age
443
- for k in ['a']:
444
- self.register_buffer(f'{k}_base_loc', torch.zeros(1))
445
- self.register_buffer(f'{k}_base_scale', torch.ones(1))
446
-
447
- # age flow
448
  self.age_flow_components = T.ComposeTransformModule([T.Spline(1)])
449
  # self.age_constraints = T.ComposeTransform([
450
  # T.AffineTransform(loc=4.09541458484, scale=0.32548387126),
451
  # T.ExpTransform()])
452
- self.age_flow = T.ComposeTransform([
453
- self.age_flow_components,
454
- # self.age_constraints,
455
- ])
456
-
 
457
  # Finding (conditional) via MLP, a -> f
458
- finding_net = DenseNN(
459
- 1, [8, 16], param_dims=[2], nonlinearity=nn.Sigmoid())#.cuda()
460
  self.finding_transform_GumbelMax = ConditionalGumbelMax(
461
- context_nn=finding_net,
462
- event_dim=0)
463
  # log space for sex and race
464
- self.sex_logit = nn.Parameter(torch.zeros(1))
465
- # self.sex_logit = pyro.param(torch.zeros(1))
466
- self.race_logits = nn.Parameter(np.log(1/3)*torch.ones(1, 3))
467
 
468
  input_shape = (args.input_channels, args.input_res, args.input_res)
469
 
@@ -477,207 +502,112 @@ class ChestPGM(nn.Module):
477
  # q(a | x, f) ~ Normal(mu(x), sigma(x))
478
  self.encoder_a = CNN(input_shape, num_outputs=1, context_dim=1)
479
 
480
- def model(self, t=None):
 
481
  # p(s), sex dist
482
  ps = dist.Bernoulli(logits=self.sex_logit).to_event(1)
483
- sex = pyro.sample('sex', ps)
484
 
485
  # p(a), age flow
486
  pa_base = dist.Normal(self.a_base_loc, self.a_base_scale).to_event(1)
487
  pa = dist.TransformedDistribution(pa_base, self.age_flow)
488
- age = pyro.sample('age', pa)
489
  # age_ = self.age_constraints.inv(age)
490
  _ = self.age_flow_components # register with pyro
491
 
492
  # p(r), race dist
493
- race_dist = dist.OneHotCategorical(logits=self.race_logits).to_event(0)
494
- race = pyro.sample('race', race_dist)
495
 
496
  # p(f | a), finding as OneHotCategorical conditioned on age
497
- finding_dist_base = dist.Gumbel(
498
- torch.zeros(1), torch.ones(1)).to_event(1)
 
499
  finding_dist = ConditionalTransformedDistributionGumbelMax(
500
- finding_dist_base,
501
- [self.finding_transform_GumbelMax]).condition(age)
502
  finding = pyro.sample("finding", finding_dist)
503
 
504
  return {
505
- 'sex': sex,
506
- 'race': race,
507
- 'age': age,
508
- 'finding': finding,
509
  }
510
 
511
- def guide(self, **obs):
512
- # print([k for k, v in obs.items() if v is not None])
513
- pyro.module('ChestPGM', self)
514
- with pyro.plate('observations', obs['x'].shape[0]):
515
  # q(s | x)
516
- if obs['sex'] is None:
517
- s_prob = torch.sigmoid(self.encoder_s(obs['x']))
518
- s = pyro.sample('sex', dist.Bernoulli(
519
- probs=s_prob).to_event(1))
520
  # q(r | x)
521
- if obs['race'] is None:
522
- r_logits = F.softmax(self.encoder_r(
523
- obs['x']), dim=-1) # .squeeze()
524
- r = pyro.sample('race', dist.OneHotCategorical(
525
- logits=r_logits).to_event(1))
526
  # q(f | x)
527
- if obs['finding'] is None:
528
- f_prob = torch.sigmoid(self.encoder_ff(obs['x']))
529
- f = pyro.sample('finding', dist.Bernoulli(
530
- probs=f_prob).to_event(1))
531
  # q(a | x, f)
532
- if obs['age'] is None:
533
- a_loc = self.encoder_a(
534
- obs['x'], y=obs['finding'])
535
- pyro.sample('age', dist.Normal(
536
- a_loc, torch.ones_like(a_loc)).to_event(1))
537
-
538
- def model_anticausal(self, **obs):
 
539
  # assumes all variables are observed, train classfiers
540
- pyro.module('ChestPGM', self)
541
- with pyro.plate('observations', obs['x'].shape[0]):
542
- # q(s | x)
543
- s_prob = torch.sigmoid(self.encoder_s(obs['x']))
544
- s = pyro.sample('sex', dist.Bernoulli(
545
- probs=s_prob).to_event(1))
 
546
 
547
  # q(r | x)
548
- r_logits = F.softmax(self.encoder_r(
549
- obs['x']), dim=-1) # .squeeze()
550
- r = pyro.sample('race', dist.OneHotCategorical(
551
- logits=r_logits).to_event(1))
552
 
553
  # q(f | x)
554
- f_prob = torch.sigmoid(self.encoder_f(obs['x']))
555
  qf_x = dist.Bernoulli(probs=f_prob).to_event(1)
556
- obs['finding'] = pyro.sample('finding', qf_x)
557
 
558
  # q(a | x, f)
559
- a_loc = self.encoder_a(
560
- obs['x'], y=obs['finding'])
561
- pyro.sample('age', dist.Normal(
562
- a_loc, torch.ones_like(a_loc)).to_event(1))
563
-
564
- def predict(self, **obs):
565
- # q(s | x)
566
- s_prob = torch.sigmoid(self.encoder_s(obs['x']))
567
- # q(r | x)
568
- r_logits = F.softmax(self.encoder_r(obs['x']), dim=-1) # .squeeze()
569
- # q(f | x)
570
- f_prob = torch.sigmoid(self.encoder_f(obs['x']))
571
- # q(a | x, f)
572
- a_loc = self.encoder_a(
573
- obs['x'], y=obs['finding'])
574
-
575
- return {
576
- 'sex': s_prob,
577
- 'race': r_logits,
578
- 'age': a_loc,
579
- 'finding': f_prob,
580
- }
581
-
582
- def predict_unnorm(self, **obs):
583
  # q(s | x)
584
- s_prob = self.encoder_s(obs['x'])
585
  # q(r | x)
586
- r_logits = self.encoder_r(obs['x'])
587
  # q(f | x)
588
- f_prob = self.encoder_f(obs['x'])
589
- qf_x = dist.Bernoulli(probs=torch.sigmoid(f_prob)).to_event(1)
590
- obs_finding = pyro.sample('finding', qf_x)
591
  # q(a | x, f)
592
- a_loc = self.encoder_a(
593
- obs['x'],
594
- # y=obs['finding'],
595
- y=obs_finding,
596
- )
597
 
598
  return {
599
- 'sex': s_prob,
600
- 'race': r_logits,
601
- 'age': a_loc,
602
- 'finding': f_prob,
603
  }
604
 
605
- def svi_model(self, **obs):
606
- with pyro.plate('observations', obs['x'].shape[0]):
607
  pyro.condition(self.model, data=obs)()
608
 
609
- def guide_pass(self, **obs):
610
  pass
611
-
612
- def infer_exogeneous(self, obs):
613
- batch_size = list(obs.values())[0].shape[0]
614
- # assuming that we use transformed distributions for everything:
615
- cond_model = pyro.condition(self.sample, data=obs)
616
- cond_trace = pyro.poutine.trace(
617
- cond_model).get_trace(batch_size)
618
-
619
- output = {}
620
- for name, node in cond_trace.nodes.items():
621
- if 'z' in name or 'fn' not in node.keys():
622
- continue
623
- fn = node['fn']
624
- if isinstance(fn, dist.Independent):
625
- fn = fn.base_dist
626
- if isinstance(fn, dist.TransformedDistribution):
627
- # compute exogenous base dist (created with TransformReparam) at all sites
628
- output[name + '_base'] = T.ComposeTransform(
629
- fn.transforms).inv(node['value'])
630
- return output
631
-
632
- def scm(self, *args, **kwargs):
633
- def config(msg):
634
- if isinstance(msg['fn'], dist.TransformedDistribution):
635
- return TransformReparam()
636
- else:
637
- return None
638
- return pyro.poutine.reparam(self.model, config=config)(*args, **kwargs)
639
-
640
- def sample_scm(self, n_samples=1, t=None):
641
- with pyro.plate('obs', n_samples):
642
- samples = self.scm(t)
643
- return samples
644
-
645
- def sample(self, n_samples=1, t=None):
646
- with pyro.plate('obs', n_samples):
647
- samples = self.model(t)
648
- return samples
649
-
650
- def counterfactual(self, obs, intervention, num_particles=1, detach=True, t=None):
651
- dag_variables = self.variables.keys()
652
- obs_ = {k: v for k, v in obs.items() if k in dag_variables}
653
- assert set(obs_.keys()) == set(dag_variables)
654
- # For continuos variables
655
- avg_cfs = {k: torch.zeros_like(obs_[k]) for k in obs_.keys()}
656
- batch_size = list(obs_.values())[0].shape[0]
657
-
658
- for _ in range(num_particles):
659
- # Abduction
660
- exo_noise = self.infer_exogeneous(obs_)
661
- exo_noise = {k: v.detach() if detach else v for k,
662
- v in exo_noise.items()}
663
- # condition on root node variables (no exogeneous noise available)
664
- for k in dag_variables:
665
- if k not in intervention.keys():
666
- if k not in [i.split('_base')[0] for i in exo_noise.keys()]:
667
- exo_noise[k] = obs_[k]
668
- # Abducted SCM
669
- abducted_scm = pyro.poutine.condition(
670
- self.sample_scm, data=exo_noise)
671
- # Action
672
- counterfactual_scm = pyro.poutine.do(
673
- abducted_scm, data=intervention)
674
- # Prediction
675
- counterfactuals = counterfactual_scm(batch_size, t)
676
- # Check if we should change "finding", i.e. if its parents and itself are not intervened,
677
- # then we use its observed value. This is needed due to stochastic abduction of discrete variables.
678
- if 'age' not in intervention.keys() and 'finding' not in intervention.keys():
679
- counterfactuals['finding'] = obs_['finding']
680
-
681
- for k, v in counterfactuals.items():
682
- avg_cfs[k] += (v / num_particles)
683
- return avg_cfs
 
1
+ from typing import Dict
2
 
3
+ import numpy as np
4
  import torch
5
+ from torch import nn, Tensor
6
  import torch.nn.functional as F
7
 
8
  import pyro
 
16
  from .layers import (
17
  ConditionalTransformedDistributionGumbelMax,
18
  ConditionalGumbelMax,
19
+ ConditionalAffineTransform,
20
+ MLP,
21
+ CNN,
22
  )
23
 
24
 
25
+ class Hparams:
26
+ def update(self, dict):
27
+ for k, v in dict.items():
28
+ setattr(self, k, v)
29
+
30
+
31
  class BasePGM(nn.Module):
32
  def __init__(self):
33
  super().__init__()
34
 
35
  def scm(self, *args, **kwargs):
36
  def config(msg):
37
+ if isinstance(msg["fn"], dist.TransformedDistribution):
38
  return TransformReparam()
39
  else:
40
  return None
41
+
42
  return pyro.poutine.reparam(self.model, config=config)(*args, **kwargs)
43
 
44
+ def sample_scm(self, n_samples: int = 1):
45
+ with pyro.plate("obs", n_samples):
46
+ samples = self.scm()
47
  return samples
48
 
49
+ def sample(self, n_samples: int = 1):
50
+ with pyro.plate("obs", n_samples):
51
+ samples = self.model() # NOTE: not ideal as model is defined in child class
52
  return samples
53
 
54
+ def infer_exogeneous(self, obs: Dict[str, Tensor]) -> Dict[str, Tensor]:
55
  batch_size = list(obs.values())[0].shape[0]
56
  # assuming that we use transformed distributions for everything:
57
  cond_model = pyro.condition(self.sample, data=obs)
58
+ cond_trace = pyro.poutine.trace(cond_model).get_trace(batch_size)
 
59
 
60
  output = {}
61
  for name, node in cond_trace.nodes.items():
62
+ if "z" in name or "fn" not in node.keys():
63
  continue
64
+ fn = node["fn"]
65
  if isinstance(fn, dist.Independent):
66
  fn = fn.base_dist
67
  if isinstance(fn, dist.TransformedDistribution):
68
  # compute exogenous base dist (created with TransformReparam) at all sites
69
+ output[name + "_base"] = T.ComposeTransform(fn.transforms).inv(
70
+ node["value"]
71
+ )
72
  return output
73
 
74
+ def counterfactual(
75
+ self,
76
+ obs: Dict[str, Tensor],
77
+ intervention: Dict[str, Tensor],
78
+ num_particles: int = 1,
79
+ detach: bool = True,
80
+ ) -> Dict[str, Tensor]:
81
+ # NOTE: not ideal as "variables" is defined in child class
82
  dag_variables = self.variables.keys()
83
  assert set(obs.keys()) == set(dag_variables)
84
  avg_cfs = {k: torch.zeros_like(obs[k]) for k in obs.keys()}
 
87
  for _ in range(num_particles):
88
  # Abduction
89
  exo_noise = self.infer_exogeneous(obs)
90
+ exo_noise = {k: v.detach() if detach else v for k, v in exo_noise.items()}
 
91
  # condition on root node variables (no exogeneous noise available)
92
  for k in dag_variables:
93
  if k not in intervention.keys():
94
+ if k not in [i.split("_base")[0] for i in exo_noise.keys()]:
95
  exo_noise[k] = obs[k]
96
  # Abducted SCM
97
+ abducted_scm = pyro.poutine.condition(self.sample_scm, data=exo_noise)
 
98
  # Action
99
+ counterfactual_scm = pyro.poutine.do(abducted_scm, data=intervention)
 
100
  # Prediction
101
+ counterfactuals = counterfactual_scm(batch_size)
102
+
103
+ if hasattr(self, "discrete_variables"): # hack for MIMIC
104
+ # Check if we should change "finding", i.e. if its parents and/or
105
+ # itself are not intervened on, then we use its observed value.
106
+ # This is used due to stochastic abduction of discrete variables
107
+ if (
108
+ "age" not in intervention.keys()
109
+ and "finding" not in intervention.keys()
110
+ ):
111
+ counterfactuals["finding"] = obs["finding"]
112
 
113
  for k, v in counterfactuals.items():
114
+ avg_cfs[k] += v / num_particles
115
  return avg_cfs
116
 
117
 
118
  class FlowPGM(BasePGM):
119
+ def __init__(self, args: Hparams):
120
  super().__init__()
121
  self.variables = {
122
+ "sex": "binary",
123
+ "mri_seq": "binary",
124
+ "age": "continuous",
125
+ "brain_volume": "continuous",
126
+ "ventricle_volume": "continuous",
127
  }
128
  # priors: s, m, a, b and v
129
  self.s_logit = nn.Parameter(torch.zeros(1))
130
  self.m_logit = nn.Parameter(torch.zeros(1))
131
+ for k in ["a", "b", "v"]:
132
+ self.register_buffer(f"{k}_base_loc", torch.zeros(1))
133
+ self.register_buffer(f"{k}_base_scale", torch.ones(1))
134
 
135
  # constraint, assumes data is [-1,1] normalized
136
  # normalize_transform = T.ComposeTransform([
 
140
 
141
  # age flow
142
  self.age_module = T.ComposeTransformModule(
143
+ [T.Spline(1, count_bins=4, order="linear")]
144
+ )
145
+ self.age_flow = T.ComposeTransform([self.age_module])
146
  # self.age_module, normalize_transform])
147
 
148
  # brain volume (conditional) flow: (sex, age) -> brain_vol
149
+ bvol_net = DenseNN(2, args.widths, [1, 1], nonlinearity=nn.LeakyReLU(0.1))
150
+ self.bvol_flow = ConditionalAffineTransform(context_nn=bvol_net, event_dim=0)
 
 
151
  # self.bvol_flow = [self.bvol_flow, normalize_transform]
152
 
153
  # ventricle volume (conditional) flow: (brain_vol, age) -> ventricle_vol
154
+ vvol_net = DenseNN(2, args.widths, [1, 1], nonlinearity=nn.LeakyReLU(0.1))
155
+ self.vvol_flow = ConditionalAffineTransform(context_nn=vvol_net, event_dim=0)
 
 
156
  # self.vvol_flow = [self.vvol_transf, normalize_transform]
157
 
158
  # if args.setup != 'sup_pgm':
 
168
  self.encoder_b = CNN(input_shape, num_outputs=2, context_dim=1)
169
  # q(v | x) = Normal(mu(x), sigma(x))
170
  self.encoder_v = CNN(input_shape, num_outputs=2)
171
+ self.f = (
172
+ lambda x: args.std_fixed * torch.ones_like(x)
173
+ if args.std_fixed > 0
174
+ else F.softplus(x)
175
+ )
176
 
177
+ def model(self) -> Dict[str, Tensor]:
178
  # p(s), sex dist
179
  ps = dist.Bernoulli(logits=self.s_logit).to_event(1)
180
+ sex = pyro.sample("sex", ps)
181
 
182
  # p(m), mri_seq dist
183
  pm = dist.Bernoulli(logits=self.m_logit).to_event(1)
184
+ mri_seq = pyro.sample("mri_seq", pm)
185
 
186
  # p(a), age flow
187
  pa_base = dist.Normal(self.a_base_loc, self.a_base_scale).to_event(1)
188
  pa = dist.TransformedDistribution(pa_base, self.age_flow)
189
+ age = pyro.sample("age", pa)
190
 
191
  # p(b | s, a), brain volume flow
192
+ pb_sa_base = dist.Normal(self.b_base_loc, self.b_base_scale).to_event(1)
 
193
  pb_sa = ConditionalTransformedDistribution(
194
+ pb_sa_base, [self.bvol_flow]
195
+ ).condition(torch.cat([sex, age], dim=1))
196
+ bvol = pyro.sample("brain_volume", pb_sa)
197
  # _ = self.bvol_transf # register with pyro
198
 
199
  # p(v | b, a), ventricle volume flow
200
+ pv_ba_base = dist.Normal(self.v_base_loc, self.v_base_scale).to_event(1)
 
201
  pv_ba = ConditionalTransformedDistribution(
202
+ pv_ba_base, [self.vvol_flow]
203
+ ).condition(torch.cat([bvol, age], dim=1))
204
+ vvol = pyro.sample("ventricle_volume", pv_ba)
205
  # _ = self.vvol_transf # register with pyro
206
 
207
  return {
208
+ "sex": sex,
209
+ "mri_seq": mri_seq,
210
+ "age": age,
211
+ "brain_volume": bvol,
212
+ "ventricle_volume": vvol,
213
  }
214
 
215
+ def guide(self, **obs) -> None:
216
  # guide for (optional) semi-supervised learning
217
+ pyro.module("FlowPGM", self)
218
+ with pyro.plate("observations", obs["x"].shape[0]):
219
  # q(m | x)
220
+ if obs["mri_seq"] is None:
221
+ m_prob = torch.sigmoid(self.encoder_m(obs["x"]))
222
+ m = pyro.sample("mri_seq", dist.Bernoulli(probs=m_prob).to_event(1))
 
223
 
224
  # q(v | x)
225
+ if obs["ventricle_volume"] is None:
226
+ v_loc, v_logscale = self.encoder_v(obs["x"]).chunk(2, dim=-1)
227
  qv_x = dist.Normal(v_loc, self.f(v_logscale)).to_event(1)
228
+ obs["ventricle_volume"] = pyro.sample("ventricle_volume", qv_x)
229
 
230
  # q(b | x, v)
231
+ if obs["brain_volume"] is None:
232
  b_loc, b_logscale = self.encoder_b(
233
+ obs["x"], y=obs["ventricle_volume"]
234
+ ).chunk(2, dim=-1)
235
  qb_xv = dist.Normal(b_loc, self.f(b_logscale)).to_event(1)
236
+ obs["brain_volume"] = pyro.sample("brain_volume", qb_xv)
237
 
238
  # q(s | x, b)
239
+ if obs["sex"] is None:
240
+ s_prob = torch.sigmoid(
241
+ self.encoder_s(obs["x"], y=obs["brain_volume"])
242
+ ) # .squeeze()
243
+ pyro.sample("sex", dist.Bernoulli(probs=s_prob).to_event(1))
244
 
245
  # q(a | b, v)
246
+ if obs["age"] is None:
247
+ ctx = torch.cat([obs["brain_volume"], obs["ventricle_volume"]], dim=-1)
 
248
  a_loc, a_logscale = self.encoder_a(ctx).chunk(2, dim=-1)
249
+ pyro.sample("age", dist.Normal(a_loc, self.f(a_logscale)).to_event(1))
 
250
 
251
+ def model_anticausal(self, **obs) -> None:
252
  # assumes all variables are observed
253
+ pyro.module("FlowPGM", self)
254
+ with pyro.plate("observations", obs["x"].shape[0]):
255
  # q(v | x)
256
+ v_loc, v_logscale = self.encoder_v(obs["x"]).chunk(2, dim=-1)
257
  qv_x = dist.Normal(v_loc, self.f(v_logscale)).to_event(1)
258
+ pyro.sample("ventricle_volume_aux", qv_x, obs=obs["ventricle_volume"])
 
259
 
260
  # q(b | x, v)
261
  b_loc, b_logscale = self.encoder_b(
262
+ obs["x"], y=obs["ventricle_volume"]
263
+ ).chunk(2, dim=-1)
264
  qb_xv = dist.Normal(b_loc, self.f(b_logscale)).to_event(1)
265
+ pyro.sample("brain_volume_aux", qb_xv, obs=obs["brain_volume"])
266
 
267
  # q(a | b, v)
268
+ ctx = torch.cat([obs["brain_volume"], obs["ventricle_volume"]], dim=-1)
 
269
  a_loc, a_logscale = self.encoder_a(ctx).chunk(2, dim=-1)
270
+ pyro.sample(
271
+ "age_aux",
272
+ dist.Normal(a_loc, self.f(a_logscale)).to_event(1),
273
+ obs=obs["age"],
274
+ )
275
 
276
  # q(s | x, b)
277
+ s_prob = torch.sigmoid(self.encoder_s(obs["x"], y=obs["brain_volume"]))
 
278
  qs_xb = dist.Bernoulli(probs=s_prob).to_event(1)
279
+ pyro.sample("sex_aux", qs_xb, obs=obs["sex"])
280
 
281
  # q(m | x)
282
+ m_prob = torch.sigmoid(self.encoder_m(obs["x"]))
283
  qm_x = dist.Bernoulli(probs=m_prob).to_event(1)
284
+ pyro.sample("mri_seq_aux", qm_x, obs=obs["mri_seq"])
285
 
286
+ def predict(self, **obs) -> Dict[str, Tensor]:
287
  # q(v | x)
288
+ v_loc, v_logscale = self.encoder_v(obs["x"]).chunk(2, dim=-1)
289
  # v_loc = torch.tanh(v_loc)
290
  # q(b | x, v)
291
+ b_loc, b_logscale = self.encoder_b(obs["x"], y=obs["ventricle_volume"]).chunk(
292
+ 2, dim=-1
293
+ )
294
  # b_loc = torch.tanh(b_loc)
295
  # q(a | b, v)
296
+ ctx = torch.cat([obs["brain_volume"], obs["ventricle_volume"]], dim=-1)
297
  a_loc, a_logscale = self.encoder_a(ctx).chunk(2, dim=-1)
298
  # a_loc = torch.tanh(b_loc)
299
  # q(s | x, b)
300
+ s_prob = torch.sigmoid(self.encoder_s(obs["x"], y=obs["brain_volume"]))
301
  # q(m | x)
302
+ m_prob = torch.sigmoid(self.encoder_m(obs["x"]))
303
 
304
  return {
305
+ "sex": s_prob,
306
+ "mri_seq": m_prob,
307
+ "age": a_loc,
308
+ "brain_volume": b_loc,
309
+ "ventricle_volume": v_loc,
310
  }
311
 
312
+ def svi_model(self, **obs) -> None:
313
+ with pyro.plate("observations", obs["x"].shape[0]):
314
  pyro.condition(self.model, data=obs)()
315
 
316
+ def guide_pass(self, **obs) -> None:
317
  pass
318
 
319
 
 
321
  def __init__(self, args):
322
  super().__init__()
323
  self.variables = {
324
+ "thickness": "continuous",
325
+ "intensity": "continuous",
326
+ "digit": "categorical",
327
  }
328
  # priors
329
  self.digit_logits = nn.Parameter(torch.zeros(1, 10)) # uniform prior
330
+ for k in ["t", "i"]: # thickness, intensity, standard Gaussian
331
+ self.register_buffer(f"{k}_base_loc", torch.zeros(1))
332
+ self.register_buffer(f"{k}_base_scale", torch.ones(1))
333
 
334
  # constraint, assumes data is [-1,1] normalized
335
+ normalize_transform = T.ComposeTransform(
336
+ [T.SigmoidTransform(), T.AffineTransform(loc=-1, scale=2)]
337
+ )
338
 
339
  # thickness flow
340
  self.thickness_module = T.ComposeTransformModule(
341
+ [T.Spline(1, count_bins=4, order="linear")]
342
+ )
343
+ self.thickness_flow = T.ComposeTransform(
344
+ [self.thickness_module, normalize_transform]
345
+ )
346
 
347
  # intensity (conditional) flow: thickness -> intensity
348
+ intensity_net = DenseNN(1, args.widths, [1, 1], nonlinearity=nn.GELU())
 
349
  self.context_nn = ConditionalAffineTransform(
350
+ context_nn=intensity_net, event_dim=0
351
+ )
352
  self.intensity_flow = [self.context_nn, normalize_transform]
353
 
354
+ if args.setup != "sup_pgm":
355
  # anticausal predictors
356
  input_shape = (args.input_channels, args.input_res, args.input_res)
357
  # q(t | x, i) = Normal(mu(x, i), sigma(x, i)), 2 outputs: loc & scale
358
+ self.encoder_t = CNN(input_shape, num_outputs=2, context_dim=1, width=8)
 
359
  # q(i | x) = Normal(mu(x), sigma(x))
360
  self.encoder_i = CNN(input_shape, num_outputs=2, width=8)
361
  # q(y | x) = Categorical(pi(x))
362
  self.encoder_y = CNN(input_shape, num_outputs=10, width=8)
363
+ self.f = (
364
+ lambda x: args.std_fixed * torch.ones_like(x)
365
+ if args.std_fixed > 0
366
+ else F.softplus(x)
367
+ )
368
+
369
+ def model(self) -> Dict[str, Tensor]:
370
+ pyro.module("MorphoMNISTPGM", self)
371
  # p(y), digit label prior dist
372
  py = dist.OneHotCategorical(
373
+ probs=F.softmax(self.digit_logits, dim=-1)
374
+ ) # .to_event(1)
375
  # with pyro.poutine.scale(scale=0.05):
376
+ digit = pyro.sample("digit", py)
377
 
378
  # p(t), thickness flow
379
  pt_base = dist.Normal(self.t_base_loc, self.t_base_scale).to_event(1)
380
  pt = dist.TransformedDistribution(pt_base, self.thickness_flow)
381
+ thickness = pyro.sample("thickness", pt)
382
 
383
  # p(i | t), intensity conditional flow
384
  pi_t_base = dist.Normal(self.i_base_loc, self.i_base_scale).to_event(1)
385
  pi_t = ConditionalTransformedDistribution(
386
+ pi_t_base, self.intensity_flow
387
+ ).condition(thickness)
388
+ intensity = pyro.sample("intensity", pi_t)
389
  _ = self.context_nn
390
 
391
+ return {"thickness": thickness, "intensity": intensity, "digit": digit}
392
 
393
+ def guide(self, **obs) -> None:
394
  # guide for (optional) semi-supervised learning
395
+ with pyro.plate("observations", obs["x"].shape[0]):
396
  # q(i | x)
397
+ if obs["intensity"] is None:
398
+ i_loc, i_logscale = self.encoder_i(obs["x"]).chunk(2, dim=-1)
399
+ qi_t = dist.Normal(torch.tanh(i_loc), self.f(i_logscale)).to_event(1)
400
+ obs["intensity"] = pyro.sample("intensity", qi_t)
 
401
 
402
  # q(t | x, i)
403
+ if obs["thickness"] is None:
404
+ t_loc, t_logscale = self.encoder_t(obs["x"], y=obs["intensity"]).chunk(
405
+ 2, dim=-1
406
+ )
407
+ qt_x = dist.Normal(torch.tanh(t_loc), self.f(t_logscale)).to_event(1)
408
+ obs["thickness"] = pyro.sample("thickness", qt_x)
409
 
410
  # q(y | x)
411
+ if obs["digit"] is None:
412
+ y_prob = F.softmax(self.encoder_y(obs["x"]), dim=-1)
413
+ qy_x = dist.OneHotCategorical(probs=y_prob) # .to_event(1)
414
+ pyro.sample("digit", qy_x)
415
 
416
+ def model_anticausal(self, **obs) -> None:
417
  # assumes all variables are observed & continuous ones are in [-1,1]
418
+ pyro.module("MorphoMNISTPGM", self)
419
+ with pyro.plate("observations", obs["x"].shape[0]):
420
  # q(t | x, i)
421
+ t_loc, t_logscale = self.encoder_t(obs["x"], y=obs["intensity"]).chunk(
422
+ 2, dim=-1
423
+ )
424
+ qt_x = dist.Normal(torch.tanh(t_loc), self.f(t_logscale)).to_event(1)
425
+ pyro.sample("thickness_aux", qt_x, obs=obs["thickness"])
426
 
427
  # q(i | x)
428
+ i_loc, i_logscale = self.encoder_i(obs["x"]).chunk(2, dim=-1)
429
+ qi_t = dist.Normal(torch.tanh(i_loc), self.f(i_logscale)).to_event(1)
430
+ pyro.sample("intensity_aux", qi_t, obs=obs["intensity"])
 
431
 
432
  # q(y | x)
433
+ y_prob = F.softmax(self.encoder_y(obs["x"]), dim=-1)
434
+ qy_x = dist.OneHotCategorical(probs=y_prob) # .to_event(1)
435
+ pyro.sample("digit_aux", qy_x, obs=obs["digit"])
436
 
437
+ def predict(self, **obs) -> Dict[str, Tensor]:
438
  # q(t | x, i)
439
+ t_loc, t_logscale = self.encoder_t(obs["x"], y=obs["intensity"]).chunk(
440
+ 2, dim=-1
441
+ )
442
  t_loc = torch.tanh(t_loc)
443
  # q(i | x)
444
+ i_loc, i_logscale = self.encoder_i(obs["x"]).chunk(2, dim=-1)
445
  i_loc = torch.tanh(i_loc)
446
  # q(y | x)
447
+ y_prob = F.softmax(self.encoder_y(obs["x"]), dim=-1)
448
+ return {"thickness": t_loc, "intensity": i_loc, "digit": y_prob}
449
 
450
+ def svi_model(self, **obs) -> None:
451
+ with pyro.plate("observations", obs["x"].shape[0]):
452
  pyro.condition(self.model, data=obs)()
453
 
454
+ def guide_pass(self, **obs) -> None:
455
  pass
456
 
457
 
458
+ class ChestPGM(BasePGM):
459
+ def __init__(self, args: Hparams):
460
  super().__init__()
461
  self.variables = {
462
+ "race": "categorical",
463
+ "sex": "binary",
464
+ "finding": "binary",
465
+ "age": "continuous",
466
  }
467
  # Discrete variables that are not root nodes
468
+ self.discrete_variables = {"finding": "binary"}
469
+ # define base distributions
470
+ for k in ["a"]: # , "f"]:
471
+ self.register_buffer(f"{k}_base_loc", torch.zeros(1))
472
+ self.register_buffer(f"{k}_base_scale", torch.ones(1))
473
+ # age spline flow
 
 
 
 
474
  self.age_flow_components = T.ComposeTransformModule([T.Spline(1)])
475
  # self.age_constraints = T.ComposeTransform([
476
  # T.AffineTransform(loc=4.09541458484, scale=0.32548387126),
477
  # T.ExpTransform()])
478
+ self.age_flow = T.ComposeTransform(
479
+ [
480
+ self.age_flow_components,
481
+ # self.age_constraints,
482
+ ]
483
+ )
484
  # Finding (conditional) via MLP, a -> f
485
+ finding_net = DenseNN(1, [8, 16], param_dims=[2], nonlinearity=nn.Sigmoid())
 
486
  self.finding_transform_GumbelMax = ConditionalGumbelMax(
487
+ context_nn=finding_net, event_dim=0
488
+ )
489
  # log space for sex and race
490
+ self.sex_logit = nn.Parameter(np.log(1 / 2) * torch.ones(1))
491
+ self.race_logits = nn.Parameter(np.log(1 / 3) * torch.ones(1, 3))
 
492
 
493
  input_shape = (args.input_channels, args.input_res, args.input_res)
494
 
 
502
  # q(a | x, f) ~ Normal(mu(x), sigma(x))
503
  self.encoder_a = CNN(input_shape, num_outputs=1, context_dim=1)
504
 
505
+ def model(self) -> Dict[str, Tensor]:
506
+ pyro.module("ChestPGM", self)
507
  # p(s), sex dist
508
  ps = dist.Bernoulli(logits=self.sex_logit).to_event(1)
509
+ sex = pyro.sample("sex", ps)
510
 
511
  # p(a), age flow
512
  pa_base = dist.Normal(self.a_base_loc, self.a_base_scale).to_event(1)
513
  pa = dist.TransformedDistribution(pa_base, self.age_flow)
514
+ age = pyro.sample("age", pa)
515
  # age_ = self.age_constraints.inv(age)
516
  _ = self.age_flow_components # register with pyro
517
 
518
  # p(r), race dist
519
+ pr = dist.OneHotCategorical(logits=self.race_logits) # .to_event(1)
520
+ race = pyro.sample("race", pr)
521
 
522
  # p(f | a), finding as OneHotCategorical conditioned on age
523
+ # finding_dist_base = dist.Gumbel(self.f_base_loc, self.f_base_scale).to_event(1)
524
+ finding_dist_base = dist.Gumbel(torch.zeros(1), torch.ones(1)).to_event(1)
525
+
526
  finding_dist = ConditionalTransformedDistributionGumbelMax(
527
+ finding_dist_base, [self.finding_transform_GumbelMax]
528
+ ).condition(age)
529
  finding = pyro.sample("finding", finding_dist)
530
 
531
  return {
532
+ "sex": sex,
533
+ "race": race,
534
+ "age": age,
535
+ "finding": finding,
536
  }
537
 
538
+ def guide(self, **obs) -> None:
539
+ with pyro.plate("observations", obs["x"].shape[0]):
 
 
540
  # q(s | x)
541
+ if obs["sex"] is None:
542
+ s_prob = torch.sigmoid(self.encoder_s(obs["x"]))
543
+ pyro.sample("sex", dist.Bernoulli(probs=s_prob).to_event(1))
 
544
  # q(r | x)
545
+ if obs["race"] is None:
546
+ r_probs = F.softmax(self.encoder_r(obs["x"]), dim=-1)
547
+ qr_x = dist.OneHotCategorical(probs=r_probs) # .to_event(1)
548
+ pyro.sample("race", qr_x)
 
549
  # q(f | x)
550
+ if obs["finding"] is None:
551
+ f_prob = torch.sigmoid(self.encoder_f(obs["x"]))
552
+ qf_x = dist.Bernoulli(probs=f_prob).to_event(1)
553
+ obs["finding"] = pyro.sample("finding", qf_x)
554
  # q(a | x, f)
555
+ if obs["age"] is None:
556
+ a_loc, a_logscale = self.encoder_a(obs["x"], y=obs["finding"]).chunk(
557
+ 2, dim=-1
558
+ )
559
+ qa_xf = dist.Normal(a_loc, self.f(a_logscale)).to_event(1)
560
+ pyro.sample("age_aux", qa_xf)
561
+
562
+ def model_anticausal(self, **obs) -> None:
563
  # assumes all variables are observed, train classfiers
564
+ pyro.module("ChestPGM", self)
565
+ with pyro.plate("observations", obs["x"].shape[0]):
566
+ # q(s | x)
567
+ s_prob = torch.sigmoid(self.encoder_s(obs["x"]))
568
+ qs_x = dist.Bernoulli(probs=s_prob).to_event(1)
569
+ # with pyro.poutine.scale(scale=0.8):
570
+ pyro.sample("sex_aux", qs_x, obs=obs["sex"])
571
 
572
  # q(r | x)
573
+ r_probs = F.softmax(self.encoder_r(obs["x"]), dim=-1)
574
+ qr_x = dist.OneHotCategorical(probs=r_probs) # .to_event(1)
575
+ # with pyro.poutine.scale(scale=0.5):
576
+ pyro.sample("race_aux", qr_x, obs=obs["race"])
577
 
578
  # q(f | x)
579
+ f_prob = torch.sigmoid(self.encoder_f(obs["x"]))
580
  qf_x = dist.Bernoulli(probs=f_prob).to_event(1)
581
+ pyro.sample("finding_aux", qf_x, obs=obs["finding"])
582
 
583
  # q(a | x, f)
584
+ a_loc, a_logscale = self.encoder_a(obs["x"], y=obs["finding"]).chunk(
585
+ 2, dim=-1
586
+ )
587
+ qa_xf = dist.Normal(a_loc, self.f(a_logscale)).to_event(1)
588
+ # with pyro.poutine.scale(scale=2):
589
+ pyro.sample("age_aux", qa_xf, obs=obs["age"])
590
+
591
+ def predict(self, **obs) -> Dict[str, Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592
  # q(s | x)
593
+ s_prob = torch.sigmoid(self.encoder_s(obs["x"]))
594
  # q(r | x)
595
+ r_probs = F.softmax(self.encoder_r(obs["x"]), dim=-1)
596
  # q(f | x)
597
+ f_prob = torch.sigmoid(self.encoder_f(obs["x"]))
 
 
598
  # q(a | x, f)
599
+ a_loc, _ = self.encoder_a(obs["x"], y=obs["finding"]).chunk(2, dim=-1)
 
 
 
 
600
 
601
  return {
602
+ "sex": s_prob,
603
+ "race": r_probs,
604
+ "finding": f_prob,
605
+ "age": a_loc,
606
  }
607
 
608
+ def svi_model(self, **obs) -> None:
609
+ with pyro.plate("observations", obs["x"].shape[0]):
610
  pyro.condition(self.model, data=obs)()
611
 
612
+ def guide_pass(self, **obs) -> None:
613
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pgm/layers.py CHANGED
@@ -7,7 +7,7 @@ from typing import Dict
7
  from pyro.distributions.conditional import (
8
  ConditionalTransformModule,
9
  ConditionalTransformedDistribution,
10
- TransformedDistribution
11
  )
12
  from pyro.distributions.torch_distribution import TorchDistributionMixin
13
 
@@ -25,7 +25,8 @@ class ConditionalAffineTransform(ConditionalTransformModule):
25
  def condition(self, context):
26
  loc, log_scale = self.context_nn(context)
27
  return torch.distributions.transforms.AffineTransform(
28
- loc, log_scale.exp(), event_dim=self.event_dim)
 
29
 
30
 
31
  class MLP(nn.Module):
@@ -58,27 +59,27 @@ class CNN(nn.Module):
58
  nn.BatchNorm2d(width),
59
  activation,
60
  (nn.MaxPool2d(2, 2) if res > 32 else nn.Identity()),
61
- nn.Conv2d(width, 2*width, 3, 2, 1, bias=False),
62
- nn.BatchNorm2d(2*width),
 
 
 
63
  activation,
64
- nn.Conv2d(2*width, 2*width, 3, 1, 1, bias=False),
65
- nn.BatchNorm2d(2*width),
66
  activation,
67
- nn.Conv2d(2*width, 4*width, 3, 2, 1, bias=False),
68
- nn.BatchNorm2d(4*width),
69
  activation,
70
- nn.Conv2d(4*width, 4*width, 3, 1, 1, bias=False),
71
- nn.BatchNorm2d(4*width),
72
  activation,
73
- nn.Conv2d(4*width, 8*width, 3, 2, 1, bias=False),
74
- nn.BatchNorm2d(8*width),
75
- activation
76
  )
77
  self.fc = nn.Sequential(
78
- nn.Linear(8*width + context_dim, 8*width, bias=False),
79
- nn.BatchNorm1d(8*width),
80
  activation,
81
- nn.Linear(8*width, num_outputs)
82
  )
83
 
84
  def forward(self, x, y=None):
@@ -96,7 +97,9 @@ class ArgMaxGumbelMax(Transform):
96
  super(ArgMaxGumbelMax, self).__init__(cache_size=cache_size)
97
  self.logits = logits
98
  self._event_dim = event_dim
99
- self._categorical = pyro.distributions.torch.Categorical(logits=self.logits).to_event(0)
 
 
100
 
101
  @property
102
  def event_dim(self):
@@ -104,7 +107,7 @@ class ArgMaxGumbelMax(Transform):
104
 
105
  def __call__(self, gumbels):
106
  """
107
- Computes the forward transform
108
  """
109
  assert self.logits != None, "Logits not defined."
110
 
@@ -126,8 +129,8 @@ class ArgMaxGumbelMax(Transform):
126
 
127
  @property
128
  def domain(self):
129
- """"
130
- Domain of input(gumbel variables), Real
131
  """
132
  if self.event_dim == 0:
133
  return constraints.real
@@ -135,8 +138,8 @@ class ArgMaxGumbelMax(Transform):
135
 
136
  @property
137
  def codomain(self):
138
- """"
139
- Domain of output(categorical variables), should be natural numbers, but set to Real for now
140
  """
141
  if self.event_dim == 0:
142
  return constraints.real
@@ -147,28 +150,32 @@ class ArgMaxGumbelMax(Transform):
147
  assert self.logits != None, "Logits not defined."
148
 
149
  uniforms = torch.rand(
150
- self.logits.shape, dtype=self.logits.dtype, device=self.logits.device)
 
151
  gumbels = -((-(uniforms.log())).log())
152
  # print(f'gumbels: {gumbels.size()}, {gumbels.dtype}')
153
  # (batch_size, num_classes) mask to select kth class
154
  # print(f'k : {k.size()}')
155
- mask = F.one_hot(k.squeeze(-1).to(torch.int64),
156
- num_classes=self.logits.shape[-1])
 
157
  # print(f'mask: {mask.size()}, {mask.dtype}')
158
  # (batch_size, 1) select topgumbel for truncation of other classes
159
- topgumbel = (mask * gumbels).sum(dim=-1, keepdim=True) - \
160
- (mask * self.logits).sum(dim=-1, keepdim=True)
 
161
  mask = 1 - mask # invert mask to select other != k classes
162
  g = gumbels + self.logits
163
  # (batch_size, num_classes)
164
- epsilons = -torch.log(mask * torch.exp(-g) +
165
- torch.exp(-topgumbel)) - (mask * self.logits)
 
166
  return epsilons
167
 
168
  def log_abs_det_jacobian(self, x, y):
169
  """We use the log_abs_det_jacobian to account for the categorical prob
170
- x: Gumbels; y: argmax(x+logits)
171
- P(y) = softmax
172
  """
173
  # print(f"logits: {torch.log(F.softmax(self.logits, dim=-1)).size()}")
174
  # print(f'y: {y.size()} ')
@@ -188,7 +195,8 @@ class ConditionalGumbelMax(ConditionalTransformModule):
188
  def condition(self, context):
189
  """Given context (age), output the Categorical results"""
190
  logits = self.context_nn(
191
- context) # The logits for calculating argmax(Gumbel + logits)
 
192
  return ArgMaxGumbelMax(logits)
193
 
194
  def _logits(self, context):
@@ -197,8 +205,8 @@ class ConditionalGumbelMax(ConditionalTransformModule):
197
 
198
  @property
199
  def domain(self):
200
- """"
201
- Domain of input(gumbel variables), Real
202
  """
203
  if self.event_dim == 0:
204
  return constraints.real
@@ -206,8 +214,8 @@ class ConditionalGumbelMax(ConditionalTransformModule):
206
 
207
  @property
208
  def codomain(self):
209
- """"
210
- Domain of output(categorical variables), should be natural numbers, but set to Real for now
211
  """
212
  if self.event_dim == 0:
213
  return constraints.real
@@ -215,8 +223,7 @@ class ConditionalGumbelMax(ConditionalTransformModule):
215
 
216
 
217
  class TransformedDistributionGumbelMax(TransformedDistribution, TorchDistributionMixin):
218
- r""" Define a TransformedDistribution class for Gumbel max
219
- """
220
  arg_constraints: Dict[str, constraints.Constraint] = {}
221
 
222
  def log_prob(self, value):
@@ -233,15 +240,16 @@ class TransformedDistributionGumbelMax(TransformedDistribution, TorchDistributio
233
  for transform in reversed(self.transforms):
234
  x = transform.inv(y)
235
  event_dim += transform.domain.event_dim - transform.codomain.event_dim
236
- log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y),
237
- event_dim - transform.domain.event_dim)
 
 
238
  y = x
239
  # print(f"log_prob: {log_prob.size()}")
240
  return log_prob
241
 
242
 
243
  class ConditionalTransformedDistributionGumbelMax(ConditionalTransformedDistribution):
244
-
245
  def condition(self, context):
246
  base_dist = self.base_dist.condition(context)
247
  transforms = [t.condition(context) for t in self.transforms]
@@ -249,4 +257,4 @@ class ConditionalTransformedDistributionGumbelMax(ConditionalTransformedDistribu
249
  return TransformedDistributionGumbelMax(base_dist, transforms)
250
 
251
  def clear_cache(self):
252
- pass
 
7
  from pyro.distributions.conditional import (
8
  ConditionalTransformModule,
9
  ConditionalTransformedDistribution,
10
+ TransformedDistribution,
11
  )
12
  from pyro.distributions.torch_distribution import TorchDistributionMixin
13
 
 
25
  def condition(self, context):
26
  loc, log_scale = self.context_nn(context)
27
  return torch.distributions.transforms.AffineTransform(
28
+ loc, log_scale.exp(), event_dim=self.event_dim
29
+ )
30
 
31
 
32
  class MLP(nn.Module):
 
59
  nn.BatchNorm2d(width),
60
  activation,
61
  (nn.MaxPool2d(2, 2) if res > 32 else nn.Identity()),
62
+ nn.Conv2d(width, 2 * width, 3, 2, 1, bias=False),
63
+ nn.BatchNorm2d(2 * width),
64
+ activation,
65
+ nn.Conv2d(2 * width, 2 * width, 3, 1, 1, bias=False),
66
+ nn.BatchNorm2d(2 * width),
67
  activation,
68
+ nn.Conv2d(2 * width, 4 * width, 3, 2, 1, bias=False),
69
+ nn.BatchNorm2d(4 * width),
70
  activation,
71
+ nn.Conv2d(4 * width, 4 * width, 3, 1, 1, bias=False),
72
+ nn.BatchNorm2d(4 * width),
73
  activation,
74
+ nn.Conv2d(4 * width, 8 * width, 3, 2, 1, bias=False),
75
+ nn.BatchNorm2d(8 * width),
76
  activation,
 
 
 
77
  )
78
  self.fc = nn.Sequential(
79
+ nn.Linear(8 * width + context_dim, 8 * width, bias=False),
80
+ nn.BatchNorm1d(8 * width),
81
  activation,
82
+ nn.Linear(8 * width, num_outputs),
83
  )
84
 
85
  def forward(self, x, y=None):
 
97
  super(ArgMaxGumbelMax, self).__init__(cache_size=cache_size)
98
  self.logits = logits
99
  self._event_dim = event_dim
100
+ self._categorical = pyro.distributions.torch.Categorical(
101
+ logits=self.logits
102
+ ).to_event(0)
103
 
104
  @property
105
  def event_dim(self):
 
107
 
108
  def __call__(self, gumbels):
109
  """
110
+ Computes the forward transform
111
  """
112
  assert self.logits != None, "Logits not defined."
113
 
 
129
 
130
  @property
131
  def domain(self):
132
+ """ "
133
+ Domain of input(gumbel variables), Real
134
  """
135
  if self.event_dim == 0:
136
  return constraints.real
 
138
 
139
  @property
140
  def codomain(self):
141
+ """ "
142
+ Domain of output(categorical variables), should be natural numbers, but set to Real for now
143
  """
144
  if self.event_dim == 0:
145
  return constraints.real
 
150
  assert self.logits != None, "Logits not defined."
151
 
152
  uniforms = torch.rand(
153
+ self.logits.shape, dtype=self.logits.dtype, device=self.logits.device
154
+ )
155
  gumbels = -((-(uniforms.log())).log())
156
  # print(f'gumbels: {gumbels.size()}, {gumbels.dtype}')
157
  # (batch_size, num_classes) mask to select kth class
158
  # print(f'k : {k.size()}')
159
+ mask = F.one_hot(
160
+ k.squeeze(-1).to(torch.int64), num_classes=self.logits.shape[-1]
161
+ )
162
  # print(f'mask: {mask.size()}, {mask.dtype}')
163
  # (batch_size, 1) select topgumbel for truncation of other classes
164
+ topgumbel = (mask * gumbels).sum(dim=-1, keepdim=True) - (
165
+ mask * self.logits
166
+ ).sum(dim=-1, keepdim=True)
167
  mask = 1 - mask # invert mask to select other != k classes
168
  g = gumbels + self.logits
169
  # (batch_size, num_classes)
170
+ epsilons = -torch.log(mask * torch.exp(-g) + torch.exp(-topgumbel)) - (
171
+ mask * self.logits
172
+ )
173
  return epsilons
174
 
175
  def log_abs_det_jacobian(self, x, y):
176
  """We use the log_abs_det_jacobian to account for the categorical prob
177
+ x: Gumbels; y: argmax(x+logits)
178
+ P(y) = softmax
179
  """
180
  # print(f"logits: {torch.log(F.softmax(self.logits, dim=-1)).size()}")
181
  # print(f'y: {y.size()} ')
 
195
  def condition(self, context):
196
  """Given context (age), output the Categorical results"""
197
  logits = self.context_nn(
198
+ context
199
+ ) # The logits for calculating argmax(Gumbel + logits)
200
  return ArgMaxGumbelMax(logits)
201
 
202
  def _logits(self, context):
 
205
 
206
  @property
207
  def domain(self):
208
+ """ "
209
+ Domain of input(gumbel variables), Real
210
  """
211
  if self.event_dim == 0:
212
  return constraints.real
 
214
 
215
  @property
216
  def codomain(self):
217
+ """ "
218
+ Domain of output(categorical variables), should be natural numbers, but set to Real for now
219
  """
220
  if self.event_dim == 0:
221
  return constraints.real
 
223
 
224
 
225
  class TransformedDistributionGumbelMax(TransformedDistribution, TorchDistributionMixin):
226
+ r"""Define a TransformedDistribution class for Gumbel max"""
 
227
  arg_constraints: Dict[str, constraints.Constraint] = {}
228
 
229
  def log_prob(self, value):
 
240
  for transform in reversed(self.transforms):
241
  x = transform.inv(y)
242
  event_dim += transform.domain.event_dim - transform.codomain.event_dim
243
+ log_prob = log_prob - _sum_rightmost(
244
+ transform.log_abs_det_jacobian(x, y),
245
+ event_dim - transform.domain.event_dim,
246
+ )
247
  y = x
248
  # print(f"log_prob: {log_prob.size()}")
249
  return log_prob
250
 
251
 
252
  class ConditionalTransformedDistributionGumbelMax(ConditionalTransformedDistribution):
 
253
  def condition(self, context):
254
  base_dist = self.base_dist.condition(context)
255
  transforms = [t.condition(context) for t in self.transforms]
 
257
  return TransformedDistributionGumbelMax(base_dist, transforms)
258
 
259
  def clear_cache(self):
260
+ pass
vae.py CHANGED
@@ -6,9 +6,17 @@ import torch.distributions as dist
6
 
7
  EPS = -9 # minimum logscale
8
 
 
9
  @torch.jit.script
10
  def gaussian_kl(q_loc, q_logscale, p_loc, p_logscale):
11
- return -0.5 + p_logscale - q_logscale + 0.5 * (q_logscale.exp().pow(2) + (q_loc - p_loc).pow(2)) / p_logscale.exp().pow(2)
 
 
 
 
 
 
 
12
 
13
 
14
  @torch.jit.script
@@ -17,20 +25,28 @@ def sample_gaussian(loc, logscale):
17
 
18
 
19
  class Block(nn.Module):
20
- def __init__(self, in_width, bottleneck, out_width, kernel_size=3, residual=True,
21
- down_rate=None, version=None):
 
 
 
 
 
 
 
 
22
  super().__init__()
23
  self.d = down_rate
24
  self.residual = residual
25
  padding = 0 if kernel_size == 1 else 1
26
 
27
- if version == 'light': # for ukbb
28
  activation = nn.ReLU()
29
  self.conv = nn.Sequential(
30
  activation,
31
  nn.Conv2d(in_width, bottleneck, kernel_size, 1, padding),
32
  activation,
33
- nn.Conv2d(bottleneck, out_width, kernel_size, 1, padding)
34
  )
35
  else: # for morphomnist
36
  activation = nn.GELU()
@@ -42,7 +58,7 @@ class Block(nn.Module):
42
  activation,
43
  nn.Conv2d(bottleneck, bottleneck, kernel_size, 1, padding),
44
  activation,
45
- nn.Conv2d(bottleneck, out_width, 1, 1)
46
  )
47
 
48
  if self.residual and (self.d or in_width > out_width):
@@ -67,30 +83,41 @@ class Encoder(nn.Module):
67
  super().__init__()
68
  # parse architecture
69
  stages = []
70
- for i, stage in enumerate(args.enc_arch.split(',')):
71
- start = stage.index('b') + 1
72
- end = stage.index('d') if 'd' in stage else None
73
  n_blocks = int(stage[start:end])
74
 
75
  if i == 0: # define network stem
76
- if n_blocks == 0 and 'd' not in stage:
77
- print('Using stride=2 conv encoder stem.')
78
- self.stem = nn.Conv2d(args.input_channels, args.widths[1],
79
- kernel_size=7, stride=2, padding=3)
 
 
 
 
 
80
  continue
81
  else:
82
- self.stem = nn.Conv2d(args.input_channels, args.widths[0],
83
- kernel_size=7, stride=1, padding=3)
 
 
 
 
 
84
 
85
  stages += [(args.widths[i], None) for _ in range(n_blocks)]
86
- if 'd' in stage: # downsampling block
87
- stages += [(args.widths[i+1], int(stage[stage.index('d') + 1]))]
88
  blocks = []
89
  for i, (width, d) in enumerate(stages):
90
- prev_width = stages[max(0, i-1)][0]
91
  bottleneck = int(prev_width / args.bottleneck)
92
- blocks.append(Block(prev_width, bottleneck, width, down_rate=d,
93
- version=args.vr))
 
94
  # scale weights of last conv layer in each block
95
  for b in blocks:
96
  b.conv[-1].weight.data *= np.sqrt(1 / len(blocks))
@@ -113,7 +140,7 @@ class DecoderBlock(nn.Module):
113
  super().__init__()
114
  bottleneck = int(in_width / args.bottleneck)
115
  self.res = resolution
116
- self.stochastic = (self.res <= args.z_max_res)
117
  self.z_dim = args.z_dim
118
  self.cond_prior = args.cond_prior
119
  k = 3 if self.res > 2 else 1
@@ -125,21 +152,35 @@ class DecoderBlock(nn.Module):
125
  # self.z_feat_proj = nn.Conv2d(self.z_dim + in_width, out_width, 1)
126
  self.z_feat_proj = nn.Conv2d(self.z_dim + in_width, out_width, 1)
127
 
128
- self.prior = Block(p_in_width, bottleneck, 2*self.z_dim + in_width,
129
- kernel_size=k, residual=False, version=args.vr)
 
 
 
 
 
 
130
  if self.stochastic:
131
- self.posterior = Block(2*in_width + args.context_dim, bottleneck, 2*self.z_dim,
132
- kernel_size=k, residual=False, version=args.vr)
 
 
 
 
 
 
133
  self.z_proj = nn.Conv2d(self.z_dim + args.context_dim, in_width, 1)
134
- self.conv = Block(in_width, bottleneck, out_width, kernel_size=k, version=args.vr)
 
 
135
 
136
  def forward_prior(self, z, pa=None, t=None):
137
  if self.cond_prior:
138
  z = torch.cat([z, pa], dim=1)
139
  z = self.prior(z)
140
- p_loc = z[:, :self.z_dim, ...]
141
- p_logscale = z[:, self.z_dim:2*self.z_dim, ...]
142
- p_features = z[:, 2*self.z_dim:, ...]
143
  if t is not None:
144
  p_logscale = p_logscale + torch.tensor(t).to(z.device).log()
145
  return p_loc, p_logscale, p_features
@@ -157,28 +198,27 @@ class Decoder(nn.Module):
157
  super().__init__()
158
  # parse architecture
159
  stages = []
160
- for i, stage in enumerate(args.dec_arch.split(',')):
161
- res = int(stage.split('b')[0])
162
- n_blocks = int(stage[stage.index('b') + 1:])
163
  stages += [(res, args.widths[::-1][i]) for _ in range(n_blocks)]
164
  self.blocks = []
165
  for i, (res, width) in enumerate(stages):
166
- next_width = stages[min(len(stages)-1, i+1)][1]
167
  self.blocks.append(DecoderBlock(args, width, next_width, res))
168
  self._scale_weights()
169
  self.blocks = nn.ModuleList(self.blocks)
170
  # bias params
171
- self.all_res = list(np.unique([stages[i][0]
172
- for i in range(len(stages))]))
173
  bias = []
174
  for i, res in enumerate(self.all_res):
175
  if res <= args.bias_max_res:
176
- bias.append(nn.Parameter(
177
- torch.zeros(1, args.widths[::-1][i], res, res)
178
- ))
179
  self.bias = nn.ParameterList(bias)
180
  self.cond_prior = args.cond_prior
181
- self.is_drop_cond = True if 'mnist' in args.hps else False # hacky
182
 
183
  def _scale_weights(self):
184
  scale = np.sqrt(1 / len(self.blocks))
@@ -200,28 +240,29 @@ class Decoder(nn.Module):
200
  res = block.res # current block resolution, e.g. 64x64
201
  pa = parents[..., :res, :res].clone() # select parents @ res
202
 
203
- if self.is_drop_cond: # for morphomnist w/ conditioning dropout. Hacky, clean up later
 
 
204
  pa_drop1 = pa.clone()
205
- pa_drop1[:,2:,...] = pa_drop1[:,2:,...] * p1
206
  pa_drop2 = pa.clone()
207
- pa_drop2[:,2:,...] = pa_drop2[:,2:,...] * p2
208
  else: # for ukbb
209
  pa_drop1 = pa_drop2 = pa
210
 
211
  if h.size(-1) < res: # upsample previous layer output
212
  b = bias[res] if res in bias.keys() else 0 # broadcasting
213
- h = b + F.interpolate(h, scale_factor=res/h.shape[-1])
214
 
215
  if block.cond_prior: # conditional prior: p(z_i | z_<i, pa_x)
216
  # w/ posterior correction
217
  # p_loc, p_logscale, p_feat = block.forward_prior(h, pa_drop1, t=t)
218
  if z.size(-1) < res: # w/o posterior correction
219
- z = b + F.interpolate(z, scale_factor=res/z.shape[-1])
220
- p_loc, p_logscale, p_feat = block.forward_prior(
221
- z, pa_drop1, t=t)
222
  else: # exogenous prior: p(z_i | z_<i)
223
  if z.size(-1) < res:
224
- z = b + F.interpolate(z, scale_factor=res/z.shape[-1])
225
  p_loc, p_logscale, p_feat = block.forward_prior(z, t=t)
226
 
227
  # computation tree:
@@ -239,16 +280,17 @@ class Decoder(nn.Module):
239
 
240
  if block.stochastic:
241
  if x is not None: # z_i ~ q(z_i | z_<i, pa_x, x)
242
- q_loc, q_logscale = block.forward_posterior(
243
- h, pa, x[res], t=t)
244
  z = sample_gaussian(q_loc, q_logscale)
245
- stat = dict(kl=gaussian_kl(
246
- q_loc, q_logscale, p_loc, p_logscale))
247
  # abduct exogenous noise
248
  if abduct:
249
  if block.cond_prior: # z* if conditional prior
250
- stat.update(dict(z={
251
- 'z': z, 'q_loc': q_loc, 'q_logscale': q_logscale}))
 
 
 
252
  else: # z if exogenous prior
253
  # stat.update(dict(z=z.detach()))
254
  stat.update(dict(z=z)) # if cf training
@@ -258,8 +300,9 @@ class Decoder(nn.Module):
258
  z = sample_gaussian(p_loc, p_logscale)
259
 
260
  if abduct and block.cond_prior: # for abducting z*
261
- stats.append(dict(z={
262
- 'p_loc': p_loc, 'p_logscale': p_logscale}))
 
263
  else:
264
  try: # forward fixed latents z or z*
265
  z = latents[i]
@@ -267,8 +310,9 @@ class Decoder(nn.Module):
267
  z = sample_gaussian(p_loc, p_logscale)
268
 
269
  if abduct and block.cond_prior: # for abducting z*
270
- stats.append(dict(z={
271
- 'p_loc': p_loc, 'p_logscale': p_logscale}))
 
272
  else:
273
  z = p_loc # deterministic path
274
 
@@ -276,7 +320,7 @@ class Decoder(nn.Module):
276
  h = self.forward_merge(block, h, z, pa_drop2)
277
 
278
  # if not block.cond_prior:
279
- if (i+1) < len(self.blocks):
280
  # z independent of pa_x for next layer prior
281
  z = block.z_feat_proj(torch.cat([z, p_feat], dim=1))
282
  return h, stats
@@ -287,7 +331,7 @@ class Decoder(nn.Module):
287
  return block.conv(h)
288
 
289
  def drop_cond(self):
290
- opt = dist.Categorical(1/3*torch.ones(3)).sample()
291
  if opt == 0: # drop stochastic path
292
  p1, p2 = 0, 1
293
  elif opt == 1: # drop deterministic path
@@ -301,30 +345,31 @@ class DGaussNet(nn.Module):
301
  def __init__(self, args):
302
  super(DGaussNet, self).__init__()
303
  self.x_loc = nn.Conv2d(
304
- args.widths[0], args.input_channels, kernel_size=1, stride=1)
 
305
  self.x_logscale = nn.Conv2d(
306
- args.widths[0], args.input_channels, kernel_size=1, stride=1)
 
307
 
308
  if args.input_channels == 3:
309
- self.channel_coeffs = nn.Conv2d(
310
- args.widths[0], 3, kernel_size=1, stride=1)
311
 
312
  if args.std_init > 0: # if std_init=0, random init weights for diag cov
313
  nn.init.zeros_(self.x_logscale.weight)
314
  nn.init.constant_(self.x_logscale.bias, np.log(args.std_init))
315
 
316
- covariance = args.x_like.split('_')[0]
317
- if covariance == 'fixed':
318
  self.x_logscale.weight.requires_grad = False
319
  self.x_logscale.bias.requires_grad = False
320
- elif covariance == 'shared':
321
  self.x_logscale.weight.requires_grad = False
322
  self.x_logscale.bias.requires_grad = True
323
- elif covariance == 'diag':
324
  self.x_logscale.weight.requires_grad = True
325
  self.x_logscale.bias.requires_grad = True
326
  else:
327
- NotImplementedError(f'{args.x_like} not implemented.')
328
 
329
  def forward(self, h, x=None, t=None):
330
  loc, logscale = self.x_loc(h), self.x_logscale(h).clamp(min=EPS)
@@ -351,7 +396,9 @@ class DGaussNet(nn.Module):
351
  return loc, logscale
352
 
353
  def approx_cdf(self, x):
354
- return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3))))
 
 
355
 
356
  def nll(self, h, x):
357
  loc, logscale = self.forward(h, x)
@@ -367,10 +414,11 @@ class DGaussNet(nn.Module):
367
  log_probs = torch.where(
368
  x < -0.999,
369
  log_cdf_plus,
370
- torch.where(x > 0.999, log_one_minus_cdf_min,
371
- torch.log(cdf_delta.clamp(min=1e-12))),
 
372
  )
373
- return -1. * log_probs.mean(dim=(1, 2, 3))
374
 
375
  def sample(self, h, return_loc=True, t=None):
376
  if return_loc:
@@ -378,20 +426,20 @@ class DGaussNet(nn.Module):
378
  else:
379
  loc, logscale = self.forward(h, t)
380
  x = loc + torch.exp(logscale) * torch.randn_like(loc)
381
- x = torch.clamp(x, min=-1., max=1.)
382
  return x, logscale.exp()
383
 
384
 
385
  class HVAE(nn.Module):
386
  def __init__(self, args):
387
  super().__init__()
388
- args.vr = 'light' if 'ukbb' in args.hps else None # hacky
389
  self.encoder = Encoder(args)
390
  self.decoder = Decoder(args)
391
- if args.x_like.split('_')[1] == 'dgauss':
392
  self.likelihood = DGaussNet(args)
393
  else:
394
- NotImplementedError(f'{args.x_like} not implemented.')
395
  self.cond_prior = args.cond_prior
396
  self.free_bits = args.kl_free_bits
397
 
@@ -404,12 +452,12 @@ class HVAE(nn.Module):
404
  kl_pp = 0.0
405
  for stat in stats:
406
  kl_pp += torch.maximum(
407
- free_bits, stat['kl'].sum(dim=(2, 3)).mean(dim=0)
408
  ).sum()
409
  else:
410
  kl_pp = torch.zeros_like(nll_pp)
411
  for i, stat in enumerate(stats):
412
- kl_pp += stat['kl'].sum(dim=(1, 2, 3))
413
  kl_pp = kl_pp / np.prod(x.shape[1:]) # per pixel
414
  elbo = nll_pp.mean() + beta * kl_pp.mean() # negative elbo (free energy)
415
  return dict(elbo=elbo, nll=nll_pp.mean(), kl=kl_pp.mean())
@@ -421,26 +469,26 @@ class HVAE(nn.Module):
421
  def abduct(self, x, parents, cf_parents=None, alpha=0.5, t=None):
422
  acts = self.encoder(x)
423
  _, q_stats = self.decoder(
424
- x=acts, parents=parents, abduct=True, t=t) # q(z|x,pa)
425
- q_stats = [s['z'] for s in q_stats]
 
426
 
427
  if self.cond_prior and cf_parents is not None:
428
- _, p_stats = self.decoder(
429
- parents=cf_parents, abduct=True, t=t) # p(z|pa*)
430
- p_stats = [s['z'] for s in p_stats]
431
 
432
  cf_zs = []
433
  t = torch.tensor(t).to(x.device) # z* sampling temperature
434
 
435
  for i in range(len(q_stats)):
436
  # from z_i ~ q(z_i | z_{<i}, x, pa)
437
- q_loc = q_stats[i]['q_loc']
438
- q_scale = q_stats[i]['q_logscale'].exp()
439
  # abduct exogenouse noise u ~ N(0,I)
440
- u = (q_stats[i]['z'] - q_loc) / q_scale
441
  # p(z_i | z_{<i}, pa*)
442
- p_loc = p_stats[i]['p_loc']
443
- p_var = p_stats[i]['p_logscale'].exp().pow(2)
444
 
445
  # Option1: mixture distribution: r(z_i | z_{<i}, x, pa, pa*)
446
  # = a*q(z_i | z_{<i}, x, pa) + (1-a)*p(z_i | z_{<i}, pa*)
 
6
 
7
  EPS = -9 # minimum logscale
8
 
9
+
10
  @torch.jit.script
11
  def gaussian_kl(q_loc, q_logscale, p_loc, p_logscale):
12
+ return (
13
+ -0.5
14
+ + p_logscale
15
+ - q_logscale
16
+ + 0.5
17
+ * (q_logscale.exp().pow(2) + (q_loc - p_loc).pow(2))
18
+ / p_logscale.exp().pow(2)
19
+ )
20
 
21
 
22
  @torch.jit.script
 
25
 
26
 
27
  class Block(nn.Module):
28
+ def __init__(
29
+ self,
30
+ in_width,
31
+ bottleneck,
32
+ out_width,
33
+ kernel_size=3,
34
+ residual=True,
35
+ down_rate=None,
36
+ version=None,
37
+ ):
38
  super().__init__()
39
  self.d = down_rate
40
  self.residual = residual
41
  padding = 0 if kernel_size == 1 else 1
42
 
43
+ if version == "light": # for ukbb
44
  activation = nn.ReLU()
45
  self.conv = nn.Sequential(
46
  activation,
47
  nn.Conv2d(in_width, bottleneck, kernel_size, 1, padding),
48
  activation,
49
+ nn.Conv2d(bottleneck, out_width, kernel_size, 1, padding),
50
  )
51
  else: # for morphomnist
52
  activation = nn.GELU()
 
58
  activation,
59
  nn.Conv2d(bottleneck, bottleneck, kernel_size, 1, padding),
60
  activation,
61
+ nn.Conv2d(bottleneck, out_width, 1, 1),
62
  )
63
 
64
  if self.residual and (self.d or in_width > out_width):
 
83
  super().__init__()
84
  # parse architecture
85
  stages = []
86
+ for i, stage in enumerate(args.enc_arch.split(",")):
87
+ start = stage.index("b") + 1
88
+ end = stage.index("d") if "d" in stage else None
89
  n_blocks = int(stage[start:end])
90
 
91
  if i == 0: # define network stem
92
+ if n_blocks == 0 and "d" not in stage:
93
+ print("Using stride=2 conv encoder stem.")
94
+ self.stem = nn.Conv2d(
95
+ args.input_channels,
96
+ args.widths[1],
97
+ kernel_size=7,
98
+ stride=2,
99
+ padding=3,
100
+ )
101
  continue
102
  else:
103
+ self.stem = nn.Conv2d(
104
+ args.input_channels,
105
+ args.widths[0],
106
+ kernel_size=7,
107
+ stride=1,
108
+ padding=3,
109
+ )
110
 
111
  stages += [(args.widths[i], None) for _ in range(n_blocks)]
112
+ if "d" in stage: # downsampling block
113
+ stages += [(args.widths[i + 1], int(stage[stage.index("d") + 1]))]
114
  blocks = []
115
  for i, (width, d) in enumerate(stages):
116
+ prev_width = stages[max(0, i - 1)][0]
117
  bottleneck = int(prev_width / args.bottleneck)
118
+ blocks.append(
119
+ Block(prev_width, bottleneck, width, down_rate=d, version=args.vr)
120
+ )
121
  # scale weights of last conv layer in each block
122
  for b in blocks:
123
  b.conv[-1].weight.data *= np.sqrt(1 / len(blocks))
 
140
  super().__init__()
141
  bottleneck = int(in_width / args.bottleneck)
142
  self.res = resolution
143
+ self.stochastic = self.res <= args.z_max_res
144
  self.z_dim = args.z_dim
145
  self.cond_prior = args.cond_prior
146
  k = 3 if self.res > 2 else 1
 
152
  # self.z_feat_proj = nn.Conv2d(self.z_dim + in_width, out_width, 1)
153
  self.z_feat_proj = nn.Conv2d(self.z_dim + in_width, out_width, 1)
154
 
155
+ self.prior = Block(
156
+ p_in_width,
157
+ bottleneck,
158
+ 2 * self.z_dim + in_width,
159
+ kernel_size=k,
160
+ residual=False,
161
+ version=args.vr,
162
+ )
163
  if self.stochastic:
164
+ self.posterior = Block(
165
+ 2 * in_width + args.context_dim,
166
+ bottleneck,
167
+ 2 * self.z_dim,
168
+ kernel_size=k,
169
+ residual=False,
170
+ version=args.vr,
171
+ )
172
  self.z_proj = nn.Conv2d(self.z_dim + args.context_dim, in_width, 1)
173
+ self.conv = Block(
174
+ in_width, bottleneck, out_width, kernel_size=k, version=args.vr
175
+ )
176
 
177
  def forward_prior(self, z, pa=None, t=None):
178
  if self.cond_prior:
179
  z = torch.cat([z, pa], dim=1)
180
  z = self.prior(z)
181
+ p_loc = z[:, : self.z_dim, ...]
182
+ p_logscale = z[:, self.z_dim : 2 * self.z_dim, ...]
183
+ p_features = z[:, 2 * self.z_dim :, ...]
184
  if t is not None:
185
  p_logscale = p_logscale + torch.tensor(t).to(z.device).log()
186
  return p_loc, p_logscale, p_features
 
198
  super().__init__()
199
  # parse architecture
200
  stages = []
201
+ for i, stage in enumerate(args.dec_arch.split(",")):
202
+ res = int(stage.split("b")[0])
203
+ n_blocks = int(stage[stage.index("b") + 1 :])
204
  stages += [(res, args.widths[::-1][i]) for _ in range(n_blocks)]
205
  self.blocks = []
206
  for i, (res, width) in enumerate(stages):
207
+ next_width = stages[min(len(stages) - 1, i + 1)][1]
208
  self.blocks.append(DecoderBlock(args, width, next_width, res))
209
  self._scale_weights()
210
  self.blocks = nn.ModuleList(self.blocks)
211
  # bias params
212
+ self.all_res = list(np.unique([stages[i][0] for i in range(len(stages))]))
 
213
  bias = []
214
  for i, res in enumerate(self.all_res):
215
  if res <= args.bias_max_res:
216
+ bias.append(
217
+ nn.Parameter(torch.zeros(1, args.widths[::-1][i], res, res))
218
+ )
219
  self.bias = nn.ParameterList(bias)
220
  self.cond_prior = args.cond_prior
221
+ self.is_drop_cond = True if "mnist" in args.hps else False # hacky
222
 
223
  def _scale_weights(self):
224
  scale = np.sqrt(1 / len(self.blocks))
 
240
  res = block.res # current block resolution, e.g. 64x64
241
  pa = parents[..., :res, :res].clone() # select parents @ res
242
 
243
+ if (
244
+ self.is_drop_cond
245
+ ): # for morphomnist w/ conditioning dropout. Hacky, clean up later
246
  pa_drop1 = pa.clone()
247
+ pa_drop1[:, 2:, ...] = pa_drop1[:, 2:, ...] * p1
248
  pa_drop2 = pa.clone()
249
+ pa_drop2[:, 2:, ...] = pa_drop2[:, 2:, ...] * p2
250
  else: # for ukbb
251
  pa_drop1 = pa_drop2 = pa
252
 
253
  if h.size(-1) < res: # upsample previous layer output
254
  b = bias[res] if res in bias.keys() else 0 # broadcasting
255
+ h = b + F.interpolate(h, scale_factor=res / h.shape[-1])
256
 
257
  if block.cond_prior: # conditional prior: p(z_i | z_<i, pa_x)
258
  # w/ posterior correction
259
  # p_loc, p_logscale, p_feat = block.forward_prior(h, pa_drop1, t=t)
260
  if z.size(-1) < res: # w/o posterior correction
261
+ z = b + F.interpolate(z, scale_factor=res / z.shape[-1])
262
+ p_loc, p_logscale, p_feat = block.forward_prior(z, pa_drop1, t=t)
 
263
  else: # exogenous prior: p(z_i | z_<i)
264
  if z.size(-1) < res:
265
+ z = b + F.interpolate(z, scale_factor=res / z.shape[-1])
266
  p_loc, p_logscale, p_feat = block.forward_prior(z, t=t)
267
 
268
  # computation tree:
 
280
 
281
  if block.stochastic:
282
  if x is not None: # z_i ~ q(z_i | z_<i, pa_x, x)
283
+ q_loc, q_logscale = block.forward_posterior(h, pa, x[res], t=t)
 
284
  z = sample_gaussian(q_loc, q_logscale)
285
+ stat = dict(kl=gaussian_kl(q_loc, q_logscale, p_loc, p_logscale))
 
286
  # abduct exogenous noise
287
  if abduct:
288
  if block.cond_prior: # z* if conditional prior
289
+ stat.update(
290
+ dict(
291
+ z={"z": z, "q_loc": q_loc, "q_logscale": q_logscale}
292
+ )
293
+ )
294
  else: # z if exogenous prior
295
  # stat.update(dict(z=z.detach()))
296
  stat.update(dict(z=z)) # if cf training
 
300
  z = sample_gaussian(p_loc, p_logscale)
301
 
302
  if abduct and block.cond_prior: # for abducting z*
303
+ stats.append(
304
+ dict(z={"p_loc": p_loc, "p_logscale": p_logscale})
305
+ )
306
  else:
307
  try: # forward fixed latents z or z*
308
  z = latents[i]
 
310
  z = sample_gaussian(p_loc, p_logscale)
311
 
312
  if abduct and block.cond_prior: # for abducting z*
313
+ stats.append(
314
+ dict(z={"p_loc": p_loc, "p_logscale": p_logscale})
315
+ )
316
  else:
317
  z = p_loc # deterministic path
318
 
 
320
  h = self.forward_merge(block, h, z, pa_drop2)
321
 
322
  # if not block.cond_prior:
323
+ if (i + 1) < len(self.blocks):
324
  # z independent of pa_x for next layer prior
325
  z = block.z_feat_proj(torch.cat([z, p_feat], dim=1))
326
  return h, stats
 
331
  return block.conv(h)
332
 
333
  def drop_cond(self):
334
+ opt = dist.Categorical(1 / 3 * torch.ones(3)).sample()
335
  if opt == 0: # drop stochastic path
336
  p1, p2 = 0, 1
337
  elif opt == 1: # drop deterministic path
 
345
  def __init__(self, args):
346
  super(DGaussNet, self).__init__()
347
  self.x_loc = nn.Conv2d(
348
+ args.widths[0], args.input_channels, kernel_size=1, stride=1
349
+ )
350
  self.x_logscale = nn.Conv2d(
351
+ args.widths[0], args.input_channels, kernel_size=1, stride=1
352
+ )
353
 
354
  if args.input_channels == 3:
355
+ self.channel_coeffs = nn.Conv2d(args.widths[0], 3, kernel_size=1, stride=1)
 
356
 
357
  if args.std_init > 0: # if std_init=0, random init weights for diag cov
358
  nn.init.zeros_(self.x_logscale.weight)
359
  nn.init.constant_(self.x_logscale.bias, np.log(args.std_init))
360
 
361
+ covariance = args.x_like.split("_")[0]
362
+ if covariance == "fixed":
363
  self.x_logscale.weight.requires_grad = False
364
  self.x_logscale.bias.requires_grad = False
365
+ elif covariance == "shared":
366
  self.x_logscale.weight.requires_grad = False
367
  self.x_logscale.bias.requires_grad = True
368
+ elif covariance == "diag":
369
  self.x_logscale.weight.requires_grad = True
370
  self.x_logscale.bias.requires_grad = True
371
  else:
372
+ NotImplementedError(f"{args.x_like} not implemented.")
373
 
374
  def forward(self, h, x=None, t=None):
375
  loc, logscale = self.x_loc(h), self.x_logscale(h).clamp(min=EPS)
 
396
  return loc, logscale
397
 
398
  def approx_cdf(self, x):
399
+ return 0.5 * (
400
+ 1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3)))
401
+ )
402
 
403
  def nll(self, h, x):
404
  loc, logscale = self.forward(h, x)
 
414
  log_probs = torch.where(
415
  x < -0.999,
416
  log_cdf_plus,
417
+ torch.where(
418
+ x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))
419
+ ),
420
  )
421
+ return -1.0 * log_probs.mean(dim=(1, 2, 3))
422
 
423
  def sample(self, h, return_loc=True, t=None):
424
  if return_loc:
 
426
  else:
427
  loc, logscale = self.forward(h, t)
428
  x = loc + torch.exp(logscale) * torch.randn_like(loc)
429
+ x = torch.clamp(x, min=-1.0, max=1.0)
430
  return x, logscale.exp()
431
 
432
 
433
  class HVAE(nn.Module):
434
  def __init__(self, args):
435
  super().__init__()
436
+ args.vr = "light" if "ukbb" in args.hps else None # hacky
437
  self.encoder = Encoder(args)
438
  self.decoder = Decoder(args)
439
+ if args.x_like.split("_")[1] == "dgauss":
440
  self.likelihood = DGaussNet(args)
441
  else:
442
+ NotImplementedError(f"{args.x_like} not implemented.")
443
  self.cond_prior = args.cond_prior
444
  self.free_bits = args.kl_free_bits
445
 
 
452
  kl_pp = 0.0
453
  for stat in stats:
454
  kl_pp += torch.maximum(
455
+ free_bits, stat["kl"].sum(dim=(2, 3)).mean(dim=0)
456
  ).sum()
457
  else:
458
  kl_pp = torch.zeros_like(nll_pp)
459
  for i, stat in enumerate(stats):
460
+ kl_pp += stat["kl"].sum(dim=(1, 2, 3))
461
  kl_pp = kl_pp / np.prod(x.shape[1:]) # per pixel
462
  elbo = nll_pp.mean() + beta * kl_pp.mean() # negative elbo (free energy)
463
  return dict(elbo=elbo, nll=nll_pp.mean(), kl=kl_pp.mean())
 
469
  def abduct(self, x, parents, cf_parents=None, alpha=0.5, t=None):
470
  acts = self.encoder(x)
471
  _, q_stats = self.decoder(
472
+ x=acts, parents=parents, abduct=True, t=t
473
+ ) # q(z|x,pa)
474
+ q_stats = [s["z"] for s in q_stats]
475
 
476
  if self.cond_prior and cf_parents is not None:
477
+ _, p_stats = self.decoder(parents=cf_parents, abduct=True, t=t) # p(z|pa*)
478
+ p_stats = [s["z"] for s in p_stats]
 
479
 
480
  cf_zs = []
481
  t = torch.tensor(t).to(x.device) # z* sampling temperature
482
 
483
  for i in range(len(q_stats)):
484
  # from z_i ~ q(z_i | z_{<i}, x, pa)
485
+ q_loc = q_stats[i]["q_loc"]
486
+ q_scale = q_stats[i]["q_logscale"].exp()
487
  # abduct exogenouse noise u ~ N(0,I)
488
+ u = (q_stats[i]["z"] - q_loc) / q_scale
489
  # p(z_i | z_{<i}, pa*)
490
+ p_loc = p_stats[i]["p_loc"]
491
+ p_var = p_stats[i]["p_logscale"].exp().pow(2)
492
 
493
  # Option1: mixture distribution: r(z_i | z_{<i}, x, pa, pa*)
494
  # = a*q(z_i | z_{<i}, x, pa) + (1-a)*p(z_i | z_{<i}, pa*)