File size: 3,296 Bytes
fb72340
5f04412
 
50dc063
5f04412
a0b5dc6
5f04412
 
 
 
c947c47
1342013
5f04412
ebaeae5
 
 
e946a29
0ddb69a
5f04412
0ddb69a
 
 
 
5f04412
0ddb69a
 
5f04412
0ddb69a
5f04412
0ddb69a
3940450
0ddb69a
 
5f04412
0ddb69a
 
5f04412
0ddb69a
 
 
3940450
0ddb69a
5f04412
0ddb69a
 
3940450
 
0ddb69a
5f04412
0ddb69a
5f04412
c273c9f
 
 
 
0ddb69a
5f04412
e946a29
0ddb69a
3940450
008f2f7
3940450
 
0ddb69a
c273c9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f04412
e946a29
ff3ca13
5f04412
ff3ca13
c273c9f
e946a29
0ddb69a
db1e09d
 
5f04412
0ddb69a
c273c9f
0ddb69a
 
5f04412
0ddb69a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os, requests

from llama_hub.youtube_transcript import YoutubeTranscriptReader
from llama_index import download_loader, PromptTemplate, ServiceContext
from llama_index.indices.vector_store.base import VectorStoreIndex
from llama_index.llms import OpenAI
from llama_index.storage.storage_context import StorageContext
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch

from pathlib import Path
from pymongo import MongoClient
from rag_base import BaseRAG

class LlamaIndexRAG(BaseRAG):
    MONGODB_DB_NAME = "llamaindex_db"
    
    def load_documents(self):
        docs = []
    
        # PDF
        PDFReader = download_loader("PDFReader")
        loader = PDFReader()
        out_dir = Path("data")
    
        if not out_dir.exists():
            os.makedirs(out_dir)
    
        out_path = out_dir / "gpt-4.pdf"
    
        if not out_path.exists():
            r = requests.get(self.PDF_URL)
            with open(out_path, "wb") as f:
                f.write(r.content)

        docs.extend(loader.load_data(file = Path(out_path)))
        #print("docs = " + str(len(docs)))
    
        # Web
        SimpleWebPageReader = download_loader("SimpleWebPageReader")
        loader = SimpleWebPageReader()
        docs.extend(loader.load_data(urls = [self.WEB_URL]))
        #print("docs = " + str(len(docs)))

        # YouTube
        loader = YoutubeTranscriptReader()
        docs.extend(loader.load_data(ytlinks = [self.YOUTUBE_URL_1,
                                                self.YOUTUBE_URL_2]))
        #print("docs = " + str(len(docs)))
    
        return docs

    def get_llm(self, config):
        return OpenAI(
            model = config["model_name"], 
            temperature = config["temperature"]
        )

    def get_vector_store(self):
        return MongoDBAtlasVectorSearch(
            MongoClient(self.MONGODB_ATLAS_CLUSTER_URI),
            db_name = self.MONGODB_DB_NAME,
            collection_name = self.MONGODB_COLLECTION_NAME,
            index_name = self.MONGODB_INDEX_NAME
        )
        
    def get_service_context(config):
        return ServiceContext.from_defaults(
            chunk_overlap = config["chunk_overlap"],
            chunk_size = config["chunk_size"],
            llm = self.get_llm(config)
        )

    def get_storage_context():
        return StorageContext.from_defaults(
            vector_store = self.get_vector_store()
        )
        
    def store_documents(self, config, docs):
        storage_context = StorageContext.from_defaults(
            vector_store = self.get_vector_store()
        )
    
        VectorStoreIndex.from_documents(
            docs,
            service_context = self.get_service_context(config),
            storage_context = self.get_storage_context()
        )

    def ingestion(self, config):
        docs = self.load_documents()
    
        self.store_documents(config, docs)
       
    def retrieval(self, config, prompt):
        index = VectorStoreIndex.from_vector_store(
            vector_store = self.get_vector_store()
        )

        query_engine = index.as_query_engine(
            service_context = self.get_service_context(config),
            similarity_top_k = config["k"]
        )
 
        return query_engine.query(prompt)