|
import os |
|
from functools import cache |
|
|
|
import qdrant_client |
|
import torch |
|
from langchain.retrievers import ContextualCompressionRetriever |
|
from langchain.retrievers.document_compressors import EmbeddingsFilter |
|
from langchain_community.retrievers import QdrantSparseVectorRetriever |
|
from langchain_community.vectorstores import Qdrant |
|
from langchain_openai.embeddings import OpenAIEmbeddings |
|
from transformers import AutoModelForMaskedLM, AutoTokenizer |
|
|
|
|
|
class DenseRetrieverClient: |
|
"""Inititalize the dense retriever using OpenAI text embeddings and Qdrant vector database. |
|
|
|
Attributes: |
|
embeddings_model (str): The embeddings model to use. Right now only OpenAI text embeddings. |
|
collection_name (str): Qdrant collection name. |
|
client (QdrantClient): Qdrant client. |
|
qdrant_collection (Qdrant): Qdrant collection. |
|
""" |
|
|
|
def __init__(self, embeddings_model: str = "text-embedding-ada-002", collection_name: str = "practitioners_db"): |
|
self.validate_environment_variables() |
|
self.embeddings_model = embeddings_model |
|
self.collection_name = collection_name |
|
self.client = qdrant_client.QdrantClient( |
|
url=os.getenv("QDRANT_URL"), |
|
api_key=os.getenv("QDRANT_API_KEY"), |
|
) |
|
self.qdrant_collection = self.load_qdrant_collection() |
|
|
|
def validate_environment_variables(self): |
|
""" Check if the Qdrant environment variables are set.""" |
|
required_vars = ["QDRANT_API_KEY", "QDRANT_URL"] |
|
for var in required_vars: |
|
if not os.getenv(var): |
|
raise EnvironmentError(f"Missing environment variable: {var}") |
|
|
|
def set_qdrant_collection(self, embeddings): |
|
"""Prepare the Qdrant collection for the embeddings model.""" |
|
return Qdrant(client=self.client, |
|
collection_name=self.collection_name, |
|
embeddings=embeddings) |
|
|
|
@cache |
|
def load_qdrant_collection(self): |
|
"""Load Qdrant collection for a given embeddings model.""" |
|
if self.embeddings_model == "text-embedding-ada-002": |
|
self.qdrant_collection = self.set_qdrant_collection( |
|
OpenAIEmbeddings(model=self.embeddings_model)) |
|
else: |
|
raise ValueError( |
|
f"Invalid embeddings model: {self.embeddings_model}. Select 'text-embedding-ada-002' from OpenAI.") |
|
|
|
return self.qdrant_collection |
|
|
|
def get_dense_retriever(self, search_type: str = "similarity", k: int = 4): |
|
"""Set up retrievers (Qdrant vectorstore as retriever). |
|
|
|
Args: |
|
search_type (str, optional): similarity or mmr. Defaults to "similarity". |
|
k (int, optional): Number of documents retrieved. Defaults to 4. |
|
|
|
Returns: |
|
Retriever: Vectorstore as a retriever |
|
""" |
|
dense_retriever = self.qdrant_collection.as_retriever(search_type=search_type, |
|
search_kwargs={ |
|
"k": k} |
|
) |
|
return dense_retriever |
|
|
|
|
|
class SparseRetrieverClient: |
|
"""Inititalize the sparse retriever using the SPLADE neural retrieval model and Qdrant vector database. |
|
|
|
Attributes: |
|
collection_name (str): Qdrant collection name. |
|
vector_name (str): Qdrant vector name. |
|
splade_model_id (str): The SPLADE neural retrieval model id. |
|
k (int): Number of documents retrieved. |
|
client (QdrantClient): Qdrant client. |
|
""" |
|
|
|
def __init__(self, collection_name: str, vector_name: str, splade_model_id: str = "naver/splade-cocondenser-ensembledistil", k: int = 15): |
|
self.validate_environment_variables() |
|
self.client = qdrant_client.QdrantClient(url=os.getenv( |
|
"QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY")) |
|
self.model_id = splade_model_id |
|
self.collection_name = collection_name |
|
self.vector_name = vector_name |
|
self.k = k |
|
|
|
def validate_environment_variables(self): |
|
required_vars = ["QDRANT_API_KEY", "QDRANT_URL"] |
|
for var in required_vars: |
|
if not os.getenv(var): |
|
raise EnvironmentError(f"Missing environment variable: {var}") |
|
|
|
@cache |
|
def set_tokenizer_config(self): |
|
"""Initialize the tokenizer and the SPLADE neural retrieval model. |
|
See to https://huggingface.co/naver/splade-cocondenser-ensembledistil for more details. |
|
""" |
|
tokenizer = AutoTokenizer.from_pretrained(self.model_id) |
|
model = AutoModelForMaskedLM.from_pretrained(self.model_id) |
|
return tokenizer, model |
|
|
|
@cache |
|
def sparse_encoder(self, text: str) -> tuple[list[int], list[float]]: |
|
"""This function encodes the input text into a sparse vector. The encoder is required for the QdrantSparseVectorRetriever. |
|
Adapted from the Qdrant documentation: Computing the Sparse Vector code. |
|
|
|
Args: |
|
text (str): Text to encode |
|
|
|
Returns: |
|
tuple[list[int], list[float]]: Indices and values of the sparse vector |
|
""" |
|
tokenizer, model = self.set_tokenizer_config() |
|
tokens = tokenizer(text, return_tensors="pt", |
|
max_length=512, padding="max_length", truncation=True) |
|
output = model(**tokens) |
|
logits, attention_mask = output.logits, tokens.attention_mask |
|
relu_log = torch.log(1 + torch.relu(logits)) |
|
weighted_log = relu_log * attention_mask.unsqueeze(-1) |
|
max_val, _ = torch.max(weighted_log, dim=1) |
|
vec = max_val.squeeze() |
|
indices = vec.nonzero().numpy().flatten() |
|
values = vec.detach().numpy()[indices] |
|
return indices.tolist(), values.tolist() |
|
|
|
def get_sparse_retriever(self): |
|
|
|
sparse_retriever = QdrantSparseVectorRetriever( |
|
client=self.client, |
|
collection_name=self.collection_name, |
|
sparse_vector_name=self.vector_name, |
|
sparse_encoder=self.sparse_encoder, |
|
k=self.k, |
|
) |
|
|
|
return sparse_retriever |
|
|
|
|
|
def compression_retriever_setup(base_retriever, embeddings_model: str = "text-embedding-ada-002", similarity_threshold: float = 0.76) -> ContextualCompressionRetriever: |
|
""" |
|
Creates a ContextualCompressionRetriever with a base retriever and a similarity threshold. |
|
|
|
The ContextualCompressionRetriever uses an EmbeddingsFilter with OpenAIEmbeddings to filter out documents |
|
with a similarity score below the given threshold. |
|
|
|
Args: |
|
base_retriever: Retriever to be filtered. |
|
similarity_threshold (float, optional): The similarity threshold for the EmbeddingsFilter. |
|
Documents with a similarity score below this threshold will be filtered out. Defaults to 0.76 (Obtained by experimenting with text-embeddings-ada-002). |
|
|
|
Returns: |
|
ContextualCompressionRetriever: The created ContextualCompressionRetriever. |
|
""" |
|
|
|
|
|
relevant_filter = EmbeddingsFilter(embeddings=OpenAIEmbeddings(model=embeddings_model), |
|
similarity_threshold=similarity_threshold) |
|
|
|
compression_retriever = ContextualCompressionRetriever( |
|
base_compressor=relevant_filter, base_retriever=base_retriever |
|
) |
|
|
|
return compression_retriever |
|
|