from operator import itemgetter from langchain_core.vectorstores import VectorStoreRetriever from langchain.schema.runnable import RunnableLambda, RunnableParallel, RunnableSequence from langchain.chat_models import ChatOpenAI, AzureChatOpenAI from langchain.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder from langchain_core.documents import Document from langchain_core.messages.ai import AIMessage from langchain_core.messages.human import HumanMessage from langchain_core.messages.system import SystemMessage from langchain_core.messages.function import FunctionMessage template = """ You are a helpful assistant, your job is to answer the user's question using the relevant context. CONTEXT ========= {context} ========= User question: {question} """ prompt = PromptTemplate.from_template(template=template) chat_prompt = ChatPromptTemplate.from_messages([ ("system", """ You are a helpful assistant, your job is to answer the user's question using the relevant context in the context section and in the conversation history. Make sure to relate the question to the conversation history and the context in the context section. If the question, the context and the conversation history does not align please let the user know about this and ask for further clarification. ========= CONTEXT: {context} ========= PREVIOUS CONVERSATION HISTORY: {chat_history} """), # MessagesPlaceholder(variable_name="chat_history"), ("human", "{question}") ]) def to_doc(input: AIMessage) -> list[Document]: return [Document(page_content="LLM", metadata={'chunk': 1.0, 'page_number': 1.0, 'text':input.content})] def merge_docs(a: dict[str, list[Document]]) -> list[Document]: merged_docs = [] for key,value in a.items(): merged_docs.extend(value) return merged_docs def create_chain(**kwargs) -> RunnableSequence: """ Requires retriever, llm and prompt """ retriever: VectorStoreRetriever = kwargs["retriever"] llm: AzureChatOpenAI = kwargs.get("llm", None) if not isinstance(retriever, VectorStoreRetriever): raise ValueError if not isinstance(llm, AzureChatOpenAI): raise ValueError docs_chain = (itemgetter("question") | retriever).with_config(config={"run_name": "docs"}) self_knowledge_chain = (itemgetter("question") | llm | to_doc).with_config(config={"run_name": "self knowledge"}) response_chain = (chat_prompt | llm).with_config(config={"run_name": "response"}) merge_docs_link = RunnableLambda(merge_docs).with_config(config={"run_name": "merge docs"}) context_chain = ( RunnableParallel( { "docs": docs_chain, "self_knowledge": self_knowledge_chain } ).with_config(config={"run_name": "parallel context"}) | merge_docs_link ) retrieval_augmented_qa_chain = ( RunnableParallel({ "question": itemgetter("question"), "chat_history": itemgetter("chat_history"), "context": context_chain }) | RunnableParallel({ "response": response_chain, "context": itemgetter("context"), "chat_history": itemgetter("chat_history") }) ) return retrieval_augmented_qa_chain