from langchain_core.prompts import ChatPromptTemplate # from modules.chat.langchain.utils import from langchain_community.chat_message_histories import ChatMessageHistory from modules.chat.base import BaseRAG from langchain_core.prompts import PromptTemplate from langchain.memory import ConversationBufferWindowMemory from langchain_core.runnables.utils import ConfigurableFieldSpec from .utils import ( CustomConversationalRetrievalChain, create_history_aware_retriever, create_stuff_documents_chain, create_retrieval_chain, return_questions, CustomRunnableWithHistory, BaseChatMessageHistory, InMemoryHistory, ) class Langchain_RAG_V1(BaseRAG): def __init__( self, llm, memory, retriever, qa_prompt: str, rephrase_prompt: str, config: dict, callbacks=None, ): """ Initialize the Langchain_RAG class. Args: llm (LanguageModelLike): The language model instance. memory (BaseChatMessageHistory): The chat message history instance. retriever (BaseRetriever): The retriever instance. qa_prompt (str): The QA prompt string. rephrase_prompt (str): The rephrase prompt string. """ self.llm = llm self.config = config # self.memory = self.add_history_from_list(memory) self.memory = ConversationBufferWindowMemory( k=self.config["llm_params"]["memory_window"], memory_key="chat_history", return_messages=True, output_key="answer", max_token_limit=128, ) self.retriever = retriever self.qa_prompt = qa_prompt self.rephrase_prompt = rephrase_prompt self.store = {} self.qa_prompt = PromptTemplate( template=self.qa_prompt, input_variables=["context", "chat_history", "input"], ) self.rag_chain = CustomConversationalRetrievalChain.from_llm( llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True, memory=self.memory, combine_docs_chain_kwargs={"prompt": self.qa_prompt}, response_if_no_docs_found="No context found", ) def add_history_from_list(self, history_list): """ TODO: Add messages from a list to the chat history. """ history = [] return history async def invoke(self, user_query, config): """ Invoke the chain. Args: kwargs: The input variables. Returns: dict: The output variables. """ res = await self.rag_chain.acall(user_query["input"]) return res class QuestionGenerator: """ Generate a question from the LLMs response and users input and past conversations. """ def __init__(self): pass def generate_questions(self, query, response, chat_history, context, config): questions = return_questions(query, response, chat_history, context, config) return questions class Langchain_RAG_V2(BaseRAG): def __init__( self, llm, memory, retriever, qa_prompt: str, rephrase_prompt: str, config: dict, callbacks=None, ): """ Initialize the Langchain_RAG class. Args: llm (LanguageModelLike): The language model instance. memory (BaseChatMessageHistory): The chat message history instance. retriever (BaseRetriever): The retriever instance. qa_prompt (str): The QA prompt string. rephrase_prompt (str): The rephrase prompt string. """ self.llm = llm self.memory = self.add_history_from_list(memory) self.retriever = retriever self.qa_prompt = qa_prompt self.rephrase_prompt = rephrase_prompt self.store = {} # Contextualize question prompt contextualize_q_system_prompt = rephrase_prompt or ( "Given a chat history and the latest user question " "which might reference context in the chat history, " "formulate a standalone question which can be understood " "without the chat history. Do NOT answer the question, just " "reformulate it if needed and otherwise return it as is." ) self.contextualize_q_prompt = ChatPromptTemplate.from_template( contextualize_q_system_prompt ) # History-aware retriever self.history_aware_retriever = create_history_aware_retriever( self.llm, self.retriever, self.contextualize_q_prompt ) # Answer question prompt qa_system_prompt = qa_prompt or ( "You are an assistant for question-answering tasks. Use " "the following pieces of retrieved context to answer the " "question. If you don't know the answer, just say that you " "don't know. Use three sentences maximum and keep the answer " "concise." "\n\n" "{context}" ) self.qa_prompt_template = ChatPromptTemplate.from_template(qa_system_prompt) # Question-answer chain self.question_answer_chain = create_stuff_documents_chain( self.llm, self.qa_prompt_template ) # Final retrieval chain self.rag_chain = create_retrieval_chain( self.history_aware_retriever, self.question_answer_chain ) self.rag_chain = CustomRunnableWithHistory( self.rag_chain, get_session_history=self.get_session_history, input_messages_key="input", history_messages_key="chat_history", output_messages_key="answer", history_factory_config=[ ConfigurableFieldSpec( id="user_id", annotation=str, name="User ID", description="Unique identifier for the user.", default="", is_shared=True, ), ConfigurableFieldSpec( id="conversation_id", annotation=str, name="Conversation ID", description="Unique identifier for the conversation.", default="", is_shared=True, ), ConfigurableFieldSpec( id="memory_window", annotation=int, name="Number of Conversations", description="Number of conversations to consider for context.", default=1, is_shared=True, ), ], ).with_config(run_name="Langchain_RAG_V2") if callbacks is not None: self.rag_chain = self.rag_chain.with_config(callbacks=callbacks) def get_session_history( self, user_id: str, conversation_id: str, memory_window: int ) -> BaseChatMessageHistory: """ Get the session history for a user and conversation. Args: user_id (str): The user identifier. conversation_id (str): The conversation identifier. memory_window (int): The number of conversations to consider for context. Returns: BaseChatMessageHistory: The chat message history. """ if (user_id, conversation_id) not in self.store: self.store[(user_id, conversation_id)] = InMemoryHistory() self.store[(user_id, conversation_id)].add_messages( self.memory.messages ) # add previous messages to the store. Note: the store is in-memory. return self.store[(user_id, conversation_id)] async def invoke(self, user_query, config, **kwargs): """ Invoke the chain. Args: kwargs: The input variables. Returns: dict: The output variables. """ res = await self.rag_chain.ainvoke(user_query, config, **kwargs) res["rephrase_prompt"] = self.rephrase_prompt res["qa_prompt"] = self.qa_prompt return res def stream(self, user_query, config): res = self.rag_chain.stream(user_query, config) return res def add_history_from_list(self, conversation_list): """ Add messages from a list to the chat history. Args: messages (list): The list of messages to add. """ history = ChatMessageHistory() for idx, message in enumerate(conversation_list): message_type = ( message.get("type", None) if isinstance(message, dict) else getattr(message, "type", None) ) message_content = ( message.get("content", None) if isinstance(message, dict) else getattr(message, "content", None) ) if message_type in ["human", "user_message"]: history.add_user_message(message_content) elif message_type in ["ai", "ai_message"]: history.add_ai_message(message_content) return history