""" LLM chain retrieval """ import json import gradio as gr from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain from langchain.memory import ConversationBufferMemory from langchain_huggingface import HuggingFaceEndpoint from langchain_core.prompts import PromptTemplate # Initialize langchain LLM chain def initialize_llmchain( llm_model, huggingfacehub_api_token, temperature, max_tokens, top_k, vector_db, progress=gr.Progress(), ): """Initialize Langchain LLM chain""" progress(0.1, desc="Initializing HF tokenizer...") progress(0.5, desc="Initializing HF Hub...") llm = HuggingFaceEndpoint( repo_id=llm_model, task="text-generation", provider="hf-inference", temperature=temperature, max_new_tokens=max_tokens, top_k=top_k, huggingfacehub_api_token=huggingfacehub_api_token, ) progress(0.75, desc="Defining buffer memory...") memory = ConversationBufferMemory( memory_key="chat_history", output_key="answer", return_messages=True, ) retriever = vector_db.as_retriever(search_type="similarity", search_kwargs={'k': top_k}) progress(0.8, desc="Defining retrieval chain...") with open('prompt_template.json', 'r') as file: system_prompt = json.load(file) prompt_template = system_prompt["prompt"] rag_prompt = PromptTemplate( template=prompt_template, input_variables=["context", "question"] ) qa_chain = ConversationalRetrievalChain.from_llm( llm, retriever=retriever, chain_type="stuff", memory=memory, combine_docs_chain_kwargs={"prompt": rag_prompt}, return_source_documents=True, verbose=False, ) progress(0.9, desc="Done!") return qa_chain # Format chat history def format_chat_history(message, chat_history): """Format chat history for LLM""" formatted_chat_history = [] for user_message, bot_message in chat_history: formatted_chat_history.append(f"User: {user_message}") formatted_chat_history.append(f"Assistant: {bot_message}") return formatted_chat_history # Invoke QA chain with history def invoke_qa_chain(qa_chain, message, history): """Invoke question-answering chain""" formatted_chat_history = format_chat_history(message, history) response = qa_chain.invoke({ "question": message, "chat_history": formatted_chat_history, }) response_sources = response["source_documents"] response_answer = response["answer"] # Clean up if "Helpful Answer:" is included if "Helpful Answer:" in response_answer: response_answer = response_answer.split("Helpful Answer:")[-1].strip() new_history = history + [(message, response_answer)] return qa_chain, new_history, response_sources