bstraehle commited on
Commit
1ce0835
·
1 Parent(s): c947c47

Update rag_langchain.py

Browse files
Files changed (1) hide show
  1. rag_langchain.py +13 -13
rag_langchain.py CHANGED
@@ -67,13 +67,13 @@ def split_documents(config, docs):
67
 
68
  return text_splitter.split_documents(docs)
69
 
70
- def store_chroma(chunks):
71
  Chroma.from_documents(
72
  documents = chunks,
73
  embedding = OpenAIEmbeddings(disallowed_special = ()),
74
  persist_directory = CHROMA_DIR)
75
 
76
- def store_mongodb(chunks):
77
  client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
78
  collection = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]
79
 
@@ -88,15 +88,15 @@ def rag_ingestion(config):
88
 
89
  chunks = split_documents(config, docs)
90
 
91
- #store_chroma(chunks)
92
- store_mongodb(chunks)
93
 
94
- def retrieve_chroma():
95
  return Chroma(
96
  embedding_function = OpenAIEmbeddings(disallowed_special = ()),
97
  persist_directory = CHROMA_DIR)
98
 
99
- def retrieve_mongodb():
100
  return MongoDBAtlasVectorSearch.from_connection_string(
101
  MONGODB_ATLAS_CLUSTER_URI,
102
  MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
@@ -113,23 +113,23 @@ def llm_chain(config, prompt):
113
  llm = get_llm(config),
114
  prompt = LLM_CHAIN_PROMPT)
115
 
116
- with get_openai_callback() as cb:
117
  completion = llm_chain.generate([{"question": prompt}])
118
 
119
- return completion, llm_chain, cb
120
 
121
  def rag_chain(config, prompt):
122
- #db = retrieve_chroma()
123
- db = retrieve_mongodb()
124
 
125
  rag_chain = RetrievalQA.from_chain_type(
126
  get_llm(config),
127
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT,
128
  "verbose": True},
129
- retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
130
  return_source_documents = True)
131
 
132
- with get_openai_callback() as cb:
133
  completion = rag_chain({"query": prompt})
134
 
135
- return completion, rag_chain, cb
 
67
 
68
  return text_splitter.split_documents(docs)
69
 
70
+ def store_documents_chroma(chunks):
71
  Chroma.from_documents(
72
  documents = chunks,
73
  embedding = OpenAIEmbeddings(disallowed_special = ()),
74
  persist_directory = CHROMA_DIR)
75
 
76
+ def store_documents_mongodb(chunks):
77
  client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
78
  collection = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]
79
 
 
88
 
89
  chunks = split_documents(config, docs)
90
 
91
+ #store_documents_chroma(chunks)
92
+ store_documents_mongodb(chunks)
93
 
94
+ def get_vector_store_chroma():
95
  return Chroma(
96
  embedding_function = OpenAIEmbeddings(disallowed_special = ()),
97
  persist_directory = CHROMA_DIR)
98
 
99
+ def get_vector_store_mongodb():
100
  return MongoDBAtlasVectorSearch.from_connection_string(
101
  MONGODB_ATLAS_CLUSTER_URI,
102
  MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
 
113
  llm = get_llm(config),
114
  prompt = LLM_CHAIN_PROMPT)
115
 
116
+ with get_openai_callback() as callback:
117
  completion = llm_chain.generate([{"question": prompt}])
118
 
119
+ return completion, llm_chain, callback
120
 
121
  def rag_chain(config, prompt):
122
+ #vector_store = get_vector_store_chroma()
123
+ vector_store = get_vector_store_mongodb()
124
 
125
  rag_chain = RetrievalQA.from_chain_type(
126
  get_llm(config),
127
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT,
128
  "verbose": True},
129
+ retriever = vector_store.as_retriever(search_kwargs = {"k": config["k"]}),
130
  return_source_documents = True)
131
 
132
+ with get_openai_callback() as callback:
133
  completion = rag_chain({"query": prompt})
134
 
135
+ return completion, rag_chain, callback