import gradio as gr
import numpy as np
import torch
import spaces
from diffusers import FluxPipeline, FluxTransformer2DModel
from diffusers.utils import export_to_gif
from huggingface_hub import hf_hub_download
from PIL import Image
import uuid
import random

device = "cuda" if torch.cuda.is_available() else "cpu"

if torch.cuda.is_available():
    torch_dtype = torch.bfloat16
else:
    torch_dtype = torch.float32



# 파이프라인 초기화 수정
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", 
    torch_dtype=torch_dtype,
    use_safetensors=True
).to(device)

MAX_SEED = np.iinfo(np.int32).max

def split_image(input_image, num_splits=8):
    width = input_image.width
    height = input_image.height
    split_width = width // num_splits
    output_images = []
    
    for i in range(num_splits):
        left = i * split_width
        right = (i + 1) * split_width
        box = (left, 0, right, height)
        split = input_image.crop(box)
        # 이미지 품질 개선을 위한 처리
        split = split.convert('RGB')
        output_images.append(split)
    
    return output_images

@spaces.GPU 
def infer(prompt, seed=1, randomize_seed=False, num_inference_steps=20, progress=gr.Progress(track_tqdm=True)):
    progress(0, desc="Starting...")
    prompt_template = f"A single clear frame of {prompt}. The scene should show only one moment of the action, high quality, detailed, centered composition."
    
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    
    frames = []
    total_frames = 8
    
    # 진행 상황을 더 세밀하게 표시
    for i in range(total_frames):
        current_progress = (i / total_frames) * 0.8
        progress(current_progress, desc=f"🎨 Generating frame {i+1}/{total_frames}")
        frame_prompt = f"{prompt_template} Frame {i+1} of sequence."
        
        frame_seed = seed + i
        generator = torch.Generator().manual_seed(frame_seed)
        
        # 각 프레임의 생성 단계도 표시
        for step in range(num_inference_steps):
            step_progress = current_progress + (step / num_inference_steps) * (0.8 / total_frames)
            progress(step_progress, desc=f"Frame {i+1}/{total_frames} - Step {step+1}/{num_inference_steps}")
        
        frame = pipe(
            prompt=frame_prompt,
            num_inference_steps=num_inference_steps,
            num_images_per_prompt=1,
            generator=generator,
            height=320,
            width=320,
            guidance_scale=7.5,
        ).images[0]
        
        frames.append(frame)
        progress((i + 1) / total_frames * 0.8, desc=f"✅ Completed frame {i+1}/{total_frames}")
    
    progress(0.9, desc="🎬 Creating GIF...")
    gif_name = f"{uuid.uuid4().hex}-flux.gif"
    
    export_to_gif(frames, gif_name, fps=8)
    
    total_width = 320 * total_frames
    preview_image = Image.new('RGB', (total_width, 320))
    for i, frame in enumerate(frames):
        preview_image.paste(frame, (i * 320, 0))
    
    progress(1.0, desc="✨ Done!")
    return gif_name, preview_image, seed

def create_preview_image(frames):
    """프레임들을 가로로 연결하여 미리보기 이미지 생성"""
    total_width = sum(frame.width for frame in frames)
    max_height = max(frame.height for frame in frames)
    
    preview = Image.new('RGB', (total_width, max_height))
    x_offset = 0
    for frame in frames:
        preview.paste(frame, (x_offset, 0))
        x_offset += frame.width
    
    return preview

examples = [
    "a red panda in mid-backflip",
    "an astronaut floating in space",
    "a butterfly spreading its wings",
    "a robot arm painting with a brush",
    "a dragon egg with cracks appearing",
    "a person stepping through a glowing portal",
    "a mermaid swimming underwater",
    "a steampunk clock gear turning",
    "a flower bud slowly opening",
    "a wizard with magical energy swirling"
]

css = """
... (이전 CSS와 동일)

/* Examples 영역 스타일 완전 재정의 */
.gr-examples-parent {
    background: transparent !important;
}
.gr-examples-parent > div {
    background: transparent !important;
}
.gr-examples {
    background: transparent !important;
}
.gr-examples * {
    background: transparent !important;
}
.gr-samples-table {
    background: transparent !important;
}
.gr-samples-table > div {
    background: transparent !important;
}
.gr-samples-table button {
    background: transparent !important;
    border: none !important;
    box-shadow: none !important;
}
.gr-samples-table button:hover {
    background: rgba(0,0,0,0.05) !important;
}
div[class*="examples"] {
    background: transparent !important;
}

/* 프로그레스 바 스타일 강화 */
.progress-bar {
    background-color: #f0f0f0;
    border-radius: 10px;
    padding: 5px;
    margin: 15px 0;
    box-shadow: 0 2px 5px rgba(0,0,0,0.1);
}

.progress-bar-fill {
    background: linear-gradient(45deg, #FF6B6B, #4ECDC4);
    height: 25px;
    border-radius: 7px;
    transition: width 0.3s ease-out;
    box-shadow: 0 2px 5px rgba(0,0,0,0.1);
}

.progress-text {
    color: black;
    font-weight: 600;
    margin-bottom: 8px;
    font-size: 1.1em;
}

/* 진행 상태 텍스트 스타일 */
.progress-label {
    display: block;
    text-align: center;
    margin-top: 5px;
    color: #666;
    font-size: 0.9em;
}
"""


def create_snow_effect():
    # CSS 스타일 정의
    snow_css = """
    @keyframes snowfall {
        0% {
            transform: translateY(-10vh) translateX(0);
            opacity: 1;
        }
        100% {
            transform: translateY(100vh) translateX(100px);
            opacity: 0.3;
        }
    }
    .snowflake {
        position: fixed;
        color: white;
        font-size: 1.5em;
        user-select: none;
        z-index: 1000;
        pointer-events: none;
        animation: snowfall linear infinite;
    }
    """

    # JavaScript 코드 정의
    snow_js = """
    function createSnowflake() {
        const snowflake = document.createElement('div');
        snowflake.innerHTML = '❄';
        snowflake.className = 'snowflake';
        snowflake.style.left = Math.random() * 100 + 'vw';
        snowflake.style.animationDuration = Math.random() * 3 + 2 + 's';
        snowflake.style.opacity = Math.random();
        document.body.appendChild(snowflake);
        
        setTimeout(() => {
            snowflake.remove();
        }, 5000);
    }
    setInterval(createSnowflake, 200);
    """

    # CSS와 JavaScript를 결합한 HTML 
    snow_html = f"""
    <style>
        {snow_css}
    </style>
    <script>
        {snow_js}
    </script>
    """
    
    return gr.HTML(snow_html)


with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
    gr.HTML("""
        <div style="text-align: center; max-width: 800px; margin: 0 auto;">
            <h1 style="font-size: 3rem; font-weight: 700; margin-bottom: 1rem;">
                FLUX Animation Creator
            </h1>
            <p style="font-size: 1.2rem; color: #666; margin-bottom: 2rem;">
                Create amazing animated GIFs with AI - Just describe what you want to see!
            </p>
        </div>
    """)
    create_snow_effect()        
    with gr.Column(elem_id="col-container"):
        with gr.Row():
            prompt = gr.Text(
                label="Your Animation Prompt",
                show_label=True,
                max_lines=1,
                placeholder="Describe the animation you want to create...",
                container=True,
                elem_id="prompt-input"
            )
            run_button = gr.Button("✨ Generate", scale=0, variant="primary")
        
        result = gr.Image(
            label="Generated Animation", 
            show_label=True,
            elem_id="main-output",
            height=500
        )
        
        with gr.Row():
            result_full = gr.Image(
                label="Preview", 
                elem_id="preview-output",
                height=200
            )
            strip_image = gr.Image(
                label="Animation Strip", 
                elem_id="strip-output",
                height=150
            )
        
        with gr.Accordion("Advanced Settings", open=False):
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            num_inference_steps = gr.Slider(
                label="Number of inference steps",
                minimum=1,
                maximum=25,
                step=1,
                value=20,
            )    
        
        gr.Examples(
            examples=examples,
            inputs=[prompt],
            outputs=[result, result_full, seed],
            fn=infer,
            cache_examples=True,
            label="Click on any example to try it out"
        )
    
    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn=infer,
        inputs=[prompt, seed, randomize_seed, num_inference_steps],
        outputs=[result, result_full, seed]
    )

demo.theme = gr.themes.Default().set(
    body_text_color="black",
    block_label_text_color="black",
    block_title_text_color="black",
    body_text_color_subdued="black",
    background_fill_primary="white"
)

demo.queue().launch()