bstraehle commited on
Commit
a0b5dc6
·
1 Parent(s): 7186dc1

Update rag_llamaindex.py

Browse files
Files changed (1) hide show
  1. rag_llamaindex.py +13 -1
rag_llamaindex.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  from llama_hub.youtube_transcript import YoutubeTranscriptReader
4
  from llama_index import download_loader, PromptTemplate
5
  from llama_index.indices.vector_store.base import VectorStoreIndex
 
6
  from llama_index.storage.storage_context import StorageContext
7
  from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch
8
 
@@ -50,7 +51,8 @@ class LlamaIndexRAG(BaseRAG):
50
 
51
  def store_documents(self, config, docs):
52
  storage_context = StorageContext.from_defaults(
53
- vector_store = self.get_vector_store())
 
54
 
55
  VectorStoreIndex.from_documents(
56
  docs,
@@ -70,11 +72,21 @@ class LlamaIndexRAG(BaseRAG):
70
 
71
  self.store_documents(config, docs)
72
 
 
 
 
 
 
73
  def retrieval(self, config, prompt):
 
 
 
 
74
  index = VectorStoreIndex.from_vector_store(
75
  vector_store = self.get_vector_store())
76
 
77
  query_engine = index.as_query_engine(
 
78
  similarity_top_k = config["k"]
79
  )
80
 
 
3
  from llama_hub.youtube_transcript import YoutubeTranscriptReader
4
  from llama_index import download_loader, PromptTemplate
5
  from llama_index.indices.vector_store.base import VectorStoreIndex
6
+ from llama_index.llms import OpenAI
7
  from llama_index.storage.storage_context import StorageContext
8
  from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch
9
 
 
51
 
52
  def store_documents(self, config, docs):
53
  storage_context = StorageContext.from_defaults(
54
+ vector_store = self.get_vector_store()
55
+ )
56
 
57
  VectorStoreIndex.from_documents(
58
  docs,
 
72
 
73
  self.store_documents(config, docs)
74
 
75
+ def get_llm(self, config):
76
+ return OpenAI(
77
+ model = config["model_name"],
78
+ temperature = config["temperature"])
79
+
80
  def retrieval(self, config, prompt):
81
+ service_context = ServiceContext.from_defaults(
82
+ llm = self.get_llm(config)
83
+ )
84
+
85
  index = VectorStoreIndex.from_vector_store(
86
  vector_store = self.get_vector_store())
87
 
88
  query_engine = index.as_query_engine(
89
+ service_context = service_context,
90
  similarity_top_k = config["k"]
91
  )
92