awacke1 commited on
Commit
a9b69b8
Β·
verified Β·
1 Parent(s): d791c5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -172
app.py CHANGED
@@ -3,7 +3,6 @@ import os
3
  import glob
4
  import base64
5
  import time
6
- import shutil
7
  import pandas as pd
8
  import torch
9
  import torch.nn as nn
@@ -15,7 +14,6 @@ import csv
15
  import fitz
16
  import requests
17
  from PIL import Image
18
- import cv2
19
  import numpy as np
20
  import logging
21
  import asyncio
@@ -39,7 +37,6 @@ class LogCaptureHandler(logging.Handler):
39
 
40
  logger.addHandler(LogCaptureHandler())
41
 
42
- # Data Classes and Models (unchanged from your original code)
43
  @dataclass
44
  class ModelConfig:
45
  name: str
@@ -61,106 +58,12 @@ class DiffusionConfig:
61
  def model_path(self):
62
  return f"diffusion_models/{self.name}"
63
 
64
- class SFTDataset(Dataset):
65
- def __init__(self, data, tokenizer, max_length=128):
66
- self.data = data
67
- self.tokenizer = tokenizer
68
- self.max_length = max_length
69
- def __len__(self):
70
- return len(self.data)
71
- def __getitem__(self, idx):
72
- prompt = self.data[idx]["prompt"]
73
- response = self.data[idx]["response"]
74
- full_text = f"{prompt} {response}"
75
- full_encoding = self.tokenizer(full_text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
76
- prompt_encoding = self.tokenizer(prompt, max_length=self.max_length, padding=False, truncation=True, return_tensors="pt")
77
- input_ids = full_encoding["input_ids"].squeeze()
78
- attention_mask = full_encoding["attention_mask"].squeeze()
79
- labels = input_ids.clone()
80
- prompt_len = prompt_encoding["input_ids"].shape[1]
81
- if prompt_len < self.max_length:
82
- labels[:prompt_len] = -100
83
- return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
84
-
85
- class TinyUNet(nn.Module):
86
- def __init__(self, in_channels=3, out_channels=3):
87
- super(TinyUNet, self).__init__()
88
- self.down1 = nn.Conv2d(in_channels, 32, 3, padding=1)
89
- self.down2 = nn.Conv2d(32, 64, 3, padding=1, stride=2)
90
- self.mid = nn.Conv2d(64, 128, 3, padding=1)
91
- self.up1 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)
92
- self.up2 = nn.Conv2d(64 + 32, 32, 3, padding=1)
93
- self.out = nn.Conv2d(32, out_channels, 3, padding=1)
94
- self.time_embed = nn.Linear(1, 64)
95
-
96
- def forward(self, x, t):
97
- t_embed = F.relu(self.time_embed(t.unsqueeze(-1)))
98
- t_embed = t_embed.view(t_embed.size(0), t_embed.size(1), 1, 1)
99
- x1 = F.relu(self.down1(x))
100
- x2 = F.relu(self.down2(x1))
101
- x_mid = F.relu(self.mid(x2)) + t_embed
102
- x_up1 = F.relu(self.up1(x_mid))
103
- x_up2 = F.relu(self.up2(torch.cat([x_up1, x1], dim=1)))
104
- return self.out(x_up2)
105
-
106
- class TinyDiffusion:
107
- def __init__(self, model, timesteps=100):
108
- self.model = model
109
- self.timesteps = timesteps
110
- self.beta = torch.linspace(0.0001, 0.02, timesteps)
111
- self.alpha = 1 - self.beta
112
- self.alpha_cumprod = torch.cumprod(self.alpha, dim=0)
113
-
114
- def train(self, images, epochs=50):
115
- dataset = TinyDiffusionDataset(images)
116
- dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
117
- optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
118
- device = torch.device("cpu")
119
- self.model.to(device)
120
- for epoch in range(epochs):
121
- total_loss = 0
122
- for x in dataloader:
123
- x = x.to(device)
124
- t = torch.randint(0, self.timesteps, (x.size(0),), device=device).float()
125
- noise = torch.randn_like(x)
126
- alpha_t = self.alpha_cumprod[t.long()].view(-1, 1, 1, 1)
127
- x_noisy = torch.sqrt(alpha_t) * x + torch.sqrt(1 - alpha_t) * noise
128
- pred_noise = self.model(x_noisy, t)
129
- loss = F.mse_loss(pred_noise, noise)
130
- optimizer.zero_grad()
131
- loss.backward()
132
- optimizer.step()
133
- total_loss += loss.item()
134
- logger.info(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataloader):.4f}")
135
- return self
136
-
137
- def generate(self, size=(64, 64), steps=100):
138
- device = torch.device("cpu")
139
- x = torch.randn(1, 3, size[0], size[1], device=device)
140
- for t in reversed(range(steps)):
141
- t_tensor = torch.full((1,), t, device=device, dtype=torch.float32)
142
- alpha_t = self.alpha_cumprod[t].view(-1, 1, 1, 1)
143
- pred_noise = self.model(x, t_tensor)
144
- x = (x - (1 - self.alpha[t]) / torch.sqrt(1 - alpha_t) * pred_noise) / torch.sqrt(self.alpha[t])
145
- if t > 0:
146
- x += torch.sqrt(self.beta[t]) * torch.randn_like(x)
147
- x = torch.clamp(x * 255, 0, 255).byte()
148
- return Image.fromarray(x.squeeze(0).permute(1, 2, 0).cpu().numpy())
149
-
150
- class TinyDiffusionDataset(Dataset):
151
- def __init__(self, images):
152
- self.images = [torch.tensor(np.array(img.convert("RGB")).transpose(2, 0, 1), dtype=torch.float32) / 255.0 for img in images]
153
- def __len__(self):
154
- return len(self.images)
155
- def __getitem__(self, idx):
156
- return self.images[idx]
157
-
158
  class ModelBuilder:
159
  def __init__(self):
160
  self.config = None
161
  self.model = None
162
  self.tokenizer = None
163
- self.sft_data = None
164
  def load_model(self, model_path: str, config: Optional[ModelConfig] = None):
165
  self.model = AutoModelForCausalLM.from_pretrained(model_path)
166
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
@@ -170,42 +73,10 @@ class ModelBuilder:
170
  self.config = config
171
  self.model.to("cuda" if torch.cuda.is_available() else "cpu")
172
  return self
173
- def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
174
- self.sft_data = []
175
- with open(csv_path, "r") as f:
176
- reader = csv.DictReader(f)
177
- for row in reader:
178
- self.sft_data.append({"prompt": row["prompt"], "response": row["response"]})
179
- dataset = SFTDataset(self.sft_data, self.tokenizer)
180
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
181
- optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
182
- self.model.train()
183
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
184
- self.model.to(device)
185
- for epoch in range(epochs):
186
- total_loss = 0
187
- for batch in dataloader:
188
- optimizer.zero_grad()
189
- input_ids = batch["input_ids"].to(device)
190
- attention_mask = batch["attention_mask"].to(device)
191
- labels = batch["labels"].to(device)
192
- outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
193
- loss = outputs.loss
194
- loss.backward()
195
- optimizer.step()
196
- total_loss += loss.item()
197
- logger.info(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
198
- return self
199
  def save_model(self, path: str):
200
  os.makedirs(os.path.dirname(path), exist_ok=True)
201
  self.model.save_pretrained(path)
202
  self.tokenizer.save_pretrained(path)
203
- def evaluate(self, prompt: str):
204
- self.model.eval()
205
- with torch.no_grad():
206
- inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.model.device)
207
- outputs = self.model.generate(**inputs, max_new_tokens=50, do_sample=True, top_p=0.95, temperature=0.7)
208
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
209
 
210
  class DiffusionBuilder:
211
  def __init__(self):
@@ -216,12 +87,14 @@ class DiffusionBuilder:
216
  if config:
217
  self.config = config
218
  return self
 
 
 
219
  def generate(self, prompt: str):
220
  return self.pipeline(prompt, num_inference_steps=20).images[0]
221
 
222
- # Utility Functions
223
  def generate_filename(sequence, ext="png"):
224
- timestamp = time.strftime("%d%m%Y%HM%S")
225
  return f"{sequence}_{timestamp}.{ext}"
226
 
227
  def pdf_url_to_filename(url):
@@ -231,6 +104,11 @@ def pdf_url_to_filename(url):
231
  def get_gallery_files(file_types=["png", "pdf"]):
232
  return sorted(list(set([f for ext in file_types for f in glob.glob(f"*.{ext}")]))) # Deduplicate files
233
 
 
 
 
 
 
234
  def download_pdf(url, output_path):
235
  try:
236
  response = requests.get(url, stream=True, timeout=10)
@@ -252,25 +130,72 @@ async def process_pdf_snapshot(pdf_path, mode="single"):
252
  output_file = generate_filename("single", "png")
253
  pix.save(output_file)
254
  output_files.append(output_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  doc.close()
256
  return output_files
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  # Gradio Interface Functions
259
- def update_gallery(history):
260
  all_files = get_gallery_files()
261
- gallery_content = "\n".join([f"- {f}" for f in all_files[:5]])
 
 
 
 
 
 
 
 
 
262
  history.append(f"Gallery updated: {len(all_files)} files")
263
- return gallery_content, history
264
 
265
- def camera_snap(image, history):
266
  if image is not None:
267
- filename = generate_filename("cam")
268
  image.save(filename)
269
- history.append(f"Snapshot saved: {filename}")
270
- return f"Image saved as {filename}", history
271
- return "No image captured", history
 
 
 
 
272
 
273
- def download_pdfs(urls, history):
274
  urls = urls.strip().split("\n")
275
  downloaded = []
276
  for url in urls:
@@ -279,7 +204,71 @@ def download_pdfs(urls, history):
279
  if download_pdf(url, output_path):
280
  downloaded.append(output_path)
281
  history.append(f"Downloaded PDF: {output_path}")
282
- return f"Downloaded {len(downloaded)} PDFs", history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
  def build_model(model_type, base_model, model_name, domain, history):
285
  config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=model_name, base_model=base_model, size="small", domain=domain)
@@ -289,62 +278,94 @@ def build_model(model_type, base_model, model_name, domain, history):
289
  history.append(f"Built {model_type} model: {model_name}")
290
  return builder, f"Model saved to {config.model_path}", history
291
 
292
- def test_model(builder, prompt, history):
293
- if builder is None:
294
- return "No model loaded", history
295
- if isinstance(builder, ModelBuilder):
296
- result = builder.evaluate(prompt)
297
- history.append(f"Tested Causal LM: {prompt} -> {result}")
298
- return result, history
299
- elif isinstance(builder, DiffusionBuilder):
300
- image = builder.generate(prompt)
301
- output_file = generate_filename("diffusion_test")
302
- image.save(output_file)
303
- history.append(f"Tested Diffusion: {prompt} -> {output_file}")
304
- return output_file, history
 
 
305
 
306
  # Gradio UI
307
  with gr.Blocks(title="AI Vision & SFT Titans πŸš€") as demo:
308
  gr.Markdown("# AI Vision & SFT Titans πŸš€")
309
  history = gr.State(value=[])
310
  builder = gr.State(value=None)
 
 
311
 
312
  with gr.Row():
313
  with gr.Column(scale=1):
314
  gr.Markdown("## Captured Files πŸ“œ")
315
- gallery_output = gr.Textbox(label="Gallery", lines=5)
316
- gr.Button("Update Gallery").click(update_gallery, inputs=[history], outputs=[gallery_output, history])
317
-
 
 
 
 
318
  with gr.Column(scale=3):
319
  with gr.Tabs():
320
  with gr.TabItem("Camera Snap πŸ“·"):
321
- camera_input = gr.Image(type="pil", label="Take a Picture")
322
- snap_output = gr.Textbox(label="Status")
323
- gr.Button("Capture").click(camera_snap, inputs=[camera_input, history], outputs=[snap_output, history])
 
 
 
 
 
 
 
 
324
 
325
  with gr.TabItem("Download PDFs πŸ“₯"):
326
  url_input = gr.Textbox(label="Enter PDF URLs (one per line)", lines=5)
 
327
  pdf_output = gr.Textbox(label="Status")
328
- gr.Button("Download").click(download_pdfs, inputs=[url_input, history], outputs=[pdf_output, history])
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
  with gr.TabItem("Build Titan 🌱"):
331
  model_type = gr.Dropdown(["Causal LM", "Diffusion"], label="Model Type")
332
  base_model = gr.Dropdown(
333
- choices=["HuggingFaceTB/SmolLM-135M", "Qwen/Qwen1.5-0.5B-Chat"] if model_type.value == "Causal LM" else ["OFA-Sys/small-stable-diffusion-v0", "stabilityai/stable-diffusion-2-base"],
334
- label="Base Model"
 
335
  )
336
  model_name = gr.Textbox(label="Model Name", value=f"tiny-titan-{int(time.time())}")
337
- domain = gr.Textbox(label="Domain", value="general")
338
  build_output = gr.Textbox(label="Status")
339
  gr.Button("Build").click(build_model, inputs=[model_type, base_model, model_name, domain, history], outputs=[builder, build_output, history])
340
 
341
- with gr.TabItem("Test Titan πŸ§ͺ"):
342
- test_prompt = gr.Textbox(label="Test Prompt", value="What is AI?")
343
- test_output = gr.Textbox(label="Result")
344
- gr.Button("Test").click(test_model, inputs=[builder, test_prompt, history], outputs=[test_output, history])
 
 
345
 
346
- with gr.Row():
347
- gr.Markdown("## History πŸ“œ")
348
- history_output = gr.Textbox(value="\n".join(history.value), label="History", lines=5, interactive=False)
349
 
350
  demo.launch()
 
3
  import glob
4
  import base64
5
  import time
 
6
  import pandas as pd
7
  import torch
8
  import torch.nn as nn
 
14
  import fitz
15
  import requests
16
  from PIL import Image
 
17
  import numpy as np
18
  import logging
19
  import asyncio
 
37
 
38
  logger.addHandler(LogCaptureHandler())
39
 
 
40
  @dataclass
41
  class ModelConfig:
42
  name: str
 
58
  def model_path(self):
59
  return f"diffusion_models/{self.name}"
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  class ModelBuilder:
62
  def __init__(self):
63
  self.config = None
64
  self.model = None
65
  self.tokenizer = None
66
+ self.jokes = ["Why did the AI go to therapy? Too many layers to unpack! πŸ˜‚", "Training complete! Time for a binary coffee break. β˜•"]
67
  def load_model(self, model_path: str, config: Optional[ModelConfig] = None):
68
  self.model = AutoModelForCausalLM.from_pretrained(model_path)
69
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
 
73
  self.config = config
74
  self.model.to("cuda" if torch.cuda.is_available() else "cpu")
75
  return self
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  def save_model(self, path: str):
77
  os.makedirs(os.path.dirname(path), exist_ok=True)
78
  self.model.save_pretrained(path)
79
  self.tokenizer.save_pretrained(path)
 
 
 
 
 
 
80
 
81
  class DiffusionBuilder:
82
  def __init__(self):
 
87
  if config:
88
  self.config = config
89
  return self
90
+ def save_model(self, path: str):
91
+ os.makedirs(os.path.dirname(path), exist_ok=True)
92
+ self.pipeline.save_pretrained(path)
93
  def generate(self, prompt: str):
94
  return self.pipeline(prompt, num_inference_steps=20).images[0]
95
 
 
96
  def generate_filename(sequence, ext="png"):
97
+ timestamp = time.strftime("%d%m%Y%H%M%S")
98
  return f"{sequence}_{timestamp}.{ext}"
99
 
100
  def pdf_url_to_filename(url):
 
104
  def get_gallery_files(file_types=["png", "pdf"]):
105
  return sorted(list(set([f for ext in file_types for f in glob.glob(f"*.{ext}")]))) # Deduplicate files
106
 
107
+ def get_model_files(model_type="causal_lm"):
108
+ path = "models/*" if model_type == "causal_lm" else "diffusion_models/*"
109
+ dirs = [d for d in glob.glob(path) if os.path.isdir(d)]
110
+ return dirs if dirs else ["None"]
111
+
112
  def download_pdf(url, output_path):
113
  try:
114
  response = requests.get(url, stream=True, timeout=10)
 
130
  output_file = generate_filename("single", "png")
131
  pix.save(output_file)
132
  output_files.append(output_file)
133
+ elif mode == "twopage":
134
+ for i in range(min(2, len(doc))):
135
+ page = doc[i]
136
+ pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
137
+ output_file = generate_filename(f"twopage_{i}", "png")
138
+ pix.save(output_file)
139
+ output_files.append(output_file)
140
+ elif mode == "allpages":
141
+ for i in range(len(doc)):
142
+ page = doc[i]
143
+ pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
144
+ output_file = generate_filename(f"page_{i}", "png")
145
+ pix.save(output_file)
146
+ output_files.append(output_file)
147
  doc.close()
148
  return output_files
149
 
150
+ async def process_ocr(image, output_file):
151
+ tokenizer = AutoTokenizer.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True)
152
+ model = AutoModel.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True, torch_dtype=torch.float32).to("cpu").eval()
153
+ temp_file = f"temp_{int(time.time())}.png"
154
+ image.save(temp_file)
155
+ result = model.chat(tokenizer, temp_file, ocr_type='ocr')
156
+ os.remove(temp_file)
157
+ async with aiofiles.open(output_file, "w") as f:
158
+ await f.write(result)
159
+ return result
160
+
161
+ async def process_image_gen(prompt, output_file, builder):
162
+ if builder and isinstance(builder, DiffusionBuilder) and builder.pipeline:
163
+ pipeline = builder.pipeline
164
+ else:
165
+ pipeline = StableDiffusionPipeline.from_pretrained("OFA-Sys/small-stable-diffusion-v0", torch_dtype=torch.float32).to("cpu")
166
+ gen_image = pipeline(prompt, num_inference_steps=20).images[0]
167
+ gen_image.save(output_file)
168
+ return gen_image
169
+
170
  # Gradio Interface Functions
171
+ def update_gallery(history, asset_checkboxes):
172
  all_files = get_gallery_files()
173
+ gallery_images = []
174
+ for file in all_files[:5]: # Limit to 5 for display
175
+ if file.endswith('.png'):
176
+ gallery_images.append(Image.open(file))
177
+ else:
178
+ doc = fitz.open(file)
179
+ pix = doc[0].get_pixmap(matrix=fitz.Matrix(0.5, 0.5))
180
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
181
+ gallery_images.append(img)
182
+ doc.close()
183
  history.append(f"Gallery updated: {len(all_files)} files")
184
+ return gallery_images, history, asset_checkboxes
185
 
186
+ def camera_snap(image, cam_id, history, asset_checkboxes, cam_files):
187
  if image is not None:
188
+ filename = generate_filename(f"cam{cam_id}")
189
  image.save(filename)
190
+ history.append(f"Snapshot from Cam {cam_id}: {filename}")
191
+ asset_checkboxes[filename] = True
192
+ cam_files[cam_id] = filename
193
+ return f"Image saved as {filename}", Image.open(filename), history, asset_checkboxes, cam_files
194
+ elif cam_files.get(cam_id) and os.path.exists(cam_files[cam_id]):
195
+ return f"Showing previous capture: {cam_files[cam_id]}", Image.open(cam_files[cam_id]), history, asset_checkboxes, cam_files
196
+ return "No image captured", None, history, asset_checkboxes, cam_files
197
 
198
+ def download_pdfs(urls, history, asset_checkboxes):
199
  urls = urls.strip().split("\n")
200
  downloaded = []
201
  for url in urls:
 
204
  if download_pdf(url, output_path):
205
  downloaded.append(output_path)
206
  history.append(f"Downloaded PDF: {output_path}")
207
+ asset_checkboxes[output_path] = True
208
+ return f"Downloaded {len(downloaded)} PDFs", history, asset_checkboxes
209
+
210
+ def upload_pdfs(pdf_files, history, asset_checkboxes):
211
+ uploaded = []
212
+ for pdf_file in pdf_files:
213
+ if pdf_file:
214
+ output_path = f"uploaded_{int(time.time())}_{pdf_file.name}"
215
+ with open(output_path, "wb") as f:
216
+ f.write(pdf_file.read())
217
+ uploaded.append(output_path)
218
+ history.append(f"Uploaded PDF: {output_path}")
219
+ asset_checkboxes[output_path] = True
220
+ return f"Uploaded {len(uploaded)} PDFs", history, asset_checkboxes
221
+
222
+ def snapshot_pdfs(mode, history, asset_checkboxes):
223
+ selected_pdfs = [path for path in get_gallery_files() if path.endswith('.pdf') and asset_checkboxes.get(path, False)]
224
+ if not selected_pdfs:
225
+ return "No PDFs selected", [], history, asset_checkboxes
226
+ snapshots = []
227
+ mode_key = {"Single Page (High-Res)": "single", "Two Pages (High-Res)": "twopage", "All Pages (High-Res)": "allpages"}[mode]
228
+ for pdf_path in selected_pdfs:
229
+ snap_files = asyncio.run(process_pdf_snapshot(pdf_path, mode_key))
230
+ for snap in snap_files:
231
+ snapshots.append(Image.open(snap))
232
+ asset_checkboxes[snap] = True
233
+ history.append(f"Snapshot {mode_key}: {snap}")
234
+ return f"Generated {len(snapshots)} snapshots", snapshots, history, asset_checkboxes
235
+
236
+ def process_ocr_all(history, asset_checkboxes):
237
+ all_files = get_gallery_files()
238
+ if not all_files:
239
+ return "No assets to OCR", history, asset_checkboxes
240
+ full_text = "# OCR Results\n\n"
241
+ for file in all_files:
242
+ if file.endswith('.png'):
243
+ image = Image.open(file)
244
+ else:
245
+ doc = fitz.open(file)
246
+ pix = doc[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
247
+ image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
248
+ doc.close()
249
+ output_file = generate_filename(f"ocr_{os.path.basename(file)}", "txt")
250
+ result = asyncio.run(process_ocr(image, output_file))
251
+ full_text += f"## {os.path.basename(file)}\n\n{result}\n\n"
252
+ history.append(f"OCR Test: {file} -> {output_file}")
253
+ md_output_file = f"full_ocr_{int(time.time())}.md"
254
+ with open(md_output_file, "w") as f:
255
+ f.write(full_text)
256
+ return f"Full OCR saved to {md_output_file}", history, asset_checkboxes
257
+
258
+ def process_ocr_single(file_path, history, asset_checkboxes):
259
+ if not file_path:
260
+ return "No file selected", None, "", history, asset_checkboxes
261
+ if file_path.endswith('.png'):
262
+ image = Image.open(file_path)
263
+ else:
264
+ doc = fitz.open(file_path)
265
+ pix = doc[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
266
+ image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
267
+ doc.close()
268
+ output_file = generate_filename("ocr_output", "txt")
269
+ result = asyncio.run(process_ocr(image, output_file))
270
+ history.append(f"OCR Test: {file_path} -> {output_file}")
271
+ return f"OCR output saved to {output_file}", image, result, history, asset_checkboxes
272
 
273
  def build_model(model_type, base_model, model_name, domain, history):
274
  config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=model_name, base_model=base_model, size="small", domain=domain)
 
278
  history.append(f"Built {model_type} model: {model_name}")
279
  return builder, f"Model saved to {config.model_path}", history
280
 
281
+ def image_gen(prompt, file_path, builder, history, asset_checkboxes):
282
+ if not file_path:
283
+ return "No file selected", None, history, asset_checkboxes
284
+ if file_path.endswith('.png'):
285
+ image = Image.open(file_path)
286
+ else:
287
+ doc = fitz.open(file_path)
288
+ pix = doc[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
289
+ image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
290
+ doc.close()
291
+ output_file = generate_filename("gen_output", "png")
292
+ gen_image = asyncio.run(process_image_gen(prompt, output_file, builder))
293
+ history.append(f"Image Gen Test: {prompt} -> {output_file}")
294
+ asset_checkboxes[output_file] = True
295
+ return f"Image saved to {output_file}", gen_image, history, asset_checkboxes
296
 
297
  # Gradio UI
298
  with gr.Blocks(title="AI Vision & SFT Titans πŸš€") as demo:
299
  gr.Markdown("# AI Vision & SFT Titans πŸš€")
300
  history = gr.State(value=[])
301
  builder = gr.State(value=None)
302
+ asset_checkboxes = gr.State(value={})
303
+ cam_files = gr.State(value={})
304
 
305
  with gr.Row():
306
  with gr.Column(scale=1):
307
  gr.Markdown("## Captured Files πŸ“œ")
308
+ gallery_output = gr.Gallery(label="Asset Gallery", columns=2, height="auto")
309
+ gr.Button("Update Gallery").click(update_gallery, inputs=[history, asset_checkboxes], outputs=[gallery_output, history, asset_checkboxes])
310
+ gr.Markdown("## History πŸ“œ")
311
+ history_output = gr.Textbox(label="History", lines=5, interactive=False)
312
+ gr.Markdown("## Action Logs πŸ“œ")
313
+ log_output = gr.Textbox(label="Logs", value="\n".join([f"{r.asctime} - {r.levelname} - {r.message}" for r in log_records]), lines=5, interactive=False)
314
+
315
  with gr.Column(scale=3):
316
  with gr.Tabs():
317
  with gr.TabItem("Camera Snap πŸ“·"):
318
+ with gr.Row():
319
+ cam0_input = gr.Image(type="pil", label="Camera 0")
320
+ cam1_input = gr.Image(type="pil", label="Camera 1")
321
+ with gr.Row():
322
+ cam0_output = gr.Textbox(label="Cam 0 Status")
323
+ cam1_output = gr.Textbox(label="Cam 1 Status")
324
+ with gr.Row():
325
+ cam0_image = gr.Image(label="Cam 0 Preview")
326
+ cam1_image = gr.Image(label="Cam 1 Preview")
327
+ gr.Button("Capture Cam 0").click(camera_snap, inputs=[cam0_input, gr.State(value=0), history, asset_checkboxes, cam_files], outputs=[cam0_output, cam0_image, history, asset_checkboxes, cam_files])
328
+ gr.Button("Capture Cam 1").click(camera_snap, inputs=[cam1_input, gr.State(value=1), history, asset_checkboxes, cam_files], outputs=[cam1_output, cam1_image, history, asset_checkboxes, cam_files])
329
 
330
  with gr.TabItem("Download PDFs πŸ“₯"):
331
  url_input = gr.Textbox(label="Enter PDF URLs (one per line)", lines=5)
332
+ pdf_upload = gr.File(label="Upload PDFs", file_count="multiple", type="binary")
333
  pdf_output = gr.Textbox(label="Status")
334
+ snapshot_mode = gr.Dropdown(["Single Page (High-Res)", "Two Pages (High-Res)", "All Pages (High-Res)"], label="Snapshot Mode")
335
+ snapshot_output = gr.Textbox(label="Snapshot Status")
336
+ snapshot_images = gr.Gallery(label="Snapshots", columns=2, height="auto")
337
+ gr.Button("Download URLs").click(download_pdfs, inputs=[url_input, history, asset_checkboxes], outputs=[pdf_output, history, asset_checkboxes])
338
+ gr.Button("Upload PDFs").click(upload_pdfs, inputs=[pdf_upload, history, asset_checkboxes], outputs=[pdf_output, history, asset_checkboxes])
339
+ gr.Button("Snapshot Selected").click(snapshot_pdfs, inputs=[snapshot_mode, history, asset_checkboxes], outputs=[snapshot_output, snapshot_images, history, asset_checkboxes])
340
+
341
+ with gr.TabItem("Test OCR πŸ”"):
342
+ all_files = gr.Dropdown(choices=get_gallery_files(), label="Select File")
343
+ ocr_output = gr.Textbox(label="Status")
344
+ ocr_image = gr.Image(label="Input Image")
345
+ ocr_result = gr.Textbox(label="OCR Result", lines=5)
346
+ gr.Button("OCR All Assets").click(process_ocr_all, inputs=[history, asset_checkboxes], outputs=[ocr_output, history, asset_checkboxes])
347
+ gr.Button("OCR Selected").click(process_ocr_single, inputs=[all_files, history, asset_checkboxes], outputs=[ocr_output, ocr_image, ocr_result, history, asset_checkboxes])
348
 
349
  with gr.TabItem("Build Titan 🌱"):
350
  model_type = gr.Dropdown(["Causal LM", "Diffusion"], label="Model Type")
351
  base_model = gr.Dropdown(
352
+ choices=["HuggingFaceTB/SmolLM-135M", "Qwen/Qwen1.5-0.5B-Chat"],
353
+ label="Base Model",
354
+ value="HuggingFaceTB/SmolLM-135M"
355
  )
356
  model_name = gr.Textbox(label="Model Name", value=f"tiny-titan-{int(time.time())}")
357
+ domain = gr.Textbox(label="Target Domain", value="general")
358
  build_output = gr.Textbox(label="Status")
359
  gr.Button("Build").click(build_model, inputs=[model_type, base_model, model_name, domain, history], outputs=[builder, build_output, history])
360
 
361
+ with gr.TabItem("Test Image Gen 🎨"):
362
+ gen_file = gr.Dropdown(choices=get_gallery_files(), label="Select Reference File")
363
+ gen_prompt = gr.Textbox(label="Prompt", value="Generate a neon superhero version of this image")
364
+ gen_output = gr.Textbox(label="Status")
365
+ gen_image = gr.Image(label="Generated Image")
366
+ gr.Button("Generate").click(image_gen, inputs=[gen_prompt, gen_file, builder, history, asset_checkboxes], outputs=[gen_output, gen_image, history, asset_checkboxes])
367
 
368
+ # Update history output on every interaction
369
+ demo.load(lambda h: "\n".join(h[-5:]), inputs=[history], outputs=[history_output])
 
370
 
371
  demo.launch()