vhr1007
commited on
Commit
·
a80ee03
1
Parent(s):
3d40486
debug query embed
Browse files
services/qdrant_searcher.py
CHANGED
@@ -6,19 +6,22 @@ from qdrant_client.http.models import Filter, FieldCondition
|
|
6 |
|
7 |
class QdrantSearcher:
|
8 |
def __init__(self, qdrant_url, access_token):
|
9 |
-
# Removed the encoder since embeddings are precomputed externally
|
10 |
self.client = QdrantClient(url=qdrant_url, api_key=access_token)
|
11 |
|
12 |
def search_documents(self, collection_name, query_embedding, user_id, limit=3):
|
13 |
logging.info("Starting document search")
|
14 |
|
15 |
-
# Ensure the query_embedding is in the correct format (list)
|
16 |
if isinstance(query_embedding, torch.Tensor):
|
17 |
-
query_embedding = query_embedding.detach().numpy().tolist()
|
18 |
-
logging.info("Converted query embedding to list")
|
19 |
elif isinstance(query_embedding, np.ndarray):
|
20 |
-
query_embedding = query_embedding.tolist()
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
# Filter by user_id
|
24 |
query_filter = Filter(must=[FieldCondition(key="user_id", match={"value": user_id})])
|
|
|
6 |
|
7 |
class QdrantSearcher:
|
8 |
def __init__(self, qdrant_url, access_token):
|
|
|
9 |
self.client = QdrantClient(url=qdrant_url, api_key=access_token)
|
10 |
|
11 |
def search_documents(self, collection_name, query_embedding, user_id, limit=3):
|
12 |
logging.info("Starting document search")
|
13 |
|
14 |
+
# Ensure the query_embedding is in the correct format (flat list of floats)
|
15 |
if isinstance(query_embedding, torch.Tensor):
|
16 |
+
query_embedding = query_embedding.detach().numpy().flatten().tolist()
|
|
|
17 |
elif isinstance(query_embedding, np.ndarray):
|
18 |
+
query_embedding = query_embedding.flatten().tolist()
|
19 |
+
else:
|
20 |
+
raise ValueError("query_embedding must be a torch.Tensor or numpy.ndarray")
|
21 |
+
|
22 |
+
# Validate that all elements in the query_vector are floats
|
23 |
+
if not all(isinstance(x, float) for x in query_embedding):
|
24 |
+
raise ValueError("All elements in query_embedding must be of type float")
|
25 |
|
26 |
# Filter by user_id
|
27 |
query_filter = Filter(must=[FieldCondition(key="user_id", match={"value": user_id})])
|