yrobel-lima commited on
Commit
0e13c24
·
verified ·
1 Parent(s): d606612

Update rag/retrievers.py

Browse files
Files changed (1) hide show
  1. rag/retrievers.py +86 -86
rag/retrievers.py CHANGED
@@ -1,86 +1,86 @@
1
- import os
2
- from functools import lru_cache
3
- from typing import Literal
4
-
5
- from langchain_core.vectorstores import VectorStoreRetriever
6
- from langchain_openai import OpenAIEmbeddings
7
- from langchain_qdrant import FastEmbedSparse, QdrantVectorStore, RetrievalMode
8
-
9
- os.environ["GRPC_VERBOSITY"] = "NONE"
10
-
11
-
12
- class RetrieversConfig:
13
- REQUIRED_ENV_VARS = ["QDRANT_API_KEY", "QDRANT_URL", "OPENAI_API_KEY"]
14
-
15
- def __init__(
16
- self,
17
- dense_model_name: Literal["text-embedding-3-small"] = "text-embedding-3-small",
18
- sparse_model_name: Literal[
19
- "prithivida/Splade_PP_en_v1"
20
- ] = "prithivida/Splade_PP_en_v1",
21
- ):
22
- self._validate_environment()
23
- self.qdrant_url = os.getenv("QDRANT_URL")
24
- self.qdrant_api_key = os.getenv("QDRANT_API_KEY")
25
- self.dense_model_name = dense_model_name
26
- self.sparse_model_name = sparse_model_name
27
-
28
- @staticmethod
29
- def _validate_environment():
30
- missing_vars = [
31
- var
32
- for var in RetrieversConfig.REQUIRED_ENV_VARS
33
- if not os.getenv(var, "").strip()
34
- ]
35
- if missing_vars:
36
- raise EnvironmentError(
37
- f"Missing or empty environment variable(s): {', '.join(missing_vars)}"
38
- )
39
-
40
- @property
41
- @lru_cache(maxsize=2)
42
- def dense_embeddings(self):
43
- return OpenAIEmbeddings(model=self.dense_model_name)
44
-
45
- @property
46
- @lru_cache(maxsize=2)
47
- def sparse_embeddings(self):
48
- return FastEmbedSparse(model_name=self.sparse_model_name)
49
-
50
- @lru_cache(maxsize=8)
51
- def get_qdrant_retriever(
52
- self,
53
- collection_name: str,
54
- dense_vector_name: str,
55
- sparse_vector_name: str,
56
- k: int = 5,
57
- ) -> VectorStoreRetriever:
58
- qdrantdb = QdrantVectorStore.from_existing_collection(
59
- embedding=self.dense_embeddings,
60
- sparse_embedding=self.sparse_embeddings,
61
- url=self.qdrant_url,
62
- api_key=self.qdrant_api_key,
63
- prefer_grpc=True,
64
- collection_name=collection_name,
65
- retrieval_mode=RetrievalMode.HYBRID,
66
- vector_name=dense_vector_name,
67
- sparse_vector_name=sparse_vector_name,
68
- )
69
-
70
- return qdrantdb.as_retriever(search_kwargs={"k": k})
71
-
72
- def get_practitioners_retriever(self, k: int = 5) -> VectorStoreRetriever:
73
- return self.get_qdrant_retriever(
74
- collection_name="practitioners_hybrid_db",
75
- dense_vector_name="practitioners_dense_vectors",
76
- sparse_vector_name="practitioners_sparse_vectors",
77
- k=k,
78
- )
79
-
80
- def get_documents_retriever(self, k: int = 5) -> VectorStoreRetriever:
81
- return self.get_qdrant_retriever(
82
- collection_name="docs_hybrid_db",
83
- dense_vector_name="docs_dense_vectors",
84
- sparse_vector_name="docs_sparse_vectors",
85
- k=k,
86
- )
 
1
+ import os
2
+ from functools import lru_cache
3
+ from typing import Literal
4
+
5
+ from langchain_core.vectorstores import VectorStoreRetriever
6
+ from langchain_openai import OpenAIEmbeddings
7
+ from langchain_qdrant import FastEmbedSparse, QdrantVectorStore, RetrievalMode
8
+
9
+ os.environ["GRPC_VERBOSITY"] = "NONE"
10
+
11
+
12
+ class RetrieversConfig:
13
+ REQUIRED_ENV_VARS = ["QDRANT_API_KEY", "QDRANT_URL", "OPENAI_API_KEY"]
14
+
15
+ def __init__(
16
+ self,
17
+ dense_model_name: Literal["text-embedding-3-small"] = "text-embedding-3-small",
18
+ sparse_model_name: Literal[
19
+ "prithivida/Splade_PP_en_v1"
20
+ ] = "prithivida/Splade_PP_en_v1",
21
+ ):
22
+ self._validate_environment()
23
+ self.qdrant_url = os.getenv("QDRANT_URL")
24
+ self.qdrant_api_key = os.getenv("QDRANT_API_KEY")
25
+ self.dense_model_name = dense_model_name
26
+ self.sparse_model_name = sparse_model_name
27
+
28
+ @staticmethod
29
+ def _validate_environment():
30
+ missing_vars = [
31
+ var
32
+ for var in RetrieversConfig.REQUIRED_ENV_VARS
33
+ if not os.getenv(var, "").strip()
34
+ ]
35
+ if missing_vars:
36
+ raise EnvironmentError(
37
+ f"Missing or empty environment variable(s): {', '.join(missing_vars)}"
38
+ )
39
+
40
+ @property
41
+ @lru_cache(maxsize=2)
42
+ def dense_embeddings(self):
43
+ return OpenAIEmbeddings(model=self.dense_model_name)
44
+
45
+ @property
46
+ @lru_cache(maxsize=2)
47
+ def sparse_embeddings(self):
48
+ return FastEmbedSparse(model_name=self.sparse_model_name)
49
+
50
+ @lru_cache(maxsize=8)
51
+ def get_qdrant_retriever(
52
+ self,
53
+ collection_name: str,
54
+ dense_vector_name: str,
55
+ sparse_vector_name: str,
56
+ k: int = 5,
57
+ ) -> VectorStoreRetriever:
58
+ qdrantdb = QdrantVectorStore.from_existing_collection(
59
+ embedding=self.dense_embeddings,
60
+ sparse_embedding=self.sparse_embeddings,
61
+ url=self.qdrant_url,
62
+ api_key=self.qdrant_api_key,
63
+ prefer_grpc=True,
64
+ collection_name=collection_name,
65
+ retrieval_mode=RetrievalMode.HYBRID,
66
+ vector_name=dense_vector_name,
67
+ sparse_vector_name=sparse_vector_name,
68
+ )
69
+
70
+ return qdrantdb.as_retriever(search_kwargs={"k": k})
71
+
72
+ def get_practitioners_retriever(self, k: int = 5) -> VectorStoreRetriever:
73
+ return self.get_qdrant_retriever(
74
+ collection_name="practitioners_hybrid_db_upgrade",
75
+ dense_vector_name="practitioners_dense_vectors_upgrade",
76
+ sparse_vector_name="practitioners_sparse_vectors_upgrade",
77
+ k=k,
78
+ )
79
+
80
+ def get_documents_retriever(self, k: int = 5) -> VectorStoreRetriever:
81
+ return self.get_qdrant_retriever(
82
+ collection_name="docs_hybrid_db_upgrade",
83
+ dense_vector_name="docs_dense_vectors_upgrade",
84
+ sparse_vector_name="docs_sparse_vectors_upgrade",
85
+ k=k,
86
+ )