awacke1 commited on
Commit
83f7f1b
·
verified ·
1 Parent(s): 26a04a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -122
app.py CHANGED
@@ -69,6 +69,10 @@ if 'selected_model_type' not in st.session_state:
69
  st.session_state['selected_model_type'] = "Causal LM"
70
  if 'selected_model' not in st.session_state:
71
  st.session_state['selected_model'] = "None"
 
 
 
 
72
 
73
  @dataclass
74
  class ModelConfig:
@@ -219,7 +223,11 @@ async def process_ocr(image, output_file):
219
  status.text("Processing GOT-OCR2_0... (0s)")
220
  tokenizer = AutoTokenizer.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True)
221
  model = AutoModel.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True, torch_dtype=torch.float32).to("cpu").eval()
222
- result = model.chat(tokenizer, image, ocr_type='ocr')
 
 
 
 
223
  elapsed = int(time.time() - start_time)
224
  status.text(f"GOT-OCR2_0 completed in {elapsed}s!")
225
  async with aiofiles.open(output_file, "w") as f:
@@ -231,7 +239,10 @@ async def process_image_gen(prompt, output_file):
231
  start_time = time.time()
232
  status = st.empty()
233
  status.text("Processing Image Gen... (0s)")
234
- pipeline = StableDiffusionPipeline.from_pretrained("OFA-Sys/small-stable-diffusion-v0", torch_dtype=torch.float32).to("cpu")
 
 
 
235
  gen_image = pipeline(prompt, num_inference_steps=20).images[0]
236
  elapsed = int(time.time() - start_time)
237
  status.text(f"Image Gen completed in {elapsed}s!")
@@ -239,101 +250,6 @@ async def process_image_gen(prompt, output_file):
239
  update_gallery()
240
  return gen_image
241
 
242
- async def process_custom_diffusion(images, output_file, model_name):
243
- start_time = time.time()
244
- status = st.empty()
245
- status.text(f"Training {model_name}... (0s)")
246
- unet = TinyUNet()
247
- diffusion = TinyDiffusion(unet)
248
- diffusion.train(images)
249
- gen_image = diffusion.generate()
250
- upscaled_image = diffusion.upscale(gen_image, scale_factor=2)
251
- elapsed = int(time.time() - start_time)
252
- status.text(f"{model_name} completed in {elapsed}s!")
253
- upscaled_image.save(output_file)
254
- update_gallery()
255
- return upscaled_image
256
-
257
- class TinyUNet(nn.Module):
258
- def __init__(self, in_channels=3, out_channels=3):
259
- super(TinyUNet, self).__init__()
260
- self.down1 = nn.Conv2d(in_channels, 32, 3, padding=1)
261
- self.down2 = nn.Conv2d(32, 64, 3, padding=1, stride=2)
262
- self.mid = nn.Conv2d(64, 128, 3, padding=1)
263
- self.up1 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)
264
- self.up2 = nn.Conv2d(64 + 32, 32, 3, padding=1)
265
- self.out = nn.Conv2d(32, out_channels, 3, padding=1)
266
- self.time_embed = nn.Linear(1, 64)
267
-
268
- def forward(self, x, t):
269
- t_embed = F.relu(self.time_embed(t.unsqueeze(-1)))
270
- t_embed = t_embed.view(t_embed.size(0), t_embed.size(1), 1, 1)
271
-
272
- x1 = F.relu(self.down1(x))
273
- x2 = F.relu(self.down2(x1))
274
- x_mid = F.relu(self.mid(x2)) + t_embed
275
- x_up1 = F.relu(self.up1(x_mid))
276
- x_up2 = F.relu(self.up2(torch.cat([x_up1, x1], dim=1)))
277
- return self.out(x_up2)
278
-
279
- class TinyDiffusion:
280
- def __init__(self, model, timesteps=100):
281
- self.model = model
282
- self.timesteps = timesteps
283
- self.beta = torch.linspace(0.0001, 0.02, timesteps)
284
- self.alpha = 1 - self.beta
285
- self.alpha_cumprod = torch.cumprod(self.alpha, dim=0)
286
-
287
- def train(self, images, epochs=50):
288
- dataset = TinyDiffusionDataset(images)
289
- dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
290
- optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
291
- device = torch.device("cpu")
292
- self.model.to(device)
293
- for epoch in range(epochs):
294
- total_loss = 0
295
- for x in dataloader:
296
- x = x.to(device)
297
- t = torch.randint(0, self.timesteps, (x.size(0),), device=device).float()
298
- noise = torch.randn_like(x)
299
- alpha_t = self.alpha_cumprod[t.long()].view(-1, 1, 1, 1)
300
- x_noisy = torch.sqrt(alpha_t) * x + torch.sqrt(1 - alpha_t) * noise
301
- pred_noise = self.model(x_noisy, t)
302
- loss = F.mse_loss(pred_noise, noise)
303
- optimizer.zero_grad()
304
- loss.backward()
305
- optimizer.step()
306
- total_loss += loss.item()
307
- logger.info(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataloader):.4f}")
308
- return self
309
-
310
- def generate(self, size=(64, 64), steps=100):
311
- device = torch.device("cpu")
312
- x = torch.randn(1, 3, size[0], size[1], device=device)
313
- for t in reversed(range(steps)):
314
- t_tensor = torch.full((1,), t, device=device, dtype=torch.float32)
315
- alpha_t = self.alpha_cumprod[t].view(-1, 1, 1, 1)
316
- pred_noise = self.model(x, t_tensor)
317
- x = (x - (1 - self.alpha[t]) / torch.sqrt(1 - alpha_t) * pred_noise) / torch.sqrt(self.alpha[t])
318
- if t > 0:
319
- x += torch.sqrt(self.beta[t]) * torch.randn_like(x)
320
- x = torch.clamp(x * 255, 0, 255).byte()
321
- return Image.fromarray(x.squeeze(0).permute(1, 2, 0).cpu().numpy())
322
-
323
- def upscale(self, image, scale_factor=2):
324
- img_tensor = torch.tensor(np.array(image.convert("RGB")).transpose(2, 0, 1), dtype=torch.float32).unsqueeze(0) / 255.0
325
- upscaled = F.interpolate(img_tensor, scale_factor=scale_factor, mode='bilinear', align_corners=False)
326
- upscaled = torch.clamp(upscaled * 255, 0, 255).byte()
327
- return Image.fromarray(upscaled.squeeze(0).permute(1, 2, 0).cpu().numpy())
328
-
329
- class TinyDiffusionDataset(Dataset):
330
- def __init__(self, images):
331
- self.images = [torch.tensor(np.array(img.convert("RGB")).transpose(2, 0, 1), dtype=torch.float32) / 255.0 for img in images]
332
- def __len__(self):
333
- return len(self.images)
334
- def __getitem__(self, idx):
335
- return self.images[idx]
336
-
337
  st.title("AI Vision & SFT Titans 🚀")
338
 
339
  # Sidebar
@@ -365,6 +281,8 @@ with cols[1]:
365
  os.remove(file)
366
  st.session_state['asset_checkboxes'].clear()
367
  st.session_state['downloaded_pdfs'].clear()
 
 
368
  st.sidebar.success("All assets vaporized! 💨")
369
  st.rerun()
370
 
@@ -402,6 +320,10 @@ def update_gallery():
402
  url_key = next((k for k, v in st.session_state['downloaded_pdfs'].items() if v == file), None)
403
  if url_key:
404
  del st.session_state['downloaded_pdfs'][url_key]
 
 
 
 
405
  st.sidebar.success(f"Asset {os.path.basename(file)} vaporized! 💨")
406
  st.rerun()
407
  update_gallery()
@@ -418,8 +340,8 @@ with history_container:
418
  for entry in st.session_state['history'][-gallery_size * 2:]:
419
  st.write(entry)
420
 
421
- tab1, tab2, tab3, tab4, tab5 = st.tabs([
422
- "Camera Snap 📷", "Download PDFs 📥", "Build Titan 🌱", "Test OCR 🔍", "Test Image Gen 🎨"
423
  ])
424
 
425
  with tab1:
@@ -430,26 +352,36 @@ with tab1:
430
  cam0_img = st.camera_input("Take a picture - Cam 0", key="cam0")
431
  if cam0_img:
432
  filename = generate_filename("cam0")
 
 
433
  with open(filename, "wb") as f:
434
  f.write(cam0_img.getvalue())
 
435
  entry = f"Snapshot from Cam 0: {filename}"
436
  if entry not in st.session_state['history']:
437
  st.session_state['history'] = [e for e in st.session_state['history'] if not e.startswith("Snapshot from Cam 0:")] + [entry]
438
  st.image(Image.open(filename), caption="Camera 0", use_container_width=True)
439
  logger.info(f"Saved snapshot from Camera 0: {filename}")
440
  update_gallery()
 
 
441
  with cols[1]:
442
  cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
443
  if cam1_img:
444
  filename = generate_filename("cam1")
 
 
445
  with open(filename, "wb") as f:
446
  f.write(cam1_img.getvalue())
 
447
  entry = f"Snapshot from Cam 1: {filename}"
448
  if entry not in st.session_state['history']:
449
  st.session_state['history'] = [e for e in st.session_state['history'] if not e.startswith("Snapshot from Cam 1:")] + [entry]
450
  st.image(Image.open(filename), caption="Camera 1", use_container_width=True)
451
  logger.info(f"Saved snapshot from Camera 1: {filename}")
452
  update_gallery()
 
 
453
 
454
  with tab2:
455
  st.header("Download PDFs 📥")
@@ -488,6 +420,7 @@ with tab2:
488
  entry = f"Downloaded PDF: {output_path}"
489
  if entry not in st.session_state['history']:
490
  st.session_state['history'].append(entry)
 
491
  else:
492
  st.error(f"Failed to nab {url} 😿")
493
  else:
@@ -506,33 +439,12 @@ with tab2:
506
  snapshots = asyncio.run(process_pdf_snapshot(pdf_path, mode_key))
507
  for snapshot in snapshots:
508
  st.image(Image.open(snapshot), caption=snapshot, use_container_width=True)
 
 
509
  else:
510
  st.warning("No PDFs selected for snapshotting! Check some boxes in the sidebar gallery.")
511
 
512
  with tab3:
513
- st.header("Build Titan 🌱")
514
- model_type = st.selectbox("Model Type", ["Causal LM", "Diffusion"], key="build_type")
515
- base_model = st.selectbox("Select Tiny Model",
516
- ["HuggingFaceTB/SmolLM-135M", "Qwen/Qwen1.5-0.5B-Chat"] if model_type == "Causal LM" else
517
- ["OFA-Sys/small-stable-diffusion-v0", "stabilityai/stable-diffusion-2-base"])
518
- model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}")
519
- domain = st.text_input("Target Domain", "general")
520
- if st.button("Download Model ⬇️"):
521
- config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=model_name, base_model=base_model, size="small", domain=domain)
522
- builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
523
- builder.load_model(base_model, config)
524
- builder.save_model(config.model_path)
525
- st.session_state['builder'] = builder
526
- st.session_state['model_loaded'] = True
527
- st.session_state['selected_model_type'] = model_type
528
- st.session_state['selected_model'] = config.model_path
529
- entry = f"Built {model_type} model: {model_name}"
530
- if entry not in st.session_state['history']:
531
- st.session_state['history'].append(entry)
532
- st.success(f"Model downloaded and saved to {config.model_path}! 🎉")
533
- st.rerun()
534
-
535
- with tab4:
536
  st.header("Test OCR 🔍")
537
  all_files = get_gallery_files()
538
  if all_files:
@@ -597,6 +509,30 @@ with tab4:
597
  else:
598
  st.warning("No assets in gallery yet. Use Camera Snap or Download PDFs!")
599
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
  with tab5:
601
  st.header("Test Image Gen 🎨")
602
  all_files = get_gallery_files()
 
69
  st.session_state['selected_model_type'] = "Causal LM"
70
  if 'selected_model' not in st.session_state:
71
  st.session_state['selected_model'] = "None"
72
+ if 'cam0_file' not in st.session_state:
73
+ st.session_state['cam0_file'] = None
74
+ if 'cam1_file' not in st.session_state:
75
+ st.session_state['cam1_file'] = None
76
 
77
  @dataclass
78
  class ModelConfig:
 
223
  status.text("Processing GOT-OCR2_0... (0s)")
224
  tokenizer = AutoTokenizer.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True)
225
  model = AutoModel.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True, torch_dtype=torch.float32).to("cpu").eval()
226
+ # Save image to temporary file since GOT-OCR2_0 expects a file path
227
+ temp_file = f"temp_{int(time.time())}.png"
228
+ image.save(temp_file)
229
+ result = model.chat(tokenizer, temp_file, ocr_type='ocr')
230
+ os.remove(temp_file) # Clean up temporary file
231
  elapsed = int(time.time() - start_time)
232
  status.text(f"GOT-OCR2_0 completed in {elapsed}s!")
233
  async with aiofiles.open(output_file, "w") as f:
 
239
  start_time = time.time()
240
  status = st.empty()
241
  status.text("Processing Image Gen... (0s)")
242
+ if st.session_state['builder'] and isinstance(st.session_state['builder'], DiffusionBuilder) and st.session_state['builder'].pipeline:
243
+ pipeline = st.session_state['builder'].pipeline
244
+ else:
245
+ pipeline = StableDiffusionPipeline.from_pretrained("OFA-Sys/small-stable-diffusion-v0", torch_dtype=torch.float32).to("cpu")
246
  gen_image = pipeline(prompt, num_inference_steps=20).images[0]
247
  elapsed = int(time.time() - start_time)
248
  status.text(f"Image Gen completed in {elapsed}s!")
 
250
  update_gallery()
251
  return gen_image
252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  st.title("AI Vision & SFT Titans 🚀")
254
 
255
  # Sidebar
 
281
  os.remove(file)
282
  st.session_state['asset_checkboxes'].clear()
283
  st.session_state['downloaded_pdfs'].clear()
284
+ st.session_state['cam0_file'] = None
285
+ st.session_state['cam1_file'] = None
286
  st.sidebar.success("All assets vaporized! 💨")
287
  st.rerun()
288
 
 
320
  url_key = next((k for k, v in st.session_state['downloaded_pdfs'].items() if v == file), None)
321
  if url_key:
322
  del st.session_state['downloaded_pdfs'][url_key]
323
+ if file == st.session_state['cam0_file']:
324
+ st.session_state['cam0_file'] = None
325
+ if file == st.session_state['cam1_file']:
326
+ st.session_state['cam1_file'] = None
327
  st.sidebar.success(f"Asset {os.path.basename(file)} vaporized! 💨")
328
  st.rerun()
329
  update_gallery()
 
340
  for entry in st.session_state['history'][-gallery_size * 2:]:
341
  st.write(entry)
342
 
343
+ tab1, tab2, tab3, tab4 = st.tabs([
344
+ "Camera Snap 📷", "Download PDFs 📥", "Test OCR 🔍", "Build Titan 🌱"
345
  ])
346
 
347
  with tab1:
 
352
  cam0_img = st.camera_input("Take a picture - Cam 0", key="cam0")
353
  if cam0_img:
354
  filename = generate_filename("cam0")
355
+ if st.session_state['cam0_file'] and os.path.exists(st.session_state['cam0_file']):
356
+ os.remove(st.session_state['cam0_file'])
357
  with open(filename, "wb") as f:
358
  f.write(cam0_img.getvalue())
359
+ st.session_state['cam0_file'] = filename
360
  entry = f"Snapshot from Cam 0: {filename}"
361
  if entry not in st.session_state['history']:
362
  st.session_state['history'] = [e for e in st.session_state['history'] if not e.startswith("Snapshot from Cam 0:")] + [entry]
363
  st.image(Image.open(filename), caption="Camera 0", use_container_width=True)
364
  logger.info(f"Saved snapshot from Camera 0: {filename}")
365
  update_gallery()
366
+ elif st.session_state['cam0_file'] and os.path.exists(st.session_state['cam0_file']):
367
+ st.image(Image.open(st.session_state['cam0_file']), caption="Camera 0", use_container_width=True)
368
  with cols[1]:
369
  cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
370
  if cam1_img:
371
  filename = generate_filename("cam1")
372
+ if st.session_state['cam1_file'] and os.path.exists(st.session_state['cam1_file']):
373
+ os.remove(st.session_state['cam1_file'])
374
  with open(filename, "wb") as f:
375
  f.write(cam1_img.getvalue())
376
+ st.session_state['cam1_file'] = filename
377
  entry = f"Snapshot from Cam 1: {filename}"
378
  if entry not in st.session_state['history']:
379
  st.session_state['history'] = [e for e in st.session_state['history'] if not e.startswith("Snapshot from Cam 1:")] + [entry]
380
  st.image(Image.open(filename), caption="Camera 1", use_container_width=True)
381
  logger.info(f"Saved snapshot from Camera 1: {filename}")
382
  update_gallery()
383
+ elif st.session_state['cam1_file'] and os.path.exists(st.session_state['cam1_file']):
384
+ st.image(Image.open(st.session_state['cam1_file']), caption="Camera 1", use_container_width=True)
385
 
386
  with tab2:
387
  st.header("Download PDFs 📥")
 
420
  entry = f"Downloaded PDF: {output_path}"
421
  if entry not in st.session_state['history']:
422
  st.session_state['history'].append(entry)
423
+ st.session_state['asset_checkboxes'][output_path] = True # Auto-check the box
424
  else:
425
  st.error(f"Failed to nab {url} 😿")
426
  else:
 
439
  snapshots = asyncio.run(process_pdf_snapshot(pdf_path, mode_key))
440
  for snapshot in snapshots:
441
  st.image(Image.open(snapshot), caption=snapshot, use_container_width=True)
442
+ st.session_state['asset_checkboxes'][snapshot] = True # Auto-check new snapshots
443
+ update_gallery()
444
  else:
445
  st.warning("No PDFs selected for snapshotting! Check some boxes in the sidebar gallery.")
446
 
447
  with tab3:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  st.header("Test OCR 🔍")
449
  all_files = get_gallery_files()
450
  if all_files:
 
509
  else:
510
  st.warning("No assets in gallery yet. Use Camera Snap or Download PDFs!")
511
 
512
+ with tab4:
513
+ st.header("Build Titan 🌱")
514
+ model_type = st.selectbox("Model Type", ["Causal LM", "Diffusion"], key="build_type")
515
+ base_model = st.selectbox("Select Tiny Model",
516
+ ["HuggingFaceTB/SmolLM-135M", "Qwen/Qwen1.5-0.5B-Chat"] if model_type == "Causal LM" else
517
+ ["OFA-Sys/small-stable-diffusion-v0", "stabilityai/stable-diffusion-2-base"])
518
+ model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}")
519
+ domain = st.text_input("Target Domain", "general")
520
+ if st.button("Download Model ⬇️"):
521
+ config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=model_name, base_model=base_model, size="small", domain=domain)
522
+ builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
523
+ builder.load_model(base_model, config)
524
+ builder.save_model(config.model_path)
525
+ st.session_state['builder'] = builder
526
+ st.session_state['model_loaded'] = True
527
+ st.session_state['selected_model_type'] = model_type
528
+ st.session_state['selected_model'] = config.model_path
529
+ entry = f"Built {model_type} model: {model_name}"
530
+ if entry not in st.session_state['history']:
531
+ st.session_state['history'].append(entry)
532
+ st.success(f"Model downloaded and saved to {config.model_path}! 🎉")
533
+ st.rerun()
534
+
535
+ tab5 = st.tabs(["Test Image Gen 🎨"])[0]
536
  with tab5:
537
  st.header("Test Image Gen 🎨")
538
  all_files = get_gallery_files()