awacke1 commited on
Commit
382ac48
·
verified ·
1 Parent(s): 0776d58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -16
app.py CHANGED
@@ -42,6 +42,14 @@ st.set_page_config(
42
  }
43
  )
44
 
 
 
 
 
 
 
 
 
45
  # Model Configuration Classes
46
  @dataclass
47
  class ModelConfig:
@@ -110,6 +118,7 @@ class ModelBuilder:
110
  self.tokenizer.pad_token = self.tokenizer.eos_token
111
  if config:
112
  self.config = config
 
113
  st.success(f"Model loaded! 🎉 {random.choice(self.jokes)}")
114
  return self
115
  def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
@@ -233,6 +242,15 @@ def get_model_files(model_type="causal_lm"):
233
  def get_gallery_files(file_types):
234
  return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
235
 
 
 
 
 
 
 
 
 
 
236
  # Mock Search Tool for RAG
237
  def mock_search(query: str) -> str:
238
  if "superhero" in query.lower():
@@ -299,13 +317,7 @@ st.title("SFT Tiny Titans 🚀 (Small but Mighty!)")
299
  # Sidebar Galleries
300
  st.sidebar.header("Media Gallery 🎨")
301
  gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 4)
302
- media_files = get_gallery_files(["png"])
303
- if media_files:
304
- cols = st.sidebar.columns(2)
305
- for idx, file in enumerate(media_files[:gallery_size * 2]):
306
- with cols[idx % 2]:
307
- st.image(Image.open(file), caption=file, use_column_width=True)
308
- st.markdown(get_download_link(file, "image/png", "Download Image"), unsafe_allow_html=True)
309
 
310
  st.sidebar.subheader("Model Management 🗂️")
311
  model_type = st.sidebar.selectbox("Model Type", ["Causal LM", "Diffusion"])
@@ -350,10 +362,8 @@ with tab2:
350
  filename = generate_filename(0)
351
  with open(filename, "wb") as f:
352
  f.write(cam0_img.getvalue())
353
- st.image(Image.open(filename), caption=filename, use_column_width=True)
354
  logger.info(f"Saved snapshot from Camera 0: {filename}")
355
- if 'captured_images' not in st.session_state:
356
- st.session_state['captured_images'] = []
357
  st.session_state['captured_images'].append(filename)
358
  update_gallery()
359
  if st.button(f"Capture {slice_count} Frames - Cam 0 📸"):
@@ -370,7 +380,7 @@ with tab2:
370
  st.session_state['captured_images'].extend(st.session_state['cam0_frames'])
371
  update_gallery()
372
  for frame in st.session_state['cam0_frames']:
373
- st.image(Image.open(frame), caption=frame, use_column_width=True)
374
  with cols[1]:
375
  st.subheader("Camera 1")
376
  cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
@@ -378,10 +388,8 @@ with tab2:
378
  filename = generate_filename(1)
379
  with open(filename, "wb") as f:
380
  f.write(cam1_img.getvalue())
381
- st.image(Image.open(filename), caption=filename, use_column_width=True)
382
  logger.info(f"Saved snapshot from Camera 1: {filename}")
383
- if 'captured_images' not in st.session_state:
384
- st.session_state['captured_images'] = []
385
  st.session_state['captured_images'].append(filename)
386
  update_gallery()
387
  if st.button(f"Capture {slice_count} Frames - Cam 1 📸"):
@@ -398,7 +406,7 @@ with tab2:
398
  st.session_state['captured_images'].extend(st.session_state['cam1_frames'])
399
  update_gallery()
400
  for frame in st.session_state['cam1_frames']:
401
- st.image(Image.open(frame), caption=frame, use_column_width=True)
402
 
403
  with tab3:
404
  st.header("Fine-Tune Titan 🔧")
@@ -485,4 +493,7 @@ st.sidebar.subheader("Action Logs 📜")
485
  log_container = st.sidebar.empty()
486
  with log_container:
487
  for record in log_records:
488
- st.write(f"{record.asctime} - {record.levelname} - {record.message}")
 
 
 
 
42
  }
43
  )
44
 
45
+ # Initialize st.session_state
46
+ if 'captured_images' not in st.session_state:
47
+ st.session_state['captured_images'] = []
48
+ if 'builder' not in st.session_state:
49
+ st.session_state['builder'] = None
50
+ if 'model_loaded' not in st.session_state:
51
+ st.session_state['model_loaded'] = False
52
+
53
  # Model Configuration Classes
54
  @dataclass
55
  class ModelConfig:
 
118
  self.tokenizer.pad_token = self.tokenizer.eos_token
119
  if config:
120
  self.config = config
121
+ self.model.to("cuda" if torch.cuda.is_available() else "cpu")
122
  st.success(f"Model loaded! 🎉 {random.choice(self.jokes)}")
123
  return self
124
  def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
 
242
  def get_gallery_files(file_types):
243
  return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
244
 
245
+ def update_gallery():
246
+ media_files = get_gallery_files(["png"])
247
+ if media_files:
248
+ cols = st.sidebar.columns(2)
249
+ for idx, file in enumerate(media_files[:gallery_size * 2]):
250
+ with cols[idx % 2]:
251
+ st.image(Image.open(file), caption=file, use_container_width=True)
252
+ st.markdown(get_download_link(file, "image/png", "Download Image"), unsafe_allow_html=True)
253
+
254
  # Mock Search Tool for RAG
255
  def mock_search(query: str) -> str:
256
  if "superhero" in query.lower():
 
317
  # Sidebar Galleries
318
  st.sidebar.header("Media Gallery 🎨")
319
  gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 4)
320
+ update_gallery()
 
 
 
 
 
 
321
 
322
  st.sidebar.subheader("Model Management 🗂️")
323
  model_type = st.sidebar.selectbox("Model Type", ["Causal LM", "Diffusion"])
 
362
  filename = generate_filename(0)
363
  with open(filename, "wb") as f:
364
  f.write(cam0_img.getvalue())
365
+ st.image(Image.open(filename), caption=filename, use_container_width=True)
366
  logger.info(f"Saved snapshot from Camera 0: {filename}")
 
 
367
  st.session_state['captured_images'].append(filename)
368
  update_gallery()
369
  if st.button(f"Capture {slice_count} Frames - Cam 0 📸"):
 
380
  st.session_state['captured_images'].extend(st.session_state['cam0_frames'])
381
  update_gallery()
382
  for frame in st.session_state['cam0_frames']:
383
+ st.image(Image.open(frame), caption=frame, use_container_width=True)
384
  with cols[1]:
385
  st.subheader("Camera 1")
386
  cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
 
388
  filename = generate_filename(1)
389
  with open(filename, "wb") as f:
390
  f.write(cam1_img.getvalue())
391
+ st.image(Image.open(filename), caption=filename, use_container_width=True)
392
  logger.info(f"Saved snapshot from Camera 1: {filename}")
 
 
393
  st.session_state['captured_images'].append(filename)
394
  update_gallery()
395
  if st.button(f"Capture {slice_count} Frames - Cam 1 📸"):
 
406
  st.session_state['captured_images'].extend(st.session_state['cam1_frames'])
407
  update_gallery()
408
  for frame in st.session_state['cam1_frames']:
409
+ st.image(Image.open(frame), caption=frame, use_container_width=True)
410
 
411
  with tab3:
412
  st.header("Fine-Tune Titan 🔧")
 
493
  log_container = st.sidebar.empty()
494
  with log_container:
495
  for record in log_records:
496
+ st.write(f"{record.asctime} - {record.levelname} - {record.message}")
497
+
498
+ # Initial Gallery Update
499
+ update_gallery()