import random
import os
import uuid
from datetime import datetime
import gradio as gr
import numpy as np
import spaces
import torch
from diffusers import DiffusionPipeline
from PIL import Image

# Create permanent storage directory
SAVE_DIR = "saved_images"  # Gradio will handle the persistence
if not os.path.exists(SAVE_DIR):
    os.makedirs(SAVE_DIR, exist_ok=True)

# Load the default image
DEFAULT_IMAGE_PATH = "cover1.webp"

device = "cuda" if torch.cuda.is_available() else "cpu"
repo_id = "black-forest-labs/FLUX.1-dev"
adapter_id = "strangerzonehf/Ctoon-Plus-Plus"

pipeline = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
pipeline.load_lora_weights(adapter_id)
pipeline = pipeline.to(device)

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024

def save_generated_image(image, prompt):
    # Generate unique filename with timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    unique_id = str(uuid.uuid4())[:8]
    filename = f"{timestamp}_{unique_id}.png"
    filepath = os.path.join(SAVE_DIR, filename)
    
    # Save the image
    image.save(filepath)
    
    # Save metadata
    metadata_file = os.path.join(SAVE_DIR, "metadata.txt")
    with open(metadata_file, "a", encoding="utf-8") as f:
        f.write(f"{filename}|{prompt}|{timestamp}\n")
    
    return filepath

@spaces.GPU(duration=120)
def inference(
    prompt: str,
    seed: int,
    randomize_seed: bool,
    width: int,
    height: int,
    guidance_scale: float,
    num_inference_steps: int,
    lora_scale: float,
    progress: gr.Progress = gr.Progress(track_tqdm=True),
):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator(device=device).manual_seed(seed)
    
    image = pipeline(
        prompt=prompt,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        width=width,
        height=height,
        generator=generator,
        joint_attention_kwargs={"scale": lora_scale},
    ).images[0]
    
    # Save the generated image
    filepath = save_generated_image(image, prompt)
    
    # Return the image and seed (gallery removed)
    return image, seed


examples = [
   "A cartoon drawing of a majestic Persian cat wearing a tiny golden hanbok and crown. The cat has sparkling blue eyes and perfectly groomed white fur that seems to glow. It sits with regal posture on a traditional Korean cushion decorated with cloud patterns. The background is a soft pink with delicate cherry blossom petals floating around. The cat's expression shows a mix of dignity and subtle amusement. [trigger]",
   
   "A cartoon drawing of an enthusiastic orange tabby cat in a puffy white chef's hat. The cat stands on its hind legs at a tiny wooden counter, wearing a white apron covered in flour pawprints. Its green eyes are focused intently on the cookie dough it's rolling with a miniature rolling pin. The background is a warm cream color with tiny floating cooking utensils and swirling steam patterns. [trigger]",
   
   "A cartoon drawing of a sophisticated tuxedo cat photographer with round wire-rimmed glasses perched on its nose. The cat balances carefully on a tree branch, one paw holding a vintage camera while its tail curls in concentration. It wears a tiny brown beret and leather camera bag. The background is a soft blue with playful butterfly silhouettes and floating leaves. [trigger]",
   
   "A cartoon drawing of a chubby Scottish Fold cat floating in a space capsule. The cat wears an adorable white spacesuit with colorful patches, its round face visible through the helmet visor. Its paws are batting at star-shaped toys that float around in zero gravity. The background shows a stylized view of Earth and twinkling stars through the capsule window. [trigger]",
   
   "A cartoon drawing of an elegant Siamese ballet dancer cat in mid-twirl. The cat wears a sparkly pink tutu that flares out perfectly, with tiny satin ribbons wrapped around its ankles. Its blue eyes are closed in graceful concentration as it performs a pirouette. The background is a soft lavender with swirling musical notes and floating rose petals. [trigger]",
   
   "A cartoon drawing of an adventurous calico cat riding atop a smiling elephant. The cat wears a tiny khaki explorer's vest filled with equipment, and a safari hat tilted at a jaunty angle. It holds a comically large map while the elephant's trunk curls up playfully. The background is a warm orange sunset with stylized acacia trees and cartoon birds soaring past. [trigger]"
]
css = """
footer {
    visibility: hidden;
}
"""

with gr.Blocks(theme=gr.themes.Soft(), css=css, analytics_enabled=False) as demo:
    gr.HTML('<div class="title"> Cartoon Image Generation </div>')
    
    gr.HTML("""<a href="https://visitorbadge.io/status?path=https%3A%2F%2Fginigen-cartoon.hf.space">
               <img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fginigen-cartoon.hf.space&countColor=%23263759" />
               </a>""")
    
    with gr.Tab("Generation"):
        with gr.Column(elem_id="col-container"):
            with gr.Row():
                prompt = gr.Text(
                    label="Prompt",
                    show_label=False,
                    max_lines=1,
                    placeholder="Enter your prompt",
                    container=False,
                )
                run_button = gr.Button("Run", scale=0)

            result = gr.Image(
                label="Result",
                show_label=False,
                value=DEFAULT_IMAGE_PATH
            )

            with gr.Accordion("Advanced Settings", open=False):
                seed = gr.Slider(
                    label="Seed",
                    minimum=0,
                    maximum=MAX_SEED,
                    step=1,
                    value=42,
                )
                randomize_seed = gr.Checkbox(label="Randomize seed", value=True)

                with gr.Row():
                    width = gr.Slider(
                        label="Width",
                        minimum=256,
                        maximum=MAX_IMAGE_SIZE,
                        step=32,
                        value=1024,
                    )
                    height = gr.Slider(
                        label="Height",
                        minimum=256,
                        maximum=MAX_IMAGE_SIZE,
                        step=32,
                        value=768,
                    )

                with gr.Row():
                    guidance_scale = gr.Slider(
                        label="Guidance scale",
                        minimum=0.0,
                        maximum=10.0,
                        step=0.1,
                        value=3.5,
                    )
                    num_inference_steps = gr.Slider(
                        label="Number of inference steps",
                        minimum=1,
                        maximum=50,
                        step=1,
                        value=30,
                    )
                    lora_scale = gr.Slider(
                        label="LoRA scale",
                        minimum=0.0,
                        maximum=1.0,
                        step=0.1,
                        value=1.0,
                    )

            gr.Examples(
                examples=examples,
                inputs=[prompt],
                outputs=[result, seed],
            )
    
    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn=inference,
        inputs=[
            prompt,
            seed,
            randomize_seed,
            width,
            height,
            guidance_scale,
            num_inference_steps,
            lora_scale,
        ],
        outputs=[result, seed],
    )

demo.queue()
demo.launch()