raja5259 commited on
Commit
71c2a66
·
verified ·
1 Parent(s): 415a1fe

change func to return images

Browse files
Files changed (1) hide show
  1. s23_openai_clip.py +2 -9
s23_openai_clip.py CHANGED
@@ -436,15 +436,8 @@ def find_matches(model, image_embeddings, query, image_filenames, n=9):
436
 
437
  values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
438
  matches = [image_filenames[idx] for idx in indices[::5]]
 
439
 
440
- _, axes = plt.subplots(3, 3, figsize=(10, 10))
441
- for match, ax in zip(matches, axes.flatten()):
442
- image = cv2.imread(f"{CFG.image_path}/{match}")
443
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
444
- ax.imshow(image)
445
- ax.axis("off")
446
-
447
- plt.show()
448
 
449
  """This is how we use this function. Aaaannnndddd the results:
450
  (The results in the blog post and the one at the beginning of the notebook were achieved with training on the 30k version)
@@ -455,7 +448,7 @@ def find_matches(model, image_embeddings, query, image_filenames, n=9):
455
  def inference_CLIP(query_text):
456
  _, valid_df = make_train_valid_dfs()
457
  model, image_embeddings = get_image_embeddings(valid_df, "best.pt")
458
- find_matches(model,
459
  image_embeddings,
460
  query=query_text,
461
  # query="dogs on the grass",
 
436
 
437
  values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
438
  matches = [image_filenames[idx] for idx in indices[::5]]
439
+ return matches
440
 
 
 
 
 
 
 
 
 
441
 
442
  """This is how we use this function. Aaaannnndddd the results:
443
  (The results in the blog post and the one at the beginning of the notebook were achieved with training on the 30k version)
 
448
  def inference_CLIP(query_text):
449
  _, valid_df = make_train_valid_dfs()
450
  model, image_embeddings = get_image_embeddings(valid_df, "best.pt")
451
+ return find_matches(model,
452
  image_embeddings,
453
  query=query_text,
454
  # query="dogs on the grass",