Spaces:
Runtime error
Runtime error
lmoss
commited on
Commit
·
3e26a38
1
Parent(s):
9bc9fb6
added interactive latent space dim
Browse files
app.py
CHANGED
@@ -1,18 +1,71 @@
|
|
1 |
import streamlit as st
|
|
|
|
|
2 |
import pyvista as pv
|
3 |
-
from dcgan import DCGAN3D_G
|
4 |
import torch
|
5 |
import requests
|
6 |
-
import time
|
7 |
import numpy as np
|
8 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
st.title("Generating Porous Media with GANs")
|
11 |
|
12 |
st.markdown(
|
13 |
"""
|
14 |
### Author
|
15 |
-
|
16 |
|
17 |
## Description
|
18 |
This is a demo of the Generative Adversarial Network (GAN, [Goodfellow 2014](https://arxiv.org/abs/1406.2661)) trained for our publication [PorousMediaGAN](https://github.com/LukasMosser/PorousMediaGan)
|
@@ -20,52 +73,32 @@ st.markdown(
|
|
20 |
|
21 |
The model is a pretrained 3D Deep Convolutional GAN ([Radford 2015](https://arxiv.org/abs/1511.06434)) that generates a volumetric image of a porous medium, here a Berea sandstone, from a set of pretrained weights.
|
22 |
|
|
|
|
|
|
|
23 |
## The Demo
|
24 |
Slices through the 3D volume are rendered using [PyVista](https://www.pyvista.org/) and [PyThreeJS](https://pythreejs.readthedocs.io/en/stable/)
|
25 |
|
26 |
The model itself currently runs on the :hugging_face: [Huggingface Spaces](https://huggingface.co/spaces) instance.
|
27 |
Future migration to the :hugging_face: [Huggingface Models](https://huggingface.co/models) repository is possible.
|
|
|
|
|
|
|
|
|
|
|
28 |
"""
|
29 |
, unsafe_allow_html=True)
|
30 |
|
31 |
-
|
|
|
32 |
|
33 |
-
|
34 |
-
resp = requests.get(url)
|
35 |
|
36 |
-
|
37 |
-
|
|
|
38 |
|
39 |
pv.set_plot_theme("document")
|
40 |
-
|
41 |
-
|
42 |
-
netG = DCGAN3D_G(64, 512, 1, 32, 1)
|
43 |
-
netG.load_state_dict(torch.load("berea_generator_epoch_24.pth", map_location=torch.device('cpu')))
|
44 |
-
z = torch.randn(1, 512, 1, 1, 1)
|
45 |
-
with torch.no_grad():
|
46 |
-
X = netG(z)
|
47 |
-
|
48 |
-
img = 1-(X[0, 0].numpy()+1)/2
|
49 |
-
|
50 |
-
a = 0.9
|
51 |
-
|
52 |
-
# create a uniform grid to sample the function with
|
53 |
-
x_min, y_min, z_min = 0, 0, 0
|
54 |
-
grid = pv.UniformGrid(
|
55 |
-
dims=img.shape,
|
56 |
-
spacing=(1, 1, 1),
|
57 |
-
origin=(x_min, y_min, z_min),
|
58 |
-
)
|
59 |
-
x, y, z = grid.points.T
|
60 |
-
|
61 |
-
# sample and plot
|
62 |
-
values = img.flatten()
|
63 |
-
grid.point_data['my_array'] = values
|
64 |
-
slices = grid.slice_orthogonal()
|
65 |
-
mesh = grid.contour(1, values, method='marching_cubes', rng=[1, 0], preference="points")
|
66 |
-
dist = np.linalg.norm(mesh.points, axis=1)
|
67 |
-
|
68 |
-
|
69 |
pl = pv.Plotter(shape=(1, 1),
|
70 |
window_size=(400, 400))
|
71 |
_ = pl.add_mesh(slices, cmap="gray")
|
@@ -76,19 +109,21 @@ pl = pv.Plotter(shape=(1, 1),
|
|
76 |
_ = pl.add_mesh(mesh, scalars=dist)
|
77 |
pl.export_html('mesh.html')
|
78 |
|
|
|
|
|
|
|
79 |
|
80 |
view_width = 400
|
81 |
view_height = 400
|
82 |
|
83 |
HtmlFile = open("slices.html", 'r', encoding='utf-8')
|
84 |
source_code = HtmlFile.read()
|
85 |
-
|
86 |
st.header("3D Intersections")
|
87 |
components.html(source_code, width=view_width, height=view_height)
|
88 |
st.markdown("_Click and drag to spin, right click to shift._")
|
|
|
89 |
HtmlFile = open("mesh.html", 'r', encoding='utf-8')
|
90 |
source_code = HtmlFile.read()
|
91 |
-
|
92 |
st.header("3D Pore Space Mesh")
|
93 |
components.html(source_code, width=view_width, height=view_height)
|
94 |
st.markdown("_Click and drag to spin, right click to shift._")
|
|
|
1 |
import streamlit as st
|
2 |
+
import streamlit.components.v1 as components
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
import pyvista as pv
|
|
|
5 |
import torch
|
6 |
import requests
|
|
|
7 |
import numpy as np
|
8 |
+
import numpy.typing as npt
|
9 |
+
from dcgan import DCGAN3D_G
|
10 |
+
|
11 |
+
|
12 |
+
def download_checkpoint(url: str, path: str) -> None:
|
13 |
+
resp = requests.get(url)
|
14 |
+
|
15 |
+
with open(path, 'wb') as f:
|
16 |
+
f.write(resp.content)
|
17 |
+
|
18 |
+
|
19 |
+
def generate_image(path: str,
|
20 |
+
image_size: int = 64,
|
21 |
+
z_dim: int = 512,
|
22 |
+
n_channels: int = 1,
|
23 |
+
n_features: int = 32,
|
24 |
+
ngpu: int = 1,
|
25 |
+
latent_size: int = 3) -> npt.ArrayLike:
|
26 |
+
netG = DCGAN3D_G(image_size, z_dim, n_channels, n_features, ngpu)
|
27 |
+
netG.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
|
28 |
+
z = torch.randn(1, z_dim, latent_size, latent_size, latent_size)
|
29 |
+
with torch.no_grad():
|
30 |
+
X = netG(z)
|
31 |
+
img = 1 - (X[0, 0].numpy() + 1) / 2
|
32 |
+
return img
|
33 |
+
|
34 |
+
|
35 |
+
def create_uniform_mesh_marching_cubes(img: npt.ArrayLike):
|
36 |
+
grid = pv.UniformGrid(
|
37 |
+
dims=img.shape,
|
38 |
+
spacing=(1, 1, 1),
|
39 |
+
origin=(0, 0, 0),
|
40 |
+
)
|
41 |
+
|
42 |
+
values = img.flatten()
|
43 |
+
grid.point_data['my_array'] = values
|
44 |
+
slices = grid.slice_orthogonal()
|
45 |
+
mesh = grid.contour(1, values, method='marching_cubes', rng=[1, 0], preference="points")
|
46 |
+
dist = np.linalg.norm(mesh.points, axis=1)
|
47 |
+
return slices, mesh, dist
|
48 |
+
|
49 |
+
|
50 |
+
def create_matplotlib_figure(img: npt.ArrayLike, midpoint: int):
|
51 |
+
fig, ax = plt.subplots(1, 3, figsize=(18, 6))
|
52 |
+
ax[0].imshow(img[midpoint], cmap="gray", vmin=0, vmax=1)
|
53 |
+
ax[1].imshow(img[:, midpoint], cmap="gray", vmin=0, vmax=1)
|
54 |
+
ax[2].imshow(img[..., midpoint], cmap="gray", vmin=0, vmax=1)
|
55 |
+
|
56 |
+
for a, title in zip(ax, ["Front", "Right", "Top"]):
|
57 |
+
a.set_title(title, fontsize=18)
|
58 |
+
|
59 |
+
for a in ax:
|
60 |
+
a.set_axis_off()
|
61 |
+
return fig
|
62 |
|
63 |
st.title("Generating Porous Media with GANs")
|
64 |
|
65 |
st.markdown(
|
66 |
"""
|
67 |
### Author
|
68 |
+
_[Lukas Mosser](https://scholar.google.com/citations?user=y0R9snMAAAAJ&hl=en&oi=ao) (2022)_ - :bird:[porestar](https://twitter.com/porestar)
|
69 |
|
70 |
## Description
|
71 |
This is a demo of the Generative Adversarial Network (GAN, [Goodfellow 2014](https://arxiv.org/abs/1406.2661)) trained for our publication [PorousMediaGAN](https://github.com/LukasMosser/PorousMediaGan)
|
|
|
73 |
|
74 |
The model is a pretrained 3D Deep Convolutional GAN ([Radford 2015](https://arxiv.org/abs/1511.06434)) that generates a volumetric image of a porous medium, here a Berea sandstone, from a set of pretrained weights.
|
75 |
|
76 |
+
## Intent
|
77 |
+
I hope this encourages others to create interactive demos of their research for knowledge sharing and validation.
|
78 |
+
|
79 |
## The Demo
|
80 |
Slices through the 3D volume are rendered using [PyVista](https://www.pyvista.org/) and [PyThreeJS](https://pythreejs.readthedocs.io/en/stable/)
|
81 |
|
82 |
The model itself currently runs on the :hugging_face: [Huggingface Spaces](https://huggingface.co/spaces) instance.
|
83 |
Future migration to the :hugging_face: [Huggingface Models](https://huggingface.co/models) repository is possible.
|
84 |
+
|
85 |
+
### Interactive Model Parameters
|
86 |
+
The GAN used here in this study is fully convolutional "_Look Ma' no MLP's_": Changing the spatial extent of the latent space vector _z_
|
87 |
+
allows one to generate larger synthetic images.
|
88 |
+
|
89 |
"""
|
90 |
, unsafe_allow_html=True)
|
91 |
|
92 |
+
model_fname = "berea_generator_epoch_24.pth"
|
93 |
+
checkpoint_url = "https://github.com/LukasMosser/PorousMediaGan/blob/master/checkpoints/berea/{0:}?raw=true".format(model_fname)
|
94 |
|
95 |
+
download_checkpoint(checkpoint_url, model_fname)
|
|
|
96 |
|
97 |
+
latent_size = st.slider("Latent Space Size z", min_value=1, max_value=5, step=1)
|
98 |
+
img = generate_image(model_fname, latent_size=latent_size)
|
99 |
+
slices, mesh, dist = create_uniform_mesh_marching_cubes(img)
|
100 |
|
101 |
pv.set_plot_theme("document")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
pl = pv.Plotter(shape=(1, 1),
|
103 |
window_size=(400, 400))
|
104 |
_ = pl.add_mesh(slices, cmap="gray")
|
|
|
109 |
_ = pl.add_mesh(mesh, scalars=dist)
|
110 |
pl.export_html('mesh.html')
|
111 |
|
112 |
+
st.header("2D Cross-Section of Generated Volume")
|
113 |
+
fig = create_matplotlib_figure(img, img.shape[0]//2)
|
114 |
+
st.pyplot(fig=fig)
|
115 |
|
116 |
view_width = 400
|
117 |
view_height = 400
|
118 |
|
119 |
HtmlFile = open("slices.html", 'r', encoding='utf-8')
|
120 |
source_code = HtmlFile.read()
|
|
|
121 |
st.header("3D Intersections")
|
122 |
components.html(source_code, width=view_width, height=view_height)
|
123 |
st.markdown("_Click and drag to spin, right click to shift._")
|
124 |
+
|
125 |
HtmlFile = open("mesh.html", 'r', encoding='utf-8')
|
126 |
source_code = HtmlFile.read()
|
|
|
127 |
st.header("3D Pore Space Mesh")
|
128 |
components.html(source_code, width=view_width, height=view_height)
|
129 |
st.markdown("_Click and drag to spin, right click to shift._")
|