from gradio.components import Component
import gradio as gr
import pandas as pd
from abc import ABC, abstractclassmethod
import inspect

class BaseTCOModel(ABC):
    # TO DO: Find way to specify which component should be used for computing cost
    def __setattr__(self, name, value):
        if isinstance(value, Component):
            self._components.append(value)
        self.__dict__[name] = value

    def __init__(self):
        super(BaseTCOModel, self).__setattr__("_components", [])
        self.use_case = None  

    def get_components(self) -> list[Component]:
        return self._components
    
    def get_components_for_cost_computing(self):
        return self.components_for_cost_computing
    
    def get_name(self):
        return self.name
    
    def register_components_for_cost_computing(self):
        args = inspect.getfullargspec(self.compute_cost_per_token)[0][1:]
        self.components_for_cost_computing = [self.__getattribute__(arg) for arg in args]
    
    @abstractclassmethod
    def compute_cost_per_token(self):
        pass
    
    @abstractclassmethod
    def render(self):
        pass
    
    def set_name(self, name):
        self.name = name
    
    def set_latency(self, latency):
        self.latency = latency
    
    def get_latency(self):
        return self.latency

class OpenAIModelGPT4(BaseTCOModel):

    def __init__(self):
        self.set_name("(SaaS) OpenAI GPT4")
        self.set_latency("15s") #Default value for GPT4
        super().__init__()

    def render(self):
        def define_cost_per_token(context_length):
            if context_length == "8K":
                cost_per_1k_input_tokens = 0.03
                cost_per_1k_output_tokens = 0.06
            else:
                cost_per_1k_input_tokens = 0.06
                cost_per_1k_output_tokens = 0.12
            return cost_per_1k_input_tokens, cost_per_1k_output_tokens
        
        self.context_length = gr.Dropdown(["8K", "32K"], value="8K", interactive=True,
                                          label="Context size",
                                          visible=False, info="Number of tokens the model considers when processing text")
        self.input_tokens_cost_per_token = gr.Number(0.03, visible=False,
                                           label="($) Price/1K input prompt tokens",
                                           interactive=False
                                           )
        self.output_tokens_cost_per_token = gr.Number(0.06, visible=False,
                                           label="($) Price/1K output prompt tokens",
                                           interactive=False
                                           )
        self.info = gr.Markdown("The cost per input and output tokens values are from OpenAI's [pricing web page](https://openai.com/pricing)", interactive=False, visible=False)
        self.context_length.change(define_cost_per_token, inputs=self.context_length, outputs=[self.input_tokens_cost_per_token, self.output_tokens_cost_per_token])
        
        self.labor = gr.Number(0, visible=False, 
                                label="($) Labor cost per month", 
                                info="This is an estimate of the labor cost of the AI engineer in charge of deploying the model",
                                interactive=True
                                )

    def compute_cost_per_token(self, input_tokens_cost_per_token, output_tokens_cost_per_token, labor):
        cost_per_input_token = (input_tokens_cost_per_token / 1000) 
        cost_per_output_token = (output_tokens_cost_per_token / 1000)

        return cost_per_input_token, cost_per_output_token, labor

class OpenAIModelGPT3_5(BaseTCOModel):

    def __init__(self):
        self.set_name("(SaaS) OpenAI GPT3.5 Turbo")
        self.set_latency("5s") #Average latency value for GPT3.5 Turbo
        super().__init__()

    def render(self):
        def define_cost_per_token(context_length):
            if context_length == "4K":
                cost_per_1k_input_tokens = 0.0015
                cost_per_1k_output_tokens = 0.002
            else:
                cost_per_1k_input_tokens = 0.003
                cost_per_1k_output_tokens = 0.004
            return cost_per_1k_input_tokens, cost_per_1k_output_tokens

        self.context_length = gr.Dropdown(choices=["4K", "16K"], value="4K", interactive=True,
                                          label="Context size",
                                          visible=False, info="Number of tokens the model considers when processing text")
        self.input_tokens_cost_per_token = gr.Number(0.0015, visible=False,
                                           label="($) Price/1K input prompt tokens",
                                           interactive=False
                                           )
        self.output_tokens_cost_per_token = gr.Number(0.002, visible=False,
                                           label="($) Price/1K output prompt tokens",
                                           interactive=False
                                           )
        self.info = gr.Markdown("The cost per input and output tokens values are from OpenAI's [pricing web page](https://openai.com/pricing)", interactive=False, visible=False)
        self.context_length.change(define_cost_per_token, inputs=self.context_length, outputs=[self.input_tokens_cost_per_token, self.output_tokens_cost_per_token])
        
        self.labor = gr.Number(0, visible=False, 
                                label="($) Labor cost per month", 
                                info="This is an estimate of the labor cost of the AI engineer in charge of deploying the model",
                                interactive=True
                                )

    def compute_cost_per_token(self, input_tokens_cost_per_token, output_tokens_cost_per_token, labor):
        cost_per_input_token = (input_tokens_cost_per_token / 1000) 
        cost_per_output_token = (output_tokens_cost_per_token / 1000)

        return cost_per_input_token, cost_per_output_token, labor

class DIYLlama2Model(BaseTCOModel):
    
    def __init__(self):
        self.set_name("(Deploy yourself) Llama 2 70B")
        self.set_latency("27s")
        super().__init__()
    
    def render(self):
        def on_maxed_out_change(maxed_out, input_tokens_cost_per_token, output_tokens_cost_per_token):
            output_tokens_cost_per_token = 0.06656
            input_tokens_cost_per_token = 0.00052
            r = maxed_out / 100
            return input_tokens_cost_per_token * 0.65 / r, output_tokens_cost_per_token * 0.65/ r
        
        self.source = gr.Markdown("""<span style="font-size: 16px; font-weight: 600; color: #212529;">Source</span>""")
        self.info = gr.Markdown("The cost per input and output tokens values below are from [these benchmark results](https://www.cursor.so/blog/llama-inference#user-content-fn-llama-paper) that were obtained using the following initial configurations.", 
                                 interactive=False, 
                                 visible=False)
        self.vm = gr.Textbox(value="2x A100 80GB NVLINK", 
                              visible=False,
                              label="Instance of VM with GPU",
                              )
        self.vm_cost_per_hour = gr.Number(4.42, label="Instance cost ($) per hour",
                                      interactive=False, visible=False)
        self.info_vm = gr.Markdown("This price above is from [CoreWeave's pricing web page](https://www.coreweave.com/gpu-cloud-pricing)", interactive=False, visible=False)
        self.maxed_out = gr.Slider(minimum=1, maximum=100, value=65, step=1, label="Maxed out", info="Estimated average percentage of total GPU memory that is used. The instantaneous value can go from very high when many users are using the service to very low when no one does.")
        self.info_maxed_out = gr.Markdown(r"""This percentage influences the input and output cost/token values, and more precisely the number of token/s. Here is the formula used:<br>
                                          $CT = \frac{VM_C}{TS}$ where $TS = TS_{max} * \frac{MO}{100}$ <br>
                                          with: <br>
                                          $CT$ = Cost per Token (Input or output), <br>
                                          $VM_C$ = VM Cost per second, <br>
                                          $TS$ = Tokens per second (Input or output), <br>
                                          $TS_{max}$ = Tokens per second when the GPU is maxed out at 100%, <br>
                                          $MO$ = Maxed Out, <br>
                                          """, interactive=False, visible=False)
        self.input_tokens_cost_per_token = gr.Number(0.00052, visible=False,
                                           label="($) Price/1K input prompt tokens",
                                           interactive=False
                                           )
        self.output_tokens_cost_per_token = gr.Number(0.06656, visible=False,
                                           label="($) Price/1K output prompt tokens",
                                           interactive=False
                                           )
        self.maxed_out.change(on_maxed_out_change, inputs=[self.maxed_out, self.input_tokens_cost_per_token, self.output_tokens_cost_per_token], outputs=[self.input_tokens_cost_per_token, self.output_tokens_cost_per_token])
        
        self.labor = gr.Number(5000, visible=False, 
                                label="($) Labor cost per month",
                                info="This is an estimate of the labor cost of the AI engineer in charge of deploying the model",
                                interactive=True
                                )

    def compute_cost_per_token(self, input_tokens_cost_per_token, output_tokens_cost_per_token, labor):
        cost_per_input_token = (input_tokens_cost_per_token / 1000) 
        cost_per_output_token = (output_tokens_cost_per_token / 1000)
        return cost_per_input_token,  cost_per_output_token, labor

class CohereModel(BaseTCOModel):
    def __init__(self):
        self.set_name("(SaaS) Cohere")
        self.set_latency("Not available")
        super().__init__()
    
    def render(self):
        def on_model_change(model):
            if model == "Default":
                cost_per_1M_tokens = 15
            else: 
                cost_per_1M_tokens = 30
            cost_per_1K_tokens = cost_per_1M_tokens / 1000
            return gr.update(value=cost_per_1K_tokens), gr.update(value=cost_per_1K_tokens)
        
        self.model = gr.Dropdown(["Default", "Custom"], value="Default",
                                 label="Model",
                                 interactive=True, visible=False)
        self.input_tokens_cost_per_token = gr.Number(0.015, visible=False,
                                           label="($) Price/1K input prompt tokens",
                                           interactive=False
                                           )
        self.output_tokens_cost_per_token = gr.Number(0.015, visible=False,
                                           label="($) Price/1K output prompt tokens",
                                           interactive=False
                                           )
        self.info = gr.Markdown("The cost per input and output tokens value is from Cohere's [pricing web page](https://cohere.com/pricing?utm_term=&utm_campaign=Cohere+Brand+%26+Industry+Terms&utm_source=adwords&utm_medium=ppc&hsa_acc=4946693046&hsa_cam=20368816223&hsa_grp=154209120409&hsa_ad=666081801359&hsa_src=g&hsa_tgt=dsa-19959388920&hsa_kw=&hsa_mt=&hsa_net=adwords&hsa_ver=3&gad=1&gclid=CjwKCAjww7KmBhAyEiwA5-PUSlyO7pq0zxeVrhViXMd8WuILW6uY-cfP1-SVuUfs-leUAz14xHlOHxoCmfkQAvD_BwE)", interactive=False, visible=False)
        self.model.change(on_model_change, inputs=self.model, outputs=[self.input_tokens_cost_per_token, self.output_tokens_cost_per_token])
        self.labor = gr.Number(0, visible=False, 
                                label="($) Labor cost per month", 
                                info="This is an estimate of the labor cost of the AI engineer in charge of deploying the model",
                                interactive=True
                                )

    def compute_cost_per_token(self, input_tokens_cost_per_token, output_tokens_cost_per_token, labor):
        
        cost_per_input_token = input_tokens_cost_per_token / 1000
        cost_per_output_token = output_tokens_cost_per_token / 1000

        return cost_per_input_token, cost_per_output_token, labor

class ModelPage:
    def __init__(self, Models: BaseTCOModel):
        self.models: list[BaseTCOModel] = []
        for Model in Models:
            model = Model()
            self.models.append(model)

    def render(self):
        for model in self.models:
            model.render()
            model.register_components_for_cost_computing() 

    def get_all_components(self) -> list[Component]:
        output = []
        for model in self.models:
            output += model.get_components()
        return output
    
    def get_all_components_for_cost_computing(self) -> list[Component]:
        output = []
        for model in self.models:
            output += model.get_components_for_cost_computing()
        return output

    def make_model_visible(self, name:str, use_case: gr.Dropdown):
        # First decide which indexes
        output = []
        for model in self.models:
            if model.get_name() == name:
                output+= [gr.update(visible=True)] * len(model.get_components()) 
                # Set use_case value in the model
                model.use_case = use_case
            else:
                output+= [gr.update(visible=False)] * len(model.get_components())
        return output
    
    def compute_cost_per_token(self, *args):
        begin=0
        current_model = args[-3]  
        current_input_tokens = args[-2]
        current_output_tokens = args[-1]
        for model in self.models:
            model_n_args = len(model.get_components_for_cost_computing())
            if current_model == model.get_name():
                model_args = args[begin:begin+model_n_args]
                cost_per_input_token, cost_per_output_token, labor_cost = model.compute_cost_per_token(*model_args)
                model_tco = cost_per_input_token * current_input_tokens.value + cost_per_output_token * current_output_tokens.value 
                latency = model.get_latency()
                
                return model_tco, latency, labor_cost
            
            begin = begin+model_n_args