Spaces:
Runtime error
Runtime error
Samuel Schmidt
commited on
Commit
·
f9494ca
1
Parent(s):
f365545
Update src/app.py
Browse files- src/app.py +5 -12
src/app.py
CHANGED
@@ -49,12 +49,6 @@ def check_index(ds):
|
|
49 |
|
50 |
else:
|
51 |
return index_dataset(ds)
|
52 |
-
|
53 |
-
|
54 |
-
def find_similar_images(method):
|
55 |
-
if method == "FAISS":
|
56 |
-
|
57 |
-
return retrieved examples
|
58 |
|
59 |
|
60 |
dataset_with_embeddings = check_index(candidate_subset)
|
@@ -62,7 +56,7 @@ dataset_with_embeddings = check_index(candidate_subset)
|
|
62 |
# Main function, to find similar images
|
63 |
# TODO: implement different distance measures
|
64 |
|
65 |
-
def get_neighbors(query_image, selected_descriptor, top_k=5):
|
66 |
"""Returns the top k nearest examples to the query image.
|
67 |
|
68 |
Args:
|
@@ -76,16 +70,15 @@ def get_neighbors(query_image, selected_descriptor, top_k=5):
|
|
76 |
cd = ColorDescriptor((8, 12, 3))
|
77 |
qi_embedding = cd.describe(query_image)
|
78 |
qi_np = np.array(qi_embedding)
|
79 |
-
if
|
80 |
scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples(
|
81 |
'color_embeddings', qi_np, k=top_k)
|
82 |
-
elif
|
83 |
tmp_dataset = dataset_with_embeddings.map(lambda row: {'distance': chi2_distance(histA=query_vector, histB=row['color_embeddings'])})
|
84 |
retrieved_examples = tmp_dataset.sort("distance")
|
85 |
else:
|
86 |
tmp_dataset = dataset_with_embeddings.map(lambda row: {'distance': euclidian_distance(histA=query_vector, histB=row['color_embeddings'])})
|
87 |
retrieved_examples = tmp_dataset.sort("distance")
|
88 |
-
|
89 |
images = retrieved_examples['image'] #retrieved images is a dict, with images and embeddings
|
90 |
return images
|
91 |
if "CLIP" == selected_descriptor:
|
@@ -133,13 +126,13 @@ with gr.Blocks() as demo:
|
|
133 |
|
134 |
with gr.Row():
|
135 |
descr_dropdown = gr.Dropdown(["Color Descriptor", "LBP", "CLIP"], value="LBP", label="Please choose an descriptor")
|
136 |
-
|
137 |
dataset_dropdown = gr.Dropdown(
|
138 |
["huggan/CelebA-faces", "EIT/cbir-eit"],
|
139 |
value="huggan/CelebA-faces",
|
140 |
label="Please select a dataset"
|
141 |
)
|
142 |
-
btn.click(get_neighbors, inputs=[image_input, descr_dropdown], outputs=[gallery_output])
|
143 |
|
144 |
|
145 |
if __name__ == "__main__":
|
|
|
49 |
|
50 |
else:
|
51 |
return index_dataset(ds)
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
|
54 |
dataset_with_embeddings = check_index(candidate_subset)
|
|
|
56 |
# Main function, to find similar images
|
57 |
# TODO: implement different distance measures
|
58 |
|
59 |
+
def get_neighbors(query_image, selected_descriptor, selected_distance top_k=5):
|
60 |
"""Returns the top k nearest examples to the query image.
|
61 |
|
62 |
Args:
|
|
|
70 |
cd = ColorDescriptor((8, 12, 3))
|
71 |
qi_embedding = cd.describe(query_image)
|
72 |
qi_np = np.array(qi_embedding)
|
73 |
+
if selected_distance == "FAISS":
|
74 |
scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples(
|
75 |
'color_embeddings', qi_np, k=top_k)
|
76 |
+
elif selected_distance == "Chi-squared":
|
77 |
tmp_dataset = dataset_with_embeddings.map(lambda row: {'distance': chi2_distance(histA=query_vector, histB=row['color_embeddings'])})
|
78 |
retrieved_examples = tmp_dataset.sort("distance")
|
79 |
else:
|
80 |
tmp_dataset = dataset_with_embeddings.map(lambda row: {'distance': euclidian_distance(histA=query_vector, histB=row['color_embeddings'])})
|
81 |
retrieved_examples = tmp_dataset.sort("distance")
|
|
|
82 |
images = retrieved_examples['image'] #retrieved images is a dict, with images and embeddings
|
83 |
return images
|
84 |
if "CLIP" == selected_descriptor:
|
|
|
126 |
|
127 |
with gr.Row():
|
128 |
descr_dropdown = gr.Dropdown(["Color Descriptor", "LBP", "CLIP"], value="LBP", label="Please choose an descriptor")
|
129 |
+
dist_dropdown = gr.Dropdown(["FAISS", "Chi-squared", "Euclid"], value="FAISS", label="Please choose a distance measure")
|
130 |
dataset_dropdown = gr.Dropdown(
|
131 |
["huggan/CelebA-faces", "EIT/cbir-eit"],
|
132 |
value="huggan/CelebA-faces",
|
133 |
label="Please select a dataset"
|
134 |
)
|
135 |
+
btn.click(get_neighbors, inputs=[image_input, descr_dropdown, dist_dropdown], outputs=[gallery_output])
|
136 |
|
137 |
|
138 |
if __name__ == "__main__":
|