File size: 3,862 Bytes
1ff6584
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e8fd8b
 
 
1ff6584
 
 
5e8fd8b
 
1ff6584
 
0ee737b
1ff6584
 
 
4e46de6
1ff6584
 
 
 
 
 
 
 
5e8fd8b
1ff6584
 
5e8fd8b
1ff6584
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1cf709
5e8fd8b
1ff6584
 
 
 
 
 
 
 
 
 
 
5e8fd8b
 
1ff6584
5e8fd8b
 
1ff6584
 
 
 
5e8fd8b
 
1ff6584
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from llama_index.core import (
    SimpleDirectoryReader,
    VectorStoreIndex,
    StorageContext,
    Settings,
    get_response_synthesizer)
from llama_index.core.query_engine import RetrieverQueryEngine, TransformQueryEngine
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import TextNode, MetadataMode
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.llms.ollama import Ollama
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.indices.query.query_transform import HyDEQueryTransform
import qdrant_client
import logging


class ChatPDF:
    text_chunks = []
    doc_ids = []
    nodes = []

    def __init__(self):
        logging.basicConfig(level=logging.INFO)
        logger = logging.getLogger(__name__)

        text_parser = SentenceSplitter(chunk_size=512, chunk_overlap=100)

        logger.info("initializing the vector store related objects")
        client = qdrant_client.QdrantClient(host="localhost", port=6333)
        vector_store = QdrantVectorStore(client=client, collection_name="rag_documents")

        logger.info("initializing the OllamaEmbedding")
        embed_model = OllamaEmbedding(model_name='mxbai-embed-large', request_timeout=1000000)
        logger.info("initializing the global settings")
        Settings.embed_model = embed_model
        Settings.llm = Ollama(model="qwen:1.8b", request_timeout=1000000)
        Settings.transformations = [text_parser]

    def ingest(self, dir_path: str):
        docs = SimpleDirectoryReader(input_dir=dir_path).load_data()

        logger.info("enumerating docs")
        for doc_idx, doc in enumerate(docs):
            curr_text_chunks = text_parser.split_text(doc.text)
            text_chunks.extend(curr_text_chunks)
            doc_ids.extend([doc_idx] * len(curr_text_chunks))

        logger.info("enumerating text_chunks")
        for idx, text_chunk in enumerate(text_chunks):
            node = TextNode(text=text_chunk)
            src_doc = docs[doc_ids[idx]]
            node.metadata = src_doc.metadata
            nodes.append(node)

        logger.info("enumerating nodes")
        for node in nodes:
            node_embedding = embed_model.get_text_embedding(
                node.get_content(metadata_mode=MetadataMode.ALL)
            )
            node.embedding = node_embedding

        logger.info("initializing the storage context")
        storage_context = StorageContext.from_defaults(vector_store=vector_store)
        logger.info("indexing the nodes in VectorStoreIndex")
        index = VectorStoreIndex(
            nodes=nodes,
            storage_context=storage_context,
            transformations=Settings.transformations,
        )

        logger.info("initializing the VectorIndexRetriever with top_k as 5")
        vector_retriever = VectorIndexRetriever(index=index, similarity_top_k=5)
        response_synthesizer = get_response_synthesizer()
        logger.info("creating the RetrieverQueryEngine instance")
        vector_query_engine = RetrieverQueryEngine(
            retriever=vector_retriever,
            response_synthesizer=response_synthesizer,
        )
        logger.info("creating the HyDEQueryTransform instance")
        hyde = HyDEQueryTransform(include_original=True)
        self.hyde_query_engine = TransformQueryEngine(vector_query_engine, hyde)

    def ask(self, query: str):
        if not self.hyde_query_engine:
            return "Please, add a PDF document first."

        logger.info("retrieving the response to the query")
        response = self.hyde_query_engine.query(str_or_query_bundle=query)
        print(response)
        return response

    def clear(self):
        self.text_chunks = []
        self.doc_ids = []
        self.nodes = []