File size: 4,707 Bytes
b7b243c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains.history_aware_retriever import (
    create_history_aware_retriever,
)
from langchain.chains.retrieval import create_retrieval_chain
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_chroma import Chroma
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter


class ConversationalQA:
    """
    A class that handles conversational question-answering using a
    retrieval-augmented generation approach with session history and
    document retrieval capabilities.
    """

    def __init__(
        self,
        docs: list,
        chunk_size: int = 1000,
        chunk_overlap: int = 200,
    ):
        """
        Initialize the ConversationalQA class with API key, documents, and
        text splitting configurations.

        :param openai_api_key: OpenAI API key to access LLM
        :param docs: List of documents to be used for retrieval and answering
        :param chunk_size: Maximum size of each text chunk for processing
        :param chunk_overlap: Number of characters to overlap between chunks
        """
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size, chunk_overlap=chunk_overlap
        )
        self.splits = self.text_splitter.split_documents(docs)
        self.llm = ChatOpenAI()
        self.vectorstore = Chroma.from_documents(
            documents=self.splits,
            embedding=OpenAIEmbeddings(),
            collection_name="youtube",
        )
        self.retriever = self.vectorstore.as_retriever()

        self.qa_system_prompt = """You are an assistant for question-answering 
        tasks. Use the following pieces of retrieved context to answer the 
        question. If you don't know the answer, just say that you don't know. 
        Use three sentences maximum and keep the answer 
        concise.\n\n{context}"""

        self.qa_prompt = ChatPromptTemplate.from_messages(
            [
                ("system", self.qa_system_prompt),
                MessagesPlaceholder("chat_history"),
                ("human", "{input}"),
            ]
        )

        self.contextualize_q_system_prompt = """Given a chat history and the 
        latest user question which might reference context in the chat 
        history, formulate a standalone question which can be understood 
        without the chat history. Do NOT answer the question, just 
        reformulate it if needed and otherwise return it as is."""

        self.contextualize_q_prompt = ChatPromptTemplate.from_messages(
            [
                ("system", self.contextualize_q_system_prompt),
                MessagesPlaceholder("chat_history"),
                ("human", "{input}"),
            ]
        )

        self.question_answer_chain = create_stuff_documents_chain(
            self.llm, self.qa_prompt
        )
        self.history_aware_chain = create_history_aware_retriever(
            self.llm, self.retriever, self.contextualize_q_prompt
        )
        self.rag_chain = create_retrieval_chain(
            self.history_aware_chain, self.question_answer_chain
        )
        self.store = {}

    def get_session_history(self, session_id: str) -> BaseChatMessageHistory:
        """
        Retrieve or create a chat history for a given session ID.

        :param session_id: Unique session identifier
        :return: ChatMessageHistory object for the session
        """
        if session_id not in self.store:
            self.store[session_id] = ChatMessageHistory()
        return self.store[session_id]

    def invoke_chain(self, session_id: str, user_input: str) -> str:
        """
        Invoke the conversational question-answering chain with user input
        and session history.

        :param session_id: Unique session identifier
        :param user_input: User's question input
        :return: Answer generated by the system
        """
        conversational_rag_chain = RunnableWithMessageHistory(
            self.rag_chain,
            self.get_session_history,
            input_messages_key="input",
            history_messages_key="chat_history",
            output_messages_key="answer",
        )
        return conversational_rag_chain.invoke(
            {"input": user_input},
            config={"configurable": {"session_id": session_id}},
        )["answer"]