Gainward777 commited on
Commit
45e6244
·
verified ·
1 Parent(s): 1cb36c5

Upload CustomRetriever.py

Browse files
Files changed (1) hide show
  1. CustomRetriever.py +47 -0
CustomRetriever.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.schema.retriever import BaseRetriever
2
+ from langchain_core.documents import Document
3
+ from typing import List
4
+
5
+ from langchain.callbacks.manager import CallbackManagerForRetrieverRun
6
+
7
+ from langchain_core.documents import Document
8
+ from langchain_core.runnables import chain
9
+
10
+ class CustomRetriever():
11
+ def __init__(self, v_db, thold=0.7):
12
+
13
+
14
+ #self.retriever=RetrieverWithScores()
15
+
16
+ class RetrieverWithScores(BaseRetriever):
17
+ #def __init__(self, vdb):
18
+ #self.vdb=vdb
19
+ #def __init__(self, retriever: BaseRetriever): # Add an __init__ to store the existing retriever
20
+ #super().__init__(retriever=retriever)
21
+ def _get_relevant_documents(
22
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun)-> List[Document]:
23
+
24
+ @chain
25
+ def retr_func(query: str)-> List[Document]: #(vdb, query: str)-> List[Document]:
26
+ docs, scores = zip(*v_db.similarity_search_with_relevance_scores(query))#similarity_search_with_score(query))
27
+ result=[]
28
+ for doc, score in zip(docs, scores):
29
+ if score>thold:
30
+ doc.metadata["score"] = score
31
+ result.append(doc)
32
+ if len(result)==0:
33
+ result.append(Document(metadata={}, page_content='No data'))
34
+
35
+ return result #docs
36
+
37
+ return retr_func.invoke(query) #(self.vdb, query)
38
+
39
+ self.retriever=RetrieverWithScores()
40
+
41
+
42
+
43
+
44
+
45
+
46
+
47
+