File size: 3,358 Bytes
cdda8d7
 
 
01e3f20
dd49b84
cdda8d7
 
dd49b84
 
 
cdda8d7
 
 
 
 
 
 
 
 
 
dd49b84
cdda8d7
dd49b84
 
01e3f20
 
 
dd49b84
 
 
 
01e3f20
 
dd49b84
01e3f20
dd49b84
 
 
 
cdda8d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01e3f20
cdda8d7
 
 
 
01e3f20
cdda8d7
 
 
 
dd49b84
cdda8d7
01e3f20
cdda8d7
 
 
 
 
 
 
 
 
 
 
 
 
dd49b84
cdda8d7
 
 
 
 
dd49b84
cdda8d7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from operator import itemgetter
from langchain_core.vectorstores import VectorStoreRetriever
from langchain.schema.runnable import RunnableLambda, RunnableParallel, RunnableSequence
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langchain.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder
from langchain_core.documents import Document
from langchain_core.messages.ai import AIMessage
from langchain_core.messages.human import HumanMessage
from langchain_core.messages.system import SystemMessage
from langchain_core.messages.function import FunctionMessage

template = """
You are a helpful assistant, your job is to answer the user's question using the relevant context.
CONTEXT
=========
{context}
=========

User question: {question}
"""

prompt = PromptTemplate.from_template(template=template)
chat_prompt = ChatPromptTemplate.from_messages([
    ("system", """
     You are a helpful assistant, your job is to answer the user's question using the relevant context in the context section and in the conversation history.
     Make sure to relate the question to the conversation history and the context in the context section. If the question, the context and the conversation history
     does not align please let the user know about this and ask for further clarification.
     =========
     CONTEXT: 
     {context}
     =========
     PREVIOUS CONVERSATION HISTORY:
     {chat_history}
     """),
    #  MessagesPlaceholder(variable_name="chat_history"),
     ("human", "{question}")
])




def to_doc(input: AIMessage) -> list[Document]:
    return [Document(page_content="LLM", metadata={'chunk': 1.0, 'page_number': 1.0, 'text':input.content})]

def merge_docs(a: dict[str, list[Document]]) -> list[Document]:
    merged_docs = []
    for key,value in a.items():
        merged_docs.extend(value)
    return merged_docs



def create_chain(**kwargs) -> RunnableSequence:
    """
    Requires retriever, llm and prompt
    """
    
    retriever: VectorStoreRetriever = kwargs["retriever"]
    llm: AzureChatOpenAI = kwargs.get("llm", None)
    

    if not isinstance(retriever, VectorStoreRetriever):
        raise ValueError
    if not isinstance(llm, AzureChatOpenAI):
        raise ValueError

    docs_chain = (itemgetter("question") | retriever).with_config(config={"run_name": "docs"})
    self_knowledge_chain = (itemgetter("question") | llm | to_doc).with_config(config={"run_name": "self knowledge"})
    response_chain = (chat_prompt | llm).with_config(config={"run_name": "response"})
    merge_docs_link = RunnableLambda(merge_docs).with_config(config={"run_name": "merge docs"})
    
    context_chain = (
        RunnableParallel(
            {
                "docs": docs_chain,
                "self_knowledge": self_knowledge_chain
            }
        ).with_config(config={"run_name": "parallel context"})
        | merge_docs_link
    )

    retrieval_augmented_qa_chain = (
        RunnableParallel({
            "question": itemgetter("question"),
            "chat_history": itemgetter("chat_history"),
            "context": context_chain
        })
        | RunnableParallel({
            "response": response_chain,
            "context": itemgetter("context"),
            "chat_history": itemgetter("chat_history")
        })
    )
    return retrieval_augmented_qa_chain