awacke1 commited on
Commit
6ae4c84
·
verified ·
1 Parent(s): e23373e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -64
app.py CHANGED
@@ -53,10 +53,8 @@ st.set_page_config(
53
  )
54
 
55
  # Initialize st.session_state
56
- if 'captured_files' not in st.session_state:
57
- st.session_state['captured_files'] = {'cam0': None, 'cam1': None} # One file per camera
58
  if 'history' not in st.session_state:
59
- st.session_state['history'] = {'cam0': None, 'cam1': None} # One history entry per camera
60
  if 'builder' not in st.session_state:
61
  st.session_state['builder'] = None
62
  if 'model_loaded' not in st.session_state:
@@ -329,21 +327,9 @@ def get_model_files(model_type="causal_lm"):
329
  path = "models/*" if model_type == "causal_lm" else "diffusion_models/*"
330
  return [d for d in glob.glob(path) if os.path.isdir(d)]
331
 
332
- def get_gallery_files(file_types):
333
  return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
334
 
335
- def download_pdf(url, output_path):
336
- try:
337
- response = requests.get(url, stream=True, timeout=10)
338
- if response.status_code == 200:
339
- with open(output_path, "wb") as f:
340
- for chunk in response.iter_content(chunk_size=8192):
341
- f.write(chunk)
342
- return True
343
- except requests.RequestException as e:
344
- logger.error(f"Failed to download {url}: {e}")
345
- return False
346
-
347
  # Mock Search Tool for RAG
348
  def mock_search(query: str) -> str:
349
  if "superhero" in query.lower():
@@ -445,9 +431,6 @@ async def process_pdf_snapshot(pdf_path, mode="thumbnail"):
445
  output_files.append(output_file)
446
  elapsed = int(time.time() - start_time)
447
  status.text(f"PDF Snapshot ({mode}) completed in {elapsed}s!")
448
- for file in output_files:
449
- if file not in st.session_state['captured_files'].values():
450
- st.session_state['captured_files'][f"pdf_{len(output_files)}"] = file
451
  update_gallery()
452
  return output_files
453
  except Exception as e:
@@ -465,8 +448,6 @@ async def process_ocr(image, output_file):
465
  status.text(f"GOT-OCR2_0 completed in {elapsed}s!")
466
  async with aiofiles.open(output_file, "w") as f:
467
  await f.write(result)
468
- if output_file not in st.session_state['captured_files'].values():
469
- st.session_state['captured_files']['ocr'] = output_file
470
  update_gallery()
471
  return result
472
 
@@ -479,8 +460,6 @@ async def process_image_gen(prompt, output_file):
479
  elapsed = int(time.time() - start_time)
480
  status.text(f"Image Gen completed in {elapsed}s!")
481
  gen_image.save(output_file)
482
- if output_file not in st.session_state['captured_files'].values():
483
- st.session_state['captured_files']['gen'] = output_file
484
  update_gallery()
485
  return gen_image
486
 
@@ -496,8 +475,6 @@ async def process_custom_diffusion(images, output_file, model_name):
496
  elapsed = int(time.time() - start_time)
497
  status.text(f"{model_name} completed in {elapsed}s!")
498
  upscaled_image.save(output_file)
499
- if output_file not in st.session_state['captured_files'].values():
500
- st.session_state['captured_files']['diffusion'] = output_file
501
  update_gallery()
502
  return upscaled_image
503
 
@@ -506,18 +483,14 @@ st.title("AI Vision & SFT Titans 🚀")
506
 
507
  # Sidebar
508
  st.sidebar.header("Captured Files 📜")
509
- gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 2) # Default to 2 for two cameras
510
  def update_gallery():
511
- media_files = [st.session_state['captured_files']['cam0'], st.session_state['captured_files']['cam1']]
512
- valid_files = [f for f in media_files if f and os.path.exists(f)] # Only valid files
513
- if valid_files:
514
  cols = st.sidebar.columns(2)
515
- if st.session_state['captured_files']['cam0'] in valid_files:
516
- with cols[0]:
517
- st.image(Image.open(st.session_state['captured_files']['cam0']), caption="Camera 0", use_container_width=True)
518
- if st.session_state['captured_files']['cam1'] in valid_files:
519
- with cols[1]:
520
- st.image(Image.open(st.session_state['captured_files']['cam1']), caption="Camera 1", use_container_width=True)
521
  update_gallery()
522
 
523
  st.sidebar.subheader("Model Management 🗂️")
@@ -541,8 +514,7 @@ with log_container:
541
  st.sidebar.subheader("History 📜")
542
  history_container = st.sidebar.empty()
543
  with history_container:
544
- valid_history = [st.session_state['history']['cam0'], st.session_state['history']['cam1']]
545
- for entry in [e for e in valid_history if e]: # Show only non-None entries
546
  st.write(entry)
547
 
548
  # Tabs
@@ -561,8 +533,9 @@ with tab1:
561
  filename = generate_filename("cam0")
562
  with open(filename, "wb") as f:
563
  f.write(cam0_img.getvalue())
564
- st.session_state['captured_files']['cam0'] = filename
565
- st.session_state['history']['cam0'] = f"Snapshot from Cam 0: {filename}"
 
566
  st.image(Image.open(filename), caption="Camera 0", use_container_width=True)
567
  logger.info(f"Saved snapshot from Camera 0: {filename}")
568
  update_gallery()
@@ -572,8 +545,9 @@ with tab1:
572
  filename = generate_filename("cam1")
573
  with open(filename, "wb") as f:
574
  f.write(cam1_img.getvalue())
575
- st.session_state['captured_files']['cam1'] = filename
576
- st.session_state['history']['cam1'] = f"Snapshot from Cam 1: {filename}"
 
577
  st.image(Image.open(filename), caption="Camera 1", use_container_width=True)
578
  logger.info(f"Saved snapshot from Camera 1: {filename}")
579
  update_gallery()
@@ -589,7 +563,9 @@ with tab2:
589
  pdf_path = generate_filename("downloaded", "pdf")
590
  if download_pdf(url, pdf_path):
591
  logger.info(f"Downloaded PDF from {url} to {pdf_path}")
592
- st.session_state['history']['pdf'] = f"Downloaded PDF: {pdf_path}"
 
 
593
  snapshots = asyncio.run(process_pdf_snapshot(pdf_path, mode.lower().replace(" ", "")))
594
  for snapshot in snapshots:
595
  st.image(Image.open(snapshot), caption=snapshot, use_container_width=True)
@@ -611,7 +587,9 @@ with tab3:
611
  builder.save_model(config.model_path)
612
  st.session_state['builder'] = builder
613
  st.session_state['model_loaded'] = True
614
- st.session_state['history']['build'] = f"Built {model_type} model: {model_name}"
 
 
615
  st.success(f"Model downloaded and saved to {config.model_path}! 🎉")
616
  st.rerun()
617
 
@@ -646,13 +624,15 @@ with tab4:
646
  st.session_state['builder'].save_model(new_config.model_path)
647
  zip_path = f"{new_config.model_path}.zip"
648
  zip_directory(new_config.model_path, zip_path)
649
- st.session_state['history']['sft'] = f"Fine-tuned Causal LM: {new_model_name}"
 
 
650
  st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Titan"), unsafe_allow_html=True)
651
  st.rerun()
652
  elif isinstance(st.session_state['builder'], DiffusionBuilder):
653
- captured_files = list(st.session_state['captured_files'].values())
654
  if len(captured_files) >= 2:
655
- demo_data = [{"image": img, "text": f"Superhero {os.path.basename(img).split('.')[0]}"} for img in captured_files if img]
656
  edited_data = st.data_editor(pd.DataFrame(demo_data), num_rows="dynamic")
657
  if st.button("Fine-Tune with Dataset 🔄"):
658
  images = [Image.open(row["image"]) for _, row in edited_data.iterrows()]
@@ -664,12 +644,14 @@ with tab4:
664
  st.session_state['builder'].save_model(new_config.model_path)
665
  zip_path = f"{new_config.model_path}.zip"
666
  zip_directory(new_config.model_path, zip_path)
667
- st.session_state['history']['sft'] = f"Fine-tuned Diffusion: {new_model_name}"
 
 
668
  st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Diffusion Model"), unsafe_allow_html=True)
669
  csv_path = f"sft_dataset_{int(time.time())}.csv"
670
  with open(csv_path, "w", newline="") as f:
671
  writer = csv.writer(f)
672
- writer.writerow(["image", "text"])
673
  for _, row in edited_data.iterrows():
674
  writer.writerow([row["image"], row["text"]])
675
  st.markdown(get_download_link(csv_path, "text/csv", "Download SFT Dataset CSV"), unsafe_allow_html=True)
@@ -696,7 +678,9 @@ with tab5:
696
  if st.button("Run Test ▶️"):
697
  status_container = st.empty()
698
  result = st.session_state['builder'].evaluate(test_prompt, status_container)
699
- st.session_state['history']['test'] = f"Causal LM Test: {test_prompt} -> {result}"
 
 
700
  st.write(f"**Generated Response**: {result}")
701
  status_container.empty()
702
  elif isinstance(st.session_state['builder'], DiffusionBuilder):
@@ -705,8 +689,9 @@ with tab5:
705
  image = st.session_state['builder'].generate(test_prompt)
706
  output_file = generate_filename("diffusion_test", "png")
707
  image.save(output_file)
708
- st.session_state['captured_files']['diffusion_test'] = output_file
709
- st.session_state['history']['test'] = f"Diffusion Test: {test_prompt} -> {output_file}"
 
710
  st.image(image, caption="Generated Image")
711
  update_gallery()
712
 
@@ -720,28 +705,31 @@ with tab6:
720
  agent = PartyPlannerAgent(st.session_state['builder'].model, st.session_state['builder'].tokenizer)
721
  task = "Plan a luxury superhero-themed party at Wayne Manor."
722
  plan_df = agent.plan_party(task)
723
- st.session_state['history']['rag'] = f"NLP RAG Demo: Planned party at Wayne Manor"
 
 
724
  st.dataframe(plan_df)
725
  elif isinstance(st.session_state['builder'], DiffusionBuilder):
726
  if st.button("Run CV RAG Demo 🎉"):
727
  agent = CVPartyPlannerAgent(st.session_state['builder'].pipeline)
728
  task = "Generate images for a luxury superhero-themed party."
729
  plan_df = agent.plan_party(task)
730
- st.session_state['history']['rag'] = f"CV RAG Demo: Generated party images"
 
 
731
  st.dataframe(plan_df)
732
  for _, row in plan_df.iterrows():
733
  image = agent.generate(row["Image Idea"])
734
  output_file = generate_filename(f"cv_rag_{row['Theme'].lower()}", "png")
735
  image.save(output_file)
736
- st.session_state['captured_files'][f"cv_rag_{row['Theme'].lower()}"] = output_file
737
  st.image(image, caption=f"{row['Theme']} - {row['Image Idea']}")
738
  update_gallery()
739
 
740
  with tab7:
741
  st.header("Test OCR 🔍")
742
- captured_files = list(st.session_state['captured_files'].values())
743
  if captured_files:
744
- selected_file = st.selectbox("Select Image", [f for f in captured_files if f and f.endswith(".png")], key="ocr_select")
745
  if selected_file:
746
  image = Image.open(selected_file)
747
  st.image(image, caption="Input Image", use_container_width=True)
@@ -749,7 +737,9 @@ with tab7:
749
  output_file = generate_filename("ocr_output", "txt")
750
  st.session_state['processing']['ocr'] = True
751
  result = asyncio.run(process_ocr(image, output_file))
752
- st.session_state['history']['ocr'] = f"OCR Test: {selected_file} -> {output_file}"
 
 
753
  st.text_area("OCR Result", result, height=200, key="ocr_result")
754
  st.success(f"OCR output saved to {output_file}")
755
  st.session_state['processing']['ocr'] = False
@@ -758,9 +748,9 @@ with tab7:
758
 
759
  with tab8:
760
  st.header("Test Image Gen 🎨")
761
- captured_files = list(st.session_state['captured_files'].values())
762
  if captured_files:
763
- selected_file = st.selectbox("Select Image", [f for f in captured_files if f and f.endswith(".png")], key="gen_select")
764
  if selected_file:
765
  image = Image.open(selected_file)
766
  st.image(image, caption="Reference Image", use_container_width=True)
@@ -769,7 +759,9 @@ with tab8:
769
  output_file = generate_filename("gen_output", "png")
770
  st.session_state['processing']['gen'] = True
771
  result = asyncio.run(process_image_gen(prompt, output_file))
772
- st.session_state['history']['gen'] = f"Image Gen Test: {prompt} -> {output_file}"
 
 
773
  st.image(result, caption="Generated Image", use_container_width=True)
774
  st.success(f"Image saved to {output_file}")
775
  st.session_state['processing']['gen'] = False
@@ -779,10 +771,10 @@ with tab8:
779
  with tab9:
780
  st.header("Custom Diffusion 🎨🤓")
781
  st.write("Unleash your inner artist with our tiny diffusion models!")
782
- captured_files = list(st.session_state['captured_files'].values())
783
  if captured_files:
784
  st.subheader("Select Images to Train")
785
- selected_files = st.multiselect("Pick Images", [f for f in captured_files if f and f.endswith(".png")], key="diffusion_select")
786
  images = [Image.open(file) for file in selected_files]
787
 
788
  model_options = [
@@ -803,8 +795,9 @@ with tab9:
803
  builder.load_model(model_name)
804
  result = builder.generate("A superhero scene inspired by captured images")
805
  result.save(output_file)
806
- st.session_state['captured_files']['diffusion'] = output_file
807
- st.session_state['history']['diffusion'] = f"Custom Diffusion: {model_choice} -> {output_file}"
 
808
  st.image(result, caption=f"{model_choice} Masterpiece", use_container_width=True)
809
  st.success(f"Image saved to {output_file}")
810
  st.session_state['processing']['diffusion'] = False
 
53
  )
54
 
55
  # Initialize st.session_state
 
 
56
  if 'history' not in st.session_state:
57
+ st.session_state['history'] = [] # Flat list for history
58
  if 'builder' not in st.session_state:
59
  st.session_state['builder'] = None
60
  if 'model_loaded' not in st.session_state:
 
327
  path = "models/*" if model_type == "causal_lm" else "diffusion_models/*"
328
  return [d for d in glob.glob(path) if os.path.isdir(d)]
329
 
330
+ def get_gallery_files(file_types=["png"]):
331
  return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
332
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  # Mock Search Tool for RAG
334
  def mock_search(query: str) -> str:
335
  if "superhero" in query.lower():
 
431
  output_files.append(output_file)
432
  elapsed = int(time.time() - start_time)
433
  status.text(f"PDF Snapshot ({mode}) completed in {elapsed}s!")
 
 
 
434
  update_gallery()
435
  return output_files
436
  except Exception as e:
 
448
  status.text(f"GOT-OCR2_0 completed in {elapsed}s!")
449
  async with aiofiles.open(output_file, "w") as f:
450
  await f.write(result)
 
 
451
  update_gallery()
452
  return result
453
 
 
460
  elapsed = int(time.time() - start_time)
461
  status.text(f"Image Gen completed in {elapsed}s!")
462
  gen_image.save(output_file)
 
 
463
  update_gallery()
464
  return gen_image
465
 
 
475
  elapsed = int(time.time() - start_time)
476
  status.text(f"{model_name} completed in {elapsed}s!")
477
  upscaled_image.save(output_file)
 
 
478
  update_gallery()
479
  return upscaled_image
480
 
 
483
 
484
  # Sidebar
485
  st.sidebar.header("Captured Files 📜")
486
+ gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 2) # Default to 2
487
  def update_gallery():
488
+ media_files = get_gallery_files(["png"])
489
+ if media_files:
 
490
  cols = st.sidebar.columns(2)
491
+ for idx, file in enumerate(media_files[:gallery_size * 2]): # Limit by gallery size
492
+ with cols[idx % 2]:
493
+ st.image(Image.open(file), caption=os.path.basename(file), use_container_width=True)
 
 
 
494
  update_gallery()
495
 
496
  st.sidebar.subheader("Model Management 🗂️")
 
514
  st.sidebar.subheader("History 📜")
515
  history_container = st.sidebar.empty()
516
  with history_container:
517
+ for entry in st.session_state['history'][-gallery_size * 2:]: # Limit by gallery size
 
518
  st.write(entry)
519
 
520
  # Tabs
 
533
  filename = generate_filename("cam0")
534
  with open(filename, "wb") as f:
535
  f.write(cam0_img.getvalue())
536
+ entry = f"Snapshot from Cam 0: {filename}"
537
+ if entry not in st.session_state['history']:
538
+ st.session_state['history'] = [e for e in st.session_state['history'] if not e.startswith("Snapshot from Cam 0:")] + [entry]
539
  st.image(Image.open(filename), caption="Camera 0", use_container_width=True)
540
  logger.info(f"Saved snapshot from Camera 0: {filename}")
541
  update_gallery()
 
545
  filename = generate_filename("cam1")
546
  with open(filename, "wb") as f:
547
  f.write(cam1_img.getvalue())
548
+ entry = f"Snapshot from Cam 1: {filename}"
549
+ if entry not in st.session_state['history']:
550
+ st.session_state['history'] = [e for e in st.session_state['history'] if not e.startswith("Snapshot from Cam 1:")] + [entry]
551
  st.image(Image.open(filename), caption="Camera 1", use_container_width=True)
552
  logger.info(f"Saved snapshot from Camera 1: {filename}")
553
  update_gallery()
 
563
  pdf_path = generate_filename("downloaded", "pdf")
564
  if download_pdf(url, pdf_path):
565
  logger.info(f"Downloaded PDF from {url} to {pdf_path}")
566
+ entry = f"Downloaded PDF: {pdf_path}"
567
+ if entry not in st.session_state['history']:
568
+ st.session_state['history'].append(entry)
569
  snapshots = asyncio.run(process_pdf_snapshot(pdf_path, mode.lower().replace(" ", "")))
570
  for snapshot in snapshots:
571
  st.image(Image.open(snapshot), caption=snapshot, use_container_width=True)
 
587
  builder.save_model(config.model_path)
588
  st.session_state['builder'] = builder
589
  st.session_state['model_loaded'] = True
590
+ entry = f"Built {model_type} model: {model_name}"
591
+ if entry not in st.session_state['history']:
592
+ st.session_state['history'].append(entry)
593
  st.success(f"Model downloaded and saved to {config.model_path}! 🎉")
594
  st.rerun()
595
 
 
624
  st.session_state['builder'].save_model(new_config.model_path)
625
  zip_path = f"{new_config.model_path}.zip"
626
  zip_directory(new_config.model_path, zip_path)
627
+ entry = f"Fine-tuned Causal LM: {new_model_name}"
628
+ if entry not in st.session_state['history']:
629
+ st.session_state['history'].append(entry)
630
  st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Titan"), unsafe_allow_html=True)
631
  st.rerun()
632
  elif isinstance(st.session_state['builder'], DiffusionBuilder):
633
+ captured_files = get_gallery_files(["png"])
634
  if len(captured_files) >= 2:
635
+ demo_data = [{"image": img, "text": f"Superhero {os.path.basename(img).split('.')[0]}"} for img in captured_files]
636
  edited_data = st.data_editor(pd.DataFrame(demo_data), num_rows="dynamic")
637
  if st.button("Fine-Tune with Dataset 🔄"):
638
  images = [Image.open(row["image"]) for _, row in edited_data.iterrows()]
 
644
  st.session_state['builder'].save_model(new_config.model_path)
645
  zip_path = f"{new_config.model_path}.zip"
646
  zip_directory(new_config.model_path, zip_path)
647
+ entry = f"Fine-tuned Diffusion: {new_model_name}"
648
+ if entry not in st.session_state['history']:
649
+ st.session_state['history'].append(entry)
650
  st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Diffusion Model"), unsafe_allow_html=True)
651
  csv_path = f"sft_dataset_{int(time.time())}.csv"
652
  with open(csv_path, "w", newline="") as f:
653
  writer = csv.writer(f)
654
+ writer.writerow(["image", "text()])
655
  for _, row in edited_data.iterrows():
656
  writer.writerow([row["image"], row["text"]])
657
  st.markdown(get_download_link(csv_path, "text/csv", "Download SFT Dataset CSV"), unsafe_allow_html=True)
 
678
  if st.button("Run Test ▶️"):
679
  status_container = st.empty()
680
  result = st.session_state['builder'].evaluate(test_prompt, status_container)
681
+ entry = f"Causal LM Test: {test_prompt} -> {result}"
682
+ if entry not in st.session_state['history']:
683
+ st.session_state['history'].append(entry)
684
  st.write(f"**Generated Response**: {result}")
685
  status_container.empty()
686
  elif isinstance(st.session_state['builder'], DiffusionBuilder):
 
689
  image = st.session_state['builder'].generate(test_prompt)
690
  output_file = generate_filename("diffusion_test", "png")
691
  image.save(output_file)
692
+ entry = f"Diffusion Test: {test_prompt} -> {output_file}"
693
+ if entry not in st.session_state['history']:
694
+ st.session_state['history'].append(entry)
695
  st.image(image, caption="Generated Image")
696
  update_gallery()
697
 
 
705
  agent = PartyPlannerAgent(st.session_state['builder'].model, st.session_state['builder'].tokenizer)
706
  task = "Plan a luxury superhero-themed party at Wayne Manor."
707
  plan_df = agent.plan_party(task)
708
+ entry = f"NLP RAG Demo: Planned party at Wayne Manor"
709
+ if entry not in st.session_state['history']:
710
+ st.session_state['history'].append(entry)
711
  st.dataframe(plan_df)
712
  elif isinstance(st.session_state['builder'], DiffusionBuilder):
713
  if st.button("Run CV RAG Demo 🎉"):
714
  agent = CVPartyPlannerAgent(st.session_state['builder'].pipeline)
715
  task = "Generate images for a luxury superhero-themed party."
716
  plan_df = agent.plan_party(task)
717
+ entry = f"CV RAG Demo: Generated party images"
718
+ if entry not in st.session_state['history']:
719
+ st.session_state['history'].append(entry)
720
  st.dataframe(plan_df)
721
  for _, row in plan_df.iterrows():
722
  image = agent.generate(row["Image Idea"])
723
  output_file = generate_filename(f"cv_rag_{row['Theme'].lower()}", "png")
724
  image.save(output_file)
 
725
  st.image(image, caption=f"{row['Theme']} - {row['Image Idea']}")
726
  update_gallery()
727
 
728
  with tab7:
729
  st.header("Test OCR 🔍")
730
+ captured_files = get_gallery_files(["png"])
731
  if captured_files:
732
+ selected_file = st.selectbox("Select Image", captured_files, key="ocr_select")
733
  if selected_file:
734
  image = Image.open(selected_file)
735
  st.image(image, caption="Input Image", use_container_width=True)
 
737
  output_file = generate_filename("ocr_output", "txt")
738
  st.session_state['processing']['ocr'] = True
739
  result = asyncio.run(process_ocr(image, output_file))
740
+ entry = f"OCR Test: {selected_file} -> {output_file}"
741
+ if entry not in st.session_state['history']:
742
+ st.session_state['history'].append(entry)
743
  st.text_area("OCR Result", result, height=200, key="ocr_result")
744
  st.success(f"OCR output saved to {output_file}")
745
  st.session_state['processing']['ocr'] = False
 
748
 
749
  with tab8:
750
  st.header("Test Image Gen 🎨")
751
+ captured_files = get_gallery_files(["png"])
752
  if captured_files:
753
+ selected_file = st.selectbox("Select Image", captured_files, key="gen_select")
754
  if selected_file:
755
  image = Image.open(selected_file)
756
  st.image(image, caption="Reference Image", use_container_width=True)
 
759
  output_file = generate_filename("gen_output", "png")
760
  st.session_state['processing']['gen'] = True
761
  result = asyncio.run(process_image_gen(prompt, output_file))
762
+ entry = f"Image Gen Test: {prompt} -> {output_file}"
763
+ if entry not in st.session_state['history']:
764
+ st.session_state['history'].append(entry)
765
  st.image(result, caption="Generated Image", use_container_width=True)
766
  st.success(f"Image saved to {output_file}")
767
  st.session_state['processing']['gen'] = False
 
771
  with tab9:
772
  st.header("Custom Diffusion 🎨🤓")
773
  st.write("Unleash your inner artist with our tiny diffusion models!")
774
+ captured_files = get_gallery_files(["png"])
775
  if captured_files:
776
  st.subheader("Select Images to Train")
777
+ selected_files = st.multiselect("Pick Images", captured_files, key="diffusion_select")
778
  images = [Image.open(file) for file in selected_files]
779
 
780
  model_options = [
 
795
  builder.load_model(model_name)
796
  result = builder.generate("A superhero scene inspired by captured images")
797
  result.save(output_file)
798
+ entry = f"Custom Diffusion: {model_choice} -> {output_file}"
799
+ if entry not in st.session_state['history']:
800
+ st.session_state['history'].append(entry)
801
  st.image(result, caption=f"{model_choice} Masterpiece", use_container_width=True)
802
  st.success(f"Image saved to {output_file}")
803
  st.session_state['processing']['diffusion'] = False