File size: 2,724 Bytes
e921012
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
from typing import List, Literal

from langchain_core.vectorstores import VectorStoreRetriever
from langchain_openai import OpenAIEmbeddings
from langchain_qdrant import FastEmbedSparse, QdrantVectorStore, RetrievalMode

os.environ["GRPC_VERBOSITY"] = "NONE"


class RetrieversConfig:
    def __init__(

        self,

        dense_model_name: Literal["text-embedding-3-small"] = "text-embedding-3-small",

        sparse_model_name: Literal[

            "prithivida/Splade_PP_en_v1"

        ] = "prithivida/Splade_PP_en_v1",

    ):
        self.required_env_vars = ["QDRANT_API_KEY", "QDRANT_URL", "OPENAI_API_KEY"]
        self._validate_environment(self.required_env_vars)
        self.qdrant_url = os.getenv("QDRANT_URL")
        self.qdrant_api_key = os.getenv("QDRANT_API_KEY")
        self.dense_embeddings = OpenAIEmbeddings(model=dense_model_name)
        self.sparse_embeddings = FastEmbedSparse(
            model_name=sparse_model_name,
        )

    def _validate_environment(self, required_env_vars: List[str]):
        missing_vars = [
            var for var in required_env_vars if not os.getenv(var, "").strip()
        ]
        if missing_vars:
            raise EnvironmentError(
                f"Missing or empty environment variable(s): {', '.join(missing_vars)}"
            )

    def get_qdrant_retriever(

        self,

        collection_name: str,

        dense_vector_name: str,

        sparse_vector_name: str,

        k: int = 5,

    ) -> VectorStoreRetriever:
        qdrantdb = QdrantVectorStore.from_existing_collection(
            embedding=self.dense_embeddings,
            sparse_embedding=self.sparse_embeddings,
            url=self.qdrant_url,
            api_key=self.qdrant_api_key,
            prefer_grpc=True,
            collection_name=collection_name,
            retrieval_mode=RetrievalMode.HYBRID,
            vector_name=dense_vector_name,
            sparse_vector_name=sparse_vector_name,
        )

        return qdrantdb.as_retriever(search_kwargs={"k": k})

    def get_documents_retriever(self, k: int = 5) -> VectorStoreRetriever:
        return self.get_qdrant_retriever(
            collection_name="docs_hybrid_db",
            dense_vector_name="docs_dense_vectors",
            sparse_vector_name="docs_sparse_vectors",
            k=k,
        )

    def get_practitioners_retriever(self, k: int = 5) -> VectorStoreRetriever:
        return self.get_qdrant_retriever(
            collection_name="practitioners_hybrid_db",
            dense_vector_name="practitioners_dense_vectors",
            sparse_vector_name="practitioners_sparse_vectors",
            k=k,
        )