#!/usr/bin/env python3
import os
import shutil
import glob
import base64
import streamlit as st
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
import csv
import time
from dataclasses import dataclass
from typing import Optional, Tuple
import zipfile
import math
from PIL import Image
import random
import logging
from datetime import datetime
import pytz
from diffusers import StableDiffusionPipeline  # For diffusion models
from urllib.parse import quote

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Page Configuration
st.set_page_config(
    page_title="SFT Tiny Titans ๐Ÿš€",
    page_icon="๐Ÿค–",
    layout="wide",
    initial_sidebar_state="expanded",
    menu_items={
        'Get Help': 'https://huggingface.co/awacke1',
        'Report a bug': 'https://huggingface.co/spaces/awacke1',
        'About': "Tiny Titans: Small models, big dreams, and a sprinkle of chaos! ๐ŸŒŒ"
    }
)

# Model Configuration Classes
@dataclass
class ModelConfig:
    name: str
    base_model: str
    size: str
    domain: Optional[str] = None
    model_type: str = "causal_lm"
    
    @property
    def model_path(self):
        return f"models/{self.name}"

@dataclass
class DiffusionConfig:
    name: str
    base_model: str
    size: str
    
    @property
    def model_path(self):
        return f"diffusion_models/{self.name}"

# Datasets
class SFTDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        prompt = self.data[idx]["prompt"]
        response = self.data[idx]["response"]
        full_text = f"{prompt} {response}"
        full_encoding = self.tokenizer(full_text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
        prompt_encoding = self.tokenizer(prompt, max_length=self.max_length, padding=False, truncation=True, return_tensors="pt")
        input_ids = full_encoding["input_ids"].squeeze()
        attention_mask = full_encoding["attention_mask"].squeeze()
        labels = input_ids.clone()
        prompt_len = prompt_encoding["input_ids"].shape[1]
        if prompt_len < self.max_length:
            labels[:prompt_len] = -100
        return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

class DiffusionDataset(Dataset):
    def __init__(self, images, texts):
        self.images = images
        self.texts = texts

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return {"image": self.images[idx], "text": self.texts[idx]}

# Model Builder Classes
class ModelBuilder:
    def __init__(self):
        self.config = None
        self.model = None
        self.tokenizer = None
        self.sft_data = None
        self.jokes = ["Why did the AI go to therapy? Too many layers to unpack! ๐Ÿ˜‚", "Training complete! Time for a binary coffee break. โ˜•"]

    def load_model(self, model_path: str, config: Optional[ModelConfig] = None):
        with st.spinner(f"Loading {model_path}... โณ"):
            self.model = AutoModelForCausalLM.from_pretrained(model_path)
            self.tokenizer = AutoTokenizer.from_pretrained(model_path)
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            if config:
                self.config = config
        st.success(f"Model loaded! ๐ŸŽ‰ {random.choice(self.jokes)}")
        return self

    def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
        self.sft_data = []
        with open(csv_path, "r") as f:
            reader = csv.DictReader(f)
            for row in reader:
                self.sft_data.append({"prompt": row["prompt"], "response": row["response"]})

        dataset = SFTDataset(self.sft_data, self.tokenizer)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)

        self.model.train()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(device)
        for epoch in range(epochs):
            with st.spinner(f"Training epoch {epoch + 1}/{epochs}... โš™๏ธ"):
                total_loss = 0
                for batch in dataloader:
                    optimizer.zero_grad()
                    input_ids = batch["input_ids"].to(device)
                    attention_mask = batch["attention_mask"].to(device)
                    labels = batch["labels"].to(device)
                    outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                    loss = outputs.loss
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()
                st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
        st.success(f"SFT Fine-tuning completed! ๐ŸŽ‰ {random.choice(self.jokes)}")
        return self

    def save_model(self, path: str):
        with st.spinner("Saving model... ๐Ÿ’พ"):
            os.makedirs(os.path.dirname(path), exist_ok=True)
            self.model.save_pretrained(path)
            self.tokenizer.save_pretrained(path)
        st.success(f"Model saved at {path}! โœ…")

    def evaluate(self, prompt: str, status_container=None):
        self.model.eval()
        if status_container:
            status_container.write("Preparing to evaluate... ๐Ÿง ")
        try:
            with torch.no_grad():
                inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.model.device)
                outputs = self.model.generate(**inputs, max_new_tokens=50, do_sample=True, top_p=0.95, temperature=0.7)
                return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        except Exception as e:
            if status_container:
                status_container.error(f"Oops! Something broke: {str(e)} ๐Ÿ’ฅ")
            return f"Error: {str(e)}"

class DiffusionBuilder:
    def __init__(self):
        self.config = None
        self.pipeline = None

    def load_model(self, model_path: str, config: Optional[DiffusionConfig] = None):
        with st.spinner(f"Loading diffusion model {model_path}... โณ"):
            self.pipeline = StableDiffusionPipeline.from_pretrained(model_path)
            self.pipeline.to("cuda" if torch.cuda.is_available() else "cpu")
            if config:
                self.config = config
        st.success(f"Diffusion model loaded! ๐ŸŽจ")
        return self

    def fine_tune_sft(self, images, texts, epochs=3):
        dataset = DiffusionDataset(images, texts)
        dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
        optimizer = torch.optim.AdamW(self.pipeline.unet.parameters(), lr=1e-5)
        
        self.pipeline.unet.train()
        for epoch in range(epochs):
            with st.spinner(f"Training diffusion epoch {epoch + 1}/{epochs}... โš™๏ธ"):
                total_loss = 0
                for batch in dataloader:
                    optimizer.zero_grad()
                    image = batch["image"].to(self.pipeline.device)
                    text = batch["text"]
                    latents = self.pipeline.vae.encode(image).latent_dist.sample()
                    noise = torch.randn_like(latents)
                    timesteps = torch.randint(0, self.pipeline.scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device)
                    noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)
                    text_embeddings = self.pipeline.text_encoder(self.pipeline.tokenizer(text, return_tensors="pt").input_ids.to(self.pipeline.device))[0]
                    pred_noise = self.pipeline.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample
                    loss = torch.nn.functional.mse_loss(pred_noise, noise)
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()
                st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
        st.success("Diffusion SFT Fine-tuning completed! ๐ŸŽจ")
        return self

    def save_model(self, path: str):
        with st.spinner("Saving diffusion model... ๐Ÿ’พ"):
            os.makedirs(os.path.dirname(path), exist_ok=True)
            self.pipeline.save_pretrained(path)
        st.success(f"Diffusion model saved at {path}! โœ…")

# Utility Functions
def get_download_link(file_path, mime_type="text/plain", label="Download"):
    with open(file_path, 'rb') as f:
        data = f.read()
    b64 = base64.b64encode(data).decode()
    return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label} ๐Ÿ“ฅ</a>'

def zip_directory(directory_path, zip_path):
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for root, _, files in os.walk(directory_path):
            for file in files:
                file_path = os.path.join(root, file)
                arcname = os.path.relpath(file_path, os.path.dirname(directory_path))
                zipf.write(file_path, arcname)

def get_model_files(model_type="causal_lm"):
    path = "models/*" if model_type == "causal_lm" else "diffusion_models/*"
    return [d for d in glob.glob(path) if os.path.isdir(d)]

def get_gallery_files(file_types):
    files = []
    for ext in file_types:
        files.extend(glob.glob(f"*.{ext}"))
    return sorted(files)

def generate_filename(text_line):
    central = pytz.timezone('US/Central')
    timestamp = datetime.now(central).strftime("%Y%m%d_%I%M%S_%p")
    safe_text = ''.join(c if c.isalnum() else '_' for c in text_line[:50])
    return f"{timestamp}_{safe_text}.png"

def display_search_links(query):
    search_urls = {
        "ArXiv": f"https://arxiv.org/search/?query={quote(query)}",
        "Wikipedia": f"https://en.wikipedia.org/wiki/{quote(query)}",
        "Google": f"https://www.google.com/search?q={quote(query)}",
        "YouTube": f"https://www.youtube.com/results?search_query={quote(query)}"
    }
    links_md = ' '.join([f"[{name}]({url})" for name, url in search_urls.items()])
    return links_md

# Agent Class
class PartyPlannerAgent:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def generate(self, prompt: str) -> str:
        self.model.eval()
        with torch.no_grad():
            inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.device)
            outputs = self.model.generate(**inputs, max_new_tokens=100, do_sample=True, top_p=0.95, temperature=0.7)
            return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

    def plan_party(self, task: str) -> pd.DataFrame:
        search_result = "Latest trends for 2025: Gold-plated Batman statues, VR superhero battles."
        prompt = f"Given this context: '{search_result}'\n{task}"
        plan_text = self.generate(prompt)
        st.markdown(f"Search Links: {display_search_links('superhero party trends')}", unsafe_allow_html=True)
        
        locations = {"Wayne Manor": (42.3601, -71.0589), "New York": (40.7128, -74.0060), "Los Angeles": (34.0522, -118.2437), "London": (51.5074, -0.1278)}
        wayne_coords = locations["Wayne Manor"]
        travel_times = {loc: calculate_cargo_travel_time(coords, wayne_coords) for loc, coords in locations.items() if loc != "Wayne Manor"}
        
        data = [
            {"Location": "New York", "Travel Time (hrs)": travel_times["New York"], "Luxury Idea": "Gold-plated Batman statues"},
            {"Location": "Los Angeles", "Travel Time (hrs)": travel_times["Los Angeles"], "Luxury Idea": "VR superhero battles"},
            {"Location": "London", "Travel Time (hrs)": travel_times["London"], "Luxury Idea": "Live stunt shows"},
            {"Location": "Wayne Manor", "Travel Time (hrs)": 0.0, "Luxury Idea": "Holographic displays"}
        ]
        return pd.DataFrame(data)

def calculate_cargo_travel_time(origin_coords: Tuple[float, float], destination_coords: Tuple[float, float], cruising_speed_kmh: float = 750.0) -> float:
    def to_radians(degrees: float) -> float:
        return degrees * (math.pi / 180)
    lat1, lon1 = map(to_radians, origin_coords)
    lat2, lon2 = map(to_radians, destination_coords)
    EARTH_RADIUS_KM = 6371.0
    dlon = lon2 - lon1
    dlat = lat2 - lat1
    a = (math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2)
    c = 2 * math.asin(math.sqrt(a))
    distance = EARTH_RADIUS_KM * c
    actual_distance = distance * 1.1
    flight_time = (actual_distance / cruising_speed_kmh) + 1.0
    return round(flight_time, 2)

# Main App
st.title("SFT Tiny Titans ๐Ÿš€ (Small but Mighty!)")

# Sidebar Galleries
st.sidebar.header("Galleries ๐ŸŽจ")
for gallery_type, file_types in [
    ("Image Gallery ๐Ÿ“ธ", ["png", "jpg", "jpeg"]),
    ("Video Gallery ๐ŸŽฅ", ["mp4"]),
    ("Audio Gallery ๐ŸŽถ", ["mp3"])
]:
    st.sidebar.subheader(gallery_type)
    files = get_gallery_files(file_types)
    if files:
        cols_num = st.sidebar.slider(f"{gallery_type} Columns", 1, 5, 3, key=f"{gallery_type}_cols")
        cols = st.sidebar.columns(cols_num)
        for idx, file in enumerate(files[:cols_num * 2]):
            with cols[idx % cols_num]:
                if "Image" in gallery_type:
                    st.image(Image.open(file), caption=file, use_column_width=True)
                elif "Video" in gallery_type:
                    st.video(file)
                elif "Audio" in gallery_type:
                    st.audio(file)

st.sidebar.subheader("Model Management ๐Ÿ—‚๏ธ")
model_type = st.sidebar.selectbox("Model Type", ["Causal LM", "Diffusion"])
model_dirs = get_model_files("causal_lm" if model_type == "Causal LM" else "diffusion")
selected_model = st.sidebar.selectbox("Select Saved Model", ["None"] + model_dirs)
if selected_model != "None" and st.sidebar.button("Load Model ๐Ÿ“‚"):
    if 'builder' not in st.session_state:
        st.session_state['builder'] = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
    config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=os.path.basename(selected_model), base_model="unknown", size="small")
    st.session_state['builder'].load_model(selected_model, config)
    st.session_state['model_loaded'] = True
    st.rerun()

# Tabs
tab1, tab2, tab3, tab4, tab5 = st.tabs(["Build Tiny Titan ๐ŸŒฑ", "Fine-Tune Titan ๐Ÿ”ง", "Test Titan ๐Ÿงช", "Agentic RAG Party ๐ŸŒ", "Diffusion SFT ๐ŸŽจ"])

with tab1:
    st.header("Build Tiny Titan ๐ŸŒฑ")
    model_type = st.selectbox("Model Type", ["Causal LM", "Diffusion"], key="build_type")
    if model_type == "Causal LM":
        base_model = st.selectbox("Select Tiny Model", ["HuggingFaceTB/SmolLM-135M", "HuggingFaceTB/SmolLM-360M", "Qwen/Qwen1.5-0.5B-Chat"])
    else:
        base_model = st.selectbox("Select Tiny Diffusion Model", ["stabilityai/stable-diffusion-2-1", "runwayml/stable-diffusion-v1-5", "CompVis/stable-diffusion-v1-4"])
    model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}")
    if st.button("Download Model โฌ‡๏ธ"):
        config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=model_name, base_model=base_model, size="small")
        builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
        builder.load_model(base_model, config)
        builder.save_model(config.model_path)
        st.session_state['builder'] = builder
        st.session_state['model_loaded'] = True
        st.rerun()

with tab2:
    st.header("Fine-Tune Titan ๐Ÿ”ง")
    if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
        st.warning("Please build or load a Titan first! โš ๏ธ")
    else:
        if isinstance(st.session_state['builder'], ModelBuilder):
            uploaded_csv = st.file_uploader("Upload CSV for SFT", type="csv")
            if uploaded_csv and st.button("Fine-Tune with Uploaded CSV ๐Ÿ”„"):
                csv_path = f"uploaded_sft_data_{int(time.time())}.csv"
                with open(csv_path, "wb") as f:
                    f.write(uploaded_csv.read())
                new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
                new_config = ModelConfig(name=new_model_name, base_model=st.session_state['builder'].config.base_model, size="small")
                st.session_state['builder'].config = new_config
                st.session_state['builder'].fine_tune_sft(csv_path)
                st.session_state['builder'].save_model(new_config.model_path)
                zip_path = f"{new_config.model_path}.zip"
                zip_directory(new_config.model_path, zip_path)
                st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Titan"), unsafe_allow_html=True)

with tab3:
    st.header("Test Titan ๐Ÿงช")
    if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
        st.warning("Please build or load a Titan first! โš ๏ธ")
    else:
        if isinstance(st.session_state['builder'], ModelBuilder):
            test_prompt = st.text_area("Enter Test Prompt", "What is AI?")
            if st.button("Run Test โ–ถ๏ธ"):
                result = st.session_state['builder'].evaluate(test_prompt)
                st.write(f"**Generated Response**: {result}")

with tab4:
    st.header("Agentic RAG Party ๐ŸŒ")
    if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False) or not isinstance(st.session_state['builder'], ModelBuilder):
        st.warning("Please build or load a Causal LM Titan first! โš ๏ธ")
    else:
        if st.button("Run Agentic RAG Demo ๐ŸŽ‰"):
            agent = PartyPlannerAgent(model=st.session_state['builder'].model, tokenizer=st.session_state['builder'].tokenizer)
            task = "Plan a luxury superhero-themed party at Wayne Manor."
            plan_df = agent.plan_party(task)
            st.dataframe(plan_df)

with tab5:
    st.header("Diffusion SFT ๐ŸŽจ")
    if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False) or not isinstance(st.session_state['builder'], DiffusionBuilder):
        st.warning("Please build or load a Diffusion Titan first! โš ๏ธ")
    else:
        uploaded_files = st.file_uploader("Upload Images/Videos", type=["png", "jpg", "jpeg", "mp4", "mp3"], accept_multiple_files=True)
        text_input = st.text_area("Enter Text (one line per image)", "Line 1\nLine 2\nLine 3")
        if uploaded_files and st.button("Fine-Tune Diffusion Model ๐Ÿ”„"):
            images = [Image.open(f) for f in uploaded_files if f.type.startswith("image")]
            texts = text_input.splitlines()
            if len(images) > len(texts):
                texts.extend([""] * (len(images) - len(texts)))
            elif len(texts) > len(images):
                texts = texts[:len(images)]
            
            st.session_state['builder'].fine_tune_sft(images, texts)
            new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
            new_config = DiffusionConfig(name=new_model_name, base_model=st.session_state['builder'].config.base_model, size="small")
            st.session_state['builder'].config = new_config
            st.session_state['builder'].save_model(new_config.model_path)
            
            for img, text in zip(images, texts):
                filename = generate_filename(text)
                img.save(filename)
                st.image(img, caption=filename)
            zip_path = f"{new_config.model_path}.zip"
            zip_directory(new_config.model_path, zip_path)
            st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Diffusion Model"), unsafe_allow_html=True)