Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
import networkx as nx | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
from matplotlib import rc, patches, colors | |
rc("font", **{"family": "serif", "serif": ["Roman"]}) | |
rc("text", usetex=True) | |
rc("image", interpolation="none") | |
rc("text.latex", preamble=r"\usepackage{amsmath} \usepackage{amssymb}") | |
from datasets import get_attr_max_min | |
HAMMER = np.array(Image.open("./hammer.png").resize((35, 35))) / 255 | |
class MidpointNormalize(colors.Normalize): | |
def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False): | |
self.midpoint = midpoint | |
colors.Normalize.__init__(self, vmin, vmax, clip) | |
def __call__(self, value, clip=None): | |
v_ext = np.max([np.abs(self.vmin), np.abs(self.vmax)]) | |
x, y = [-v_ext, self.midpoint, v_ext], [0, 0.5, 1] | |
return np.ma.masked_array(np.interp(value, x, y)) | |
def postprocess(x): | |
return ((x + 1.0) * 127.5).squeeze().detach().cpu().numpy() | |
def mnist_graph(*args): | |
x, t, i, y = r"$\mathbf{x}$", r"$t$", r"$i$", r"$y$" | |
ut, ui, uy = r"$\mathbf{U}_t$", r"$\mathbf{U}_i$", r"$\mathbf{U}_y$" | |
zx, ex = r"$\mathbf{z}_{1:L}$", r"$\boldsymbol{\epsilon}$" | |
G = nx.DiGraph() | |
G.add_edge(t, x) | |
G.add_edge(i, x) | |
G.add_edge(y, x) | |
G.add_edge(t, i) | |
G.add_edge(ut, t) | |
G.add_edge(ui, i) | |
G.add_edge(uy, y) | |
G.add_edge(zx, x) | |
G.add_edge(ex, x) | |
pos = { | |
y: (0, 0), | |
uy: (-1, 0), | |
t: (0, 0.5), | |
ut: (0, 1), | |
x: (1, 0), | |
zx: (2, 0.375), | |
ex: (2, 0), | |
i: (1, 0.5), | |
ui: (1, 1), | |
} | |
node_c = {} | |
for node in G: | |
node_c[node] = "lightgrey" if node in [x, t, i, y] else "white" | |
node_line_c = {k: "black" for k, _ in node_c.items()} | |
edge_c = {e: "black" for e in G.edges} | |
if args[0]: # do_t | |
edge_c[(ut, t)] = "lightgrey" | |
# G.remove_edge(ut, t) | |
node_line_c[t] = "red" | |
if args[1]: # do_i | |
edge_c[(ui, i)] = "lightgrey" | |
edge_c[(t, i)] = "lightgrey" | |
# G.remove_edges_from([(ui, i), (t, i)]) | |
node_line_c[i] = "red" | |
if args[2]: # do_y | |
edge_c[(uy, y)] = "lightgrey" | |
# G.remove_edge(uy, y) | |
node_line_c[y] = "red" | |
fs = 30 | |
options = { | |
"font_size": fs, | |
"node_size": 3000, | |
"node_color": list(node_c.values()), | |
"edgecolors": list(node_line_c.values()), | |
"edge_color": list(edge_c.values()), | |
"linewidths": 2, | |
"width": 2, | |
} | |
plt.close("all") | |
fig, ax = plt.subplots(1, 1, figsize=(6, 4.1)) # , constrained_layout=True) | |
# fig.patch.set_visible(False) | |
ax.margins(x=0.06, y=0.15, tight=False) | |
ax.axis("off") | |
nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle="-|>", ax=ax) | |
# need to reuse x, y limits so that the graphs plot the same way before and after removing edges | |
x_lim = (-1.348, 2.348) | |
y_lim = (-0.215, 1.215) | |
ax.set_xlim(x_lim) | |
ax.set_ylim(y_lim) | |
rect = patches.FancyBboxPatch( | |
(1.75, -0.16), | |
0.5, | |
0.7, | |
boxstyle="round, pad=0.05, rounding_size=0", | |
linewidth=2, | |
edgecolor="black", | |
facecolor="none", | |
linestyle="-", | |
) | |
ax.add_patch(rect) | |
ax.text(1.85, 0.65, r"$\mathbf{U}_{\mathbf{x}}$", fontsize=fs) | |
if args[0]: # do_t | |
fig.figimage(HAMMER, 0.26 * fig.bbox.xmax, 0.525 * fig.bbox.ymax, zorder=10) | |
if args[1]: # do_i | |
fig.figimage(HAMMER, 0.5175 * fig.bbox.xmax, 0.525 * fig.bbox.ymax, zorder=11) | |
if args[2]: # do_y | |
fig.figimage(HAMMER, 0.26 * fig.bbox.xmax, 0.2 * fig.bbox.ymax, zorder=12) | |
fig.tight_layout() | |
fig.canvas.draw() | |
return np.array(fig.canvas.renderer.buffer_rgba()) | |
def brain_graph(*args): | |
x, m, s, a, b, v = r"$\mathbf{x}$", r"$m$", r"$s$", r"$a$", r"$b$", r"$v$" | |
um, us, ua, ub, uv = ( | |
r"$\mathbf{U}_m$", | |
r"$\mathbf{U}_s$", | |
r"$\mathbf{U}_a$", | |
r"$\mathbf{U}_b$", | |
r"$\mathbf{U}_v$", | |
) | |
zx, ex = r"$\mathbf{z}_{1:L}$", r"$\boldsymbol{\epsilon}$" | |
G = nx.DiGraph() | |
G.add_edge(m, x) | |
G.add_edge(s, x) | |
G.add_edge(b, x) | |
G.add_edge(v, x) | |
G.add_edge(zx, x) | |
G.add_edge(ex, x) | |
G.add_edge(a, b) | |
G.add_edge(a, v) | |
G.add_edge(s, b) | |
G.add_edge(um, m) | |
G.add_edge(us, s) | |
G.add_edge(ua, a) | |
G.add_edge(ub, b) | |
G.add_edge(uv, v) | |
pos = { | |
x: (0, 0), | |
zx: (-0.25, -1), | |
ex: (0.25, -1), | |
a: (0, 1), | |
ua: (0, 2), | |
s: (1, 0), | |
us: (1, -1), | |
b: (1, 1), | |
ub: (1, 2), | |
m: (-1, 0), | |
um: (-1, -1), | |
v: (-1, 1), | |
uv: (-1, 2), | |
} | |
node_c = {} | |
for node in G: | |
node_c[node] = "lightgrey" if node in [x, m, s, a, b, v] else "white" | |
node_line_c = {k: "black" for k, _ in node_c.items()} | |
edge_c = {e: "black" for e in G.edges} | |
if args[0]: # do_m | |
# G.remove_edge(um, m) | |
edge_c[(um, m)] = "lightgrey" | |
node_line_c[m] = "red" | |
if args[1]: # do_s | |
# G.remove_edge(us, s) | |
edge_c[(us, s)] = "lightgrey" | |
node_line_c[s] = "red" | |
if args[2]: # do_a | |
# G.remove_edge(ua, a) | |
edge_c[(ua, a)] = "lightgrey" | |
node_line_c[a] = "red" | |
if args[3]: # do_b | |
# G.remove_edges_from([(ub, b), (s, b), (a, b)]) | |
edge_c[(ub, b)] = "lightgrey" | |
edge_c[(s, b)] = "lightgrey" | |
edge_c[(a, b)] = "lightgrey" | |
node_line_c[b] = "red" | |
if args[4]: # do_v | |
# G.remove_edges_from([(uv, v), (a, v), (b, v)]) | |
edge_c[(uv, v)] = "lightgrey" | |
edge_c[(a, v)] = "lightgrey" | |
edge_c[(b, v)] = "lightgrey" | |
node_line_c[v] = "red" | |
fs = 30 | |
options = { | |
"font_size": fs, | |
"node_size": 3000, | |
"node_color": list(node_c.values()), | |
"edgecolors": list(node_line_c.values()), | |
"edge_color": list(edge_c.values()), | |
"linewidths": 2, | |
"width": 2, | |
} | |
plt.close("all") | |
fig, ax = plt.subplots(1, 1, figsize=(5, 5)) # , constrained_layout=True) | |
# fig.patch.set_visible(False) | |
ax.margins(x=0.1, y=0.08, tight=False) | |
ax.axis("off") | |
nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle="-|>", ax=ax) | |
# need to reuse x, y limits so that the graphs plot the same way before and after removing edges | |
x_lim = (-1.32, 1.32) | |
y_lim = (-1.414, 2.414) | |
ax.set_xlim(x_lim) | |
ax.set_ylim(y_lim) | |
rect = patches.FancyBboxPatch( | |
(-0.5, -1.325), | |
1, | |
0.65, | |
boxstyle="round, pad=0.05, rounding_size=0", | |
linewidth=2, | |
edgecolor="black", | |
facecolor="none", | |
linestyle="-", | |
) | |
ax.add_patch(rect) | |
# ax.text(1.85, 0.65, r"$\mathbf{U}_{\mathbf{x}}$", fontsize=fs) | |
if args[0]: # do_m | |
fig.figimage(HAMMER, 0.0075 * fig.bbox.xmax, 0.395 * fig.bbox.ymax, zorder=10) | |
if args[1]: # do_s | |
fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.395 * fig.bbox.ymax, zorder=11) | |
if args[2]: # do_a | |
fig.figimage(HAMMER, 0.363 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=12) | |
if args[3]: # do_b | |
fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=13) | |
if args[4]: # do_v | |
fig.figimage(HAMMER, 0.0075 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=14) | |
else: # b -> v | |
a3 = patches.FancyArrowPatch( | |
(0.86, 1.21), | |
(-0.86, 1.21), | |
connectionstyle="arc3,rad=.3", | |
linewidth=2, | |
arrowstyle="simple, head_width=10, head_length=10", | |
color="k", | |
) | |
ax.add_patch(a3) | |
# print(ax.get_xlim()) | |
# print(ax.get_ylim()) | |
fig.tight_layout() | |
fig.canvas.draw() | |
return np.array(fig.canvas.renderer.buffer_rgba()) | |
def chest_graph(*args): | |
x, a, d, r, s = r"$\mathbf{x}$", r"$a$", r"$d$", r"$r$", r"$s$" | |
ua, ud, ur, us = ( | |
r"$\mathbf{U}_a$", | |
r"$\mathbf{U}_d$", | |
r"$\mathbf{U}_r$", | |
r"$\mathbf{U}_s$", | |
) | |
zx, ex = r"$\mathbf{z}_{1:L}$", r"$\boldsymbol{\epsilon}$" | |
G = nx.DiGraph() | |
G.add_edge(ua, a) | |
G.add_edge(ud, d) | |
G.add_edge(ur, r) | |
G.add_edge(us, s) | |
G.add_edge(a, d) | |
G.add_edge(d, x) | |
G.add_edge(r, x) | |
G.add_edge(s, x) | |
G.add_edge(ex, x) | |
G.add_edge(zx, x) | |
G.add_edge(a, x) | |
pos = { | |
x: (0, 0), | |
a: (-1, 1), | |
d: (0, 1), | |
r: (1, 1), | |
s: (1, 0), | |
ua: (-1, 2), | |
ud: (0, 2), | |
ur: (1, 2), | |
us: (1, -1), | |
zx: (-0.25, -1), | |
ex: (0.25, -1), | |
} | |
node_c = {} | |
for node in G: | |
node_c[node] = "lightgrey" if node in [x, a, d, r, s] else "white" | |
edge_c = {e: "black" for e in G.edges} | |
node_line_c = {k: "black" for k, _ in node_c.items()} | |
if args[0]: # do_r | |
# G.remove_edge(ur, r) | |
edge_c[(ur, r)] = "lightgrey" | |
node_line_c[r] = "red" | |
if args[1]: # do_s | |
# G.remove_edges_from([(us, s)]) | |
edge_c[(us, s)] = "lightgrey" | |
node_line_c[s] = "red" | |
if args[2]: # do_f (do_d) | |
# G.remove_edges_from([(ud, d), (a, d)]) | |
edge_c[(ud, d)] = "lightgrey" | |
edge_c[(a, d)] = "lightgrey" | |
node_line_c[d] = "red" | |
if args[3]: # do_a | |
# G.remove_edge(ua, a) | |
edge_c[(ua, a)] = "lightgrey" | |
node_line_c[a] = "red" | |
fs = 30 | |
options = { | |
"font_size": fs, | |
"node_size": 3000, | |
"node_color": list(node_c.values()), | |
"edgecolors": list(node_line_c.values()), | |
"edge_color": list(edge_c.values()), | |
"linewidths": 2, | |
"width": 2, | |
} | |
plt.close("all") | |
fig, ax = plt.subplots(1, 1, figsize=(5, 5)) # , constrained_layout=True) | |
# fig.patch.set_visible(False) | |
ax.margins(x=0.1, y=0.08, tight=False) | |
ax.axis("off") | |
nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle="-|>", ax=ax) | |
# need to reuse x, y limits so that the graphs plot the same way before and after removing edges | |
x_lim = (-1.32, 1.32) | |
y_lim = (-1.414, 2.414) | |
ax.set_xlim(x_lim) | |
ax.set_ylim(y_lim) | |
rect = patches.FancyBboxPatch( | |
(-0.5, -1.325), | |
1, | |
0.65, | |
boxstyle="round, pad=0.05, rounding_size=0", | |
linewidth=2, | |
edgecolor="black", | |
facecolor="none", | |
linestyle="-", | |
) | |
ax.add_patch(rect) | |
ax.text(-0.9, -1.075, r"$\mathbf{U}_{\mathbf{x}}$", fontsize=fs) | |
if args[0]: # do_r | |
fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=10) | |
if args[1]: # do_s | |
fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.395 * fig.bbox.ymax, zorder=11) | |
if args[2]: # do_f | |
fig.figimage(HAMMER, 0.363 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=12) | |
if args[3]: # do_a | |
fig.figimage(HAMMER, 0.0075 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=13) | |
fig.tight_layout() | |
fig.canvas.draw() | |
return np.array(fig.canvas.renderer.buffer_rgba()) | |
def vae_preprocess(args, pa): | |
if "ukbb" in args.hps: | |
# preprocessing ukbb parents for the vae which was originally trained using | |
# log standardized parents. The pgm was trained using [-1,1] normalization | |
# first undo [-1,1] parent preprocessing back to original range | |
for k, v in pa.items(): | |
if k != "mri_seq" and k != "sex": | |
pa[k] = (v + 1) / 2 # [-1,1] -> [0,1] | |
_max, _min = get_attr_max_min(k) | |
pa[k] = pa[k] * (_max - _min) + _min | |
# log_standardize parents for vae input | |
for k, v in pa.items(): | |
logpa_k = torch.log(v.clamp(min=1e-12)) | |
if k == "age": | |
pa[k] = (logpa_k - 4.112339973449707) / 0.11769197136163712 | |
elif k == "brain_volume": | |
pa[k] = (logpa_k - 13.965583801269531) / 0.09537758678197861 | |
elif k == "ventricle_volume": | |
pa[k] = (logpa_k - 10.345998764038086) / 0.43127763271331787 | |
# concatenate parents expand to input res for conditioning the vae | |
pa = torch.cat( | |
[pa[k] if len(pa[k].shape) > 1 else pa[k][..., None] for k in args.parents_x], | |
dim=1, | |
) | |
pa = ( | |
pa[..., None, None].repeat(1, 1, *(args.input_res,) * 2).to(args.device).float() | |
) | |
return pa | |
def preprocess_brain(args, obs): | |
obs["x"] = (obs["x"][None, ...].float().to(args.device) - 127.5) / 127.5 # [-1,1] | |
# for all other variables except x | |
for k in [k for k in obs.keys() if k != "x"]: | |
obs[k] = obs[k].float().to(args.device).view(1, 1) | |
if k in ["age", "brain_volume", "ventricle_volume"]: | |
k_max, k_min = get_attr_max_min(k) | |
obs[k] = (obs[k] - k_min) / (k_max - k_min) # [0,1] | |
obs[k] = 2 * obs[k] - 1 # [-1,1] | |
return obs | |
def get_fig_arr(x, width=4, height=4, dpi=144, cmap="Greys_r", norm=None): | |
fig = plt.figure(figsize=(width, height), dpi=dpi) | |
ax = plt.axes([0, 0, 1, 1], frameon=False) | |
if cmap == "Greys_r": | |
ax.imshow(x, cmap=cmap, vmin=0, vmax=255) | |
else: | |
ax.imshow(x, cmap=cmap, norm=norm) | |
ax.axis("off") | |
fig.canvas.draw() | |
return np.array(fig.canvas.renderer.buffer_rgba()) | |
def normalize(x, x_min=None, x_max=None, zero_one=False): | |
if x_min is None: | |
x_min = x.min() | |
if x_max is None: | |
x_max = x.max() | |
x = (x - x_min) / (x_max - x_min) # [0,1] | |
return x if zero_one else 2 * x - 1 # else [-1,1] | |