Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
import os
|
3 |
-
import shutil
|
4 |
import glob
|
5 |
import base64
|
6 |
import streamlit as st
|
@@ -31,7 +30,7 @@ st.set_page_config(
|
|
31 |
initial_sidebar_state="expanded",
|
32 |
menu_items={
|
33 |
'Get Help': 'https://huggingface.co/awacke1',
|
34 |
-
'Report a
|
35 |
'About': "Tiny Titans: Small models, big dreams, and a sprinkle of chaos! 🌌"
|
36 |
}
|
37 |
)
|
@@ -177,9 +176,9 @@ class DiffusionBuilder:
|
|
177 |
total_loss = 0
|
178 |
for batch in dataloader:
|
179 |
optimizer.zero_grad()
|
180 |
-
image = batch["image"].to(self.pipeline.device)
|
181 |
-
text = batch["text"]
|
182 |
-
latents = self.pipeline.vae.encode(image).latent_dist.sample()
|
183 |
noise = torch.randn_like(latents)
|
184 |
timesteps = torch.randint(0, self.pipeline.scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device)
|
185 |
noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)
|
@@ -220,9 +219,10 @@ def get_model_files(model_type="causal_lm"):
|
|
220 |
def get_gallery_files(file_types):
|
221 |
return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
|
222 |
|
|
|
223 |
def mock_search(query: str) -> str:
|
224 |
if "superhero" in query.lower():
|
225 |
-
return "Latest trends
|
226 |
return "No relevant results found."
|
227 |
|
228 |
class PartyPlannerAgent:
|
@@ -291,6 +291,7 @@ if media_files:
|
|
291 |
for idx, file in enumerate(media_files[:gallery_size * 2]):
|
292 |
with cols[idx % 2]:
|
293 |
st.image(Image.open(file), caption=file, use_column_width=True)
|
|
|
294 |
|
295 |
st.sidebar.subheader("Model Management 🗂️")
|
296 |
model_type = st.sidebar.selectbox("Model Type", ["Causal LM", "Diffusion"])
|
@@ -351,7 +352,7 @@ with tab2:
|
|
351 |
f.write(img.getvalue())
|
352 |
st.session_state['cam0_frames'].append(filename)
|
353 |
logger.info(f"Saved frame {i} from Camera 0: {filename}")
|
354 |
-
time.sleep(1.0 / slice_count)
|
355 |
st.session_state['captured_images'].extend(st.session_state['cam0_frames'])
|
356 |
update_gallery()
|
357 |
for frame in st.session_state['cam0_frames']:
|
@@ -379,7 +380,7 @@ with tab2:
|
|
379 |
f.write(img.getvalue())
|
380 |
st.session_state['cam1_frames'].append(filename)
|
381 |
logger.info(f"Saved frame {i} from Camera 1: {filename}")
|
382 |
-
time.sleep(1.0 / slice_count)
|
383 |
st.session_state['captured_images'].extend(st.session_state['cam1_frames'])
|
384 |
update_gallery()
|
385 |
for frame in st.session_state['cam1_frames']:
|
@@ -420,6 +421,13 @@ with tab3:
|
|
420 |
zip_path = f"{new_config.model_path}.zip"
|
421 |
zip_directory(new_config.model_path, zip_path)
|
422 |
st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Diffusion Model"), unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
423 |
|
424 |
with tab4:
|
425 |
st.header("Test Titan 🧪")
|
@@ -456,4 +464,11 @@ with tab5:
|
|
456 |
st.dataframe(plan_df)
|
457 |
for _, row in plan_df.iterrows():
|
458 |
image = agent.generate(row["Image Idea"])
|
459 |
-
st.image(image, caption=f"{row['Theme']} - {row['Image Idea']}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
#!/usr/bin/env python3
|
2 |
import os
|
|
|
3 |
import glob
|
4 |
import base64
|
5 |
import streamlit as st
|
|
|
30 |
initial_sidebar_state="expanded",
|
31 |
menu_items={
|
32 |
'Get Help': 'https://huggingface.co/awacke1',
|
33 |
+
'Report a Bug': 'https://huggingface.co/spaces/awacke1',
|
34 |
'About': "Tiny Titans: Small models, big dreams, and a sprinkle of chaos! 🌌"
|
35 |
}
|
36 |
)
|
|
|
176 |
total_loss = 0
|
177 |
for batch in dataloader:
|
178 |
optimizer.zero_grad()
|
179 |
+
image = batch["image"][0].to(self.pipeline.device)
|
180 |
+
text = batch["text"][0]
|
181 |
+
latents = self.pipeline.vae.encode(torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float().to(self.pipeline.device)).latent_dist.sample()
|
182 |
noise = torch.randn_like(latents)
|
183 |
timesteps = torch.randint(0, self.pipeline.scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device)
|
184 |
noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)
|
|
|
219 |
def get_gallery_files(file_types):
|
220 |
return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
|
221 |
|
222 |
+
# Mock Search Tool for RAG
|
223 |
def mock_search(query: str) -> str:
|
224 |
if "superhero" in query.lower():
|
225 |
+
return "Latest trends: Gold-plated Batman statues, VR superhero battles."
|
226 |
return "No relevant results found."
|
227 |
|
228 |
class PartyPlannerAgent:
|
|
|
291 |
for idx, file in enumerate(media_files[:gallery_size * 2]):
|
292 |
with cols[idx % 2]:
|
293 |
st.image(Image.open(file), caption=file, use_column_width=True)
|
294 |
+
st.markdown(get_download_link(file, "image/png", "Download Image"), unsafe_allow_html=True)
|
295 |
|
296 |
st.sidebar.subheader("Model Management 🗂️")
|
297 |
model_type = st.sidebar.selectbox("Model Type", ["Causal LM", "Diffusion"])
|
|
|
352 |
f.write(img.getvalue())
|
353 |
st.session_state['cam0_frames'].append(filename)
|
354 |
logger.info(f"Saved frame {i} from Camera 0: {filename}")
|
355 |
+
time.sleep(1.0 / slice_count)
|
356 |
st.session_state['captured_images'].extend(st.session_state['cam0_frames'])
|
357 |
update_gallery()
|
358 |
for frame in st.session_state['cam0_frames']:
|
|
|
380 |
f.write(img.getvalue())
|
381 |
st.session_state['cam1_frames'].append(filename)
|
382 |
logger.info(f"Saved frame {i} from Camera 1: {filename}")
|
383 |
+
time.sleep(1.0 / slice_count)
|
384 |
st.session_state['captured_images'].extend(st.session_state['cam1_frames'])
|
385 |
update_gallery()
|
386 |
for frame in st.session_state['cam1_frames']:
|
|
|
421 |
zip_path = f"{new_config.model_path}.zip"
|
422 |
zip_directory(new_config.model_path, zip_path)
|
423 |
st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Diffusion Model"), unsafe_allow_html=True)
|
424 |
+
csv_path = f"sft_dataset_{int(time.time())}.csv"
|
425 |
+
with open(csv_path, "w", newline="") as f:
|
426 |
+
writer = csv.writer(f)
|
427 |
+
writer.writerow(["image", "text"])
|
428 |
+
for _, row in edited_data.iterrows():
|
429 |
+
writer.writerow([row["image"], row["text"]])
|
430 |
+
st.markdown(get_download_link(csv_path, "text/csv", "Download SFT Dataset CSV"), unsafe_allow_html=True)
|
431 |
|
432 |
with tab4:
|
433 |
st.header("Test Titan 🧪")
|
|
|
464 |
st.dataframe(plan_df)
|
465 |
for _, row in plan_df.iterrows():
|
466 |
image = agent.generate(row["Image Idea"])
|
467 |
+
st.image(image, caption=f"{row['Theme']} - {row['Image Idea']}")
|
468 |
+
|
469 |
+
# Main App
|
470 |
+
st.sidebar.subheader("Action Logs 📜")
|
471 |
+
log_container = st.sidebar.empty()
|
472 |
+
with log_container:
|
473 |
+
for record in logger.handlers[0].buffer:
|
474 |
+
st.write(f"{record.asctime} - {record.levelname} - {record.message}")
|