import os
from langchain_community.llms import HuggingFaceHub
from langchain_community.llms import OpenAI
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
import warnings

warnings.filterwarnings("ignore")

class LLLResponseGenerator():
    def __init__(self):
        self.context = "You are a mental health supporting non-medical assistant. DO NOT PROVIDE any medical advice with conviction."
        self.conversation_history = []

    def update_context(self, user_text):
        self.conversation_history.append(user_text)
        self.context = "\n".join(self.conversation_history)

    def llm_inference(
        self,
        model_type: str,
        question: str,
        prompt_template: str,
        ai_tone: str,
        questionnaire: str,
        user_text: str,
        openai_model_name: str = "",
        hf_repo_id: str = "tiiuae/falcon-7b-instruct",
        temperature: float = 0.1,
        max_length: int = 128,
    ) -> str:
        """Call HuggingFace/OpenAI model for inference

        Given a question, prompt_template, and other parameters, this function calls the relevant
        API to fetch LLM inference results.

        Args:
            model_str: Denotes the LLM vendor's name. Can be either 'huggingface' or 'openai'
            question: The question to be asked to the LLM.
            prompt_template: The prompt template itself.
            ai_tone: Can be either empathy, encouragement or suggest medical help.
            questionnaire: Can be either depression, anxiety or adhd.
            user_text: Response given by the user.
            hf_repo_id: The Huggingface model's repo_id
            temperature: (Default: 1.0). Range: Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, 0 means always take the highest score, 100.0 is getting closer to uniform probability.
            max_length: Integer to define the maximum length in tokens of the output summary.

        Returns:
            A Python string which contains the inference result.

        HuggingFace repo_id examples:
            - google/flan-t5-xxl
            - tiiuae/falcon-7b-instruct

        """
        prompt = PromptTemplate(
            template=prompt_template,
            input_variables=[
                "context",
                "ai_tone",
                "questionnaire",
                "question",
                "user_text",
            ],
        )

        if model_type == "openai":
            llm = OpenAI(
                model_name=openai_model_name, temperature=temperature, max_tokens=max_length
            )
            llm_chain = LLMChain(prompt=prompt, llm=llm)
            return llm_chain.run(
                context=self.context,
                ai_tone=ai_tone,
                questionnaire=questionnaire,
                question=question,
                user_text=user_text,
            )

        elif model_type == "huggingface":
            llm = HuggingFaceHub(
                repo_id=hf_repo_id,
                model_kwargs={"temperature": temperature, "max_length": max_length},
            )

            llm_chain = LLMChain(prompt=prompt, llm=llm)
            response = llm_chain.run(
                context=self.context,
                ai_tone=ai_tone,
                questionnaire=questionnaire,
                question=question,
                user_text=user_text,
            )

            # Extracting only the response part from the output
            response_start_index = response.find("Response;")
            return response[response_start_index + len("Response;"):].strip()

        else:
            print(
                "Please use the correct value of model_type parameter: It can have a value of either openai or huggingface"
            )


if __name__ == "__main__":
    # Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN' and 'OPENAI_API_KEY' values.
    HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')

    ai_tone = "EMPATHY"
    questionnaire = "ADHD"
    question = (
        "How often do you find yourself having trouble focusing on tasks or activities?"
    )
    user_text = "I feel distracted all the time, and I am never able to finish"

    # The user may have signs of {questionnaire}.
    template = """INSTRUCTIONS: {context}

    Respond to the user with a tone of {ai_tone}.

    Question asked to the user: {question}

    Response by the user: {user_text}

    Provide some advice and ask a relevant question back to the user.

    Response;
    """

    temperature = 0.1
    max_length = 128

    model = LLLResponseGenerator()

    # Initial prompt
    print("Bot:", model.llm_inference(
        model_type="huggingface",
        question=question,
        prompt_template=template,
        ai_tone=ai_tone,
        questionnaire=questionnaire,
        user_text=user_text,
        temperature=temperature,
        max_length=max_length,
        ))

    while True:
        user_input = input("You: ")
        if user_input.lower() == "exit":
            break

        model.update_context(user_input)

        print("Bot:", model.llm_inference(
            model_type="huggingface",
            question=question,
            prompt_template=template,
            ai_tone=ai_tone,
            questionnaire=questionnaire,
            user_text=user_input,
            temperature=temperature,
            max_length=max_length,
        ))