awacke1 commited on
Commit
7ca272c
·
verified ·
1 Parent(s): 1d38074

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -84
app.py CHANGED
@@ -7,11 +7,13 @@ import time
7
  from dataclasses import dataclass
8
  import zipfile
9
  import logging
 
 
10
 
11
- # Logging setup with custom log storage
12
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
13
  logger = logging.getLogger(__name__)
14
- log_records = [] # Custom list to store logs
15
 
16
  class LogCaptureHandler(logging.Handler):
17
  def emit(self, record):
@@ -170,12 +172,12 @@ def get_download_link(file_path, mime_type="text/plain", label="Download"):
170
  b64 = base64.b64encode(data).decode()
171
  return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label} 📥</a>'
172
 
173
- def generate_filename(sequence):
174
  from datetime import datetime
175
  import pytz
176
  central = pytz.timezone('US/Central')
177
  timestamp = datetime.now(central).strftime("%d%m%Y%H%M%S%p")
178
- return f"{sequence}{timestamp}.png"
179
 
180
  def get_gallery_files(file_types):
181
  import glob
@@ -188,29 +190,52 @@ def zip_files(files, zip_name):
188
  return zip_name
189
 
190
  # Video Processor for WebRTC
191
- class VideoSnapshot:
192
  def __init__(self):
193
  self.snapshot = None
 
 
 
 
194
  def recv(self, frame):
195
  from PIL import Image
196
  img = frame.to_image()
197
  self.snapshot = img
198
- return frame
 
 
 
199
  def take_snapshot(self):
 
200
  return self.snapshot
201
 
 
 
 
 
 
 
 
 
 
202
  # Main App
203
- st.title("SFT Tiny Titans 🚀 (Capture & Tune!)")
204
 
205
  # Sidebar Galleries
206
- st.sidebar.header("Captured Images 🎨")
207
- image_files = get_gallery_files(["png"])
208
- if image_files:
209
- cols = st.sidebar.columns(2)
210
- for idx, file in enumerate(image_files[:4]):
211
- with cols[idx % 2]:
212
- from PIL import Image
213
- st.image(Image.open(file), caption=file.split('/')[-1], use_container_width=True)
 
 
 
 
 
 
214
 
215
  # Sidebar Model Management
216
  st.sidebar.subheader("Model Hub 🗂️")
@@ -252,63 +277,83 @@ with tab1:
252
  st.error(f"Download failed: {str(e)}")
253
 
254
  with tab2:
255
- st.header("Camera Snap 📷 (Sequence Shots!)")
256
- from streamlit_webrtc import webrtc_streamer
257
- ctx = webrtc_streamer(
258
- key="camera",
259
- video_processor_factory=VideoSnapshot,
260
- frontend_rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
261
- )
262
- if ctx.video_processor:
263
- delay = st.slider("Delay between captures (seconds)", 0, 10, 2)
264
- if st.button("Capture 6 Frames 📸"):
265
- logger.info("Starting 6-frame capture")
266
- captured_images = []
267
- try:
268
- for i in range(6):
269
- snapshot = ctx.video_processor.take_snapshot()
270
- if snapshot:
271
- filename = generate_filename(i)
272
- snapshot.save(filename)
273
- st.image(snapshot, caption=filename, use_container_width=True)
274
- captured_images.append(filename)
275
- logger.info(f"Captured frame {i}: {filename}")
276
- time.sleep(delay)
277
- st.success("6 frames captured! 🎉")
278
- st.session_state['captured_images'] = captured_images
279
- except Exception as e:
280
- st.error(f"Capture failed: {str(e)}")
281
- logger.error(f"Error during capture: {str(e)}")
282
-
283
- if 'captured_images' in st.session_state and len(st.session_state['captured_images']) >= 2:
284
- st.subheader("Diffusion SFT Dataset 🎨")
285
- sample_texts = ["Neon Hero", "Glowing Cape", "Spark Flyer", "Dark Knight", "Iron Shine", "Thunder Bolt"]
286
- dataset = list(zip(st.session_state['captured_images'], sample_texts[:len(st.session_state['captured_images'])]))
287
- st.code("\n".join([f"{i+1}. {text} -> {img}" for i, (img, text) in enumerate(dataset)]), language="text")
288
- if st.button("Download Dataset CSV 📝"):
289
- logger.info("Generating dataset CSV")
290
- try:
291
- csv_path = f"diffusion_sft_{int(time.time())}.csv"
292
- with open(csv_path, "w", newline="") as f:
293
- writer = csv.writer(f)
294
- writer.writerow(["image", "text"])
295
- for img, text in dataset:
296
- writer.writerow([img, text])
297
- st.markdown(get_download_link(csv_path, "text/csv", "Download Dataset CSV"), unsafe_allow_html=True)
298
- logger.info("Dataset CSV generated")
299
- except Exception as e:
300
- st.error(f"CSV generation failed: {str(e)}")
301
- logger.error(f"Error generating CSV: {str(e)}")
302
- if st.button("Download Images ZIP 📦"):
303
- logger.info("Generating images ZIP")
304
- try:
305
- zip_path = f"captured_images_{int(time.time())}.zip"
306
- zip_files(st.session_state['captured_images'], zip_path)
307
- st.markdown(get_download_link(zip_path, "application/zip", "Download Images ZIP"), unsafe_allow_html=True)
308
- logger.info("Images ZIP generated")
309
- except Exception as e:
310
- st.error(f"ZIP generation failed: {str(e)}")
311
- logger.error(f"Error generating ZIP: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
  with tab3:
314
  st.header("Fine-Tune Titans 🔧 (Tune Fast!)")
@@ -345,7 +390,7 @@ with tab3:
345
  st.warning("Capture at least 2 images first! ⚠️")
346
 
347
  with tab4:
348
- st.header("Test Titans 🧪 (Quick Check!)")
349
  if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
350
  st.warning("Load a Titan first! ⚠️")
351
  else:
@@ -360,20 +405,43 @@ with tab4:
360
  except Exception as e:
361
  st.error(f"NLP test failed: {str(e)}")
362
  elif isinstance(st.session_state['builder'], DiffusionBuilder):
363
- st.subheader("CV Test 🎨")
364
- prompt = st.text_area("Prompt", "Neon Batman", key="cv_test")
365
- if st.button("Test CV ▶️"):
366
- logger.info("Running CV test")
367
- try:
368
- with st.spinner("Generating... ⏳"):
369
- img = st.session_state['builder'].generate(prompt)
370
- st.image(img, caption="Generated Art", use_container_width=True)
371
- except Exception as e:
372
- st.error(f"CV test failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
  # Display Logs
375
  st.sidebar.subheader("Action Logs 📜")
376
  log_container = st.sidebar.empty()
377
  with log_container:
378
  for record in log_records:
379
- st.write(f"{record.asctime} - {record.levelname} - {record.message}")
 
 
 
7
  from dataclasses import dataclass
8
  import zipfile
9
  import logging
10
+ import av
11
+ from streamlit_webrtc import webrtc_streamer, VideoProcessorBase, WebRtcMode
12
 
13
+ # Logging setup
14
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
15
  logger = logging.getLogger(__name__)
16
+ log_records = []
17
 
18
  class LogCaptureHandler(logging.Handler):
19
  def emit(self, record):
 
172
  b64 = base64.b64encode(data).decode()
173
  return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label} 📥</a>'
174
 
175
+ def generate_filename(sequence, ext="png"):
176
  from datetime import datetime
177
  import pytz
178
  central = pytz.timezone('US/Central')
179
  timestamp = datetime.now(central).strftime("%d%m%Y%H%M%S%p")
180
+ return f"{sequence}{timestamp}.{ext}"
181
 
182
  def get_gallery_files(file_types):
183
  import glob
 
190
  return zip_name
191
 
192
  # Video Processor for WebRTC
193
+ class CameraProcessor(VideoProcessorBase):
194
  def __init__(self):
195
  self.snapshot = None
196
+ self.recording = False
197
+ self.frames = []
198
+ self.start_time = None
199
+
200
  def recv(self, frame):
201
  from PIL import Image
202
  img = frame.to_image()
203
  self.snapshot = img
204
+ if self.recording and time.time() - self.start_time < 10:
205
+ self.frames.append(frame.to_ndarray(format="bgr24"))
206
+ return av.VideoFrame.from_image(img)
207
+
208
  def take_snapshot(self):
209
+ from PIL import Image
210
  return self.snapshot
211
 
212
+ def start_recording(self):
213
+ self.recording = True
214
+ self.frames = []
215
+ self.start_time = time.time()
216
+
217
+ def stop_recording(self):
218
+ self.recording = False
219
+ return self.frames
220
+
221
  # Main App
222
+ st.title("SFT Tiny Titans 🚀 (Dual Cam Action!)")
223
 
224
  # Sidebar Galleries
225
+ st.sidebar.header("Captured Media 🎨")
226
+ gallery_container = st.sidebar.empty()
227
+ def update_gallery():
228
+ media_files = get_gallery_files(["png", "mp4"])
229
+ with gallery_container:
230
+ if media_files:
231
+ cols = st.columns(2)
232
+ for idx, file in enumerate(media_files[:4]):
233
+ with cols[idx % 2]:
234
+ if file.endswith(".png"):
235
+ from PIL import Image
236
+ st.image(Image.open(file), caption=file.split('/')[-1], use_container_width=True)
237
+ elif file.endswith(".mp4"):
238
+ st.video(file)
239
 
240
  # Sidebar Model Management
241
  st.sidebar.subheader("Model Hub 🗂️")
 
277
  st.error(f"Download failed: {str(e)}")
278
 
279
  with tab2:
280
+ st.header("Camera Snap 📷 (Dual Live Feed!)")
281
+ cols = st.columns(2)
282
+ processors = {}
283
+ for i in range(2):
284
+ with cols[i]:
285
+ st.subheader(f"Camera {i}")
286
+ key = f"camera_{i}"
287
+ processors[key] = webrtc_streamer(
288
+ key=key,
289
+ mode=WebRtcMode.SENDRECV,
290
+ video_processor_factory=CameraProcessor,
291
+ frontend_rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
292
+ )
293
+ if processors[key].video_processor:
294
+ if st.button(f"Capture 📸 Cam {i}", key=f"snap_{i}"):
295
+ logger.info(f"Capturing snapshot from Camera {i}")
296
+ try:
297
+ snapshot = processors[key].video_processor.take_snapshot()
298
+ if snapshot:
299
+ filename = generate_filename(i)
300
+ snapshot.save(filename)
301
+ st.image(snapshot, caption=filename, use_container_width=True)
302
+ logger.info(f"Saved snapshot: {filename}")
303
+ if 'captured_images' not in st.session_state:
304
+ st.session_state['captured_images'] = []
305
+ st.session_state['captured_images'].append(filename)
306
+ update_gallery()
307
+ except Exception as e:
308
+ st.error(f"Snapshot failed: {str(e)}")
309
+ logger.error(f"Error capturing snapshot: {str(e)}")
310
+ record_key = f"record_{i}"
311
+ if record_key not in st.session_state:
312
+ st.session_state[record_key] = False
313
+ if st.button(f"{'Stop' if st.session_state[record_key] else 'Record'} 🎥 Cam {i}", key=f"rec_{i}"):
314
+ if not st.session_state[record_key]:
315
+ logger.info(f"Starting recording from Camera {i}")
316
+ try:
317
+ processors[key].video_processor.start_recording()
318
+ st.session_state[record_key] = True
319
+ except Exception as e:
320
+ st.error(f"Start recording failed: {str(e)}")
321
+ logger.error(f"Error starting recording: {str(e)}")
322
+ else:
323
+ logger.info(f"Stopping recording from Camera {i}")
324
+ try:
325
+ frames = processors[key].video_processor.stop_recording()
326
+ if frames:
327
+ mp4_filename = generate_filename(i, "mp4")
328
+ with av.open(mp4_filename, "w") as container:
329
+ stream = container.add_stream("h264", rate=30)
330
+ stream.width = frames[0].shape[1]
331
+ stream.height = frames[0].shape[0]
332
+ for frame in frames:
333
+ av_frame = av.VideoFrame.from_ndarray(frame, format="bgr24")
334
+ for packet in stream.encode(av_frame):
335
+ container.mux(packet)
336
+ for packet in stream.encode():
337
+ container.mux(packet)
338
+ st.video(mp4_filename)
339
+ logger.info(f"Saved video: {mp4_filename}")
340
+ # Slice into 10 frames
341
+ sliced_images = []
342
+ step = max(1, len(frames) // 10)
343
+ for j in range(0, len(frames), step):
344
+ if len(sliced_images) < 10:
345
+ img = Image.fromarray(frames[j][:, :, ::-1]) # BGR to RGB
346
+ img_filename = generate_filename(f"{i}_{len(sliced_images)}")
347
+ img.save(img_filename)
348
+ sliced_images.append(img_filename)
349
+ st.image(img, caption=img_filename, use_container_width=True)
350
+ st.session_state['captured_images'] = st.session_state.get('captured_images', []) + sliced_images
351
+ logger.info(f"Sliced video into {len(sliced_images)} images")
352
+ update_gallery()
353
+ st.session_state[record_key] = False
354
+ except Exception as e:
355
+ st.error(f"Stop recording failed: {str(e)}")
356
+ logger.error(f"Error stopping recording: {str(e)}")
357
 
358
  with tab3:
359
  st.header("Fine-Tune Titans 🔧 (Tune Fast!)")
 
390
  st.warning("Capture at least 2 images first! ⚠️")
391
 
392
  with tab4:
393
+ st.header("Test Titans 🧪 (Image Agent Demo!)")
394
  if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
395
  st.warning("Load a Titan first! ⚠️")
396
  else:
 
405
  except Exception as e:
406
  st.error(f"NLP test failed: {str(e)}")
407
  elif isinstance(st.session_state['builder'], DiffusionBuilder):
408
+ st.subheader("CV Test 🎨 (Image Set Demo)")
409
+ captured_images = get_gallery_files(["png"])
410
+ if len(captured_images) >= 2:
411
+ if st.button("Run CV Demo ▶️"):
412
+ logger.info("Running CV image set demo")
413
+ try:
414
+ from PIL import Image
415
+ images = [Image.open(img) for img in captured_images[:10]]
416
+ prompts = ["Neon " + os.path.basename(img).split('.')[0] for img in captured_images[:10]]
417
+ generated_images = []
418
+ for prompt in prompts:
419
+ img = st.session_state['builder'].generate(prompt)
420
+ generated_images.append(img)
421
+ cols = st.columns(2)
422
+ for idx, (orig, gen) in enumerate(zip(images, generated_images)):
423
+ with cols[idx % 2]:
424
+ st.image(orig, caption=f"Original: {captured_images[idx]}", use_container_width=True)
425
+ st.image(gen, caption=f"Generated: {prompts[idx]}", use_container_width=True)
426
+ md_content = "# Image Set Demo\n\nScript of filenames and descriptions:\n"
427
+ for i, (img, prompt) in enumerate(zip(captured_images[:10], prompts)):
428
+ md_content += f"{i+1}. `{img}` - {prompt}\n"
429
+ md_filename = f"demo_metadata_{int(time.time())}.md"
430
+ with open(md_filename, "w") as f:
431
+ f.write(md_content)
432
+ st.markdown(get_download_link(md_filename, "text/markdown", "Download Metadata .md"), unsafe_allow_html=True)
433
+ logger.info("CV demo completed with metadata")
434
+ except Exception as e:
435
+ st.error(f"CV demo failed: {str(e)}")
436
+ logger.error(f"Error in CV demo: {str(e)}")
437
+ else:
438
+ st.warning("Capture at least 2 images first! ⚠️")
439
 
440
  # Display Logs
441
  st.sidebar.subheader("Action Logs 📜")
442
  log_container = st.sidebar.empty()
443
  with log_container:
444
  for record in log_records:
445
+ st.write(f"{record.asctime} - {record.levelname} - {record.message}")
446
+
447
+ update_gallery() # Initial gallery update