Samuel Schmidt commited on
Commit
f9494ca
·
1 Parent(s): f365545

Update src/app.py

Browse files
Files changed (1) hide show
  1. src/app.py +5 -12
src/app.py CHANGED
@@ -49,12 +49,6 @@ def check_index(ds):
49
 
50
  else:
51
  return index_dataset(ds)
52
-
53
-
54
- def find_similar_images(method):
55
- if method == "FAISS":
56
-
57
- return retrieved examples
58
 
59
 
60
  dataset_with_embeddings = check_index(candidate_subset)
@@ -62,7 +56,7 @@ dataset_with_embeddings = check_index(candidate_subset)
62
  # Main function, to find similar images
63
  # TODO: implement different distance measures
64
 
65
- def get_neighbors(query_image, selected_descriptor, top_k=5):
66
  """Returns the top k nearest examples to the query image.
67
 
68
  Args:
@@ -76,16 +70,15 @@ def get_neighbors(query_image, selected_descriptor, top_k=5):
76
  cd = ColorDescriptor((8, 12, 3))
77
  qi_embedding = cd.describe(query_image)
78
  qi_np = np.array(qi_embedding)
79
- if distance_measure == "FAISS":
80
  scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples(
81
  'color_embeddings', qi_np, k=top_k)
82
- elif distance_measure == "Chi-squared":
83
  tmp_dataset = dataset_with_embeddings.map(lambda row: {'distance': chi2_distance(histA=query_vector, histB=row['color_embeddings'])})
84
  retrieved_examples = tmp_dataset.sort("distance")
85
  else:
86
  tmp_dataset = dataset_with_embeddings.map(lambda row: {'distance': euclidian_distance(histA=query_vector, histB=row['color_embeddings'])})
87
  retrieved_examples = tmp_dataset.sort("distance")
88
-
89
  images = retrieved_examples['image'] #retrieved images is a dict, with images and embeddings
90
  return images
91
  if "CLIP" == selected_descriptor:
@@ -133,13 +126,13 @@ with gr.Blocks() as demo:
133
 
134
  with gr.Row():
135
  descr_dropdown = gr.Dropdown(["Color Descriptor", "LBP", "CLIP"], value="LBP", label="Please choose an descriptor")
136
- checkboxes_descr = gr.CheckboxGroup(["Color Descriptor", "LBP", "CLIP"], label="Please choose an descriptor")
137
  dataset_dropdown = gr.Dropdown(
138
  ["huggan/CelebA-faces", "EIT/cbir-eit"],
139
  value="huggan/CelebA-faces",
140
  label="Please select a dataset"
141
  )
142
- btn.click(get_neighbors, inputs=[image_input, descr_dropdown], outputs=[gallery_output])
143
 
144
 
145
  if __name__ == "__main__":
 
49
 
50
  else:
51
  return index_dataset(ds)
 
 
 
 
 
 
52
 
53
 
54
  dataset_with_embeddings = check_index(candidate_subset)
 
56
  # Main function, to find similar images
57
  # TODO: implement different distance measures
58
 
59
+ def get_neighbors(query_image, selected_descriptor, selected_distance top_k=5):
60
  """Returns the top k nearest examples to the query image.
61
 
62
  Args:
 
70
  cd = ColorDescriptor((8, 12, 3))
71
  qi_embedding = cd.describe(query_image)
72
  qi_np = np.array(qi_embedding)
73
+ if selected_distance == "FAISS":
74
  scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples(
75
  'color_embeddings', qi_np, k=top_k)
76
+ elif selected_distance == "Chi-squared":
77
  tmp_dataset = dataset_with_embeddings.map(lambda row: {'distance': chi2_distance(histA=query_vector, histB=row['color_embeddings'])})
78
  retrieved_examples = tmp_dataset.sort("distance")
79
  else:
80
  tmp_dataset = dataset_with_embeddings.map(lambda row: {'distance': euclidian_distance(histA=query_vector, histB=row['color_embeddings'])})
81
  retrieved_examples = tmp_dataset.sort("distance")
 
82
  images = retrieved_examples['image'] #retrieved images is a dict, with images and embeddings
83
  return images
84
  if "CLIP" == selected_descriptor:
 
126
 
127
  with gr.Row():
128
  descr_dropdown = gr.Dropdown(["Color Descriptor", "LBP", "CLIP"], value="LBP", label="Please choose an descriptor")
129
+ dist_dropdown = gr.Dropdown(["FAISS", "Chi-squared", "Euclid"], value="FAISS", label="Please choose a distance measure")
130
  dataset_dropdown = gr.Dropdown(
131
  ["huggan/CelebA-faces", "EIT/cbir-eit"],
132
  value="huggan/CelebA-faces",
133
  label="Please select a dataset"
134
  )
135
+ btn.click(get_neighbors, inputs=[image_input, descr_dropdown, dist_dropdown], outputs=[gallery_output])
136
 
137
 
138
  if __name__ == "__main__":