Spaces:
Sleeping
Sleeping
from operator import itemgetter | |
from langchain_core.vectorstores import VectorStoreRetriever | |
from langchain.schema.runnable import RunnableLambda, RunnableParallel, RunnableSequence | |
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI | |
from langchain.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder | |
from langchain_core.documents import Document | |
from langchain_core.messages.ai import AIMessage | |
from langchain_core.messages.human import HumanMessage | |
from langchain_core.messages.system import SystemMessage | |
from langchain_core.messages.function import FunctionMessage | |
template = """ | |
You are a helpful assistant, your job is to answer the user's question using the relevant context. | |
CONTEXT | |
========= | |
{context} | |
========= | |
User question: {question} | |
""" | |
prompt = PromptTemplate.from_template(template=template) | |
chat_prompt = ChatPromptTemplate.from_messages([ | |
("system", """ | |
You are a helpful assistant, your job is to answer the user's question using the relevant context in the context section and in the conversation history. | |
Make sure to relate the question to the conversation history and the context in the context section. If the question, the context and the conversation history | |
does not align please let the user know about this and ask for further clarification. | |
========= | |
CONTEXT: | |
{context} | |
========= | |
PREVIOUS CONVERSATION HISTORY: | |
{chat_history} | |
"""), | |
# MessagesPlaceholder(variable_name="chat_history"), | |
("human", "{question}") | |
]) | |
def to_doc(input: AIMessage) -> list[Document]: | |
return [Document(page_content="LLM", metadata={'chunk': 1.0, 'page_number': 1.0, 'text':input.content})] | |
def merge_docs(a: dict[str, list[Document]]) -> list[Document]: | |
merged_docs = [] | |
for key,value in a.items(): | |
merged_docs.extend(value) | |
return merged_docs | |
def create_chain(**kwargs) -> RunnableSequence: | |
""" | |
Requires retriever, llm and prompt | |
""" | |
retriever: VectorStoreRetriever = kwargs["retriever"] | |
llm: AzureChatOpenAI = kwargs.get("llm", None) | |
if not isinstance(retriever, VectorStoreRetriever): | |
raise ValueError | |
if not isinstance(llm, AzureChatOpenAI): | |
raise ValueError | |
docs_chain = (itemgetter("question") | retriever).with_config(config={"run_name": "docs"}) | |
self_knowledge_chain = (itemgetter("question") | llm | to_doc).with_config(config={"run_name": "self knowledge"}) | |
response_chain = (chat_prompt | llm).with_config(config={"run_name": "response"}) | |
merge_docs_link = RunnableLambda(merge_docs).with_config(config={"run_name": "merge docs"}) | |
context_chain = ( | |
RunnableParallel( | |
{ | |
"docs": docs_chain, | |
"self_knowledge": self_knowledge_chain | |
} | |
).with_config(config={"run_name": "parallel context"}) | |
| merge_docs_link | |
) | |
retrieval_augmented_qa_chain = ( | |
RunnableParallel({ | |
"question": itemgetter("question"), | |
"chat_history": itemgetter("chat_history"), | |
"context": context_chain | |
}) | |
| RunnableParallel({ | |
"response": response_chain, | |
"context": itemgetter("context"), | |
"chat_history": itemgetter("chat_history") | |
}) | |
) | |
return retrieval_augmented_qa_chain |