import os
import json
import copy
import time
import random
import logging
import numpy as np
from typing import Any, Dict, List, Optional, Union

import torch
from PIL import Image
import gradio as gr

from diffusers import (
    DiffusionPipeline,
    AutoencoderTiny,
    AutoencoderKL,
    AutoPipelineForImage2Image,
    FluxPipeline,
    FlowMatchEulerDiscreteScheduler
)

from huggingface_hub import (
    hf_hub_download,
    HfFileSystem,
    ModelCard,
    snapshot_download
)

from diffusers.utils import load_image

import spaces

# Import the prompt enhancer generator from enhance.py
from enhance import generate as enhance_generate

# Attempt to import loras from lora.py; otherwise use a default placeholder.
try:
    from lora import loras
except ImportError:
    loras = [
        {"image": "placeholder.jpg", "title": "Placeholder LoRA", "repo": "placeholder/repo", "weights": None, "trigger_word": ""}
    ]

#---if workspace = local or colab---
# (Optional: add Hugging Face login code here)

def calculate_shift(
    image_seq_len,
    base_seq_len: int = 256,
    max_seq_len: int = 4096,
    base_shift: float = 0.5,
    max_shift: float = 1.16,
):
    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
    b = base_shift - m * base_seq_len
    mu = image_seq_len * m + b
    return mu

def retrieve_timesteps(
    scheduler,
    num_inference_steps: Optional[int] = None,
    device: Optional[Union[str, torch.device]] = None,
    timesteps: Optional[List[int]] = None,
    sigmas: Optional[List[float]] = None,
    **kwargs,
):
    if timesteps is not None and sigmas is not None:
        raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
    if timesteps is not None:
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    elif sigmas is not None:
        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    else:
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        timesteps = scheduler.timesteps
    return timesteps, num_inference_steps

# FLUX pipeline
@torch.inference_mode()
def flux_pipe_call_that_returns_an_iterable_of_images(
    self,
    prompt: Union[str, List[str]] = None,
    prompt_2: Optional[Union[str, List[str]]] = None,
    height: Optional[int] = None,
    width: Optional[int] = None,
    num_inference_steps: int = 28,
    timesteps: List[int] = None,
    guidance_scale: float = 3.5,
    num_images_per_prompt: Optional[int] = 1,
    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
    latents: Optional[torch.FloatTensor] = None,
    prompt_embeds: Optional[torch.FloatTensor] = None,
    pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
    output_type: Optional[str] = "pil",
    return_dict: bool = True,
    joint_attention_kwargs: Optional[Dict[str, Any]] = None,
    max_sequence_length: int = 512,
    good_vae: Optional[Any] = None,
):
    height = height or self.default_sample_size * self.vae_scale_factor
    width = width or self.default_sample_size * self.vae_scale_factor
    
    self.check_inputs(
        prompt,
        prompt_2,
        height,
        width,
        prompt_embeds=prompt_embeds,
        pooled_prompt_embeds=pooled_prompt_embeds,
        max_sequence_length=max_sequence_length,
    )

    self._guidance_scale = guidance_scale
    self._joint_attention_kwargs = joint_attention_kwargs
    self._interrupt = False

    batch_size = 1 if isinstance(prompt, str) else len(prompt)
    device = self._execution_device

    lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
    prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
        prompt=prompt,
        prompt_2=prompt_2,
        prompt_embeds=prompt_embeds,
        pooled_prompt_embeds=pooled_prompt_embeds,
        device=device,
        num_images_per_prompt=num_images_per_prompt,
        max_sequence_length=max_sequence_length,
        lora_scale=lora_scale,
    )
    
    num_channels_latents = self.transformer.config.in_channels // 4
    latents, latent_image_ids = self.prepare_latents(
        batch_size * num_images_per_prompt,
        num_channels_latents,
        height,
        width,
        prompt_embeds.dtype,
        device,
        generator,
        latents,
    )
    
    sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
    image_seq_len = latents.shape[1]
    mu = calculate_shift(
        image_seq_len,
        self.scheduler.config.base_image_seq_len,
        self.scheduler.config.max_image_seq_len,
        self.scheduler.config.base_shift,
        self.scheduler.config.max_shift,
    )
    timesteps, num_inference_steps = retrieve_timesteps(
        self.scheduler,
        num_inference_steps,
        device,
        timesteps,
        sigmas,
        mu=mu,
    )
    self._num_timesteps = len(timesteps)

    guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None

    for i, t in enumerate(timesteps):
        if self.interrupt:
            continue

        timestep = t.expand(latents.shape[0]).to(latents.dtype)

        noise_pred = self.transformer(
            hidden_states=latents,
            timestep=timestep / 1000,
            guidance=guidance,
            pooled_projections=pooled_prompt_embeds,
            encoder_hidden_states=prompt_embeds,
            txt_ids=text_ids,
            img_ids=latent_image_ids,
            joint_attention_kwargs=self.joint_attention_kwargs,
            return_dict=False,
        )[0]

        latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
        latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
        image = self.vae.decode(latents_for_image, return_dict=False)[0]
        yield self.image_processor.postprocess(image, output_type=output_type)[0]
        latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
        torch.cuda.empty_cache()
        
    latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
    latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
    image = good_vae.decode(latents, return_dict=False)[0]
    self.maybe_free_model_hooks()
    torch.cuda.empty_cache()
    yield self.image_processor.postprocess(image, output_type=output_type)[0]

#--------------------------------------------------Model Initialization-----------------------------------------------------------------------------------------#
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "black-forest-labs/FLUX.1-dev"

# TAEF1 is a very tiny autoencoder which uses the same "latent API" as FLUX.1's VAE.
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
    base_model,
    vae=good_vae,
    transformer=pipe.transformer,
    text_encoder=pipe.text_encoder,
    tokenizer=pipe.tokenizer,
    text_encoder_2=pipe.text_encoder_2,
    tokenizer_2=pipe.tokenizer_2,
    torch_dtype=dtype,
).to(device)
MAX_SEED = 2**32-1

pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)

class calculateDuration:
    def __init__(self, activity_name=""):
        self.activity_name = activity_name
    def __enter__(self):
        self.start_time = time.time()
        return self
    def __exit__(self, exc_type, exc_value, traceback):
        self.end_time = time.time()
        self.elapsed_time = self.end_time - self.start_time
        if self.activity_name:
            print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
        else:
            print(f"Elapsed time: {self.elapsed_time:.6f} seconds")

def update_selection(evt: gr.SelectData, width, height):
    selected_lora = loras[evt.index]
    new_placeholder = f"Type a prompt for {selected_lora['title']}"
    lora_repo = selected_lora["repo"]
    updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✅"
    if "aspect" in selected_lora:
        if selected_lora["aspect"] == "portrait":
            width = 768
            height = 1024
        elif selected_lora["aspect"] == "landscape":
            width = 1024
            height = 768
        else:
            width = 1024
            height = 1024
    return (
        gr.update(placeholder=new_placeholder),
        updated_text,
        evt.index,
        width,
        height,
    )

@spaces.GPU(duration=100)
def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress):
    pipe.to("cuda")
    generator = torch.Generator(device="cuda").manual_seed(seed)
    with calculateDuration("Generating image"):
        for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
            prompt=prompt_mash,
            num_inference_steps=steps,
            guidance_scale=cfg_scale,
            width=width,
            height=height,
            generator=generator,
            joint_attention_kwargs={"scale": lora_scale},
            output_type="pil",
            good_vae=good_vae,
        ):
            yield img

def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed):
    generator = torch.Generator(device="cuda").manual_seed(seed)
    pipe_i2i.to("cuda")
    image_input = load_image(image_input_path)
    final_image = pipe_i2i(
        prompt=prompt_mash,
        image=image_input,
        strength=image_strength,
        num_inference_steps=steps,
        guidance_scale=cfg_scale,
        width=width,
        height=height,
        generator=generator,
        joint_attention_kwargs={"scale": lora_scale},
        output_type="pil",
    ).images[0]
    return final_image 

@spaces.GPU(duration=100)
def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, use_enhancer, progress=gr.Progress(track_tqdm=True)):
    # Check if a LoRA is selected.
    if selected_index is None:
        raise gr.Error("You must select a LoRA before proceeding.🧨")
    selected_lora = loras[selected_index]
    lora_path = selected_lora["repo"]
    trigger_word = selected_lora["trigger_word"]
    # Prepare prompt by appending/prepending trigger word if available.
    if trigger_word:
        if "trigger_position" in selected_lora and selected_lora["trigger_position"] == "prepend":
            prompt_mash = f"{trigger_word} {prompt}"
        else:
            prompt_mash = f"{prompt} {trigger_word}"
    else:
        prompt_mash = prompt

    # If prompt enhancer is enabled, stream the enhanced prompt.
    enhanced_text = ""
    if use_enhancer:
        for enhanced_chunk in enhance_generate(prompt_mash):
            enhanced_text = enhanced_chunk
            # Yield intermediate output (no image yet, but update enhanced prompt textbox)
            yield None, seed, gr.update(visible=False), enhanced_text
        prompt_mash = enhanced_text  # Use final enhanced prompt for generation
    # Else, leave prompt_mash as is.
    
    with calculateDuration("Unloading LoRA"):
        pipe.unload_lora_weights()
        pipe_i2i.unload_lora_weights()
        
    with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
        pipe_to_use = pipe_i2i if image_input is not None else pipe
        weight_name = selected_lora.get("weights", None)
        pipe_to_use.load_lora_weights(
            lora_path, 
            weight_name=weight_name, 
            low_cpu_mem_usage=True
        )
            
    with calculateDuration("Randomizing seed"):
        if randomize_seed:
            seed = random.randint(0, MAX_SEED)
            
    if image_input is not None:
        final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed)
        yield final_image, seed, gr.update(visible=False), enhanced_text
    else:
        image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress)
        final_image = None
        step_counter = 0
        for image in image_generator:
            step_counter += 1
            final_image = image
            progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
            yield image, seed, gr.update(value=progress_bar, visible=True), enhanced_text
        yield final_image, seed, gr.update(value=progress_bar, visible=False), enhanced_text
        
def get_huggingface_safetensors(link):
    split_link = link.split("/")
    if len(split_link) == 2:
        model_card = ModelCard.load(link)
        base_model = model_card.data.get("base_model")
        print(base_model)
        if (base_model != "black-forest-labs/FLUX.1-dev") and (base_model != "black-forest-labs/FLUX.1-schnell"):
            raise Exception("Flux LoRA Not Found!")
        image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
        trigger_word = model_card.data.get("instance_prompt", "")
        image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
        fs = HfFileSystem()
        try:
            list_of_files = fs.ls(link, detail=False)
            for file in list_of_files:
                if file.endswith(".safetensors"):
                    safetensors_name = file.split("/")[-1]
                if not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
                    image_elements = file.split("/")
                    image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
        except Exception as e:
            print(e)
            gr.Warning("You didn't include a link nor a valid Hugging Face repository with a *.safetensors LoRA")
            raise Exception("Invalid LoRA repository")
        return split_link[1], link, safetensors_name, trigger_word, image_url
    else:
        raise Exception("Invalid LoRA link format")

def check_custom_model(link):
    if link.startswith("https://"):
        if link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co"):
            link_split = link.split("huggingface.co/")
            return get_huggingface_safetensors(link_split[1])
    else:
        return get_huggingface_safetensors(link)

def add_custom_lora(custom_lora):
    global loras
    if custom_lora:
        try:
            title, repo, path, trigger_word, image = check_custom_model(custom_lora)
            print(f"Loaded custom LoRA: {repo}")
            card = f'''
            <div class="custom_lora_card">
              <span>Loaded custom LoRA:</span>
              <div class="card_internal">
                <img src="{image}" />
                <div>
                    <h3>{title}</h3>
                    <small>{"Using: <code><b>" + trigger_word + "</b></code> as the trigger word" if trigger_word else "No trigger word found. Include it in your prompt"}<br></small>
                </div>
              </div>
            </div>
            '''
            existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
            if not existing_item_index:
                new_item = {
                    "image": image,
                    "title": title,
                    "repo": repo,
                    "weights": path,
                    "trigger_word": trigger_word
                }
                print(new_item)
                existing_item_index = len(loras)
                loras.append(new_item)
        
            return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
        except Exception as e:
            gr.Warning("Invalid LoRA: either you entered an invalid link or a non-FLUX LoRA")
            return gr.update(visible=True, value="Invalid LoRA"), gr.update(visible=False), gr.update(), "", None, ""
    else:
        return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""

def remove_custom_lora():
    return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""

run_lora.zerogpu = True

css = '''
#gen_btn { height: 100%; }
#gen_column { align-self: stretch; }
#title { text-align: center; }
#title h1 { font-size: 3em; display:inline-flex; align-items:center; }
#title img { width: 100px; margin-right: 0.5em; }
#gallery .grid-wrap { height: 10vh; }
#lora_list { background: var(--block-background-fill); padding: 0 1em .3em; font-size: 90%; }
.card_internal { display: flex; height: 100px; margin-top: .5em; }
.card_internal img { margin-right: 1em; }
.styler { --form-gap-width: 0px !important; }
#progress { height:30px; }
#progress .generating { display:none; }
.progress-container { width: 100%; height: 30px; background-color: #f0f0f0; border-radius: 15px; overflow: hidden; margin-bottom: 20px; }
.progress-bar { height: 100%; background-color: #4f46e5; width: calc(var(--current) / var(--total) * 100%); transition: width 0.5s ease-in-out; }
'''

with gr.Blocks(theme=gr.themes.Base(), css=css, delete_cache=(60, 60)) as app:
    title = gr.HTML(
        """<h1>Flux LoRA Generation</h1>""",
        elem_id="title",
    )
    selected_index = gr.State(None)
    with gr.Row():
        with gr.Column(scale=3):
            prompt = gr.Textbox(label="Prompt", lines=1, placeholder=":/ choose the LoRA and type the prompt ")
        with gr.Column(scale=1, elem_id="gen_column"):
            generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
    with gr.Row():
        with gr.Column():
            selected_info = gr.Markdown("")
            gallery = gr.Gallery(
                [(item["image"], item["title"]) for item in loras],
                label="LoRA DLC's",
                allow_preview=False,
                columns=3,
                elem_id="gallery",
                show_share_button=False
            )
            with gr.Group():
                custom_lora = gr.Textbox(label="Enter Custom LoRA", placeholder="prithivMLmods/Canopus-LoRA-Flux-Anime")
                gr.Markdown("[Check the list of FLUX LoRA's](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
            custom_lora_info = gr.HTML(visible=False)
            custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
        with gr.Column():
            progress_bar = gr.Markdown(elem_id="progress", visible=False)
            result = gr.Image(label="Generated Image")
    with gr.Row():
        with gr.Accordion("Advanced Settings", open=False):
            with gr.Row():
                input_image = gr.Image(label="Input image", type="filepath")
                image_strength = gr.Slider(label="Denoise Strength", info="Lower means more image influence", minimum=0.1, maximum=1.0, step=0.01, value=0.75)
            with gr.Column():
                with gr.Row():
                    cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
                    steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
                with gr.Row():
                    width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
                    height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
                with gr.Row():
                    randomize_seed = gr.Checkbox(True, label="Randomize seed")
                    seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
                    lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=3, step=0.01, value=0.95)
                with gr.Row():
                    use_enhancer = gr.Checkbox(value=False, label="Use Prompt Enhancer")
                    show_enhanced_prompt = gr.Checkbox(value=False, label="Display Enhanced Prompt")
                enhanced_prompt_box = gr.Textbox(label="Enhanced Prompt", visible=False)
            # Add the change event so that the enhanced prompt box visibility toggles.
            show_enhanced_prompt.change(fn=lambda show: gr.update(visible=show),
                                        inputs=show_enhanced_prompt,
                                        outputs=enhanced_prompt_box)
    gallery.select(
        update_selection,
        inputs=[width, height],
        outputs=[prompt, selected_info, selected_index, width, height]
    )
    custom_lora.input(
        add_custom_lora,
        inputs=[custom_lora],
        outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt]
    )
    custom_lora_button.click(
        remove_custom_lora,
        outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora]
    )
    gr.on(
        triggers=[generate_button.click, prompt.submit],
        fn=run_lora,
        inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, use_enhancer],
        outputs=[result, seed, progress_bar, enhanced_prompt_box]
    )
    with gr.Row():
        gr.HTML("<div style='text-align:center; font-size:0.9em; margin-top:20px;'>Credits: <a href='https://ruslanmv.com' target='_blank'>ruslanmv.com</a></div>")
    
app.queue()
app.launch(debug=True)