Spaces:
Runtime error
Runtime error
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision.transforms as T
|
4 |
+
from torchvision.utils import make_grid
|
5 |
+
import torch
|
6 |
+
from IPython.display import display
|
7 |
+
|
8 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
9 |
+
|
10 |
+
@torch.inference_mode()
|
11 |
+
def inference_gan():
|
12 |
+
generator = torch.jit.load("models/mnist-G-torchscript.pt").to(device)
|
13 |
+
x = torch.randn(30, 256, device='cuda')
|
14 |
+
y = generator(x)
|
15 |
+
y = y.view(-1, 1, 28, 28) # reshape y to have 1 channel
|
16 |
+
grid = make_grid(y.cpu().detach(), nrow=8)
|
17 |
+
img = T.functional.to_pil_image(grid)
|
18 |
+
return img
|
19 |
+
|
20 |
+
@torch.inference_mode()
|
21 |
+
def inference_dcgan():
|
22 |
+
generator = torch.jit.load("models/animefacedataset-G2-torchscript.pt").to(device)
|
23 |
+
def denorm(img_tensors):
|
24 |
+
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
|
25 |
+
return img_tensors * stats[1][0] + stats[0][0]
|
26 |
+
x = torch.randn(64, 128, 1, 1, device='cuda')
|
27 |
+
y = generator(x)
|
28 |
+
y = y.view(-1, 3, 64, 64) # reshape y to have 3 channels
|
29 |
+
grid = make_grid(denorm(y.cpu().detach()), nrow=8)
|
30 |
+
img = T.functional.to_pil_image(grid)
|
31 |
+
return img
|
32 |
+
def inference_both():
|
33 |
+
inference_gan()
|
34 |
+
inference_dcgan()
|
35 |
+
|
36 |
+
st.markdown("# Image Generation with GANs and DCGANs")
|
37 |
+
st.button("Generate Images", on_click=inference_both)
|
38 |
+
st.image(inference_dcgan(), caption="", use_column_width=True)
|
39 |
+
st.image(inference_gan(), caption="", use_column_width=True)
|