import torch
from diffusers import DDIMScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline, AutoPipelineForImage2Image

from src.eunms import Model_Type, Scheduler_Type
from src.euler_scheduler import MyEulerAncestralDiscreteScheduler
from src.lcm_scheduler import MyLCMScheduler
from src.ddpm_scheduler import MyDDPMScheduler
from src.sdxl_inversion_pipeline import SDXLDDIMPipeline
from src.sd_inversion_pipeline import SDDDIMPipeline
    
def scheduler_type_to_class(scheduler_type):
    if scheduler_type == Scheduler_Type.DDIM:
        return DDIMScheduler
    elif scheduler_type == Scheduler_Type.EULER:
        return MyEulerAncestralDiscreteScheduler
    elif scheduler_type == Scheduler_Type.LCM:
        return MyLCMScheduler
    elif scheduler_type == Scheduler_Type.DDPM:
        return MyDDPMScheduler
    else:
        raise ValueError("Unknown scheduler type")
    
def model_type_to_class(model_type):
    if model_type == Model_Type.SDXL:
        return StableDiffusionXLImg2ImgPipeline, SDXLDDIMPipeline
    elif model_type == Model_Type.SDXL_Turbo:
        return AutoPipelineForImage2Image, SDXLDDIMPipeline
    elif model_type == Model_Type.LCM_SDXL:
        return AutoPipelineForImage2Image, SDXLDDIMPipeline
    elif model_type == Model_Type.SD15:
        return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
    elif model_type == Model_Type.SD14:
        return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
    elif model_type == Model_Type.SD21:
        return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
    elif model_type == Model_Type.SD21_Turbo:
        return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
    else:
        raise ValueError("Unknown model type")
    
def model_type_to_model_name(model_type):
    if model_type == Model_Type.SDXL:
        return "stabilityai/stable-diffusion-xl-base-1.0"
    elif model_type == Model_Type.SDXL_Turbo:
        return "stabilityai/sdxl-turbo"
    elif model_type == Model_Type.LCM_SDXL:
        return "stabilityai/stable-diffusion-xl-base-1.0"
    elif model_type == Model_Type.SD15:
        return "runwayml/stable-diffusion-v1-5"
    elif model_type == Model_Type.SD14:
        return "CompVis/stable-diffusion-v1-4"
    elif model_type == Model_Type.SD21:
        return "stabilityai/stable-diffusion-2-1"
    elif model_type == Model_Type.SD21_Turbo:
        return "stabilityai/sd-turbo"
    else:
        raise ValueError("Unknown model type")

    
def model_type_to_size(model_type):
    if model_type == Model_Type.SDXL:
        return (1024, 1024)
    elif model_type == Model_Type.SDXL_Turbo:
        return (512, 512)
    elif model_type == Model_Type.LCM_SDXL:
        return (768, 768) #TODO: check
    elif model_type == Model_Type.SD15:
        return (512, 512)
    elif model_type == Model_Type.SD14:
        return (512, 512)
    elif model_type == Model_Type.SD21:
        return (512, 512)
    elif model_type == Model_Type.SD21_Turbo:
        return (512, 512)
    else:
        raise ValueError("Unknown model type")
    
def is_float16(model_type):
    if model_type == Model_Type.SDXL:
        return True
    elif model_type == Model_Type.SDXL_Turbo:
        return True
    elif model_type == Model_Type.LCM_SDXL:
        return True
    elif model_type == Model_Type.SD15:
        return False
    elif model_type == Model_Type.SD14:
        return False
    elif model_type == Model_Type.SD21:
        return False
    elif model_type == Model_Type.SD21_Turbo:
        return False
    else:
        raise ValueError("Unknown model type")
    
def is_sd(model_type):
    if model_type == Model_Type.SDXL:
        return False
    elif model_type == Model_Type.SDXL_Turbo:
        return False
    elif model_type == Model_Type.LCM_SDXL:
        return False
    elif model_type == Model_Type.SD15:
        return True
    elif model_type == Model_Type.SD14:
        return True
    elif model_type == Model_Type.SD21:
        return True
    elif model_type == Model_Type.SD21_Turbo:
        return True
    else:
        raise ValueError("Unknown model type")
    
def _get_pipes(model_type, device):
    model_name = model_type_to_model_name(model_type)
    pipeline_inf, pipeline_inv = model_type_to_class(model_type)

    if is_float16(model_type):
        pipe_inversion = pipeline_inv.from_pretrained(
                model_name,
                torch_dtype=torch.float16,
                use_safetensors=True,
                variant="fp16",
                safety_checker = None
            ).to(device)

        pipe_inference = pipeline_inf.from_pretrained(
                model_name,
                torch_dtype=torch.float16,
                use_safetensors=True,
                variant="fp16",
                safety_checker = None
            ).to(device)
    else:        
        pipe_inversion = pipeline_inv.from_pretrained(
                model_name,
                use_safetensors=True,
                safety_checker = None
            ).to(device)

        pipe_inference = pipeline_inf.from_pretrained(
                model_name,
                use_safetensors=True,
                safety_checker = None
            ).to(device)
    
    return pipe_inversion, pipe_inference
    
def get_pipes(model_type, scheduler_type, device="cuda"):
    # model_name = model_type_to_model_name(model_type)
    # pipeline_inf, pipeline_inv = model_type_to_class(model_type)
    scheduler_class = scheduler_type_to_class(scheduler_type)

    pipe_inversion, pipe_inference = _get_pipes(model_type, device)

    # pipe_inversion = pipeline_inv.from_pretrained(
    #         model_name,
    #         # torch_dtype=torch.float16,
    #         use_safetensors=True,
    #         # variant="fp16",
    #         safety_checker = None
    #     ).to("cuda")

    # pipe_inference = pipeline_inf.from_pretrained(
    #         model_name,
    #         # torch_dtype=torch.float16,
    #         use_safetensors=True,
    #         # variant="fp16",
    #         safety_checker = None
    #     ).to("cuda")
    
    pipe_inference.scheduler            = scheduler_class.from_config(pipe_inference.scheduler.config)
    pipe_inversion.scheduler            = scheduler_class.from_config(pipe_inversion.scheduler.config)
    pipe_inversion.scheduler_inference  = scheduler_class.from_config(pipe_inference.scheduler.config)

    if is_sd(model_type):
        pipe_inference.scheduler.add_noise = lambda init_latents, noise, timestep: init_latents
        pipe_inversion.scheduler.add_noise = lambda init_latents, noise, timestep: init_latents
        pipe_inversion.scheduler_inference.add_noise = lambda init_latents, noise, timestep: init_latents

    if model_type == Model_Type.LCM_SDXL:
        adapter_id = "latent-consistency/lcm-lora-sdxl"
        # load and fuse lcm lora
        pipe_inversion.load_lora_weights(adapter_id)
        # pipe_inversion.fuse_lora()
        pipe_inference.load_lora_weights(adapter_id)
        # pipe_inference.fuse_lora()

    return pipe_inversion, pipe_inference