ChatitoArXiv / utils /chain.py
RubenAMtz's picture
changed chain and system prompt
01e3f20
from operator import itemgetter
from langchain_core.vectorstores import VectorStoreRetriever
from langchain.schema.runnable import RunnableLambda, RunnableParallel, RunnableSequence
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langchain.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder
from langchain_core.documents import Document
from langchain_core.messages.ai import AIMessage
from langchain_core.messages.human import HumanMessage
from langchain_core.messages.system import SystemMessage
from langchain_core.messages.function import FunctionMessage
template = """
You are a helpful assistant, your job is to answer the user's question using the relevant context.
CONTEXT
=========
{context}
=========
User question: {question}
"""
prompt = PromptTemplate.from_template(template=template)
chat_prompt = ChatPromptTemplate.from_messages([
("system", """
You are a helpful assistant, your job is to answer the user's question using the relevant context in the context section and in the conversation history.
Make sure to relate the question to the conversation history and the context in the context section. If the question, the context and the conversation history
does not align please let the user know about this and ask for further clarification.
=========
CONTEXT:
{context}
=========
PREVIOUS CONVERSATION HISTORY:
{chat_history}
"""),
# MessagesPlaceholder(variable_name="chat_history"),
("human", "{question}")
])
def to_doc(input: AIMessage) -> list[Document]:
return [Document(page_content="LLM", metadata={'chunk': 1.0, 'page_number': 1.0, 'text':input.content})]
def merge_docs(a: dict[str, list[Document]]) -> list[Document]:
merged_docs = []
for key,value in a.items():
merged_docs.extend(value)
return merged_docs
def create_chain(**kwargs) -> RunnableSequence:
"""
Requires retriever, llm and prompt
"""
retriever: VectorStoreRetriever = kwargs["retriever"]
llm: AzureChatOpenAI = kwargs.get("llm", None)
if not isinstance(retriever, VectorStoreRetriever):
raise ValueError
if not isinstance(llm, AzureChatOpenAI):
raise ValueError
docs_chain = (itemgetter("question") | retriever).with_config(config={"run_name": "docs"})
self_knowledge_chain = (itemgetter("question") | llm | to_doc).with_config(config={"run_name": "self knowledge"})
response_chain = (chat_prompt | llm).with_config(config={"run_name": "response"})
merge_docs_link = RunnableLambda(merge_docs).with_config(config={"run_name": "merge docs"})
context_chain = (
RunnableParallel(
{
"docs": docs_chain,
"self_knowledge": self_knowledge_chain
}
).with_config(config={"run_name": "parallel context"})
| merge_docs_link
)
retrieval_augmented_qa_chain = (
RunnableParallel({
"question": itemgetter("question"),
"chat_history": itemgetter("chat_history"),
"context": context_chain
})
| RunnableParallel({
"response": response_chain,
"context": itemgetter("context"),
"chat_history": itemgetter("chat_history")
})
)
return retrieval_augmented_qa_chain