awacke1 commited on
Commit
07943e1
·
verified ·
1 Parent(s): 428c305

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +569 -0
app.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import glob
4
+ import base64
5
+ import streamlit as st
6
+ import pandas as pd
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+ from torch.utils.data import Dataset, DataLoader
10
+ import csv
11
+ import time
12
+ from dataclasses import dataclass
13
+ from typing import Optional, Tuple
14
+ import zipfile
15
+ import math
16
+ from PIL import Image
17
+ import random
18
+ import logging
19
+ import numpy as np
20
+
21
+ # Logging setup with a custom buffer
22
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
23
+ logger = logging.getLogger(__name__)
24
+ log_records = []
25
+
26
+ class LogCaptureHandler(logging.Handler):
27
+ def emit(self, record):
28
+ log_records.append(record)
29
+
30
+ logger.addHandler(LogCaptureHandler())
31
+
32
+ # Page Configuration
33
+ st.set_page_config(
34
+ page_title="SFT Tiny Titans 🚀",
35
+ page_icon="🤖",
36
+ layout="wide",
37
+ initial_sidebar_state="expanded",
38
+ menu_items={
39
+ 'Get Help': 'https://huggingface.co/awacke1',
40
+ 'Report a Bug': 'https://huggingface.co/spaces/awacke1',
41
+ 'About': "Tiny Titans: Small models, big dreams, and a sprinkle of chaos! 🌌"
42
+ }
43
+ )
44
+
45
+ # Initialize st.session_state
46
+ if 'captured_images' not in st.session_state:
47
+ st.session_state['captured_images'] = []
48
+ if 'builder' not in st.session_state:
49
+ st.session_state['builder'] = None
50
+ if 'model_loaded' not in st.session_state:
51
+ st.session_state['model_loaded'] = False
52
+
53
+ # Model Configuration Classes
54
+ @dataclass
55
+ class ModelConfig:
56
+ name: str
57
+ base_model: str
58
+ size: str
59
+ domain: Optional[str] = None
60
+ model_type: str = "causal_lm"
61
+ @property
62
+ def model_path(self):
63
+ return f"models/{self.name}"
64
+
65
+ @dataclass
66
+ class DiffusionConfig:
67
+ name: str
68
+ base_model: str
69
+ size: str
70
+ @property
71
+ def model_path(self):
72
+ return f"diffusion_models/{self.name}"
73
+
74
+ # Datasets
75
+ class SFTDataset(Dataset):
76
+ def __init__(self, data, tokenizer, max_length=128):
77
+ self.data = data
78
+ self.tokenizer = tokenizer
79
+ self.max_length = max_length
80
+ def __len__(self):
81
+ return len(self.data)
82
+ def __getitem__(self, idx):
83
+ prompt = self.data[idx]["prompt"]
84
+ response = self.data[idx]["response"]
85
+ full_text = f"{prompt} {response}"
86
+ full_encoding = self.tokenizer(full_text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
87
+ prompt_encoding = self.tokenizer(prompt, max_length=self.max_length, padding=False, truncation=True, return_tensors="pt")
88
+ input_ids = full_encoding["input_ids"].squeeze()
89
+ attention_mask = full_encoding["attention_mask"].squeeze()
90
+ labels = input_ids.clone()
91
+ prompt_len = prompt_encoding["input_ids"].shape[1]
92
+ if prompt_len < self.max_length:
93
+ labels[:prompt_len] = -100
94
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
95
+
96
+ class DiffusionDataset(Dataset):
97
+ def __init__(self, images, texts):
98
+ self.images = images
99
+ self.texts = texts
100
+ def __len__(self):
101
+ return len(self.images)
102
+ def __getitem__(self, idx):
103
+ return {"image": self.images[idx], "text": self.texts[idx]}
104
+
105
+ # Model Builders
106
+ class ModelBuilder:
107
+ def __init__(self):
108
+ self.config = None
109
+ self.model = None
110
+ self.tokenizer = None
111
+ self.sft_data = None
112
+ self.jokes = ["Why did the AI go to therapy? Too many layers to unpack! 😂", "Training complete! Time for a binary coffee break. ☕"]
113
+ def load_model(self, model_path: str, config: Optional[ModelConfig] = None):
114
+ with st.spinner(f"Loading {model_path}... ⏳ (Patience, young padawan!)"):
115
+ self.model = AutoModelForCausalLM.from_pretrained(model_path)
116
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
117
+ if self.tokenizer.pad_token is None:
118
+ self.tokenizer.pad_token = self.tokenizer.eos_token
119
+ if config:
120
+ self.config = config
121
+ self.model.to("cuda" if torch.cuda.is_available() else "cpu")
122
+ st.success(f"Model loaded! 🎉 {random.choice(self.jokes)}")
123
+ return self
124
+ def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
125
+ self.sft_data = []
126
+ with open(csv_path, "r") as f:
127
+ reader = csv.DictReader(f)
128
+ for row in reader:
129
+ self.sft_data.append({"prompt": row["prompt"], "response": row["response"]})
130
+ dataset = SFTDataset(self.sft_data, self.tokenizer)
131
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
132
+ optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
133
+ self.model.train()
134
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
135
+ self.model.to(device)
136
+ for epoch in range(epochs):
137
+ with st.spinner(f"Training epoch {epoch + 1}/{epochs}... ⚙️ (The AI is lifting weights!)"):
138
+ total_loss = 0
139
+ for batch in dataloader:
140
+ optimizer.zero_grad()
141
+ input_ids = batch["input_ids"].to(device)
142
+ attention_mask = batch["attention_mask"].to(device)
143
+ labels = batch["labels"].to(device)
144
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
145
+ loss = outputs.loss
146
+ loss.backward()
147
+ optimizer.step()
148
+ total_loss += loss.item()
149
+ st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
150
+ st.success(f"SFT Fine-tuning completed! 🎉 {random.choice(self.jokes)}")
151
+ return self
152
+ def save_model(self, path: str):
153
+ with st.spinner("Saving model... 💾 (Packing the AI’s suitcase!)"):
154
+ os.makedirs(os.path.dirname(path), exist_ok=True)
155
+ self.model.save_pretrained(path)
156
+ self.tokenizer.save_pretrained(path)
157
+ st.success(f"Model saved at {path}! ✅ May the force be with it.")
158
+ def evaluate(self, prompt: str, status_container=None):
159
+ self.model.eval()
160
+ if status_container:
161
+ status_container.write("Preparing to evaluate... 🧠 (Titan’s warming up its circuits!)")
162
+ logger.info(f"Evaluating prompt: {prompt}")
163
+ try:
164
+ with torch.no_grad():
165
+ inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.model.device)
166
+ outputs = self.model.generate(**inputs, max_new_tokens=50, do_sample=True, top_p=0.95, temperature=0.7)
167
+ result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
168
+ logger.info(f"Generated response: {result}")
169
+ return result
170
+ except Exception as e:
171
+ logger.error(f"Evaluation error: {str(e)}")
172
+ if status_container:
173
+ status_container.error(f"Oops! Something broke: {str(e)} 💥 (Titan tripped over a wire!)")
174
+ return f"Error: {str(e)}"
175
+
176
+ class DiffusionBuilder:
177
+ def __init__(self):
178
+ self.config = None
179
+ self.pipeline = None
180
+ def load_model(self, model_path: str, config: Optional[DiffusionConfig] = None):
181
+ from diffusers import StableDiffusionPipeline
182
+ with st.spinner(f"Loading diffusion model {model_path}... ⏳"):
183
+ self.pipeline = StableDiffusionPipeline.from_pretrained(model_path)
184
+ self.pipeline.to("cuda" if torch.cuda.is_available() else "cpu")
185
+ if config:
186
+ self.config = config
187
+ st.success(f"Diffusion model loaded! 🎨")
188
+ return self
189
+ def fine_tune_sft(self, images, texts, epochs=3):
190
+ dataset = DiffusionDataset(images, texts)
191
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
192
+ optimizer = torch.optim.AdamW(self.pipeline.unet.parameters(), lr=1e-5)
193
+ self.pipeline.unet.train()
194
+ for epoch in range(epochs):
195
+ with st.spinner(f"Training diffusion epoch {epoch + 1}/{epochs}... ⚙️"):
196
+ total_loss = 0
197
+ for batch in dataloader:
198
+ optimizer.zero_grad()
199
+ image = batch["image"][0].to(self.pipeline.device)
200
+ text = batch["text"][0]
201
+ latents = self.pipeline.vae.encode(torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float().to(self.pipeline.device)).latent_dist.sample()
202
+ noise = torch.randn_like(latant)
203
+ timesteps = torch.randint(0, self.pipeline.scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device)
204
+ noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)
205
+ text_embeddings = self.pipeline.text_encoder(self.pipeline.tokenizer(text, return_tensors="pt").input_ids.to(self.pipeline.device))[0]
206
+ pred_noise = self.pipeline.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample
207
+ loss = torch.nn.functional.mse_loss(pred_noise, noise)
208
+ loss.backward()
209
+ optimizer.step()
210
+ total_loss += loss.item()
211
+ st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
212
+ st.success("Diffusion SFT Fine-tuning completed! 🎨")
213
+ return self
214
+ def save_model(self, path: str):
215
+ with st.spinner("Saving diffusion model... 💾"):
216
+ os.makedirs(os.path.dirname(path), exist_ok=True)
217
+ self.pipeline.save_pretrained(path)
218
+ st.success(f"Diffusion model saved at {path}! ✅")
219
+ def generate(self, prompt: str):
220
+ return self.pipeline(prompt, num_inference_steps=50).images[0]
221
+
222
+ # Utility Functions
223
+ def generate_filename(sequence, ext="png"):
224
+ from datetime import datetime
225
+ import pytz
226
+ central = pytz.timezone('US/Central')
227
+ dt = datetime.now(central)
228
+ return f"{dt.strftime('%m-%d-%Y-%I-%M-%p')}.{ext}"
229
+
230
+ def get_download_link(file_path, mime_type="text/plain", label="Download"):
231
+ with open(file_path, 'rb') as f:
232
+ data = f.read()
233
+ b64 = base64.b64encode(data).decode()
234
+ return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label} 📥</a>'
235
+
236
+ def zip_directory(directory_path, zip_path):
237
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
238
+ for root, _, files in os.walk(directory_path):
239
+ for file in files:
240
+ zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), os.path.dirname(directory_path)))
241
+
242
+ def get_model_files(model_type="causal_lm"):
243
+ path = "models/*" if model_type == "causal_lm" else "diffusion_models/*"
244
+ return [d for d in glob.glob(path) if os.path.isdir(d)]
245
+
246
+ def get_gallery_files(file_types):
247
+ files = sorted(list(set(f for ext in file_types for f in glob.glob(f"*.{ext}")))) # Remove duplicates and sort
248
+ return files
249
+
250
+ def update_gallery():
251
+ media_files = get_gallery_files(["png"])
252
+ if media_files:
253
+ cols = st.sidebar.columns(2)
254
+ for idx, file in enumerate(media_files[:gallery_size * 2]):
255
+ with cols[idx % 2]:
256
+ st.image(Image.open(file), caption=file, use_container_width=True)
257
+ st.markdown(get_download_link(file, "image/png", "Download Image"), unsafe_allow_html=True)
258
+
259
+ # Mock Search Tool for RAG
260
+ def mock_search(query: str) -> str:
261
+ if "superhero" in query.lower():
262
+ return "Latest trends: Gold-plated Batman statues, VR superhero battles."
263
+ return "No relevant results found."
264
+
265
+ class PartyPlannerAgent:
266
+ def __init__(self, model, tokenizer):
267
+ self.model = model
268
+ self.tokenizer = tokenizer
269
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
270
+ self.model.to(self.device)
271
+ def generate(self, prompt: str) -> str:
272
+ self.model.eval()
273
+ with torch.no_grad():
274
+ inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.device)
275
+ outputs = self.model.generate(**inputs, max_new_tokens=100, do_sample=True, top_p=0.95, temperature=0.7)
276
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
277
+ def plan_party(self, task: str) -> pd.DataFrame:
278
+ search_result = mock_search("superhero party trends")
279
+ prompt = f"Given this context: '{search_result}'\n{task}"
280
+ plan_text = self.generate(prompt)
281
+ locations = {"Wayne Manor": (42.3601, -71.0589), "New York": (40.7128, -74.0060)}
282
+ wayne_coords = locations["Wayne Manor"]
283
+ travel_times = {loc: calculate_cargo_travel_time(coords, wayne_coords) for loc, coords in locations.items() if loc != "Wayne Manor"}
284
+ data = [
285
+ {"Location": "New York", "Travel Time (hrs)": travel_times["New York"], "Luxury Idea": "Gold-plated Batman statues"},
286
+ {"Location": "Wayne Manor", "Travel Time (hrs)": 0.0, "Luxury Idea": "VR superhero battles"}
287
+ ]
288
+ return pd.DataFrame(data)
289
+
290
+ class CVPartyPlannerAgent:
291
+ def __init__(self, pipeline):
292
+ self.pipeline = pipeline
293
+ def generate(self, prompt: str) -> Image.Image:
294
+ return self.pipeline(prompt, num_inference_steps=50).images[0]
295
+ def plan_party(self, task: str) -> pd.DataFrame:
296
+ search_result = mock_search("superhero party trends")
297
+ prompt = f"Given this context: '{search_result}'\n{task}"
298
+ data = [
299
+ {"Theme": "Batman", "Image Idea": "Gold-plated Batman statue"},
300
+ {"Theme": "Avengers", "Image Idea": "VR superhero battle scene"}
301
+ ]
302
+ return pd.DataFrame(data)
303
+
304
+ def calculate_cargo_travel_time(origin_coords: Tuple[float, float], destination_coords: Tuple[float, float], cruising_speed_kmh: float = 750.0) -> float:
305
+ def to_radians(degrees: float) -> float:
306
+ return degrees * (math.pi / 180)
307
+ lat1, lon1 = map(to_radians, origin_coords)
308
+ lat2, lon2 = map(to_radians, destination_coords)
309
+ EARTH_RADIUS_KM = 6371.0
310
+ dlon = lon2 - lon1
311
+ dlat = lat2 - lat1
312
+ a = (math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2)
313
+ c = 2 * math.asin(math.sqrt(a))
314
+ distance = EARTH_RADIUS_KM * c
315
+ actual_distance = distance * 1.1
316
+ flight_time = (actual_distance / cruising_speed_kmh) + 1.0
317
+ return round(flight_time, 2)
318
+
319
+ # Main App
320
+ st.title("SFT Tiny Titans 🚀 (Small but Mighty!)")
321
+
322
+ # Sidebar Galleries
323
+ st.sidebar.header("Media Gallery 🎨")
324
+ gallery_size = st.sidebar.slider("Gallery Size 📸", 1, 10, 4, help="Adjust how many epic captures you see! 🌟")
325
+ update_gallery()
326
+
327
+ st.sidebar.subheader("Model Management 🗂️")
328
+ model_type = st.sidebar.selectbox("Model Type", ["Causal LM", "Diffusion"])
329
+ model_dirs = get_model_files("causal_lm" if model_type == "Causal LM" else "diffusion")
330
+ selected_model = st.sidebar.selectbox("Select Saved Model", ["None"] + model_dirs)
331
+ if selected_model != "None" and st.sidebar.button("Load Model 📂"):
332
+ builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
333
+ config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=os.path.basename(selected_model), base_model="unknown", size="small")
334
+ builder.load_model(selected_model, config)
335
+ st.session_state['builder'] = builder
336
+ st.session_state['model_loaded'] = True
337
+ st.rerun()
338
+
339
+ # Tabs
340
+ tab1, tab2, tab3, tab4, tab5, tab6, tab7, tab8 = st.tabs([
341
+ "Build Titan 🌱", "Camera Snap 📷",
342
+ "Fine-Tune Titan (NLP) 🔧", "Test Titan (NLP) 🧪", "Agentic RAG Party (NLP) 🌐",
343
+ "Fine-Tune Titan (CV) 🔧", "Test Titan (CV) 🧪", "Agentic RAG Party (CV) 🌐"
344
+ ])
345
+
346
+ with tab1:
347
+ st.header("Build Titan 🌱")
348
+ model_type = st.selectbox("Model Type", ["Causal LM", "Diffusion"], key="build_type")
349
+ base_model = st.selectbox("Select Tiny Model",
350
+ ["HuggingFaceTB/SmolLM-135M", "HuggingFaceTB/SmolLM-360M", "Qwen/Qwen1.5-0.5B-Chat"] if model_type == "Causal LM" else
351
+ ["stabilityai/stable-diffusion-2-base", "runwayml/stable-diffusion-v1-5"])
352
+ model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}")
353
+ domain = st.text_input("Target Domain", "general", help="Where will your Titan flex its muscles? 💪")
354
+ if st.button("Download Model ⬇️"):
355
+ config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=model_name, base_model=base_model, size="small", domain=domain if model_type == "Causal LM" else None)
356
+ builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
357
+ builder.load_model(base_model, config)
358
+ builder.save_model(config.model_path)
359
+ st.session_state['builder'] = builder
360
+ st.session_state['model_loaded'] = True
361
+ st.rerun()
362
+
363
+ with tab2:
364
+ st.header("Camera Snap 📷 (Dual Capture!)")
365
+ slice_count = st.number_input("Image Slice Count 🎞️", min_value=1, max_value=20, value=10, help="How many snaps to dream of? (Automation’s on vacation! 😜)")
366
+ video_length = st.number_input("Video Dream Length (seconds) 🎥", min_value=1, max_value=30, value=10, help="Imagine a vid this long—sadly, we’re stuck with pics for now! 😂")
367
+ cols = st.columns(2)
368
+ with cols[0]:
369
+ st.subheader("Camera 0 🎬")
370
+ cam0_img = st.camera_input("Snap a Shot - Cam 0 📸", key="cam0", help="Click to capture a heroic moment! 🦸‍♂️")
371
+ if cam0_img:
372
+ filename = generate_filename(0)
373
+ with open(filename, "wb") as f:
374
+ f.write(cam0_img.getvalue())
375
+ st.image(Image.open(filename), caption=filename, use_container_width=True)
376
+ logger.info(f"Saved snapshot from Camera 0: {filename}")
377
+ st.session_state['captured_images'].append(filename)
378
+ update_gallery()
379
+ st.info("🚨 Multi-frame capture’s on strike! Snap one at a time—your Titan’s too cool for automation glitches! 😎")
380
+ with cols[1]:
381
+ st.subheader("Camera 1 🎥")
382
+ cam1_img = st.camera_input("Snap a Shot - Cam 1 📸", key="cam1", help="Grab another epic frame! 🌟")
383
+ if cam1_img:
384
+ filename = generate_filename(1)
385
+ with open(filename, "wb") as f:
386
+ f.write(cam1_img.getvalue())
387
+ st.image(Image.open(filename), caption=filename, use_container_width=True)
388
+ logger.info(f"Saved snapshot from Camera 1: {filename}")
389
+ st.session_state['captured_images'].append(filename)
390
+ update_gallery()
391
+ st.info("🚨 Frame bursts? Nope, manual snaps only! One click, one masterpiece! 🎨")
392
+
393
+ with tab3: # Fine-Tune Titan (NLP)
394
+ st.header("Fine-Tune Titan (NLP) 🔧 (Teach Your Word Wizard Some Tricks!)")
395
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False) or not isinstance(st.session_state['builder'], ModelBuilder):
396
+ st.warning("Please build or load an NLP Titan first! ⚠️ (No word wizard, no magic!)")
397
+ else:
398
+ if st.button("Generate Sample CSV 📝"):
399
+ sample_data = [
400
+ {"prompt": "What is AI?", "response": "AI is artificial intelligence, simulating human smarts in machines."},
401
+ {"prompt": "Explain machine learning", "response": "Machine learning is AI’s gym where models bulk up on data."},
402
+ {"prompt": "What is a neural network?", "response": "A neural network is a brainy AI mimicking human noggins."},
403
+ ]
404
+ csv_path = f"sft_data_{int(time.time())}.csv"
405
+ with open(csv_path, "w", newline="") as f:
406
+ writer = csv.DictWriter(f, fieldnames=["prompt", "response"])
407
+ writer.writeheader()
408
+ writer.writerows(sample_data)
409
+ st.markdown(get_download_link(csv_path, "text/csv", "Download Sample CSV"), unsafe_allow_html=True)
410
+ st.success(f"Sample CSV generated as {csv_path}! ✅ (Fresh from the data oven!)")
411
+ uploaded_csv = st.file_uploader("Upload CSV for SFT 📜", type="csv", help="Feed your Titan some tasty prompt-response pairs! 🍽️")
412
+ if uploaded_csv and st.button("Fine-Tune with Uploaded CSV 🔄"):
413
+ csv_path = f"uploaded_sft_data_{int(time.time())}.csv"
414
+ with open(csv_path, "wb") as f:
415
+ f.write(uploaded_csv.read())
416
+ new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
417
+ new_config = ModelConfig(name=new_model_name, base_model=st.session_state['builder'].config.base_model, size="small", domain=st.session_state['builder'].config.domain)
418
+ st.session_state['builder'].config = new_config
419
+ with st.status("Fine-tuning NLP Titan... ⏳ (Whipping words into shape!)", expanded=True) as status:
420
+ st.session_state['builder'].fine_tune_sft(csv_path)
421
+ st.session_state['builder'].save_model(new_config.model_path)
422
+ status.update(label="Fine-tuning completed! 🎉 (Wordsmith Titan unleashed!)", state="complete")
423
+ zip_path = f"{new_config.model_path}.zip"
424
+ zip_directory(new_config.model_path, zip_path)
425
+ st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned NLP Titan"), unsafe_allow_html=True)
426
+
427
+ with tab4: # Test Titan (NLP)
428
+ st.header("Test Titan (NLP) 🧪 (Put Your Word Wizard to the Test!)")
429
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False) or not isinstance(st.session_state['builder'], ModelBuilder):
430
+ st.warning("Please build or load an NLP Titan first! ⚠️ (No word wizard, no test drive!)")
431
+ else:
432
+ if st.session_state['builder'].sft_data:
433
+ st.write("Testing with SFT Data:")
434
+ with st.spinner("Running SFT data tests... ⏳ (Titan’s flexing its word muscles!)"):
435
+ for item in st.session_state['builder'].sft_data[:3]:
436
+ prompt = item["prompt"]
437
+ expected = item["response"]
438
+ status_container = st.empty()
439
+ generated = st.session_state['builder'].evaluate(prompt, status_container)
440
+ st.write(f"**Prompt**: {prompt}")
441
+ st.write(f"**Expected**: {expected}")
442
+ st.write(f"**Generated**: {generated} (Titan says: '{random.choice(['Bleep bloop!', 'I am groot!', '42!'])}')")
443
+ st.write("---")
444
+ status_container.empty()
445
+ test_prompt = st.text_area("Enter Test Prompt 🗣️", "What is AI?", help="Ask your Titan anything—it’s ready to chat! 😜")
446
+ if st.button("Run Test ▶️"):
447
+ with st.spinner("Testing your prompt... ⏳ (Titan’s pondering deeply!)"):
448
+ status_container = st.empty()
449
+ result = st.session_state['builder'].evaluate(test_prompt, status_container)
450
+ st.write(f"**Generated Response**: {result} (Titan’s wisdom unleashed!)")
451
+ status_container.empty()
452
+
453
+ with tab5: # Agentic RAG Party (NLP)
454
+ st.header("Agentic RAG Party (NLP) 🌐 (Party Like It’s 2099!)")
455
+ st.write("This demo uses your SFT-tuned NLP Titan to plan a superhero party with mock retrieval!")
456
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False) or not isinstance(st.session_state['builder'], ModelBuilder):
457
+ st.warning("Please build or load an NLP Titan first! ⚠️ (No word wizard, no party!)")
458
+ else:
459
+ if st.button("Run NLP RAG Demo 🎉"):
460
+ with st.spinner("Loading your SFT-tuned NLP Titan... ⏳ (Titan’s suiting up!)"):
461
+ agent = PartyPlannerAgent(st.session_state['builder'].model, st.session_state['builder'].tokenizer)
462
+ st.write("Agent ready! 🦸‍♂️ (Time to plan an epic bash!)")
463
+ task = """
464
+ Plan a luxury superhero-themed party at Wayne Manor (42.3601° N, 71.0589° W).
465
+ Use mock search results for the latest superhero party trends, refine for luxury elements
466
+ (decorations, entertainment, catering), and calculate cargo travel times from key locations
467
+ (New York: 40.7128° N, 74.0060° W; LA: 34.0522° N, 118.2437° W; London: 51.5074° N, 0.1278° W)
468
+ to Wayne Manor. Create a plan with at least 6 entries in a pandas dataframe.
469
+ """
470
+ with st.spinner("Planning the ultimate superhero bash... ⏳ (Calling all caped crusaders!)"):
471
+ try:
472
+ locations = {
473
+ "Wayne Manor": (42.3601, -71.0589),
474
+ "New York": (40.7128, -74.0060),
475
+ "Los Angeles": (34.0522, -118.2437),
476
+ "London": (51.5074, -0.1278)
477
+ }
478
+ wayne_coords = locations["Wayne Manor"]
479
+ travel_times = {loc: calculate_cargo_travel_time(coords, wayne_coords) for loc, coords in locations.items() if loc != "Wayne Manor"}
480
+ search_result = mock_search("superhero party trends")
481
+ prompt = f"""
482
+ Given this context from a search: "{search_result}"
483
+ Plan a luxury superhero-themed party at Wayne Manor. Suggest luxury decorations, entertainment, and catering ideas.
484
+ """
485
+ plan_text = agent.generate(prompt)
486
+ catchphrases = ["To the Batmobile!", "Avengers, assemble!", "I am Iron Man!", "By the power of Grayskull!"]
487
+ data = [
488
+ {"Location": "New York", "Travel Time (hrs)": travel_times["New York"], "Luxury Idea": "Gold-plated Batman statues", "Catchphrase": random.choice(catchphrases)},
489
+ {"Location": "Los Angeles", "Travel Time (hrs)": travel_times["Los Angeles"], "Luxury Idea": "Holographic Avengers displays", "Catchphrase": random.choice(catchphrases)},
490
+ {"Location": "London", "Travel Time (hrs)": travel_times["London"], "Luxury Idea": "Live stunt shows with Iron Man suits", "Catchphrase": random.choice(catchphrases)},
491
+ {"Location": "Wayne Manor", "Travel Time (hrs)": 0.0, "Luxury Idea": "VR superhero battles", "Catchphrase": random.choice(catchphrases)},
492
+ {"Location": "New York", "Travel Time (hrs)": travel_times["New York"], "Luxury Idea": "Gourmet kryptonite-green cocktails", "Catchphrase": random.choice(catchphrases)},
493
+ {"Location": "Los Angeles", "Travel Time (hrs)": travel_times["Los Angeles"], "Luxury Idea": "Thor’s hammer-shaped appetizers", "Catchphrase": random.choice(catchphrases)},
494
+ ]
495
+ plan_df = pd.DataFrame(data)
496
+ st.write("Agentic RAG Party Plan:")
497
+ st.dataframe(plan_df)
498
+ st.write("Party on, Wayne! 🦸‍♂️🎉")
499
+ except Exception as e:
500
+ st.error(f"Error planning party: {str(e)} (Even Superman has kryptonite days!)")
501
+
502
+ with tab6: # Fine-Tune Titan (CV)
503
+ st.header("Fine-Tune Titan (CV) 🔧 (Paint Your Titan’s Masterpiece!)")
504
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False) or not isinstance(st.session_state['builder'], DiffusionBuilder):
505
+ st.warning("Please build or load a CV Titan first! ⚠️ (No artist, no canvas!)")
506
+ else:
507
+ captured_images = get_gallery_files(["png"])
508
+ if len(captured_images) >= 2:
509
+ demo_data = [{"image": img, "text": f"Superhero {os.path.basename(img).split('.')[0]}"} for img in captured_images[:min(len(captured_images), 10)]]
510
+ edited_data = st.data_editor(pd.DataFrame(demo_data), num_rows="dynamic", help="Craft your image-text pairs like a superhero artist! 🎨")
511
+ if st.button("Fine-Tune with Dataset 🔄"):
512
+ images = [Image.open(row["image"]) for _, row in edited_data.iterrows()]
513
+ texts = [row["text"] for _, row in edited_data.iterrows()]
514
+ new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
515
+ new_config = DiffusionConfig(name=new_model_name, base_model=st.session_state['builder'].config.base_model, size="small")
516
+ st.session_state['builder'].config = new_config
517
+ with st.status("Fine-tuning CV Titan... ⏳ (Brushing up those pixels!)", expanded=True) as status:
518
+ st.session_state['builder'].fine_tune_sft(images, texts)
519
+ st.session_state['builder'].save_model(new_config.model_path)
520
+ status.update(label="Fine-tuning completed! 🎉 (Pixel Titan unleashed!)", state="complete")
521
+ zip_path = f"{new_config.model_path}.zip"
522
+ zip_directory(new_config.model_path, zip_path)
523
+ st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned CV Titan"), unsafe_allow_html=True)
524
+ csv_path = f"sft_dataset_{int(time.time())}.csv"
525
+ with open(csv_path, "w", newline="") as f:
526
+ writer = csv.writer(f)
527
+ writer.writerow(["image", "text"])
528
+ for _, row in edited_data.iterrows():
529
+ writer.writerow([row["image"], row["text"]])
530
+ st.markdown(get_download_link(csv_path, "text/csv", "Download SFT Dataset CSV"), unsafe_allow_html=True)
531
+
532
+ with tab7: # Test Titan (CV)
533
+ st.header("Test Titan (CV) 🧪 (Unleash Your Pixel Power!)")
534
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False) or not isinstance(st.session_state['builder'], DiffusionBuilder):
535
+ st.warning("Please build or load a CV Titan first! ⚠️ (No artist, no masterpiece!)")
536
+ else:
537
+ test_prompt = st.text_area("Enter Test Prompt 🎨", "Neon Batman", help="Dream up a wild image—your Titan’s got the brush! 🖌️")
538
+ if st.button("Run Test ▶️"):
539
+ with st.spinner("Painting your masterpiece... ⏳ (Titan’s mixing colors!)"):
540
+ image = st.session_state['builder'].generate(test_prompt)
541
+ st.image(image, caption="Generated Image", use_container_width=True)
542
+
543
+ with tab8: # Agentic RAG Party (CV)
544
+ st.header("Agentic RAG Party (CV) 🌐 (Party with Pixels!)")
545
+ st.write("This demo uses your SFT-tuned CV Titan to generate superhero party images with mock retrieval!")
546
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False) or not isinstance(st.session_state['builder'], DiffusionBuilder):
547
+ st.warning("Please build or load a CV Titan first! ⚠️ (No artist, no party!)")
548
+ else:
549
+ if st.button("Run CV RAG Demo 🎉"):
550
+ with st.spinner("Loading your SFT-tuned CV Titan... ⏳ (Titan’s grabbing its paintbrush!)"):
551
+ agent = CVPartyPlannerAgent(st.session_state['builder'].pipeline)
552
+ st.write("Agent ready! 🎨 (Time to paint an epic bash!)")
553
+ task = "Generate images for a luxury superhero-themed party."
554
+ with st.spinner("Crafting superhero party visuals... ⏳ (Pixels assemble!)"):
555
+ plan_df = agent.plan_party(task)
556
+ st.dataframe(plan_df)
557
+ for _, row in plan_df.iterrows():
558
+ image = agent.generate(row["Image Idea"])
559
+ st.image(image, caption=f"{row['Theme']} - {row['Image Idea']}", use_container_width=True)
560
+
561
+ # Display Logs
562
+ st.sidebar.subheader("Action Logs 📜")
563
+ log_container = st.sidebar.empty()
564
+ with log_container:
565
+ for record in log_records:
566
+ st.write(f"{record.asctime} - {record.levelname} - {record.message}")
567
+
568
+ # Initial Gallery Update
569
+ update_gallery()