import logging

from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.llms import HuggingFaceHub
from langchain.prompts.chat import (
    PromptTemplate,
    ChatPromptTemplate,
    MessagesPlaceholder,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.memory import ConversationBufferWindowMemory
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
from openai.error import AuthenticationError
import streamlit as st


def setup_memory():
    msgs = StreamlitChatMessageHistory(key="basic_chat_app")
    memory = ConversationBufferWindowMemory(k=3, memory_key="chat_history",
                                            chat_memory=msgs,
                                            return_messages=True)
    logging.info("setting up new chat memory")
    return memory


def use_existing_chain(model, provider, model_kwargs):
    # TODO: consider whether prompt needs to be checked here
    if "mistral" in model:
        return False
    if "current_chain" in st.session_state:
        current_chain = st.session_state.current_chain
        if (current_chain.model == model) \
                and (current_chain.provider == provider) \
                and (current_chain.model_kwargs == model_kwargs):
            return True
    return False


class CurrentChain():
    def __init__(self, model, provider, prompt, memory, model_kwargs):
        self.model = model
        self.provider = provider
        self.model_kwargs = model_kwargs

        logging.info(f"setting up new chain with params {model_name}, {provider}, {temp}")
        if provider == "OpenAI":
            llm = ChatOpenAI(model_name=model,
                             temperature=model_kwargs['temperature']
                             )
        elif provider == "HuggingFace":
            llm = HuggingFaceHub(repo_id=model,
                                 model_kwargs=model_kwargs
                                 )

        self.conversation = LLMChain(
            llm=llm,
            prompt=prompt,
            verbose=True,
            memory=memory
        )


def format_mistral_prompt(message, history):
    prompt = "<s>"
    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "
    prompt += f"[INST] {message} [/INST]"
    return prompt


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)

    st.header("Basic chatbot")
    st.write("On small screens, click the `>` at top left to choose options")
    with st.expander("How conversation history works"):
        st.write("To keep input lengths down and costs reasonable,"
                 " only the past three turns of conversation "
                 " are used for OpenAI models. Otherwise the entire chat history is used.")
        st.write("To clear all memory and start fresh, click 'Clear history'")
    st.sidebar.title("Choose options")

    #### USER INPUT ######
    model_name = st.sidebar.selectbox(
        label="Choose a model",
        options=["gpt-3.5-turbo (OpenAI)",
                 # "bigscience/bloom (HuggingFace)",  # runs
                 # "google/flan-t5-xxl (HuggingFace)",  # runs
                 "mistralai/Mistral-7B-Instruct-v0.1 (HuggingFace)"
                 ],
        help="Which LLM to use",
    )

    temp = st.sidebar.slider(
        label="Temperature",
        min_value=float(0),
        max_value=2.0,
        step=0.1,
        value=0.4,
        help="Set the decoding temperature. "
             "Higher temps give more unpredictable outputs."
    )
    ##########################

    model = model_name.split("(")[0].rstrip()  # remove name of model provider
    provider = model_name.split("(")[-1].split(")")[0]

    model_kwargs = {"temperature": temp,
                    "max_new_tokens": 256,
                    "repetition_penalty": 1.0,
                    "top_p": 0.95,
                    "do_sample": True,
                    "seed": 42}
    # TODO: maybe expose more of these to the user

    if "session_memory" not in st.session_state:
        st.session_state.session_memory = setup_memory()  # for openai

    if "history" not in st.session_state:
        st.session_state.history = []  # for mistral

    if "mistral" in model:
        prompt = PromptTemplate(input_variables=["input"],
                                template="{input}")
    else:
        prompt = ChatPromptTemplate(
            messages=[
                SystemMessagePromptTemplate.from_template(
                    "You are a nice chatbot having a conversation with a human."
                ),
                MessagesPlaceholder(variable_name="chat_history"),
                HumanMessagePromptTemplate.from_template("{input}")
            ],
            verbose=True
        )

    if use_existing_chain(model, provider, model_kwargs):
        chain = st.session_state.current_chain
    else:
        chain = CurrentChain(model,
                             provider,
                             prompt,
                             st.session_state.session_memory,
                             model_kwargs)
        st.session_state.current_chain = chain

    conversation = chain.conversation

    if st.button("Clear history"):
        conversation.memory.clear()  # for openai
        st.session_state.history = []  # for mistral
        logging.info("history cleared")

    for user_msg, asst_msg in st.session_state.history:
        with st.chat_message("user"):
            st.write(user_msg)
        with st.chat_message("assistant"):
            st.write(asst_msg)

    text = st.chat_input()
    if text:
        with st.chat_message("user"):
            st.write(text)
            logging.info(text)
        try:
            if "mistral" in model:
                full_prompt = format_mistral_prompt(text, st.session_state.history)
                result = conversation.predict(input=full_prompt)
            else:
                result = conversation.predict(input=text)

            st.session_state.history.append((text, result))
            logging.info(repr(result))
            with st.chat_message("assistant"):
                st.write(result)
        except (AuthenticationError, ValueError):
            st.warning("Supply a valid API key", icon="⚠️")