Spaces:
Runtime error
Runtime error
Update clip_model.py
Browse files- clip_model.py +3 -3
clip_model.py
CHANGED
@@ -306,7 +306,7 @@ def find_matches(model, image_embeddings, query, image_filenames, n=9):
|
|
306 |
values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
|
307 |
matches = [image_filenames[idx] for idx in indices[::5]]
|
308 |
|
309 |
-
_, axes = plt.subplots(
|
310 |
|
311 |
results = []
|
312 |
for match, ax in zip(matches, axes.flatten()):
|
@@ -321,11 +321,11 @@ def find_matches(model, image_embeddings, query, image_filenames, n=9):
|
|
321 |
def clip_image_search(model,image_embeddings,
|
322 |
query,
|
323 |
image_filenames,
|
324 |
-
n=
|
325 |
_, valid_df = make_train_valid_dfs()
|
326 |
model, image_embeddings = get_image_embeddings(valid_df, "best.pt")
|
327 |
return find_matches(model,
|
328 |
image_embeddings,
|
329 |
query,
|
330 |
image_filenames = valid_df['image'].values,
|
331 |
-
n
|
|
|
306 |
values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
|
307 |
matches = [image_filenames[idx] for idx in indices[::5]]
|
308 |
|
309 |
+
_, axes = plt.subplots(4, 4, figsize=(10, 10))
|
310 |
|
311 |
results = []
|
312 |
for match, ax in zip(matches, axes.flatten()):
|
|
|
321 |
def clip_image_search(model,image_embeddings,
|
322 |
query,
|
323 |
image_filenames,
|
324 |
+
n=16):
|
325 |
_, valid_df = make_train_valid_dfs()
|
326 |
model, image_embeddings = get_image_embeddings(valid_df, "best.pt")
|
327 |
return find_matches(model,
|
328 |
image_embeddings,
|
329 |
query,
|
330 |
image_filenames = valid_df['image'].values,
|
331 |
+
n)
|