bstraehle commited on
Commit
401a2a7
·
1 Parent(s): 3ddc880

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +14 -6
rag.py CHANGED
@@ -81,15 +81,23 @@ def run_rag_batch(config):
81
  embed_store_documents_mongodb(chunks)
82
 
83
  def retrieve_documents_chroma():
84
- return Chroma(embedding_function = OpenAIEmbeddings(disallowed_special = ()),
85
- persist_directory = CHROMA_DIR)
 
 
 
 
86
 
87
  def retrieve_documents_mongodb():
88
- return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
89
- MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
90
- OpenAIEmbeddings(disallowed_special = ()),
91
- index_name = MONGODB_INDEX_NAME)
 
 
92
 
 
 
93
  def get_llm(config, openai_api_key):
94
  return ChatOpenAI(model_name = config["model_name"],
95
  openai_api_key = openai_api_key,
 
81
  embed_store_documents_mongodb(chunks)
82
 
83
  def retrieve_documents_chroma():
84
+ with get_openai_callback() as cb:
85
+ db = Chroma(embedding_function = OpenAIEmbeddings(disallowed_special = ()),
86
+ persist_directory = CHROMA_DIR)
87
+ print(cb)
88
+
89
+ return db
90
 
91
  def retrieve_documents_mongodb():
92
+ with get_openai_callback() as cb:
93
+ db = MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
94
+ MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
95
+ OpenAIEmbeddings(disallowed_special = ()),
96
+ index_name = MONGODB_INDEX_NAME)
97
+ print cb
98
 
99
+ return db
100
+
101
  def get_llm(config, openai_api_key):
102
  return ChatOpenAI(model_name = config["model_name"],
103
  openai_api_key = openai_api_key,