Spaces:
Build error
Build error
Update rag_langchain.py
Browse files- rag_langchain.py +15 -9
rag_langchain.py
CHANGED
@@ -62,7 +62,8 @@ class LangChainRAG(BaseRAG):
|
|
62 |
Chroma.from_documents(
|
63 |
documents = chunks,
|
64 |
embedding = OpenAIEmbeddings(disallowed_special = ()),
|
65 |
-
persist_directory = self.CHROMA_DIR
|
|
|
66 |
|
67 |
def store_documents_mongodb(self, chunks):
|
68 |
client = MongoClient(self.MONGODB_ATLAS_CLUSTER_URI)
|
@@ -72,7 +73,8 @@ class LangChainRAG(BaseRAG):
|
|
72 |
documents = chunks,
|
73 |
embedding = OpenAIEmbeddings(disallowed_special = ()),
|
74 |
collection = collection,
|
75 |
-
index_name = self.MONGODB_INDEX_NAME
|
|
|
76 |
|
77 |
def ingestion(self, config):
|
78 |
docs = self.load_documents()
|
@@ -85,24 +87,28 @@ class LangChainRAG(BaseRAG):
|
|
85 |
def get_vector_store_chroma(self):
|
86 |
return Chroma(
|
87 |
embedding_function = OpenAIEmbeddings(disallowed_special = ()),
|
88 |
-
persist_directory = self.CHROMA_DIR
|
|
|
89 |
|
90 |
def get_vector_store_mongodb(self):
|
91 |
return MongoDBAtlasVectorSearch.from_connection_string(
|
92 |
self.MONGODB_ATLAS_CLUSTER_URI,
|
93 |
self.MONGODB_DB_NAME + "." + self.MONGODB_COLLECTION_NAME,
|
94 |
OpenAIEmbeddings(disallowed_special = ()),
|
95 |
-
index_name = self.MONGODB_INDEX_NAME
|
|
|
96 |
|
97 |
def get_llm(self, config):
|
98 |
return ChatOpenAI(
|
99 |
model_name = config["model_name"],
|
100 |
-
temperature = config["temperature"]
|
|
|
101 |
|
102 |
def llm_chain(self, config, prompt):
|
103 |
llm_chain = LLMChain(
|
104 |
llm = self.get_llm(config),
|
105 |
-
prompt = self.LLM_CHAIN_PROMPT
|
|
|
106 |
|
107 |
with get_openai_callback() as callback:
|
108 |
completion = llm_chain.generate([{"question": prompt}])
|
@@ -115,10 +121,10 @@ class LangChainRAG(BaseRAG):
|
|
115 |
|
116 |
rag_chain = RetrievalQA.from_chain_type(
|
117 |
self.get_llm(config),
|
118 |
-
chain_type_kwargs = {"prompt": self.RAG_CHAIN_PROMPT,
|
119 |
-
"verbose": True},
|
120 |
retriever = vector_store.as_retriever(search_kwargs = {"k": config["k"]}),
|
121 |
-
return_source_documents = True
|
|
|
122 |
|
123 |
with get_openai_callback() as callback:
|
124 |
completion = rag_chain({"query": prompt})
|
|
|
62 |
Chroma.from_documents(
|
63 |
documents = chunks,
|
64 |
embedding = OpenAIEmbeddings(disallowed_special = ()),
|
65 |
+
persist_directory = self.CHROMA_DIR
|
66 |
+
)
|
67 |
|
68 |
def store_documents_mongodb(self, chunks):
|
69 |
client = MongoClient(self.MONGODB_ATLAS_CLUSTER_URI)
|
|
|
73 |
documents = chunks,
|
74 |
embedding = OpenAIEmbeddings(disallowed_special = ()),
|
75 |
collection = collection,
|
76 |
+
index_name = self.MONGODB_INDEX_NAME
|
77 |
+
)
|
78 |
|
79 |
def ingestion(self, config):
|
80 |
docs = self.load_documents()
|
|
|
87 |
def get_vector_store_chroma(self):
|
88 |
return Chroma(
|
89 |
embedding_function = OpenAIEmbeddings(disallowed_special = ()),
|
90 |
+
persist_directory = self.CHROMA_DIR
|
91 |
+
)
|
92 |
|
93 |
def get_vector_store_mongodb(self):
|
94 |
return MongoDBAtlasVectorSearch.from_connection_string(
|
95 |
self.MONGODB_ATLAS_CLUSTER_URI,
|
96 |
self.MONGODB_DB_NAME + "." + self.MONGODB_COLLECTION_NAME,
|
97 |
OpenAIEmbeddings(disallowed_special = ()),
|
98 |
+
index_name = self.MONGODB_INDEX_NAME
|
99 |
+
)
|
100 |
|
101 |
def get_llm(self, config):
|
102 |
return ChatOpenAI(
|
103 |
model_name = config["model_name"],
|
104 |
+
temperature = config["temperature"]
|
105 |
+
)
|
106 |
|
107 |
def llm_chain(self, config, prompt):
|
108 |
llm_chain = LLMChain(
|
109 |
llm = self.get_llm(config),
|
110 |
+
prompt = self.LLM_CHAIN_PROMPT
|
111 |
+
)
|
112 |
|
113 |
with get_openai_callback() as callback:
|
114 |
completion = llm_chain.generate([{"question": prompt}])
|
|
|
121 |
|
122 |
rag_chain = RetrievalQA.from_chain_type(
|
123 |
self.get_llm(config),
|
124 |
+
chain_type_kwargs = {"prompt": self.RAG_CHAIN_PROMPT},
|
|
|
125 |
retriever = vector_store.as_retriever(search_kwargs = {"k": config["k"]}),
|
126 |
+
return_source_documents = True
|
127 |
+
)
|
128 |
|
129 |
with get_openai_callback() as callback:
|
130 |
completion = rag_chain({"query": prompt})
|