File size: 3,848 Bytes
15aea1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import logging
import os
from typing import List, Optional

from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains.conversational_retrieval.base import (
    BaseConversationalRetrievalChain,
)
from langchain.chains.history_aware_retriever import (
    create_history_aware_retriever,
)
from langchain.chains.retrieval import create_retrieval_chain

from ..utilities.llm_models import get_llm_model_chat
from ..vector_store.vector_store import VectorStoreManager
from .prompts import CHAT_PROMPT, CONTEXTUEL_QUERY_PROMPT


class RAGSystem:
    def __init__(
        self,
        docs_dir: str = "data/chunks",
        persist_directory_dir="data/chroma_db",
        batch_size: int = 64,
        top_k_documents=5,
    ):
        self.top_k_documents = top_k_documents
        self.llm = self._get_llm()
        self.chain: Optional[BaseConversationalRetrievalChain] = None
        self.vector_store_management = VectorStoreManager(
            persist_directory_dir, batch_size
        )
        self.docs_dir = docs_dir

    def _get_llm(
        self,
    ):
        return get_llm_model_chat(temperature=0.1, max_tokens=1000)

    def load_documents(self) -> List:
        """Load and split documents from the specified directory"""
        return self.vector_store_management.load_and_process_documents(self.docs_dir)

    def initialize_vector_store(self, documents: List = None):
        """Initialize or load the vector store"""
        self.vector_store_management.initialize_vector_store(documents)

    def setup_rag_chain(self):
        if self.chain is not None:
            return
        retriever = self.vector_store_management.create_retriever(
            self.llm, self.top_k_documents, bm25_portion=0.03
        )

        # Contextualize question
        self.history_aware_retriever = create_history_aware_retriever(
            self.llm, retriever, CONTEXTUEL_QUERY_PROMPT
        )
        self.question_answer_chain = create_stuff_documents_chain(self.llm, CHAT_PROMPT)
        self.chain = create_retrieval_chain(
            self.history_aware_retriever, self.question_answer_chain
        )
        logging.info("RAG chain setup complete" + str(self.chain))

        return self.chain

    def query(self, question: str, history: list = []):
        """Query the RAG system"""
        if not self.vector_store_management.vs_initialized:
            self.initialize_vector_store()

        self.setup_rag_chain()

        for token in self.chain.stream({"input": question, "chat_history": history}):
            if "answer" in token:
                yield token["answer"]


if __name__ == "__main__":
    from glob import glob

    from dotenv import load_dotenv

    # loading variables from .env file
    load_dotenv()

    docs_dir = "data/docs"
    persist_directory_dir = "data/chroma_db"
    batch_size = 64

    # Initialize RAG system
    rag = RAGSystem(docs_dir, persist_directory_dir, batch_size)

    if len(glob(os.path.join(persist_directory_dir, "*/*.bin"))):
        rag.initialize_vector_store()  # vector store initialized
    else:
        # Load and index documents
        documents = rag.load_documents()
        rag.initialize_vector_store(documents)  # documents

    queries = [
        "Quand a eu lieu la traite négrière ?",
        "Explique moi comment soigner la tiphoide puis le paludisme",
        "Quels étaient les premiers peuples d'afrique centrale et quelles ont été leurs migrations?",
    ]

    print("Comparaison méthodes de query")

    for query in queries:
        print("Query: ", query, "\n\n")
        print("1. Méthode simple:--------------------\n")
        rag.query(question=query)

        print("\n\n2. Méthode par décomposition:-----------------------\n\n")
        rag.query_complex(question=query, verbose=True)