#!/usr/bin/env python3 import os import base64 import streamlit as st import csv import time from dataclasses import dataclass import zipfile import logging # Logging setup logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) st.set_page_config(page_title="SFT Tiny Titans πŸš€", page_icon="πŸ€–", layout="wide", initial_sidebar_state="expanded") # Model Configurations @dataclass class ModelConfig: name: str base_model: str model_type: str = "causal_lm" @property def model_path(self): return f"models/{self.name}" @dataclass class DiffusionConfig: name: str base_model: str @property def model_path(self): return f"diffusion_models/{self.name}" # Lazy-loaded Builders class ModelBuilder: def __init__(self): self.config = None self.model = None self.tokenizer = None def load_model(self, model_path: str, config: ModelConfig): try: from transformers import AutoModelForCausalLM, AutoTokenizer import torch logger.info(f"Loading NLP model: {model_path}") self.model = AutoModelForCausalLM.from_pretrained(model_path) self.tokenizer = AutoTokenizer.from_pretrained(model_path) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.config = config self.model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) logger.info("NLP model loaded successfully") except Exception as e: logger.error(f"Error loading NLP model: {str(e)}") raise def fine_tune(self, csv_path): try: from torch.utils.data import Dataset, DataLoader import torch logger.info(f"Starting NLP fine-tuning with {csv_path}") class SFTDataset(Dataset): def __init__(self, data, tokenizer): self.data = data self.tokenizer = tokenizer def __len__(self): return len(self.data) def __getitem__(self, idx): prompt = self.data[idx]["prompt"] response = self.data[idx]["response"] inputs = self.tokenizer(f"{prompt} {response}", return_tensors="pt", padding="max_length", max_length=128, truncation=True) labels = inputs["input_ids"].clone() labels[0, :len(self.tokenizer(prompt)["input_ids"][0])] = -100 return {"input_ids": inputs["input_ids"][0], "attention_mask": inputs["attention_mask"][0], "labels": labels[0]} data = [] with open(csv_path, "r") as f: reader = csv.DictReader(f) for row in reader: data.append({"prompt": row["prompt"], "response": row["response"]}) dataset = SFTDataset(data, self.tokenizer) dataloader = DataLoader(dataset, batch_size=2) optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5) self.model.train() for _ in range(1): for batch in dataloader: optimizer.zero_grad() outputs = self.model(**{k: v.to(self.model.device) for k, v in batch.items()}) outputs.loss.backward() optimizer.step() logger.info("NLP fine-tuning completed") except Exception as e: logger.error(f"Error in NLP fine-tuning: {str(e)}") raise def evaluate(self, prompt: str): try: import torch logger.info(f"Evaluating NLP with prompt: {prompt}") self.model.eval() with torch.no_grad(): inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.model.device) outputs = self.model.generate(**inputs, max_new_tokens=50) result = self.tokenizer.decode(outputs[0], skip_special_tokens=True) logger.info(f"NLP evaluation result: {result}") return result except Exception as e: logger.error(f"Error in NLP evaluation: {str(e)}") raise class DiffusionBuilder: def __init__(self): self.config = None self.pipeline = None def load_model(self, model_path: str, config: DiffusionConfig): try: from diffusers import StableDiffusionPipeline import torch logger.info(f"Loading diffusion model: {model_path}") self.pipeline = StableDiffusionPipeline.from_pretrained(model_path) self.pipeline.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) self.config = config logger.info("Diffusion model loaded successfully") except Exception as e: logger.error(f"Error loading diffusion model: {str(e)}") raise def fine_tune(self, images, texts): try: import torch from PIL import Image import numpy as np logger.info("Starting diffusion fine-tuning") optimizer = torch.optim.AdamW(self.pipeline.unet.parameters(), lr=1e-5) self.pipeline.unet.train() for _ in range(1): for img, text in zip(images, texts): optimizer.zero_grad() img_tensor = torch.tensor(np.array(img)).permute(2, 0, 1).unsqueeze(0).float().to(self.pipeline.device) / 255.0 # Normalize latents = self.pipeline.vae.encode(img_tensor).latent_dist.sample() noise = torch.randn_like(latents) timesteps = torch.randint(0, self.pipeline.scheduler.num_train_timesteps, (1,), device=latents.device) noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps) text_emb = self.pipeline.text_encoder(self.pipeline.tokenizer(text, return_tensors="pt").input_ids.to(self.pipeline.device))[0] pred_noise = self.pipeline.unet(noisy_latents, timesteps, encoder_hidden_states=text_emb).sample loss = torch.nn.functional.mse_loss(pred_noise, noise) loss.backward() optimizer.step() logger.info("Diffusion fine-tuning completed") except Exception as e: logger.error(f"Error in diffusion fine-tuning: {str(e)}") raise def generate(self, prompt: str): try: logger.info(f"Generating image with prompt: {prompt}") img = self.pipeline(prompt, num_inference_steps=20).images[0] logger.info("Image generated successfully") return img except Exception as e: logger.error(f"Error in image generation: {str(e)}") raise # Utilities def get_download_link(file_path, mime_type="text/plain", label="Download"): with open(file_path, 'rb') as f: data = f.read() b64 = base64.b64encode(data).decode() return f'{label} πŸ“₯' def generate_filename(sequence): from datetime import datetime import pytz central = pytz.timezone('US/Central') timestamp = datetime.now(central).strftime("%d%m%Y%H%M%S%p") return f"{sequence}{timestamp}.png" def get_gallery_files(file_types): import glob return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")]) def zip_files(files, zip_name): with zipfile.ZipFile(zip_name, 'w', zipfile.ZIP_DEFLATED) as zipf: for file in files: zipf.write(file, os.path.basename(file)) return zip_name # Video Processor for WebRTC class VideoSnapshot: def __init__(self): self.snapshot = None def recv(self, frame): from PIL import Image img = frame.to_image() self.snapshot = img return frame def take_snapshot(self): return self.snapshot # Main App st.title("SFT Tiny Titans πŸš€ (Capture & Tune!)") # Sidebar Galleries st.sidebar.header("Captured Images 🎨") image_files = get_gallery_files(["png"]) if image_files: cols = st.sidebar.columns(2) for idx, file in enumerate(image_files[:4]): with cols[idx % 2]: from PIL import Image st.image(Image.open(file), caption=file.split('/')[-1], use_container_width=True) # Sidebar Model Management st.sidebar.subheader("Model Hub πŸ—‚οΈ") model_type = st.sidebar.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"]) model_options = { "NLP (Causal LM)": "HuggingFaceTB/SmolLM-135M", "CV (Diffusion)": ["CompVis/stable-diffusion-v1-4", "stabilityai/stable-diffusion-2-base", "runwayml/stable-diffusion-v1-5"] } selected_model = st.sidebar.selectbox("Select Model", ["None"] + ([model_options[model_type]] if "NLP" in model_type else model_options[model_type])) if selected_model != "None" and st.sidebar.button("Load Model πŸ“‚"): builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder() config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=f"titan_{int(time.time())}", base_model=selected_model) with st.spinner("Loading... ⏳"): try: builder.load_model(selected_model, config) st.session_state['builder'] = builder st.session_state['model_loaded'] = True st.success("Model loaded! πŸŽ‰") except Exception as e: st.error(f"Load failed: {str(e)}") # Tabs tab1, tab2, tab3, tab4 = st.tabs(["Build Titan 🌱", "Camera Snap πŸ“·", "Fine-Tune Titans πŸ”§", "Test Titans πŸ§ͺ"]) with tab1: st.header("Build Titan 🌱 (Quick Start!)") model_type = st.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"], key="build_type") base_model = st.selectbox("Select Model", model_options[model_type], key="build_model") if st.button("Download Model ⬇️"): config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=f"titan_{int(time.time())}", base_model=base_model) builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder() with st.spinner("Fetching... ⏳"): try: builder.load_model(base_model, config) st.session_state['builder'] = builder st.session_state['model_loaded'] = True st.success("Titan up! πŸŽ‰") except Exception as e: st.error(f"Download failed: {str(e)}") with tab2: st.header("Camera Snap πŸ“· (Sequence Shots!)") from streamlit_webrtc import webrtc_streamer ctx = webrtc_streamer( key="camera", video_processor_factory=VideoSnapshot, frontend_rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]} ) if ctx.video_processor: delay = st.slider("Delay between captures (seconds)", 0, 10, 2) if st.button("Capture 6 Frames πŸ“Έ"): logger.info("Starting 6-frame capture") captured_images = [] try: for i in range(6): snapshot = ctx.video_processor.take_snapshot() if snapshot: filename = generate_filename(i) snapshot.save(filename) st.image(snapshot, caption=filename, use_container_width=True) captured_images.append(filename) logger.info(f"Captured frame {i}: {filename}") time.sleep(delay) st.success("6 frames captured! πŸŽ‰") st.session_state['captured_images'] = captured_images except Exception as e: st.error(f"Capture failed: {str(e)}") logger.error(f"Error during capture: {str(e)}") if 'captured_images' in st.session_state and len(st.session_state['captured_images']) >= 2: st.subheader("Diffusion SFT Dataset 🎨") sample_texts = ["Neon Hero", "Glowing Cape", "Spark Flyer", "Dark Knight", "Iron Shine", "Thunder Bolt"] dataset = list(zip(st.session_state['captured_images'], sample_texts[:len(st.session_state['captured_images'])])) st.code("\n".join([f"{i+1}. {text} -> {img}" for i, (img, text) in enumerate(dataset)]), language="text") if st.button("Download Dataset CSV πŸ“"): logger.info("Generating dataset CSV") try: csv_path = f"diffusion_sft_{int(time.time())}.csv" with open(csv_path, "w", newline="") as f: writer = csv.writer(f) writer.writerow(["image", "text"]) for img, text in dataset: writer.writerow([img, text]) st.markdown(get_download_link(csv_path, "text/csv", "Download Dataset CSV"), unsafe_allow_html=True) logger.info("Dataset CSV generated") except Exception as e: st.error(f"CSV generation failed: {str(e)}") logger.error(f"Error generating CSV: {str(e)}") if st.button("Download Images ZIP πŸ“¦"): logger.info("Generating images ZIP") try: zip_path = f"captured_images_{int(time.time())}.zip" zip_files(st.session_state['captured_images'], zip_path) st.markdown(get_download_link(zip_path, "application/zip", "Download Images ZIP"), unsafe_allow_html=True) logger.info("Images ZIP generated") except Exception as e: st.error(f"ZIP generation failed: {str(e)}") logger.error(f"Error generating ZIP: {str(e)}") with tab3: st.header("Fine-Tune Titans πŸ”§ (Tune Fast!)") if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False): st.warning("Load a Titan first! ⚠️") else: if isinstance(st.session_state['builder'], ModelBuilder): st.subheader("NLP Tune 🧠") uploaded_csv = st.file_uploader("Upload CSV", type="csv", key="nlp_csv") if uploaded_csv and st.button("Tune NLP πŸ”„"): logger.info("Initiating NLP fine-tune") try: with open("temp.csv", "wb") as f: f.write(uploaded_csv.read()) st.session_state['builder'].fine_tune("temp.csv") st.success("NLP sharpened! πŸŽ‰") except Exception as e: st.error(f"NLP fine-tune failed: {str(e)}") elif isinstance(st.session_state['builder'], DiffusionBuilder): st.subheader("CV Tune 🎨") captured_images = get_gallery_files(["png"]) if len(captured_images) >= 2: texts = ["Superhero Neon", "Hero Glow", "Cape Spark"][:len(captured_images)] if st.button("Tune CV πŸ”„"): logger.info("Initiating CV fine-tune") try: from PIL import Image images = [Image.open(img) for img in captured_images] st.session_state['builder'].fine_tune(images, texts) st.success("CV polished! πŸŽ‰") except Exception as e: st.error(f"CV fine-tune failed: {str(e)}") else: st.warning("Capture at least 2 images first! ⚠️") with tab4: st.header("Test Titans πŸ§ͺ (Quick Check!)") if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False): st.warning("Load a Titan first! ⚠️") else: if isinstance(st.session_state['builder'], ModelBuilder): st.subheader("NLP Test 🧠") prompt = st.text_area("Prompt", "What’s a superhero?", key="nlp_test") if st.button("Test NLP ▢️"): logger.info("Running NLP test") try: result = st.session_state['builder'].evaluate(prompt) st.write(f"**Answer**: {result}") except Exception as e: st.error(f"NLP test failed: {str(e)}") elif isinstance(st.session_state['builder'], DiffusionBuilder): st.subheader("CV Test 🎨") prompt = st.text_area("Prompt", "Neon Batman", key="cv_test") if st.button("Test CV ▢️"): logger.info("Running CV test") try: with st.spinner("Generating... ⏳"): img = st.session_state['builder'].generate(prompt) st.image(img, caption="Generated Art", use_container_width=True) except Exception as e: st.error(f"CV test failed: {str(e)}") # Display Logs st.sidebar.subheader("Action Logs πŸ“œ") log_container = st.sidebar.empty() with log_container: for record in logger.handlers[0].buffer: st.write(f"{record.asctime} - {record.levelname} - {record.message}")