import os
import cv2
import torch
import random
import numpy as np

seed = 1024
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

from PIL import Image
from gdown import download_folder
from spiga_draw import spiga_process, spiga_segmentation

from pipeline_sd15 import StableDiffusionControlNetPipeline
from diffusers import DDIMScheduler, ControlNetModel
from diffusers import UNet2DConditionModel as OriginalUNet2DConditionModel
from detail_encoder.encoder_plus import detail_encoder

device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

def get_draw(pil_img, size):
    cv2_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
    spigas = spiga_process(cv2_img)
    if spigas == False:
        width, height = pil_img.size
        black_image_pil = Image.new("RGB", (width, height), color=(0, 0, 0))
        return black_image_pil
    else:
        spigas_faces = spiga_segmentation(spigas, size=size)
        return spigas_faces


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG"])


def concatenate_images(image_files, output_file):
    images = image_files  # list
    max_height = max(img.height for img in images)
    images = [img.resize((img.width, max_height)) for img in images]
    total_width = sum(img.width for img in images)
    combined = Image.new("RGB", (total_width, max_height))
    x_offset = 0
    for img in images:
        combined.paste(img, (x_offset, 0))
        x_offset += img.width
    combined.save(output_file)


def init_pipeline():
    # Initialize the model
    model_id  = "runwayml/stable-diffusion-v1-5"  # or your local sdv1-5 path
    base_path = "./checkpoints/stablemakeup"
    folder_id = "1397t27GrUyLPnj17qVpKWGwg93EcaFfg"
    if not os.path.exists(base_path):
        download_folder(id=folder_id, output=base_path, quiet=False, use_cookies=False)
    makeup_encoder_path = base_path + "/pytorch_model.bin"
    id_encoder_path     = base_path + "/pytorch_model_1.bin"
    pose_encoder_path   = base_path + "/pytorch_model_2.bin"

    Unet              = OriginalUNet2DConditionModel.from_pretrained(model_id, device=device, subfolder="unet").half()
    id_encoder        = ControlNetModel.from_unet(Unet)
    pose_encoder      = ControlNetModel.from_unet(Unet)
    makeup_encoder    = detail_encoder(Unet, "openai/clip-vit-large-patch14", device=device, dtype=torch.float16)
    id_state_dict     = torch.load(id_encoder_path, map_location=torch.device('cpu'))
    pose_state_dict   = torch.load(pose_encoder_path, map_location=torch.device('cpu'))
    makeup_state_dict = torch.load(makeup_encoder_path, map_location=torch.device('cpu'))
    id_encoder.load_state_dict(id_state_dict, strict=False)
    pose_encoder.load_state_dict(pose_state_dict, strict=False)
    makeup_encoder.load_state_dict(makeup_state_dict, strict=False)
    id_encoder.to(device=device).half()
    pose_encoder.to(device=device).half()
    makeup_encoder.to(device=device).half()

    pipe = StableDiffusionControlNetPipeline.from_pretrained(
        model_id, safety_checker=None, unet=Unet, controlnet=[id_encoder, pose_encoder], device=device, torch_dtype=torch.float16
    ).to(device=device)
    pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
    return pipe, makeup_encoder


# Initialize the model
pipeline, makeup_encoder = init_pipeline()


def inference(id_image_pil, makeup_image_pil, guidance_scale=1.6, size=512):
    id_image     = id_image_pil.resize((size, size))
    makeup_image = makeup_image_pil.resize((size, size))
    pose_image   = get_draw(id_image, size=size)
    result_img   = makeup_encoder.generate(id_image=[id_image, pose_image], makeup_image=makeup_image, pipe=pipeline, guidance_scale=guidance_scale)
    return result_img