bstraehle commited on
Commit
62a61d6
·
1 Parent(s): 5ad33e1

Update rag_langchain.py

Browse files
Files changed (1) hide show
  1. rag_langchain.py +15 -9
rag_langchain.py CHANGED
@@ -62,7 +62,8 @@ class LangChainRAG(BaseRAG):
62
  Chroma.from_documents(
63
  documents = chunks,
64
  embedding = OpenAIEmbeddings(disallowed_special = ()),
65
- persist_directory = self.CHROMA_DIR)
 
66
 
67
  def store_documents_mongodb(self, chunks):
68
  client = MongoClient(self.MONGODB_ATLAS_CLUSTER_URI)
@@ -72,7 +73,8 @@ class LangChainRAG(BaseRAG):
72
  documents = chunks,
73
  embedding = OpenAIEmbeddings(disallowed_special = ()),
74
  collection = collection,
75
- index_name = self.MONGODB_INDEX_NAME)
 
76
 
77
  def ingestion(self, config):
78
  docs = self.load_documents()
@@ -85,24 +87,28 @@ class LangChainRAG(BaseRAG):
85
  def get_vector_store_chroma(self):
86
  return Chroma(
87
  embedding_function = OpenAIEmbeddings(disallowed_special = ()),
88
- persist_directory = self.CHROMA_DIR)
 
89
 
90
  def get_vector_store_mongodb(self):
91
  return MongoDBAtlasVectorSearch.from_connection_string(
92
  self.MONGODB_ATLAS_CLUSTER_URI,
93
  self.MONGODB_DB_NAME + "." + self.MONGODB_COLLECTION_NAME,
94
  OpenAIEmbeddings(disallowed_special = ()),
95
- index_name = self.MONGODB_INDEX_NAME)
 
96
 
97
  def get_llm(self, config):
98
  return ChatOpenAI(
99
  model_name = config["model_name"],
100
- temperature = config["temperature"])
 
101
 
102
  def llm_chain(self, config, prompt):
103
  llm_chain = LLMChain(
104
  llm = self.get_llm(config),
105
- prompt = self.LLM_CHAIN_PROMPT)
 
106
 
107
  with get_openai_callback() as callback:
108
  completion = llm_chain.generate([{"question": prompt}])
@@ -115,10 +121,10 @@ class LangChainRAG(BaseRAG):
115
 
116
  rag_chain = RetrievalQA.from_chain_type(
117
  self.get_llm(config),
118
- chain_type_kwargs = {"prompt": self.RAG_CHAIN_PROMPT,
119
- "verbose": True},
120
  retriever = vector_store.as_retriever(search_kwargs = {"k": config["k"]}),
121
- return_source_documents = True)
 
122
 
123
  with get_openai_callback() as callback:
124
  completion = rag_chain({"query": prompt})
 
62
  Chroma.from_documents(
63
  documents = chunks,
64
  embedding = OpenAIEmbeddings(disallowed_special = ()),
65
+ persist_directory = self.CHROMA_DIR
66
+ )
67
 
68
  def store_documents_mongodb(self, chunks):
69
  client = MongoClient(self.MONGODB_ATLAS_CLUSTER_URI)
 
73
  documents = chunks,
74
  embedding = OpenAIEmbeddings(disallowed_special = ()),
75
  collection = collection,
76
+ index_name = self.MONGODB_INDEX_NAME
77
+ )
78
 
79
  def ingestion(self, config):
80
  docs = self.load_documents()
 
87
  def get_vector_store_chroma(self):
88
  return Chroma(
89
  embedding_function = OpenAIEmbeddings(disallowed_special = ()),
90
+ persist_directory = self.CHROMA_DIR
91
+ )
92
 
93
  def get_vector_store_mongodb(self):
94
  return MongoDBAtlasVectorSearch.from_connection_string(
95
  self.MONGODB_ATLAS_CLUSTER_URI,
96
  self.MONGODB_DB_NAME + "." + self.MONGODB_COLLECTION_NAME,
97
  OpenAIEmbeddings(disallowed_special = ()),
98
+ index_name = self.MONGODB_INDEX_NAME
99
+ )
100
 
101
  def get_llm(self, config):
102
  return ChatOpenAI(
103
  model_name = config["model_name"],
104
+ temperature = config["temperature"]
105
+ )
106
 
107
  def llm_chain(self, config, prompt):
108
  llm_chain = LLMChain(
109
  llm = self.get_llm(config),
110
+ prompt = self.LLM_CHAIN_PROMPT
111
+ )
112
 
113
  with get_openai_callback() as callback:
114
  completion = llm_chain.generate([{"question": prompt}])
 
121
 
122
  rag_chain = RetrievalQA.from_chain_type(
123
  self.get_llm(config),
124
+ chain_type_kwargs = {"prompt": self.RAG_CHAIN_PROMPT},
 
125
  retriever = vector_store.as_retriever(search_kwargs = {"k": config["k"]}),
126
+ return_source_documents = True
127
+ )
128
 
129
  with get_openai_callback() as callback:
130
  completion = rag_chain({"query": prompt})