awacke1 commited on
Commit
8ff3549
·
verified ·
1 Parent(s): 0525548

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -43
app.py CHANGED
@@ -2,17 +2,9 @@
2
  import os
3
  import base64
4
  import streamlit as st
5
- import pandas as pd
6
  import csv
7
  import time
8
  from dataclasses import dataclass
9
- from PIL import Image
10
- from datetime import datetime
11
- import pytz
12
- from streamlit_webrtc import webrtc_streamer, VideoTransformerBase
13
- import av
14
-
15
- # Minimal initial imports to reduce startup delay
16
 
17
  st.set_page_config(page_title="SFT Tiny Titans 🚀", page_icon="🤖", layout="wide", initial_sidebar_state="expanded")
18
 
@@ -78,41 +70,43 @@ def get_download_link(file_path, mime_type="text/plain", label="Download"):
78
  return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label} 📥</a>'
79
 
80
  def generate_filename(text_line):
 
 
81
  central = pytz.timezone('US/Central')
82
  timestamp = datetime.now(central).strftime("%Y%m%d_%I%M%S_%p")
83
  safe_text = ''.join(c if c.isalnum() else '_' for c in text_line[:50])
84
  return f"{timestamp}_{safe_text}.png"
85
 
86
  def get_gallery_files(file_types):
 
87
  return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
88
 
89
  # Video Transformer for WebRTC
90
- class VideoSnapshot(VideoTransformerBase):
91
  def __init__(self):
92
  self.snapshot = None
93
- def transform(self, frame):
94
- img = frame.to_ndarray(format="bgr24")
95
- return img
 
 
96
  def take_snapshot(self):
97
- if self.snapshot is not None:
98
- return Image.fromarray(self.snapshot)
99
 
100
  # Main App
101
- st.title("SFT Tiny Titans 🚀 (Lean & Mean!)")
102
 
103
  # Sidebar Galleries
104
  st.sidebar.header("Media Gallery 🎨")
105
- for gallery_type, file_types, emoji in [
106
- ("Images 📸", ["png", "jpg", "jpeg"], "🖼️"),
107
- ("Videos 🎥", ["mp4"], "🎬")
108
- ]:
109
  st.sidebar.subheader(f"{gallery_type} {emoji}")
110
  files = get_gallery_files(file_types)
111
  if files:
112
- cols = st.sidebar.columns(3)
113
- for idx, file in enumerate(files[:6]):
114
- with cols[idx % 3]:
115
  if "Images" in gallery_type:
 
116
  st.image(Image.open(file), caption=file.split('/')[-1], use_column_width=True)
117
  elif "Videos" in gallery_type:
118
  st.video(file)
@@ -120,7 +114,7 @@ for gallery_type, file_types, emoji in [
120
  # Sidebar Model Management
121
  st.sidebar.subheader("Model Hub 🗂️")
122
  model_type = st.sidebar.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"])
123
- model_options = ["HuggingFaceTB/SmolLM-135M", "Qwen/Qwen1.5-0.5B-Chat"] if "NLP" in model_type else ["stabilityai/stable-diffusion-2-1", "CompVis/stable-diffusion-v1-4"]
124
  selected_model = st.sidebar.selectbox("Select Model", ["None"] + model_options)
125
  if selected_model != "None" and st.sidebar.button("Load Model 📂"):
126
  builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
@@ -131,15 +125,10 @@ if selected_model != "None" and st.sidebar.button("Load Model 📂"):
131
  st.session_state['model_loaded'] = True
132
 
133
  # Tabs
134
- tab1, tab2, tab3, tab4 = st.tabs([
135
- "Build Titan 🌱",
136
- "Fine-Tune Titans 🔧",
137
- "Test Titans 🧪",
138
- "Camera Snap 📷"
139
- ])
140
 
141
  with tab1:
142
- st.header("Build Titan 🌱 (Start Small!)")
143
  model_type = st.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"], key="build_type")
144
  base_model = st.selectbox("Select Model", model_options, key="build_model")
145
  if st.button("Download Model ⬇️"):
@@ -149,10 +138,10 @@ with tab1:
149
  builder.load_model(base_model, config)
150
  st.session_state['builder'] = builder
151
  st.session_state['model_loaded'] = True
152
- st.success("Titan ready! 🎉")
153
 
154
  with tab2:
155
- st.header("Fine-Tune Titans 🔧 (Sharpen Up!)")
156
  if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
157
  st.warning("Load a Titan first! ⚠️")
158
  else:
@@ -186,24 +175,26 @@ with tab2:
186
  dataloader = DataLoader(dataset, batch_size=2)
187
  optimizer = torch.optim.AdamW(st.session_state['builder'].model.parameters(), lr=2e-5)
188
  st.session_state['builder'].model.train()
189
- for _ in range(3): # Simplified epochs
190
  for batch in dataloader:
191
  optimizer.zero_grad()
192
  outputs = st.session_state['builder'].model(**{k: v.to(st.session_state['builder'].model.device) for k, v in batch.items()})
193
  outputs.loss.backward()
194
  optimizer.step()
195
- st.success("NLP tuned! 🎉")
196
  elif isinstance(st.session_state['builder'], DiffusionBuilder):
197
  st.subheader("CV Tune 🎨")
198
  uploaded_files = st.file_uploader("Upload Images", type=["png", "jpg"], accept_multiple_files=True, key="cv_upload")
199
  text_input = st.text_area("Text (one per image)", "Bat Neon\nIron Glow", key="cv_text")
200
  if uploaded_files and st.button("Tune CV 🔄"):
201
  import torch
 
 
202
  images = [Image.open(f).convert("RGB") for f in uploaded_files]
203
  texts = text_input.splitlines()[:len(images)]
204
  optimizer = torch.optim.AdamW(st.session_state['builder'].pipeline.unet.parameters(), lr=1e-5)
205
  st.session_state['builder'].pipeline.unet.train()
206
- for _ in range(3): # Simplified epochs
207
  for img, text in zip(images, texts):
208
  optimizer.zero_grad()
209
  latents = st.session_state['builder'].pipeline.vae.encode(torch.tensor(np.array(img)).permute(2, 0, 1).unsqueeze(0).float().to(st.session_state['builder'].pipeline.device)).latent_dist.sample()
@@ -218,16 +209,16 @@ with tab2:
218
  for img, text in zip(images, texts):
219
  filename = generate_filename(text)
220
  img.save(filename)
221
- st.success("CV tuned! 🎉")
222
 
223
  with tab3:
224
- st.header("Test Titans 🧪 (Showtime!)")
225
  if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
226
  st.warning("Load a Titan first! ⚠️")
227
  else:
228
  if isinstance(st.session_state['builder'], ModelBuilder):
229
  st.subheader("NLP Test 🧠")
230
- prompt = st.text_area("Prompt", "What’s a superhero party?", key="nlp_test")
231
  if st.button("Test NLP ▶️"):
232
  result = st.session_state['builder'].evaluate(prompt)
233
  st.write(f"**Answer**: {result}")
@@ -240,12 +231,13 @@ with tab3:
240
  st.image(img, caption="Generated Art")
241
 
242
  with tab4:
243
- st.header("Camera Snap 📷 (Live Action!)")
244
- ctx = webrtc_streamer(key="camera", video_transformer_factory=VideoSnapshot, rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]})
245
- if ctx.video_transformer:
 
246
  snapshot_text = st.text_input("Snapshot Text", "Live Snap")
247
  if st.button("Snap It! 📸"):
248
- snapshot = ctx.video_transformer.take_snapshot()
249
  if snapshot:
250
  filename = generate_filename(snapshot_text)
251
  snapshot.save(filename)
@@ -254,10 +246,11 @@ with tab4:
254
 
255
  # Demo Dataset
256
  st.subheader("Demo CV Dataset 🎨")
257
- demo_texts = ["Bat Neon", "Iron Glow", "Thor Spark"]
258
  demo_images = [generate_filename(t) for t in demo_texts]
259
  for img, text in zip(demo_images, demo_texts):
260
  if not os.path.exists(img):
 
261
  Image.new("RGB", (100, 100)).save(img)
262
  st.code("\n".join([f"{i+1}. {t} -> {img}" for i, (t, img) in enumerate(zip(demo_texts, demo_images))]), language="text")
263
  if st.button("Download Demo CSV 📝"):
 
2
  import os
3
  import base64
4
  import streamlit as st
 
5
  import csv
6
  import time
7
  from dataclasses import dataclass
 
 
 
 
 
 
 
8
 
9
  st.set_page_config(page_title="SFT Tiny Titans 🚀", page_icon="🤖", layout="wide", initial_sidebar_state="expanded")
10
 
 
70
  return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label} 📥</a>'
71
 
72
  def generate_filename(text_line):
73
+ from datetime import datetime
74
+ import pytz
75
  central = pytz.timezone('US/Central')
76
  timestamp = datetime.now(central).strftime("%Y%m%d_%I%M%S_%p")
77
  safe_text = ''.join(c if c.isalnum() else '_' for c in text_line[:50])
78
  return f"{timestamp}_{safe_text}.png"
79
 
80
  def get_gallery_files(file_types):
81
+ import glob
82
  return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
83
 
84
  # Video Transformer for WebRTC
85
+ class VideoSnapshot:
86
  def __init__(self):
87
  self.snapshot = None
88
+ def recv(self, frame):
89
+ from PIL import Image
90
+ img = frame.to_image()
91
+ self.snapshot = img
92
+ return frame
93
  def take_snapshot(self):
94
+ return self.snapshot
 
95
 
96
  # Main App
97
+ st.title("SFT Tiny Titans 🚀 (Fast & Furious!)")
98
 
99
  # Sidebar Galleries
100
  st.sidebar.header("Media Gallery 🎨")
101
+ for gallery_type, file_types, emoji in [("Images 📸", ["png", "jpg", "jpeg"], "🖼️"), ("Videos 🎥", ["mp4"], "🎬")]:
 
 
 
102
  st.sidebar.subheader(f"{gallery_type} {emoji}")
103
  files = get_gallery_files(file_types)
104
  if files:
105
+ cols = st.sidebar.columns(2)
106
+ for idx, file in enumerate(files[:4]):
107
+ with cols[idx % 2]:
108
  if "Images" in gallery_type:
109
+ from PIL import Image
110
  st.image(Image.open(file), caption=file.split('/')[-1], use_column_width=True)
111
  elif "Videos" in gallery_type:
112
  st.video(file)
 
114
  # Sidebar Model Management
115
  st.sidebar.subheader("Model Hub 🗂️")
116
  model_type = st.sidebar.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"])
117
+ model_options = ["HuggingFaceTB/SmolLM-135M", "Qwen/Qwen1.5-0.5B-Chat"] if "NLP" in model_type else ["CompVis/stable-diffusion-v1-4"]
118
  selected_model = st.sidebar.selectbox("Select Model", ["None"] + model_options)
119
  if selected_model != "None" and st.sidebar.button("Load Model 📂"):
120
  builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
 
125
  st.session_state['model_loaded'] = True
126
 
127
  # Tabs
128
+ tab1, tab2, tab3, tab4 = st.tabs(["Build Titan 🌱", "Fine-Tune Titans 🔧", "Test Titans 🧪", "Camera Snap 📷"])
 
 
 
 
 
129
 
130
  with tab1:
131
+ st.header("Build Titan 🌱 (Quick Start!)")
132
  model_type = st.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"], key="build_type")
133
  base_model = st.selectbox("Select Model", model_options, key="build_model")
134
  if st.button("Download Model ⬇️"):
 
138
  builder.load_model(base_model, config)
139
  st.session_state['builder'] = builder
140
  st.session_state['model_loaded'] = True
141
+ st.success("Titan up! 🎉")
142
 
143
  with tab2:
144
+ st.header("Fine-Tune Titans 🔧 (Tune Fast!)")
145
  if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
146
  st.warning("Load a Titan first! ⚠️")
147
  else:
 
175
  dataloader = DataLoader(dataset, batch_size=2)
176
  optimizer = torch.optim.AdamW(st.session_state['builder'].model.parameters(), lr=2e-5)
177
  st.session_state['builder'].model.train()
178
+ for _ in range(1): # Minimal epochs
179
  for batch in dataloader:
180
  optimizer.zero_grad()
181
  outputs = st.session_state['builder'].model(**{k: v.to(st.session_state['builder'].model.device) for k, v in batch.items()})
182
  outputs.loss.backward()
183
  optimizer.step()
184
+ st.success("NLP sharpened! 🎉")
185
  elif isinstance(st.session_state['builder'], DiffusionBuilder):
186
  st.subheader("CV Tune 🎨")
187
  uploaded_files = st.file_uploader("Upload Images", type=["png", "jpg"], accept_multiple_files=True, key="cv_upload")
188
  text_input = st.text_area("Text (one per image)", "Bat Neon\nIron Glow", key="cv_text")
189
  if uploaded_files and st.button("Tune CV 🔄"):
190
  import torch
191
+ from PIL import Image
192
+ import numpy as np
193
  images = [Image.open(f).convert("RGB") for f in uploaded_files]
194
  texts = text_input.splitlines()[:len(images)]
195
  optimizer = torch.optim.AdamW(st.session_state['builder'].pipeline.unet.parameters(), lr=1e-5)
196
  st.session_state['builder'].pipeline.unet.train()
197
+ for _ in range(1): # Minimal epochs
198
  for img, text in zip(images, texts):
199
  optimizer.zero_grad()
200
  latents = st.session_state['builder'].pipeline.vae.encode(torch.tensor(np.array(img)).permute(2, 0, 1).unsqueeze(0).float().to(st.session_state['builder'].pipeline.device)).latent_dist.sample()
 
209
  for img, text in zip(images, texts):
210
  filename = generate_filename(text)
211
  img.save(filename)
212
+ st.success("CV polished! 🎉")
213
 
214
  with tab3:
215
+ st.header("Test Titans 🧪 (Quick Check!)")
216
  if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
217
  st.warning("Load a Titan first! ⚠️")
218
  else:
219
  if isinstance(st.session_state['builder'], ModelBuilder):
220
  st.subheader("NLP Test 🧠")
221
+ prompt = st.text_area("Prompt", "What’s a superhero?", key="nlp_test")
222
  if st.button("Test NLP ▶️"):
223
  result = st.session_state['builder'].evaluate(prompt)
224
  st.write(f"**Answer**: {result}")
 
231
  st.image(img, caption="Generated Art")
232
 
233
  with tab4:
234
+ st.header("Camera Snap 📷 (Instant Shots!)")
235
+ from streamlit_webrtc import webrtc_streamer
236
+ ctx = webrtc_streamer(key="camera", video_processor_factory=VideoSnapshot, rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]})
237
+ if ctx.video_processor:
238
  snapshot_text = st.text_input("Snapshot Text", "Live Snap")
239
  if st.button("Snap It! 📸"):
240
+ snapshot = ctx.video_processor.take_snapshot()
241
  if snapshot:
242
  filename = generate_filename(snapshot_text)
243
  snapshot.save(filename)
 
246
 
247
  # Demo Dataset
248
  st.subheader("Demo CV Dataset 🎨")
249
+ demo_texts = ["Bat Neon", "Iron Glow"]
250
  demo_images = [generate_filename(t) for t in demo_texts]
251
  for img, text in zip(demo_images, demo_texts):
252
  if not os.path.exists(img):
253
+ from PIL import Image
254
  Image.new("RGB", (100, 100)).save(img)
255
  st.code("\n".join([f"{i+1}. {t} -> {img}" for i, (t, img) in enumerate(zip(demo_texts, demo_images))]), language="text")
256
  if st.button("Download Demo CSV 📝"):