Gainward777 commited on
Commit
1de93cb
·
verified ·
1 Parent(s): 2a86356

Update CustomRetriever.py

Browse files
Files changed (1) hide show
  1. CustomRetriever.py +19 -37
CustomRetriever.py CHANGED
@@ -7,44 +7,26 @@ from langchain.callbacks.manager import CallbackManagerForRetrieverRun
7
  from langchain_core.documents import Document
8
  from langchain_core.runnables import chain
9
 
10
- class CustomRetriever():
11
- def __init__(self, v_db, thold=0.7):
12
 
13
- self.vdb=v_db
14
-
15
- #self.retriever=RetrieverWithScores()
16
-
17
- @chain
18
- def retr_func(query: str)-> List[Document]: #(vdb, query: str)-> List[Document]:
19
- docs, scores = zip(*self.vdb.similarity_search_with_relevance_scores(query))#similarity_search_with_score(query))
20
- result=[]
21
- for doc, score in zip(docs, scores):
22
- if score>thold:
23
- doc.metadata["score"] = score
24
- result.append(doc)
25
- if len(result)==0:
26
- result.append(Document(metadata={}, page_content='No data'))
27
-
28
- return result #docs
29
-
30
- class RetrieverWithScores(BaseRetriever):
31
- #def __init__(self, vdb):
32
- #self.vdb=vdb
33
  #def __init__(self, retriever: BaseRetriever): # Add an __init__ to store the existing retriever
34
  #super().__init__(retriever=retriever)
35
- def _get_relevant_documents(
36
- self, query: str, *, run_manager: CallbackManagerForRetrieverRun)-> List[Document]:
37
-
38
-
39
-
40
- return retr_func.invoke(query) #(self.vdb, query)
41
-
42
- self.retriever=RetrieverWithScores()
 
 
 
 
 
 
43
 
44
-
45
-
46
-
47
-
48
-
49
-
50
-
 
7
  from langchain_core.documents import Document
8
  from langchain_core.runnables import chain
9
 
 
 
10
 
11
+ class RetrieverWithScores(BaseRetriever):
12
+ def __init__(self, vdb, thold=0.7):
13
+ self.vdb=vdb
14
+ self.thold=thold
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  #def __init__(self, retriever: BaseRetriever): # Add an __init__ to store the existing retriever
16
  #super().__init__(retriever=retriever)
17
+ def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun)-> List[Document]:
18
+
19
+ return self.retr_func.invoke(query) #(self.vdb, query)
20
+
21
+ @chain
22
+ def retr_func(self, query: str)-> List[Document]: #(vdb, query: str)-> List[Document]:
23
+ docs, scores = zip(*self.vdb.similarity_search_with_relevance_scores(query))#similarity_search_with_score(query))
24
+ result=[]
25
+ for doc, score in zip(docs, scores):
26
+ if score>self.thold:
27
+ doc.metadata["score"] = score
28
+ result.append(doc)
29
+ if len(result)==0:
30
+ result.append(Document(metadata={}, page_content='No data'))
31
 
32
+ return result #docs