minhalvp commited on
Commit
e161624
·
1 Parent(s): a6b370f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -0
app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch.nn as nn
3
+ import torchvision.transforms as T
4
+ from torchvision.utils import make_grid
5
+ import torch
6
+ from IPython.display import display
7
+
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ @torch.inference_mode()
11
+ def inference_gan():
12
+ generator = torch.jit.load("models/mnist-G-torchscript.pt").to(device)
13
+ x = torch.randn(30, 256, device='cuda')
14
+ y = generator(x)
15
+ y = y.view(-1, 1, 28, 28) # reshape y to have 1 channel
16
+ grid = make_grid(y.cpu().detach(), nrow=8)
17
+ img = T.functional.to_pil_image(grid)
18
+ return img
19
+
20
+ @torch.inference_mode()
21
+ def inference_dcgan():
22
+ generator = torch.jit.load("models/animefacedataset-G2-torchscript.pt").to(device)
23
+ def denorm(img_tensors):
24
+ stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
25
+ return img_tensors * stats[1][0] + stats[0][0]
26
+ x = torch.randn(64, 128, 1, 1, device='cuda')
27
+ y = generator(x)
28
+ y = y.view(-1, 3, 64, 64) # reshape y to have 3 channels
29
+ grid = make_grid(denorm(y.cpu().detach()), nrow=8)
30
+ img = T.functional.to_pil_image(grid)
31
+ return img
32
+ def inference_both():
33
+ inference_gan()
34
+ inference_dcgan()
35
+
36
+ st.markdown("# Image Generation with GANs and DCGANs")
37
+ st.button("Generate Images", on_click=inference_both)
38
+ st.image(inference_dcgan(), caption="", use_column_width=True)
39
+ st.image(inference_gan(), caption="", use_column_width=True)