awacke1 commited on
Commit
1538fca
·
verified ·
1 Parent(s): d35a95b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +750 -0
app.py ADDED
@@ -0,0 +1,750 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import glob
4
+ import base64
5
+ import streamlit as st
6
+ import pandas as pd
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+ from torch.utils.data import Dataset, DataLoader
10
+ import csv
11
+ import time
12
+ from dataclasses import dataclass
13
+ from typing import Optional, Tuple
14
+ import zipfile
15
+ import math
16
+ from PIL import Image
17
+ import random
18
+ import logging
19
+ import numpy as np
20
+
21
+ # Logging setup with a custom buffer
22
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
23
+ logger = logging.getLogger(__name__)
24
+ log_records = []
25
+
26
+ class LogCaptureHandler(logging.Handler):
27
+ def emit(self, record):
28
+ log_records.append(record)
29
+
30
+ logger.addHandler(LogCaptureHandler())
31
+
32
+ # Page Configuration
33
+ st.set_page_config(
34
+ page_title="SFT Tiny Titans 🚀",
35
+ page_icon="🤖",
36
+ layout="wide",
37
+ initial_sidebar_state="expanded",
38
+ menu_items={
39
+ 'Get Help': 'https://huggingface.co/awacke1',
40
+ 'Report a Bug': 'https://huggingface.co/spaces/awacke1',
41
+ 'About': "Tiny Titans: Small models, big dreams, and a sprinkle of chaos! 🌌"
42
+ }
43
+ )
44
+
45
+ # Initialize st.session_state
46
+ if 'captured_images' not in st.session_state:
47
+ st.session_state['captured_images'] = []
48
+ if 'nlp_builder' not in st.session_state:
49
+ st.session_state['nlp_builder'] = None
50
+ if 'cv_builder' not in st.session_state:
51
+ st.session_state['cv_builder'] = None
52
+ if 'nlp_loaded' not in st.session_state:
53
+ st.session_state['nlp_loaded'] = False
54
+ if 'cv_loaded' not in st.session_state:
55
+ st.session_state['cv_loaded'] = False
56
+ if 'active_tab' not in st.session_state:
57
+ st.session_state['active_tab'] = "Build Titan 🌱"
58
+
59
+ # Model Configuration Classes
60
+ @dataclass
61
+ class ModelConfig:
62
+ name: str
63
+ base_model: str
64
+ size: str
65
+ domain: Optional[str] = None
66
+ model_type: str = "causal_lm"
67
+ @property
68
+ def model_path(self):
69
+ return f"models/{self.name}"
70
+
71
+ @dataclass
72
+ class DiffusionConfig:
73
+ name: str
74
+ base_model: str
75
+ size: str
76
+ @property
77
+ def model_path(self):
78
+ return f"diffusion_models/{self.name}"
79
+
80
+ # Datasets
81
+ class SFTDataset(Dataset):
82
+ def __init__(self, data, tokenizer, max_length=128):
83
+ self.data = data
84
+ self.tokenizer = tokenizer
85
+ self.max_length = max_length
86
+ def __len__(self):
87
+ return len(self.data)
88
+ def __getitem__(self, idx):
89
+ prompt = self.data[idx]["prompt"]
90
+ response = self.data[idx]["response"]
91
+ full_text = f"{prompt} {response}"
92
+ full_encoding = self.tokenizer(full_text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
93
+ prompt_encoding = self.tokenizer(prompt, max_length=self.max_length, padding=False, truncation=True, return_tensors="pt")
94
+ input_ids = full_encoding["input_ids"].squeeze()
95
+ attention_mask = full_encoding["attention_mask"].squeeze()
96
+ labels = input_ids.clone()
97
+ prompt_len = prompt_encoding["input_ids"].shape[1]
98
+ if prompt_len < self.max_length:
99
+ labels[:prompt_len] = -100
100
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
101
+
102
+ class DiffusionDataset(Dataset):
103
+ def __init__(self, images, texts):
104
+ self.images = images
105
+ self.texts = texts
106
+ def __len__(self):
107
+ return len(self.images)
108
+ def __getitem__(self, idx):
109
+ return {"image": self.images[idx], "text": self.texts[idx]}
110
+
111
+ # Model Builders
112
+ class ModelBuilder:
113
+ def __init__(self):
114
+ self.config = None
115
+ self.model = None
116
+ self.tokenizer = None
117
+ self.sft_data = None
118
+ self.jokes = ["Why did the AI go to therapy? Too many layers to unpack! 😂", "Training complete! Time for a binary coffee break. ☕"]
119
+ def load_model(self, model_path: str, config: Optional[ModelConfig] = None):
120
+ try:
121
+ with st.spinner(f"Loading {model_path}... ⏳ (Patience, young padawan!)"):
122
+ self.model = AutoModelForCausalLM.from_pretrained(model_path)
123
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
124
+ if self.tokenizer.pad_token is None:
125
+ self.tokenizer.pad_token = self.tokenizer.eos_token
126
+ if config:
127
+ self.config = config
128
+ self.model.to("cuda" if torch.cuda.is_available() else "cpu")
129
+ st.success(f"Model loaded! 🎉 {random.choice(self.jokes)}")
130
+ logger.info(f"Successfully loaded Causal LM model: {model_path}")
131
+ except torch.cuda.OutOfMemoryError as e:
132
+ st.error(f"GPU memory error loading {model_path}: {str(e)} 💥 (Out of GPU juice!)")
133
+ logger.error(f"GPU memory error loading {model_path}: {str(e)}")
134
+ raise
135
+ except MemoryError as e:
136
+ st.error(f"CPU memory error loading {model_path}: {str(e)} 💥 (RAM ran away!)")
137
+ logger.error(f"CPU memory error loading {model_path}: {str(e)}")
138
+ raise
139
+ except Exception as e:
140
+ st.error(f"Failed to load {model_path}: {str(e)} 💥 (Something broke—check the logs!)")
141
+ logger.error(f"Failed to load {model_path}: {str(e)}")
142
+ raise
143
+ return self
144
+ def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
145
+ try:
146
+ self.sft_data = []
147
+ with open(csv_path, "r") as f:
148
+ reader = csv.DictReader(f)
149
+ for row in reader:
150
+ self.sft_data.append({"prompt": row["prompt"], "response": row["response"]})
151
+ dataset = SFTDataset(self.sft_data, self.tokenizer)
152
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
153
+ optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
154
+ self.model.train()
155
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
156
+ self.model.to(device)
157
+ for epoch in range(epochs):
158
+ with st.spinner(f"Training epoch {epoch + 1}/{epochs}... ⚙️ (The AI is lifting weights!)"):
159
+ total_loss = 0
160
+ for batch in dataloader:
161
+ optimizer.zero_grad()
162
+ input_ids = batch["input_ids"].to(device)
163
+ attention_mask = batch["attention_mask"].to(device)
164
+ labels = batch["labels"].to(device)
165
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
166
+ loss = outputs.loss
167
+ loss.backward()
168
+ optimizer.step()
169
+ total_loss += loss.item()
170
+ st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
171
+ st.success(f"SFT Fine-tuning completed! 🎉 {random.choice(self.jokes)}")
172
+ logger.info(f"Successfully fine-tuned Causal LM model: {self.config.name}")
173
+ except Exception as e:
174
+ st.error(f"Fine-tuning failed: {str(e)} 💥 (Training hit a snag!)")
175
+ logger.error(f"Fine-tuning failed: {str(e)}")
176
+ raise
177
+ return self
178
+ def save_model(self, path: str):
179
+ try:
180
+ with st.spinner("Saving model... 💾 (Packing the AI’s suitcase!)"):
181
+ os.makedirs(os.path.dirname(path), exist_ok=True)
182
+ self.model.save_pretrained(path)
183
+ self.tokenizer.save_pretrained(path)
184
+ st.success(f"Model saved at {path}! ✅ May the force be with it.")
185
+ logger.info(f"Model saved at {path}")
186
+ except Exception as e:
187
+ st.error(f"Failed to save model: {str(e)} 💥 (Save operation crashed!)")
188
+ logger.error(f"Failed to save model: {str(e)}")
189
+ raise
190
+ def evaluate(self, prompt: str, status_container=None):
191
+ self.model.eval()
192
+ if status_container:
193
+ status_container.write("Preparing to evaluate... 🧠 (Titan’s warming up its circuits!)")
194
+ logger.info(f"Evaluating prompt: {prompt}")
195
+ try:
196
+ with torch.no_grad():
197
+ inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.model.device)
198
+ outputs = self.model.generate(**inputs, max_new_tokens=50, do_sample=True, top_p=0.95, temperature=0.7)
199
+ result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
200
+ logger.info(f"Generated response: {result}")
201
+ return result
202
+ except Exception as e:
203
+ logger.error(f"Evaluation error: {str(e)}")
204
+ if status_container:
205
+ status_container.error(f"Oops! Something broke: {str(e)} 💥 (Titan tripped over a wire!)")
206
+ return f"Error: {str(e)}"
207
+
208
+ class DiffusionBuilder:
209
+ def __init__(self):
210
+ self.config = None
211
+ self.pipeline = None
212
+ def load_model(self, model_path: str, config: Optional[DiffusionConfig] = None):
213
+ from diffusers import StableDiffusionPipeline
214
+ try:
215
+ with st.spinner(f"Loading diffusion model {model_path}... ⏳"):
216
+ self.pipeline = StableDiffusionPipeline.from_pretrained(model_path)
217
+ self.pipeline.to("cuda" if torch.cuda.is_available() else "cpu")
218
+ if config:
219
+ self.config = config
220
+ st.success(f"Diffusion model loaded! 🎨")
221
+ logger.info(f"Successfully loaded Diffusion model: {model_path}")
222
+ except torch.cuda.OutOfMemoryError as e:
223
+ st.error(f"GPU memory error loading {model_path}: {str(e)} 💥 (Out of GPU juice!)")
224
+ logger.error(f"GPU memory error loading {model_path}: {str(e)}")
225
+ raise
226
+ except MemoryError as e:
227
+ st.error(f"CPU memory error loading {model_path}: {str(e)} 💥 (RAM ran away!)")
228
+ logger.error(f"CPU memory error loading {model_path}: {str(e)}")
229
+ raise
230
+ except Exception as e:
231
+ st.error(f"Failed to load {model_path}: {str(e)} 💥 (Something broke—check the logs!)")
232
+ logger.error(f"Failed to load {model_path}: {str(e)}")
233
+ raise
234
+ return self
235
+ def fine_tune_sft(self, images, texts, epochs=3):
236
+ try:
237
+ dataset = DiffusionDataset(images, texts)
238
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
239
+ optimizer = torch.optim.AdamW(self.pipeline.unet.parameters(), lr=1e-5)
240
+ self.pipeline.unet.train()
241
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
242
+ for epoch in range(epochs):
243
+ with st.spinner(f"Training diffusion epoch {epoch + 1}/{epochs}... ⚙️"):
244
+ total_loss = 0
245
+ for batch in dataloader:
246
+ optimizer.zero_grad()
247
+ image = batch["image"][0].to(device)
248
+ text = batch["text"][0]
249
+ latents = self.pipeline.vae.encode(torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float().to(device)).latent_dist.sample()
250
+ noise = torch.randn_like(latents)
251
+ timesteps = torch.randint(0, self.pipeline.scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device)
252
+ noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)
253
+ text_embeddings = self.pipeline.text_encoder(self.pipeline.tokenizer(text, return_tensors="pt").input_ids.to(device))[0]
254
+ pred_noise = self.pipeline.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample
255
+ loss = torch.nn.functional.mse_loss(pred_noise, noise)
256
+ loss.backward()
257
+ optimizer.step()
258
+ total_loss += loss.item()
259
+ st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
260
+ st.success("Diffusion SFT Fine-tuning completed! 🎨")
261
+ logger.info(f"Successfully fine-tuned Diffusion model: {self.config.name}")
262
+ except Exception as e:
263
+ st.error(f"Fine-tuning failed: {str(e)} 💥 (Training hit a snag!)")
264
+ logger.error(f"Fine-tuning failed: {str(e)}")
265
+ raise
266
+ return self
267
+ def save_model(self, path: str):
268
+ try:
269
+ with st.spinner("Saving diffusion model... 💾"):
270
+ os.makedirs(os.path.dirname(path), exist_ok=True)
271
+ self.pipeline.save_pretrained(path)
272
+ st.success(f"Diffusion model saved at {path}! ✅")
273
+ logger.info(f"Diffusion model saved at {path}")
274
+ except Exception as e:
275
+ st.error(f"Failed to save model: {str(e)} 💥 (Save operation crashed!)")
276
+ logger.error(f"Failed to save model: {str(e)}")
277
+ raise
278
+ def generate(self, prompt: str):
279
+ try:
280
+ return self.pipeline(prompt, num_inference_steps=50).images[0]
281
+ except Exception as e:
282
+ st.error(f"Image generation failed: {str(e)} 💥 (Pixel party pooper!)")
283
+ logger.error(f"Image generation failed: {str(e)}")
284
+ raise
285
+
286
+ # Utility Functions
287
+ def generate_filename(sequence, ext="png"):
288
+ from datetime import datetime
289
+ import pytz
290
+ central = pytz.timezone('US/Central')
291
+ dt = datetime.now(central)
292
+ return f"{dt.strftime('%m-%d-%Y-%I-%M-%S-%p')}.{ext}"
293
+
294
+ def get_download_link(file_path, mime_type="text/plain", label="Download"):
295
+ try:
296
+ with open(file_path, 'rb') as f:
297
+ data = f.read()
298
+ b64 = base64.b64encode(data).decode()
299
+ return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label} 📥</a>'
300
+ except Exception as e:
301
+ logger.error(f"Failed to generate download link for {file_path}: {str(e)}")
302
+ return f"Error: Could not generate link for {file_path}"
303
+
304
+ def zip_files(files, zip_path):
305
+ try:
306
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
307
+ for file in files:
308
+ zipf.write(file, os.path.basename(file))
309
+ logger.info(f"Created ZIP file: {zip_path}")
310
+ except Exception as e:
311
+ logger.error(f"Failed to create ZIP file {zip_path}: {str(e)}")
312
+ raise
313
+
314
+ def delete_files(files):
315
+ try:
316
+ for file in files:
317
+ os.remove(file)
318
+ logger.info(f"Deleted file: {file}")
319
+ st.session_state['captured_images'] = [f for f in st.session_state['captured_images'] if f not in files]
320
+ except Exception as e:
321
+ logger.error(f"Failed to delete files: {str(e)}")
322
+ raise
323
+
324
+ def get_model_files(model_type="causal_lm"):
325
+ path = "models/*" if model_type == "causal_lm" else "diffusion_models/*"
326
+ return [d for d in glob.glob(path) if os.path.isdir(d)]
327
+
328
+ def get_gallery_files(file_types):
329
+ return sorted(list(set(f for ext in file_types for f in glob.glob(f"*.{ext}"))))
330
+
331
+ def update_gallery():
332
+ media_files = get_gallery_files(["png"])
333
+ if media_files:
334
+ cols = st.sidebar.columns(2)
335
+ for idx, file in enumerate(media_files[:gallery_size * 2]):
336
+ with cols[idx % 2]:
337
+ st.image(Image.open(file), caption=file, use_container_width=True)
338
+ st.markdown(get_download_link(file, "image/png", "Download Image"), unsafe_allow_html=True)
339
+
340
+ # Mock Search Tool for RAG
341
+ def mock_search(query: str) -> str:
342
+ if "superhero" in query.lower():
343
+ return "Latest trends: Gold-plated Batman statues, VR superhero battles."
344
+ return "No relevant results found."
345
+
346
+ class PartyPlannerAgent:
347
+ def __init__(self, model, tokenizer):
348
+ self.model = model
349
+ self.tokenizer = tokenizer
350
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
351
+ self.model.to(self.device)
352
+ def generate(self, prompt: str) -> str:
353
+ self.model.eval()
354
+ with torch.no_grad():
355
+ inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.device)
356
+ outputs = self.model.generate(**inputs, max_new_tokens=100, do_sample=True, top_p=0.95, temperature=0.7)
357
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
358
+ def plan_party(self, task: str) -> pd.DataFrame:
359
+ search_result = mock_search("superhero party trends")
360
+ prompt = f"Given this context: '{search_result}'\n{task}"
361
+ plan_text = self.generate(prompt)
362
+ locations = {"Wayne Manor": (42.3601, -71.0589), "New York": (40.7128, -74.0060)}
363
+ wayne_coords = locations["Wayne Manor"]
364
+ travel_times = {loc: calculate_cargo_travel_time(coords, wayne_coords) for loc, coords in locations.items() if loc != "Wayne Manor"}
365
+ data = [
366
+ {"Location": "New York", "Travel Time (hrs)": travel_times["New York"], "Luxury Idea": "Gold-plated Batman statues"},
367
+ {"Location": "Wayne Manor", "Travel Time (hrs)": 0.0, "Luxury Idea": "VR superhero battles"}
368
+ ]
369
+ return pd.DataFrame(data)
370
+
371
+ class CVPartyPlannerAgent:
372
+ def __init__(self, pipeline):
373
+ self.pipeline = pipeline
374
+ def generate(self, prompt: str) -> Image.Image:
375
+ return self.pipeline(prompt, num_inference_steps=50).images[0]
376
+ def plan_party(self, task: str) -> pd.DataFrame:
377
+ search_result = mock_search("superhero party trends")
378
+ prompt = f"Given this context: '{search_result}'\n{task}"
379
+ data = [
380
+ {"Theme": "Batman", "Image Idea": "Gold-plated Batman statue"},
381
+ {"Theme": "Avengers", "Image Idea": "VR superhero battle scene"}
382
+ ]
383
+ return pd.DataFrame(data)
384
+
385
+ def calculate_cargo_travel_time(origin_coords: Tuple[float, float], destination_coords: Tuple[float, float], cruising_speed_kmh: float = 750.0) -> float:
386
+ def to_radians(degrees: float) -> float:
387
+ return degrees * (math.pi / 180)
388
+ lat1, lon1 = map(to_radians, origin_coords)
389
+ lat2, lon2 = map(to_radians, destination_coords)
390
+ EARTH_RADIUS_KM = 6371.0
391
+ dlon = lon2 - lon1
392
+ dlat = lat2 - lat1
393
+ a = (math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2)
394
+ c = 2 * math.asin(math.sqrt(a))
395
+ distance = EARTH_RADIUS_KM * c
396
+ actual_distance = distance * 1.1
397
+ flight_time = (actual_distance / cruising_speed_kmh) + 1.0
398
+ return round(flight_time, 2)
399
+
400
+ # Main App
401
+ st.title("SFT Tiny Titans 🚀 (Small but Mighty!)")
402
+
403
+ # Sidebar Galleries and File Management
404
+ st.sidebar.header("Media Gallery 🎨")
405
+ gallery_size = st.sidebar.slider("Gallery Size 📸", 1, 10, 4, help="Adjust how many epic captures you see! 🌟")
406
+ update_gallery()
407
+
408
+ col1, col2 = st.sidebar.columns(2)
409
+ with col1:
410
+ if st.button("Download All 📦"):
411
+ media_files = get_gallery_files(["png"])
412
+ if media_files:
413
+ zip_path = f"snapshot_collection_{int(time.time())}.zip"
414
+ zip_files(media_files, zip_path)
415
+ st.sidebar.markdown(get_download_link(zip_path, "application/zip", "Download All Snapshots"), unsafe_allow_html=True)
416
+ st.sidebar.success("Snapshots zipped and ready! 🎉 Grab your loot!")
417
+ else:
418
+ st.sidebar.warning("No snapshots to zip! 📸 Snap some pics first!")
419
+ with col2:
420
+ if st.button("Delete All 🗑️"):
421
+ media_files = get_gallery_files(["png"])
422
+ if media_files:
423
+ delete_files(media_files)
424
+ st.sidebar.success("All snapshots vanquished! 🧹 Gallery cleared!")
425
+ update_gallery()
426
+ else:
427
+ st.sidebar.warning("Nothing to delete! 📸 Snap some pics to clear later!")
428
+
429
+ # File Uploader
430
+ uploaded_files = st.sidebar.file_uploader("Upload Files 🎵🎥🖼️📝📜", type=["mp3", "mp4", "png", "jpeg", "md", "pdf", "docx"], accept_multiple_files=True)
431
+ if uploaded_files:
432
+ for uploaded_file in uploaded_files:
433
+ filename = uploaded_file.name
434
+ with open(filename, "wb") as f:
435
+ f.write(uploaded_file.getvalue())
436
+ logger.info(f"Uploaded file: {filename}")
437
+
438
+ # Sidebar Galleries by Type
439
+ st.sidebar.subheader("Audio Gallery 🎵")
440
+ audio_files = get_gallery_files(["mp3"])
441
+ if audio_files:
442
+ for file in audio_files[:gallery_size]:
443
+ st.sidebar.audio(file, format="audio/mp3")
444
+ st.sidebar.markdown(get_download_link(file, "audio/mp3", f"Download {file}"), unsafe_allow_html=True)
445
+
446
+ st.sidebar.subheader("Video Gallery 🎥")
447
+ video_files = get_gallery_files(["mp4"])
448
+ if video_files:
449
+ for file in video_files[:gallery_size]:
450
+ st.sidebar.video(file, format="video/mp4")
451
+ st.sidebar.markdown(get_download_link(file, "video/mp4", f"Download {file}"), unsafe_allow_html=True)
452
+
453
+ st.sidebar.subheader("Image Gallery 🖼️")
454
+ image_files = get_gallery_files(["png", "jpeg"])
455
+ if image_files:
456
+ cols = st.sidebar.columns(2)
457
+ for idx, file in enumerate(image_files[:gallery_size * 2]):
458
+ with cols[idx % 2]:
459
+ st.image(Image.open(file), caption=file, use_container_width=True)
460
+ st.markdown(get_download_link(file, "image/png" if file.endswith(".png") else "image/jpeg", f"Download {file}"), unsafe_allow_html=True)
461
+
462
+ st.sidebar.subheader("Markdown Gallery 📝")
463
+ md_files = get_gallery_files(["md"])
464
+ if md_files:
465
+ for file in md_files[:gallery_size]:
466
+ with open(file, "r") as f:
467
+ st.sidebar.markdown(f.read())
468
+ st.sidebar.markdown(get_download_link(file, "text/markdown", f"Download {file}"), unsafe_allow_html=True)
469
+
470
+ st.sidebar.subheader("Document Gallery 📜")
471
+ doc_files = get_gallery_files(["pdf", "docx"])
472
+ if doc_files:
473
+ for file in doc_files[:gallery_size]:
474
+ mime_type = "application/pdf" if file.endswith(".pdf") else "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
475
+ st.sidebar.markdown(get_download_link(file, mime_type, f"Download {file}"), unsafe_allow_html=True)
476
+
477
+ st.sidebar.subheader("Model Management 🗂️")
478
+ model_type = st.sidebar.selectbox("Model Type", ["Causal LM", "Diffusion"])
479
+ model_dirs = get_model_files("causal_lm" if model_type == "Causal LM" else "diffusion")
480
+ selected_model = st.sidebar.selectbox("Select Saved Model", ["None"] + model_dirs)
481
+ if selected_model != "None" and st.sidebar.button("Load Model 📂"):
482
+ builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
483
+ config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=os.path.basename(selected_model), base_model="unknown", size="small")
484
+ try:
485
+ builder.load_model(selected_model, config)
486
+ if model_type == "Causal LM":
487
+ st.session_state['nlp_builder'] = builder
488
+ st.session_state['nlp_loaded'] = True
489
+ else:
490
+ st.session_state['cv_builder'] = builder
491
+ st.session_state['cv_loaded'] = True
492
+ st.rerun()
493
+ except Exception as e:
494
+ st.error(f"Model load failed: {str(e)} 💥 (Check logs for details!)")
495
+
496
+ st.sidebar.subheader("Model Status 🚦")
497
+ st.sidebar.write(f"**NLP Model**: {'Loaded' if st.session_state['nlp_loaded'] else 'Not Loaded'} {'(Active)' if st.session_state['nlp_loaded'] and isinstance(st.session_state.get('nlp_builder'), ModelBuilder) else ''}")
498
+ st.sidebar.write(f"**CV Model**: {'Loaded' if st.session_state['cv_loaded'] else 'Not Loaded'} {'(Active)' if st.session_state['cv_loaded'] and isinstance(st.session_state.get('cv_builder'), DiffusionBuilder) else ''}")
499
+
500
+ # Tabs
501
+ tabs = [
502
+ "Build Titan 🌱", "Camera Snap 📷",
503
+ "Fine-Tune Titan (NLP) 🔧", "Test Titan (NLP) 🧪", "Agentic RAG Party (NLP) 🌐",
504
+ "Fine-Tune Titan (CV) 🔧", "Test Titan (CV) 🧪", "Agentic RAG Party (CV) 🌐"
505
+ ]
506
+ tab1, tab2, tab3, tab4, tab5, tab6, tab7, tab8 = st.tabs(tabs)
507
+
508
+ # Log Tab Switches
509
+ for i, tab in enumerate(tabs):
510
+ if st.session_state['active_tab'] != tab and st.session_state.get(f'tab{i}_active', False):
511
+ logger.info(f"Switched to tab: {tab}")
512
+ st.session_state['active_tab'] = tab
513
+ st.session_state[f'tab{i}_active'] = (st.session_state['active_tab'] == tab)
514
+
515
+ with tab1:
516
+ st.header("Build Titan 🌱")
517
+ model_type = st.selectbox("Model Type", ["Causal LM", "Diffusion"], key="build_type")
518
+ base_model = st.selectbox("Select Tiny Model",
519
+ ["HuggingFaceTB/SmolLM-135M", "HuggingFaceTB/SmolLM-360M", "Qwen/Qwen1.5-0.5B-Chat"] if model_type == "Causal LM" else
520
+ ["stabilityai/stable-diffusion-2-base", "runwayml/stable-diffusion-v1-5"])
521
+ model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}")
522
+ domain = st.text_input("Target Domain", "general", help="Where will your Titan flex its muscles? 💪") if model_type == "Causal LM" else None
523
+ if st.button("Download Model ⬇️"):
524
+ config = ModelConfig(name=model_name, base_model=base_model, size="small", domain=domain) if model_type == "Causal LM" else DiffusionConfig(name=model_name, base_model=base_model, size="small")
525
+ builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
526
+ try:
527
+ builder.load_model(base_model, config)
528
+ builder.save_model(config.model_path)
529
+ if model_type == "Causal LM":
530
+ st.session_state['nlp_builder'] = builder
531
+ st.session_state['nlp_loaded'] = True
532
+ else:
533
+ st.session_state['cv_builder'] = builder
534
+ st.session_state['cv_loaded'] = True
535
+ st.rerun()
536
+ except Exception as e:
537
+ st.error(f"Model build failed: {str(e)} 💥 (Check logs for details!)")
538
+
539
+ with tab2:
540
+ st.header("Camera Snap 📷 (Dual Capture!)")
541
+ slice_count = st.number_input("Image Slice Count 🎞️", min_value=1, max_value=20, value=10, help="How many snaps to dream of? (Automation’s on vacation! 😜)")
542
+ video_length = st.number_input("Video Dream Length (seconds) 🎥", min_value=1, max_value=30, value=10, help="Imagine a vid this long—sadly, we’re stuck with pics for now! 😂")
543
+ cols = st.columns(2)
544
+ with cols[0]:
545
+ st.subheader("Camera 0 🎬")
546
+ cam0_img = st.camera_input("Snap a Shot - Cam 0 📸", key="cam0", help="Click to capture a heroic moment! 🦸‍♂️")
547
+ if cam0_img:
548
+ filename = generate_filename(0)
549
+ with open(filename, "wb") as f:
550
+ f.write(cam0_img.getvalue())
551
+ st.image(Image.open(filename), caption=filename, use_container_width=True)
552
+ logger.info(f"Saved snapshot from Camera 0: {filename}")
553
+ st.session_state['captured_images'].append(filename)
554
+ update_gallery()
555
+ st.info("🚨 Multi-frame capture’s on strike! Snap one at a time—your Titan’s too cool for automation glitches! 😎")
556
+ with cols[1]:
557
+ st.subheader("Camera 1 🎥")
558
+ cam1_img = st.camera_input("Snap a Shot - Cam 1 📸", key="cam1", help="Grab another epic frame! 🌟")
559
+ if cam1_img:
560
+ filename = generate_filename(1)
561
+ with open(filename, "wb") as f:
562
+ f.write(cam1_img.getvalue())
563
+ st.image(Image.open(filename), caption=filename, use_container_width=True)
564
+ logger.info(f"Saved snapshot from Camera 1: {filename}")
565
+ st.session_state['captured_images'].append(filename)
566
+ update_gallery()
567
+ st.info("🚨 Frame bursts? Nope, manual snaps only! One click, one masterpiece! 🎨")
568
+
569
+ with tab3: # Fine-Tune Titan (NLP)
570
+ st.header("Fine-Tune Titan (NLP) 🔧 (Teach Your Word Wizard Some Tricks!)")
571
+ if not st.session_state['nlp_loaded'] or not isinstance(st.session_state['nlp_builder'], ModelBuilder):
572
+ st.warning("Please build or load an NLP Titan first! ⚠️ (No word wizard, no magic!)")
573
+ else:
574
+ if st.button("Generate Sample CSV 📝"):
575
+ sample_data = [
576
+ {"prompt": "What is AI?", "response": "AI is artificial intelligence, simulating human smarts in machines."},
577
+ {"prompt": "Explain machine learning", "response": "Machine learning is AI’s gym where models bulk up on data."},
578
+ {"prompt": "What is a neural network?", "response": "A neural network is a brainy AI mimicking human noggins."},
579
+ ]
580
+ csv_path = f"sft_data_{int(time.time())}.csv"
581
+ with open(csv_path, "w", newline="") as f:
582
+ writer = csv.DictWriter(f, fieldnames=["prompt", "response"])
583
+ writer.writeheader()
584
+ writer.writerows(sample_data)
585
+ st.markdown(get_download_link(csv_path, "text/csv", "Download Sample CSV"), unsafe_allow_html=True)
586
+ st.success(f"Sample CSV generated as {csv_path}! ✅ (Fresh from the data oven!)")
587
+ uploaded_csv = st.file_uploader("Upload CSV for SFT 📜", type="csv", help="Feed your Titan some tasty prompt-response pairs! 🍽️")
588
+ if uploaded_csv and st.button("Fine-Tune with Uploaded CSV 🔄"):
589
+ csv_path = f"uploaded_sft_data_{int(time.time())}.csv"
590
+ with open(csv_path, "wb") as f:
591
+ f.write(uploaded_csv.read())
592
+ new_model_name = f"{st.session_state['nlp_builder'].config.name}-sft-{int(time.time())}"
593
+ new_config = ModelConfig(name=new_model_name, base_model=st.session_state['nlp_builder'].config.base_model, size="small", domain=st.session_state['nlp_builder'].config.domain)
594
+ st.session_state['nlp_builder'].config = new_config
595
+ with st.status("Fine-tuning NLP Titan... ⏳ (Whipping words into shape!)", expanded=True) as status:
596
+ st.session_state['nlp_builder'].fine_tune_sft(csv_path)
597
+ st.session_state['nlp_builder'].save_model(new_config.model_path)
598
+ status.update(label="Fine-tuning completed! 🎉 (Wordsmith Titan unleashed!)", state="complete")
599
+ zip_path = f"{new_config.model_path}.zip"
600
+ zip_directory(new_config.model_path, zip_path)
601
+ st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned NLP Titan"), unsafe_allow_html=True)
602
+
603
+ with tab4: # Test Titan (NLP)
604
+ st.header("Test Titan (NLP) 🧪 (Put Your Word Wizard to the Test!)")
605
+ if not st.session_state['nlp_loaded'] or not isinstance(st.session_state['nlp_builder'], ModelBuilder):
606
+ st.warning("Please build or load an NLP Titan first! ⚠️ (No word wizard, no test drive!)")
607
+ else:
608
+ if st.session_state['nlp_builder'].sft_data:
609
+ st.write("Testing with SFT Data:")
610
+ with st.spinner("Running SFT data tests... ⏳ (Titan’s flexing its word muscles!)"):
611
+ for item in st.session_state['nlp_builder'].sft_data[:3]:
612
+ prompt = item["prompt"]
613
+ expected = item["response"]
614
+ status_container = st.empty()
615
+ generated = st.session_state['nlp_builder'].evaluate(prompt, status_container)
616
+ st.write(f"**Prompt**: {prompt}")
617
+ st.write(f"**Expected**: {expected}")
618
+ st.write(f"**Generated**: {generated} (Titan says: '{random.choice(['Bleep bloop!', 'I am groot!', '42!'])}')")
619
+ st.write("---")
620
+ status_container.empty()
621
+ test_prompt = st.text_area("Enter Test Prompt 🗣️", "What is AI?", help="Ask your Titan anything—it’s ready to chat! 😜")
622
+ if st.button("Run Test ▶️"):
623
+ with st.spinner("Testing your prompt... ⏳ (Titan’s pondering deeply!)"):
624
+ status_container = st.empty()
625
+ result = st.session_state['nlp_builder'].evaluate(test_prompt, status_container)
626
+ st.write(f"**Generated Response**: {result} (Titan’s wisdom unleashed!)")
627
+ status_container.empty()
628
+
629
+ with tab5: # Agentic RAG Party (NLP)
630
+ st.header("Agentic RAG Party (NLP) 🌐 (Party Like It’s 2099!)")
631
+ st.write("This demo uses your SFT-tuned NLP Titan to plan a superhero party with mock retrieval!")
632
+ if not st.session_state['nlp_loaded'] or not isinstance(st.session_state['nlp_builder'], ModelBuilder):
633
+ st.warning("Please build or load an NLP Titan first! ⚠️ (No word wizard, no party!)")
634
+ else:
635
+ if st.button("Run NLP RAG Demo 🎉"):
636
+ with st.spinner("Loading your SFT-tuned NLP Titan... ⏳ (Titan’s suiting up!)"):
637
+ agent = PartyPlannerAgent(st.session_state['nlp_builder'].model, st.session_state['nlp_builder'].tokenizer)
638
+ st.write("Agent ready! 🦸‍♂️ (Time to plan an epic bash!)")
639
+ task = """
640
+ Plan a luxury superhero-themed party at Wayne Manor (42.3601° N, 71.0589° W).
641
+ Use mock search results for the latest superhero party trends, refine for luxury elements
642
+ (decorations, entertainment, catering), and calculate cargo travel times from key locations
643
+ (New York: 40.7128° N, 74.0060° W; LA: 34.0522° N, 118.2437° W; London: 51.5074° N, 0.1278° W)
644
+ to Wayne Manor. Create a plan with at least 6 entries in a pandas dataframe.
645
+ """
646
+ with st.spinner("Planning the ultimate superhero bash... ⏳ (Calling all caped crusaders!)"):
647
+ try:
648
+ locations = {
649
+ "Wayne Manor": (42.3601, -71.0589),
650
+ "New York": (40.7128, -74.0060),
651
+ "Los Angeles": (34.0522, -118.2437),
652
+ "London": (51.5074, -0.1278)
653
+ }
654
+ wayne_coords = locations["Wayne Manor"]
655
+ travel_times = {loc: calculate_cargo_travel_time(coords, wayne_coords) for loc, coords in locations.items() if loc != "Wayne Manor"}
656
+ search_result = mock_search("superhero party trends")
657
+ prompt = f"""
658
+ Given this context from a search: "{search_result}"
659
+ Plan a luxury superhero-themed party at Wayne Manor. Suggest luxury decorations, entertainment, and catering ideas.
660
+ """
661
+ plan_text = agent.generate(prompt)
662
+ catchphrases = ["To the Batmobile!", "Avengers, assemble!", "I am Iron Man!", "By the power of Grayskull!"]
663
+ data = [
664
+ {"Location": "New York", "Travel Time (hrs)": travel_times["New York"], "Luxury Idea": "Gold-plated Batman statues", "Catchphrase": random.choice(catchphrases)},
665
+ {"Location": "Los Angeles", "Travel Time (hrs)": travel_times["Los Angeles"], "Luxury Idea": "Holographic Avengers displays", "Catchphrase": random.choice(catchphrases)},
666
+ {"Location": "London", "Travel Time (hrs)": travel_times["London"], "Luxury Idea": "Live stunt shows with Iron Man suits", "Catchphrase": random.choice(catchphrases)},
667
+ {"Location": "Wayne Manor", "Travel Time (hrs)": 0.0, "Luxury Idea": "VR superhero battles", "Catchphrase": random.choice(catchphrases)},
668
+ {"Location": "New York", "Travel Time (hrs)": travel_times["New York"], "Luxury Idea": "Gourmet kryptonite-green cocktails", "Catchphrase": random.choice(catchphrases)},
669
+ {"Location": "Los Angeles", "Travel Time (hrs)": travel_times["Los Angeles"], "Luxury Idea": "Thor’s hammer-shaped appetizers", "Catchphrase": random.choice(catchphrases)},
670
+ ]
671
+ plan_df = pd.DataFrame(data)
672
+ st.write("Agentic RAG Party Plan:")
673
+ st.dataframe(plan_df)
674
+ st.write("Party on, Wayne! 🦸‍♂️🎉")
675
+ except Exception as e:
676
+ st.error(f"Error planning party: {str(e)} (Even Superman has kryptonite days!)")
677
+ logger.error(f"Error in NLP RAG demo: {str(e)}")
678
+
679
+ with tab6: # Fine-Tune Titan (CV)
680
+ st.header("Fine-Tune Titan (CV) 🔧 (Paint Your Titan’s Masterpiece!)")
681
+ if not st.session_state['cv_loaded'] or not isinstance(st.session_state['cv_builder'], DiffusionBuilder):
682
+ st.warning("Please build or load a CV Titan first! ⚠️ (No artist, no canvas!)")
683
+ else:
684
+ captured_images = get_gallery_files(["png"])
685
+ if len(captured_images) >= 2:
686
+ demo_data = [{"image": img, "text": f"Superhero {os.path.basename(img).split('.')[0]}"} for img in captured_images[:min(len(captured_images), 10)]]
687
+ edited_data = st.data_editor(pd.DataFrame(demo_data), num_rows="dynamic", help="Craft your image-text pairs like a superhero artist! 🎨")
688
+ if st.button("Fine-Tune with Dataset 🔄"):
689
+ images = [Image.open(row["image"]) for _, row in edited_data.iterrows()]
690
+ texts = [row["text"] for _, row in edited_data.iterrows()]
691
+ new_model_name = f"{st.session_state['cv_builder'].config.name}-sft-{int(time.time())}"
692
+ new_config = DiffusionConfig(name=new_model_name, base_model=st.session_state['cv_builder'].config.base_model, size="small")
693
+ st.session_state['cv_builder'].config = new_config
694
+ with st.status("Fine-tuning CV Titan... ⏳ (Brushing up those pixels!)", expanded=True) as status:
695
+ st.session_state['cv_builder'].fine_tune_sft(images, texts)
696
+ st.session_state['cv_builder'].save_model(new_config.model_path)
697
+ status.update(label="Fine-tuning completed! 🎉 (Pixel Titan unleashed!)", state="complete")
698
+ zip_path = f"{new_config.model_path}.zip"
699
+ zip_directory(new_config.model_path, zip_path)
700
+ st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned CV Titan"), unsafe_allow_html=True)
701
+ csv_path = f"sft_dataset_{int(time.time())}.csv"
702
+ with open(csv_path, "w", newline="") as f:
703
+ writer = csv.writer(f)
704
+ writer.writerow(["image", "text"])
705
+ for _, row in edited_data.iterrows():
706
+ writer.writerow([row["image"], row["text"]])
707
+ st.markdown(get_download_link(csv_path, "text/csv", "Download SFT Dataset CSV"), unsafe_allow_html=True)
708
+
709
+ with tab7: # Test Titan (CV)
710
+ st.header("Test Titan (CV) 🧪 (Unleash Your Pixel Power!)")
711
+ if not st.session_state['cv_loaded'] or not isinstance(st.session_state['cv_builder'], DiffusionBuilder):
712
+ st.warning("Please build or load a CV Titan first! ⚠️ (No artist, no masterpiece!)")
713
+ else:
714
+ test_prompt = st.text_area("Enter Test Prompt 🎨", "Neon Batman", help="Dream up a wild image—your Titan’s got the brush! 🖌️")
715
+ if st.button("Run Test ▶️"):
716
+ with st.spinner("Painting your masterpiece... ⏳ (Titan’s mixing colors!)"):
717
+ image = st.session_state['cv_builder'].generate(test_prompt)
718
+ st.image(image, caption="Generated Image", use_container_width=True)
719
+
720
+ with tab8: # Agentic RAG Party (CV)
721
+ st.header("Agentic RAG Party (CV) 🌐 (Party with Pixels!)")
722
+ st.write("This demo uses your SFT-tuned CV Titan to generate superhero party images with mock retrieval!")
723
+ if not st.session_state['cv_loaded'] or not isinstance(st.session_state['cv_builder'], DiffusionBuilder):
724
+ st.warning("Please build or load a CV Titan first! ⚠️ (No artist, no party!)")
725
+ else:
726
+ if st.button("Run CV RAG Demo 🎉"):
727
+ with st.spinner("Loading your SFT-tuned CV Titan... ⏳ (Titan’s grabbing its paintbrush!)"):
728
+ agent = CVPartyPlannerAgent(st.session_state['cv_builder'].pipeline)
729
+ st.write("Agent ready! 🎨 (Time to paint an epic bash!)")
730
+ task = "Generate images for a luxury superhero-themed party."
731
+ with st.spinner("Crafting superhero party visuals... ⏳ (Pixels assemble!)"):
732
+ try:
733
+ plan_df = agent.plan_party(task)
734
+ st.dataframe(plan_df)
735
+ for _, row in plan_df.iterrows():
736
+ image = agent.generate(row["Image Idea"])
737
+ st.image(image, caption=f"{row['Theme']} - {row['Image Idea']}", use_container_width=True)
738
+ except Exception as e:
739
+ st.error(f"Error in CV RAG demo: {str(e)} 💥 (Pixel party crashed!)")
740
+ logger.error(f"Error in CV RAG demo: {str(e)}")
741
+
742
+ # Display Logs
743
+ st.sidebar.subheader("Action Logs 📜")
744
+ log_container = st.sidebar.empty()
745
+ with log_container:
746
+ for record in log_records:
747
+ st.write(f"{record.asctime} - {record.levelname} - {record.message}")
748
+
749
+ # Initial Gallery Update
750
+ update_gallery()