Samuel Schmidt commited on
Commit
fbfbcf2
·
1 Parent(s): cfe24db

Update: Improve performance by using batches

Browse files
Files changed (2) hide show
  1. src/CLIP.py +12 -0
  2. src/app.py +2 -2
src/CLIP.py CHANGED
@@ -13,3 +13,15 @@ class CLIPImageEncoder:
13
  input = self.processor(images=image_pil, return_tensors="pt")
14
  image_features = self.model.get_image_features(**input)
15
  return image_features.cpu().detach().numpy()[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  input = self.processor(images=image_pil, return_tensors="pt")
14
  image_features = self.model.get_image_features(**input)
15
  return image_features.cpu().detach().numpy()[0]
16
+
17
+ def encode_images(self, batch):
18
+ images = batch["image"]
19
+ print(images)
20
+ input = self.processor(images=images, return_tensors="pt")
21
+ with torch.no_grad():
22
+ #image_features = self.model.get_image_features(**input)
23
+ image_features = self.model(**input).last_hidden_state[:,0].cpu()
24
+ print(image_features)
25
+ print("--------------------")
26
+ print(self.model.get_image_features(**input).cpu().detach().numpy()[0])
27
+ return {"clip_embeddings": image_features}
src/app.py CHANGED
@@ -6,7 +6,7 @@ import numpy as np
6
  from datasets import *
7
 
8
  dataset = load_dataset("huggan/CelebA-faces")
9
- candidate_subset = dataset["train"].select(range(10)) # This is a small CBIR app! :D
10
 
11
  def emb_dataset(dataset):
12
  # This function might need to be split up, to reduce start-up time of app
@@ -19,7 +19,7 @@ def emb_dataset(dataset):
19
 
20
  ## CLIP Embeddings
21
  clip_model = CLIPImageEncoder()
22
- dataset_with_embeddings = dataset_with_embeddings.map(lambda row: {'clip_embeddings': clip_model.encode_image(row["image"])})
23
 
24
  # Add index
25
  dataset_with_embeddings.add_faiss_index(column='color_embeddings')
 
6
  from datasets import *
7
 
8
  dataset = load_dataset("huggan/CelebA-faces")
9
+ candidate_subset = dataset["train"].select(range(40)) # This is a small CBIR app! :D
10
 
11
  def emb_dataset(dataset):
12
  # This function might need to be split up, to reduce start-up time of app
 
19
 
20
  ## CLIP Embeddings
21
  clip_model = CLIPImageEncoder()
22
+ dataset_with_embeddings = dataset_with_embeddings.map(clip_model.encode_images, batched=True, batch_size=8)
23
 
24
  # Add index
25
  dataset_with_embeddings.add_faiss_index(column='color_embeddings')