Gainward777 commited on
Commit
5a63c61
·
verified ·
1 Parent(s): 7c73845

Update CustomRetriever.py

Browse files
Files changed (1) hide show
  1. CustomRetriever.py +33 -19
CustomRetriever.py CHANGED
@@ -7,27 +7,41 @@ from langchain.callbacks.manager import CallbackManagerForRetrieverRun
7
  from langchain_core.documents import Document
8
  from langchain_core.runnables import chain
9
 
10
-
11
- class RetrieverWithScores(BaseRetriever):
12
- def __init__(self, vdb, thold=0.7):
13
- self.vdb=vdb
14
- self.thold=thold
15
-
16
- @chain
17
- def retr_func(query: str)-> List[Document]: #(vdb, query: str)-> List[Document]:
18
- docs, scores = zip(*self.vdb.similarity_search_with_relevance_scores(query))#similarity_search_with_score(query))
19
- result=[]
20
- for doc, score in zip(docs, scores):
21
- if score>self.thold:
22
- doc.metadata["score"] = score
23
- result.append(doc)
24
- if len(result)==0:
25
- result.append(Document(metadata={}, page_content='No data'))
26
 
27
- return result #docs
 
 
 
 
 
28
  #def __init__(self, retriever: BaseRetriever): # Add an __init__ to store the existing retriever
29
  #super().__init__(retriever=retriever)
30
- def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun)-> List[Document]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- return self.retr_func.invoke(query) #(self.vdb, query)
33
 
 
7
  from langchain_core.documents import Document
8
  from langchain_core.runnables import chain
9
 
10
+ class CustomRetriever():
11
+ def __init__(self, v_db, thold=0.7):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+
14
+ #self.retriever=RetrieverWithScores()
15
+
16
+ class RetrieverWithScores(BaseRetriever):
17
+ #def __init__(self, vdb):
18
+ #self.vdb=vdb
19
  #def __init__(self, retriever: BaseRetriever): # Add an __init__ to store the existing retriever
20
  #super().__init__(retriever=retriever)
21
+ def _get_relevant_documents(
22
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun)-> List[Document]:
23
+
24
+ @chain
25
+ def retr_func(query: str)-> List[Document]: #(vdb, query: str)-> List[Document]:
26
+ docs, scores = zip(*v_db.similarity_search_with_relevance_scores(query))#similarity_search_with_score(query))
27
+ result=[]
28
+ for doc, score in zip(docs, scores):
29
+ if score>thold:
30
+ doc.metadata["score"] = score
31
+ result.append(doc)
32
+ if len(result)==0:
33
+ result.append(Document(metadata={}, page_content='No data'))
34
+
35
+ return result #docs
36
+
37
+ return retr_func.invoke(query) #(self.vdb, query)
38
+
39
+ self.retriever=RetrieverWithScores()
40
+
41
+
42
+
43
+
44
+
45
+
46
 
 
47