Manuscript / llm /utils.py
Gainward777's picture
Update llm/utils.py
46a38fc verified
raw
history blame
2.59 kB
from langchain_community.llms import HuggingFaceEndpoint
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
import gradio as gr
import os
from llm.CustomRetriever import CustomRetriever
from langchain.schema.retriever import BaseRetriever
from langchain_core.documents import Document
from typing import List
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.runnables import chain
API_TOKEN=os.getenv("TOKEN")
# Initialize langchain LLM chain
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vdb,
thold=0.8, progress=gr.Progress()):
llm = HuggingFaceEndpoint(
huggingfacehub_api_token = API_TOKEN,
repo_id=llm_model,
temperature = temperature,
max_new_tokens = max_tokens,
top_k = top_k,
)
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key='answer',
return_messages=True
)
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=CustomRetriever(vectorstore=vdb, thold=thold),
chain_type="stuff",
memory=memory,
return_source_documents=True,
verbose=False,
)
return qa_chain
# Initialize LLM
def initialize_LLM(llm_temperature, max_tokens, top_k, vector_db, thold, progress=gr.Progress()):
# print("llm_option",llm_option)
llm_name = "mistralai/Mistral-7B-Instruct-v0.2"
#print("llm_name: ",llm_name)
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, thold)
return qa_chain
def format_chat_history(chat_history):
formatted_chat_history = []
for user_message, bot_message in chat_history:
formatted_chat_history.append(f"User: {user_message}")
formatted_chat_history.append(f"Assistant: {bot_message}")
return formatted_chat_history
def postprocess(response):
try:
result=response["answer"]
#Here should be a binary classification model.
if not "I don't know" in result:
for doc in response['source_documents']:
file_doc="\n\nFile: " + doc.metadata["source"].split('/')[-1]
page="\nPage: " + str(doc.metadata["page"])
content="\nFragment: " + doc.page_content.strip()
result+=file_doc+page+content
return result
except:
return "I don't know."