Hikmat Farhat commited on
Commit
e67ffd0
·
1 Parent(s): b862576

added code

Browse files
Files changed (3) hide show
  1. gui.py +28 -0
  2. requirements.txt +2 -0
  3. utils.py +12 -0
gui.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as slt
2
+ import math
3
+ import torch
4
+ from torchvision.utils import make_grid
5
+ from utils import norm,recover_image
6
+ # it seems streamlit reruns the whole script when an event occurs
7
+ from transformers import AutoModel
8
+ if "generator" not in slt.session_state:
9
+ slt.session_state.generator=AutoModel.from_pretrained("hikmatfarhat/WGANGP_generator",trust_remote_code=True)
10
+ slt.markdown("# Choose the number of images to generate from the sidebar")
11
+ def display_images():
12
+ with torch.no_grad():
13
+ total=slt.session_state.size
14
+ rows=int(math.sqrt(total))
15
+ noise=torch.randn(total,128, 1, 1)
16
+ generator=slt.session_state.generator
17
+ fake_images=generator(noise)
18
+ res=make_grid(fake_images,nrow=rows,padding=2,normalize=True)
19
+ norm(res)
20
+ img=recover_image(res)
21
+ slt.session_state.image=img
22
+ #slt.image(img)
23
+ if 'image' in slt.session_state:
24
+ slt.image(slt.session_state.image)
25
+
26
+ slt.sidebar.selectbox("Select number of images",[16,32,64],key="size")
27
+
28
+ slt.sidebar.button("Generate images",on_click=display_images)#,args=[slt.session_state.total])
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ torchvision
utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ def norm(img):
5
+ low=float(img.min())
6
+ high=float(img.max())
7
+ img.sub_(low).div_(max(high - low, 1e-5))
8
+
9
+ def recover_image(tensor):
10
+ tensor=tensor.cpu().numpy().transpose(1, 2,0)*255
11
+
12
+ return Image.fromarray(tensor.astype(np.uint8))