Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Update app.py
Browse files
app.py
CHANGED
@@ -42,6 +42,14 @@ st.set_page_config(
|
|
42 |
}
|
43 |
)
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
# Model Configuration Classes
|
46 |
@dataclass
|
47 |
class ModelConfig:
|
@@ -110,6 +118,7 @@ class ModelBuilder:
|
|
110 |
self.tokenizer.pad_token = self.tokenizer.eos_token
|
111 |
if config:
|
112 |
self.config = config
|
|
|
113 |
st.success(f"Model loaded! 🎉 {random.choice(self.jokes)}")
|
114 |
return self
|
115 |
def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
|
@@ -233,6 +242,15 @@ def get_model_files(model_type="causal_lm"):
|
|
233 |
def get_gallery_files(file_types):
|
234 |
return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
|
235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
# Mock Search Tool for RAG
|
237 |
def mock_search(query: str) -> str:
|
238 |
if "superhero" in query.lower():
|
@@ -299,13 +317,7 @@ st.title("SFT Tiny Titans 🚀 (Small but Mighty!)")
|
|
299 |
# Sidebar Galleries
|
300 |
st.sidebar.header("Media Gallery 🎨")
|
301 |
gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 4)
|
302 |
-
|
303 |
-
if media_files:
|
304 |
-
cols = st.sidebar.columns(2)
|
305 |
-
for idx, file in enumerate(media_files[:gallery_size * 2]):
|
306 |
-
with cols[idx % 2]:
|
307 |
-
st.image(Image.open(file), caption=file, use_column_width=True)
|
308 |
-
st.markdown(get_download_link(file, "image/png", "Download Image"), unsafe_allow_html=True)
|
309 |
|
310 |
st.sidebar.subheader("Model Management 🗂️")
|
311 |
model_type = st.sidebar.selectbox("Model Type", ["Causal LM", "Diffusion"])
|
@@ -350,10 +362,8 @@ with tab2:
|
|
350 |
filename = generate_filename(0)
|
351 |
with open(filename, "wb") as f:
|
352 |
f.write(cam0_img.getvalue())
|
353 |
-
st.image(Image.open(filename), caption=filename,
|
354 |
logger.info(f"Saved snapshot from Camera 0: {filename}")
|
355 |
-
if 'captured_images' not in st.session_state:
|
356 |
-
st.session_state['captured_images'] = []
|
357 |
st.session_state['captured_images'].append(filename)
|
358 |
update_gallery()
|
359 |
if st.button(f"Capture {slice_count} Frames - Cam 0 📸"):
|
@@ -370,7 +380,7 @@ with tab2:
|
|
370 |
st.session_state['captured_images'].extend(st.session_state['cam0_frames'])
|
371 |
update_gallery()
|
372 |
for frame in st.session_state['cam0_frames']:
|
373 |
-
st.image(Image.open(frame), caption=frame,
|
374 |
with cols[1]:
|
375 |
st.subheader("Camera 1")
|
376 |
cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
|
@@ -378,10 +388,8 @@ with tab2:
|
|
378 |
filename = generate_filename(1)
|
379 |
with open(filename, "wb") as f:
|
380 |
f.write(cam1_img.getvalue())
|
381 |
-
st.image(Image.open(filename), caption=filename,
|
382 |
logger.info(f"Saved snapshot from Camera 1: {filename}")
|
383 |
-
if 'captured_images' not in st.session_state:
|
384 |
-
st.session_state['captured_images'] = []
|
385 |
st.session_state['captured_images'].append(filename)
|
386 |
update_gallery()
|
387 |
if st.button(f"Capture {slice_count} Frames - Cam 1 📸"):
|
@@ -398,7 +406,7 @@ with tab2:
|
|
398 |
st.session_state['captured_images'].extend(st.session_state['cam1_frames'])
|
399 |
update_gallery()
|
400 |
for frame in st.session_state['cam1_frames']:
|
401 |
-
st.image(Image.open(frame), caption=frame,
|
402 |
|
403 |
with tab3:
|
404 |
st.header("Fine-Tune Titan 🔧")
|
@@ -485,4 +493,7 @@ st.sidebar.subheader("Action Logs 📜")
|
|
485 |
log_container = st.sidebar.empty()
|
486 |
with log_container:
|
487 |
for record in log_records:
|
488 |
-
st.write(f"{record.asctime} - {record.levelname} - {record.message}")
|
|
|
|
|
|
|
|
42 |
}
|
43 |
)
|
44 |
|
45 |
+
# Initialize st.session_state
|
46 |
+
if 'captured_images' not in st.session_state:
|
47 |
+
st.session_state['captured_images'] = []
|
48 |
+
if 'builder' not in st.session_state:
|
49 |
+
st.session_state['builder'] = None
|
50 |
+
if 'model_loaded' not in st.session_state:
|
51 |
+
st.session_state['model_loaded'] = False
|
52 |
+
|
53 |
# Model Configuration Classes
|
54 |
@dataclass
|
55 |
class ModelConfig:
|
|
|
118 |
self.tokenizer.pad_token = self.tokenizer.eos_token
|
119 |
if config:
|
120 |
self.config = config
|
121 |
+
self.model.to("cuda" if torch.cuda.is_available() else "cpu")
|
122 |
st.success(f"Model loaded! 🎉 {random.choice(self.jokes)}")
|
123 |
return self
|
124 |
def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
|
|
|
242 |
def get_gallery_files(file_types):
|
243 |
return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
|
244 |
|
245 |
+
def update_gallery():
|
246 |
+
media_files = get_gallery_files(["png"])
|
247 |
+
if media_files:
|
248 |
+
cols = st.sidebar.columns(2)
|
249 |
+
for idx, file in enumerate(media_files[:gallery_size * 2]):
|
250 |
+
with cols[idx % 2]:
|
251 |
+
st.image(Image.open(file), caption=file, use_container_width=True)
|
252 |
+
st.markdown(get_download_link(file, "image/png", "Download Image"), unsafe_allow_html=True)
|
253 |
+
|
254 |
# Mock Search Tool for RAG
|
255 |
def mock_search(query: str) -> str:
|
256 |
if "superhero" in query.lower():
|
|
|
317 |
# Sidebar Galleries
|
318 |
st.sidebar.header("Media Gallery 🎨")
|
319 |
gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 4)
|
320 |
+
update_gallery()
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
|
322 |
st.sidebar.subheader("Model Management 🗂️")
|
323 |
model_type = st.sidebar.selectbox("Model Type", ["Causal LM", "Diffusion"])
|
|
|
362 |
filename = generate_filename(0)
|
363 |
with open(filename, "wb") as f:
|
364 |
f.write(cam0_img.getvalue())
|
365 |
+
st.image(Image.open(filename), caption=filename, use_container_width=True)
|
366 |
logger.info(f"Saved snapshot from Camera 0: {filename}")
|
|
|
|
|
367 |
st.session_state['captured_images'].append(filename)
|
368 |
update_gallery()
|
369 |
if st.button(f"Capture {slice_count} Frames - Cam 0 📸"):
|
|
|
380 |
st.session_state['captured_images'].extend(st.session_state['cam0_frames'])
|
381 |
update_gallery()
|
382 |
for frame in st.session_state['cam0_frames']:
|
383 |
+
st.image(Image.open(frame), caption=frame, use_container_width=True)
|
384 |
with cols[1]:
|
385 |
st.subheader("Camera 1")
|
386 |
cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
|
|
|
388 |
filename = generate_filename(1)
|
389 |
with open(filename, "wb") as f:
|
390 |
f.write(cam1_img.getvalue())
|
391 |
+
st.image(Image.open(filename), caption=filename, use_container_width=True)
|
392 |
logger.info(f"Saved snapshot from Camera 1: {filename}")
|
|
|
|
|
393 |
st.session_state['captured_images'].append(filename)
|
394 |
update_gallery()
|
395 |
if st.button(f"Capture {slice_count} Frames - Cam 1 📸"):
|
|
|
406 |
st.session_state['captured_images'].extend(st.session_state['cam1_frames'])
|
407 |
update_gallery()
|
408 |
for frame in st.session_state['cam1_frames']:
|
409 |
+
st.image(Image.open(frame), caption=frame, use_container_width=True)
|
410 |
|
411 |
with tab3:
|
412 |
st.header("Fine-Tune Titan 🔧")
|
|
|
493 |
log_container = st.sidebar.empty()
|
494 |
with log_container:
|
495 |
for record in log_records:
|
496 |
+
st.write(f"{record.asctime} - {record.levelname} - {record.message}")
|
497 |
+
|
498 |
+
# Initial Gallery Update
|
499 |
+
update_gallery()
|