awacke1 commited on
Commit
301c0b7
·
verified ·
1 Parent(s): 6f8a2f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -9
app.py CHANGED
@@ -1,6 +1,5 @@
1
  #!/usr/bin/env python3
2
  import os
3
- import shutil
4
  import glob
5
  import base64
6
  import streamlit as st
@@ -31,7 +30,7 @@ st.set_page_config(
31
  initial_sidebar_state="expanded",
32
  menu_items={
33
  'Get Help': 'https://huggingface.co/awacke1',
34
- 'Report a bug': 'https://huggingface.co/spaces/awacke1',
35
  'About': "Tiny Titans: Small models, big dreams, and a sprinkle of chaos! 🌌"
36
  }
37
  )
@@ -177,9 +176,9 @@ class DiffusionBuilder:
177
  total_loss = 0
178
  for batch in dataloader:
179
  optimizer.zero_grad()
180
- image = batch["image"].to(self.pipeline.device)
181
- text = batch["text"]
182
- latents = self.pipeline.vae.encode(image).latent_dist.sample()
183
  noise = torch.randn_like(latents)
184
  timesteps = torch.randint(0, self.pipeline.scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device)
185
  noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)
@@ -220,9 +219,10 @@ def get_model_files(model_type="causal_lm"):
220
  def get_gallery_files(file_types):
221
  return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
222
 
 
223
  def mock_search(query: str) -> str:
224
  if "superhero" in query.lower():
225
- return "Latest trends for 2025: Gold-plated Batman statues, VR superhero battles."
226
  return "No relevant results found."
227
 
228
  class PartyPlannerAgent:
@@ -291,6 +291,7 @@ if media_files:
291
  for idx, file in enumerate(media_files[:gallery_size * 2]):
292
  with cols[idx % 2]:
293
  st.image(Image.open(file), caption=file, use_column_width=True)
 
294
 
295
  st.sidebar.subheader("Model Management 🗂️")
296
  model_type = st.sidebar.selectbox("Model Type", ["Causal LM", "Diffusion"])
@@ -351,7 +352,7 @@ with tab2:
351
  f.write(img.getvalue())
352
  st.session_state['cam0_frames'].append(filename)
353
  logger.info(f"Saved frame {i} from Camera 0: {filename}")
354
- time.sleep(1.0 / slice_count) # Adjust frame rate
355
  st.session_state['captured_images'].extend(st.session_state['cam0_frames'])
356
  update_gallery()
357
  for frame in st.session_state['cam0_frames']:
@@ -379,7 +380,7 @@ with tab2:
379
  f.write(img.getvalue())
380
  st.session_state['cam1_frames'].append(filename)
381
  logger.info(f"Saved frame {i} from Camera 1: {filename}")
382
- time.sleep(1.0 / slice_count) # Adjust frame rate
383
  st.session_state['captured_images'].extend(st.session_state['cam1_frames'])
384
  update_gallery()
385
  for frame in st.session_state['cam1_frames']:
@@ -420,6 +421,13 @@ with tab3:
420
  zip_path = f"{new_config.model_path}.zip"
421
  zip_directory(new_config.model_path, zip_path)
422
  st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Diffusion Model"), unsafe_allow_html=True)
 
 
 
 
 
 
 
423
 
424
  with tab4:
425
  st.header("Test Titan 🧪")
@@ -456,4 +464,11 @@ with tab5:
456
  st.dataframe(plan_df)
457
  for _, row in plan_df.iterrows():
458
  image = agent.generate(row["Image Idea"])
459
- st.image(image, caption=f"{row['Theme']} - {row['Image Idea']}")
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
2
  import os
 
3
  import glob
4
  import base64
5
  import streamlit as st
 
30
  initial_sidebar_state="expanded",
31
  menu_items={
32
  'Get Help': 'https://huggingface.co/awacke1',
33
+ 'Report a Bug': 'https://huggingface.co/spaces/awacke1',
34
  'About': "Tiny Titans: Small models, big dreams, and a sprinkle of chaos! 🌌"
35
  }
36
  )
 
176
  total_loss = 0
177
  for batch in dataloader:
178
  optimizer.zero_grad()
179
+ image = batch["image"][0].to(self.pipeline.device)
180
+ text = batch["text"][0]
181
+ latents = self.pipeline.vae.encode(torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float().to(self.pipeline.device)).latent_dist.sample()
182
  noise = torch.randn_like(latents)
183
  timesteps = torch.randint(0, self.pipeline.scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device)
184
  noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)
 
219
  def get_gallery_files(file_types):
220
  return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
221
 
222
+ # Mock Search Tool for RAG
223
  def mock_search(query: str) -> str:
224
  if "superhero" in query.lower():
225
+ return "Latest trends: Gold-plated Batman statues, VR superhero battles."
226
  return "No relevant results found."
227
 
228
  class PartyPlannerAgent:
 
291
  for idx, file in enumerate(media_files[:gallery_size * 2]):
292
  with cols[idx % 2]:
293
  st.image(Image.open(file), caption=file, use_column_width=True)
294
+ st.markdown(get_download_link(file, "image/png", "Download Image"), unsafe_allow_html=True)
295
 
296
  st.sidebar.subheader("Model Management 🗂️")
297
  model_type = st.sidebar.selectbox("Model Type", ["Causal LM", "Diffusion"])
 
352
  f.write(img.getvalue())
353
  st.session_state['cam0_frames'].append(filename)
354
  logger.info(f"Saved frame {i} from Camera 0: {filename}")
355
+ time.sleep(1.0 / slice_count)
356
  st.session_state['captured_images'].extend(st.session_state['cam0_frames'])
357
  update_gallery()
358
  for frame in st.session_state['cam0_frames']:
 
380
  f.write(img.getvalue())
381
  st.session_state['cam1_frames'].append(filename)
382
  logger.info(f"Saved frame {i} from Camera 1: {filename}")
383
+ time.sleep(1.0 / slice_count)
384
  st.session_state['captured_images'].extend(st.session_state['cam1_frames'])
385
  update_gallery()
386
  for frame in st.session_state['cam1_frames']:
 
421
  zip_path = f"{new_config.model_path}.zip"
422
  zip_directory(new_config.model_path, zip_path)
423
  st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Diffusion Model"), unsafe_allow_html=True)
424
+ csv_path = f"sft_dataset_{int(time.time())}.csv"
425
+ with open(csv_path, "w", newline="") as f:
426
+ writer = csv.writer(f)
427
+ writer.writerow(["image", "text"])
428
+ for _, row in edited_data.iterrows():
429
+ writer.writerow([row["image"], row["text"]])
430
+ st.markdown(get_download_link(csv_path, "text/csv", "Download SFT Dataset CSV"), unsafe_allow_html=True)
431
 
432
  with tab4:
433
  st.header("Test Titan 🧪")
 
464
  st.dataframe(plan_df)
465
  for _, row in plan_df.iterrows():
466
  image = agent.generate(row["Image Idea"])
467
+ st.image(image, caption=f"{row['Theme']} - {row['Image Idea']}")
468
+
469
+ # Main App
470
+ st.sidebar.subheader("Action Logs 📜")
471
+ log_container = st.sidebar.empty()
472
+ with log_container:
473
+ for record in logger.handlers[0].buffer:
474
+ st.write(f"{record.asctime} - {record.levelname} - {record.message}")