Spaces:
Running
Running
Hikmat Farhat
commited on
Commit
·
e67ffd0
1
Parent(s):
b862576
added code
Browse files- gui.py +28 -0
- requirements.txt +2 -0
- utils.py +12 -0
gui.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as slt
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
from torchvision.utils import make_grid
|
5 |
+
from utils import norm,recover_image
|
6 |
+
# it seems streamlit reruns the whole script when an event occurs
|
7 |
+
from transformers import AutoModel
|
8 |
+
if "generator" not in slt.session_state:
|
9 |
+
slt.session_state.generator=AutoModel.from_pretrained("hikmatfarhat/WGANGP_generator",trust_remote_code=True)
|
10 |
+
slt.markdown("# Choose the number of images to generate from the sidebar")
|
11 |
+
def display_images():
|
12 |
+
with torch.no_grad():
|
13 |
+
total=slt.session_state.size
|
14 |
+
rows=int(math.sqrt(total))
|
15 |
+
noise=torch.randn(total,128, 1, 1)
|
16 |
+
generator=slt.session_state.generator
|
17 |
+
fake_images=generator(noise)
|
18 |
+
res=make_grid(fake_images,nrow=rows,padding=2,normalize=True)
|
19 |
+
norm(res)
|
20 |
+
img=recover_image(res)
|
21 |
+
slt.session_state.image=img
|
22 |
+
#slt.image(img)
|
23 |
+
if 'image' in slt.session_state:
|
24 |
+
slt.image(slt.session_state.image)
|
25 |
+
|
26 |
+
slt.sidebar.selectbox("Select number of images",[16,32,64],key="size")
|
27 |
+
|
28 |
+
slt.sidebar.button("Generate images",on_click=display_images)#,args=[slt.session_state.total])
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
utils.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
def norm(img):
|
5 |
+
low=float(img.min())
|
6 |
+
high=float(img.max())
|
7 |
+
img.sub_(low).div_(max(high - low, 1e-5))
|
8 |
+
|
9 |
+
def recover_image(tensor):
|
10 |
+
tensor=tensor.cpu().numpy().transpose(1, 2,0)*255
|
11 |
+
|
12 |
+
return Image.fromarray(tensor.astype(np.uint8))
|