ai-virtual-assistant / rag /retrievers.py
yrobel-lima's picture
Upload 4 files
e921012 verified
raw
history blame
2.72 kB
import os
from typing import List, Literal
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_openai import OpenAIEmbeddings
from langchain_qdrant import FastEmbedSparse, QdrantVectorStore, RetrievalMode
os.environ["GRPC_VERBOSITY"] = "NONE"
class RetrieversConfig:
def __init__(
self,
dense_model_name: Literal["text-embedding-3-small"] = "text-embedding-3-small",
sparse_model_name: Literal[
"prithivida/Splade_PP_en_v1"
] = "prithivida/Splade_PP_en_v1",
):
self.required_env_vars = ["QDRANT_API_KEY", "QDRANT_URL", "OPENAI_API_KEY"]
self._validate_environment(self.required_env_vars)
self.qdrant_url = os.getenv("QDRANT_URL")
self.qdrant_api_key = os.getenv("QDRANT_API_KEY")
self.dense_embeddings = OpenAIEmbeddings(model=dense_model_name)
self.sparse_embeddings = FastEmbedSparse(
model_name=sparse_model_name,
)
def _validate_environment(self, required_env_vars: List[str]):
missing_vars = [
var for var in required_env_vars if not os.getenv(var, "").strip()
]
if missing_vars:
raise EnvironmentError(
f"Missing or empty environment variable(s): {', '.join(missing_vars)}"
)
def get_qdrant_retriever(
self,
collection_name: str,
dense_vector_name: str,
sparse_vector_name: str,
k: int = 5,
) -> VectorStoreRetriever:
qdrantdb = QdrantVectorStore.from_existing_collection(
embedding=self.dense_embeddings,
sparse_embedding=self.sparse_embeddings,
url=self.qdrant_url,
api_key=self.qdrant_api_key,
prefer_grpc=True,
collection_name=collection_name,
retrieval_mode=RetrievalMode.HYBRID,
vector_name=dense_vector_name,
sparse_vector_name=sparse_vector_name,
)
return qdrantdb.as_retriever(search_kwargs={"k": k})
def get_documents_retriever(self, k: int = 5) -> VectorStoreRetriever:
return self.get_qdrant_retriever(
collection_name="docs_hybrid_db",
dense_vector_name="docs_dense_vectors",
sparse_vector_name="docs_sparse_vectors",
k=k,
)
def get_practitioners_retriever(self, k: int = 5) -> VectorStoreRetriever:
return self.get_qdrant_retriever(
collection_name="practitioners_hybrid_db",
dense_vector_name="practitioners_dense_vectors",
sparse_vector_name="practitioners_sparse_vectors",
k=k,
)