import streamlit as st
import torch
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler
from src.mgd_pipelines.mgd_pipe import MGDPipe  # Your MGDPipe implementation


# Load models and pipeline
def load_models(pretrained_model_path, device):
    """
    Load the models required for the MGDPipe.
    Args:
        pretrained_model_path (str): Path or Hugging Face identifier for the model.
        device (torch.device): Device to load the models on.

    Returns:
        MGDPipe: Initialized MGDPipe object.
    """
    # Load components of Stable Diffusion
    tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device)
    vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
    scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
    scheduler.set_timesteps(50)

    # Handle torch.hub checkpoint loading for CPU-only environments
    map_location = torch.device("cpu") if device.type == "cpu" else None

    # Load the UNet model and force map_location for state_dict loading
    unet = torch.hub.load(
        repo_or_dir="aimagelab/multimodal-garment-designer",
        source="github",
        model="mgd",
        pretrained=True,
        dataset="dresscode",  # Change to "vitonhd" if needed
    )

    # Ensure the model state dict is mapped correctly to the CPU if needed
    if device.type == "cpu":
        checkpoint_url = unet.config.get("checkpoint")
        if checkpoint_url:
            state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
            unet.load_state_dict(state_dict)

    # Move UNet to the appropriate device
    unet = unet.to(device)

    # Initialize the pipeline
    pipeline = MGDPipe(
        vae=vae,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        unet=unet,
        scheduler=scheduler,
    )

    return pipeline


# Function to preprocess and generate images
def generate_image(pipeline, sketch, prompt, device):
    """
    Generate an image using the MGDPipe.
    Args:
        pipeline (MGDPipe): Initialized MGDPipe object.
        sketch (PIL.Image.Image): Sketch uploaded by the user.
        prompt (str): Text prompt provided by the user.
        device (torch.device): Device for inference.

    Returns:
        PIL.Image.Image: Generated image.
    """
    # Preprocess the sketch
    sketch = sketch.resize((512, 384)).convert("RGB")
    sketch_tensor = torch.tensor([torch.tensor(sketch, dtype=torch.float32).permute(2, 0, 1) / 255.0]).to(device)

    # Run the pipeline
    output = pipeline(
        prompt=prompt,
        image=torch.zeros_like(sketch_tensor),  # Placeholder for masked image
        mask_image=torch.ones_like(sketch_tensor),  # Placeholder for mask
        pose_map=torch.zeros((1, 3, 64, 48)).to(device),  # Placeholder pose map
        sketch=sketch_tensor,
        guidance_scale=7.5,
        num_inference_steps=50,
    )

    return output.images[0]


# Streamlit Interface
st.title("Garment Designer")
st.write("Upload a sketch and provide a text description to generate garment designs!")

# User inputs
uploaded_file = st.file_uploader("Upload your sketch", type=["png", "jpg", "jpeg"])
text_prompt = st.text_input("Enter a text description for the garment")

if st.button("Generate"):
    if uploaded_file and text_prompt:
        st.write("Loading models...")

        # Detect device (CPU or GPU)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        pretrained_model_path = "runwayml/stable-diffusion-inpainting"  # Change as required

        # Load the pipeline
        pipeline = load_models(pretrained_model_path, device)

        # Load sketch
        sketch = Image.open(uploaded_file)

        # Generate the image
        st.write("Generating the garment design...")
        generated_image = generate_image(pipeline, sketch, text_prompt, device)

        # Display the result
        st.image(generated_image, caption="Generated Garment Design", use_column_width=True)
    else:
        st.error("Please upload a sketch and enter a text description.")