Spaces:
Runtime error
Runtime error
Change topk output for I2I mode
Browse files
app.py
CHANGED
@@ -52,9 +52,9 @@ def image_2_image(model, img_emb, img_names, img_urls,n_top_k_images):
|
|
52 |
st.write("The image with the most similar embedding is:")
|
53 |
cosine_sim = get_match(model, image.convert("RGB"), img_emb)
|
54 |
logger.info(cosine_sim.shape)
|
55 |
-
top_k_images_indices = torch.topk(cosine_sim, n_top_k_images,
|
56 |
logger.info(top_k_images_indices.squeeze().tolist())
|
57 |
-
images_found = [img_names[top_k_best_image] for top_k_best_image in top_k_images_indices
|
58 |
cols = st.columns(n_top_k_images)
|
59 |
for i, image_found in enumerate(images_found):
|
60 |
logger.success(f"Image match found: {image_found}")
|
|
|
52 |
st.write("The image with the most similar embedding is:")
|
53 |
cosine_sim = get_match(model, image.convert("RGB"), img_emb)
|
54 |
logger.info(cosine_sim.shape)
|
55 |
+
top_k_images_indices = torch.topk(cosine_sim, n_top_k_images, 0).indices
|
56 |
logger.info(top_k_images_indices.squeeze().tolist())
|
57 |
+
images_found = [img_names[top_k_best_image] for top_k_best_image in top_k_images_indices]
|
58 |
cols = st.columns(n_top_k_images)
|
59 |
for i, image_found in enumerate(images_found):
|
60 |
logger.success(f"Image match found: {image_found}")
|