awacke1 commited on
Commit
9218bcd
·
verified ·
1 Parent(s): a9df450

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -38
app.py CHANGED
@@ -13,7 +13,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
13
  from diffusers import StableDiffusionPipeline
14
  from torch.utils.data import Dataset, DataLoader
15
  import csv
16
- import fitz # PyMuPDF, pure Python library
17
  import requests
18
  from PIL import Image
19
  import cv2
@@ -27,6 +27,7 @@ from typing import Optional, Tuple
27
  import zipfile
28
  import math
29
  import random
 
30
 
31
  # Logging setup with custom buffer
32
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
@@ -61,6 +62,10 @@ if 'model_loaded' not in st.session_state:
61
  st.session_state['model_loaded'] = False
62
  if 'processing' not in st.session_state:
63
  st.session_state['processing'] = {}
 
 
 
 
64
 
65
  # Model Configuration Classes
66
  @dataclass
@@ -311,11 +316,16 @@ def generate_filename(sequence, ext="png"):
311
  timestamp = time.strftime("%d%m%Y%H%M%S")
312
  return f"{sequence}_{timestamp}.{ext}"
313
 
314
- def get_download_link(file_path, mime_type="text/plain", label="Download"):
 
 
 
 
 
315
  with open(file_path, 'rb') as f:
316
  data = f.read()
317
  b64 = base64.b64encode(data).decode()
318
- return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label} 📥</a>'
319
 
320
  def zip_directory(directory_path, zip_path):
321
  with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
@@ -330,6 +340,9 @@ def get_model_files(model_type="causal_lm"):
330
  def get_gallery_files(file_types=["png"]):
331
  return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
332
 
 
 
 
333
  def download_pdf(url, output_path):
334
  try:
335
  response = requests.get(url, stream=True, timeout=10)
@@ -343,26 +356,33 @@ def download_pdf(url, output_path):
343
  return False
344
 
345
  # Async Processing Functions
346
- async def process_pdf_snapshot(pdf_path, mode="thumbnail"):
347
  start_time = time.time()
348
  status = st.empty()
349
  status.text(f"Processing PDF Snapshot ({mode})... (0s)")
350
  try:
351
  doc = fitz.open(pdf_path)
352
  output_files = []
353
- if mode == "thumbnail":
354
  page = doc[0]
355
- pix = page.get_pixmap(matrix=fitz.Matrix(0.5, 0.5)) # 50% scale
356
- output_file = generate_filename("thumbnail", "png")
357
  pix.save(output_file)
358
  output_files.append(output_file)
359
  elif mode == "twopage":
360
  for i in range(min(2, len(doc))):
361
  page = doc[i]
362
- pix = page.get_pixmap(matrix=fitz.Matrix(1.0, 1.0)) # Full scale
363
  output_file = generate_filename(f"twopage_{i}", "png")
364
  pix.save(output_file)
365
  output_files.append(output_file)
 
 
 
 
 
 
 
366
  doc.close()
367
  elapsed = int(time.time() - start_time)
368
  status.text(f"PDF Snapshot ({mode}) completed in {elapsed}s!")
@@ -502,11 +522,16 @@ st.sidebar.header("Captured Files 📜")
502
  gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 2) # Default to 2
503
  def update_gallery():
504
  media_files = get_gallery_files(["png"])
505
- if media_files:
 
 
506
  cols = st.sidebar.columns(2)
507
  for idx, file in enumerate(media_files[:gallery_size * 2]): # Limit by gallery size
508
  with cols[idx % 2]:
509
  st.image(Image.open(file), caption=os.path.basename(file), use_container_width=True)
 
 
 
510
  update_gallery()
511
 
512
  st.sidebar.subheader("Model Management 🗂️")
@@ -570,23 +595,98 @@ with tab1:
570
 
571
  with tab2:
572
  st.header("Download PDFs 📥")
573
- url_input = st.text_area("Enter PDF URLs (one per line)", height=100)
574
- mode = st.selectbox("Snapshot Mode", ["Thumbnail", "Two-Page View"], key="download_mode")
575
- if st.button("Download & Snapshot 📸"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
576
  urls = url_input.strip().split("\n")
577
- for url in urls:
 
 
 
 
578
  if url:
579
- pdf_path = generate_filename("downloaded", "pdf")
580
- if download_pdf(url, pdf_path):
581
- logger.info(f"Downloaded PDF from {url} to {pdf_path}")
582
- entry = f"Downloaded PDF: {pdf_path}"
583
- if entry not in st.session_state['history']:
584
- st.session_state['history'].append(entry)
585
- snapshots = asyncio.run(process_pdf_snapshot(pdf_path, mode.lower().replace(" ", "")))
586
- for snapshot in snapshots:
587
- st.image(Image.open(snapshot), caption=snapshot, use_container_width=True)
 
 
588
  else:
589
- st.error(f"Failed to download {url}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
 
591
  with tab3:
592
  st.header("Build Titan 🌱")
@@ -647,11 +747,14 @@ with tab4:
647
  st.rerun()
648
  elif isinstance(st.session_state['builder'], DiffusionBuilder):
649
  captured_files = get_gallery_files(["png"])
650
- if len(captured_files) >= 2:
 
651
  demo_data = [{"image": img, "text": f"Superhero {os.path.basename(img).split('.')[0]}"} for img in captured_files]
 
 
652
  edited_data = st.data_editor(pd.DataFrame(demo_data), num_rows="dynamic")
653
  if st.button("Fine-Tune with Dataset 🔄"):
654
- images = [Image.open(row["image"]) for _, row in edited_data.iterrows()]
655
  texts = [row["text"] for _, row in edited_data.iterrows()]
656
  new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
657
  new_config = DiffusionConfig(name=new_model_name, base_model=st.session_state['builder'].config.base_model, size="small")
@@ -701,6 +804,7 @@ with tab5:
701
  status_container.empty()
702
  elif isinstance(st.session_state['builder'], DiffusionBuilder):
703
  test_prompt = st.text_area("Enter Test Prompt", "Neon Batman")
 
704
  if st.button("Run Test ▶️"):
705
  image = st.session_state['builder'].generate(test_prompt)
706
  output_file = generate_filename("diffusion_test", "png")
@@ -744,10 +848,18 @@ with tab6:
744
  with tab7:
745
  st.header("Test OCR 🔍")
746
  captured_files = get_gallery_files(["png"])
747
- if captured_files:
748
- selected_file = st.selectbox("Select Image", captured_files, key="ocr_select")
 
 
749
  if selected_file:
750
- image = Image.open(selected_file)
 
 
 
 
 
 
751
  st.image(image, caption="Input Image", use_container_width=True)
752
  if st.button("Run OCR 🚀", key="ocr_run"):
753
  output_file = generate_filename("ocr_output", "txt")
@@ -760,15 +872,23 @@ with tab7:
760
  st.success(f"OCR output saved to {output_file}")
761
  st.session_state['processing']['ocr'] = False
762
  else:
763
- st.warning("No images captured yet. Use Camera Snap or Download PDFs first!")
764
 
765
  with tab8:
766
  st.header("Test Image Gen 🎨")
767
  captured_files = get_gallery_files(["png"])
768
- if captured_files:
769
- selected_file = st.selectbox("Select Image", captured_files, key="gen_select")
 
 
770
  if selected_file:
771
- image = Image.open(selected_file)
 
 
 
 
 
 
772
  st.image(image, caption="Reference Image", use_container_width=True)
773
  prompt = st.text_area("Prompt", "Generate a similar superhero image", key="gen_prompt")
774
  if st.button("Run Image Gen 🚀", key="gen_run"):
@@ -782,16 +902,26 @@ with tab8:
782
  st.success(f"Image saved to {output_file}")
783
  st.session_state['processing']['gen'] = False
784
  else:
785
- st.warning("No images captured yet. Use Camera Snap or Download PDFs first!")
786
 
787
  with tab9:
788
  st.header("Custom Diffusion 🎨🤓")
789
  st.write("Unleash your inner artist with our tiny diffusion models!")
790
  captured_files = get_gallery_files(["png"])
791
- if captured_files:
792
- st.subheader("Select Images to Train")
793
- selected_files = st.multiselect("Pick Images", captured_files, key="diffusion_select")
794
- images = [Image.open(file) for file in selected_files]
 
 
 
 
 
 
 
 
 
 
795
 
796
  model_options = [
797
  ("PixelTickler 🎨✨", "OFA-Sys/small-stable-diffusion-v0"),
@@ -818,7 +948,7 @@ with tab9:
818
  st.success(f"Image saved to {output_file}")
819
  st.session_state['processing']['diffusion'] = False
820
  else:
821
- st.warning("No images captured yet. Use Camera Snap or Download PDFs first!")
822
 
823
  # Initial Gallery Update
824
  update_gallery()
 
13
  from diffusers import StableDiffusionPipeline
14
  from torch.utils.data import Dataset, DataLoader
15
  import csv
16
+ import fitz # PyMuPDF
17
  import requests
18
  from PIL import Image
19
  import cv2
 
27
  import zipfile
28
  import math
29
  import random
30
+ import re
31
 
32
  # Logging setup with custom buffer
33
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
62
  st.session_state['model_loaded'] = False
63
  if 'processing' not in st.session_state:
64
  st.session_state['processing'] = {}
65
+ if 'pdf_checkboxes' not in st.session_state:
66
+ st.session_state['pdf_checkboxes'] = {} # Shared cache for PDF checkboxes
67
+ if 'downloaded_pdfs' not in st.session_state:
68
+ st.session_state['downloaded_pdfs'] = {} # Cache for downloaded PDF paths
69
 
70
  # Model Configuration Classes
71
  @dataclass
 
316
  timestamp = time.strftime("%d%m%Y%H%M%S")
317
  return f"{sequence}_{timestamp}.{ext}"
318
 
319
+ def pdf_url_to_filename(url):
320
+ # Convert full URL to filename, replacing illegal characters
321
+ safe_name = re.sub(r'[<>:"/\\|?*]', '_', url)
322
+ return f"{safe_name}.pdf"
323
+
324
+ def get_download_link(file_path, mime_type="application/pdf", label="Download"):
325
  with open(file_path, 'rb') as f:
326
  data = f.read()
327
  b64 = base64.b64encode(data).decode()
328
+ return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label}</a>'
329
 
330
  def zip_directory(directory_path, zip_path):
331
  with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
 
340
  def get_gallery_files(file_types=["png"]):
341
  return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
342
 
343
+ def get_pdf_files():
344
+ return sorted(glob.glob("*.pdf"))
345
+
346
  def download_pdf(url, output_path):
347
  try:
348
  response = requests.get(url, stream=True, timeout=10)
 
356
  return False
357
 
358
  # Async Processing Functions
359
+ async def process_pdf_snapshot(pdf_path, mode="single"):
360
  start_time = time.time()
361
  status = st.empty()
362
  status.text(f"Processing PDF Snapshot ({mode})... (0s)")
363
  try:
364
  doc = fitz.open(pdf_path)
365
  output_files = []
366
+ if mode == "single":
367
  page = doc[0]
368
+ pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) # High-res: 200% scale
369
+ output_file = generate_filename("single", "png")
370
  pix.save(output_file)
371
  output_files.append(output_file)
372
  elif mode == "twopage":
373
  for i in range(min(2, len(doc))):
374
  page = doc[i]
375
+ pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) # High-res: 200% scale
376
  output_file = generate_filename(f"twopage_{i}", "png")
377
  pix.save(output_file)
378
  output_files.append(output_file)
379
+ elif mode == "allthumbs":
380
+ for i in range(len(doc)):
381
+ page = doc[i]
382
+ pix = page.get_pixmap(matrix=fitz.Matrix(0.5, 0.5)) # Thumbnail: 50% scale
383
+ output_file = generate_filename(f"thumb_{i}", "png")
384
+ pix.save(output_file)
385
+ output_files.append(output_file)
386
  doc.close()
387
  elapsed = int(time.time() - start_time)
388
  status.text(f"PDF Snapshot ({mode}) completed in {elapsed}s!")
 
522
  gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 2) # Default to 2
523
  def update_gallery():
524
  media_files = get_gallery_files(["png"])
525
+ pdf_files = get_pdf_files()
526
+ if media_files or pdf_files:
527
+ st.sidebar.subheader("Images 📸")
528
  cols = st.sidebar.columns(2)
529
  for idx, file in enumerate(media_files[:gallery_size * 2]): # Limit by gallery size
530
  with cols[idx % 2]:
531
  st.image(Image.open(file), caption=os.path.basename(file), use_container_width=True)
532
+ st.sidebar.subheader("PDF Downloads 📖")
533
+ for pdf_file in pdf_files[:gallery_size * 2]: # Limit by gallery size
534
+ st.markdown(get_download_link(pdf_file, "application/pdf", f"📥 Grab {os.path.basename(pdf_file)}"), unsafe_allow_html=True)
535
  update_gallery()
536
 
537
  st.sidebar.subheader("Model Management 🗂️")
 
595
 
596
  with tab2:
597
  st.header("Download PDFs 📥")
598
+ # Examples button with arXiv PDF links from README.md
599
+ if st.button("Examples 📚"):
600
+ example_urls = [
601
+ "https://arxiv.org/pdf/2308.03892", # Streamlit
602
+ "https://arxiv.org/pdf/1912.01703", # PyTorch
603
+ "https://arxiv.org/pdf/2408.11039", # Qwen2-VL
604
+ "https://arxiv.org/pdf/2109.10282", # TrOCR
605
+ "https://arxiv.org/pdf/2112.10752", # LDM
606
+ "https://arxiv.org/pdf/2308.11236", # OpenCV
607
+ "https://arxiv.org/pdf/1706.03762", # Attention is All You Need
608
+ "https://arxiv.org/pdf/2006.11239", # DDPM
609
+ "https://arxiv.org/pdf/2305.11207", # Pandas
610
+ "https://arxiv.org/pdf/2106.09685", # LoRA
611
+ "https://arxiv.org/pdf/2005.11401", # RAG
612
+ "https://arxiv.org/pdf/2106.10504" # Fine-Tuning Vision Transformers
613
+ ]
614
+ st.session_state['pdf_urls'] = "\n".join(example_urls)
615
+
616
+ # Robo-Downloader
617
+ url_input = st.text_area("Enter PDF URLs (one per line)", value=st.session_state.get('pdf_urls', ""), height=200)
618
+ if st.button("Robo-Download 🤖"):
619
  urls = url_input.strip().split("\n")
620
+ progress_bar = st.progress(0)
621
+ status_text = st.empty()
622
+ total_urls = len(urls)
623
+ existing_pdfs = get_pdf_files()
624
+ for idx, url in enumerate(urls):
625
  if url:
626
+ output_path = pdf_url_to_filename(url)
627
+ status_text.text(f"Fetching {idx + 1}/{total_urls}: {os.path.basename(output_path)}...")
628
+ if output_path not in existing_pdfs:
629
+ if download_pdf(url, output_path):
630
+ st.session_state['downloaded_pdfs'][url] = output_path
631
+ logger.info(f"Downloaded PDF from {url} to {output_path}")
632
+ entry = f"Downloaded PDF: {output_path}"
633
+ if entry not in st.session_state['history']:
634
+ st.session_state['history'].append(entry)
635
+ else:
636
+ st.error(f"Failed to nab {url} 😿")
637
  else:
638
+ st.info(f"Already got {os.path.basename(output_path)}! Skipping... 🐾")
639
+ st.session_state['downloaded_pdfs'][url] = output_path
640
+ progress_bar.progress((idx + 1) / total_urls)
641
+ status_text.text("Robo-Download complete! 🚀")
642
+ update_gallery()
643
+
644
+ # PDF Gallery with Thumbnails and Checkboxes
645
+ st.subheader("PDF Gallery 📖")
646
+ downloaded_pdfs = list(st.session_state['downloaded_pdfs'].values())
647
+ if downloaded_pdfs:
648
+ cols_per_row = 3
649
+ for i in range(0, len(downloaded_pdfs), cols_per_row):
650
+ cols = st.columns(cols_per_row)
651
+ for j, pdf_path in enumerate(downloaded_pdfs[i:i + cols_per_row]):
652
+ with cols[j]:
653
+ doc = fitz.open(pdf_path)
654
+ page = doc[0]
655
+ pix = page.get_pixmap(matrix=fitz.Matrix(0.5, 0.5)) # Thumbnail at 50% scale
656
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
657
+ st.image(img, caption=os.path.basename(pdf_path), use_container_width=True)
658
+ # Checkbox for SFT/Input use
659
+ checkbox_key = f"pdf_{pdf_path}"
660
+ st.session_state['pdf_checkboxes'][checkbox_key] = st.checkbox(
661
+ "Use for SFT/Input",
662
+ value=st.session_state['pdf_checkboxes'].get(checkbox_key, False),
663
+ key=checkbox_key
664
+ )
665
+ # Download and Delete Buttons
666
+ st.markdown(get_download_link(pdf_path, "application/pdf", "Snag It! 📥"), unsafe_allow_html=True)
667
+ if st.button("Zap It! 🗑️", key=f"delete_{pdf_path}"):
668
+ os.remove(pdf_path)
669
+ url_key = next((k for k, v in st.session_state['downloaded_pdfs'].items() if v == pdf_path), None)
670
+ if url_key:
671
+ del st.session_state['downloaded_pdfs'][url_key]
672
+ del st.session_state['pdf_checkboxes'][checkbox_key]
673
+ st.success(f"PDF {os.path.basename(pdf_path)} vaporized! 💨")
674
+ st.rerun()
675
+ doc.close()
676
+ else:
677
+ st.info("No PDFs captured yet. Feed the robo-downloader some URLs! 🤖")
678
+
679
+ mode = st.selectbox("Snapshot Mode", ["Single Page (High-Res)", "Two Pages (High-Res)", "All Pages (Thumbnails)"], key="download_mode")
680
+ if st.button("Snapshot Selected 📸"):
681
+ selected_pdfs = [path for key, path in st.session_state['downloaded_pdfs'].items() if st.session_state['pdf_checkboxes'].get(f"pdf_{path}", False)]
682
+ if selected_pdfs:
683
+ for pdf_path in selected_pdfs:
684
+ mode_key = {"Single Page (High-Res)": "single", "Two Pages (High-Res)": "twopage", "All Pages (Thumbnails)": "allthumbs"}[mode]
685
+ snapshots = asyncio.run(process_pdf_snapshot(pdf_path, mode_key))
686
+ for snapshot in snapshots:
687
+ st.image(Image.open(snapshot), caption=snapshot, use_container_width=True)
688
+ else:
689
+ st.warning("No PDFs selected for snapshotting! Check some boxes first. 📝")
690
 
691
  with tab3:
692
  st.header("Build Titan 🌱")
 
747
  st.rerun()
748
  elif isinstance(st.session_state['builder'], DiffusionBuilder):
749
  captured_files = get_gallery_files(["png"])
750
+ selected_pdfs = [path for key, path in st.session_state['downloaded_pdfs'].items() if st.session_state['pdf_checkboxes'].get(f"pdf_{path}", False)]
751
+ if len(captured_files) + len(selected_pdfs) >= 2:
752
  demo_data = [{"image": img, "text": f"Superhero {os.path.basename(img).split('.')[0]}"} for img in captured_files]
753
+ for pdf_path in selected_pdfs:
754
+ demo_data.append({"image": pdf_path, "text": f"PDF {os.path.basename(pdf_path)}"})
755
  edited_data = st.data_editor(pd.DataFrame(demo_data), num_rows="dynamic")
756
  if st.button("Fine-Tune with Dataset 🔄"):
757
+ images = [Image.open(row["image"]) if row["image"].endswith('.png') else Image.frombytes("RGB", fitz.open(row["image"])[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0)).size, fitz.open(row["image"])[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0)).samples) for _, row in edited_data.iterrows()]
758
  texts = [row["text"] for _, row in edited_data.iterrows()]
759
  new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
760
  new_config = DiffusionConfig(name=new_model_name, base_model=st.session_state['builder'].config.base_model, size="small")
 
804
  status_container.empty()
805
  elif isinstance(st.session_state['builder'], DiffusionBuilder):
806
  test_prompt = st.text_area("Enter Test Prompt", "Neon Batman")
807
+ selected_pdfs = [path for key, path in st.session_state['downloaded_pdfs'].items() if st.session_state['pdf_checkboxes'].get(f"pdf_{path}", False)]
808
  if st.button("Run Test ▶️"):
809
  image = st.session_state['builder'].generate(test_prompt)
810
  output_file = generate_filename("diffusion_test", "png")
 
848
  with tab7:
849
  st.header("Test OCR 🔍")
850
  captured_files = get_gallery_files(["png"])
851
+ selected_pdfs = [path for key, path in st.session_state['downloaded_pdfs'].items() if st.session_state['pdf_checkboxes'].get(f"pdf_{path}", False)]
852
+ all_files = captured_files + selected_pdfs
853
+ if all_files:
854
+ selected_file = st.selectbox("Select Image or PDF", all_files, key="ocr_select")
855
  if selected_file:
856
+ if selected_file.endswith('.png'):
857
+ image = Image.open(selected_file)
858
+ else:
859
+ doc = fitz.open(selected_file)
860
+ pix = doc[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
861
+ image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
862
+ doc.close()
863
  st.image(image, caption="Input Image", use_container_width=True)
864
  if st.button("Run OCR 🚀", key="ocr_run"):
865
  output_file = generate_filename("ocr_output", "txt")
 
872
  st.success(f"OCR output saved to {output_file}")
873
  st.session_state['processing']['ocr'] = False
874
  else:
875
+ st.warning("No images or PDFs captured yet. Use Camera Snap or Download PDFs first!")
876
 
877
  with tab8:
878
  st.header("Test Image Gen 🎨")
879
  captured_files = get_gallery_files(["png"])
880
+ selected_pdfs = [path for key, path in st.session_state['downloaded_pdfs'].items() if st.session_state['pdf_checkboxes'].get(f"pdf_{path}", False)]
881
+ all_files = captured_files + selected_pdfs
882
+ if all_files:
883
+ selected_file = st.selectbox("Select Image or PDF", all_files, key="gen_select")
884
  if selected_file:
885
+ if selected_file.endswith('.png'):
886
+ image = Image.open(selected_file)
887
+ else:
888
+ doc = fitz.open(selected_file)
889
+ pix = doc[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
890
+ image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
891
+ doc.close()
892
  st.image(image, caption="Reference Image", use_container_width=True)
893
  prompt = st.text_area("Prompt", "Generate a similar superhero image", key="gen_prompt")
894
  if st.button("Run Image Gen 🚀", key="gen_run"):
 
902
  st.success(f"Image saved to {output_file}")
903
  st.session_state['processing']['gen'] = False
904
  else:
905
+ st.warning("No images or PDFs captured yet. Use Camera Snap or Download PDFs first WAV!")
906
 
907
  with tab9:
908
  st.header("Custom Diffusion 🎨🤓")
909
  st.write("Unleash your inner artist with our tiny diffusion models!")
910
  captured_files = get_gallery_files(["png"])
911
+ selected_pdfs = [path for key, path in st.session_state['downloaded_pdfs'].items() if st.session_state['pdf_checkboxes'].get(f"pdf_{path}", False)]
912
+ all_files = captured_files + selected_pdfs
913
+ if all_files:
914
+ st.subheader("Select Images or PDFs to Train")
915
+ selected_files = st.multiselect("Pick Images or PDFs", all_files, key="diffusion_select")
916
+ images = []
917
+ for file in selected_files:
918
+ if file.endswith('.png'):
919
+ images.append(Image.open(file))
920
+ else:
921
+ doc = fitz.open(file)
922
+ pix = doc[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
923
+ images.append(Image.frombytes("RGB", [pix.width, pix.height], pix.samples))
924
+ doc.close()
925
 
926
  model_options = [
927
  ("PixelTickler 🎨✨", "OFA-Sys/small-stable-diffusion-v0"),
 
948
  st.success(f"Image saved to {output_file}")
949
  st.session_state['processing']['diffusion'] = False
950
  else:
951
+ st.warning("No images or PDFs captured yet. Use Camera Snap or Download PDFs first!")
952
 
953
  # Initial Gallery Update
954
  update_gallery()