# Utility class for loading and using diffusers model
import diffusers
import transformers

import torch 
from typing import Union
import os
import warnings
import numpy as np
from PIL import Image
import tqdm
from copy import deepcopy
import matplotlib.pyplot as plt

def build_generator(
        device : torch.device,
        seed : int,
):
    """
    Build a torch.Generator with a given seed.
    """
    generator = torch.Generator(device).manual_seed(seed)
    return generator

def load_stablediffusion_model(
        model_id : Union[str, os.PathLike],
        device : torch.device,
        ):
    """
    Load a complete diffusion model from a model id.
    Returns a tuple of the model and a torch.Generator if seed is not None.

    """
    pipe = diffusers.DiffusionPipeline.from_pretrained(
        model_id,
        revision="fp16", 
        torch_dtype=torch.float16,
        use_auth_token=True,
    )
    pipe.scheduler = diffusers.DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    try:
        pipe = pipe.to(device)
    except:
        warnings.warn(
            f'Could not load model to device:{device}. Using CPU instead.'
        )
        pipe = pipe.to('cpu')
        device = 'cpu'

    return pipe


def visualize_image_grid(
        imgs : np.array,
        rows : int,
        cols : int):
    
    assert len(imgs) == rows*cols

    # create grid
    w, h = imgs[0].size # assuming each image is the same size

    grid = Image.new('RGB', size=(cols*w, rows*h))

    for i,img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid


def build_pipeline(
        autoencoder : Union[str, os.PathLike] = "CompVis/stable-diffusion-v1-4",
        tokenizer : Union[str, os.PathLike] = "openai/clip-vit-large-patch14",
        text_encoder : Union[str, os.PathLike] = "openai/clip-vit-large-patch14",
        unet : Union[str, os.PathLike] = "CompVis/stable-diffusion-v1-4",
        device : torch.device = torch.device('cuda'),
        ):
    """
    Create a pipeline for StableDiffusion by loading the model and component seperetely.
    Arguments:
        autoencoder: path to model that autoencoder will be loaded from
        tokenizer: path to tokenizer
        text_encoder: path to text_encoder
        unet: path to unet
    """
    # Load the VAE for encoding images into the latent space
    vae = diffusers.AutoencoderKL.from_pretrained(autoencoder, subfolder = 'vae')

    # Load tokenizer & text encoder for encoding text into the latent space
    tokenizer = transformers.CLIPTokenizer.from_pretrained(tokenizer)
    text_encoder = transformers.CLIPTextModel.from_pretrained(text_encoder)

    # Use the UNet model for conditioning the diffusion process
    unet = diffusers.UNet2DConditionModel.from_pretrained(unet, subfolder = 'unet')

    # Move all the components to device
    vae = vae.to(device)
    text_encoder = text_encoder.to(device)
    unet = unet.to(device)

    return vae, tokenizer, text_encoder, unet

#TODO : Add negative prompting
def custom_stablediffusion_inference(
        vae,
        tokenizer,
        text_encoder,
        unet,
        noise_scheduler,
        prompt : list,
        device : torch.device,
        num_inference_steps = 100,
        image_size = (512,512),
        guidance_scale = 8,
        seed = 42,
        return_image_step = 5,
    ):
    # Get the text embeddings that will condition the diffusion process
    if isinstance(prompt,str):
        prompt = [prompt]

    batch_size = len(prompt)
    text_input = tokenizer(
        prompt,
        padding = 'max_length',
        truncation = True,
        max_length = tokenizer.model_max_length,
        return_tensors = 'pt').to(device)
    
    text_embeddings = text_encoder(
        text_input.input_ids.to(device)
    )[0]

    # Get the text embeddings for classifier-free guidance
    max_length = text_input.input_ids.shape[-1]
    empty = [""] * batch_size
    uncond_input = tokenizer(
        empty,
        padding = 'max_length',
        truncation = True,
        max_length = max_length,
        return_tensors = 'pt').to(device)
    
    uncond_embeddings = text_encoder(
        uncond_input.input_ids.to(device)
    )[0]

    # Concatenate the text embeddings to get the conditioning vector
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

    # Generate initial noise
    latents = torch.randn(
        (1, unet.in_channels, image_size[0] // 8, image_size[1] // 8),
        generator=torch.manual_seed(seed) if seed is not None else None
    )
    print(latents.shape)

    latents = latents.to(device)

    # Initialize scheduler for noise generation
    noise_scheduler.set_timesteps(num_inference_steps)

    latents = latents * noise_scheduler.init_noise_sigma 

    noise_scheduler.set_timesteps(num_inference_steps)
    for i,t in tqdm.tqdm(enumerate(noise_scheduler.timesteps)):
        # If no text embedding is provided (classifier-free guidance), extend the conditioning vector
        latent_model_input = torch.cat([latents] * 2)

        latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)

        with torch.no_grad():
            # Get the noise prediction from the UNet
            noise_pred = unet(latent_model_input, t, encoder_hidden_states = text_embeddings).sample 

        # Perform guidance from the text embeddings
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # Compute the previously noisy sample x_t -> x_t-1
        latents = noise_scheduler.step(noise_pred, t, latents).prev_sample

        # Now that latent is generated from a noise, use unet decoder to generate images
        if i % return_image_step == 0:
            with torch.no_grad():
                latents_copy = deepcopy(latents)
                image = vae.decode(1/0.18215 * latents_copy).sample

            image = (image / 2 + 0.5).clamp(0,1)
            image = image.detach().cpu().permute(0,2,3,1).numpy() # bxhxwxc
            images = (image * 255).round().astype("uint8")

            pil_images = [Image.fromarray(img) for img in images]

            yield pil_images[0]

    yield pil_images[0]

if __name__ == "__main__":
    device = torch.device("cpu")
    model_id = "stabilityai/stable-diffusion-2-1"
    tokenizer_id = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
    #noise_scheduler = diffusers.LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
    noise_scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained(model_id,subfolder="scheduler")
    prompt = "A Hyperrealistic photograph of Italian architectural modern home in Italy, lens flares,\
            cinematic, hdri, matte painting, concept art, celestial, soft render, highly detailed, octane\
            render, architectural HD, HQ, 4k, 8k"
    
    vae, tokenizer, text_encoder, unet = build_pipeline(
        autoencoder = model_id,
        tokenizer=tokenizer_id,
        text_encoder=tokenizer_id,
        unet=model_id,
        device=device,
        )
    image_iter = custom_stablediffusion_inference(vae, tokenizer, text_encoder, unet, noise_scheduler, prompt = prompt, device=device, seed = None)
    for i, image in enumerate(image_iter):
        image.save(f"step_{i}.png")