awacke1 commited on
Commit
353aa7f
·
verified ·
1 Parent(s): d6f4b6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -163
app.py CHANGED
@@ -9,11 +9,11 @@ import pandas as pd
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoProcessor, Qwen2VLForConditionalGeneration, TrOCRProcessor, VisionEncoderDecoderModel
13
  from diffusers import StableDiffusionPipeline
14
  from torch.utils.data import Dataset, DataLoader
15
  import csv
16
- import fitz # PyMuPDF
17
  import requests
18
  from PIL import Image
19
  import cv2
@@ -28,10 +28,7 @@ import zipfile
28
  import math
29
  import random
30
  import re
31
- from datetime import datetime
32
- import pytz
33
 
34
- # Logging setup with custom buffer
35
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
36
  logger = logging.getLogger(__name__)
37
  log_records = []
@@ -42,7 +39,6 @@ class LogCaptureHandler(logging.Handler):
42
 
43
  logger.addHandler(LogCaptureHandler())
44
 
45
- # Page Configuration
46
  st.set_page_config(
47
  page_title="AI Vision & SFT Titans 🚀",
48
  page_icon="🤖",
@@ -55,9 +51,8 @@ st.set_page_config(
55
  }
56
  )
57
 
58
- # Initialize st.session_state
59
  if 'history' not in st.session_state:
60
- st.session_state['history'] = [] # Flat list for history
61
  if 'builder' not in st.session_state:
62
  st.session_state['builder'] = None
63
  if 'model_loaded' not in st.session_state:
@@ -68,10 +63,7 @@ if 'pdf_checkboxes' not in st.session_state:
68
  st.session_state['pdf_checkboxes'] = {}
69
  if 'downloaded_pdfs' not in st.session_state:
70
  st.session_state['downloaded_pdfs'] = {}
71
- if 'captured_images' not in st.session_state:
72
- st.session_state['captured_images'] = []
73
 
74
- # Model Configuration Classes
75
  @dataclass
76
  class ModelConfig:
77
  name: str
@@ -88,12 +80,11 @@ class DiffusionConfig:
88
  name: str
89
  base_model: str
90
  size: str
91
- domain: Optional[str] = None # Fixed to include domain
92
  @property
93
  def model_path(self):
94
  return f"diffusion_models/{self.name}"
95
 
96
- # Datasets
97
  class SFTDataset(Dataset):
98
  def __init__(self, data, tokenizer, max_length=128):
99
  self.data = data
@@ -132,7 +123,6 @@ class TinyDiffusionDataset(Dataset):
132
  def __getitem__(self, idx):
133
  return self.images[idx]
134
 
135
- # Custom Tiny Diffusion Model
136
  class TinyUNet(nn.Module):
137
  def __init__(self, in_channels=3, out_channels=3):
138
  super(TinyUNet, self).__init__()
@@ -205,7 +195,6 @@ class TinyDiffusion:
205
  upscaled = torch.clamp(upscaled * 255, 0, 255).byte()
206
  return Image.fromarray(upscaled.squeeze(0).permute(1, 2, 0).cpu().numpy())
207
 
208
- # Model Builders
209
  class ModelBuilder:
210
  def __init__(self):
211
  self.config = None
@@ -316,10 +305,8 @@ class DiffusionBuilder:
316
  def generate(self, prompt: str):
317
  return self.pipeline(prompt, num_inference_steps=20).images[0]
318
 
319
- # Utility Functions
320
  def generate_filename(sequence, ext="png"):
321
- central = pytz.timezone('US/Central')
322
- timestamp = datetime.now(central).strftime("%d%m%Y%H%M%S%p")
323
  return f"{sequence}_{timestamp}.{ext}"
324
 
325
  def pdf_url_to_filename(url):
@@ -342,7 +329,7 @@ def get_model_files(model_type="causal_lm"):
342
  path = "models/*" if model_type == "causal_lm" else "diffusion_models/*"
343
  return [d for d in glob.glob(path) if os.path.isdir(d)]
344
 
345
- def get_gallery_files(file_types=["png", "txt"]):
346
  return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
347
 
348
  def get_pdf_files():
@@ -360,33 +347,6 @@ def download_pdf(url, output_path):
360
  logger.error(f"Failed to download {url}: {e}")
361
  return False
362
 
363
- # Model Loaders for New App Features
364
- def load_ocr_qwen2vl():
365
- model_id = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
366
- processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
367
- model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float32).to("cpu").eval()
368
- return processor, model
369
-
370
- def load_ocr_trocr():
371
- model_id = "microsoft/trocr-small-handwritten"
372
- processor = TrOCRProcessor.from_pretrained(model_id)
373
- model = VisionEncoderDecoderModel.from_pretrained(model_id, torch_dtype=torch.float32).to("cpu").eval()
374
- return processor, model
375
-
376
- def load_image_gen():
377
- model_id = "OFA-Sys/small-stable-diffusion-v0"
378
- pipeline = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32).to("cpu")
379
- return pipeline
380
-
381
- def load_line_drawer():
382
- def edge_detection(image):
383
- img_np = np.array(image.convert("RGB"))
384
- gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
385
- edges = cv2.Canny(gray, 100, 200)
386
- return Image.fromarray(edges)
387
- return edge_detection
388
-
389
- # Async Processing Functions
390
  async def process_pdf_snapshot(pdf_path, mode="single"):
391
  start_time = time.time()
392
  status = st.empty()
@@ -423,31 +383,17 @@ async def process_pdf_snapshot(pdf_path, mode="single"):
423
  status.error(f"Failed to process PDF: {str(e)}")
424
  return []
425
 
426
- async def process_ocr(image, prompt, model_name, output_file):
427
  start_time = time.time()
428
  status = st.empty()
429
- status.text(f"Processing {model_name} OCR... (0s)")
430
- if model_name == "Qwen2-VL-OCR-2B":
431
- processor, model = load_ocr_qwen2vl()
432
- messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
433
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
434
- inputs = processor(text=[text], images=[image], return_tensors="pt", padding=True).to("cpu")
435
- outputs = model.generate(**inputs, max_new_tokens=1024)
436
- result = processor.batch_decode(outputs, skip_special_tokens=True)[0]
437
- elif model_name == "TrOCR-Small":
438
- processor, model = load_ocr_trocr()
439
- pixel_values = processor(images=image, return_tensors="pt").pixel_values.to("cpu")
440
- outputs = model.generate(pixel_values)
441
- result = processor.batch_decode(outputs, skip_special_tokens=True)[0]
442
- else: # GOT-OCR2_0 (original from Backup 6)
443
- tokenizer = AutoTokenizer.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True)
444
- model = AutoModel.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True, torch_dtype=torch.float32).to("cpu").eval()
445
- result = model.chat(tokenizer, image, ocr_type='ocr')
446
  elapsed = int(time.time() - start_time)
447
- status.text(f"{model_name} OCR completed in {elapsed}s!")
448
  async with aiofiles.open(output_file, "w") as f:
449
  await f.write(result)
450
- st.session_state['captured_images'].append(output_file)
451
  update_gallery()
452
  return result
453
 
@@ -455,29 +401,29 @@ async def process_image_gen(prompt, output_file):
455
  start_time = time.time()
456
  status = st.empty()
457
  status.text("Processing Image Gen... (0s)")
458
- pipeline = load_image_gen()
459
  gen_image = pipeline(prompt, num_inference_steps=20).images[0]
460
  elapsed = int(time.time() - start_time)
461
  status.text(f"Image Gen completed in {elapsed}s!")
462
  gen_image.save(output_file)
463
- st.session_state['captured_images'].append(output_file)
464
  update_gallery()
465
  return gen_image
466
 
467
- async def process_line_drawing(image, output_file):
468
  start_time = time.time()
469
  status = st.empty()
470
- status.text("Processing Line Drawing... (0s)")
471
- edge_fn = load_line_drawer()
472
- line_drawing = edge_fn(image)
 
 
 
473
  elapsed = int(time.time() - start_time)
474
- status.text(f"Line Drawing completed in {elapsed}s!")
475
- line_drawing.save(output_file)
476
- st.session_state['captured_images'].append(output_file)
477
  update_gallery()
478
- return line_drawing
479
 
480
- # Mock Search Tool for RAG
481
  def mock_search(query: str) -> str:
482
  if "superhero" in query.lower():
483
  return "Latest trends: Gold-plated Batman statues, VR superhero battles."
@@ -493,7 +439,6 @@ def mock_duckduckgo_search(query: str) -> str:
493
  """
494
  return "No relevant results found."
495
 
496
- # Agent Classes
497
  class PartyPlannerAgent:
498
  def __init__(self, model, tokenizer):
499
  self.model = model
@@ -558,26 +503,19 @@ def calculate_cargo_travel_time(origin_coords: Tuple[float, float], destination_
558
  flight_time = (actual_distance / cruising_speed_kmh) + 1.0
559
  return round(flight_time, 2)
560
 
561
- # Main App
562
  st.title("AI Vision & SFT Titans 🚀")
563
 
564
- # Sidebar
565
  st.sidebar.header("Captured Files 📜")
566
  gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 2)
567
  def update_gallery():
568
- media_files = get_gallery_files(["png", "txt"])
569
  pdf_files = get_pdf_files()
570
  if media_files or pdf_files:
571
- st.sidebar.subheader("Images & Text 📸")
572
  cols = st.sidebar.columns(2)
573
  for idx, file in enumerate(media_files[:gallery_size * 2]):
574
  with cols[idx % 2]:
575
- if file.endswith(".png"):
576
- st.image(Image.open(file), caption=os.path.basename(file), use_container_width=True)
577
- elif file.endswith(".txt"):
578
- with open(file, "r") as f:
579
- content = f.read()
580
- st.text(content[:50] + "..." if len(content) > 50 else content, help=file)
581
  st.sidebar.subheader("PDF Downloads 📖")
582
  for pdf_file in pdf_files[:gallery_size * 2]:
583
  st.markdown(get_download_link(pdf_file, "application/pdf", f"📥 Grab {os.path.basename(pdf_file)}"), unsafe_allow_html=True)
@@ -607,11 +545,9 @@ with history_container:
607
  for entry in st.session_state['history'][-gallery_size * 2:]:
608
  st.write(entry)
609
 
610
- # Tabs
611
- tab1, tab2, tab3, tab4, tab5, tab6, tab7, tab8, tab9, tab10 = st.tabs([
612
  "Camera Snap 📷", "Download PDFs 📥", "Build Titan 🌱", "Fine-Tune Titan 🔧",
613
- "Test Titan 🧪", "Agentic RAG Party 🌐", "Test OCR 🔍", "Test Image Gen 🎨",
614
- "Test Line Drawings ✏️", "Custom Diffusion 🎨🤓"
615
  ])
616
 
617
  with tab1:
@@ -622,55 +558,43 @@ with tab1:
622
  cam0_img = st.camera_input("Take a picture - Cam 0", key="cam0")
623
  if cam0_img:
624
  filename = generate_filename("cam0")
625
- if filename not in st.session_state['captured_images']:
626
- with open(filename, "wb") as f:
627
- f.write(cam0_img.getvalue())
628
- st.image(Image.open(filename), caption="Camera 0", use_container_width=True)
629
- logger.info(f"Saved snapshot from Camera 0: {filename}")
630
- st.session_state['captured_images'].append(filename)
631
- update_gallery()
 
632
  with cols[1]:
633
  cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
634
  if cam1_img:
635
  filename = generate_filename("cam1")
636
- if filename not in st.session_state['captured_images']:
637
- with open(filename, "wb") as f:
638
- f.write(cam1_img.getvalue())
639
- st.image(Image.open(filename), caption="Camera 1", use_container_width=True)
640
- logger.info(f"Saved snapshot from Camera 1: {filename}")
641
- st.session_state['captured_images'].append(filename)
642
- update_gallery()
643
-
644
- st.subheader("Burst Capture")
645
- slice_count = st.number_input("Number of Frames", min_value=1, max_value=20, value=10, key="burst_count")
646
- if st.button("Start Burst Capture 📸"):
647
- st.session_state['burst_frames'] = []
648
- placeholder = st.empty()
649
- for i in range(slice_count):
650
- with placeholder.container():
651
- st.write(f"Capturing frame {i+1}/{slice_count}...")
652
- img = st.camera_input(f"Frame {i}", key=f"burst_{i}_{time.time()}")
653
- if img:
654
- filename = generate_filename(f"burst_{i}")
655
- if filename not in st.session_state['captured_images']:
656
- with open(filename, "wb") as f:
657
- f.write(img.getvalue())
658
- st.session_state['burst_frames'].append(filename)
659
- logger.info(f"Saved burst frame {i}: {filename}")
660
- st.image(Image.open(filename), caption=filename, use_container_width=True)
661
- time.sleep(0.5)
662
- st.session_state['captured_images'].extend([f for f in st.session_state['burst_frames'] if f not in st.session_state['captured_images']])
663
- update_gallery()
664
- placeholder.success(f"Captured {len(st.session_state['burst_frames'])} frames!")
665
 
666
  with tab2:
667
  st.header("Download PDFs 📥")
668
  if st.button("Examples 📚"):
669
  example_urls = [
670
- "https://arxiv.org/pdf/2308.03892", "https://arxiv.org/pdf/1912.01703", "https://arxiv.org/pdf/2408.11039",
671
- "https://arxiv.org/pdf/2109.10282", "https://arxiv.org/pdf/2112.10752", "https://arxiv.org/pdf/2308.11236",
672
- "https://arxiv.org/pdf/1706.03762", "https://arxiv.org/pdf/2006.11239", "https://arxiv.org/pdf/2305.11207",
673
- "https://arxiv.org/pdf/2106.09685", "https://arxiv.org/pdf/2005.11401", "https://arxiv.org/pdf/2106.10504"
 
 
 
 
 
 
 
 
674
  ]
675
  st.session_state['pdf_urls'] = "\n".join(example_urls)
676
 
@@ -716,7 +640,9 @@ with tab2:
716
  st.image(img, caption=os.path.basename(pdf_path), use_container_width=True)
717
  checkbox_key = f"pdf_{pdf_path}"
718
  st.session_state['pdf_checkboxes'][checkbox_key] = st.checkbox(
719
- "Use for SFT/Input", value=st.session_state['pdf_checkboxes'].get(checkbox_key, False), key=checkbox_key
 
 
720
  )
721
  st.markdown(get_download_link(pdf_path, "application/pdf", "Snag It! 📥"), unsafe_allow_html=True)
722
  if st.button("Zap It! 🗑️", key=f"delete_{pdf_path}"):
@@ -916,12 +842,13 @@ with tab7:
916
  image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
917
  doc.close()
918
  st.image(image, caption="Input Image", use_container_width=True)
919
- ocr_model = st.selectbox("Select OCR Model", ["Qwen2-VL-OCR-2B", "TrOCR-Small", "GOT-OCR2_0"], key="ocr_model_select")
920
- prompt = st.text_area("Prompt", "Extract text from the image", key="ocr_prompt")
921
  if st.button("Run OCR 🚀", key="ocr_run"):
922
  output_file = generate_filename("ocr_output", "txt")
923
  st.session_state['processing']['ocr'] = True
924
- result = asyncio.run(process_ocr(image, prompt, ocr_model, output_file))
 
 
 
925
  st.text_area("OCR Result", result, height=200, key="ocr_result")
926
  st.success(f"OCR output saved to {output_file}")
927
  st.session_state['processing']['ocr'] = False
@@ -949,6 +876,9 @@ with tab8:
949
  output_file = generate_filename("gen_output", "png")
950
  st.session_state['processing']['gen'] = True
951
  result = asyncio.run(process_image_gen(prompt, output_file))
 
 
 
952
  st.image(result, caption="Generated Image", use_container_width=True)
953
  st.success(f"Image saved to {output_file}")
954
  st.session_state['processing']['gen'] = False
@@ -956,32 +886,6 @@ with tab8:
956
  st.warning("No images or PDFs captured yet. Use Camera Snap or Download PDFs first!")
957
 
958
  with tab9:
959
- st.header("Test Line Drawings ✏️")
960
- captured_files = get_gallery_files(["png"])
961
- selected_pdfs = [path for key, path in st.session_state['downloaded_pdfs'].items() if st.session_state['pdf_checkboxes'].get(f"pdf_{path}", False)]
962
- all_files = captured_files + selected_pdfs
963
- if all_files:
964
- selected_file = st.selectbox("Select Image or PDF", all_files, key="line_select")
965
- if selected_file:
966
- if selected_file.endswith('.png'):
967
- image = Image.open(selected_file)
968
- else:
969
- doc = fitz.open(selected_file)
970
- pix = doc[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
971
- image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
972
- doc.close()
973
- st.image(image, caption="Input Image", use_container_width=True)
974
- if st.button("Run Line Drawing 🚀", key="line_run"):
975
- output_file = generate_filename("line_output", "png")
976
- st.session_state['processing']['line'] = True
977
- result = asyncio.run(process_line_drawing(image, output_file))
978
- st.image(result, caption="Line Drawing", use_container_width=True)
979
- st.success(f"Line drawing saved to {output_file}")
980
- st.session_state['processing']['line'] = False
981
- else:
982
- st.warning("No images or PDFs captured yet. Use Camera Snap or Download PDFs first!")
983
-
984
- with tab10:
985
  st.header("Custom Diffusion 🎨🤓")
986
  st.write("Unleash your inner artist with our tiny diffusion models!")
987
  captured_files = get_gallery_files(["png"])
@@ -1027,5 +931,4 @@ with tab10:
1027
  else:
1028
  st.warning("No images or PDFs captured yet. Use Camera Snap or Download PDFs first!")
1029
 
1030
- # Initial Gallery Update
1031
  update_gallery()
 
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
13
  from diffusers import StableDiffusionPipeline
14
  from torch.utils.data import Dataset, DataLoader
15
  import csv
16
+ import fitz
17
  import requests
18
  from PIL import Image
19
  import cv2
 
28
  import math
29
  import random
30
  import re
 
 
31
 
 
32
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
33
  logger = logging.getLogger(__name__)
34
  log_records = []
 
39
 
40
  logger.addHandler(LogCaptureHandler())
41
 
 
42
  st.set_page_config(
43
  page_title="AI Vision & SFT Titans 🚀",
44
  page_icon="🤖",
 
51
  }
52
  )
53
 
 
54
  if 'history' not in st.session_state:
55
+ st.session_state['history'] = []
56
  if 'builder' not in st.session_state:
57
  st.session_state['builder'] = None
58
  if 'model_loaded' not in st.session_state:
 
63
  st.session_state['pdf_checkboxes'] = {}
64
  if 'downloaded_pdfs' not in st.session_state:
65
  st.session_state['downloaded_pdfs'] = {}
 
 
66
 
 
67
  @dataclass
68
  class ModelConfig:
69
  name: str
 
80
  name: str
81
  base_model: str
82
  size: str
83
+ domain: Optional[str] = None
84
  @property
85
  def model_path(self):
86
  return f"diffusion_models/{self.name}"
87
 
 
88
  class SFTDataset(Dataset):
89
  def __init__(self, data, tokenizer, max_length=128):
90
  self.data = data
 
123
  def __getitem__(self, idx):
124
  return self.images[idx]
125
 
 
126
  class TinyUNet(nn.Module):
127
  def __init__(self, in_channels=3, out_channels=3):
128
  super(TinyUNet, self).__init__()
 
195
  upscaled = torch.clamp(upscaled * 255, 0, 255).byte()
196
  return Image.fromarray(upscaled.squeeze(0).permute(1, 2, 0).cpu().numpy())
197
 
 
198
  class ModelBuilder:
199
  def __init__(self):
200
  self.config = None
 
305
  def generate(self, prompt: str):
306
  return self.pipeline(prompt, num_inference_steps=20).images[0]
307
 
 
308
  def generate_filename(sequence, ext="png"):
309
+ timestamp = time.strftime("%d%m%Y%H%M%S")
 
310
  return f"{sequence}_{timestamp}.{ext}"
311
 
312
  def pdf_url_to_filename(url):
 
329
  path = "models/*" if model_type == "causal_lm" else "diffusion_models/*"
330
  return [d for d in glob.glob(path) if os.path.isdir(d)]
331
 
332
+ def get_gallery_files(file_types=["png"]):
333
  return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
334
 
335
  def get_pdf_files():
 
347
  logger.error(f"Failed to download {url}: {e}")
348
  return False
349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  async def process_pdf_snapshot(pdf_path, mode="single"):
351
  start_time = time.time()
352
  status = st.empty()
 
383
  status.error(f"Failed to process PDF: {str(e)}")
384
  return []
385
 
386
+ async def process_ocr(image, output_file):
387
  start_time = time.time()
388
  status = st.empty()
389
+ status.text("Processing GOT-OCR2_0... (0s)")
390
+ tokenizer = AutoTokenizer.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True)
391
+ model = AutoModel.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True, torch_dtype=torch.float32).to("cpu").eval()
392
+ result = model.chat(tokenizer, image, ocr_type='ocr')
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  elapsed = int(time.time() - start_time)
394
+ status.text(f"GOT-OCR2_0 completed in {elapsed}s!")
395
  async with aiofiles.open(output_file, "w") as f:
396
  await f.write(result)
 
397
  update_gallery()
398
  return result
399
 
 
401
  start_time = time.time()
402
  status = st.empty()
403
  status.text("Processing Image Gen... (0s)")
404
+ pipeline = StableDiffusionPipeline.from_pretrained("OFA-Sys/small-stable-diffusion-v0", torch_dtype=torch.float32).to("cpu")
405
  gen_image = pipeline(prompt, num_inference_steps=20).images[0]
406
  elapsed = int(time.time() - start_time)
407
  status.text(f"Image Gen completed in {elapsed}s!")
408
  gen_image.save(output_file)
 
409
  update_gallery()
410
  return gen_image
411
 
412
+ async def process_custom_diffusion(images, output_file, model_name):
413
  start_time = time.time()
414
  status = st.empty()
415
+ status.text(f"Training {model_name}... (0s)")
416
+ unet = TinyUNet()
417
+ diffusion = TinyDiffusion(unet)
418
+ diffusion.train(images)
419
+ gen_image = diffusion.generate()
420
+ upscaled_image = diffusion.upscale(gen_image, scale_factor=2)
421
  elapsed = int(time.time() - start_time)
422
+ status.text(f"{model_name} completed in {elapsed}s!")
423
+ upscaled_image.save(output_file)
 
424
  update_gallery()
425
+ return upscaled_image
426
 
 
427
  def mock_search(query: str) -> str:
428
  if "superhero" in query.lower():
429
  return "Latest trends: Gold-plated Batman statues, VR superhero battles."
 
439
  """
440
  return "No relevant results found."
441
 
 
442
  class PartyPlannerAgent:
443
  def __init__(self, model, tokenizer):
444
  self.model = model
 
503
  flight_time = (actual_distance / cruising_speed_kmh) + 1.0
504
  return round(flight_time, 2)
505
 
 
506
  st.title("AI Vision & SFT Titans 🚀")
507
 
 
508
  st.sidebar.header("Captured Files 📜")
509
  gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 2)
510
  def update_gallery():
511
+ media_files = get_gallery_files(["png"])
512
  pdf_files = get_pdf_files()
513
  if media_files or pdf_files:
514
+ st.sidebar.subheader("Images 📸")
515
  cols = st.sidebar.columns(2)
516
  for idx, file in enumerate(media_files[:gallery_size * 2]):
517
  with cols[idx % 2]:
518
+ st.image(Image.open(file), caption=os.path.basename(file), use_container_width=True)
 
 
 
 
 
519
  st.sidebar.subheader("PDF Downloads 📖")
520
  for pdf_file in pdf_files[:gallery_size * 2]:
521
  st.markdown(get_download_link(pdf_file, "application/pdf", f"📥 Grab {os.path.basename(pdf_file)}"), unsafe_allow_html=True)
 
545
  for entry in st.session_state['history'][-gallery_size * 2:]:
546
  st.write(entry)
547
 
548
+ tab1, tab2, tab3, tab4, tab5, tab6, tab7, tab8, tab9 = st.tabs([
 
549
  "Camera Snap 📷", "Download PDFs 📥", "Build Titan 🌱", "Fine-Tune Titan 🔧",
550
+ "Test Titan 🧪", "Agentic RAG Party 🌐", "Test OCR 🔍", "Test Image Gen 🎨", "Custom Diffusion 🎨🤓"
 
551
  ])
552
 
553
  with tab1:
 
558
  cam0_img = st.camera_input("Take a picture - Cam 0", key="cam0")
559
  if cam0_img:
560
  filename = generate_filename("cam0")
561
+ with open(filename, "wb") as f:
562
+ f.write(cam0_img.getvalue())
563
+ entry = f"Snapshot from Cam 0: {filename}"
564
+ if entry not in st.session_state['history']:
565
+ st.session_state['history'] = [e for e in st.session_state['history'] if not e.startswith("Snapshot from Cam 0:")] + [entry]
566
+ st.image(Image.open(filename), caption="Camera 0", use_container_width=True)
567
+ logger.info(f"Saved snapshot from Camera 0: {filename}")
568
+ update_gallery()
569
  with cols[1]:
570
  cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
571
  if cam1_img:
572
  filename = generate_filename("cam1")
573
+ with open(filename, "wb") as f:
574
+ f.write(cam1_img.getvalue())
575
+ entry = f"Snapshot from Cam 1: {filename}"
576
+ if entry not in st.session_state['history']:
577
+ st.session_state['history'] = [e for e in st.session_state['history'] if not e.startswith("Snapshot from Cam 1:")] + [entry]
578
+ st.image(Image.open(filename), caption="Camera 1", use_container_width=True)
579
+ logger.info(f"Saved snapshot from Camera 1: {filename}")
580
+ update_gallery()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581
 
582
  with tab2:
583
  st.header("Download PDFs 📥")
584
  if st.button("Examples 📚"):
585
  example_urls = [
586
+ "https://arxiv.org/pdf/2308.03892",
587
+ "https://arxiv.org/pdf/1912.01703",
588
+ "https://arxiv.org/pdf/2408.11039",
589
+ "https://arxiv.org/pdf/2109.10282",
590
+ "https://arxiv.org/pdf/2112.10752",
591
+ "https://arxiv.org/pdf/2308.11236",
592
+ "https://arxiv.org/pdf/1706.03762",
593
+ "https://arxiv.org/pdf/2006.11239",
594
+ "https://arxiv.org/pdf/2305.11207",
595
+ "https://arxiv.org/pdf/2106.09685",
596
+ "https://arxiv.org/pdf/2005.11401",
597
+ "https://arxiv.org/pdf/2106.10504"
598
  ]
599
  st.session_state['pdf_urls'] = "\n".join(example_urls)
600
 
 
640
  st.image(img, caption=os.path.basename(pdf_path), use_container_width=True)
641
  checkbox_key = f"pdf_{pdf_path}"
642
  st.session_state['pdf_checkboxes'][checkbox_key] = st.checkbox(
643
+ "Use for SFT/Input",
644
+ value=st.session_state['pdf_checkboxes'].get(checkbox_key, False),
645
+ key=checkbox_key
646
  )
647
  st.markdown(get_download_link(pdf_path, "application/pdf", "Snag It! 📥"), unsafe_allow_html=True)
648
  if st.button("Zap It! 🗑️", key=f"delete_{pdf_path}"):
 
842
  image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
843
  doc.close()
844
  st.image(image, caption="Input Image", use_container_width=True)
 
 
845
  if st.button("Run OCR 🚀", key="ocr_run"):
846
  output_file = generate_filename("ocr_output", "txt")
847
  st.session_state['processing']['ocr'] = True
848
+ result = asyncio.run(process_ocr(image, output_file))
849
+ entry = f"OCR Test: {selected_file} -> {output_file}"
850
+ if entry not in st.session_state['history']:
851
+ st.session_state['history'].append(entry)
852
  st.text_area("OCR Result", result, height=200, key="ocr_result")
853
  st.success(f"OCR output saved to {output_file}")
854
  st.session_state['processing']['ocr'] = False
 
876
  output_file = generate_filename("gen_output", "png")
877
  st.session_state['processing']['gen'] = True
878
  result = asyncio.run(process_image_gen(prompt, output_file))
879
+ entry = f"Image Gen Test: {prompt} -> {output_file}"
880
+ if entry not in st.session_state['history']:
881
+ st.session_state['history'].append(entry)
882
  st.image(result, caption="Generated Image", use_container_width=True)
883
  st.success(f"Image saved to {output_file}")
884
  st.session_state['processing']['gen'] = False
 
886
  st.warning("No images or PDFs captured yet. Use Camera Snap or Download PDFs first!")
887
 
888
  with tab9:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
889
  st.header("Custom Diffusion 🎨🤓")
890
  st.write("Unleash your inner artist with our tiny diffusion models!")
891
  captured_files = get_gallery_files(["png"])
 
931
  else:
932
  st.warning("No images or PDFs captured yet. Use Camera Snap or Download PDFs first!")
933
 
 
934
  update_gallery()