awacke1 commited on
Commit
2eee123
·
verified ·
1 Parent(s): c801f72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +458 -97
app.py CHANGED
@@ -1,21 +1,32 @@
1
  #!/usr/bin/env python3
2
  import os
3
  import glob
 
4
  import time
 
5
  import streamlit as st
 
 
 
 
 
 
6
  import fitz # PyMuPDF
7
  import requests
8
  from PIL import Image
9
- from transformers import AutoTokenizer, AutoModel
10
- from diffusers import StableDiffusionPipeline
11
  import cv2
12
  import numpy as np
13
  import logging
14
  import asyncio
15
  import aiofiles
16
  from io import BytesIO
 
 
 
 
 
17
 
18
- # Logging setup
19
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
20
  logger = logging.getLogger(__name__)
21
  log_records = []
@@ -28,40 +39,212 @@ logger.addHandler(LogCaptureHandler())
28
 
29
  # Page Configuration
30
  st.set_page_config(
31
- page_title="AI Vision Titans 🚀",
32
  page_icon="🤖",
33
  layout="wide",
34
  initial_sidebar_state="expanded",
35
- menu_items={'About': "AI Vision Titans: PDF Snapshots, OCR, Image Gen, Line Drawings on CPU! 🌌"}
 
 
 
 
36
  )
37
 
38
  # Initialize st.session_state
39
  if 'captured_files' not in st.session_state:
40
  st.session_state['captured_files'] = []
 
 
 
 
41
  if 'processing' not in st.session_state:
42
  st.session_state['processing'] = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  # Utility Functions
45
- def generate_filename(sequence, ext="png"):
46
- timestamp = time.strftime("%d%m%Y%H%M%S")
47
- return f"{sequence}{timestamp}.{ext}"
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  def get_gallery_files(file_types):
50
  return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
51
 
52
- def update_gallery():
53
- media_files = get_gallery_files(["png", "txt"])
54
- if media_files:
55
- cols = st.sidebar.columns(2)
56
- for idx, file in enumerate(media_files[:gallery_size * 2]):
57
- with cols[idx % 2]:
58
- if file.endswith(".png"):
59
- st.image(Image.open(file), caption=file, use_container_width=True)
60
- elif file.endswith(".txt"):
61
- with open(file, "r") as f:
62
- content = f.read()
63
- st.text(content[:50] + "..." if len(content) > 50 else content, help=file)
64
-
65
  def download_pdf(url, output_path):
66
  try:
67
  response = requests.get(url, stream=True, timeout=10)
@@ -74,28 +257,86 @@ def download_pdf(url, output_path):
74
  logger.error(f"Failed to download {url}: {e}")
75
  return False
76
 
77
- # Model Loaders (CPU-focused)
78
- def load_ocr_got():
79
- model_id = "ucaslcl/GOT-OCR2_0"
80
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
81
- model = AutoModel.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float32).to("cpu").eval()
82
- return tokenizer, model
83
-
84
- def load_image_gen():
85
- model_id = "OFA-Sys/small-stable-diffusion-v0" # ~300 MB
86
- pipeline = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32).to("cpu")
87
- return pipeline
88
-
89
- def load_line_drawer():
90
- def edge_detection(image, style="fine"):
91
- img_np = np.array(image.convert("RGB"))
92
- gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
93
- if style == "fine":
94
- edges = cv2.Canny(gray, 50, 150) # Finer lines
95
- else: # "bold"
96
- edges = cv2.Canny(gray, 100, 200) # Bolder lines
97
- return Image.fromarray(edges)
98
- return edge_detection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  # Async Processing Functions
101
  async def process_pdf_snapshot(pdf_path, mode="thumbnail"):
@@ -104,21 +345,19 @@ async def process_pdf_snapshot(pdf_path, mode="thumbnail"):
104
  status.text(f"Processing PDF Snapshot ({mode})... (0s)")
105
  doc = fitz.open(pdf_path)
106
  output_files = []
107
-
108
  if mode == "thumbnail":
109
  page = doc[0]
110
- pix = page.get_pixmap(matrix=fitz.Matrix(0.5, 0.5)) # 50% scale
111
  output_file = generate_filename("thumbnail", "png")
112
  pix.save(output_file)
113
  output_files.append(output_file)
114
  elif mode == "twopage":
115
  for i in range(min(2, len(doc))):
116
  page = doc[i]
117
- pix = page.get_pixmap(matrix=fitz.Matrix(1.0, 1.0)) # Full scale
118
  output_file = generate_filename(f"twopage_{i}", "png")
119
  pix.save(output_file)
120
  output_files.append(output_file)
121
-
122
  doc.close()
123
  elapsed = int(time.time() - start_time)
124
  status.text(f"PDF Snapshot ({mode}) completed in {elapsed}s!")
@@ -143,50 +382,43 @@ async def process_ocr(image, output_file):
143
  update_gallery()
144
  return result
145
 
146
- async def process_image_gen(prompt, output_file):
147
- start_time = time.time()
148
- status = st.empty()
149
- status.text("Processing Image Gen... (0s)")
150
- pipeline = load_image_gen()
151
- gen_image = pipeline(prompt, num_inference_steps=20).images[0]
152
- elapsed = int(time.time() - start_time)
153
- status.text(f"Image Gen completed in {elapsed}s!")
154
- gen_image.save(output_file)
155
- if output_file not in st.session_state['captured_files']:
156
- st.session_state['captured_files'].append(output_file)
157
- update_gallery()
158
- return gen_image
159
-
160
- async def process_line_drawing(image, style, output_file):
161
- start_time = time.time()
162
- status = st.empty()
163
- status.text(f"Processing Line Drawing ({style})... (0s)")
164
- edge_fn = load_line_drawer()
165
- line_drawing = edge_fn(image, style=style)
166
- elapsed = int(time.time() - start_time)
167
- status.text(f"Line Drawing ({style}) completed in {elapsed}s!")
168
- line_drawing.save(output_file)
169
- if output_file not in st.session_state['captured_files']:
170
- st.session_state['captured_files'].append(output_file)
171
- update_gallery()
172
- return line_drawing
173
-
174
  # Main App
175
- st.title("AI Vision Titans 🚀")
176
 
177
- # Sidebar Gallery
178
  st.sidebar.header("Captured Files 📜")
179
  gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 4)
180
  update_gallery()
181
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  st.sidebar.subheader("Action Logs 📜")
183
  log_container = st.sidebar.empty()
184
  with log_container:
185
  for record in log_records:
186
  st.write(f"{record.asctime} - {record.levelname} - {record.message}")
187
 
 
 
 
 
 
 
188
  # Tabs
189
- tab1, tab2, tab3, tab4, tab5 = st.tabs(["Camera Snap 📷", "Download PDFs 📥", "Test OCR 🔍", "Test Image Gen 🎨", "Test Line Drawings ✏️"])
 
 
 
190
 
191
  with tab1:
192
  st.header("Camera Snap 📷")
@@ -202,6 +434,7 @@ with tab1:
202
  st.image(Image.open(filename), caption=filename, use_container_width=True)
203
  logger.info(f"Saved snapshot from Camera 0: {filename}")
204
  st.session_state['captured_files'].append(filename)
 
205
  update_gallery()
206
  with cols[1]:
207
  cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
@@ -213,6 +446,7 @@ with tab1:
213
  st.image(Image.open(filename), caption=filename, use_container_width=True)
214
  logger.info(f"Saved snapshot from Camera 1: {filename}")
215
  st.session_state['captured_files'].append(filename)
 
216
  update_gallery()
217
 
218
  st.subheader("Burst Capture")
@@ -231,6 +465,7 @@ with tab1:
231
  f.write(img.getvalue())
232
  st.session_state['burst_frames'].append(filename)
233
  logger.info(f"Saved burst frame {i}: {filename}")
 
234
  st.image(Image.open(filename), caption=filename, use_container_width=True)
235
  time.sleep(0.5)
236
  st.session_state['captured_files'].extend([f for f in st.session_state['burst_frames'] if f not in st.session_state['captured_files']])
@@ -248,6 +483,7 @@ with tab2:
248
  pdf_path = generate_filename("downloaded", "pdf")
249
  if download_pdf(url, pdf_path):
250
  logger.info(f"Downloaded PDF from {url} to {pdf_path}")
 
251
  snapshots = asyncio.run(process_pdf_snapshot(pdf_path, mode.lower().replace(" ", "")))
252
  for snapshot in snapshots:
253
  st.image(Image.open(snapshot), caption=snapshot, use_container_width=True)
@@ -255,6 +491,147 @@ with tab2:
255
  st.error(f"Failed to download {url}")
256
 
257
  with tab3:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  st.header("Test OCR 🔍")
259
  captured_files = get_gallery_files(["png"])
260
  if captured_files:
@@ -265,13 +642,14 @@ with tab3:
265
  output_file = generate_filename("ocr_output", "txt")
266
  st.session_state['processing']['ocr'] = True
267
  result = asyncio.run(process_ocr(image, output_file))
 
268
  st.text_area("OCR Result", result, height=200, key="ocr_result")
269
  st.success(f"OCR output saved to {output_file}")
270
  st.session_state['processing']['ocr'] = False
271
  else:
272
  st.warning("No images captured yet. Use Camera Snap or Download PDFs first!")
273
 
274
- with tab4:
275
  st.header("Test Image Gen 🎨")
276
  captured_files = get_gallery_files(["png"])
277
  if captured_files:
@@ -283,29 +661,12 @@ with tab4:
283
  output_file = generate_filename("gen_output", "png")
284
  st.session_state['processing']['gen'] = True
285
  result = asyncio.run(process_image_gen(prompt, output_file))
 
286
  st.image(result, caption="Generated Image", use_container_width=True)
287
  st.success(f"Image saved to {output_file}")
288
  st.session_state['processing']['gen'] = False
289
  else:
290
  st.warning("No images captured yet. Use Camera Snap or Download PDFs first!")
291
 
292
- with tab5:
293
- st.header("Test Line Drawings ✏️")
294
- captured_files = get_gallery_files(["png"])
295
- if captured_files:
296
- selected_file = st.selectbox("Select Image", captured_files, key="line_select")
297
- image = Image.open(selected_file)
298
- st.image(image, caption="Input Image", use_container_width=True)
299
- style = st.selectbox("Line Style", ["Fine", "Bold"], key="line_style")
300
- if st.button("Run Line Drawing 🚀", key="line_run"):
301
- output_file = generate_filename(f"line_{style.lower()}", "png")
302
- st.session_state['processing']['line'] = True
303
- result = asyncio.run(process_line_drawing(image, style.lower(), output_file))
304
- st.image(result, caption=f"{style} Line Drawing", use_container_width=True)
305
- st.success(f"Line drawing saved to {output_file}")
306
- st.session_state['processing']['line'] = False
307
- else:
308
- st.warning("No images captured yet. Use Camera Snap or Download PDFs first!")
309
-
310
  # Initial Gallery Update
311
  update_gallery()
 
1
  #!/usr/bin/env python3
2
  import os
3
  import glob
4
+ import base64
5
  import time
6
+ import shutil
7
  import streamlit as st
8
+ import pandas as pd
9
+ import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
11
+ from diffusers import StableDiffusionPipeline
12
+ from torch.utils.data import Dataset, DataLoader
13
+ import csv
14
  import fitz # PyMuPDF
15
  import requests
16
  from PIL import Image
 
 
17
  import cv2
18
  import numpy as np
19
  import logging
20
  import asyncio
21
  import aiofiles
22
  from io import BytesIO
23
+ from dataclasses import dataclass
24
+ from typing import Optional, Tuple
25
+ import zipfile
26
+ import math
27
+ import random
28
 
29
+ # Logging setup with custom buffer
30
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
31
  logger = logging.getLogger(__name__)
32
  log_records = []
 
39
 
40
  # Page Configuration
41
  st.set_page_config(
42
+ page_title="AI Vision & SFT Titans 🚀",
43
  page_icon="🤖",
44
  layout="wide",
45
  initial_sidebar_state="expanded",
46
+ menu_items={
47
+ 'Get Help': 'https://huggingface.co/awacke1',
48
+ 'Report a Bug': 'https://huggingface.co/spaces/awacke1',
49
+ 'About': "AI Vision & SFT Titans: PDFs, OCR, Image Gen, Line Drawings, and SFT on CPU! 🌌"
50
+ }
51
  )
52
 
53
  # Initialize st.session_state
54
  if 'captured_files' not in st.session_state:
55
  st.session_state['captured_files'] = []
56
+ if 'builder' not in st.session_state:
57
+ st.session_state['builder'] = None
58
+ if 'model_loaded' not in st.session_state:
59
+ st.session_state['model_loaded'] = False
60
  if 'processing' not in st.session_state:
61
  st.session_state['processing'] = {}
62
+ if 'history' not in st.session_state:
63
+ st.session_state['history'] = []
64
+
65
+ # Model Configuration Classes
66
+ @dataclass
67
+ class ModelConfig:
68
+ name: str
69
+ base_model: str
70
+ size: str
71
+ domain: Optional[str] = None
72
+ model_type: str = "causal_lm"
73
+ @property
74
+ def model_path(self):
75
+ return f"models/{self.name}"
76
+
77
+ @dataclass
78
+ class DiffusionConfig:
79
+ name: str
80
+ base_model: str
81
+ size: str
82
+ @property
83
+ def model_path(self):
84
+ return f"diffusion_models/{self.name}"
85
+
86
+ # Datasets
87
+ class SFTDataset(Dataset):
88
+ def __init__(self, data, tokenizer, max_length=128):
89
+ self.data = data
90
+ self.tokenizer = tokenizer
91
+ self.max_length = max_length
92
+ def __len__(self):
93
+ return len(self.data)
94
+ def __getitem__(self, idx):
95
+ prompt = self.data[idx]["prompt"]
96
+ response = self.data[idx]["response"]
97
+ full_text = f"{prompt} {response}"
98
+ full_encoding = self.tokenizer(full_text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
99
+ prompt_encoding = self.tokenizer(prompt, max_length=self.max_length, padding=False, truncation=True, return_tensors="pt")
100
+ input_ids = full_encoding["input_ids"].squeeze()
101
+ attention_mask = full_encoding["attention_mask"].squeeze()
102
+ labels = input_ids.clone()
103
+ prompt_len = prompt_encoding["input_ids"].shape[1]
104
+ if prompt_len < self.max_length:
105
+ labels[:prompt_len] = -100
106
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
107
+
108
+ class DiffusionDataset(Dataset):
109
+ def __init__(self, images, texts):
110
+ self.images = images
111
+ self.texts = texts
112
+ def __len__(self):
113
+ return len(self.images)
114
+ def __getitem__(self, idx):
115
+ return {"image": self.images[idx], "text": self.texts[idx]}
116
+
117
+ # Model Builders
118
+ class ModelBuilder:
119
+ def __init__(self):
120
+ self.config = None
121
+ self.model = None
122
+ self.tokenizer = None
123
+ self.sft_data = None
124
+ self.jokes = ["Why did the AI go to therapy? Too many layers to unpack! 😂", "Training complete! Time for a binary coffee break. ☕"]
125
+ def load_model(self, model_path: str, config: Optional[ModelConfig] = None):
126
+ with st.spinner(f"Loading {model_path}... ⏳"):
127
+ self.model = AutoModelForCausalLM.from_pretrained(model_path)
128
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
129
+ if self.tokenizer.pad_token is None:
130
+ self.tokenizer.pad_token = self.tokenizer.eos_token
131
+ if config:
132
+ self.config = config
133
+ self.model.to("cuda" if torch.cuda.is_available() else "cpu")
134
+ st.success(f"Model loaded! 🎉 {random.choice(self.jokes)}")
135
+ return self
136
+ def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
137
+ self.sft_data = []
138
+ with open(csv_path, "r") as f:
139
+ reader = csv.DictReader(f)
140
+ for row in reader:
141
+ self.sft_data.append({"prompt": row["prompt"], "response": row["response"]})
142
+ dataset = SFTDataset(self.sft_data, self.tokenizer)
143
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
144
+ optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
145
+ self.model.train()
146
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
147
+ self.model.to(device)
148
+ for epoch in range(epochs):
149
+ with st.spinner(f"Training epoch {epoch + 1}/{epochs}... ⚙️"):
150
+ total_loss = 0
151
+ for batch in dataloader:
152
+ optimizer.zero_grad()
153
+ input_ids = batch["input_ids"].to(device)
154
+ attention_mask = batch["attention_mask"].to(device)
155
+ labels = batch["labels"].to(device)
156
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
157
+ loss = outputs.loss
158
+ loss.backward()
159
+ optimizer.step()
160
+ total_loss += loss.item()
161
+ st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
162
+ st.success(f"SFT Fine-tuning completed! 🎉 {random.choice(self.jokes)}")
163
+ return self
164
+ def save_model(self, path: str):
165
+ with st.spinner("Saving model... 💾"):
166
+ os.makedirs(os.path.dirname(path), exist_ok=True)
167
+ self.model.save_pretrained(path)
168
+ self.tokenizer.save_pretrained(path)
169
+ st.success(f"Model saved at {path}! ✅")
170
+ def evaluate(self, prompt: str, status_container=None):
171
+ self.model.eval()
172
+ if status_container:
173
+ status_container.write("Preparing to evaluate... 🧠")
174
+ try:
175
+ with torch.no_grad():
176
+ inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.model.device)
177
+ outputs = self.model.generate(**inputs, max_new_tokens=50, do_sample=True, top_p=0.95, temperature=0.7)
178
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
179
+ except Exception as e:
180
+ if status_container:
181
+ status_container.error(f"Oops! Something broke: {str(e)} 💥")
182
+ return f"Error: {str(e)}"
183
+
184
+ class DiffusionBuilder:
185
+ def __init__(self):
186
+ self.config = None
187
+ self.pipeline = None
188
+ def load_model(self, model_path: str, config: Optional[DiffusionConfig] = None):
189
+ with st.spinner(f"Loading diffusion model {model_path}... ⏳"):
190
+ self.pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float32).to("cpu")
191
+ if config:
192
+ self.config = config
193
+ st.success(f"Diffusion model loaded! 🎨")
194
+ return self
195
+ def fine_tune_sft(self, images, texts, epochs=3):
196
+ dataset = DiffusionDataset(images, texts)
197
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
198
+ optimizer = torch.optim.AdamW(self.pipeline.unet.parameters(), lr=1e-5)
199
+ self.pipeline.unet.train()
200
+ for epoch in range(epochs):
201
+ with—for st.spinner(f"Training diffusion epoch {epoch + 1}/{epochs}... ⚙️"):
202
+ total_loss = 0
203
+ for batch in dataloader:
204
+ optimizer.zero_grad()
205
+ image = batch["image"][0].to(self.pipeline.device)
206
+ text = batch["text"][0]
207
+ latents = self.pipeline.vae.encode(torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float().to(self.pipeline.device)).latent_dist.sample()
208
+ noise = torch.randn_like(latents)
209
+ timesteps = torch.randint(0, self.pipeline.scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device)
210
+ noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)
211
+ text_embeddings = self.pipeline.text_encoder(self.pipeline.tokenizer(text, return_tensors="pt").input_ids.to(self.pipeline.device))[0]
212
+ pred_noise = self.pipeline.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample
213
+ loss = torch.nn.functional.mse_loss(pred_noise, noise)
214
+ loss.backward()
215
+ optimizer.step()
216
+ total_loss += loss.item()
217
+ st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
218
+ st.success("Diffusion SFT Fine-tuning completed! 🎨")
219
+ return self
220
+ def save_model(self, path: str):
221
+ with st.spinner("Saving diffusion model... 💾"):
222
+ os.makedirs(os.path.dirname(path), exist_ok=True)
223
+ self.pipeline.save_pretrained(path)
224
+ st.success(f"Diffusion model saved at {path}! ✅")
225
+ def generate(self, prompt: str):
226
+ return self.pipeline(prompt, num_inference_steps=20).images[0]
227
 
228
  # Utility Functions
229
+ def get_download_link(file_path, mime_type="text/plain", label="Download"):
230
+ with open(file_path, 'rb') as f:
231
+ data = f.read()
232
+ b64 = base64.b64encode(data).decode()
233
+ return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label} 📥</a>'
234
+
235
+ def zip_directory(directory_path, zip_path):
236
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
237
+ for root, _, files in os.walk(directory_path):
238
+ for file in files:
239
+ zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), os.path.dirname(directory_path)))
240
+
241
+ def get_model_files(model_type="causal_lm"):
242
+ path = "models/*" if model_type == "causal_lm" else "diffusion_models/*"
243
+ return [d for d in glob.glob(path) if os.path.isdir(d)]
244
 
245
  def get_gallery_files(file_types):
246
  return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  def download_pdf(url, output_path):
249
  try:
250
  response = requests.get(url, stream=True, timeout=10)
 
257
  logger.error(f"Failed to download {url}: {e}")
258
  return False
259
 
260
+ # Mock Search Tool for RAG
261
+ def mock_search(query: str) -> str:
262
+ if "superhero" in query.lower():
263
+ return "Latest trends: Gold-plated Batman statues, VR superhero battles."
264
+ return "No relevant results found."
265
+
266
+ def mock_duckduckgo_search(query: str) -> str:
267
+ if "superhero party trends" in query.lower():
268
+ return """
269
+ Latest trends for 2025:
270
+ - Luxury decorations: Gold-plated Batman statues, holographic Avengers displays.
271
+ - Entertainment: Live stunt shows with Iron Man suits, VR superhero battles.
272
+ - Catering: Gourmet kryptonite-green cocktails, Thor’s hammer-shaped appetizers.
273
+ """
274
+ return "No relevant results found."
275
+
276
+ # Agent Classes
277
+ class PartyPlannerAgent:
278
+ def __init__(self, model, tokenizer):
279
+ self.model = model
280
+ self.tokenizer = tokenizer
281
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
282
+ self.model.to(self.device)
283
+ def generate(self, prompt: str) -> str:
284
+ self.model.eval()
285
+ with torch.no_grad():
286
+ inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.device)
287
+ outputs = self.model.generate(**inputs, max_new_tokens=100, do_sample=True, top_p=0.95, temperature=0.7)
288
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
289
+ def plan_party(self, task: str) -> pd.DataFrame:
290
+ search_result = mock_duckduckgo_search("latest superhero party trends")
291
+ prompt = f"Given this context: '{search_result}'\n{task}"
292
+ plan_text = self.generate(prompt)
293
+ locations = {
294
+ "Wayne Manor": (42.3601, -71.0589),
295
+ "New York": (40.7128, -74.0060),
296
+ "Los Angeles": (34.0522, -118.2437),
297
+ "London": (51.5074, -0.1278)
298
+ }
299
+ wayne_coords = locations["Wayne Manor"]
300
+ travel_times = {loc: calculate_cargo_travel_time(coords, wayne_coords) for loc, coords in locations.items() if loc != "Wayne Manor"}
301
+ catchphrases = ["To the Batmobile!", "Avengers, assemble!", "I am Iron Man!", "By the power of Grayskull!"]
302
+ data = [
303
+ {"Location": "New York", "Travel Time (hrs)": travel_times["New York"], "Luxury Idea": "Gold-plated Batman statues", "Catchphrase": random.choice(catchphrases)},
304
+ {"Location": "Los Angeles", "Travel Time (hrs)": travel_times["Los Angeles"], "Luxury Idea": "Holographic Avengers displays", "Catchphrase": random.choice(catchphrases)},
305
+ {"Location": "London", "Travel Time (hrs)": travel_times["London"], "Luxury Idea": "Live stunt shows with Iron Man suits", "Catchphrase": random.choice(catchphrases)},
306
+ {"Location": "Wayne Manor", "Travel Time (hrs)": 0.0, "Luxury Idea": "VR superhero battles", "Catchphrase": random.choice(catchphrases)},
307
+ {"Location": "New York", "Travel Time (hrs)": travel_times["New York"], "Luxury Idea": "Gourmet kryptonite-green cocktails", "Catchphrase": random.choice(catchphrases)},
308
+ {"Location": "Los Angeles", "Travel Time (hrs)": travel_times["Los Angeles"], "Luxury Idea": "Thor’s hammer-shaped appetizers", "Catchphrase": random.choice(catchphrases)},
309
+ ]
310
+ return pd.DataFrame(data)
311
+
312
+ class CVPartyPlannerAgent:
313
+ def __init__(self, pipeline):
314
+ self.pipeline = pipeline
315
+ def generate(self, prompt: str) -> Image.Image:
316
+ return self.pipeline(prompt, num_inference_steps=20).images[0]
317
+ def plan_party(self, task: str) -> pd.DataFrame:
318
+ search_result = mock_search("superhero party trends")
319
+ prompt = f"Given this context: '{search_result}'\n{task}"
320
+ data = [
321
+ {"Theme": "Batman", "Image Idea": "Gold-plated Batman statue"},
322
+ {"Theme": "Avengers", "Image Idea": "VR superhero battle scene"}
323
+ ]
324
+ return pd.DataFrame(data)
325
+
326
+ def calculate_cargo_travel_time(origin_coords: Tuple[float, float], destination_coords: Tuple[float, float], cruising_speed_kmh: float = 750.0) -> float:
327
+ def to_radians(degrees: float) -> float:
328
+ return degrees * (math.pi / 180)
329
+ lat1, lon1 = map(to_radians, origin_coords)
330
+ lat2, lon2 = map(to_radians, destination_coords)
331
+ EARTH_RADIUS_KM = 6371.0
332
+ dlon = lon2 - lon1
333
+ dlat = lat2 - lat1
334
+ a = (math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2)
335
+ c = 2 * math.asin(math.sqrt(a))
336
+ distance = EARTH_RADIUS_KM * c
337
+ actual_distance = distance * 1.1
338
+ flight_time = (actual_distance / cruising_speed_kmh) + 1.0
339
+ return round(flight_time, 2)
340
 
341
  # Async Processing Functions
342
  async def process_pdf_snapshot(pdf_path, mode="thumbnail"):
 
345
  status.text(f"Processing PDF Snapshot ({mode})... (0s)")
346
  doc = fitz.open(pdf_path)
347
  output_files = []
 
348
  if mode == "thumbnail":
349
  page = doc[0]
350
+ pix = page.get_pixmap(matrix=fitz.Matrix(0.5, 0.5))
351
  output_file = generate_filename("thumbnail", "png")
352
  pix.save(output_file)
353
  output_files.append(output_file)
354
  elif mode == "twopage":
355
  for i in range(min(2, len(doc))):
356
  page = doc[i]
357
+ pix = page.get_pixmap(matrix=fitz.Matrix(1.0, 1.0))
358
  output_file = generate_filename(f"twopage_{i}", "png")
359
  pix.save(output_file)
360
  output_files.append(output_file)
 
361
  doc.close()
362
  elapsed = int(time.time() - start_time)
363
  status.text(f"PDF Snapshot ({mode}) completed in {elapsed}s!")
 
382
  update_gallery()
383
  return result
384
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  # Main App
386
+ st.title("AI Vision & SFT Titans 🚀")
387
 
388
+ # Sidebar
389
  st.sidebar.header("Captured Files 📜")
390
  gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 4)
391
  update_gallery()
392
 
393
+ st.sidebar.subheader("Model Management 🗂️")
394
+ model_type = st.sidebar.selectbox("Model Type", ["Causal LM", "Diffusion"], key="sidebar_model_type")
395
+ model_dirs = get_model_files(model_type)
396
+ selected_model = st.sidebar.selectbox("Select Saved Model", ["None"] + model_dirs, key="sidebar_model_select")
397
+ if selected_model != "None" and st.sidebar.button("Load Model 📂"):
398
+ builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
399
+ config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=os.path.basename(selected_model), base_model="unknown", size="small")
400
+ builder.load_model(selected_model, config)
401
+ st.session_state['builder'] = builder
402
+ st.session_state['model_loaded'] = True
403
+ st.rerun()
404
+
405
  st.sidebar.subheader("Action Logs 📜")
406
  log_container = st.sidebar.empty()
407
  with log_container:
408
  for record in log_records:
409
  st.write(f"{record.asctime} - {record.levelname} - {record.message}")
410
 
411
+ st.sidebar.subheader("History 📜")
412
+ history_container = st.sidebar.empty()
413
+ with history_container:
414
+ for entry in st.session_state['history'][-5:]:
415
+ st.write(entry)
416
+
417
  # Tabs
418
+ tab1, tab2, tab3, tab4, tab5, tab6, tab7, tab8 = st.tabs([
419
+ "Camera Snap 📷", "Download PDFs 📥", "Build Titan 🌱", "Fine-Tune Titan 🔧",
420
+ "Test Titan 🧪", "Agentic RAG Party 🌐", "Test OCR 🔍", "Test Image Gen 🎨"
421
+ ])
422
 
423
  with tab1:
424
  st.header("Camera Snap 📷")
 
434
  st.image(Image.open(filename), caption=filename, use_container_width=True)
435
  logger.info(f"Saved snapshot from Camera 0: {filename}")
436
  st.session_state['captured_files'].append(filename)
437
+ st.session_state['history'].append(f"Snapshot from Cam 0: {filename}")
438
  update_gallery()
439
  with cols[1]:
440
  cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
 
446
  st.image(Image.open(filename), caption=filename, use_container_width=True)
447
  logger.info(f"Saved snapshot from Camera 1: {filename}")
448
  st.session_state['captured_files'].append(filename)
449
+ st.session_state['history'].append(f"Snapshot from Cam 1: {filename}")
450
  update_gallery()
451
 
452
  st.subheader("Burst Capture")
 
465
  f.write(img.getvalue())
466
  st.session_state['burst_frames'].append(filename)
467
  logger.info(f"Saved burst frame {i}: {filename}")
468
+ st.session_state['history'].append(f"Burst frame {i}: {filename}")
469
  st.image(Image.open(filename), caption=filename, use_container_width=True)
470
  time.sleep(0.5)
471
  st.session_state['captured_files'].extend([f for f in st.session_state['burst_frames'] if f not in st.session_state['captured_files']])
 
483
  pdf_path = generate_filename("downloaded", "pdf")
484
  if download_pdf(url, pdf_path):
485
  logger.info(f"Downloaded PDF from {url} to {pdf_path}")
486
+ st.session_state['history'].append(f"Downloaded PDF: {pdf_path}")
487
  snapshots = asyncio.run(process_pdf_snapshot(pdf_path, mode.lower().replace(" ", "")))
488
  for snapshot in snapshots:
489
  st.image(Image.open(snapshot), caption=snapshot, use_container_width=True)
 
491
  st.error(f"Failed to download {url}")
492
 
493
  with tab3:
494
+ st.header("Build Titan 🌱")
495
+ model_type = st.selectbox("Model Type", ["Causal LM", "Diffusion"], key="build_type")
496
+ base_model = st.selectbox("Select Tiny Model",
497
+ ["HuggingFaceTB/SmolLM-135M", "Qwen/Qwen1.5-0.5B-Chat"] if model_type == "Causal LM" else
498
+ ["OFA-Sys/small-stable-diffusion-v0", "stabilityai/stable-diffusion-2-base"])
499
+ model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}")
500
+ domain = st.text_input("Target Domain", "general")
501
+ if st.button("Download Model ⬇️"):
502
+ config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=model_name, base_model=base_model, size="small", domain=domain)
503
+ builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
504
+ builder.load_model(base_model, config)
505
+ builder.save_model(config.model_path)
506
+ st.session_state['builder'] = builder
507
+ st.session_state['model_loaded'] = True
508
+ st.session_state['history'].append(f"Built {model_type} model: {model_name}")
509
+ st.success(f"Model downloaded and saved to {config.model_path}! 🎉")
510
+ st.rerun()
511
+
512
+ with tab4:
513
+ st.header("Fine-Tune Titan 🔧")
514
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
515
+ st.warning("Please build or load a Titan first! ⚠️")
516
+ else:
517
+ if isinstance(st.session_state['builder'], ModelBuilder):
518
+ if st.button("Generate Sample CSV 📝"):
519
+ sample_data = [
520
+ {"prompt": "What is AI?", "response": "AI is artificial intelligence, simulating human smarts in machines."},
521
+ {"prompt": "Explain machine learning", "response": "Machine learning is AI’s gym where models bulk up on data."},
522
+ ]
523
+ csv_path = f"sft_data_{int(time.time())}.csv"
524
+ with open(csv_path, "w", newline="") as f:
525
+ writer = csv.DictWriter(f, fieldnames=["prompt", "response"])
526
+ writer.writeheader()
527
+ writer.writerows(sample_data)
528
+ st.markdown(get_download_link(csv_path, "text/csv", "Download Sample CSV"), unsafe_allow_html=True)
529
+ st.success(f"Sample CSV generated as {csv_path}! ✅")
530
+
531
+ uploaded_csv = st.file_uploader("Upload CSV for SFT", type="csv")
532
+ if uploaded_csv and st.button("Fine-Tune with Uploaded CSV 🔄"):
533
+ csv_path = f"uploaded_sft_data_{int(time.time())}.csv"
534
+ with open(csv_path, "wb") as f:
535
+ f.write(uploaded_csv.read())
536
+ new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
537
+ new_config = ModelConfig(name=new_model_name, base_model=st.session_state['builder'].config.base_model, size="small", domain=st.session_state['builder'].config.domain)
538
+ st.session_state['builder'].config = new_config
539
+ st.session_state['builder'].fine_tune_sft(csv_path)
540
+ st.session_state['builder'].save_model(new_config.model_path)
541
+ zip_path = f"{new_config.model_path}.zip"
542
+ zip_directory(new_config.model_path, zip_path)
543
+ st.session_state['history'].append(f"Fine-tuned Causal LM: {new_model_name}")
544
+ st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Titan"), unsafe_allow_html=True)
545
+ st.rerun()
546
+ elif isinstance(st.session_state['builder'], DiffusionBuilder):
547
+ captured_files = get_gallery_files(["png"])
548
+ if len(captured_files) >= 2:
549
+ demo_data = [{"image": img, "text": f"Superhero {os.path.basename(img).split('.')[0]}"} for img in captured_files[:min(len(captured_files), slice_count)]]
550
+ edited_data = st.data_editor(pd.DataFrame(demo_data), num_rows="dynamic")
551
+ if st.button("Fine-Tune with Dataset 🔄"):
552
+ images = [Image.open(row["image"]) for _, row in edited_data.iterrows()]
553
+ texts = [row["text"] for _, row in edited_data.iterrows()]
554
+ new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
555
+ new_config = DiffusionConfig(name=new_model_name, base_model=st.session_state['builder'].config.base_model, size="small")
556
+ st.session_state['builder'].config = new_config
557
+ st.session_state['builder'].fine_tune_sft(images, texts)
558
+ st.session_state['builder'].save_model(new_config.model_path)
559
+ zip_path = f"{new_config.model_path}.zip"
560
+ zip_directory(new_config.model_path, zip_path)
561
+ st.session_state['history'].append(f"Fine-tuned Diffusion: {new_model_name}")
562
+ st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Diffusion Model"), unsafe_allow_html=True)
563
+ csv_path = f"sft_dataset_{int(time.time())}.csv"
564
+ with open(csv_path, "w", newline="") as f:
565
+ writer = csv.writer(f)
566
+ writer.writerow(["image", "text"])
567
+ for _, row in edited_data.iterrows():
568
+ writer.writerow([row["image"], row["text"]])
569
+ st.markdown(get_download_link(csv_path, "text/csv", "Download SFT Dataset CSV"), unsafe_allow_html=True)
570
+
571
+ with tab5:
572
+ st.header("Test Titan 🧪")
573
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
574
+ st.warning("Please build or load a Titan first! ⚠️")
575
+ else:
576
+ if isinstance(st.session_state['builder'], ModelBuilder):
577
+ if st.session_state['builder'].sft_data:
578
+ st.write("Testing with SFT Data:")
579
+ for item in st.session_state['builder'].sft_data[:3]:
580
+ prompt = item["prompt"]
581
+ expected = item["response"]
582
+ status_container = st.empty()
583
+ generated = st.session_state['builder'].evaluate(prompt, status_container)
584
+ st.write(f"**Prompt**: {prompt}")
585
+ st.write(f"**Expected**: {expected}")
586
+ st.write(f"**Generated**: {generated}")
587
+ st.write("---")
588
+ status_container.empty()
589
+ test_prompt = st.text_area("Enter Test Prompt", "What is AI?")
590
+ if st.button("Run Test ▶️"):
591
+ status_container = st.empty()
592
+ result = st.session_state['builder'].evaluate(test_prompt, status_container)
593
+ st.session_state['history'].append(f"Causal LM Test: {test_prompt} -> {result}")
594
+ st.write(f"**Generated Response**: {result}")
595
+ status_container.empty()
596
+ elif isinstance(st.session_state['builder'], DiffusionBuilder):
597
+ test_prompt = st.text_area("Enter Test Prompt", "Neon Batman")
598
+ if st.button("Run Test ▶️"):
599
+ image = st.session_state['builder'].generate(test_prompt)
600
+ output_file = generate_filename("diffusion_test", "png")
601
+ image.save(output_file)
602
+ st.session_state['captured_files'].append(output_file)
603
+ st.session_state['history'].append(f"Diffusion Test: {test_prompt} -> {output_file}")
604
+ st.image(image, caption="Generated Image")
605
+ update_gallery()
606
+
607
+ with tab6:
608
+ st.header("Agentic RAG Party 🌐")
609
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
610
+ st.warning("Please build or load a Titan first! ⚠️")
611
+ else:
612
+ if isinstance(st.session_state['builder'], ModelBuilder):
613
+ if st.button("Run NLP RAG Demo 🎉"):
614
+ agent = PartyPlannerAgent(st.session_state['builder'].model, st.session_state['builder'].tokenizer)
615
+ task = "Plan a luxury superhero-themed party at Wayne Manor."
616
+ plan_df = agent.plan_party(task)
617
+ st.session_state['history'].append(f"NLP RAG Demo: Planned party at Wayne Manor")
618
+ st.dataframe(plan_df)
619
+ elif isinstance(st.session_state['builder'], DiffusionBuilder):
620
+ if st.button("Run CV RAG Demo 🎉"):
621
+ agent = CVPartyPlannerAgent(st.session_state['builder'].pipeline)
622
+ task = "Generate images for a luxury superhero-themed party."
623
+ plan_df = agent.plan_party(task)
624
+ st.session_state['history'].append(f"CV RAG Demo: Generated party images")
625
+ st.dataframe(plan_df)
626
+ for _, row in plan_df.iterrows():
627
+ image = agent.generate(row["Image Idea"])
628
+ output_file = generate_filename(f"cv_rag_{row['Theme'].lower()}", "png")
629
+ image.save(output_file)
630
+ st.session_state['captured_files'].append(output_file)
631
+ st.image(image, caption=f"{row['Theme']} - {row['Image Idea']}")
632
+ update_gallery()
633
+
634
+ with tab7:
635
  st.header("Test OCR 🔍")
636
  captured_files = get_gallery_files(["png"])
637
  if captured_files:
 
642
  output_file = generate_filename("ocr_output", "txt")
643
  st.session_state['processing']['ocr'] = True
644
  result = asyncio.run(process_ocr(image, output_file))
645
+ st.session_state['history'].append(f"OCR Test: {selected_file} -> {output_file}")
646
  st.text_area("OCR Result", result, height=200, key="ocr_result")
647
  st.success(f"OCR output saved to {output_file}")
648
  st.session_state['processing']['ocr'] = False
649
  else:
650
  st.warning("No images captured yet. Use Camera Snap or Download PDFs first!")
651
 
652
+ with tab8:
653
  st.header("Test Image Gen 🎨")
654
  captured_files = get_gallery_files(["png"])
655
  if captured_files:
 
661
  output_file = generate_filename("gen_output", "png")
662
  st.session_state['processing']['gen'] = True
663
  result = asyncio.run(process_image_gen(prompt, output_file))
664
+ st.session_state['history'].append(f"Image Gen Test: {prompt} -> {output_file}")
665
  st.image(result, caption="Generated Image", use_container_width=True)
666
  st.success(f"Image saved to {output_file}")
667
  st.session_state['processing']['gen'] = False
668
  else:
669
  st.warning("No images captured yet. Use Camera Snap or Download PDFs first!")
670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
671
  # Initial Gallery Update
672
  update_gallery()