import time, torch
from langchain.prompts import PromptTemplate
from langchain_openai import ChatOpenAI, OpenAI
from langchain.schema import HumanMessage
from langchain_core.output_parsers import JsonOutputParser
from langchain.output_parsers import RetryWithErrorOutputParser

from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
from vouchervision.utils_geolocate_HERE import validate_coordinates_here

class OpenAIHandler: 
    RETRY_DELAY = 10  # Wait 10 seconds before retrying
    MAX_RETRIES = 3  # Maximum number of retries
    STARTING_TEMP = 0.5
    TOKENIZER_NAME = 'gpt-4'
    VENDOR = 'openai'

    def __init__(self, logger, model_name, JSON_dict_structure, is_azure, llm_object):
        self.logger = logger
        self.model_name = model_name
        self.JSON_dict_structure = JSON_dict_structure
        self.is_azure = is_azure
        self.llm_object = llm_object
        self.name_parts = self.model_name.split('-')
        
        self.monitor = SystemLoadMonitor(logger)
        self.has_GPU = torch.cuda.is_available() 

        self.starting_temp = float(self.STARTING_TEMP)
        self.temp_increment = float(0.2)
        self.adjust_temp = self.starting_temp 
        
        # Set up a parser
        self.parser = JsonOutputParser()

        self.prompt = PromptTemplate(
            template="Answer the user query.\n{format_instructions}\n{query}\n",
            input_variables=["query"],
            partial_variables={"format_instructions": self.parser.get_format_instructions()},
        )
        self._set_config()

    def _set_config(self):
        self.config = {'max_new_tokens': 1024,
                'temperature': self.starting_temp,
                'random_seed': 2023,
                'top_p': 1,
                }
        # Adjusting the LLM settings based on whether Azure is used
        if self.is_azure:
            self.llm_object.deployment_name = self.model_name
            self.llm_object.model_name = self.model_name
        else:
            self.llm_object = None
        self._build_model_chain_parser()


       # Define a function to format the input for azure_call
    def format_input_for_azure(self, prompt_text):
        msg = HumanMessage(content=prompt_text.text)
        # self.llm_object.temperature = self.config.get('temperature')
        return self.llm_object(messages=[msg]) 

    def _adjust_config(self):
        new_temp = self.adjust_temp + self.temp_increment
        self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
        self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
        self.adjust_temp += self.temp_increment
        self.config['temperature'] = self.adjust_temp   

    def _reset_config(self):
        self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
        self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
        self.adjust_temp = self.starting_temp
        self.config['temperature'] = self.starting_temp   
        
    def _build_model_chain_parser(self):
        if not self.is_azure and ('instruct' in self.name_parts):
            # Set up the retry parser with 3 retries
            self.retry_parser = RetryWithErrorOutputParser.from_llm(
                # parser=self.parser, llm=self.llm_object if self.is_azure else OpenAI(temperature=self.config.get('temperature'), model=self.model_name), max_retries=self.MAX_RETRIES
                parser=self.parser, llm=self.llm_object if self.is_azure else OpenAI(model=self.model_name), max_retries=self.MAX_RETRIES
            )
        else:
            # Set up the retry parser with 3 retries
            self.retry_parser = RetryWithErrorOutputParser.from_llm(
                # parser=self.parser, llm=self.llm_object if self.is_azure else ChatOpenAI(temperature=self.config.get('temperature'), model=self.model_name), max_retries=self.MAX_RETRIES
                parser=self.parser, llm=self.llm_object if self.is_azure else ChatOpenAI(model=self.model_name), max_retries=self.MAX_RETRIES
            )
        # Prepare the chain
        if not self.is_azure and ('instruct' in self.name_parts):
            # self.chain = self.prompt | (self.format_input_for_azure if self.is_azure else OpenAI(temperature=self.config.get('temperature'), model=self.model_name))
            self.chain = self.prompt | (self.format_input_for_azure if self.is_azure else OpenAI(model=self.model_name))
        else:
            # self.chain = self.prompt | (self.format_input_for_azure if self.is_azure else ChatOpenAI(temperature=self.config.get('temperature'), model=self.model_name))
            self.chain = self.prompt | (self.format_input_for_azure if self.is_azure else ChatOpenAI(model=self.model_name))


    def call_llm_api_OpenAI(self, prompt_template, json_report):
        self.json_report = json_report
        self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
        self.monitor.start_monitoring_usage()
        nt_in = 0
        nt_out = 0
        
        ind = 0
        while ind < self.MAX_RETRIES:
            ind += 1
            try:
                model_kwargs = {"temperature": self.adjust_temp}
                # Invoke the chain to generate prompt text
                response = self.chain.invoke({"query": prompt_template, "model_kwargs": model_kwargs})

                response_text = response.content if not isinstance(response, str) else response

                # Use retry_parser to parse the response with retry logic
                output = self.retry_parser.parse_with_prompt(response_text, prompt_value=prompt_template)

                if output is None:
                    self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response_text}')
                    self._adjust_config()
                else:
                    nt_in = count_tokens(prompt_template, self.VENDOR, self.TOKENIZER_NAME)
                    nt_out = count_tokens(response_text, self.VENDOR, self.TOKENIZER_NAME)
                
                    output = validate_and_align_JSON_keys_with_template(output, self.JSON_dict_structure)
                    if output is None:
                        self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response_text}')
                        self._adjust_config()   
                    else:
                        json_report.set_text(text_main=f'Working on WFO and Geolocation')

                        output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) ###################################### make this configurable
                        output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable

                        self.logger.info(f"Formatted JSON: {output}")

                        self.monitor.stop_monitoring_report_usage()    
                        
                        if self.adjust_temp != self.starting_temp:            
                            self._reset_config()
                        json_report.set_text(text_main=f'LLM call successful')
                        return output, nt_in, nt_out, WFO_record, GEO_record
            
            except Exception as e:
                self.logger.error(f'{e}')
                
                self._adjust_config()           
                time.sleep(self.RETRY_DELAY)

        self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
        self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')

        self.monitor.stop_monitoring_report_usage()                
        self._reset_config()

        json_report.set_text(text_main=f'LLM call failed')
        return None, nt_in, nt_out, None, None