dmolino commited on
Commit
ea5427d
·
verified ·
1 Parent(s): beadcf9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -90
app.py CHANGED
@@ -129,15 +129,7 @@ if 'generate' not in st.session_state:
129
 
130
  # Inizializza inference_tester solo una volta
131
  if 'inference_tester' not in st.session_state:
132
- model_load_paths = ['CoDi_encoders.pth', 'CoDi_text_diffuser.pth', 'CoDi_video_diffuser_8frames.pth']
133
- st.session_state['inference_tester'] = dani_model(model='thesis_model',
134
- data_dir='/mimer/NOBACKUP/groups/snic2022-5-277/dmolino/checkpoints/',
135
- pth=model_load_paths, load_weights=False)
136
- inference_tester = st.session_state['inference_tester']
137
-
138
- # Caricamento dei pesi Clip, Optimus, Frontal, Lateral e Text una sola volta
139
- if 'weights_loaded' not in st.session_state:
140
- st.session_state['weights_loaded'] = True # Indica che i pesi sono stati caricati
141
 
142
  # Usa inference_tester dalla sessione
143
  inference_tester = st.session_state['inference_tester']
@@ -209,18 +201,12 @@ if st.session_state['step'] == 2:
209
 
210
  # Pulsante per provare un esempio
211
  with col1:
212
- if st.button("Inference"):
213
- st.session_state['step'] = 3 # Passa al passo 3
214
- st.rerun()
215
-
216
- # Pulsante per provare un esempio
217
- with col2:
218
  if st.button("Try an example"):
219
  st.session_state['step'] = 5 # Passa al passo 5
220
  st.rerun()
221
 
222
  # Pulsante per tornare all'inizio
223
- with col3:
224
  if st.button("Return to the beginning"):
225
  # Ripristina lo stato della sessione
226
  st.session_state['step'] = 1
@@ -378,79 +364,8 @@ if st.session_state['step'] == 3:
378
  st.rerun()
379
 
380
  if st.session_state['step'] == 4:
381
- # Costruzione del prompt
382
- if st.session_state['generate'] is True:
383
- conditioning = []
384
- for inp in st.session_state['inputs']:
385
- if inp == 'frontal':
386
- cim = inference_tester.net.clip_encode_vision(st.session_state['frontal'], encode_type='encode_vision').to(device)
387
- uim = inference_tester.net.clip_encode_vision(torch.zeros_like(st.session_state['frontal']).to(device),
388
- encode_type='encode_vision').to(device)
389
- conditioning.append(torch.cat([uim, cim]))
390
- elif inp == 'lateral':
391
- cim = inference_tester.net.clip_encode_vision(st.session_state['lateral'], encode_type='encode_vision').to(device)
392
- uim = inference_tester.net.clip_encode_vision(torch.zeros_like(st.session_state['lateral']).to(device),
393
- encode_type='encode_vision').to(device)
394
- conditioning.append(torch.cat([uim, cim]))
395
- elif inp == 'text':
396
- ctx = inference_tester.net.clip_encode_text(1 * [st.session_state['report']], encode_type='encode_text').to(device)
397
- utx = inference_tester.net.clip_encode_text(1 * [""], encode_type='encode_text').to(device)
398
- conditioning.append(torch.cat([utx, ctx]))
399
-
400
- # Costruzione delle shapes
401
- shapes = []
402
- for out in st.session_state['outputs']:
403
- if out == 'frontal' or out == 'lateral':
404
- shape = [1, 4, 256 // 8, 256 // 8]
405
- shapes.append(shape)
406
- elif out == 'text':
407
- shape = [1, 768]
408
- shapes.append(shape)
409
-
410
- progress_bar = st.progress(0)
411
-
412
- # Inferenza
413
- z, _ = inference_tester.sampler.sample(
414
- steps=50,
415
- shape=shapes,
416
- condition=conditioning,
417
- unconditional_guidance_scale=7.5,
418
- xtype=st.session_state['outputs'],
419
- condition_types=st.session_state['inputs'],
420
- eta=1,
421
- verbose=False,
422
- mix_weight={'lateral': 1, 'text': 1, 'frontal': 1},
423
- progress_bar=progress_bar)
424
-
425
- # Decoder e visualizzazione dei risultati
426
- output_cols = st.columns(len(st.session_state['outputs']))
427
-
428
- # Definire due colonne per le immagini
429
- col1, col2 = st.columns(2)
430
-
431
- # Iterare sugli output e assegnare le immagini alle colonne corrispondenti
432
- for i, out in enumerate(st.session_state['outputs']):
433
- if out == 'frontal':
434
- x = inference_tester.net.autokl_decode(z[i])
435
- x = torch.clamp((x[0] + 1.0) / 2.0, min=0.0, max=1.0)
436
- im = x[0].cpu().numpy()
437
- with col1: # Mostrare la frontal image nella prima colonna
438
- st.image(im, caption="Generated Frontal Image")
439
- elif out == 'lateral':
440
- x = inference_tester.net.autokl_decode(z[i])
441
- x = torch.clamp((x[0] + 1.0) / 2.0, min=0.0, max=1.0)
442
- im = x[0].cpu().numpy()
443
- with col2: # Mostrare la lateral image nella seconda colonna
444
- st.image(im, caption="Generated Lateral Image")
445
- elif out == 'text':
446
- x = inference_tester.net.optimus_decode(z[i], max_length=100)
447
- x = [a.tolist() for a in x]
448
- rec_text = [inference_tester.net.optimus.tokenizer_decoder.decode(a) for a in x]
449
- rec_text = rec_text[0].replace('<BOS>', '').replace('<EOS>', '')
450
- st.write(f"Generated Report: {rec_text}")
451
-
452
- st.write("Generation completed successfully!")
453
- st.session_state['generate'] = False
454
 
455
  if st.button("Return to the beginning"):
456
  # Ripristina lo stato della sessione
@@ -564,4 +479,4 @@ if st.session_state['step'] == 5:
564
  st.session_state['frontal_file'] = None
565
  st.session_state['lateral_file'] = None
566
  st.session_state['report'] = ""
567
- st.rerun()
 
129
 
130
  # Inizializza inference_tester solo una volta
131
  if 'inference_tester' not in st.session_state:
132
+ st.session_state['inference_tester'] = 1
 
 
 
 
 
 
 
 
133
 
134
  # Usa inference_tester dalla sessione
135
  inference_tester = st.session_state['inference_tester']
 
201
 
202
  # Pulsante per provare un esempio
203
  with col1:
 
 
 
 
 
 
204
  if st.button("Try an example"):
205
  st.session_state['step'] = 5 # Passa al passo 5
206
  st.rerun()
207
 
208
  # Pulsante per tornare all'inizio
209
+ with col2:
210
  if st.button("Return to the beginning"):
211
  # Ripristina lo stato della sessione
212
  st.session_state['step'] = 1
 
364
  st.rerun()
365
 
366
  if st.session_state['step'] == 4:
367
+ st.write("Generation completed successfully!")
368
+ st.session_state['generate'] = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
 
370
  if st.button("Return to the beginning"):
371
  # Ripristina lo stato della sessione
 
479
  st.session_state['frontal_file'] = None
480
  st.session_state['lateral_file'] = None
481
  st.session_state['report'] = ""
482
+ st.rerun()