nnngoc commited on
Commit
0b802db
·
1 Parent(s): ada25c5
Files changed (2) hide show
  1. rag.py +4 -4
  2. utility.py +67 -1
rag.py CHANGED
@@ -1,6 +1,6 @@
1
  from langchain.embeddings import HuggingFaceEmbeddings
2
  from langchain.prompts import PromptTemplate
3
- from utility import load_data, process_data, CustomRetriever
4
 
5
 
6
  data1 = load_data('raw_data/sv')
@@ -137,9 +137,9 @@ ensemble_retriever3 = EnsembleRetriever(retrievers=[bm25_retriever3, retriever3]
137
 
138
  #########################################################################################
139
 
140
- custom_retriever1 = CustomRetriever(retriever = ensemble_retriever1)
141
- custom_retriever2 = CustomRetriever(retriever = ensemble_retriever2)
142
- custom_retriever3 = CustomRetriever(retriever = ensemble_retriever3)
143
 
144
  multiq_chain1 = generate_queries | custom_retriever1
145
  multiq_chain2 = generate_queries | custom_retriever2
 
1
  from langchain.embeddings import HuggingFaceEmbeddings
2
  from langchain.prompts import PromptTemplate
3
+ from utility import load_data, process_data, CustomRetriever, CustomRetriever1
4
 
5
 
6
  data1 = load_data('raw_data/sv')
 
137
 
138
  #########################################################################################
139
 
140
+ custom_retriever1 = CustomRetriever1(retriever = ensemble_retriever1)
141
+ custom_retriever2 = CustomRetriever1(retriever = ensemble_retriever2)
142
+ custom_retriever3 = CustomRetriever1(retriever = ensemble_retriever3)
143
 
144
  multiq_chain1 = generate_queries | custom_retriever1
145
  multiq_chain2 = generate_queries | custom_retriever2
utility.py CHANGED
@@ -144,4 +144,70 @@ class CustomRetriever(BaseRetriever):
144
 
145
  docs_top_10 = docs[0:10]
146
 
147
- return docs_top_10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  docs_top_10 = docs[0:10]
146
 
147
+ return docs_top_10
148
+
149
+
150
+ import cohere
151
+ COHERE_API_KEY = 'axMzubIv9l3UTObYnIaHuZhE6tR3Nj8eGReXTws9'
152
+
153
+ class CustomRetriever1(BaseRetriever):
154
+ # vectorstores:Chroma
155
+ retriever:Any
156
+
157
+ def reciprocal_rank_fusion(self, results: list[list], k=60):
158
+ """ Reciprocal_rank_fusion that takes multiple lists of ranked documents
159
+ and an optional parameter k used in the RRF formula """
160
+
161
+ # Initialize a dictionary to hold fused scores for each unique document
162
+ fused_scores = {}
163
+
164
+ # Iterate through each list of ranked documents
165
+ for docs in results:
166
+ # Iterate through each document in the list, with its rank (position in the list)
167
+ for rank, doc in enumerate(docs):
168
+ # Convert the document to a string format to use as a key (assumes documents can be serialized to JSON)
169
+ doc_str = dumps(doc)
170
+ # If the document is not yet in the fused_scores dictionary, add it with an initial score of 0
171
+ if doc_str not in fused_scores:
172
+ fused_scores[doc_str] = 0
173
+ # Retrieve the current score of the document, if any
174
+ previous_score = fused_scores[doc_str]
175
+ # Update the score of the document using the RRF formula: 1 / (rank + k)
176
+ fused_scores[doc_str] += 1 / (rank + k)
177
+
178
+ # Sort the documents based on their fused scores in descending order to get the final reranked results
179
+ reranked_results = [
180
+ (loads(doc), score)
181
+ for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True) #[:10] #Top 10
182
+ ]
183
+
184
+ # Return the reranked results as a list of tuples, each containing the document and its fused score
185
+ rr_list=[]
186
+ for doc in reranked_results:
187
+ rr_list.append(doc[0])
188
+ return rr_list[:30]
189
+
190
+ def _get_relevant_documents(
191
+ self, queries: list, *, run_manager: CallbackManagerForRetrieverRun
192
+ ) -> List[Document]:
193
+ # Use your existing retriever to get the documents
194
+ documents=[]
195
+ for i in range(len(queries)):
196
+ document = self.retriever.get_relevant_documents(queries[i], callbacks=run_manager.get_child())
197
+ documents.append(document)
198
+
199
+ unique_documents = self.reciprocal_rank_fusion(documents)
200
+
201
+ # Get page content
202
+ docs_content = []
203
+ for i in range(len(unique_documents)):
204
+ docs_content.append(unique_documents[i].page_content)
205
+
206
+ co = cohere.Client(COHERE_API_KEY)
207
+ results = co.rerank(query=queries[0], documents=docs_content, top_n=10, model='rerank-multilingual-v3.0', return_documents=True)
208
+
209
+ reranked_indices = [result.index for result in results.results]
210
+
211
+ sorted_documents = [unique_documents[idx] for idx in reranked_indices]
212
+
213
+ return sorted_documents