awacke1 commited on
Commit
b8ca8a3
·
verified ·
1 Parent(s): 8ff3549

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -10
app.py CHANGED
@@ -81,7 +81,7 @@ def get_gallery_files(file_types):
81
  import glob
82
  return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
83
 
84
- # Video Transformer for WebRTC
85
  class VideoSnapshot:
86
  def __init__(self):
87
  self.snapshot = None
@@ -94,7 +94,7 @@ class VideoSnapshot:
94
  return self.snapshot
95
 
96
  # Main App
97
- st.title("SFT Tiny Titans 🚀 (Fast & Furious!)")
98
 
99
  # Sidebar Galleries
100
  st.sidebar.header("Media Gallery 🎨")
@@ -107,15 +107,15 @@ for gallery_type, file_types, emoji in [("Images 📸", ["png", "jpg", "jpeg"],
107
  with cols[idx % 2]:
108
  if "Images" in gallery_type:
109
  from PIL import Image
110
- st.image(Image.open(file), caption=file.split('/')[-1], use_column_width=True)
111
  elif "Videos" in gallery_type:
112
  st.video(file)
113
 
114
  # Sidebar Model Management
115
  st.sidebar.subheader("Model Hub 🗂️")
116
  model_type = st.sidebar.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"])
117
- model_options = ["HuggingFaceTB/SmolLM-135M", "Qwen/Qwen1.5-0.5B-Chat"] if "NLP" in model_type else ["CompVis/stable-diffusion-v1-4"]
118
- selected_model = st.sidebar.selectbox("Select Model", ["None"] + model_options)
119
  if selected_model != "None" and st.sidebar.button("Load Model 📂"):
120
  builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
121
  config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=f"titan_{int(time.time())}", base_model=selected_model)
@@ -130,7 +130,7 @@ tab1, tab2, tab3, tab4 = st.tabs(["Build Titan 🌱", "Fine-Tune Titans 🔧", "
130
  with tab1:
131
  st.header("Build Titan 🌱 (Quick Start!)")
132
  model_type = st.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"], key="build_type")
133
- base_model = st.selectbox("Select Model", model_options, key="build_model")
134
  if st.button("Download Model ⬇️"):
135
  config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=f"titan_{int(time.time())}", base_model=base_model)
136
  builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
@@ -175,7 +175,7 @@ with tab2:
175
  dataloader = DataLoader(dataset, batch_size=2)
176
  optimizer = torch.optim.AdamW(st.session_state['builder'].model.parameters(), lr=2e-5)
177
  st.session_state['builder'].model.train()
178
- for _ in range(1): # Minimal epochs
179
  for batch in dataloader:
180
  optimizer.zero_grad()
181
  outputs = st.session_state['builder'].model(**{k: v.to(st.session_state['builder'].model.device) for k, v in batch.items()})
@@ -194,7 +194,7 @@ with tab2:
194
  texts = text_input.splitlines()[:len(images)]
195
  optimizer = torch.optim.AdamW(st.session_state['builder'].pipeline.unet.parameters(), lr=1e-5)
196
  st.session_state['builder'].pipeline.unet.train()
197
- for _ in range(1): # Minimal epochs
198
  for img, text in zip(images, texts):
199
  optimizer.zero_grad()
200
  latents = st.session_state['builder'].pipeline.vae.encode(torch.tensor(np.array(img)).permute(2, 0, 1).unsqueeze(0).float().to(st.session_state['builder'].pipeline.device)).latent_dist.sample()
@@ -233,7 +233,11 @@ with tab3:
233
  with tab4:
234
  st.header("Camera Snap 📷 (Instant Shots!)")
235
  from streamlit_webrtc import webrtc_streamer
236
- ctx = webrtc_streamer(key="camera", video_processor_factory=VideoSnapshot, rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]})
 
 
 
 
237
  if ctx.video_processor:
238
  snapshot_text = st.text_input("Snapshot Text", "Live Snap")
239
  if st.button("Snap It! 📸"):
@@ -241,7 +245,7 @@ with tab4:
241
  if snapshot:
242
  filename = generate_filename(snapshot_text)
243
  snapshot.save(filename)
244
- st.image(snapshot, caption=filename)
245
  st.success("Snapped! 🎉")
246
 
247
  # Demo Dataset
 
81
  import glob
82
  return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
83
 
84
+ # Video Processor for WebRTC
85
  class VideoSnapshot:
86
  def __init__(self):
87
  self.snapshot = None
 
94
  return self.snapshot
95
 
96
  # Main App
97
+ st.title("SFT Tiny Titans 🚀 (Fast & Fixed!)")
98
 
99
  # Sidebar Galleries
100
  st.sidebar.header("Media Gallery 🎨")
 
107
  with cols[idx % 2]:
108
  if "Images" in gallery_type:
109
  from PIL import Image
110
+ st.image(Image.open(file), caption=file.split('/')[-1], use_container_width=True)
111
  elif "Videos" in gallery_type:
112
  st.video(file)
113
 
114
  # Sidebar Model Management
115
  st.sidebar.subheader("Model Hub 🗂️")
116
  model_type = st.sidebar.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"])
117
+ model_options = {"NLP (Causal LM)": "HuggingFaceTB/SmolLM-135M", "CV (Diffusion)": "CompVis/stable-diffusion-v1-4"}
118
+ selected_model = st.sidebar.selectbox("Select Model", ["None", model_options[model_type]])
119
  if selected_model != "None" and st.sidebar.button("Load Model 📂"):
120
  builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
121
  config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=f"titan_{int(time.time())}", base_model=selected_model)
 
130
  with tab1:
131
  st.header("Build Titan 🌱 (Quick Start!)")
132
  model_type = st.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"], key="build_type")
133
+ base_model = st.selectbox("Select Model", [model_options[model_type]], key="build_model")
134
  if st.button("Download Model ⬇️"):
135
  config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=f"titan_{int(time.time())}", base_model=base_model)
136
  builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
 
175
  dataloader = DataLoader(dataset, batch_size=2)
176
  optimizer = torch.optim.AdamW(st.session_state['builder'].model.parameters(), lr=2e-5)
177
  st.session_state['builder'].model.train()
178
+ for _ in range(1):
179
  for batch in dataloader:
180
  optimizer.zero_grad()
181
  outputs = st.session_state['builder'].model(**{k: v.to(st.session_state['builder'].model.device) for k, v in batch.items()})
 
194
  texts = text_input.splitlines()[:len(images)]
195
  optimizer = torch.optim.AdamW(st.session_state['builder'].pipeline.unet.parameters(), lr=1e-5)
196
  st.session_state['builder'].pipeline.unet.train()
197
+ for _ in range(1):
198
  for img, text in zip(images, texts):
199
  optimizer.zero_grad()
200
  latents = st.session_state['builder'].pipeline.vae.encode(torch.tensor(np.array(img)).permute(2, 0, 1).unsqueeze(0).float().to(st.session_state['builder'].pipeline.device)).latent_dist.sample()
 
233
  with tab4:
234
  st.header("Camera Snap 📷 (Instant Shots!)")
235
  from streamlit_webrtc import webrtc_streamer
236
+ ctx = webrtc_streamer(
237
+ key="camera",
238
+ video_processor_factory=VideoSnapshot,
239
+ frontend_rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
240
+ )
241
  if ctx.video_processor:
242
  snapshot_text = st.text_input("Snapshot Text", "Live Snap")
243
  if st.button("Snap It! 📸"):
 
245
  if snapshot:
246
  filename = generate_filename(snapshot_text)
247
  snapshot.save(filename)
248
+ st.image(snapshot, caption=filename, use_container_width=True)
249
  st.success("Snapped! 🎉")
250
 
251
  # Demo Dataset