julenalvaro commited on
Commit
fdc5234
1 Parent(s): 6d7326c
Files changed (2) hide show
  1. app.py +42 -0
  2. utils.py +15 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from utils import carga_modelo, genera
4
+
5
+ ##P谩gina principal
6
+ st.title('Generador de mariposas')
7
+ st.write('Este es un generador de mariposas creado con Huggan y Streamlit')
8
+
9
+ ##Barra lateral
10
+ st.sidebar.subheader('隆Esta mariposa no existe! 驴Puedes creerlo?')
11
+ st.sidebar.image('assets/logo.png', width = 200)
12
+ st.sidebar.caption('Demo creado en vivo')
13
+
14
+ ##Carga del modelo
15
+
16
+ repo_id = 'ceyda/butterfly_cropped_uniq1K_512'
17
+ modelo_gan = carga_modelo(repo_id)
18
+
19
+ ##Generaci贸n de im谩genes de mariposas
20
+ n_mariposas = 4
21
+
22
+ def corre():
23
+ with st.spinner('Cargando modelo...'):
24
+ ims = genera(modelo_gan, n_mariposas):
25
+ st,session_state['ims'] = ims
26
+
27
+ if 'ims' not in st.session_state:
28
+ st.session_state['ims'] = None
29
+ corre()
30
+
31
+ ims = st.session_state['ims']
32
+
33
+ corre_boton = st.button('Generar mariposas',
34
+ on_click = corre,
35
+ help = 'Estamos en vuelo, abrocha tu cintur贸n'
36
+ )
37
+
38
+ if ims is not None:
39
+ cols = st.columns(n_mariposas)
40
+ for j, im in enumerate(ims):
41
+ i = j & n_mariposas
42
+ cols[i].image(im, use_column_width=True)
utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import huggan.pytorch.lightweight_gan import LightweightGan
4
+
5
+ def carga_modelo(model_name = 'ceyda/butterfly_cropped_uniq1K_512', model_version = None)
6
+ gan = LightweightGan(model_name, version = model_version)
7
+ gan.eval()
8
+ return gan
9
+
10
+ def genera(gan, batch_size = 1)
11
+ with torch.no_grad() #no queremos entrenar el modelo
12
+ ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp(0.0,1.0) * 255 #generamos im谩genes y las aplastamos entre 0 y 1
13
+ ims = ims.permute(0,2,3,1).detach().cpu().numpy().astype(np.uint8) #las pasamos a numpy
14
+ return ims
15
+