Update app.py
Browse files
app.py
CHANGED
@@ -81,7 +81,7 @@ def get_gallery_files(file_types):
|
|
81 |
import glob
|
82 |
return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
|
83 |
|
84 |
-
# Video
|
85 |
class VideoSnapshot:
|
86 |
def __init__(self):
|
87 |
self.snapshot = None
|
@@ -94,7 +94,7 @@ class VideoSnapshot:
|
|
94 |
return self.snapshot
|
95 |
|
96 |
# Main App
|
97 |
-
st.title("SFT Tiny Titans 🚀 (Fast &
|
98 |
|
99 |
# Sidebar Galleries
|
100 |
st.sidebar.header("Media Gallery 🎨")
|
@@ -107,15 +107,15 @@ for gallery_type, file_types, emoji in [("Images 📸", ["png", "jpg", "jpeg"],
|
|
107 |
with cols[idx % 2]:
|
108 |
if "Images" in gallery_type:
|
109 |
from PIL import Image
|
110 |
-
st.image(Image.open(file), caption=file.split('/')[-1],
|
111 |
elif "Videos" in gallery_type:
|
112 |
st.video(file)
|
113 |
|
114 |
# Sidebar Model Management
|
115 |
st.sidebar.subheader("Model Hub 🗂️")
|
116 |
model_type = st.sidebar.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"])
|
117 |
-
model_options =
|
118 |
-
selected_model = st.sidebar.selectbox("Select Model", ["None"
|
119 |
if selected_model != "None" and st.sidebar.button("Load Model 📂"):
|
120 |
builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
|
121 |
config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=f"titan_{int(time.time())}", base_model=selected_model)
|
@@ -130,7 +130,7 @@ tab1, tab2, tab3, tab4 = st.tabs(["Build Titan 🌱", "Fine-Tune Titans 🔧", "
|
|
130 |
with tab1:
|
131 |
st.header("Build Titan 🌱 (Quick Start!)")
|
132 |
model_type = st.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"], key="build_type")
|
133 |
-
base_model = st.selectbox("Select Model", model_options, key="build_model")
|
134 |
if st.button("Download Model ⬇️"):
|
135 |
config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=f"titan_{int(time.time())}", base_model=base_model)
|
136 |
builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
|
@@ -175,7 +175,7 @@ with tab2:
|
|
175 |
dataloader = DataLoader(dataset, batch_size=2)
|
176 |
optimizer = torch.optim.AdamW(st.session_state['builder'].model.parameters(), lr=2e-5)
|
177 |
st.session_state['builder'].model.train()
|
178 |
-
for _ in range(1):
|
179 |
for batch in dataloader:
|
180 |
optimizer.zero_grad()
|
181 |
outputs = st.session_state['builder'].model(**{k: v.to(st.session_state['builder'].model.device) for k, v in batch.items()})
|
@@ -194,7 +194,7 @@ with tab2:
|
|
194 |
texts = text_input.splitlines()[:len(images)]
|
195 |
optimizer = torch.optim.AdamW(st.session_state['builder'].pipeline.unet.parameters(), lr=1e-5)
|
196 |
st.session_state['builder'].pipeline.unet.train()
|
197 |
-
for _ in range(1):
|
198 |
for img, text in zip(images, texts):
|
199 |
optimizer.zero_grad()
|
200 |
latents = st.session_state['builder'].pipeline.vae.encode(torch.tensor(np.array(img)).permute(2, 0, 1).unsqueeze(0).float().to(st.session_state['builder'].pipeline.device)).latent_dist.sample()
|
@@ -233,7 +233,11 @@ with tab3:
|
|
233 |
with tab4:
|
234 |
st.header("Camera Snap 📷 (Instant Shots!)")
|
235 |
from streamlit_webrtc import webrtc_streamer
|
236 |
-
ctx = webrtc_streamer(
|
|
|
|
|
|
|
|
|
237 |
if ctx.video_processor:
|
238 |
snapshot_text = st.text_input("Snapshot Text", "Live Snap")
|
239 |
if st.button("Snap It! 📸"):
|
@@ -241,7 +245,7 @@ with tab4:
|
|
241 |
if snapshot:
|
242 |
filename = generate_filename(snapshot_text)
|
243 |
snapshot.save(filename)
|
244 |
-
st.image(snapshot, caption=filename)
|
245 |
st.success("Snapped! 🎉")
|
246 |
|
247 |
# Demo Dataset
|
|
|
81 |
import glob
|
82 |
return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
|
83 |
|
84 |
+
# Video Processor for WebRTC
|
85 |
class VideoSnapshot:
|
86 |
def __init__(self):
|
87 |
self.snapshot = None
|
|
|
94 |
return self.snapshot
|
95 |
|
96 |
# Main App
|
97 |
+
st.title("SFT Tiny Titans 🚀 (Fast & Fixed!)")
|
98 |
|
99 |
# Sidebar Galleries
|
100 |
st.sidebar.header("Media Gallery 🎨")
|
|
|
107 |
with cols[idx % 2]:
|
108 |
if "Images" in gallery_type:
|
109 |
from PIL import Image
|
110 |
+
st.image(Image.open(file), caption=file.split('/')[-1], use_container_width=True)
|
111 |
elif "Videos" in gallery_type:
|
112 |
st.video(file)
|
113 |
|
114 |
# Sidebar Model Management
|
115 |
st.sidebar.subheader("Model Hub 🗂️")
|
116 |
model_type = st.sidebar.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"])
|
117 |
+
model_options = {"NLP (Causal LM)": "HuggingFaceTB/SmolLM-135M", "CV (Diffusion)": "CompVis/stable-diffusion-v1-4"}
|
118 |
+
selected_model = st.sidebar.selectbox("Select Model", ["None", model_options[model_type]])
|
119 |
if selected_model != "None" and st.sidebar.button("Load Model 📂"):
|
120 |
builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
|
121 |
config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=f"titan_{int(time.time())}", base_model=selected_model)
|
|
|
130 |
with tab1:
|
131 |
st.header("Build Titan 🌱 (Quick Start!)")
|
132 |
model_type = st.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"], key="build_type")
|
133 |
+
base_model = st.selectbox("Select Model", [model_options[model_type]], key="build_model")
|
134 |
if st.button("Download Model ⬇️"):
|
135 |
config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=f"titan_{int(time.time())}", base_model=base_model)
|
136 |
builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
|
|
|
175 |
dataloader = DataLoader(dataset, batch_size=2)
|
176 |
optimizer = torch.optim.AdamW(st.session_state['builder'].model.parameters(), lr=2e-5)
|
177 |
st.session_state['builder'].model.train()
|
178 |
+
for _ in range(1):
|
179 |
for batch in dataloader:
|
180 |
optimizer.zero_grad()
|
181 |
outputs = st.session_state['builder'].model(**{k: v.to(st.session_state['builder'].model.device) for k, v in batch.items()})
|
|
|
194 |
texts = text_input.splitlines()[:len(images)]
|
195 |
optimizer = torch.optim.AdamW(st.session_state['builder'].pipeline.unet.parameters(), lr=1e-5)
|
196 |
st.session_state['builder'].pipeline.unet.train()
|
197 |
+
for _ in range(1):
|
198 |
for img, text in zip(images, texts):
|
199 |
optimizer.zero_grad()
|
200 |
latents = st.session_state['builder'].pipeline.vae.encode(torch.tensor(np.array(img)).permute(2, 0, 1).unsqueeze(0).float().to(st.session_state['builder'].pipeline.device)).latent_dist.sample()
|
|
|
233 |
with tab4:
|
234 |
st.header("Camera Snap 📷 (Instant Shots!)")
|
235 |
from streamlit_webrtc import webrtc_streamer
|
236 |
+
ctx = webrtc_streamer(
|
237 |
+
key="camera",
|
238 |
+
video_processor_factory=VideoSnapshot,
|
239 |
+
frontend_rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
|
240 |
+
)
|
241 |
if ctx.video_processor:
|
242 |
snapshot_text = st.text_input("Snapshot Text", "Live Snap")
|
243 |
if st.button("Snap It! 📸"):
|
|
|
245 |
if snapshot:
|
246 |
filename = generate_filename(snapshot_text)
|
247 |
snapshot.save(filename)
|
248 |
+
st.image(snapshot, caption=filename, use_container_width=True)
|
249 |
st.success("Snapped! 🎉")
|
250 |
|
251 |
# Demo Dataset
|