from gradio.components import Component
import gradio as gr
from abc import ABC, abstractclassmethod
import inspect

class BaseTCOModel(ABC):
    # TO DO: Find way to specify which component should be used for computing cost
    def __setattr__(self, name, value):
        if isinstance(value, Component):
            self._components.append(value)
        self.__dict__[name] = value

    def __init__(self):
        super(BaseTCOModel, self).__setattr__("_components", [])

    def get_components(self) -> list[Component]:
        return self._components
    
    def get_components_for_cost_computing(self):
        return self.components_for_cost_computing
    
    def get_name(self):
        return self.name
    
    def register_components_for_cost_computing(self):
        args = inspect.getfullargspec(self.compute_cost_per_token)[0][1:]
        self.components_for_cost_computing = [self.__getattribute__(arg) for arg in args]
    
    @abstractclassmethod
    def compute_cost_per_token(self):
        pass
    
    @abstractclassmethod
    def render(self):
        pass
    
    def set_name(self, name):
        self.name = name

class OpenAIModel(BaseTCOModel):

    def __init__(self):
        self.set_name("(SaaS) OpenAI")
        super().__init__()

    def render(self):
        def on_model_change(model):
            
            if model == "GPT-4":
                print("GPT4")
                return gr.Dropdown.update(choices=["8K", "32K"])
            else:
                print("GPT3.5")
                return gr.Dropdown.update(choices=["4K", "16K"])

        self.model = gr.Dropdown(["GPT-4", "GPT-3.5 Turbo"], value="GPT-4",
                                 label="OpenAI model",
                                 interactive=True, visible=False)
        self.context_length = gr.Dropdown(["8K", "32K"], value="8K", interactive=True,
                                          label="Context size",
                                          visible=False)
        self.model.change(on_model_change, inputs=self.model, outputs=self.context_length)
        self.input_length = gr.Number(350, label="Average number of input tokens", 
                                      interactive=True, visible=False)

    def compute_cost_per_token(self, model, context_length, input_length):
        """Cost per token = """
        model = model[0]
        context_length = context_length[0]

        if model == "GPT-4" and context_length == "8K":
            cost_per_1k_input_tokens = 0.03
        elif model == "GPT-4" and context_length == "32K":
            cost_per_1k_input_tokens = 0.06
        elif model == "GPT-3.5" and context_length == "4K":
            cost_per_1k_input_tokens = 0.0015
        else:
            cost_per_1k_input_tokens = 0.003

        cost_per_output_token = cost_per_1k_input_tokens * input_length / 1000

        return cost_per_output_token

class OpenSourceLlama2Model(BaseTCOModel):
    def __init__(self):
        self.set_name("(Open source) Llama 2")
        super().__init__()
    
    def render(self):
        vm_choices = ["1x Nvidia A100 (Azure NC24ads A100 v4)",
                      "2x Nvidia A100 (Azure NC48ads A100 v4)"]
        
        def on_model_change(model):
            if model == "Llama 2 7B":
                return gr.Dropdown.update(choices=vm_choices)
            else:
                not_supported_vm = ["1x Nvidia A100 (Azure NC24ads A100 v4)"]
                choices = [x for x in vm_choices if x not in not_supported_vm]
                return gr.Dropdown.update(choices=choices)

        def on_vm_change(model, vm):
            # TO DO: load info from CSV
            if model == "Llama 2 7B" and vm == "1x Nvidia A100 (Azure NC24ads A100 v4)":
                return gr.Number.update(value=900)
            elif model == "Llama 2 7B" and vm == "2x Nvidia A100 (Azure NC48ads A100 v4)":
                return gr.Number.update(value=1800)
        
        self.model = gr.Dropdown(["Llama 2 7B", "Llama 2 70B"], value="Llama 2 7B", visible=False)
        self.vm = gr.Dropdown(vm_choices, 
                              visible=False,
                              label="Instance of VM with GPU"
                              )
        self.vm_cost_per_hour = gr.Number(3.5, label="VM instance cost per hour", 
                                      interactive=True, visible=False)
        self.tokens_per_second = gr.Number(900, visible=False,
                                           label="Number of tokens per second for this specific model and VM instance",
                                           interactive=False
                                           )
        self.input_length = gr.Number(350, label="Average number of input tokens", 
                                      interactive=True, visible=False)
        
        self.model.change(on_model_change, inputs=self.model, outputs=self.vm)
        self.vm.change(on_vm_change, inputs=[self.model, self.vm], outputs=self.tokens_per_second)
        self.maxed_out = gr.Slider(minimum=0.01, value=1., step=0.01, label="% maxed out", 
                                   info="How much the GPU is fully used.",
                                   interactive=True,
                                   visible=False)

    def compute_cost_per_token(self, vm_cost_per_hour, tokens_per_second, maxed_out):
        cost_per_token = vm_cost_per_hour / (tokens_per_second * 3600 * maxed_out)
        return cost_per_token 
    
class ModelPage:
    def __init__(self, Models: BaseTCOModel):
        self.models: list[BaseTCOModel] = []
        for Model in Models:
            model = Model()
            self.models.append(model)

    def render(self):
        for model in self.models:
            model.render()
            model.register_components_for_cost_computing() 

    def get_all_components(self) -> list[Component]:
        output = []
        for model in self.models:
            output += model.get_components()
        return output
    
    def get_all_components_for_cost_computing(self) -> list[Component]:
        output = []
        for model in self.models:
            output += model.get_components_for_cost_computing()
        return output

    def make_model_visible(self, name:str):
        # First decide which indexes
        output = []
        for model in self.models:
            if model.get_name() == name:
                output+= [gr.update(visible=True)] * len(model.get_components())
            else:
                output+= [gr.update(visible=False)] * len(model.get_components())
        return output
    
    def compute_cost_per_token(self, *args):
        begin=0
        current_model = args[-1]
        for model in self.models:
            model_n_args = len(model.get_components_for_cost_computing())
            if current_model == model.get_name():
                model_args = args[begin:begin+model_n_args]
                print("Model args: ",model_args)
                model_tco = model.compute_cost_per_token(*model_args)
                return f"Model {current_model} has TCO {model_tco}"
            begin = begin+model_n_args