import os, sys, time, re, pdb
import torch, torchvision
import numpy
from PIL import Image
import hashlib
from tqdm import tqdm
import openai
from utils.direction_utils import *

p = "submodules/pix2pix-zero/src/utils"
if p not in sys.path:
    sys.path.append(p)
from diffusers import DDIMScheduler
from edit_directions import construct_direction
from edit_pipeline import EditingPipeline
from ddim_inv import DDIMInversion
from scheduler import DDIMInverseScheduler
from lavis.models import load_model_and_preprocess
from transformers import T5Tokenizer, AutoTokenizer, T5ForConditionalGeneration, BloomForCausalLM



def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device="cuda"):
    with torch.no_grad():
        l_embeddings = []
        for sent in tqdm(l_sentences):
            text_inputs = tokenizer(
                    sent,
                    padding="max_length",
                    max_length=tokenizer.model_max_length,
                    truncation=True,
                    return_tensors="pt",
                )
            text_input_ids = text_inputs.input_ids
            prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
            l_embeddings.append(prompt_embeds)
    return torch.concatenate(l_embeddings, dim=0).mean(dim=0).unsqueeze(0)



def launch_generate_sample(prompt, seed, negative_scale, num_ddim):
    os.makedirs("tmp", exist_ok=True)
    # do the editing
    edit_pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to("cuda")
    edit_pipe.scheduler = DDIMScheduler.from_config(edit_pipe.scheduler.config)

    # set the random seed and sample the input noise map
    torch.cuda.manual_seed(int(seed))
    z = torch.randn((1,4,64,64), device="cuda")

    z_hashname = hashlib.sha256(z.cpu().numpy().tobytes()).hexdigest()
    z_inv_fname = f"tmp/{z_hashname}_ddim_{num_ddim}_inv.pt"
    torch.save(z, z_inv_fname)

    rec_pil = edit_pipe(prompt, 
        num_inference_steps=num_ddim, x_in=z,
        only_sample=True, # this flag will only generate the sampled image, not the edited image
        guidance_scale=negative_scale,
        negative_prompt="" # use the empty string for the negative prompt
    )
    # print(rec_pil)
    del edit_pipe
    torch.cuda.empty_cache()

    return rec_pil[0], z_inv_fname



def clean_l_sentences(ls):
    s = [re.sub('\d', '', x) for x in ls]
    s = [x.replace(".","").replace("-","").replace(")","").strip() for x in s]
    return s



def gpt3_compute_word2sentences(task_type, word, num=100):
    l_sentences = [] 
    if task_type=="object":
        template_prompt = f"Provide many captions for images containing {word}."
    elif task_type=="style":
        template_prompt = f"Provide many captions for images that are in the {word} style."
    while True:
        ret = openai.Completion.create(
            model="text-davinci-002",
            prompt=template_prompt,
            max_tokens=1000,
            temperature=1.0)
        raw_return = ret.choices[0].text
        for line in raw_return.split("\n"):
            line = line.strip()
            if len(line)>10:
                skip=False 
                for subword in word.split(" "):
                    if subword not in line: skip=True
                if not skip: l_sentences.append(line)
                else:
                    l_sentences.append(line+f", {word}")
        time.sleep(0.05)
        print(len(l_sentences))
        if len(l_sentences)>=num:
            break
    l_sentences = clean_l_sentences(l_sentences)
    return l_sentences


def flant5xl_compute_word2sentences(word, num=100):
    text_input = f"Provide a caption for images containing a {word}. The captions should be in English and should be no longer than 150 characters."
    
    l_sentences = []
    tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
    model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_map="auto", torch_dtype=torch.float16)
    input_ids = tokenizer(text_input, return_tensors="pt").input_ids.to("cuda")
    input_length = input_ids.shape[1]
    while True:
        outputs = model.generate(input_ids,temperature=0.9, num_return_sequences=16, do_sample=True, max_length=128)
        output = tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True)
        for line in output:
            line = line.strip()
            skip=False 
            for subword in word.split(" "):
                if subword not in line: skip=True
            if not skip: l_sentences.append(line)
            else: l_sentences.append(line+f", {word}")
        print(len(l_sentences))
        if len(l_sentences)>=num:
            break
    l_sentences = clean_l_sentences(l_sentences)

    del model
    del tokenizer
    torch.cuda.empty_cache()

    return l_sentences

def bloomz_compute_sentences(word, num=100):
    l_sentences = []
    tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-7b1")
    model = BloomForCausalLM.from_pretrained("bigscience/bloomz-7b1", device_map="auto", torch_dtype=torch.float16)
    input_text = f"Provide a caption for images containing a {word}. The captions should be in English and should be no longer than 150 characters. Caption:"
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
    input_length = input_ids.shape[1]
    t = 0.95
    eta = 1e-5
    min_length = 15

    while True:
        try:
            outputs = model.generate(input_ids,temperature=t, num_return_sequences=16, do_sample=True, max_length=128, min_length=min_length, eta_cutoff=eta)
            output = tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True)
        except:
            continue
        for line in output:
            line = line.strip()
            skip=False 
            for subword in word.split(" "):
                if subword not in line: skip=True
            if not skip: l_sentences.append(line)
            else: l_sentences.append(line+f", {word}")
        print(len(l_sentences))
        if len(l_sentences)>=num:
            break
    l_sentences = clean_l_sentences(l_sentences)
    del model
    del tokenizer
    torch.cuda.empty_cache()

    return l_sentences



def make_custom_dir(description, sent_type, api_key, org_key, l_custom_sentences):
    if sent_type=="fixed-template":
        l_sentences = generate_image_prompts_with_templates(description)
    elif "GPT3" in sent_type:
        import openai
        openai.organization = org_key
        openai.api_key = api_key
        _=openai.Model.retrieve("text-davinci-002")
        l_sentences = gpt3_compute_word2sentences("object", description, num=1000)
    
    elif "flan-t5-xl" in sent_type:
        l_sentences = flant5xl_compute_word2sentences(description, num=1000)
        # save the sentences to file
        with open(f"tmp/flant5xl_sentences_{description}.txt", "w") as f:
            for line in l_sentences:
                f.write(line+"\n")
    elif "BLOOMZ-7B" in sent_type:
        l_sentences = bloomz_compute_sentences(description, num=1000)
        # save the sentences to file
        with open(f"tmp/bloomz_sentences_{description}.txt", "w") as f:
            for line in l_sentences:
                f.write(line+"\n")
    
    elif sent_type=="custom sentences":
        l_sentences = l_custom_sentences.split("\n")
        print(f"length of new sentence is {len(l_sentences)}")

    pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to("cuda")
    emb = load_sentence_embeddings(l_sentences, pipe.tokenizer, pipe.text_encoder, device="cuda")
    del pipe
    torch.cuda.empty_cache()
    return emb


def launch_main(img_in_real, img_in_synth, src, src_custom, dest, dest_custom, num_ddim, xa_guidance, edit_mul, fpath_z_gen, gen_prompt, sent_type_src, sent_type_dest, api_key, org_key, custom_sentences_src, custom_sentences_dest):
    d_name2desc = get_all_directions_names()
    d_desc2name = {v:k for k,v in d_name2desc.items()}
    os.makedirs("tmp", exist_ok=True)

    # generate custom direction first
    if src=="make your own!":
        outf_name = f"tmp/template_emb_{src_custom}_{sent_type_src}.pt"
        if not os.path.exists(outf_name):
            src_emb = make_custom_dir(src_custom, sent_type_src, api_key, org_key, custom_sentences_src)
            torch.save(src_emb, outf_name)
        else:
            src_emb = torch.load(outf_name)
    else:
        src_emb = get_emb(d_desc2name[src])
    
    if dest=="make your own!":
        outf_name = f"tmp/template_emb_{dest_custom}_{sent_type_dest}.pt"
        if not os.path.exists(outf_name):
            dest_emb = make_custom_dir(dest_custom, sent_type_dest, api_key, org_key, custom_sentences_dest)
            torch.save(dest_emb, outf_name)
        else:
            dest_emb = torch.load(outf_name)
    else:
        dest_emb = get_emb(d_desc2name[dest])
    text_dir = (dest_emb.cuda() - src_emb.cuda())*edit_mul



    if img_in_real is not None and img_in_synth is None:
        print("using real image")
        # resize the image so that the longer side is 512
        width, height = img_in_real.size
        if width > height: scale_factor = 512 / width
        else: scale_factor = 512 / height
        new_size = (int(width * scale_factor), int(height * scale_factor))
        img_in_real = img_in_real.resize(new_size, Image.Resampling.LANCZOS)
        hash = hashlib.sha256(img_in_real.tobytes()).hexdigest()
        # print(hash)
        inv_fname = f"tmp/{hash}_ddim_{num_ddim}_inv.pt"
        caption_fname = f"tmp/{hash}_caption.txt"

        # make the caption if it hasn't been made before
        if not os.path.exists(caption_fname):
            # BLIP
            model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=torch.device("cuda"))
            _image = vis_processors["eval"](img_in_real).unsqueeze(0).cuda()
            prompt_str = model_blip.generate({"image": _image})[0]
            del model_blip
            torch.cuda.empty_cache()
            with open(caption_fname, "w") as f:
                f.write(prompt_str)
        else:
            prompt_str = open(caption_fname, "r").read().strip()
        print(f"CAPTION: {prompt_str}")
        
        # do the inversion if it hasn't been done before
        if not os.path.exists(inv_fname):
            # inversion pipeline
            pipe_inv = DDIMInversion.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to("cuda")
            pipe_inv.scheduler = DDIMInverseScheduler.from_config(pipe_inv.scheduler.config)
            x_inv, x_inv_image, x_dec_img = pipe_inv( prompt_str, 
                    guidance_scale=1, num_inversion_steps=num_ddim,
                    img=img_in_real, torch_dtype=torch.float32 )
            x_inv = x_inv.detach()
            torch.save(x_inv, inv_fname)
            del pipe_inv
            torch.cuda.empty_cache()
        else:
            x_inv = torch.load(inv_fname)

        # do the editing
        edit_pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to("cuda")
        edit_pipe.scheduler = DDIMScheduler.from_config(edit_pipe.scheduler.config)

        _, edit_pil = edit_pipe(prompt_str,
                num_inference_steps=num_ddim,
                x_in=x_inv,
                edit_dir=text_dir,
                guidance_amount=xa_guidance,
                guidance_scale=5.0,
                negative_prompt=prompt_str # use the unedited prompt for the negative prompt
        )
        del edit_pipe
        torch.cuda.empty_cache()
        return edit_pil[0]


    elif img_in_real is None and img_in_synth is not None:
        print("using synthetic image")
        x_inv = torch.load(fpath_z_gen)
        pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to("cuda")
        pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
        rec_pil, edit_pil = pipe(gen_prompt,
            num_inference_steps=num_ddim,
            x_in=x_inv,
            edit_dir=text_dir,
            guidance_amount=xa_guidance,
            guidance_scale=5,
            negative_prompt="" # use the empty string for the negative prompt
        )
        del pipe
        torch.cuda.empty_cache()
        return edit_pil[0]

    else:
        raise ValueError(f"Invalid image type found: {img_in_real} {img_in_synth}")



if __name__=="__main__":
    print(flant5xl_compute_word2sentences("cat wearing sunglasses", num=100))