awacke1 commited on
Commit
ec29ecf
·
verified ·
1 Parent(s): 481e614

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -102
app.py CHANGED
@@ -4,41 +4,40 @@ import glob
4
  import base64
5
  import time
6
  import shutil
 
 
 
 
 
 
 
 
 
 
7
  import streamlit as st
8
  import pandas as pd
9
  import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
13
- from diffusers import StableDiffusionPipeline
14
- from torch.utils.data import Dataset, DataLoader
15
- import csv
16
  import fitz
17
  import requests
18
  from PIL import Image
19
- import cv2
20
- import numpy as np
21
- import logging
22
- import asyncio
23
- import aiofiles
24
- from io import BytesIO
25
- from dataclasses import dataclass
26
- from typing import Optional, Tuple
27
- import zipfile
28
- import math
29
- import random
30
- import re
31
 
 
32
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
33
  logger = logging.getLogger(__name__)
34
  log_records = []
35
-
36
  class LogCaptureHandler(logging.Handler):
37
  def emit(self, record):
38
  log_records.append(record)
39
-
40
  logger.addHandler(LogCaptureHandler())
41
 
 
42
  st.set_page_config(
43
  page_title="AI Vision & SFT Titans 🚀",
44
  page_icon="🤖",
@@ -51,6 +50,7 @@ st.set_page_config(
51
  }
52
  )
53
 
 
54
  if 'history' not in st.session_state:
55
  st.session_state['history'] = []
56
  if 'builder' not in st.session_state:
@@ -74,6 +74,7 @@ if 'cam0_file' not in st.session_state:
74
  if 'cam1_file' not in st.session_state:
75
  st.session_state['cam1_file'] = None
76
 
 
77
  @dataclass
78
  class ModelConfig:
79
  name: str
@@ -95,12 +96,14 @@ class DiffusionConfig:
95
  def model_path(self):
96
  return f"diffusion_models/{self.name}"
97
 
 
98
  class ModelBuilder:
99
  def __init__(self):
100
  self.config = None
101
  self.model = None
102
  self.tokenizer = None
103
- self.jokes = ["Why did the AI go to therapy? Too many layers to unpack! 😂", "Training complete! Time for a binary coffee break. ☕"]
 
104
  def load_model(self, model_path: str, config: Optional[ModelConfig] = None):
105
  with st.spinner(f"Loading {model_path}... ⏳"):
106
  self.model = AutoModelForCausalLM.from_pretrained(model_path)
@@ -128,7 +131,7 @@ class DiffusionBuilder:
128
  self.pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float32).to("cpu")
129
  if config:
130
  self.config = config
131
- st.success(f"Diffusion model loaded! 🎨")
132
  return self
133
  def save_model(self, path: str):
134
  with st.spinner("Saving diffusion model... 💾"):
@@ -138,6 +141,7 @@ class DiffusionBuilder:
138
  def generate(self, prompt: str):
139
  return self.pipeline(prompt, num_inference_steps=20).images[0]
140
 
 
141
  def generate_filename(sequence, ext="png"):
142
  timestamp = time.strftime("%d%m%Y%H%M%S")
143
  return f"{sequence}_{timestamp}.{ext}"
@@ -181,6 +185,7 @@ def download_pdf(url, output_path):
181
  logger.error(f"Failed to download {url}: {e}")
182
  return False
183
 
 
184
  async def process_pdf_snapshot(pdf_path, mode="single"):
185
  start_time = time.time()
186
  status = st.empty()
@@ -223,11 +228,10 @@ async def process_ocr(image, output_file):
223
  status.text("Processing GOT-OCR2_0... (0s)")
224
  tokenizer = AutoTokenizer.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True)
225
  model = AutoModel.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True, torch_dtype=torch.float32).to("cpu").eval()
226
- # Save image to temporary file since GOT-OCR2_0 expects a file path
227
  temp_file = f"temp_{int(time.time())}.png"
228
  image.save(temp_file)
229
  result = model.chat(tokenizer, temp_file, ocr_type='ocr')
230
- os.remove(temp_file) # Clean up temporary file
231
  elapsed = int(time.time() - start_time)
232
  status.text(f"GOT-OCR2_0 completed in {elapsed}s!")
233
  async with aiofiles.open(output_file, "w") as f:
@@ -250,49 +254,31 @@ async def process_image_gen(prompt, output_file):
250
  update_gallery()
251
  return gen_image
252
 
253
- st.title("AI Vision & SFT Titans 🚀")
254
-
255
- # Sidebar
256
- model_type = st.sidebar.selectbox("Model Type", ["Causal LM", "Diffusion"], key="sidebar_model_type", index=0 if st.session_state['selected_model_type'] == "Causal LM" else 1)
257
- model_dirs = get_model_files(model_type)
258
- if model_dirs and st.session_state['selected_model'] == "None" and "None" not in model_dirs:
259
- st.session_state['selected_model'] = model_dirs[0]
260
- selected_model = st.sidebar.selectbox("Select Saved Model", model_dirs, key="sidebar_model_select", index=model_dirs.index(st.session_state['selected_model']) if st.session_state['selected_model'] in model_dirs else 0)
261
- if selected_model != "None" and st.sidebar.button("Load Model 📂"):
262
- builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
263
- config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=os.path.basename(selected_model), base_model="unknown", size="small")
264
- builder.load_model(selected_model, config)
265
- st.session_state['builder'] = builder
266
- st.session_state['model_loaded'] = True
267
- st.rerun()
268
-
269
- st.sidebar.header("Captured Files 📜")
270
- cols = st.sidebar.columns(2)
271
- with cols[0]:
272
- if st.button("Zip All 🤐"):
273
- zip_path = f"all_assets_{int(time.time())}.zip"
274
- with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
275
- for file in get_gallery_files():
276
- zipf.write(file, os.path.basename(file))
277
- st.sidebar.markdown(get_download_link(zip_path, "application/zip", "Download All Assets"), unsafe_allow_html=True)
278
- with cols[1]:
279
- if st.button("Zap All! 🗑️"):
280
- for file in get_gallery_files():
281
- os.remove(file)
282
- st.session_state['asset_checkboxes'].clear()
283
- st.session_state['downloaded_pdfs'].clear()
284
- st.session_state['cam0_file'] = None
285
- st.session_state['cam1_file'] = None
286
- st.sidebar.success("All assets vaporized! 💨")
287
- st.rerun()
288
-
289
- gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 2)
290
  def update_gallery():
291
  all_files = get_gallery_files()
292
  if all_files:
293
  st.sidebar.subheader("Asset Gallery 📸📖")
294
  cols = st.sidebar.columns(2)
295
- for idx, file in enumerate(all_files[:gallery_size * 2]):
296
  with cols[idx % 2]:
297
  st.session_state['unique_counter'] += 1
298
  unique_id = st.session_state['unique_counter']
@@ -305,46 +291,41 @@ def update_gallery():
305
  st.image(img, caption=os.path.basename(file), use_container_width=True)
306
  doc.close()
307
  checkbox_key = f"asset_{file}_{unique_id}"
308
- st.session_state['asset_checkboxes'][file] = st.checkbox(
309
- "Use for SFT/Input",
310
- value=st.session_state['asset_checkboxes'].get(file, False),
311
- key=checkbox_key
312
- )
313
  mime_type = "image/png" if file.endswith('.png') else "application/pdf"
314
  st.markdown(get_download_link(file, mime_type, "Snag It! 📥"), unsafe_allow_html=True)
315
  if st.button("Zap It! 🗑️", key=f"delete_{file}_{unique_id}"):
316
  os.remove(file)
317
- if file in st.session_state['asset_checkboxes']:
318
- del st.session_state['asset_checkboxes'][file]
319
- if file.endswith('.pdf'):
320
- url_key = next((k for k, v in st.session_state['downloaded_pdfs'].items() if v == file), None)
321
- if url_key:
322
- del st.session_state['downloaded_pdfs'][url_key]
323
- if file == st.session_state['cam0_file']:
324
- st.session_state['cam0_file'] = None
325
- if file == st.session_state['cam1_file']:
326
- st.session_state['cam1_file'] = None
327
  st.sidebar.success(f"Asset {os.path.basename(file)} vaporized! 💨")
328
- st.rerun()
329
  update_gallery()
330
 
 
331
  st.sidebar.subheader("Action Logs 📜")
332
- log_container = st.sidebar.empty()
333
- with log_container:
334
  for record in log_records:
335
  st.write(f"{record.asctime} - {record.levelname} - {record.message}")
336
-
337
  st.sidebar.subheader("History 📜")
338
- history_container = st.sidebar.empty()
339
- with history_container:
340
- for entry in st.session_state['history'][-gallery_size * 2:]:
341
  st.write(entry)
342
 
343
- tab1, tab2, tab3, tab4 = st.tabs([
344
- "Camera Snap 📷", "Download PDFs 📥", "Test OCR 🔍", "Build Titan 🌱"
 
 
 
 
 
 
 
 
345
  ])
 
346
 
347
- with tab1:
 
348
  st.header("Camera Snap 📷")
349
  st.subheader("Single Capture")
350
  cols = st.columns(2)
@@ -363,8 +344,6 @@ with tab1:
363
  st.image(Image.open(filename), caption="Camera 0", use_container_width=True)
364
  logger.info(f"Saved snapshot from Camera 0: {filename}")
365
  update_gallery()
366
- elif st.session_state['cam0_file'] and os.path.exists(st.session_state['cam0_file']):
367
- st.image(Image.open(st.session_state['cam0_file']), caption="Camera 0", use_container_width=True)
368
  with cols[1]:
369
  cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
370
  if cam1_img:
@@ -380,10 +359,9 @@ with tab1:
380
  st.image(Image.open(filename), caption="Camera 1", use_container_width=True)
381
  logger.info(f"Saved snapshot from Camera 1: {filename}")
382
  update_gallery()
383
- elif st.session_state['cam1_file'] and os.path.exists(st.session_state['cam1_file']):
384
- st.image(Image.open(st.session_state['cam1_file']), caption="Camera 1", use_container_width=True)
385
 
386
- with tab2:
 
387
  st.header("Download PDFs 📥")
388
  if st.button("Examples 📚"):
389
  example_urls = [
@@ -420,7 +398,7 @@ with tab2:
420
  entry = f"Downloaded PDF: {output_path}"
421
  if entry not in st.session_state['history']:
422
  st.session_state['history'].append(entry)
423
- st.session_state['asset_checkboxes'][output_path] = True # Auto-check the box
424
  else:
425
  st.error(f"Failed to nab {url} 😿")
426
  else:
@@ -429,7 +407,6 @@ with tab2:
429
  progress_bar.progress((idx + 1) / total_urls)
430
  status_text.text("Robo-Download complete! 🚀")
431
  update_gallery()
432
-
433
  mode = st.selectbox("Snapshot Mode", ["Single Page (High-Res)", "Two Pages (High-Res)", "All Pages (High-Res)"], key="download_mode")
434
  if st.button("Snapshot Selected 📸"):
435
  selected_pdfs = [path for path in get_gallery_files() if path.endswith('.pdf') and st.session_state['asset_checkboxes'].get(path, False)]
@@ -439,12 +416,13 @@ with tab2:
439
  snapshots = asyncio.run(process_pdf_snapshot(pdf_path, mode_key))
440
  for snapshot in snapshots:
441
  st.image(Image.open(snapshot), caption=snapshot, use_container_width=True)
442
- st.session_state['asset_checkboxes'][snapshot] = True # Auto-check new snapshots
443
  update_gallery()
444
  else:
445
- st.warning("No PDFs selected for snapshotting! Check some boxes in the sidebar gallery.")
446
 
447
- with tab3:
 
448
  st.header("Test OCR 🔍")
449
  all_files = get_gallery_files()
450
  if all_files:
@@ -509,7 +487,8 @@ with tab3:
509
  else:
510
  st.warning("No assets in gallery yet. Use Camera Snap or Download PDFs!")
511
 
512
- with tab4:
 
513
  st.header("Build Titan 🌱")
514
  model_type = st.selectbox("Model Type", ["Causal LM", "Diffusion"], key="build_type")
515
  base_model = st.selectbox("Select Tiny Model",
@@ -530,10 +509,10 @@ with tab4:
530
  if entry not in st.session_state['history']:
531
  st.session_state['history'].append(entry)
532
  st.success(f"Model downloaded and saved to {config.model_path}! 🎉")
533
- st.rerun()
534
 
535
- tab5 = st.tabs(["Test Image Gen 🎨"])[0]
536
- with tab5:
537
  st.header("Test Image Gen 🎨")
538
  all_files = get_gallery_files()
539
  if all_files:
@@ -560,5 +539,140 @@ with tab5:
560
  st.session_state['processing']['gen'] = False
561
  else:
562
  st.warning("No images or PDFs in gallery yet. Use Camera Snap or Download PDFs!")
 
563
 
564
- update_gallery()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import base64
5
  import time
6
  import shutil
7
+ import zipfile
8
+ import re
9
+ import logging
10
+ import asyncio
11
+ from io import BytesIO
12
+ from datetime import datetime
13
+ import pytz
14
+ from dataclasses import dataclass
15
+ from typing import Optional
16
+
17
  import streamlit as st
18
  import pandas as pd
19
  import torch
 
 
 
 
 
 
20
  import fitz
21
  import requests
22
  from PIL import Image
23
+ from diffusers import StableDiffusionPipeline
24
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
25
+
26
+ # --- OpenAI Setup (for GPT related features) ---
27
+ import openai
28
+ openai.api_key = os.getenv('OPENAI_API_KEY')
29
+ openai.organization = os.getenv('OPENAI_ORG_ID')
 
 
 
 
 
30
 
31
+ # --- Logging ---
32
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
33
  logger = logging.getLogger(__name__)
34
  log_records = []
 
35
  class LogCaptureHandler(logging.Handler):
36
  def emit(self, record):
37
  log_records.append(record)
 
38
  logger.addHandler(LogCaptureHandler())
39
 
40
+ # --- Streamlit Page Config ---
41
  st.set_page_config(
42
  page_title="AI Vision & SFT Titans 🚀",
43
  page_icon="🤖",
 
50
  }
51
  )
52
 
53
+ # --- Session State Defaults ---
54
  if 'history' not in st.session_state:
55
  st.session_state['history'] = []
56
  if 'builder' not in st.session_state:
 
74
  if 'cam1_file' not in st.session_state:
75
  st.session_state['cam1_file'] = None
76
 
77
+ # --- Model & Diffusion DataClasses ---
78
  @dataclass
79
  class ModelConfig:
80
  name: str
 
96
  def model_path(self):
97
  return f"diffusion_models/{self.name}"
98
 
99
+ # --- Model Builders ---
100
  class ModelBuilder:
101
  def __init__(self):
102
  self.config = None
103
  self.model = None
104
  self.tokenizer = None
105
+ self.jokes = ["Why did the AI go to therapy? Too many layers to unpack! 😂",
106
+ "Training complete! Time for a binary coffee break. ☕"]
107
  def load_model(self, model_path: str, config: Optional[ModelConfig] = None):
108
  with st.spinner(f"Loading {model_path}... ⏳"):
109
  self.model = AutoModelForCausalLM.from_pretrained(model_path)
 
131
  self.pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float32).to("cpu")
132
  if config:
133
  self.config = config
134
+ st.success("Diffusion model loaded! 🎨")
135
  return self
136
  def save_model(self, path: str):
137
  with st.spinner("Saving diffusion model... 💾"):
 
141
  def generate(self, prompt: str):
142
  return self.pipeline(prompt, num_inference_steps=20).images[0]
143
 
144
+ # --- Utility Functions ---
145
  def generate_filename(sequence, ext="png"):
146
  timestamp = time.strftime("%d%m%Y%H%M%S")
147
  return f"{sequence}_{timestamp}.{ext}"
 
185
  logger.error(f"Failed to download {url}: {e}")
186
  return False
187
 
188
+ # --- Original PDF Snapshot & OCR Functions ---
189
  async def process_pdf_snapshot(pdf_path, mode="single"):
190
  start_time = time.time()
191
  status = st.empty()
 
228
  status.text("Processing GOT-OCR2_0... (0s)")
229
  tokenizer = AutoTokenizer.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True)
230
  model = AutoModel.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True, torch_dtype=torch.float32).to("cpu").eval()
 
231
  temp_file = f"temp_{int(time.time())}.png"
232
  image.save(temp_file)
233
  result = model.chat(tokenizer, temp_file, ocr_type='ocr')
234
+ os.remove(temp_file)
235
  elapsed = int(time.time() - start_time)
236
  status.text(f"GOT-OCR2_0 completed in {elapsed}s!")
237
  async with aiofiles.open(output_file, "w") as f:
 
254
  update_gallery()
255
  return gen_image
256
 
257
+ # --- New Function: Process an image (PIL) with a custom prompt using GPT ---
258
+ def process_image_with_prompt(image, prompt, model="o3-mini-high"):
259
+ buffered = BytesIO()
260
+ image.save(buffered, format="PNG")
261
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
262
+ messages = [{
263
+ "role": "user",
264
+ "content": [
265
+ {"type": "text", "text": prompt},
266
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_str}"}}
267
+ ]
268
+ }]
269
+ try:
270
+ response = openai.ChatCompletion.create(model=model, messages=messages)
271
+ return response.choices[0].message.content
272
+ except Exception as e:
273
+ return f"Error processing image with GPT: {str(e)}"
274
+
275
+ # --- Gallery Update ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  def update_gallery():
277
  all_files = get_gallery_files()
278
  if all_files:
279
  st.sidebar.subheader("Asset Gallery 📸📖")
280
  cols = st.sidebar.columns(2)
281
+ for idx, file in enumerate(all_files[:st.sidebar.slider("Gallery Size", 1, 10, 2)]):
282
  with cols[idx % 2]:
283
  st.session_state['unique_counter'] += 1
284
  unique_id = st.session_state['unique_counter']
 
291
  st.image(img, caption=os.path.basename(file), use_container_width=True)
292
  doc.close()
293
  checkbox_key = f"asset_{file}_{unique_id}"
294
+ st.session_state['asset_checkboxes'][file] = st.checkbox("Use for SFT/Input", value=st.session_state['asset_checkboxes'].get(file, False), key=checkbox_key)
 
 
 
 
295
  mime_type = "image/png" if file.endswith('.png') else "application/pdf"
296
  st.markdown(get_download_link(file, mime_type, "Snag It! 📥"), unsafe_allow_html=True)
297
  if st.button("Zap It! 🗑️", key=f"delete_{file}_{unique_id}"):
298
  os.remove(file)
299
+ st.session_state['asset_checkboxes'].pop(file, None)
 
 
 
 
 
 
 
 
 
300
  st.sidebar.success(f"Asset {os.path.basename(file)} vaporized! 💨")
301
+ st.experimental_rerun()
302
  update_gallery()
303
 
304
+ # --- Sidebar Logs & History ---
305
  st.sidebar.subheader("Action Logs 📜")
306
+ with st.sidebar:
 
307
  for record in log_records:
308
  st.write(f"{record.asctime} - {record.levelname} - {record.message}")
 
309
  st.sidebar.subheader("History 📜")
310
+ with st.sidebar:
311
+ for entry in st.session_state['history']:
 
312
  st.write(entry)
313
 
314
+ # --- Create Tabs (Existing + New) ---
315
+ tabs = st.tabs([
316
+ "Camera Snap 📷",
317
+ "Download PDFs 📥",
318
+ "Test OCR 🔍",
319
+ "Build Titan 🌱",
320
+ "Test Image Gen 🎨",
321
+ "PDF Process 📄",
322
+ "Image Process 🖼️",
323
+ "MD Gallery 📚"
324
  ])
325
+ (tab_camera, tab_download, tab_ocr, tab_build, tab_imggen, tab_pdf_process, tab_image_process, tab_md_gallery) = tabs
326
 
327
+ # === Tab: Camera Snap (existing) ===
328
+ with tab_camera:
329
  st.header("Camera Snap 📷")
330
  st.subheader("Single Capture")
331
  cols = st.columns(2)
 
344
  st.image(Image.open(filename), caption="Camera 0", use_container_width=True)
345
  logger.info(f"Saved snapshot from Camera 0: {filename}")
346
  update_gallery()
 
 
347
  with cols[1]:
348
  cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
349
  if cam1_img:
 
359
  st.image(Image.open(filename), caption="Camera 1", use_container_width=True)
360
  logger.info(f"Saved snapshot from Camera 1: {filename}")
361
  update_gallery()
 
 
362
 
363
+ # === Tab: Download PDFs (existing) ===
364
+ with tab_download:
365
  st.header("Download PDFs 📥")
366
  if st.button("Examples 📚"):
367
  example_urls = [
 
398
  entry = f"Downloaded PDF: {output_path}"
399
  if entry not in st.session_state['history']:
400
  st.session_state['history'].append(entry)
401
+ st.session_state['asset_checkboxes'][output_path] = True
402
  else:
403
  st.error(f"Failed to nab {url} 😿")
404
  else:
 
407
  progress_bar.progress((idx + 1) / total_urls)
408
  status_text.text("Robo-Download complete! 🚀")
409
  update_gallery()
 
410
  mode = st.selectbox("Snapshot Mode", ["Single Page (High-Res)", "Two Pages (High-Res)", "All Pages (High-Res)"], key="download_mode")
411
  if st.button("Snapshot Selected 📸"):
412
  selected_pdfs = [path for path in get_gallery_files() if path.endswith('.pdf') and st.session_state['asset_checkboxes'].get(path, False)]
 
416
  snapshots = asyncio.run(process_pdf_snapshot(pdf_path, mode_key))
417
  for snapshot in snapshots:
418
  st.image(Image.open(snapshot), caption=snapshot, use_container_width=True)
419
+ st.session_state['asset_checkboxes'][snapshot] = True
420
  update_gallery()
421
  else:
422
+ st.warning("No PDFs selected for snapshotting! Check some boxes in the sidebar.")
423
 
424
+ # === Tab: Test OCR (existing) ===
425
+ with tab_ocr:
426
  st.header("Test OCR 🔍")
427
  all_files = get_gallery_files()
428
  if all_files:
 
487
  else:
488
  st.warning("No assets in gallery yet. Use Camera Snap or Download PDFs!")
489
 
490
+ # === Tab: Build Titan (existing) ===
491
+ with tab_build:
492
  st.header("Build Titan 🌱")
493
  model_type = st.selectbox("Model Type", ["Causal LM", "Diffusion"], key="build_type")
494
  base_model = st.selectbox("Select Tiny Model",
 
509
  if entry not in st.session_state['history']:
510
  st.session_state['history'].append(entry)
511
  st.success(f"Model downloaded and saved to {config.model_path}! 🎉")
512
+ st.experimental_rerun()
513
 
514
+ # === Tab: Test Image Gen (existing) ===
515
+ with tab_imggen:
516
  st.header("Test Image Gen 🎨")
517
  all_files = get_gallery_files()
518
  if all_files:
 
539
  st.session_state['processing']['gen'] = False
540
  else:
541
  st.warning("No images or PDFs in gallery yet. Use Camera Snap or Download PDFs!")
542
+ update_gallery()
543
 
544
+ # === New Tab: PDF Process ===
545
+ with tab_pdf_process:
546
+ st.header("PDF Process")
547
+ st.subheader("Upload PDFs for GPT-based text extraction")
548
+ uploaded_pdfs = st.file_uploader("Upload PDF files", type=["pdf"], accept_multiple_files=True, key="pdf_process_uploader")
549
+ view_mode = st.selectbox("View Mode", ["Single Page", "Double Page"], key="pdf_view_mode")
550
+ if st.button("Process Uploaded PDFs", key="process_pdfs"):
551
+ combined_text = ""
552
+ for pdf_file in uploaded_pdfs:
553
+ pdf_bytes = pdf_file.read()
554
+ temp_pdf_path = f"temp_{pdf_file.name}"
555
+ with open(temp_pdf_path, "wb") as f:
556
+ f.write(pdf_bytes)
557
+ try:
558
+ doc = fitz.open(temp_pdf_path)
559
+ st.write(f"Processing {pdf_file.name} with {len(doc)} pages")
560
+ if view_mode == "Single Page":
561
+ for i, page in enumerate(doc):
562
+ pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
563
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
564
+ st.image(img, caption=f"{pdf_file.name} Page {i+1}")
565
+ gpt_text = process_image_with_prompt(img, "Extract the electronic text from image")
566
+ combined_text += f"\n## {pdf_file.name} - Page {i+1}\n\n{gpt_text}\n"
567
+ else: # Double Page: combine two consecutive pages
568
+ pages = list(doc)
569
+ for i in range(0, len(pages), 2):
570
+ if i+1 < len(pages):
571
+ pix1 = pages[i].get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
572
+ img1 = Image.frombytes("RGB", [pix1.width, pix1.height], pix1.samples)
573
+ pix2 = pages[i+1].get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
574
+ img2 = Image.frombytes("RGB", [pix2.width, pix2.height], pix2.samples)
575
+ total_width = img1.width + img2.width
576
+ max_height = max(img1.height, img2.height)
577
+ combined_img = Image.new("RGB", (total_width, max_height))
578
+ combined_img.paste(img1, (0, 0))
579
+ combined_img.paste(img2, (img1.width, 0))
580
+ st.image(combined_img, caption=f"{pdf_file.name} Pages {i+1}-{i+2}")
581
+ gpt_text = process_image_with_prompt(combined_img, "Extract the electronic text from image")
582
+ combined_text += f"\n## {pdf_file.name} - Pages {i+1}-{i+2}\n\n{gpt_text}\n"
583
+ else:
584
+ pix = pages[i].get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
585
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
586
+ st.image(img, caption=f"{pdf_file.name} Page {i+1}")
587
+ gpt_text = process_image_with_prompt(img, "Extract the electronic text from image")
588
+ combined_text += f"\n## {pdf_file.name} - Page {i+1}\n\n{gpt_text}\n"
589
+ doc.close()
590
+ except Exception as e:
591
+ st.error(f"Error processing {pdf_file.name}: {str(e)}")
592
+ finally:
593
+ os.remove(temp_pdf_path)
594
+ output_filename = generate_filename("processed_pdf", "md")
595
+ with open(output_filename, "w", encoding="utf-8") as f:
596
+ f.write(combined_text)
597
+ st.success(f"PDF processing complete. MD file saved as {output_filename}")
598
+ st.markdown(get_download_link(output_filename, "text/markdown", "Download Processed PDF MD"), unsafe_allow_html=True)
599
+
600
+ # === New Tab: Image Process ===
601
+ with tab_image_process:
602
+ st.header("Image Process")
603
+ st.subheader("Upload Images for GPT-based OCR")
604
+ prompt_img = st.text_input("Enter prompt for image processing", "Extract the electronic text from image", key="img_process_prompt")
605
+ uploaded_images = st.file_uploader("Upload image files", type=["png", "jpg", "jpeg"], accept_multiple_files=True, key="image_process_uploader")
606
+ if st.button("Process Uploaded Images", key="process_images"):
607
+ combined_text = ""
608
+ for img_file in uploaded_images:
609
+ try:
610
+ img = Image.open(img_file)
611
+ st.image(img, caption=img_file.name)
612
+ gpt_text = process_image_with_prompt(img, prompt_img)
613
+ combined_text += f"\n## {img_file.name}\n\n{gpt_text}\n"
614
+ except Exception as e:
615
+ st.error(f"Error processing image {img_file.name}: {str(e)}")
616
+ output_filename = generate_filename("processed_image", "md")
617
+ with open(output_filename, "w", encoding="utf-8") as f:
618
+ f.write(combined_text)
619
+ st.success(f"Image processing complete. MD file saved as {output_filename}")
620
+ st.markdown(get_download_link(output_filename, "text/markdown", "Download Processed Image MD"), unsafe_allow_html=True)
621
+
622
+ # === New Tab: MD Gallery ===
623
+ with tab_md_gallery:
624
+ st.header("MD Gallery and GPT Processing")
625
+ md_files = sorted(glob.glob("*.md"))
626
+ if md_files:
627
+ st.subheader("Individual File Processing")
628
+ cols = st.columns(2)
629
+ for idx, md_file in enumerate(md_files):
630
+ with cols[idx % 2]:
631
+ st.write(md_file)
632
+ if st.button(f"Process {md_file}", key=f"process_md_{md_file}"):
633
+ try:
634
+ with open(md_file, "r", encoding="utf-8") as f:
635
+ content = f.read()
636
+ prompt_md = "Summarize this into markdown outline with emojis and number the topics 1..12"
637
+ messages = [{"role": "user", "content": prompt_md + "\n\n" + content}]
638
+ response = openai.ChatCompletion.create(model="o3-mini-high", messages=messages)
639
+ result_text = response.choices[0].message.content
640
+ st.markdown(result_text)
641
+ output_filename = generate_filename(f"processed_{os.path.splitext(md_file)[0]}", "md")
642
+ with open(output_filename, "w", encoding="utf-8") as f:
643
+ f.write(result_text)
644
+ st.markdown(get_download_link(output_filename, "text/markdown", f"Download {output_filename}"), unsafe_allow_html=True)
645
+ except Exception as e:
646
+ st.error(f"Error processing {md_file}: {str(e)}")
647
+ st.subheader("Batch Processing")
648
+ st.write("Select MD files to combine and process:")
649
+ selected_md = {}
650
+ for md_file in md_files:
651
+ selected_md[md_file] = st.checkbox(md_file, key=f"checkbox_md_{md_file}")
652
+ batch_prompt = st.text_input("Enter batch processing prompt", "Summarize this into markdown outline with emojis and number the topics 1..12", key="batch_prompt")
653
+ if st.button("Process Selected MD Files", key="process_batch_md"):
654
+ combined_content = ""
655
+ for md_file, selected in selected_md.items():
656
+ if selected:
657
+ try:
658
+ with open(md_file, "r", encoding="utf-8") as f:
659
+ combined_content += f"\n## {md_file}\n" + f.read() + "\n"
660
+ except Exception as e:
661
+ st.error(f"Error reading {md_file}: {str(e)}")
662
+ if combined_content:
663
+ messages = [{"role": "user", "content": batch_prompt + "\n\n" + combined_content}]
664
+ try:
665
+ response = openai.ChatCompletion.create(model="o3-mini-high", messages=messages)
666
+ result_text = response.choices[0].message.content
667
+ st.markdown(result_text)
668
+ output_filename = generate_filename("batch_processed_md", "md")
669
+ with open(output_filename, "w", encoding="utf-8") as f:
670
+ f.write(result_text)
671
+ st.success(f"Batch processing complete. MD file saved as {output_filename}")
672
+ st.markdown(get_download_link(output_filename, "text/markdown", "Download Batch Processed MD"), unsafe_allow_html=True)
673
+ except Exception as e:
674
+ st.error(f"Error processing batch: {str(e)}")
675
+ else:
676
+ st.warning("No MD files selected.")
677
+ else:
678
+ st.warning("No MD files found.")