import concurrent.futures 
import random
import gradio as gr
import requests
import io, base64, json
import spaces
from PIL import Image
from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS, MUSEUM_UNSUPPORTED_MODELS, DESIRED_APPEAR_MODEL, load_pipeline
from .fetch_museum_results import draw_from_imagen_museum, draw2_from_imagen_museum, draw_from_videogen_museum, draw2_from_videogen_museum

class ModelManager:
    def __init__(self):
        self.model_ig_list = IMAGE_GENERATION_MODELS
        self.model_ie_list = IMAGE_EDITION_MODELS
        self.model_vg_list = VIDEO_GENERATION_MODELS
        self.excluding_model_list = MUSEUM_UNSUPPORTED_MODELS
        self.desired_model_list = DESIRED_APPEAR_MODEL
        self.loaded_models = {}

    def load_model_pipe(self, model_name):
        if not model_name in self.loaded_models:
            pipe = load_pipeline(model_name)
            self.loaded_models[model_name] = pipe
        else:
            pipe = self.loaded_models[model_name]
        return pipe
    
    @spaces.GPU(duration=120)
    def generate_image_ig(self, prompt, model_name):
        pipe = self.load_model_pipe(model_name)
        result = pipe(prompt=prompt)
        return result

    def generate_image_ig_api(self, prompt, model_name):
        pipe = self.load_model_pipe(model_name)
        result = pipe(prompt=prompt)
        return result

    def generate_image_ig_museum(self, model_name):
        model_name = model_name.split('_')[1]
        result_list = draw_from_imagen_museum("t2i", model_name)
        image_link = result_list[0]
        prompt = result_list[1]

        return image_link, prompt


    def generate_image_ig_parallel_anony(self, prompt, model_A, model_B):
        # Using list comprehension to get the difference between two lists
        picking_list = [item for item in self.model_ig_list if item not in self.excluding_model_list]
        if model_A == "" and model_B == "":
            model_names = random.sample([model for model in picking_list], 2)
        else:
            model_names = [model_A, model_B]

        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("imagenhub")
                       else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
            results = [future.result() for future in futures]
        return results[0], results[1], model_names[0], model_names[1]

    def generate_image_ig_museum_parallel_anony(self, model_A, model_B):
        # Using list comprehension to get the difference between two lists
        picking_list = [item for item in self.model_ig_list if item not in self.excluding_model_list]
        if model_A == "" and model_B == "":
            model_names = random.sample([model for model in picking_list], 2)
        else:
            model_names = [model_A, model_B]

        with concurrent.futures.ThreadPoolExecutor() as executor:
            model_1 = model_names[0].split('_')[1]
            model_2 = model_names[1].split('_')[1]
            result_list = draw2_from_imagen_museum("t2i", model_1, model_2)
            image_links = result_list[0]
            prompt_list = result_list[1]

        return image_links[0], image_links[1], model_names[0], model_names[1], prompt_list[0]
    
    def generate_image_ig_parallel(self, prompt, model_A, model_B):
        model_names = [model_A, model_B]
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("imagenhub")
                       else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
            results = [future.result() for future in futures]
        return results[0], results[1]

    def generate_image_ig_museum_parallel(self, model_A, model_B):
        with concurrent.futures.ThreadPoolExecutor() as executor:
            model_1 = model_A.split('_')[1]
            model_2 = model_B.split('_')[1]
            result_list = draw2_from_imagen_museum("t2i", model_1, model_2)
            image_links = result_list[0]
            prompt_list = result_list[1]
        return image_links[0], image_links[1], prompt_list[0]


    @spaces.GPU(duration=200)
    def generate_image_ie(self, textbox_source, textbox_target, textbox_instruct, source_image, model_name):
        pipe = self.load_model_pipe(model_name)
        result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct)
        return result

    def generate_image_ie_museum(self, model_name):
        model_name = model_name.split('_')[1]
        result_list = draw_from_imagen_museum("tie", model_name)
        image_links = result_list[0]
        prompt_list = result_list[1]
        # image_links = [src, model]
        # prompt_list = [source_caption, target_caption, instruction]
        return image_links[0], image_links[1], prompt_list[0], prompt_list[1], prompt_list[2]

    def generate_image_ie_parallel(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
        model_names = [model_A, model_B]
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [
                executor.submit(self.generate_image_ie, textbox_source, textbox_target, textbox_instruct, source_image,
                                model) for model in model_names]
            results = [future.result() for future in futures]
        return results[0], results[1]

    def generate_image_ie_museum_parallel(self, model_A, model_B):
        model_names = [model_A, model_B]
        with concurrent.futures.ThreadPoolExecutor() as executor:
            model_1 = model_names[0].split('_')[1]
            model_2 = model_names[1].split('_')[1]
            result_list = draw2_from_imagen_museum("tie", model_1, model_2)
            image_links = result_list[0]
            prompt_list = result_list[1]
            # image_links = [src, model_A, model_B]
            # prompt_list = [source_caption, target_caption, instruction]
        return image_links[0], image_links[1], image_links[2], prompt_list[0], prompt_list[1], prompt_list[2]

    def generate_image_ie_parallel_anony(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
        # Using list comprehension to get the difference between two lists
        picking_list = [item for item in self.model_ie_list if item not in self.excluding_model_list]
        if model_A == "" and model_B == "":
            model_names = random.sample([model for model in picking_list], 2)
        else:
            model_names = [model_A, model_B]
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [executor.submit(self.generate_image_ie, textbox_source, textbox_target, textbox_instruct, source_image, model) for model in model_names]
            results = [future.result() for future in futures]
        return results[0], results[1], model_names[0], model_names[1]

    def generate_image_ie_museum_parallel_anony(self, model_A, model_B):
        # Using list comprehension to get the difference between two lists
        picking_list = [item for item in self.model_ie_list if item not in self.excluding_model_list]
        if model_A == "" and model_B == "":
            model_names = random.sample([model for model in picking_list], 2)
        else:
            model_names = [model_A, model_B]
        with concurrent.futures.ThreadPoolExecutor() as executor:
            model_1 = model_names[0].split('_')[1]
            model_2 = model_names[1].split('_')[1]
            result_list = draw2_from_imagen_museum("tie", model_1, model_2)
            image_links = result_list[0]
            prompt_list = result_list[1]
            # image_links = [src, model_A, model_B]
            # prompt_list = [source_caption, target_caption, instruction]
        return image_links[0], image_links[1], image_links[2], prompt_list[0], prompt_list[1], prompt_list[2], model_names[0], model_names[1]

    @spaces.GPU(duration=150)
    def generate_video_vg(self, prompt, model_name):
        pipe = self.load_model_pipe(model_name)
        result = pipe(prompt=prompt)
        return result

    def generate_video_vg_api(self, prompt, model_name):
        pipe = self.load_model_pipe(model_name)
        result = pipe(prompt=prompt)
        return result

    def generate_video_vg_museum(self, model_name):
        model_name = model_name.split('_')[1]
        result_list = draw_from_videogen_museum("t2v", model_name)
        video_link = result_list[0]
        prompt = result_list[1]

        return video_link, prompt

    def generate_video_vg_parallel_anony(self, prompt, model_A, model_B):
        # Using list comprehension to get the difference between two lists
        picking_list = [item for item in self.model_vg_list if item not in self.excluding_model_list]
        if model_A == "" and model_B == "":
            model_names = random.sample([model for model in picking_list], 2)
        else:
            model_names = [model_A, model_B]

        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [executor.submit(self.generate_video_vg, prompt, model) if model.startswith("videogenhub")
                       else executor.submit(self.generate_video_vg_api, prompt, model) for model in model_names]
            results = [future.result() for future in futures]
        return results[0], results[1], model_names[0], model_names[1]

    def generate_video_vg_museum_parallel_anony(self, model_A, model_B):
        # Using list comprehension to get the difference between two lists
        picking_list = [item for item in self.model_vg_list if item not in self.excluding_model_list]
        #picking_list = [item for item in picking_list if item not in self.desired_model_list]
        if model_A == "" and model_B == "":
            model_names = random.sample([model for model in picking_list], 2)

            #override the random selection
            #model_names[random.choice([0, 1])] = random.choice(self.desired_model_list)
        else:
            model_names = [model_A, model_B]

        with concurrent.futures.ThreadPoolExecutor() as executor:
            model_1 = model_names[0].split('_')[1]
            model_2 = model_names[1].split('_')[1]
            result_list = draw2_from_videogen_museum("t2v", model_1, model_2)
            video_links = result_list[0]
            prompt_list = result_list[1]

        return video_links[0], video_links[1], model_names[0], model_names[1], prompt_list[0]
    
    def generate_video_vg_parallel(self, prompt, model_A, model_B):
        model_names = [model_A, model_B]
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [executor.submit(self.generate_video_vg, prompt, model) if model.startswith("videogenhub")
                       else executor.submit(self.generate_video_vg_api, prompt, model) for model in model_names]
            results = [future.result() for future in futures]
        return results[0], results[1]

    def generate_video_vg_museum_parallel(self, model_A, model_B):
        model_names = [model_A, model_B]
        with concurrent.futures.ThreadPoolExecutor() as executor:
            model_1 = model_A.split('_')[1]
            model_2 = model_B.split('_')[1]
            result_list = draw2_from_videogen_museum("t2v", model_1, model_2)
            video_links = result_list[0]
            prompt_list = result_list[1]
        return video_links[0], video_links[1], prompt_list[0]