vhr1007 commited on
Commit
a80ee03
·
1 Parent(s): 3d40486

debug query embed

Browse files
Files changed (1) hide show
  1. services/qdrant_searcher.py +9 -6
services/qdrant_searcher.py CHANGED
@@ -6,19 +6,22 @@ from qdrant_client.http.models import Filter, FieldCondition
6
 
7
  class QdrantSearcher:
8
  def __init__(self, qdrant_url, access_token):
9
- # Removed the encoder since embeddings are precomputed externally
10
  self.client = QdrantClient(url=qdrant_url, api_key=access_token)
11
 
12
  def search_documents(self, collection_name, query_embedding, user_id, limit=3):
13
  logging.info("Starting document search")
14
 
15
- # Ensure the query_embedding is in the correct format (list)
16
  if isinstance(query_embedding, torch.Tensor):
17
- query_embedding = query_embedding.detach().numpy().tolist()
18
- logging.info("Converted query embedding to list")
19
  elif isinstance(query_embedding, np.ndarray):
20
- query_embedding = query_embedding.tolist()
21
- logging.info("Converted query embedding to list")
 
 
 
 
 
22
 
23
  # Filter by user_id
24
  query_filter = Filter(must=[FieldCondition(key="user_id", match={"value": user_id})])
 
6
 
7
  class QdrantSearcher:
8
  def __init__(self, qdrant_url, access_token):
 
9
  self.client = QdrantClient(url=qdrant_url, api_key=access_token)
10
 
11
  def search_documents(self, collection_name, query_embedding, user_id, limit=3):
12
  logging.info("Starting document search")
13
 
14
+ # Ensure the query_embedding is in the correct format (flat list of floats)
15
  if isinstance(query_embedding, torch.Tensor):
16
+ query_embedding = query_embedding.detach().numpy().flatten().tolist()
 
17
  elif isinstance(query_embedding, np.ndarray):
18
+ query_embedding = query_embedding.flatten().tolist()
19
+ else:
20
+ raise ValueError("query_embedding must be a torch.Tensor or numpy.ndarray")
21
+
22
+ # Validate that all elements in the query_vector are floats
23
+ if not all(isinstance(x, float) for x in query_embedding):
24
+ raise ValueError("All elements in query_embedding must be of type float")
25
 
26
  # Filter by user_id
27
  query_filter = Filter(must=[FieldCondition(key="user_id", match={"value": user_id})])