import os
import pickle

import faiss
import langchain
from langchain import HuggingFaceHub
from langchain.cache import InMemoryCache
from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import DirectoryLoader, TextLoader, UnstructuredHTMLLoader, PyPDFDirectoryLoader
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceHubEmbeddings
from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts.chat import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    SystemMessagePromptTemplate,
)
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores.faiss import FAISS

from mapping import FILE_URL_MAPPING
from memory import CustomMongoDBChatMessageHistory

langchain.llm_cache = InMemoryCache()

global model_name

models = ["GPT-3.5", "Flan UL2", "GPT-4", "Flan T5"]

pickle_file = "_vs.pkl"
index_file = "_vs.index"
models_folder = "models/"
MONGO_DB_URL = os.environ['MONGO_DB_URL']

llm = ChatOpenAI(model_name="gpt-4", temperature=0.1)

embeddings = OpenAIEmbeddings(model='text-embedding-ada-002')

message_history = CustomMongoDBChatMessageHistory(
    connection_string=MONGO_DB_URL, session_id='session_id', database_name='coursera_bots',
    collection_name='3d_printing_applications'
)

memory = ConversationBufferWindowMemory(memory_key="chat_history", k=4)

vectorstore_index = None

system_template = """You are Coursera QA Bot. Have a conversation with a human, answering the following questions as best you can.
You are a teaching assistant for a Coursera Course: 3D Printing Applications and can answer any question about that using vectorstore or context.
Use the following pieces of context to answer the users question. 
----------------
{context}"""

messages = [
    SystemMessagePromptTemplate.from_template(system_template),
    HumanMessagePromptTemplate.from_template("{question}"),
]
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)


def set_session_id(session_id):
    global message_history, memory
    # check if message_history with same session id exists
    if message_history.session_id == session_id:
        print("Session id already set: " + str(message_history.session_id))
    else:
        # create new message history with session id
        print("Setting session id to " + str(session_id))
        message_history = CustomMongoDBChatMessageHistory(
            connection_string=MONGO_DB_URL, session_id=session_id, database_name='coursera_bots',
            collection_name='printing_3d_applications'
        )
        memory = ConversationBufferWindowMemory(memory_key="chat_history", chat_memory=message_history, k=10,
                                                return_messages=True)


def set_model_and_embeddings(model):
    set_model(model)
    # set_embeddings(model)


def set_model(model):
    global llm
    print("Setting model to " + str(model))
    if model == "GPT-3.5":
        print("Loading GPT-3.5")
        llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.1)
    elif model == "GPT-4":
        print("Loading GPT-4")
        llm = ChatOpenAI(model_name="gpt-4", temperature=0.1)
    elif model == "Flan UL2":
        print("Loading Flan-UL2")
        llm = HuggingFaceHub(repo_id="google/flan-ul2", model_kwargs={"temperature": 0.1, "max_new_tokens": 500})
    elif model == "Flan T5":
        print("Loading Flan T5")
        llm = HuggingFaceHub(repo_id="google/flan-t5-base", model_kwargs={"temperature": 0.1})
    else:
        print("Loading GPT-3.5 from else")
        llm = ChatOpenAI(model_name="text-davinci-002", temperature=0.1)


def set_embeddings(model):
    global embeddings
    if model == "GPT-3.5" or model == "GPT-4":
        print("Loading OpenAI embeddings")
        embeddings = OpenAIEmbeddings(model='text-embedding-ada-002')
    elif model == "Flan UL2" or model == "Flan T5":
        print("Loading Hugging Face embeddings")
        embeddings = HuggingFaceHubEmbeddings(repo_id="sentence-transformers/all-MiniLM-L6-v2")


def get_search_index(model):
    global vectorstore_index
    if os.path.isfile(get_file_path(model, pickle_file)) and os.path.isfile(
            get_file_path(model, index_file)) and os.path.getsize(get_file_path(model, pickle_file)) > 0:
        # Load index from pickle file
        with open(get_file_path(model, pickle_file), "rb") as f:
            search_index = pickle.load(f)
            print("Loaded index")
    else:
        search_index = create_index(model)
        print("Created index")

    vectorstore_index = search_index
    return search_index


def create_index(model):
    source_chunks = create_chunk_documents()
    search_index = search_index_from_docs(source_chunks)
    faiss.write_index(search_index.index, get_file_path(model, index_file))
    # Save index to pickle file
    with open(get_file_path(model, pickle_file), "wb") as f:
        pickle.dump(search_index, f)
    return search_index


def get_file_path(model, file):
    # If model is GPT3.5 or GPT4 return models_folder + openai + file else return models_folder + hf + file
    if model == "GPT-3.5" or model == "GPT-4":
        return models_folder + "openai" + file
    else:
        return models_folder + "hf" + file


def search_index_from_docs(source_chunks):
    # print("source chunks: " + str(len(source_chunks)))
    # print("embeddings: " + str(embeddings))

    search_index = FAISS.from_documents(source_chunks, embeddings)
    return search_index


def get_pdf_files():
    loader = PyPDFDirectoryLoader('docs', glob="**/*.pdf", recursive=True)
    document_list = loader.load()
    return document_list
def get_html_files():
    loader = DirectoryLoader('docs', glob="**/*.html", loader_cls=UnstructuredHTMLLoader, recursive=True)
    document_list = loader.load()
    return document_list


def fetch_data_for_embeddings():
    document_list = get_text_files()
    document_list.extend(get_html_files())
    document_list.extend(get_pdf_files())

    # use file_url_mapping to set metadata of document to url which has been set as the source
    for document in document_list:
        document.metadata["url"] = FILE_URL_MAPPING.get(document.metadata["source"])
    print("document list: " + str(len(document_list)))
    return document_list


def get_text_files():
    loader = DirectoryLoader('docs', glob="**/*.txt", loader_cls=TextLoader, recursive=True)
    document_list = loader.load()
    return document_list


def create_chunk_documents():
    sources = fetch_data_for_embeddings()

    splitter = CharacterTextSplitter(separator=" ", chunk_size=800, chunk_overlap=0)

    source_chunks = splitter.split_documents(sources)

    print("chunks: " + str(len(source_chunks)))

    return source_chunks


def get_qa_chain(vectorstore_index):
    global llm
    print(llm)

    # embeddings_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.76)
    # compression_retriever = ContextualCompressionRetriever(base_compressor=embeddings_filter, base_retriever=gpt_3_5_index.as_retriever())
    retriever = vectorstore_index.as_retriever(search_type="similarity_score_threshold",
                                               search_kwargs={"score_threshold": .76})

    chain = ConversationalRetrievalChain.from_llm(llm, retriever, return_source_documents=True,
                                                  verbose=True,
                                                  combine_docs_chain_kwargs={"prompt": CHAT_PROMPT})
    return chain


def get_chat_history(inputs) -> str:
    res = []
    for human, ai in inputs:
        res.append(f"Human:{human}\nAI:{ai}")
    return "\n".join(res)


def generate_answer(question) -> str:
    global vectorstore_index
    chain = get_qa_chain(vectorstore_index)
    # get last 4 messages from chat history
    history = memory.chat_memory.messages[-4:]
    result = chain(
        {"question": question, "chat_history": history})

    save_chat_history(question, result)
    sources = []
    print(result)

    for document in result['source_documents']:
        sources.append("\n" + document.metadata['url'])
        print(sources)

    source = ',\n'.join(set(sources))
    return result['answer'] + '\nSOURCES: ' + source


def save_chat_history(question, result):
    memory.chat_memory.add_user_message(question)
    memory.chat_memory.add_ai_message(result["answer"])
    print("chat history after saving: " + str(memory.chat_memory.messages))