SeemG commited on
Commit
764b2ab
·
verified ·
1 Parent(s): ddedf1b

Update clip_model.py

Browse files
Files changed (1) hide show
  1. clip_model.py +3 -3
clip_model.py CHANGED
@@ -306,7 +306,7 @@ def find_matches(model, image_embeddings, query, image_filenames, n=9):
306
  values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
307
  matches = [image_filenames[idx] for idx in indices[::5]]
308
 
309
- _, axes = plt.subplots(3, 3, figsize=(10, 10))
310
 
311
  results = []
312
  for match, ax in zip(matches, axes.flatten()):
@@ -321,11 +321,11 @@ def find_matches(model, image_embeddings, query, image_filenames, n=9):
321
  def clip_image_search(model,image_embeddings,
322
  query,
323
  image_filenames,
324
- n=9 ):
325
  _, valid_df = make_train_valid_dfs()
326
  model, image_embeddings = get_image_embeddings(valid_df, "best.pt")
327
  return find_matches(model,
328
  image_embeddings,
329
  query,
330
  image_filenames = valid_df['image'].values,
331
- n=9)
 
306
  values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
307
  matches = [image_filenames[idx] for idx in indices[::5]]
308
 
309
+ _, axes = plt.subplots(4, 4, figsize=(10, 10))
310
 
311
  results = []
312
  for match, ax in zip(matches, axes.flatten()):
 
321
  def clip_image_search(model,image_embeddings,
322
  query,
323
  image_filenames,
324
+ n=16):
325
  _, valid_df = make_train_valid_dfs()
326
  model, image_embeddings = get_image_embeddings(valid_df, "best.pt")
327
  return find_matches(model,
328
  image_embeddings,
329
  query,
330
  image_filenames = valid_df['image'].values,
331
+ n)