lmoss commited on
Commit
3e26a38
·
1 Parent(s): 9bc9fb6

added interactive latent space dim

Browse files
Files changed (1) hide show
  1. app.py +75 -40
app.py CHANGED
@@ -1,18 +1,71 @@
1
  import streamlit as st
 
 
2
  import pyvista as pv
3
- from dcgan import DCGAN3D_G
4
  import torch
5
  import requests
6
- import time
7
  import numpy as np
8
- import streamlit.components.v1 as components
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  st.title("Generating Porous Media with GANs")
11
 
12
  st.markdown(
13
  """
14
  ### Author
15
- _Lukas Mosser (2022)_ - :bird:[porestar](https://twitter.com/porestar)
16
 
17
  ## Description
18
  This is a demo of the Generative Adversarial Network (GAN, [Goodfellow 2014](https://arxiv.org/abs/1406.2661)) trained for our publication [PorousMediaGAN](https://github.com/LukasMosser/PorousMediaGan)
@@ -20,52 +73,32 @@ st.markdown(
20
 
21
  The model is a pretrained 3D Deep Convolutional GAN ([Radford 2015](https://arxiv.org/abs/1511.06434)) that generates a volumetric image of a porous medium, here a Berea sandstone, from a set of pretrained weights.
22
 
 
 
 
23
  ## The Demo
24
  Slices through the 3D volume are rendered using [PyVista](https://www.pyvista.org/) and [PyThreeJS](https://pythreejs.readthedocs.io/en/stable/)
25
 
26
  The model itself currently runs on the :hugging_face: [Huggingface Spaces](https://huggingface.co/spaces) instance.
27
  Future migration to the :hugging_face: [Huggingface Models](https://huggingface.co/models) repository is possible.
 
 
 
 
 
28
  """
29
  , unsafe_allow_html=True)
30
 
31
- url = "https://github.com/LukasMosser/PorousMediaGan/blob/master/checkpoints/berea/berea_generator_epoch_24.pth?raw=true"
 
32
 
33
- # If repo is private - we need to add a token in header:
34
- resp = requests.get(url)
35
 
36
- with open('berea_generator_epoch_24.pth', 'wb') as f:
37
- f.write(resp.content)
 
38
 
39
  pv.set_plot_theme("document")
40
-
41
-
42
- netG = DCGAN3D_G(64, 512, 1, 32, 1)
43
- netG.load_state_dict(torch.load("berea_generator_epoch_24.pth", map_location=torch.device('cpu')))
44
- z = torch.randn(1, 512, 1, 1, 1)
45
- with torch.no_grad():
46
- X = netG(z)
47
-
48
- img = 1-(X[0, 0].numpy()+1)/2
49
-
50
- a = 0.9
51
-
52
- # create a uniform grid to sample the function with
53
- x_min, y_min, z_min = 0, 0, 0
54
- grid = pv.UniformGrid(
55
- dims=img.shape,
56
- spacing=(1, 1, 1),
57
- origin=(x_min, y_min, z_min),
58
- )
59
- x, y, z = grid.points.T
60
-
61
- # sample and plot
62
- values = img.flatten()
63
- grid.point_data['my_array'] = values
64
- slices = grid.slice_orthogonal()
65
- mesh = grid.contour(1, values, method='marching_cubes', rng=[1, 0], preference="points")
66
- dist = np.linalg.norm(mesh.points, axis=1)
67
-
68
-
69
  pl = pv.Plotter(shape=(1, 1),
70
  window_size=(400, 400))
71
  _ = pl.add_mesh(slices, cmap="gray")
@@ -76,19 +109,21 @@ pl = pv.Plotter(shape=(1, 1),
76
  _ = pl.add_mesh(mesh, scalars=dist)
77
  pl.export_html('mesh.html')
78
 
 
 
 
79
 
80
  view_width = 400
81
  view_height = 400
82
 
83
  HtmlFile = open("slices.html", 'r', encoding='utf-8')
84
  source_code = HtmlFile.read()
85
-
86
  st.header("3D Intersections")
87
  components.html(source_code, width=view_width, height=view_height)
88
  st.markdown("_Click and drag to spin, right click to shift._")
 
89
  HtmlFile = open("mesh.html", 'r', encoding='utf-8')
90
  source_code = HtmlFile.read()
91
-
92
  st.header("3D Pore Space Mesh")
93
  components.html(source_code, width=view_width, height=view_height)
94
  st.markdown("_Click and drag to spin, right click to shift._")
 
1
  import streamlit as st
2
+ import streamlit.components.v1 as components
3
+ import matplotlib.pyplot as plt
4
  import pyvista as pv
 
5
  import torch
6
  import requests
 
7
  import numpy as np
8
+ import numpy.typing as npt
9
+ from dcgan import DCGAN3D_G
10
+
11
+
12
+ def download_checkpoint(url: str, path: str) -> None:
13
+ resp = requests.get(url)
14
+
15
+ with open(path, 'wb') as f:
16
+ f.write(resp.content)
17
+
18
+
19
+ def generate_image(path: str,
20
+ image_size: int = 64,
21
+ z_dim: int = 512,
22
+ n_channels: int = 1,
23
+ n_features: int = 32,
24
+ ngpu: int = 1,
25
+ latent_size: int = 3) -> npt.ArrayLike:
26
+ netG = DCGAN3D_G(image_size, z_dim, n_channels, n_features, ngpu)
27
+ netG.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
28
+ z = torch.randn(1, z_dim, latent_size, latent_size, latent_size)
29
+ with torch.no_grad():
30
+ X = netG(z)
31
+ img = 1 - (X[0, 0].numpy() + 1) / 2
32
+ return img
33
+
34
+
35
+ def create_uniform_mesh_marching_cubes(img: npt.ArrayLike):
36
+ grid = pv.UniformGrid(
37
+ dims=img.shape,
38
+ spacing=(1, 1, 1),
39
+ origin=(0, 0, 0),
40
+ )
41
+
42
+ values = img.flatten()
43
+ grid.point_data['my_array'] = values
44
+ slices = grid.slice_orthogonal()
45
+ mesh = grid.contour(1, values, method='marching_cubes', rng=[1, 0], preference="points")
46
+ dist = np.linalg.norm(mesh.points, axis=1)
47
+ return slices, mesh, dist
48
+
49
+
50
+ def create_matplotlib_figure(img: npt.ArrayLike, midpoint: int):
51
+ fig, ax = plt.subplots(1, 3, figsize=(18, 6))
52
+ ax[0].imshow(img[midpoint], cmap="gray", vmin=0, vmax=1)
53
+ ax[1].imshow(img[:, midpoint], cmap="gray", vmin=0, vmax=1)
54
+ ax[2].imshow(img[..., midpoint], cmap="gray", vmin=0, vmax=1)
55
+
56
+ for a, title in zip(ax, ["Front", "Right", "Top"]):
57
+ a.set_title(title, fontsize=18)
58
+
59
+ for a in ax:
60
+ a.set_axis_off()
61
+ return fig
62
 
63
  st.title("Generating Porous Media with GANs")
64
 
65
  st.markdown(
66
  """
67
  ### Author
68
+ _[Lukas Mosser](https://scholar.google.com/citations?user=y0R9snMAAAAJ&hl=en&oi=ao) (2022)_ - :bird:[porestar](https://twitter.com/porestar)
69
 
70
  ## Description
71
  This is a demo of the Generative Adversarial Network (GAN, [Goodfellow 2014](https://arxiv.org/abs/1406.2661)) trained for our publication [PorousMediaGAN](https://github.com/LukasMosser/PorousMediaGan)
 
73
 
74
  The model is a pretrained 3D Deep Convolutional GAN ([Radford 2015](https://arxiv.org/abs/1511.06434)) that generates a volumetric image of a porous medium, here a Berea sandstone, from a set of pretrained weights.
75
 
76
+ ## Intent
77
+ I hope this encourages others to create interactive demos of their research for knowledge sharing and validation.
78
+
79
  ## The Demo
80
  Slices through the 3D volume are rendered using [PyVista](https://www.pyvista.org/) and [PyThreeJS](https://pythreejs.readthedocs.io/en/stable/)
81
 
82
  The model itself currently runs on the :hugging_face: [Huggingface Spaces](https://huggingface.co/spaces) instance.
83
  Future migration to the :hugging_face: [Huggingface Models](https://huggingface.co/models) repository is possible.
84
+
85
+ ### Interactive Model Parameters
86
+ The GAN used here in this study is fully convolutional "_Look Ma' no MLP's_": Changing the spatial extent of the latent space vector _z_
87
+ allows one to generate larger synthetic images.
88
+
89
  """
90
  , unsafe_allow_html=True)
91
 
92
+ model_fname = "berea_generator_epoch_24.pth"
93
+ checkpoint_url = "https://github.com/LukasMosser/PorousMediaGan/blob/master/checkpoints/berea/{0:}?raw=true".format(model_fname)
94
 
95
+ download_checkpoint(checkpoint_url, model_fname)
 
96
 
97
+ latent_size = st.slider("Latent Space Size z", min_value=1, max_value=5, step=1)
98
+ img = generate_image(model_fname, latent_size=latent_size)
99
+ slices, mesh, dist = create_uniform_mesh_marching_cubes(img)
100
 
101
  pv.set_plot_theme("document")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  pl = pv.Plotter(shape=(1, 1),
103
  window_size=(400, 400))
104
  _ = pl.add_mesh(slices, cmap="gray")
 
109
  _ = pl.add_mesh(mesh, scalars=dist)
110
  pl.export_html('mesh.html')
111
 
112
+ st.header("2D Cross-Section of Generated Volume")
113
+ fig = create_matplotlib_figure(img, img.shape[0]//2)
114
+ st.pyplot(fig=fig)
115
 
116
  view_width = 400
117
  view_height = 400
118
 
119
  HtmlFile = open("slices.html", 'r', encoding='utf-8')
120
  source_code = HtmlFile.read()
 
121
  st.header("3D Intersections")
122
  components.html(source_code, width=view_width, height=view_height)
123
  st.markdown("_Click and drag to spin, right click to shift._")
124
+
125
  HtmlFile = open("mesh.html", 'r', encoding='utf-8')
126
  source_code = HtmlFile.read()
 
127
  st.header("3D Pore Space Mesh")
128
  components.html(source_code, width=view_width, height=view_height)
129
  st.markdown("_Click and drag to spin, right click to shift._")