Spaces:
Sleeping
Sleeping
from langchain_community.llms import HuggingFaceEndpoint | |
from langchain.memory import ConversationBufferMemory | |
from langchain.chains import ConversationalRetrievalChain | |
import gradio as gr | |
import os | |
from llm.CustomRetriever import CustomRetriever | |
from langchain.schema.retriever import BaseRetriever | |
from langchain_core.documents import Document | |
from typing import List | |
from langchain.callbacks.manager import CallbackManagerForRetrieverRun | |
from langchain_core.documents import Document | |
from langchain_core.runnables import chain | |
API_TOKEN=os.getenv("TOKEN") | |
# Initialize langchain LLM chain | |
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vdb, | |
thold=0.8, progress=gr.Progress()): | |
llm = HuggingFaceEndpoint( | |
huggingfacehub_api_token = API_TOKEN, | |
repo_id=llm_model, | |
temperature = temperature, | |
max_new_tokens = max_tokens, | |
top_k = top_k, | |
) | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", | |
output_key='answer', | |
return_messages=True | |
) | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm, | |
retriever=CustomRetriever(vectorstore=vdb, thold=thold), | |
chain_type="stuff", | |
memory=memory, | |
return_source_documents=True, | |
verbose=False, | |
) | |
return qa_chain | |
# Initialize LLM | |
def initialize_LLM(llm_temperature, max_tokens, top_k, vector_db, thold, progress=gr.Progress()): | |
# print("llm_option",llm_option) | |
llm_name = "mistralai/Mistral-7B-Instruct-v0.2" | |
#print("llm_name: ",llm_name) | |
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, thold) | |
return qa_chain | |
def format_chat_history(chat_history): | |
formatted_chat_history = [] | |
for user_message, bot_message in chat_history: | |
formatted_chat_history.append(f"User: {user_message}") | |
formatted_chat_history.append(f"Assistant: {bot_message}") | |
return formatted_chat_history | |
def postprocess(response): | |
try: | |
result=response["answer"] | |
#Here should be a binary classification model. | |
if not "I don't know" in result: | |
for doc in response['source_documents']: | |
file_doc="\n\nFile: " + doc.metadata["source"].split('/')[-1] | |
page="\nPage: " + str(doc.metadata["page"]) | |
content="\nFragment: " + doc.page_content.strip() | |
result+=file_doc+page+content | |
return result | |
except: | |
return "I don't know." | |