awacke1 commited on
Commit
49639b7
·
verified ·
1 Parent(s): 4e89aed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -55
app.py CHANGED
@@ -175,17 +175,17 @@ class DiffusionBuilder:
175
  self.config = None
176
  self.pipeline = None
177
  self.model_type = None
178
- def load_model(self, model_path: str, config: Optional[DiffusionConfig] = None, model_type: str = "StableDiffusion"):
179
- with st.spinner(f"Loading diffusion model {model_path}... ⏳"):
180
  if model_type == "StableDiffusion":
181
- self.pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float32).to("cpu")
182
  elif model_type == "DDPM":
183
- self.pipeline = DDPMPipeline.from_pretrained(model_path, torch_dtype=torch.float32).to("cpu")
184
  self.pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipeline.scheduler.config)
185
  if config:
186
  self.config = config
187
  self.model_type = model_type
188
- st.success(f"Diffusion model loaded! 🎨")
189
  return self
190
  def fine_tune_sft(self, images, texts, epochs=3):
191
  dataset = DiffusionDataset(images, texts)
@@ -339,11 +339,35 @@ if selected_model != "None" and st.sidebar.button("Load Model 📂"):
339
  st.session_state['model_loaded'] = True
340
  st.rerun()
341
 
342
- # Tabs (Reordered: Camera Snap first)
343
- tab1, tab2, tab3, tab4, tab5 = st.tabs(["Camera Snap 📷", "Fine-Tune Titan 🔧", "Build Titan 🌱", "Test Titan 🧪", "Agentic RAG Party 🌐"])
344
 
345
  with tab1:
346
- st.header("Camera Snap 📷 (Dual Capture!)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  slice_count = st.number_input("Image Slice Count", min_value=1, max_value=20, value=10)
348
  video_length = st.number_input("Video Length (seconds)", min_value=1, max_value=30, value=10)
349
  cols = st.columns(2)
@@ -352,24 +376,26 @@ with tab1:
352
  cam0_img = st.camera_input("Take a picture - Cam 0", key="cam0")
353
  if cam0_img:
354
  filename = generate_filename(0)
355
- with open(filename, "wb") as f:
356
- f.write(cam0_img.getvalue())
357
- st.image(Image.open(filename), caption=filename, use_container_width=True)
358
- logger.info(f"Saved snapshot from Camera 0: {filename}")
359
- st.session_state['captured_images'].append(filename)
360
- update_gallery()
 
361
  if st.button(f"Capture {slice_count} Frames - Cam 0 📸"):
362
  st.session_state['cam0_frames'] = []
363
  for i in range(slice_count):
364
  img = st.camera_input(f"Frame {i} - Cam 0", key=f"cam0_frame_{i}_{time.time()}")
365
  if img:
366
  filename = generate_filename(f"0_{i}")
367
- with open(filename, "wb") as f:
368
- f.write(img.getvalue())
369
- st.session_state['cam0_frames'].append(filename)
370
- logger.info(f"Saved frame {i} from Camera 0: {filename}")
371
- time.sleep(1.0 / slice_count)
372
- st.session_state['captured_images'].extend(st.session_state['cam0_frames'])
 
373
  update_gallery()
374
  for frame in st.session_state['cam0_frames']:
375
  st.image(Image.open(frame), caption=frame, use_container_width=True)
@@ -378,24 +404,26 @@ with tab1:
378
  cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
379
  if cam1_img:
380
  filename = generate_filename(1)
381
- with open(filename, "wb") as f:
382
- f.write(cam1_img.getvalue())
383
- st.image(Image.open(filename), caption=filename, use_container_width=True)
384
- logger.info(f"Saved snapshot from Camera 1: {filename}")
385
- st.session_state['captured_images'].append(filename)
386
- update_gallery()
 
387
  if st.button(f"Capture {slice_count} Frames - Cam 1 📸"):
388
  st.session_state['cam1_frames'] = []
389
  for i in range(slice_count):
390
  img = st.camera_input(f"Frame {i} - Cam 1", key=f"cam1_frame_{i}_{time.time()}")
391
  if img:
392
  filename = generate_filename(f"1_{i}")
393
- with open(filename, "wb") as f:
394
- f.write(img.getvalue())
395
- st.session_state['cam1_frames'].append(filename)
396
- logger.info(f"Saved frame {i} from Camera 1: {filename}")
397
- time.sleep(1.0 / slice_count)
398
- st.session_state['captured_images'].extend(st.session_state['cam1_frames'])
 
399
  update_gallery()
400
  for frame in st.session_state['cam1_frames']:
401
  st.image(Image.open(frame), caption=frame, use_container_width=True)
@@ -444,28 +472,6 @@ with tab2:
444
  st.markdown(get_download_link(csv_path, "text/csv", "Download SFT Dataset CSV"), unsafe_allow_html=True)
445
 
446
  with tab3:
447
- st.header("Build Titan 🌱")
448
- model_type = st.selectbox("Model Type", ["Causal LM", "Diffusion"], key="build_type")
449
- base_model_options = {
450
- "Causal LM": ["HuggingFaceTB/SmolLM-135M", "Qwen/Qwen1.5-0.5B-Chat"],
451
- "Diffusion": [
452
- "OFA-Sys/small-stable-diffusion-v0 (LDM/Conditional)",
453
- "google/ddpm-ema-celebahq-256 (DDPM/SDE/Autoregressive Proxy)"
454
- ]
455
- }
456
- base_model = st.selectbox("Select Tiny Model", base_model_options[model_type])
457
- model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}")
458
- if st.button("Download Model ⬇️"):
459
- config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=model_name, base_model=base_model.split(" ")[0], size="small")
460
- builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
461
- model_type_for_diffusion = "StableDiffusion" if "small-stable-diffusion" in base_model else "DDPM"
462
- builder.load_model(base_model.split(" ")[0], config, model_type_for_diffusion)
463
- builder.save_model(config.model_path)
464
- st.session_state['builder'] = builder
465
- st.session_state['model_loaded'] = True
466
- st.rerun()
467
-
468
- with tab4:
469
  st.header("Test Titan 🧪")
470
  if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
471
  st.warning("Please build or load a Titan first! ⚠️")
@@ -487,7 +493,7 @@ with tab4:
487
  image = st.session_state['builder'].generate(prompt)
488
  st.image(image, caption=f"Generated from {selected_pipeline}")
489
 
490
- with tab5:
491
  st.header("Agentic RAG Party 🌐")
492
  if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
493
  st.warning("Please build or load a Titan first! ⚠️")
 
175
  self.config = None
176
  self.pipeline = None
177
  self.model_type = None
178
+ def load_model(self, model_path: str, config: Optional[DiffusionConfig] = None, model_type: str = "StableDiffusion", download: bool = True):
179
+ with st.spinner(f"{'Downloading' if download else 'Loading'} {model_path}... ⏳"):
180
  if model_type == "StableDiffusion":
181
+ self.pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float32, use_safetensors=True, local_files_only=not download).to("cpu")
182
  elif model_type == "DDPM":
183
+ self.pipeline = DDPMPipeline.from_pretrained(model_path, torch_dtype=torch.float32, use_safetensors=True, local_files_only=not download).to("cpu")
184
  self.pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipeline.scheduler.config)
185
  if config:
186
  self.config = config
187
  self.model_type = model_type
188
+ st.success(f"Diffusion model {'downloaded' if download else 'loaded'}! 🎨")
189
  return self
190
  def fine_tune_sft(self, images, texts, epochs=3):
191
  dataset = DiffusionDataset(images, texts)
 
339
  st.session_state['model_loaded'] = True
340
  st.rerun()
341
 
342
+ # Tabs
343
+ tab1, tab2, tab3, tab4 = st.tabs(["Build Titan & Camera Snap 🌱📷", "Fine-Tune Titan 🔧", "Test Titan 🧪", "Agentic RAG Party 🌐"])
344
 
345
  with tab1:
346
+ st.header("Build Titan & Camera Snap 🌱📷")
347
+ st.subheader("Build Titan 🌱")
348
+ model_type = st.selectbox("Model Type", ["Causal LM", "Diffusion"], key="build_type")
349
+ base_model_options = {
350
+ "Causal LM": ["HuggingFaceTB/SmolLM-135M", "Qwen/Qwen1.5-0.5B-Chat"],
351
+ "Diffusion": [
352
+ "OFA-Sys/small-stable-diffusion-v0 (LDM/Conditional, ~300 MB)",
353
+ "google/ddpm-ema-celebahq-256 (DDPM/SDE/Autoregressive Proxy, ~280 MB)"
354
+ ]
355
+ }
356
+ base_model = st.selectbox("Select Tiny Model", base_model_options[model_type])
357
+ action = st.radio("Action", ["Use Model", "Download Model"], index=0 if "Causal LM" in model_type else 1)
358
+ model_name = st.text_input("Model Name (for Download)", f"tiny-titan-{int(time.time())}") if action == "Download Model" else None
359
+ if st.button(f"{action} ⬇️"):
360
+ config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=model_name or base_model.split(" ")[0], base_model=base_model.split(" ")[0], size="small")
361
+ builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
362
+ model_type_for_diffusion = "StableDiffusion" if "small-stable-diffusion" in base_model else "DDPM"
363
+ builder.load_model(base_model.split(" ")[0], config, model_type_for_diffusion, download=action == "Download Model")
364
+ if action == "Download Model":
365
+ builder.save_model(config.model_path)
366
+ st.session_state['builder'] = builder
367
+ st.session_state['model_loaded'] = True
368
+ st.rerun()
369
+
370
+ st.subheader("Camera Snap 📷")
371
  slice_count = st.number_input("Image Slice Count", min_value=1, max_value=20, value=10)
372
  video_length = st.number_input("Video Length (seconds)", min_value=1, max_value=30, value=10)
373
  cols = st.columns(2)
 
376
  cam0_img = st.camera_input("Take a picture - Cam 0", key="cam0")
377
  if cam0_img:
378
  filename = generate_filename(0)
379
+ if filename not in st.session_state['captured_images']:
380
+ with open(filename, "wb") as f:
381
+ f.write(cam0_img.getvalue())
382
+ st.image(Image.open(filename), caption=filename, use_container_width=True)
383
+ logger.info(f"Saved snapshot from Camera 0: {filename}")
384
+ st.session_state['captured_images'].append(filename)
385
+ update_gallery()
386
  if st.button(f"Capture {slice_count} Frames - Cam 0 📸"):
387
  st.session_state['cam0_frames'] = []
388
  for i in range(slice_count):
389
  img = st.camera_input(f"Frame {i} - Cam 0", key=f"cam0_frame_{i}_{time.time()}")
390
  if img:
391
  filename = generate_filename(f"0_{i}")
392
+ if filename not in st.session_state['captured_images']:
393
+ with open(filename, "wb") as f:
394
+ f.write(img.getvalue())
395
+ st.session_state['cam0_frames'].append(filename)
396
+ logger.info(f"Saved frame {i} from Camera 0: {filename}")
397
+ time.sleep(1.0 / slice_count)
398
+ st.session_state['captured_images'].extend([f for f in st.session_state['cam0_frames'] if f not in st.session_state['captured_images']])
399
  update_gallery()
400
  for frame in st.session_state['cam0_frames']:
401
  st.image(Image.open(frame), caption=frame, use_container_width=True)
 
404
  cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
405
  if cam1_img:
406
  filename = generate_filename(1)
407
+ if filename not in st.session_state['captured_images']:
408
+ with open(filename, "wb") as f:
409
+ f.write(cam1_img.getvalue())
410
+ st.image(Image.open(filename), caption=filename, use_container_width=True)
411
+ logger.info(f"Saved snapshot from Camera 1: {filename}")
412
+ st.session_state['captured_images'].append(filename)
413
+ update_gallery()
414
  if st.button(f"Capture {slice_count} Frames - Cam 1 📸"):
415
  st.session_state['cam1_frames'] = []
416
  for i in range(slice_count):
417
  img = st.camera_input(f"Frame {i} - Cam 1", key=f"cam1_frame_{i}_{time.time()}")
418
  if img:
419
  filename = generate_filename(f"1_{i}")
420
+ if filename not in st.session_state['captured_images']:
421
+ with open(filename, "wb") as f:
422
+ f.write(img.getvalue())
423
+ st.session_state['cam1_frames'].append(filename)
424
+ logger.info(f"Saved frame {i} from Camera 1: {filename}")
425
+ time.sleep(1.0 / slice_count)
426
+ st.session_state['captured_images'].extend([f for f in st.session_state['cam1_frames'] if f not in st.session_state['captured_images']])
427
  update_gallery()
428
  for frame in st.session_state['cam1_frames']:
429
  st.image(Image.open(frame), caption=frame, use_container_width=True)
 
472
  st.markdown(get_download_link(csv_path, "text/csv", "Download SFT Dataset CSV"), unsafe_allow_html=True)
473
 
474
  with tab3:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  st.header("Test Titan 🧪")
476
  if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
477
  st.warning("Please build or load a Titan first! ⚠️")
 
493
  image = st.session_state['builder'].generate(prompt)
494
  st.image(image, caption=f"Generated from {selected_pipeline}")
495
 
496
+ with tab4:
497
  st.header("Agentic RAG Party 🌐")
498
  if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
499
  st.warning("Please build or load a Titan first! ⚠️")