import logging import torch import numpy as np from qdrant_client import QdrantClient from qdrant_client.http.models import Filter, FieldCondition from collections import defaultdict class QdrantSearcher: def __init__(self, qdrant_url, access_token): self.client = QdrantClient(url=qdrant_url, api_key=access_token) def search_documents(self, collection_name, query_embedding, user_id, limit=3,similarity_threshold=0.6, file_id=None): logging.info("Starting document search") # Ensure the query_embedding is in the correct format (flat list of floats) if isinstance(query_embedding, torch.Tensor): query_embedding = query_embedding.detach().numpy().flatten().tolist() elif isinstance(query_embedding, np.ndarray): query_embedding = query_embedding.flatten().tolist() else: raise ValueError("query_embedding must be a torch.Tensor or numpy.ndarray") # Validate that all elements in the query_vector are floats if not all(isinstance(x, float) for x in query_embedding): raise ValueError("All elements in query_embedding must be of type float") filter_conditions = [FieldCondition(key="user_id", match={"value": user_id})] if file_id: filter_conditions.append(FieldCondition(key="file_id", match={"value": file_id})) # Filter by user_id query_filter = Filter(must=filter_conditions) logging.info(f"Performing search using the precomputed embeddings for user_id: {user_id}") try: hits = self.client.search( collection_name=collection_name, query_vector=query_embedding, limit=limit, query_filter=query_filter ) except Exception as e: logging.error(f"Error during Qdrant search: {e}") return None, str(e) filtered_hits = [hit for hit in hits if hit.score >= similarity_threshold] if not filtered_hits: logging.info("No documents found for the given query") return None, "No documents found for the given query." hits_list = [] for hit in filtered_hits: hit_info = { "id": hit.id, "score": hit.score, "file_id": hit.payload.get('file_id'), "file_name": hit.payload.get('file_name'), "organization_id": hit.payload.get('organization_id'), "chunk_index": hit.payload.get('chunk_index'), "chunk_text": hit.payload.get('chunk_text'), "s3_bucket_key": hit.payload.get('s3_bucket_key') } hits_list.append(hit_info) logging.info(f"Document search completed with {len(hits_list)} hits") logging.info(f"Hits: {hits_list}") return hits_list, None def search_documents_grouped(self, collection_name, query_embedding, user_id, limit=60, similarity_threshold=0.6, file_id=None): logging.info("Starting grouped document search") if isinstance(query_embedding, torch.Tensor): query_embedding = query_embedding.detach().numpy().flatten().tolist() elif isinstance(query_embedding, np.ndarray): query_embedding = query_embedding.flatten().tolist() else: raise ValueError("query_embedding must be a torch.Tensor or numpy.ndarray") if not all(isinstance(x, float) for x in query_embedding): raise ValueError("All elements in query_embedding must be of type float") #query_filter = Filter(must=[FieldCondition(key="user_id", match={"value": user_id})]) filter_conditions = [FieldCondition(key="user_id", match={"value": user_id})] if file_id: filter_conditions.append(FieldCondition(key="file_id", match={"value": file_id})) # Filter by user_id query_filter = Filter(must=filter_conditions) logging.info(f"Performing grouped search using the precomputed embeddings for user_id: {user_id}") try: hits = self.client.search( collection_name=collection_name, query_vector=query_embedding, limit=limit, query_filter=query_filter ) except Exception as e: logging.error(f"Error during Qdrant search: {e}") return None, str(e) #filtered_hits = [hit for hit in hits if hit.score >= similarity_threshold] if not hits: logging.info("No documents found for the given query") return None, "No documents found for the given query." # Group hits by filename and calculate average score grouped_hits = defaultdict(list) for hit in hits: grouped_hits[hit.payload.get('file_name')].append(hit.score) grouped_results = [] for file_name, scores in grouped_hits.items(): average_score = sum(scores) / len(scores) grouped_results.append({ "file_name": file_name, "average_score": average_score }) logging.info(f"Grouped search completed with {len(grouped_results)} results") logging.info(f"Grouped Hits: {grouped_results}") return grouped_results, None