import io
import os
import base64
from PIL import Image
import cv2
import numpy as np
from scripts.generate_prompt import load_wd14_tagger_model, generate_tags, preprocess_image as wd14_preprocess_image
from scripts.lineart_util import scribble_xdog, get_sketch, canny
from scripts.anime import init_model
import torch
from diffusers import StableDiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler, AutoencoderKL
import gc
from peft import PeftModel
from dotenv import load_dotenv
from scripts.hf_utils import download_file

import spaces

# グローバル変数
use_local = False
model = None
device = None
torch_dtype = None # torch.float16 if device == "cuda" else torch.float32
sotai_gen_pipe = None
refine_gen_pipe = None

def get_file_path(filename, subfolder):
    if use_local:
        return subfolder + "/" + filename
    else:
        return download_file(filename, subfolder)

def ensure_rgb(image):
    if image.mode != 'RGB':
        return image.convert('RGB')
    return image

def initialize(_use_local=False, use_gpu=False, use_dotenv=False):
    if use_dotenv:
        load_dotenv()
    global model, sotai_gen_pipe, refine_gen_pipe, use_local, device, torch_dtype
    device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
    torch_dtype = torch.float16 if device == "cuda" else torch.float32
    use_local = _use_local

    print(f"\nDevice: {device}, Local model: {_use_local}\n")

    init_model(use_local)
    model = load_wd14_tagger_model()
    sotai_gen_pipe = initialize_sotai_model()
    refine_gen_pipe = initialize_refine_model()

def load_lora(pipeline, lora_path, alpha=0.75):
    pipeline.load_lora_weights(lora_path)
    pipeline.fuse_lora(lora_scale=alpha)

def initialize_sotai_model():
    global device, torch_dtype

    sotai_sd_model_path = get_file_path(os.environ["sotai_sd_model_name"], subfolder=os.environ["sd_models_dir"])
    controlnet_path1 =  get_file_path(os.environ["controlnet_name1"], subfolder=os.environ["controlnet_dir2"])
    # controlnet_path1 =  get_file_path(os.environ["controlnet_name2"], subfolder=os.environ["controlnet_dir1"])
    controlnet_path2 =  get_file_path(os.environ["controlnet_name2"], subfolder=os.environ["controlnet_dir1"])

    # Load the Stable Diffusion model
    sd_pipe = StableDiffusionPipeline.from_single_file(
        sotai_sd_model_path,
        torch_dtype=torch_dtype,
        use_safetensors=True
    ).to(device)
    
    # Load the ControlNet model
    controlnet1 = ControlNetModel.from_single_file(
        controlnet_path1,
        torch_dtype=torch_dtype
    ).to(device)
    
    # Load the ControlNet model
    controlnet2 = ControlNetModel.from_single_file(
        controlnet_path2,
        torch_dtype=torch_dtype
    ).to(device)

    # Create the ControlNet pipeline
    sotai_gen_pipe = StableDiffusionControlNetPipeline(
        vae=sd_pipe.vae,
        text_encoder=sd_pipe.text_encoder,
        tokenizer=sd_pipe.tokenizer,
        unet=sd_pipe.unet,
        scheduler=sd_pipe.scheduler,
        safety_checker=sd_pipe.safety_checker,
        feature_extractor=sd_pipe.feature_extractor,
        controlnet=[controlnet1, controlnet2]
    ).to(device)

    # LoRAの適用
    lora_names = [
        (os.environ["lora_name1"], 1.0),
        # (os.environ["lora_name2"], 0.3),
    ]
    
    for lora_name, alpha in lora_names:
        lora_path = get_file_path(lora_name, subfolder=os.environ["lora_dir"])
        load_lora(sotai_gen_pipe, lora_path, alpha)

    # スケジューラーの設定
    sotai_gen_pipe.scheduler = UniPCMultistepScheduler.from_config(sotai_gen_pipe.scheduler.config)

    return sotai_gen_pipe

def initialize_refine_model():
    global device, torch_dtype

    refine_sd_model_path = get_file_path(os.environ["refine_sd_model_name"], subfolder=os.environ["sd_models_dir"])
    controlnet_path3 = get_file_path(os.environ["controlnet_name3"], subfolder=os.environ["controlnet_dir1"])
    controlnet_path4 = get_file_path(os.environ["controlnet_name4"], subfolder=os.environ["controlnet_dir1"])
    vae_path = get_file_path(os.environ["vae_name"], subfolder=os.environ["vae_dir"])

    # Load the Stable Diffusion model
    sd_pipe = StableDiffusionPipeline.from_single_file(
        refine_sd_model_path,
        torch_dtype=torch_dtype,
        variant="fp16", 
        use_safetensors=True
    ).to(device)
    
    # controlnet_path = "models/cn/control_v11p_sd15_canny.pth"
    controlnet1 = ControlNetModel.from_single_file(
        controlnet_path3,
        torch_dtype=torch_dtype
    ).to(device)
    
    # Load the ControlNet model
    controlnet2 = ControlNetModel.from_single_file(
        controlnet_path4,
        torch_dtype=torch_dtype
    ).to(device)

    # Create the ControlNet pipeline
    refine_gen_pipe = StableDiffusionControlNetPipeline(
        vae=AutoencoderKL.from_single_file(vae_path, torch_dtype=torch_dtype).to(device),
        text_encoder=sd_pipe.text_encoder,
        tokenizer=sd_pipe.tokenizer,
        unet=sd_pipe.unet,
        scheduler=sd_pipe.scheduler,
        safety_checker=sd_pipe.safety_checker,
        feature_extractor=sd_pipe.feature_extractor,
        controlnet=[controlnet1, controlnet2],  # 複数のControlNetを指定
    ).to(device)

    # スケジューラーの設定
    refine_gen_pipe.scheduler = UniPCMultistepScheduler.from_config(refine_gen_pipe.scheduler.config)

    return refine_gen_pipe

def get_wd_tags(images: list) -> list:
    global model
    if model is None:
        raise ValueError("Model is not initialized")
        # initialize()
    preprocessed_images = [wd14_preprocess_image(img) for img in images]
    preprocessed_images = np.array(preprocessed_images)
    return generate_tags(preprocessed_images, os.environ["wd_model_name"], model)

def preprocess_image_for_generation(image):
    if isinstance(image, str):  # base64文字列の場合
        image = Image.open(io.BytesIO(base64.b64decode(image)))
    elif isinstance(image, np.ndarray):  # numpy配列の場合
        image = Image.fromarray(image)
    elif not isinstance(image, Image.Image):
        raise ValueError("Unsupported image type")
    
    # 画像サイズの計算
    input_width, input_height = image.size
    max_size = 736
    output_width = max_size if input_height < input_width else int(input_width / input_height * max_size)
    output_height = max_size if input_height > input_width else int(input_height / input_width * max_size)
    
    image = image.resize((output_width, output_height))
    return image, output_width, output_height

def binarize_image(image: Image.Image) -> np.ndarray:
    image = np.array(image.convert('L'))
    # 色反転
    image = 255 - image
    
    # ヒストグラム平坦化
    clahe = cv2.createCLAHE(clipLimit=1.0, tileGridSize=(8, 8))
    image = clahe.apply(image)

    # ガウシアンブラー適用
    image = cv2.GaussianBlur(image, (5, 5), 0)

    # 適応的二値化
    binary_image = cv2.adaptiveThreshold(image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 9, -8)

    return binary_image

def create_rgba_image(binary_image: np.ndarray, color: list) -> Image.Image:
    rgba_image = np.zeros((binary_image.shape[0], binary_image.shape[1], 4), dtype=np.uint8)
    rgba_image[:, :, 0] = color[0]
    rgba_image[:, :, 1] = color[1]
    rgba_image[:, :, 2] = color[2]
    rgba_image[:, :, 3] = binary_image
    return Image.fromarray(rgba_image, 'RGBA')

@spaces.GPU
def generate_sotai_image(input_image: Image.Image, output_width: int, output_height: int) -> Image.Image:
    input_image = ensure_rgb(input_image)
    global sotai_gen_pipe
    if sotai_gen_pipe is None:
        raise ValueError("Model is not initialized")
        # initialize()

    prompt = "anime pose, girl, (white background:1.5), (monochrome:1.5), full body, sketch, eyes, breasts, (slim legs, skinny legs:1.2)"
    try:
        # 入力画像のリサイズ
        if input_image.size[0] > input_image.size[1]:
            input_image = input_image.resize((512, int(512 * input_image.size[1] / input_image.size[0])))
        else:
            input_image = input_image.resize((int(512 * input_image.size[0] / input_image.size[1]), 512))

        # EasyNegativeV2の内容
        easy_negative_v2 = "(worst quality, low quality, normal quality:1.4), lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry, artist name, (bad_prompt_version2:0.8)"

        output = sotai_gen_pipe(
            prompt,
            image=[input_image, input_image],
            negative_prompt=f"(wings:1.6), (clothes, garment, lighting, gray, missing limb, extra line, extra limb, extra arm, extra legs, hair, bangs, fringe, forelock, front hair, fill:1.4), (ink pool:1.6)",
            # negative_prompt=f"{easy_negative_v2}, (wings:1.6), (clothes, garment, lighting, gray, missing limb, extra line, extra limb, extra arm, extra legs, hair, bangs, fringe, forelock, front hair, fill:1.4), (ink pool:1.6)",
            num_inference_steps=20,
            guidance_scale=8,
            width=output_width,
            height=output_height,
            denoising_strength=0.13,
            num_images_per_prompt=1,  # Equivalent to batch_size
            guess_mode=[True, True],  # Equivalent to pixel_perfect
            controlnet_conditioning_scale=[1.4, 1.3],  # 各ControlNetの重み
            guidance_start=[0.0, 0.0],
            guidance_end=[1.0, 1.0],
        )
        generated_image = output.images[0]
        
        return generated_image

    finally:
        # メモリ解放
        if device == "cuda":
            torch.cuda.empty_cache()
        gc.collect()

@spaces.GPU
def generate_refined_image(prompt: str, original_image: Image.Image, output_width: int, output_height: int, weight1: float, weight2: float) -> Image.Image:
    original_image = ensure_rgb(original_image)
    global refine_gen_pipe
    if refine_gen_pipe is None:
        raise ValueError("Model is not initialized")
        # initialize()

    try:
        original_image_np = np.array(original_image)
        # scribble_xdog
        scribble_image, _ = scribble_xdog(original_image_np, 2048, 20)

        original_image = original_image.resize((output_width, output_height))
        output = refine_gen_pipe(
            prompt,
            image=[scribble_image, original_image],  # 2つのControlNetに対応する入力画像
            negative_prompt="extra limb, monochrome, black and white",
            num_inference_steps=20,
            width=output_width,
            height=output_height,
            controlnet_conditioning_scale=[weight1, weight2],  # 各ControlNetの重み
            control_guidance_start=[0.0, 0.0],
            control_guidance_end=[1.0, 1.0],
            guess_mode=[False, False],  # pixel_perfect
        )
        generated_image = output.images[0]
        
        return generated_image

    finally:
        # メモリ解放
        if device == "cuda":
            torch.cuda.empty_cache()
        gc.collect()

def process_image(input_image, mode: str, weight1: float = 0.4, weight2: float = 0.3):
    input_image = ensure_rgb(input_image)
    # サイズを取得
    input_width, input_height = input_image.size
    max_size = 736
    output_width = max_size if input_height < input_width else int(input_width / input_height * max_size)
    output_height = max_size if input_height > input_width else int(input_height / input_width * max_size)

    if mode == "refine":
        # WD-14 taggerを使用してプロンプトを生成
        image_np = np.array(ensure_rgb(input_image))
        prompt = get_wd_tags([image_np])[0]
        prompt = f"{prompt}"

        refined_image = generate_refined_image(prompt, input_image, output_width, output_height, weight1, weight2)
        refined_image = refined_image.convert('RGB')

        # スケッチ画像を生成
        refined_image_np = np.array(refined_image)
        sketch_image = get_sketch(refined_image_np, "both", 2048, 10)
        sketch_image = sketch_image.resize((output_width, output_height))  # 画像サイズを合わせる
        # スケッチ画像の二値化
        sketch_binary = binarize_image(sketch_image)
        # RGBAに変換(透明なベース画像を作成)して、青い線を設定
        sketch_image = create_rgba_image(sketch_binary, [0, 0, 255])

        # 素体画像の生成
        sotai_image = generate_sotai_image(refined_image, output_width, output_height)

    elif mode == "original":
        sotai_image = generate_sotai_image(input_image, output_width, output_height)
        
        # スケッチ画像の生成
        input_image_np = np.array(input_image)
        sketch_image = get_sketch(input_image_np, "both", 2048, 16)

    elif mode == "sketch":
        # スケッチ画像の生成
        input_image_np = np.array(input_image)
        sketch_image = get_sketch(input_image_np, "both", 2048, 16)
        
        # 素体画像の生成
        sotai_image = generate_sotai_image(sketch_image, output_width, output_height)

    else:
        raise ValueError("Invalid mode")

    # 素体画像の二値化
    sotai_binary = binarize_image(sotai_image)
    # RGBAに変換(透明なベース画像を作成)して、赤い線を設定
    sotai_image = create_rgba_image(sotai_binary, [255, 0, 0])

    return sotai_image, sketch_image

def image_to_base64(img_array):
    buffered = io.BytesIO()
    img_array.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode()

def process_image_as_base64(input_image, mode: str, weight1: float = 0.4, weight2: float = 0.3):
    sotai_image, sketch_image = process_image(input_image, mode, weight1, weight2)
    return image_to_base64(sotai_image), image_to_base64(sketch_image)