as-cle-bert commited on
Commit
6b750d5
·
verified ·
1 Parent(s): c4ae4d8

Update QdrantRag.py

Browse files
Files changed (1) hide show
  1. QdrantRag.py +7 -3
QdrantRag.py CHANGED
@@ -21,6 +21,10 @@ qdrant_client = QdrantClient(url=os.getenv("qdrant_url"), api_key=os.getenv("qdr
21
  sparse_encoder = SparseTextEmbedding(model_name="prithivida/Splade_PP_en_v1")
22
  co = cohere.ClientV2(os.getenv("cohere_api_key"))
23
 
 
 
 
 
24
  def get_sparse_embedding(text: str, model: SparseTextEmbedding):
25
  embeddings = list(model.embed(text))
26
  vector = {f"sparse-text": models.SparseVector(indices=embeddings[0].indices, values=embeddings[0].values)}
@@ -152,7 +156,7 @@ class NeuralSearcher:
152
  results = co.rerank(model="rerank-v3.5", query=text, documents=search_result, top_n = 3)
153
  ranked_results = [search_result[results.results[i].index] for i in range(3)]
154
  return ranked_results
155
- def search_image(self, image: ImageFile, limit: int = 1):
156
  img = image
157
  inputs = self.image_processor(images=img, return_tensors="pt").to(device)
158
  with torch.no_grad():
@@ -163,5 +167,5 @@ class NeuralSearcher:
163
  query_filter=None,
164
  limit=limit,
165
  )
166
- payloads = [hit.payload["label"] for hit in search_result]
167
- return payloads
 
21
  sparse_encoder = SparseTextEmbedding(model_name="prithivida/Splade_PP_en_v1")
22
  co = cohere.ClientV2(os.getenv("cohere_api_key"))
23
 
24
+ dataset = load_dataset("Karbo31881/Pokemon_images")
25
+ ds = dataset["train"]
26
+ labels = ds["text"]
27
+
28
  def get_sparse_embedding(text: str, model: SparseTextEmbedding):
29
  embeddings = list(model.embed(text))
30
  vector = {f"sparse-text": models.SparseVector(indices=embeddings[0].indices, values=embeddings[0].values)}
 
156
  results = co.rerank(model="rerank-v3.5", query=text, documents=search_result, top_n = 3)
157
  ranked_results = [search_result[results.results[i].index] for i in range(3)]
158
  return ranked_results
159
+ def search_image(self, image: ImageFile, limit: int = 5):
160
  img = image
161
  inputs = self.image_processor(images=img, return_tensors="pt").to(device)
162
  with torch.no_grad():
 
167
  query_filter=None,
168
  limit=limit,
169
  )
170
+ payloads = [f"- {hit.payload["label"]} with score {hit.score}" for hit in search_result]
171
+ return payloads