Update app.py
Browse files
app.py
CHANGED
@@ -129,15 +129,7 @@ if 'generate' not in st.session_state:
|
|
129 |
|
130 |
# Inizializza inference_tester solo una volta
|
131 |
if 'inference_tester' not in st.session_state:
|
132 |
-
|
133 |
-
st.session_state['inference_tester'] = dani_model(model='thesis_model',
|
134 |
-
data_dir='/mimer/NOBACKUP/groups/snic2022-5-277/dmolino/checkpoints/',
|
135 |
-
pth=model_load_paths, load_weights=False)
|
136 |
-
inference_tester = st.session_state['inference_tester']
|
137 |
-
|
138 |
-
# Caricamento dei pesi Clip, Optimus, Frontal, Lateral e Text una sola volta
|
139 |
-
if 'weights_loaded' not in st.session_state:
|
140 |
-
st.session_state['weights_loaded'] = True # Indica che i pesi sono stati caricati
|
141 |
|
142 |
# Usa inference_tester dalla sessione
|
143 |
inference_tester = st.session_state['inference_tester']
|
@@ -209,18 +201,12 @@ if st.session_state['step'] == 2:
|
|
209 |
|
210 |
# Pulsante per provare un esempio
|
211 |
with col1:
|
212 |
-
if st.button("Inference"):
|
213 |
-
st.session_state['step'] = 3 # Passa al passo 3
|
214 |
-
st.rerun()
|
215 |
-
|
216 |
-
# Pulsante per provare un esempio
|
217 |
-
with col2:
|
218 |
if st.button("Try an example"):
|
219 |
st.session_state['step'] = 5 # Passa al passo 5
|
220 |
st.rerun()
|
221 |
|
222 |
# Pulsante per tornare all'inizio
|
223 |
-
with
|
224 |
if st.button("Return to the beginning"):
|
225 |
# Ripristina lo stato della sessione
|
226 |
st.session_state['step'] = 1
|
@@ -378,79 +364,8 @@ if st.session_state['step'] == 3:
|
|
378 |
st.rerun()
|
379 |
|
380 |
if st.session_state['step'] == 4:
|
381 |
-
|
382 |
-
|
383 |
-
conditioning = []
|
384 |
-
for inp in st.session_state['inputs']:
|
385 |
-
if inp == 'frontal':
|
386 |
-
cim = inference_tester.net.clip_encode_vision(st.session_state['frontal'], encode_type='encode_vision').to(device)
|
387 |
-
uim = inference_tester.net.clip_encode_vision(torch.zeros_like(st.session_state['frontal']).to(device),
|
388 |
-
encode_type='encode_vision').to(device)
|
389 |
-
conditioning.append(torch.cat([uim, cim]))
|
390 |
-
elif inp == 'lateral':
|
391 |
-
cim = inference_tester.net.clip_encode_vision(st.session_state['lateral'], encode_type='encode_vision').to(device)
|
392 |
-
uim = inference_tester.net.clip_encode_vision(torch.zeros_like(st.session_state['lateral']).to(device),
|
393 |
-
encode_type='encode_vision').to(device)
|
394 |
-
conditioning.append(torch.cat([uim, cim]))
|
395 |
-
elif inp == 'text':
|
396 |
-
ctx = inference_tester.net.clip_encode_text(1 * [st.session_state['report']], encode_type='encode_text').to(device)
|
397 |
-
utx = inference_tester.net.clip_encode_text(1 * [""], encode_type='encode_text').to(device)
|
398 |
-
conditioning.append(torch.cat([utx, ctx]))
|
399 |
-
|
400 |
-
# Costruzione delle shapes
|
401 |
-
shapes = []
|
402 |
-
for out in st.session_state['outputs']:
|
403 |
-
if out == 'frontal' or out == 'lateral':
|
404 |
-
shape = [1, 4, 256 // 8, 256 // 8]
|
405 |
-
shapes.append(shape)
|
406 |
-
elif out == 'text':
|
407 |
-
shape = [1, 768]
|
408 |
-
shapes.append(shape)
|
409 |
-
|
410 |
-
progress_bar = st.progress(0)
|
411 |
-
|
412 |
-
# Inferenza
|
413 |
-
z, _ = inference_tester.sampler.sample(
|
414 |
-
steps=50,
|
415 |
-
shape=shapes,
|
416 |
-
condition=conditioning,
|
417 |
-
unconditional_guidance_scale=7.5,
|
418 |
-
xtype=st.session_state['outputs'],
|
419 |
-
condition_types=st.session_state['inputs'],
|
420 |
-
eta=1,
|
421 |
-
verbose=False,
|
422 |
-
mix_weight={'lateral': 1, 'text': 1, 'frontal': 1},
|
423 |
-
progress_bar=progress_bar)
|
424 |
-
|
425 |
-
# Decoder e visualizzazione dei risultati
|
426 |
-
output_cols = st.columns(len(st.session_state['outputs']))
|
427 |
-
|
428 |
-
# Definire due colonne per le immagini
|
429 |
-
col1, col2 = st.columns(2)
|
430 |
-
|
431 |
-
# Iterare sugli output e assegnare le immagini alle colonne corrispondenti
|
432 |
-
for i, out in enumerate(st.session_state['outputs']):
|
433 |
-
if out == 'frontal':
|
434 |
-
x = inference_tester.net.autokl_decode(z[i])
|
435 |
-
x = torch.clamp((x[0] + 1.0) / 2.0, min=0.0, max=1.0)
|
436 |
-
im = x[0].cpu().numpy()
|
437 |
-
with col1: # Mostrare la frontal image nella prima colonna
|
438 |
-
st.image(im, caption="Generated Frontal Image")
|
439 |
-
elif out == 'lateral':
|
440 |
-
x = inference_tester.net.autokl_decode(z[i])
|
441 |
-
x = torch.clamp((x[0] + 1.0) / 2.0, min=0.0, max=1.0)
|
442 |
-
im = x[0].cpu().numpy()
|
443 |
-
with col2: # Mostrare la lateral image nella seconda colonna
|
444 |
-
st.image(im, caption="Generated Lateral Image")
|
445 |
-
elif out == 'text':
|
446 |
-
x = inference_tester.net.optimus_decode(z[i], max_length=100)
|
447 |
-
x = [a.tolist() for a in x]
|
448 |
-
rec_text = [inference_tester.net.optimus.tokenizer_decoder.decode(a) for a in x]
|
449 |
-
rec_text = rec_text[0].replace('<BOS>', '').replace('<EOS>', '')
|
450 |
-
st.write(f"Generated Report: {rec_text}")
|
451 |
-
|
452 |
-
st.write("Generation completed successfully!")
|
453 |
-
st.session_state['generate'] = False
|
454 |
|
455 |
if st.button("Return to the beginning"):
|
456 |
# Ripristina lo stato della sessione
|
@@ -564,4 +479,4 @@ if st.session_state['step'] == 5:
|
|
564 |
st.session_state['frontal_file'] = None
|
565 |
st.session_state['lateral_file'] = None
|
566 |
st.session_state['report'] = ""
|
567 |
-
st.rerun()
|
|
|
129 |
|
130 |
# Inizializza inference_tester solo una volta
|
131 |
if 'inference_tester' not in st.session_state:
|
132 |
+
st.session_state['inference_tester'] = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
# Usa inference_tester dalla sessione
|
135 |
inference_tester = st.session_state['inference_tester']
|
|
|
201 |
|
202 |
# Pulsante per provare un esempio
|
203 |
with col1:
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
if st.button("Try an example"):
|
205 |
st.session_state['step'] = 5 # Passa al passo 5
|
206 |
st.rerun()
|
207 |
|
208 |
# Pulsante per tornare all'inizio
|
209 |
+
with col2:
|
210 |
if st.button("Return to the beginning"):
|
211 |
# Ripristina lo stato della sessione
|
212 |
st.session_state['step'] = 1
|
|
|
364 |
st.rerun()
|
365 |
|
366 |
if st.session_state['step'] == 4:
|
367 |
+
st.write("Generation completed successfully!")
|
368 |
+
st.session_state['generate'] = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
|
370 |
if st.button("Return to the beginning"):
|
371 |
# Ripristina lo stato della sessione
|
|
|
479 |
st.session_state['frontal_file'] = None
|
480 |
st.session_state['lateral_file'] = None
|
481 |
st.session_state['report'] = ""
|
482 |
+
st.rerun()
|