Spaces:
Runtime error
Runtime error
Samuel Schmidt
commited on
Commit
·
cb03df9
1
Parent(s):
49a4ce1
Bugfix, comma
Browse files- src/app.py +5 -5
src/app.py
CHANGED
@@ -49,14 +49,14 @@ def check_index(ds):
|
|
49 |
|
50 |
else:
|
51 |
return index_dataset(ds)
|
52 |
-
|
53 |
|
54 |
dataset_with_embeddings = check_index(candidate_subset)
|
55 |
|
56 |
# Main function, to find similar images
|
57 |
# TODO: implement different distance measures
|
58 |
|
59 |
-
def get_neighbors(query_image, selected_descriptor, selected_distance top_k=5):
|
60 |
"""Returns the top k nearest examples to the query image.
|
61 |
|
62 |
Args:
|
@@ -75,10 +75,10 @@ def get_neighbors(query_image, selected_descriptor, selected_distance top_k=5):
|
|
75 |
'color_embeddings', qi_np, k=top_k)
|
76 |
elif selected_distance == "Chi-squared":
|
77 |
tmp_dataset = dataset_with_embeddings.map(lambda row: {'distance': chi2_distance(histA=query_vector, histB=row['color_embeddings'])})
|
78 |
-
retrieved_examples = tmp_dataset.sort("distance")
|
79 |
-
else:
|
80 |
tmp_dataset = dataset_with_embeddings.map(lambda row: {'distance': euclidian_distance(histA=query_vector, histB=row['color_embeddings'])})
|
81 |
-
retrieved_examples = tmp_dataset.sort("distance")
|
82 |
images = retrieved_examples['image'] #retrieved images is a dict, with images and embeddings
|
83 |
return images
|
84 |
if "CLIP" == selected_descriptor:
|
|
|
49 |
|
50 |
else:
|
51 |
return index_dataset(ds)
|
52 |
+
|
53 |
|
54 |
dataset_with_embeddings = check_index(candidate_subset)
|
55 |
|
56 |
# Main function, to find similar images
|
57 |
# TODO: implement different distance measures
|
58 |
|
59 |
+
def get_neighbors(query_image, selected_descriptor, selected_distance, top_k=5):
|
60 |
"""Returns the top k nearest examples to the query image.
|
61 |
|
62 |
Args:
|
|
|
75 |
'color_embeddings', qi_np, k=top_k)
|
76 |
elif selected_distance == "Chi-squared":
|
77 |
tmp_dataset = dataset_with_embeddings.map(lambda row: {'distance': chi2_distance(histA=query_vector, histB=row['color_embeddings'])})
|
78 |
+
retrieved_examples = tmp_dataset.sort("distance")[:5]
|
79 |
+
else:
|
80 |
tmp_dataset = dataset_with_embeddings.map(lambda row: {'distance': euclidian_distance(histA=query_vector, histB=row['color_embeddings'])})
|
81 |
+
retrieved_examples = tmp_dataset.sort("distance")[:5]
|
82 |
images = retrieved_examples['image'] #retrieved images is a dict, with images and embeddings
|
83 |
return images
|
84 |
if "CLIP" == selected_descriptor:
|