|
""" |
|
LLM chain retrieval |
|
""" |
|
|
|
import json |
|
import gradio as gr |
|
|
|
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain_huggingface import HuggingFaceEndpoint |
|
from langchain_core.prompts import PromptTemplate |
|
|
|
|
|
|
|
PROMPT_TEMPLATE = """ |
|
You are an assistant for question-answering tasks. Use the following pieces of context to answer the question at the end. |
|
If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer concise. |
|
Question: {question} |
|
Context: {context} |
|
Helpful Answer: |
|
""" |
|
|
|
|
|
|
|
def initialize_llmchain( |
|
llm_model, |
|
huggingfacehub_api_token, |
|
temperature, |
|
max_tokens, |
|
top_k, |
|
vector_db, |
|
progress=gr.Progress(), |
|
): |
|
"""Initialize Langchain LLM chain""" |
|
|
|
progress(0.1, desc="Initializing HF tokenizer...") |
|
|
|
progress(0.5, desc="Initializing HF Hub...") |
|
|
|
|
|
|
|
|
|
llm = HuggingFaceEndpoint( |
|
repo_id=llm_model, |
|
task="text-generation", |
|
temperature=temperature, |
|
max_new_tokens=max_tokens, |
|
top_k=top_k, |
|
huggingfacehub_api_token=huggingfacehub_api_token, |
|
) |
|
|
|
progress(0.75, desc="Defining buffer memory...") |
|
memory = ConversationBufferMemory(memory_key="chat_history", output_key="answer", return_messages=True) |
|
|
|
retriever = vector_db.as_retriever() |
|
|
|
progress(0.8, desc="Defining retrieval chain...") |
|
with open("prompt_template.json", "r") as file: |
|
system_prompt = json.load(file) |
|
prompt_template = system_prompt["prompt"] |
|
rag_prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) |
|
qa_chain = ConversationalRetrievalChain.from_llm( |
|
llm, |
|
retriever=retriever, |
|
chain_type="stuff", |
|
memory=memory, |
|
combine_docs_chain_kwargs={"prompt": rag_prompt}, |
|
return_source_documents=True, |
|
|
|
verbose=False, |
|
) |
|
progress(0.9, desc="Done!") |
|
|
|
return qa_chain |
|
|
|
|
|
def format_chat_history(message, chat_history): |
|
"""Format chat history for llm chain""" |
|
|
|
formatted_chat_history = [] |
|
for user_message, bot_message in chat_history: |
|
formatted_chat_history.append(f"User: {user_message}") |
|
formatted_chat_history.append(f"Assistant: {bot_message}") |
|
return formatted_chat_history |
|
|
|
|
|
def invoke_qa_chain(qa_chain, message, history): |
|
"""Invoke question-answering chain""" |
|
|
|
formatted_chat_history = format_chat_history(message, history) |
|
|
|
|
|
|
|
response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history}) |
|
|
|
response_sources = response["source_documents"] |
|
|
|
response_answer = response["answer"] |
|
if response_answer.find("Helpful Answer:") != -1: |
|
response_answer = response_answer.split("Helpful Answer:")[-1] |
|
|
|
|
|
new_history = history + [(message, response_answer)] |
|
|
|
|
|
|
|
|
|
return qa_chain, new_history, response_sources |
|
|