Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Update app.py
Browse files
app.py
CHANGED
@@ -175,17 +175,17 @@ class DiffusionBuilder:
|
|
175 |
self.config = None
|
176 |
self.pipeline = None
|
177 |
self.model_type = None
|
178 |
-
def load_model(self, model_path: str, config: Optional[DiffusionConfig] = None, model_type: str = "StableDiffusion"):
|
179 |
-
with st.spinner(f"
|
180 |
if model_type == "StableDiffusion":
|
181 |
-
self.pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float32).to("cpu")
|
182 |
elif model_type == "DDPM":
|
183 |
-
self.pipeline = DDPMPipeline.from_pretrained(model_path, torch_dtype=torch.float32).to("cpu")
|
184 |
self.pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipeline.scheduler.config)
|
185 |
if config:
|
186 |
self.config = config
|
187 |
self.model_type = model_type
|
188 |
-
st.success(f"Diffusion model loaded! 🎨")
|
189 |
return self
|
190 |
def fine_tune_sft(self, images, texts, epochs=3):
|
191 |
dataset = DiffusionDataset(images, texts)
|
@@ -339,11 +339,35 @@ if selected_model != "None" and st.sidebar.button("Load Model 📂"):
|
|
339 |
st.session_state['model_loaded'] = True
|
340 |
st.rerun()
|
341 |
|
342 |
-
# Tabs
|
343 |
-
tab1, tab2, tab3, tab4
|
344 |
|
345 |
with tab1:
|
346 |
-
st.header("Camera Snap
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
slice_count = st.number_input("Image Slice Count", min_value=1, max_value=20, value=10)
|
348 |
video_length = st.number_input("Video Length (seconds)", min_value=1, max_value=30, value=10)
|
349 |
cols = st.columns(2)
|
@@ -352,24 +376,26 @@ with tab1:
|
|
352 |
cam0_img = st.camera_input("Take a picture - Cam 0", key="cam0")
|
353 |
if cam0_img:
|
354 |
filename = generate_filename(0)
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
|
|
361 |
if st.button(f"Capture {slice_count} Frames - Cam 0 📸"):
|
362 |
st.session_state['cam0_frames'] = []
|
363 |
for i in range(slice_count):
|
364 |
img = st.camera_input(f"Frame {i} - Cam 0", key=f"cam0_frame_{i}_{time.time()}")
|
365 |
if img:
|
366 |
filename = generate_filename(f"0_{i}")
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
|
|
373 |
update_gallery()
|
374 |
for frame in st.session_state['cam0_frames']:
|
375 |
st.image(Image.open(frame), caption=frame, use_container_width=True)
|
@@ -378,24 +404,26 @@ with tab1:
|
|
378 |
cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
|
379 |
if cam1_img:
|
380 |
filename = generate_filename(1)
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
|
|
387 |
if st.button(f"Capture {slice_count} Frames - Cam 1 📸"):
|
388 |
st.session_state['cam1_frames'] = []
|
389 |
for i in range(slice_count):
|
390 |
img = st.camera_input(f"Frame {i} - Cam 1", key=f"cam1_frame_{i}_{time.time()}")
|
391 |
if img:
|
392 |
filename = generate_filename(f"1_{i}")
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
|
|
399 |
update_gallery()
|
400 |
for frame in st.session_state['cam1_frames']:
|
401 |
st.image(Image.open(frame), caption=frame, use_container_width=True)
|
@@ -444,28 +472,6 @@ with tab2:
|
|
444 |
st.markdown(get_download_link(csv_path, "text/csv", "Download SFT Dataset CSV"), unsafe_allow_html=True)
|
445 |
|
446 |
with tab3:
|
447 |
-
st.header("Build Titan 🌱")
|
448 |
-
model_type = st.selectbox("Model Type", ["Causal LM", "Diffusion"], key="build_type")
|
449 |
-
base_model_options = {
|
450 |
-
"Causal LM": ["HuggingFaceTB/SmolLM-135M", "Qwen/Qwen1.5-0.5B-Chat"],
|
451 |
-
"Diffusion": [
|
452 |
-
"OFA-Sys/small-stable-diffusion-v0 (LDM/Conditional)",
|
453 |
-
"google/ddpm-ema-celebahq-256 (DDPM/SDE/Autoregressive Proxy)"
|
454 |
-
]
|
455 |
-
}
|
456 |
-
base_model = st.selectbox("Select Tiny Model", base_model_options[model_type])
|
457 |
-
model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}")
|
458 |
-
if st.button("Download Model ⬇️"):
|
459 |
-
config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=model_name, base_model=base_model.split(" ")[0], size="small")
|
460 |
-
builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
|
461 |
-
model_type_for_diffusion = "StableDiffusion" if "small-stable-diffusion" in base_model else "DDPM"
|
462 |
-
builder.load_model(base_model.split(" ")[0], config, model_type_for_diffusion)
|
463 |
-
builder.save_model(config.model_path)
|
464 |
-
st.session_state['builder'] = builder
|
465 |
-
st.session_state['model_loaded'] = True
|
466 |
-
st.rerun()
|
467 |
-
|
468 |
-
with tab4:
|
469 |
st.header("Test Titan 🧪")
|
470 |
if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
|
471 |
st.warning("Please build or load a Titan first! ⚠️")
|
@@ -487,7 +493,7 @@ with tab4:
|
|
487 |
image = st.session_state['builder'].generate(prompt)
|
488 |
st.image(image, caption=f"Generated from {selected_pipeline}")
|
489 |
|
490 |
-
with
|
491 |
st.header("Agentic RAG Party 🌐")
|
492 |
if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
|
493 |
st.warning("Please build or load a Titan first! ⚠️")
|
|
|
175 |
self.config = None
|
176 |
self.pipeline = None
|
177 |
self.model_type = None
|
178 |
+
def load_model(self, model_path: str, config: Optional[DiffusionConfig] = None, model_type: str = "StableDiffusion", download: bool = True):
|
179 |
+
with st.spinner(f"{'Downloading' if download else 'Loading'} {model_path}... ⏳"):
|
180 |
if model_type == "StableDiffusion":
|
181 |
+
self.pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float32, use_safetensors=True, local_files_only=not download).to("cpu")
|
182 |
elif model_type == "DDPM":
|
183 |
+
self.pipeline = DDPMPipeline.from_pretrained(model_path, torch_dtype=torch.float32, use_safetensors=True, local_files_only=not download).to("cpu")
|
184 |
self.pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipeline.scheduler.config)
|
185 |
if config:
|
186 |
self.config = config
|
187 |
self.model_type = model_type
|
188 |
+
st.success(f"Diffusion model {'downloaded' if download else 'loaded'}! 🎨")
|
189 |
return self
|
190 |
def fine_tune_sft(self, images, texts, epochs=3):
|
191 |
dataset = DiffusionDataset(images, texts)
|
|
|
339 |
st.session_state['model_loaded'] = True
|
340 |
st.rerun()
|
341 |
|
342 |
+
# Tabs
|
343 |
+
tab1, tab2, tab3, tab4 = st.tabs(["Build Titan & Camera Snap 🌱📷", "Fine-Tune Titan 🔧", "Test Titan 🧪", "Agentic RAG Party 🌐"])
|
344 |
|
345 |
with tab1:
|
346 |
+
st.header("Build Titan & Camera Snap 🌱📷")
|
347 |
+
st.subheader("Build Titan 🌱")
|
348 |
+
model_type = st.selectbox("Model Type", ["Causal LM", "Diffusion"], key="build_type")
|
349 |
+
base_model_options = {
|
350 |
+
"Causal LM": ["HuggingFaceTB/SmolLM-135M", "Qwen/Qwen1.5-0.5B-Chat"],
|
351 |
+
"Diffusion": [
|
352 |
+
"OFA-Sys/small-stable-diffusion-v0 (LDM/Conditional, ~300 MB)",
|
353 |
+
"google/ddpm-ema-celebahq-256 (DDPM/SDE/Autoregressive Proxy, ~280 MB)"
|
354 |
+
]
|
355 |
+
}
|
356 |
+
base_model = st.selectbox("Select Tiny Model", base_model_options[model_type])
|
357 |
+
action = st.radio("Action", ["Use Model", "Download Model"], index=0 if "Causal LM" in model_type else 1)
|
358 |
+
model_name = st.text_input("Model Name (for Download)", f"tiny-titan-{int(time.time())}") if action == "Download Model" else None
|
359 |
+
if st.button(f"{action} ⬇️"):
|
360 |
+
config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=model_name or base_model.split(" ")[0], base_model=base_model.split(" ")[0], size="small")
|
361 |
+
builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
|
362 |
+
model_type_for_diffusion = "StableDiffusion" if "small-stable-diffusion" in base_model else "DDPM"
|
363 |
+
builder.load_model(base_model.split(" ")[0], config, model_type_for_diffusion, download=action == "Download Model")
|
364 |
+
if action == "Download Model":
|
365 |
+
builder.save_model(config.model_path)
|
366 |
+
st.session_state['builder'] = builder
|
367 |
+
st.session_state['model_loaded'] = True
|
368 |
+
st.rerun()
|
369 |
+
|
370 |
+
st.subheader("Camera Snap 📷")
|
371 |
slice_count = st.number_input("Image Slice Count", min_value=1, max_value=20, value=10)
|
372 |
video_length = st.number_input("Video Length (seconds)", min_value=1, max_value=30, value=10)
|
373 |
cols = st.columns(2)
|
|
|
376 |
cam0_img = st.camera_input("Take a picture - Cam 0", key="cam0")
|
377 |
if cam0_img:
|
378 |
filename = generate_filename(0)
|
379 |
+
if filename not in st.session_state['captured_images']:
|
380 |
+
with open(filename, "wb") as f:
|
381 |
+
f.write(cam0_img.getvalue())
|
382 |
+
st.image(Image.open(filename), caption=filename, use_container_width=True)
|
383 |
+
logger.info(f"Saved snapshot from Camera 0: {filename}")
|
384 |
+
st.session_state['captured_images'].append(filename)
|
385 |
+
update_gallery()
|
386 |
if st.button(f"Capture {slice_count} Frames - Cam 0 📸"):
|
387 |
st.session_state['cam0_frames'] = []
|
388 |
for i in range(slice_count):
|
389 |
img = st.camera_input(f"Frame {i} - Cam 0", key=f"cam0_frame_{i}_{time.time()}")
|
390 |
if img:
|
391 |
filename = generate_filename(f"0_{i}")
|
392 |
+
if filename not in st.session_state['captured_images']:
|
393 |
+
with open(filename, "wb") as f:
|
394 |
+
f.write(img.getvalue())
|
395 |
+
st.session_state['cam0_frames'].append(filename)
|
396 |
+
logger.info(f"Saved frame {i} from Camera 0: {filename}")
|
397 |
+
time.sleep(1.0 / slice_count)
|
398 |
+
st.session_state['captured_images'].extend([f for f in st.session_state['cam0_frames'] if f not in st.session_state['captured_images']])
|
399 |
update_gallery()
|
400 |
for frame in st.session_state['cam0_frames']:
|
401 |
st.image(Image.open(frame), caption=frame, use_container_width=True)
|
|
|
404 |
cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
|
405 |
if cam1_img:
|
406 |
filename = generate_filename(1)
|
407 |
+
if filename not in st.session_state['captured_images']:
|
408 |
+
with open(filename, "wb") as f:
|
409 |
+
f.write(cam1_img.getvalue())
|
410 |
+
st.image(Image.open(filename), caption=filename, use_container_width=True)
|
411 |
+
logger.info(f"Saved snapshot from Camera 1: {filename}")
|
412 |
+
st.session_state['captured_images'].append(filename)
|
413 |
+
update_gallery()
|
414 |
if st.button(f"Capture {slice_count} Frames - Cam 1 📸"):
|
415 |
st.session_state['cam1_frames'] = []
|
416 |
for i in range(slice_count):
|
417 |
img = st.camera_input(f"Frame {i} - Cam 1", key=f"cam1_frame_{i}_{time.time()}")
|
418 |
if img:
|
419 |
filename = generate_filename(f"1_{i}")
|
420 |
+
if filename not in st.session_state['captured_images']:
|
421 |
+
with open(filename, "wb") as f:
|
422 |
+
f.write(img.getvalue())
|
423 |
+
st.session_state['cam1_frames'].append(filename)
|
424 |
+
logger.info(f"Saved frame {i} from Camera 1: {filename}")
|
425 |
+
time.sleep(1.0 / slice_count)
|
426 |
+
st.session_state['captured_images'].extend([f for f in st.session_state['cam1_frames'] if f not in st.session_state['captured_images']])
|
427 |
update_gallery()
|
428 |
for frame in st.session_state['cam1_frames']:
|
429 |
st.image(Image.open(frame), caption=frame, use_container_width=True)
|
|
|
472 |
st.markdown(get_download_link(csv_path, "text/csv", "Download SFT Dataset CSV"), unsafe_allow_html=True)
|
473 |
|
474 |
with tab3:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
475 |
st.header("Test Titan 🧪")
|
476 |
if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
|
477 |
st.warning("Please build or load a Titan first! ⚠️")
|
|
|
493 |
image = st.session_state['builder'].generate(prompt)
|
494 |
st.image(image, caption=f"Generated from {selected_pipeline}")
|
495 |
|
496 |
+
with tab4:
|
497 |
st.header("Agentic RAG Party 🌐")
|
498 |
if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
|
499 |
st.warning("Please build or load a Titan first! ⚠️")
|