|
""" |
|
LLM chain retrieval |
|
""" |
|
|
|
import json |
|
import gradio as gr |
|
|
|
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain_huggingface import HuggingFaceEndpoint |
|
from langchain_core.prompts import PromptTemplate |
|
|
|
|
|
|
|
def initialize_llmchain( |
|
llm_model, |
|
huggingfacehub_api_token, |
|
temperature, |
|
max_tokens, |
|
top_k, |
|
vector_db, |
|
progress=gr.Progress(), |
|
): |
|
"""Initialize Langchain LLM chain""" |
|
|
|
progress(0.1, desc="Initializing HF tokenizer...") |
|
progress(0.5, desc="Initializing HF Hub...") |
|
|
|
llm = HuggingFaceEndpoint( |
|
repo_id=llm_model, |
|
task="text-generation", |
|
provider="hf-inference", |
|
temperature=temperature, |
|
max_new_tokens=max_tokens, |
|
top_k=top_k, |
|
huggingfacehub_api_token=huggingfacehub_api_token, |
|
) |
|
|
|
progress(0.75, desc="Defining buffer memory...") |
|
memory = ConversationBufferMemory( |
|
memory_key="chat_history", |
|
output_key="answer", |
|
return_messages=True, |
|
) |
|
retriever = vector_db.as_retriever(search_type="similarity", search_kwargs={'k': top_k}) |
|
|
|
progress(0.8, desc="Defining retrieval chain...") |
|
with open('prompt_template.json', 'r') as file: |
|
system_prompt = json.load(file) |
|
prompt_template = system_prompt["prompt"] |
|
rag_prompt = PromptTemplate( |
|
template=prompt_template, input_variables=["context", "question"] |
|
) |
|
|
|
qa_chain = ConversationalRetrievalChain.from_llm( |
|
llm, |
|
retriever=retriever, |
|
chain_type="stuff", |
|
memory=memory, |
|
combine_docs_chain_kwargs={"prompt": rag_prompt}, |
|
return_source_documents=True, |
|
verbose=False, |
|
) |
|
|
|
progress(0.9, desc="Done!") |
|
return qa_chain |
|
|
|
|
|
|
|
def format_chat_history(message, chat_history): |
|
"""Format chat history for LLM""" |
|
formatted_chat_history = [] |
|
for user_message, bot_message in chat_history: |
|
formatted_chat_history.append(f"User: {user_message}") |
|
formatted_chat_history.append(f"Assistant: {bot_message}") |
|
return formatted_chat_history |
|
|
|
|
|
|
|
def invoke_qa_chain(qa_chain, message, history): |
|
"""Invoke question-answering chain""" |
|
formatted_chat_history = format_chat_history(message, history) |
|
|
|
response = qa_chain.invoke({ |
|
"question": message, |
|
"chat_history": formatted_chat_history, |
|
}) |
|
|
|
response_sources = response["source_documents"] |
|
response_answer = response["answer"] |
|
|
|
|
|
if "Helpful Answer:" in response_answer: |
|
response_answer = response_answer.split("Helpful Answer:")[-1].strip() |
|
|
|
new_history = history + [(message, response_answer)] |
|
return qa_chain, new_history, response_sources |
|
|