awacke1 commited on
Commit
0b13d03
·
verified ·
1 Parent(s): b8ca8a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -104
app.py CHANGED
@@ -5,6 +5,7 @@ import streamlit as st
5
  import csv
6
  import time
7
  from dataclasses import dataclass
 
8
 
9
  st.set_page_config(page_title="SFT Tiny Titans 🚀", page_icon="🤖", layout="wide", initial_sidebar_state="expanded")
10
 
@@ -41,6 +42,37 @@ class ModelBuilder:
41
  self.tokenizer.pad_token = self.tokenizer.eos_token
42
  self.config = config
43
  self.model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def evaluate(self, prompt: str):
45
  import torch
46
  self.model.eval()
@@ -59,6 +91,25 @@ class DiffusionBuilder:
59
  self.pipeline = StableDiffusionPipeline.from_pretrained(model_path)
60
  self.pipeline.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
61
  self.config = config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def generate(self, prompt: str):
63
  return self.pipeline(prompt, num_inference_steps=20).images[0]
64
 
@@ -69,18 +120,23 @@ def get_download_link(file_path, mime_type="text/plain", label="Download"):
69
  b64 = base64.b64encode(data).decode()
70
  return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label} 📥</a>'
71
 
72
- def generate_filename(text_line):
73
  from datetime import datetime
74
  import pytz
75
  central = pytz.timezone('US/Central')
76
- timestamp = datetime.now(central).strftime("%Y%m%d_%I%M%S_%p")
77
- safe_text = ''.join(c if c.isalnum() else '_' for c in text_line[:50])
78
- return f"{timestamp}_{safe_text}.png"
79
 
80
  def get_gallery_files(file_types):
81
  import glob
82
  return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
83
 
 
 
 
 
 
 
84
  # Video Processor for WebRTC
85
  class VideoSnapshot:
86
  def __init__(self):
@@ -94,28 +150,26 @@ class VideoSnapshot:
94
  return self.snapshot
95
 
96
  # Main App
97
- st.title("SFT Tiny Titans 🚀 (Fast & Fixed!)")
98
 
99
  # Sidebar Galleries
100
- st.sidebar.header("Media Gallery 🎨")
101
- for gallery_type, file_types, emoji in [("Images 📸", ["png", "jpg", "jpeg"], "🖼️"), ("Videos 🎥", ["mp4"], "🎬")]:
102
- st.sidebar.subheader(f"{gallery_type} {emoji}")
103
- files = get_gallery_files(file_types)
104
- if files:
105
- cols = st.sidebar.columns(2)
106
- for idx, file in enumerate(files[:4]):
107
- with cols[idx % 2]:
108
- if "Images" in gallery_type:
109
- from PIL import Image
110
- st.image(Image.open(file), caption=file.split('/')[-1], use_container_width=True)
111
- elif "Videos" in gallery_type:
112
- st.video(file)
113
 
114
  # Sidebar Model Management
115
  st.sidebar.subheader("Model Hub 🗂️")
116
  model_type = st.sidebar.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"])
117
- model_options = {"NLP (Causal LM)": "HuggingFaceTB/SmolLM-135M", "CV (Diffusion)": "CompVis/stable-diffusion-v1-4"}
118
- selected_model = st.sidebar.selectbox("Select Model", ["None", model_options[model_type]])
 
 
 
119
  if selected_model != "None" and st.sidebar.button("Load Model 📂"):
120
  builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
121
  config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=f"titan_{int(time.time())}", base_model=selected_model)
@@ -130,7 +184,7 @@ tab1, tab2, tab3, tab4 = st.tabs(["Build Titan 🌱", "Fine-Tune Titans 🔧", "
130
  with tab1:
131
  st.header("Build Titan 🌱 (Quick Start!)")
132
  model_type = st.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"], key="build_type")
133
- base_model = st.selectbox("Select Model", [model_options[model_type]], key="build_model")
134
  if st.button("Download Model ⬇️"):
135
  config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=f"titan_{int(time.time())}", base_model=base_model)
136
  builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
@@ -149,67 +203,22 @@ with tab2:
149
  st.subheader("NLP Tune 🧠")
150
  uploaded_csv = st.file_uploader("Upload CSV", type="csv", key="nlp_csv")
151
  if uploaded_csv and st.button("Tune NLP 🔄"):
152
- from torch.utils.data import Dataset, DataLoader
153
- import torch
154
- class SFTDataset(Dataset):
155
- def __init__(self, data, tokenizer):
156
- self.data = data
157
- self.tokenizer = tokenizer
158
- def __len__(self):
159
- return len(self.data)
160
- def __getitem__(self, idx):
161
- prompt = self.data[idx]["prompt"]
162
- response = self.data[idx]["response"]
163
- inputs = self.tokenizer(f"{prompt} {response}", return_tensors="pt", padding="max_length", max_length=128, truncation=True)
164
- labels = inputs["input_ids"].clone()
165
- labels[0, :len(self.tokenizer(prompt)["input_ids"][0])] = -100
166
- return {"input_ids": inputs["input_ids"][0], "attention_mask": inputs["attention_mask"][0], "labels": labels[0]}
167
- data = []
168
  with open("temp.csv", "wb") as f:
169
  f.write(uploaded_csv.read())
170
- with open("temp.csv", "r") as f:
171
- reader = csv.DictReader(f)
172
- for row in reader:
173
- data.append({"prompt": row["prompt"], "response": row["response"]})
174
- dataset = SFTDataset(data, st.session_state['builder'].tokenizer)
175
- dataloader = DataLoader(dataset, batch_size=2)
176
- optimizer = torch.optim.AdamW(st.session_state['builder'].model.parameters(), lr=2e-5)
177
- st.session_state['builder'].model.train()
178
- for _ in range(1):
179
- for batch in dataloader:
180
- optimizer.zero_grad()
181
- outputs = st.session_state['builder'].model(**{k: v.to(st.session_state['builder'].model.device) for k, v in batch.items()})
182
- outputs.loss.backward()
183
- optimizer.step()
184
  st.success("NLP sharpened! 🎉")
185
  elif isinstance(st.session_state['builder'], DiffusionBuilder):
186
  st.subheader("CV Tune 🎨")
187
- uploaded_files = st.file_uploader("Upload Images", type=["png", "jpg"], accept_multiple_files=True, key="cv_upload")
188
- text_input = st.text_area("Text (one per image)", "Bat Neon\nIron Glow", key="cv_text")
189
- if uploaded_files and st.button("Tune CV 🔄"):
190
- import torch
191
- from PIL import Image
192
- import numpy as np
193
- images = [Image.open(f).convert("RGB") for f in uploaded_files]
194
- texts = text_input.splitlines()[:len(images)]
195
- optimizer = torch.optim.AdamW(st.session_state['builder'].pipeline.unet.parameters(), lr=1e-5)
196
- st.session_state['builder'].pipeline.unet.train()
197
- for _ in range(1):
198
- for img, text in zip(images, texts):
199
- optimizer.zero_grad()
200
- latents = st.session_state['builder'].pipeline.vae.encode(torch.tensor(np.array(img)).permute(2, 0, 1).unsqueeze(0).float().to(st.session_state['builder'].pipeline.device)).latent_dist.sample()
201
- noise = torch.randn_like(latents)
202
- timesteps = torch.randint(0, 1000, (1,), device=latents.device)
203
- noisy_latents = st.session_state['builder'].pipeline.scheduler.add_noise(latents, noise, timesteps)
204
- text_emb = st.session_state['builder'].pipeline.text_encoder(st.session_state['builder'].pipeline.tokenizer(text, return_tensors="pt").input_ids.to(st.session_state['builder'].pipeline.device))[0]
205
- pred_noise = st.session_state['builder'].pipeline.unet(noisy_latents, timesteps, encoder_hidden_states=text_emb).sample
206
- loss = torch.nn.functional.mse_loss(pred_noise, noise)
207
- loss.backward()
208
- optimizer.step()
209
- for img, text in zip(images, texts):
210
- filename = generate_filename(text)
211
- img.save(filename)
212
- st.success("CV polished! 🎉")
213
 
214
  with tab3:
215
  st.header("Test Titans 🧪 (Quick Check!)")
@@ -228,10 +237,10 @@ with tab3:
228
  if st.button("Test CV ▶️"):
229
  with st.spinner("Generating... ⏳"):
230
  img = st.session_state['builder'].generate(prompt)
231
- st.image(img, caption="Generated Art")
232
 
233
  with tab4:
234
- st.header("Camera Snap 📷 (Instant Shots!)")
235
  from streamlit_webrtc import webrtc_streamer
236
  ctx = webrtc_streamer(
237
  key="camera",
@@ -239,29 +248,36 @@ with tab4:
239
  frontend_rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
240
  )
241
  if ctx.video_processor:
242
- snapshot_text = st.text_input("Snapshot Text", "Live Snap")
243
- if st.button("Snap It! 📸"):
244
- snapshot = ctx.video_processor.take_snapshot()
245
- if snapshot:
246
- filename = generate_filename(snapshot_text)
247
- snapshot.save(filename)
248
- st.image(snapshot, caption=filename, use_container_width=True)
249
- st.success("Snapped! 🎉")
 
 
 
 
 
 
250
 
251
- # Demo Dataset
252
- st.subheader("Demo CV Dataset 🎨")
253
- demo_texts = ["Bat Neon", "Iron Glow"]
254
- demo_images = [generate_filename(t) for t in demo_texts]
255
- for img, text in zip(demo_images, demo_texts):
256
- if not os.path.exists(img):
257
- from PIL import Image
258
- Image.new("RGB", (100, 100)).save(img)
259
- st.code("\n".join([f"{i+1}. {t} -> {img}" for i, (t, img) in enumerate(zip(demo_texts, demo_images))]), language="text")
260
- if st.button("Download Demo CSV 📝"):
261
- csv_path = f"demo_cv_{int(time.time())}.csv"
262
- with open(csv_path, "w", newline="") as f:
263
- writer = csv.writer(f)
264
- writer.writerow(["image", "text"])
265
- for img, text in zip(demo_images, demo_texts):
266
- writer.writerow([img, text])
267
- st.markdown(get_download_link(csv_path, "text/csv", "Download Demo CSV"), unsafe_allow_html=True)
 
 
5
  import csv
6
  import time
7
  from dataclasses import dataclass
8
+ import zipfile
9
 
10
  st.set_page_config(page_title="SFT Tiny Titans 🚀", page_icon="🤖", layout="wide", initial_sidebar_state="expanded")
11
 
 
42
  self.tokenizer.pad_token = self.tokenizer.eos_token
43
  self.config = config
44
  self.model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
45
+ def fine_tune(self, csv_path):
46
+ from torch.utils.data import Dataset, DataLoader
47
+ import torch
48
+ class SFTDataset(Dataset):
49
+ def __init__(self, data, tokenizer):
50
+ self.data = data
51
+ self.tokenizer = tokenizer
52
+ def __len__(self):
53
+ return len(self.data)
54
+ def __getitem__(self, idx):
55
+ prompt = self.data[idx]["prompt"]
56
+ response = self.data[idx]["response"]
57
+ inputs = self.tokenizer(f"{prompt} {response}", return_tensors="pt", padding="max_length", max_length=128, truncation=True)
58
+ labels = inputs["input_ids"].clone()
59
+ labels[0, :len(self.tokenizer(prompt)["input_ids"][0])] = -100
60
+ return {"input_ids": inputs["input_ids"][0], "attention_mask": inputs["attention_mask"][0], "labels": labels[0]}
61
+ data = []
62
+ with open(csv_path, "r") as f:
63
+ reader = csv.DictReader(f)
64
+ for row in reader:
65
+ data.append({"prompt": row["prompt"], "response": row["response"]})
66
+ dataset = SFTDataset(data, self.tokenizer)
67
+ dataloader = DataLoader(dataset, batch_size=2)
68
+ optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
69
+ self.model.train()
70
+ for _ in range(1):
71
+ for batch in dataloader:
72
+ optimizer.zero_grad()
73
+ outputs = self.model(**{k: v.to(self.model.device) for k, v in batch.items()})
74
+ outputs.loss.backward()
75
+ optimizer.step()
76
  def evaluate(self, prompt: str):
77
  import torch
78
  self.model.eval()
 
91
  self.pipeline = StableDiffusionPipeline.from_pretrained(model_path)
92
  self.pipeline.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
93
  self.config = config
94
+ def fine_tune(self, images, texts):
95
+ import torch
96
+ from PIL import Image
97
+ import numpy as np
98
+ optimizer = torch.optim.AdamW(self.pipeline.unet.parameters(), lr=1e-5)
99
+ self.pipeline.unet.train()
100
+ for _ in range(1):
101
+ for img, text in zip(images, texts):
102
+ optimizer.zero_grad()
103
+ img_tensor = torch.tensor(np.array(img)).permute(2, 0, 1).unsqueeze(0).float().to(self.pipeline.device)
104
+ latents = self.pipeline.vae.encode(img_tensor).latent_dist.sample()
105
+ noise = torch.randn_like(latents)
106
+ timesteps = torch.randint(0, 1000, (1,), device=latents.device)
107
+ noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)
108
+ text_emb = self.pipeline.text_encoder(self.pipeline.tokenizer(text, return_tensors="pt").input_ids.to(self.pipeline.device))[0]
109
+ pred_noise = self.pipeline.unet(noisy_latents, timesteps, encoder_hidden_states=text_emb).sample
110
+ loss = torch.nn.functional.mse_loss(pred_noise, noise)
111
+ loss.backward()
112
+ optimizer.step()
113
  def generate(self, prompt: str):
114
  return self.pipeline(prompt, num_inference_steps=20).images[0]
115
 
 
120
  b64 = base64.b64encode(data).decode()
121
  return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label} 📥</a>'
122
 
123
+ def generate_filename(sequence):
124
  from datetime import datetime
125
  import pytz
126
  central = pytz.timezone('US/Central')
127
+ timestamp = datetime.now(central).strftime("%d%m%Y%H%M%S%p")
128
+ return f"{sequence}{timestamp}.png"
 
129
 
130
  def get_gallery_files(file_types):
131
  import glob
132
  return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
133
 
134
+ def zip_files(files, zip_name):
135
+ with zipfile.ZipFile(zip_name, 'w', zipfile.ZIP_DEFLATED) as zipf:
136
+ for file in files:
137
+ zipf.write(file, os.path.basename(file))
138
+ return zip_name
139
+
140
  # Video Processor for WebRTC
141
  class VideoSnapshot:
142
  def __init__(self):
 
150
  return self.snapshot
151
 
152
  # Main App
153
+ st.title("SFT Tiny Titans 🚀 (Capture & Tune!)")
154
 
155
  # Sidebar Galleries
156
+ st.sidebar.header("Captured Images 🎨")
157
+ image_files = get_gallery_files(["png"])
158
+ if image_files:
159
+ cols = st.sidebar.columns(2)
160
+ for idx, file in enumerate(image_files[:4]):
161
+ with cols[idx % 2]:
162
+ from PIL import Image
163
+ st.image(Image.open(file), caption=file.split('/')[-1], use_container_width=True)
 
 
 
 
 
164
 
165
  # Sidebar Model Management
166
  st.sidebar.subheader("Model Hub 🗂️")
167
  model_type = st.sidebar.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"])
168
+ model_options = {
169
+ "NLP (Causal LM)": "HuggingFaceTB/SmolLM-135M",
170
+ "CV (Diffusion)": ["CompVis/stable-diffusion-v1-4", "stabilityai/stable-diffusion-2-base", "runwayml/stable-diffusion-v1-5"]
171
+ }
172
+ selected_model = st.sidebar.selectbox("Select Model", ["None"] + ([model_options[model_type]] if "NLP" in model_type else model_options[model_type]))
173
  if selected_model != "None" and st.sidebar.button("Load Model 📂"):
174
  builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
175
  config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=f"titan_{int(time.time())}", base_model=selected_model)
 
184
  with tab1:
185
  st.header("Build Titan 🌱 (Quick Start!)")
186
  model_type = st.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"], key="build_type")
187
+ base_model = st.selectbox("Select Model", model_options[model_type], key="build_model")
188
  if st.button("Download Model ⬇️"):
189
  config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=f"titan_{int(time.time())}", base_model=base_model)
190
  builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
 
203
  st.subheader("NLP Tune 🧠")
204
  uploaded_csv = st.file_uploader("Upload CSV", type="csv", key="nlp_csv")
205
  if uploaded_csv and st.button("Tune NLP 🔄"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  with open("temp.csv", "wb") as f:
207
  f.write(uploaded_csv.read())
208
+ st.session_state['builder'].fine_tune("temp.csv")
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  st.success("NLP sharpened! 🎉")
210
  elif isinstance(st.session_state['builder'], DiffusionBuilder):
211
  st.subheader("CV Tune 🎨")
212
+ captured_images = get_gallery_files(["png"])
213
+ if len(captured_images) >= 2:
214
+ texts = ["Superhero Neon", "Hero Glow", "Cape Spark"][:len(captured_images)]
215
+ if st.button("Tune CV 🔄"):
216
+ from PIL import Image
217
+ images = [Image.open(img) for img in captured_images]
218
+ st.session_state['builder'].fine_tune(images, texts)
219
+ st.success("CV polished! 🎉")
220
+ else:
221
+ st.warning("Capture at least 2 images first! ⚠️")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
  with tab3:
224
  st.header("Test Titans 🧪 (Quick Check!)")
 
237
  if st.button("Test CV ▶️"):
238
  with st.spinner("Generating... ⏳"):
239
  img = st.session_state['builder'].generate(prompt)
240
+ st.image(img, caption="Generated Art", use_container_width=True)
241
 
242
  with tab4:
243
+ st.header("Camera Snap 📷 (Sequence Shots!)")
244
  from streamlit_webrtc import webrtc_streamer
245
  ctx = webrtc_streamer(
246
  key="camera",
 
248
  frontend_rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
249
  )
250
  if ctx.video_processor:
251
+ delay = st.slider("Delay between captures (seconds)", 0, 10, 2)
252
+ if st.button("Capture 6 Frames 📸"):
253
+ captured_images = []
254
+ for i in range(6):
255
+ snapshot = ctx.video_processor.take_snapshot()
256
+ if snapshot:
257
+ filename = generate_filename(i)
258
+ snapshot.save(filename)
259
+ st.image(snapshot, caption=filename, use_container_width=True)
260
+ captured_images.append(filename)
261
+ time.sleep(delay)
262
+ st.success("6 frames captured! 🎉")
263
+ if len(captured_images) >= 2:
264
+ st.session_state['captured_images'] = captured_images
265
 
266
+ # Dataset and ZIP Download
267
+ if 'captured_images' in st.session_state and len(st.session_state['captured_images']) >= 2:
268
+ st.subheader("Diffusion SFT Dataset 🎨")
269
+ sample_texts = ["Neon Hero", "Glowing Cape", "Spark Flyer", "Dark Knight", "Iron Shine", "Thunder Bolt"]
270
+ dataset = list(zip(st.session_state['captured_images'], sample_texts[:len(st.session_state['captured_images'])]))
271
+ st.code("\n".join([f"{i+1}. {text} -> {img}" for i, (img, text) in enumerate(dataset)]), language="text")
272
+ if st.button("Download Dataset CSV 📝"):
273
+ csv_path = f"diffusion_sft_{int(time.time())}.csv"
274
+ with open(csv_path, "w", newline="") as f:
275
+ writer = csv.writer(f)
276
+ writer.writerow(["image", "text"])
277
+ for img, text in dataset:
278
+ writer.writerow([img, text])
279
+ st.markdown(get_download_link(csv_path, "text/csv", "Download Dataset CSV"), unsafe_allow_html=True)
280
+ if st.button("Download Images ZIP 📦"):
281
+ zip_path = f"captured_images_{int(time.time())}.zip"
282
+ zip_files(st.session_state['captured_images'], zip_path)
283
+ st.markdown(get_download_link(zip_path, "application/zip", "Download Images ZIP"), unsafe_allow_html=True)