import gradio as gr
import numpy as np
import random
import torch
import spaces
from PIL import Image
import os

from models.transformer_sd3 import SD3Transformer2DModel
from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline

from transformers import AutoProcessor, SiglipVisionModel
from huggingface_hub import hf_hub_download


# Constants
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model_path = 'stabilityai/stable-diffusion-3.5-large'
image_encoder_path = "google/siglip-so400m-patch14-384"
ipadapter_path = hf_hub_download(repo_id="InstantX/SD3.5-Large-IP-Adapter", filename="ip-adapter.bin")

transformer = SD3Transformer2DModel.from_pretrained(
    model_path, 
    subfolder="transformer", 
    torch_dtype=torch.bfloat16
)

pipe = StableDiffusion3Pipeline.from_pretrained(
    model_path, 
    transformer=transformer, 
    torch_dtype=torch.bfloat16
).to("cuda")

pipe.init_ipadapter(
    ip_adapter_path=ipadapter_path, 
    image_encoder_path=image_encoder_path, 
    nb_token=64, 
)

def resize_img(image, max_size=1024):
    width, height = image.size
    scaling_factor = min(max_size / width, max_size / height)
    new_width = int(width * scaling_factor)
    new_height = int(height * scaling_factor)
    return image.resize((new_width, new_height), Image.LANCZOS)

@spaces.GPU
def process_image(
    image,
    prompt,
    scale,
    seed,
    randomize_seed,
    width,
    height,
    progress=gr.Progress(track_tqdm=True),
):
    #pipe.to("cuda")
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    
    if image is None:
        return None, seed
    
    # Convert to PIL Image if needed
    if not isinstance(image, Image.Image):
        image = Image.fromarray(image)
    
    # Resize image
    image = resize_img(image)
    
    # Generate the image
    result = pipe(
        clip_image=image,
        prompt=prompt,
        ipadapter_scale=scale,
        width=width,
        height=height,
        generator=torch.Generator().manual_seed(seed)
    ).images[0]
    
    return result, seed

# UI CSS
css = """
#col-container {
    margin: 0 auto;
    max-width: 960px;
}
"""

# Create the Gradio interface
with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("# InstantX's SD3.5 IP Adapter")
        
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(
                    label="Input Image",
                    type="pil"
                )
                scale = gr.Slider(
                    label="Image Scale",
                    minimum=0.0,
                    maximum=1.0,
                    step=0.1,
                    value=0.7,
                )
                prompt = gr.Text(
                    label="Prompt",
                    max_lines=1,
                    placeholder="Enter your prompt",
                )
                run_button = gr.Button("Generate", variant="primary")
            
            with gr.Column():
                result = gr.Image(label="Result")
        
        with gr.Accordion("Advanced Settings", open=False):
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=42,
            )
            
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            
            with gr.Row():
                width = gr.Slider(
                    label="Width",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
                
                height = gr.Slider(
                    label="Height",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
    
    run_button.click(
        fn=process_image,
        inputs=[
            input_image,
            prompt,
            scale,
            seed,
            randomize_seed,
            width,
            height,
        ],
        outputs=[result, seed],
    )

if __name__ == "__main__":
    demo.launch()