#####################################
##  BitsAndBytes
#####################################

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
 
model_name = "bn22/Mistral-7B-Instruct-v0.1-sharded"

###### other models:
# "Trelis/Llama-2-7b-chat-hf-sharded-bf16"
# "bn22/Mistral-7B-Instruct-v0.1-sharded"
# "HuggingFaceH4/zephyr-7b-beta"

# function for loading 4-bit quantized model
def load_quantized_model(model_name: str):

    model = HuggingFaceHub(
        repo_id="google/flan-ul2", 
        model_kwargs={"temperature":0.1,
                      "max_new_tokens":256})

    """
    :param model_name: Name or path of the model to be loaded.
    :return: Loaded quantized model.
    
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        load_in_4bit=True,
        torch_dtype=torch.bfloat16,
        quantization_config=bnb_config
    )"""
    return model

##################################################
## vs chat
##################################################
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline

from langchain_core.messages import AIMessage, HumanMessage
from langchain_community.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma

#from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.vectorstores.faiss import FAISS


from dotenv import load_dotenv
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain


load_dotenv()

def get_vectorstore_from_url(url):
    # get the text in document form
    loader = WebBaseLoader(url)
    document = loader.load()
    
    # split the document into chunks
    text_splitter = RecursiveCharacterTextSplitter()
    document_chunks = text_splitter.split_documents(document)
    ####### 
    ''' 
        FAISS
        A FAISS vector store containing the embeddings of the text chunks.
   '''
    model = "BAAI/bge-base-en-v1.5"
    encode_kwargs = {
        "normalize_embeddings": True
    }  # set True to compute cosine similarity
    embeddings = HuggingFaceBgeEmbeddings(
        model_name=model, encode_kwargs=encode_kwargs, model_kwargs={"device": "cpu"}
    )
    # load from disk
    vector_store = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
 
    #vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings)
    vector_store = Chroma.from_documents(document_chunks, embeddings, persist_directory="./chroma_db")
 



    print("-----")
    print(vector_store.similarity_search("What is ALiBi?"))
    print("-----") 

    #######
    # create a vectorstore from the chunks

    return vector_store





def get_context_retriever_chain(vector_store):

    # specify model huggingface mode name
    model_name = "anakin87/zephyr-7b-alpha-sharded"
   # model_name = "bn22/Mistral-7B-Instruct-v0.1-sharded"

    ###### other models:
    # "Trelis/Llama-2-7b-chat-hf-sharded-bf16"
    # "bn22/Mistral-7B-Instruct-v0.1-sharded"
    # "HuggingFaceH4/zephyr-7b-beta"

    # function for loading 4-bit quantized model
     

    llm = load_quantized_model(model_name)
    
    retriever = vector_store.as_retriever()
    
    prompt = ChatPromptTemplate.from_messages([
      MessagesPlaceholder(variable_name="chat_history"),
      ("user", "{input}"),
      ("user", "Given the above conversation, generate a search query to look up in order to get information relevant to the conversation")
    ])
    
    retriever_chain = create_history_aware_retriever(llm, retriever, prompt)
    
    return retriever_chain
    
def get_conversational_rag_chain(retriever_chain): 
    
    llm = load_quantized_model(model_name)
    
    prompt = ChatPromptTemplate.from_messages([
      ("system", "Answer the user's questions based on the below context:\n\n{context}"),
      MessagesPlaceholder(variable_name="chat_history"),
      ("user", "{input}"),
    ])
    
    stuff_documents_chain = create_stuff_documents_chain(llm,prompt)
    
    return create_retrieval_chain(retriever_chain, stuff_documents_chain)

def get_response(user_input):
    retriever_chain = get_context_retriever_chain(st.session_state.vector_store)
    conversation_rag_chain = get_conversational_rag_chain(retriever_chain)
    
    response = conversation_rag_chain.invoke({
        "chat_history": st.session_state.chat_history,
        "input": user_query
    })
    
    return response['answer']



###################

###################
import gradio as gr

##from langchain_core.runnables.base import ChatPromptValue
#from torch import tensor

# Create Gradio interface
#vector_store = None  # Set your vector store here
chat_history = []     # Set your chat history here

# Define your function here
def get_response(user_input):

      # Define the prompt as a ChatPromptValue object
    #user_input = ChatPromptValue(user_input)
    
    # Convert the prompt to a tensor
    #input_ids = user_input.tensor
    

    #vs = get_vectorstore_from_url(user_url, all_domain)
    vs = get_vectorstore_from_url("https://www.bofrost.de/shop/laenderkueche_5573/italienische-kueche_5576/linguine-mit-feinen-pilzen.html?position=1&clicked=")
    print("------ here 22 " )
    chat_history =[]
    retriever_chain = get_context_retriever_chain(vs)
    conversation_rag_chain = get_conversational_rag_chain(retriever_chain)
    
    response = conversation_rag_chain.invoke({
        "chat_history": chat_history,
        "input": user_input
    })
    
    return response['answer']

def simple(text:str):
  return text +" hhhmmm "

app = gr.Interface(
    fn=get_response,
    #fn=simple,
    inputs=["text"],
    outputs="text",
    title="Chat with Websites",
    description="Type your message and chat with websites.",
    #allow_flagging=False
)

app.launch(debug=True, share=True)#wie registriere ich mich bei bofrost? Was kosten Linguine