|
import logging |
|
import torch |
|
import numpy as np |
|
from qdrant_client import QdrantClient |
|
from qdrant_client.http.models import Filter, FieldCondition |
|
from collections import defaultdict |
|
|
|
class QdrantSearcher: |
|
def __init__(self, qdrant_url, access_token): |
|
self.client = QdrantClient(url=qdrant_url, api_key=access_token) |
|
|
|
def search_documents(self, collection_name, query_embedding, user_id, limit=3,similarity_threshold=0.6, file_id=None): |
|
logging.info("Starting document search") |
|
|
|
|
|
if isinstance(query_embedding, torch.Tensor): |
|
query_embedding = query_embedding.detach().numpy().flatten().tolist() |
|
elif isinstance(query_embedding, np.ndarray): |
|
query_embedding = query_embedding.flatten().tolist() |
|
else: |
|
raise ValueError("query_embedding must be a torch.Tensor or numpy.ndarray") |
|
|
|
|
|
if not all(isinstance(x, float) for x in query_embedding): |
|
raise ValueError("All elements in query_embedding must be of type float") |
|
|
|
filter_conditions = [FieldCondition(key="user_id", match={"value": user_id})] |
|
|
|
if file_id: |
|
filter_conditions.append(FieldCondition(key="file_id", match={"value": file_id})) |
|
|
|
|
|
query_filter = Filter(must=filter_conditions) |
|
logging.info(f"Performing search using the precomputed embeddings for user_id: {user_id}") |
|
try: |
|
hits = self.client.search( |
|
collection_name=collection_name, |
|
query_vector=query_embedding, |
|
limit=limit, |
|
query_filter=query_filter |
|
) |
|
except Exception as e: |
|
logging.error(f"Error during Qdrant search: {e}") |
|
return None, str(e) |
|
|
|
filtered_hits = [hit for hit in hits if hit.score >= similarity_threshold] |
|
|
|
if not filtered_hits: |
|
logging.info("No documents found for the given query") |
|
return None, "No documents found for the given query." |
|
|
|
hits_list = [] |
|
for hit in filtered_hits: |
|
hit_info = { |
|
"id": hit.id, |
|
"score": hit.score, |
|
"file_id": hit.payload.get('file_id'), |
|
"file_name": hit.payload.get('file_name'), |
|
"organization_id": hit.payload.get('organization_id'), |
|
"chunk_index": hit.payload.get('chunk_index'), |
|
"chunk_text": hit.payload.get('chunk_text'), |
|
"s3_bucket_key": hit.payload.get('s3_bucket_key') |
|
} |
|
hits_list.append(hit_info) |
|
|
|
logging.info(f"Document search completed with {len(hits_list)} hits") |
|
logging.info(f"Hits: {hits_list}") |
|
return hits_list, None |
|
|
|
def search_documents_grouped(self, collection_name, query_embedding, user_id, limit=60, similarity_threshold=0.6, file_id=None): |
|
logging.info("Starting grouped document search") |
|
|
|
if isinstance(query_embedding, torch.Tensor): |
|
query_embedding = query_embedding.detach().numpy().flatten().tolist() |
|
elif isinstance(query_embedding, np.ndarray): |
|
query_embedding = query_embedding.flatten().tolist() |
|
else: |
|
raise ValueError("query_embedding must be a torch.Tensor or numpy.ndarray") |
|
|
|
if not all(isinstance(x, float) for x in query_embedding): |
|
raise ValueError("All elements in query_embedding must be of type float") |
|
|
|
filter_conditions = [FieldCondition(key="user_id", match={"value": user_id})] |
|
|
|
if file_id: |
|
filter_conditions.append(FieldCondition(key="file_id", match={"value": file_id})) |
|
|
|
|
|
query_filter = Filter(must=filter_conditions) |
|
logging.info(f"Performing grouped search using the precomputed embeddings for user_id: {user_id}") |
|
try: |
|
hits = self.client.search( |
|
collection_name=collection_name, |
|
query_vector=query_embedding, |
|
limit=limit, |
|
query_filter=query_filter |
|
) |
|
except Exception as e: |
|
logging.error(f"Error during Qdrant search: {e}") |
|
return None, str(e) |
|
|
|
|
|
|
|
if not hits: |
|
logging.info("No documents found for the given query") |
|
return None, "No documents found for the given query." |
|
|
|
|
|
grouped_hits = defaultdict(list) |
|
for hit in hits: |
|
grouped_hits[hit.payload.get('file_name')].append(hit.score) |
|
|
|
grouped_results = [] |
|
for file_name, scores in grouped_hits.items(): |
|
average_score = sum(scores) / len(scores) |
|
grouped_results.append({ |
|
"file_name": file_name, |
|
"average_score": average_score |
|
}) |
|
|
|
logging.info(f"Grouped search completed with {len(grouped_results)} results") |
|
logging.info(f"Grouped Hits: {grouped_results}") |
|
return grouped_results, None |
|
|
|
|