Gainward777 commited on
Commit
2a86356
·
verified ·
1 Parent(s): edb6e23

Update CustomRetriever.py

Browse files
Files changed (1) hide show
  1. CustomRetriever.py +50 -47
CustomRetriever.py CHANGED
@@ -1,47 +1,50 @@
1
- from langchain.schema.retriever import BaseRetriever
2
- from langchain_core.documents import Document
3
- from typing import List
4
-
5
- from langchain.callbacks.manager import CallbackManagerForRetrieverRun
6
-
7
- from langchain_core.documents import Document
8
- from langchain_core.runnables import chain
9
-
10
- class CustomRetriever():
11
- def __init__(self, v_db, thold=0.7):
12
-
13
-
14
- #self.retriever=RetrieverWithScores()
15
-
16
- class RetrieverWithScores(BaseRetriever):
17
- #def __init__(self, vdb):
18
- #self.vdb=vdb
19
- #def __init__(self, retriever: BaseRetriever): # Add an __init__ to store the existing retriever
20
- #super().__init__(retriever=retriever)
21
- def _get_relevant_documents(
22
- self, query: str, *, run_manager: CallbackManagerForRetrieverRun)-> List[Document]:
23
-
24
- @chain
25
- def retr_func(query: str)-> List[Document]: #(vdb, query: str)-> List[Document]:
26
- docs, scores = zip(*v_db.similarity_search_with_relevance_scores(query))#similarity_search_with_score(query))
27
- result=[]
28
- for doc, score in zip(docs, scores):
29
- if score>thold:
30
- doc.metadata["score"] = score
31
- result.append(doc)
32
- if len(result)==0:
33
- result.append(Document(metadata={}, page_content='No data'))
34
-
35
- return result #docs
36
-
37
- return retr_func.invoke(query) #(self.vdb, query)
38
-
39
- self.retriever=RetrieverWithScores()
40
-
41
-
42
-
43
-
44
-
45
-
46
-
47
-
 
 
 
 
1
+ from langchain.schema.retriever import BaseRetriever
2
+ from langchain_core.documents import Document
3
+ from typing import List
4
+
5
+ from langchain.callbacks.manager import CallbackManagerForRetrieverRun
6
+
7
+ from langchain_core.documents import Document
8
+ from langchain_core.runnables import chain
9
+
10
+ class CustomRetriever():
11
+ def __init__(self, v_db, thold=0.7):
12
+
13
+ self.vdb=v_db
14
+
15
+ #self.retriever=RetrieverWithScores()
16
+
17
+ @chain
18
+ def retr_func(query: str)-> List[Document]: #(vdb, query: str)-> List[Document]:
19
+ docs, scores = zip(*self.vdb.similarity_search_with_relevance_scores(query))#similarity_search_with_score(query))
20
+ result=[]
21
+ for doc, score in zip(docs, scores):
22
+ if score>thold:
23
+ doc.metadata["score"] = score
24
+ result.append(doc)
25
+ if len(result)==0:
26
+ result.append(Document(metadata={}, page_content='No data'))
27
+
28
+ return result #docs
29
+
30
+ class RetrieverWithScores(BaseRetriever):
31
+ #def __init__(self, vdb):
32
+ #self.vdb=vdb
33
+ #def __init__(self, retriever: BaseRetriever): # Add an __init__ to store the existing retriever
34
+ #super().__init__(retriever=retriever)
35
+ def _get_relevant_documents(
36
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun)-> List[Document]:
37
+
38
+
39
+
40
+ return retr_func.invoke(query) #(self.vdb, query)
41
+
42
+ self.retriever=RetrieverWithScores()
43
+
44
+
45
+
46
+
47
+
48
+
49
+
50
+