Spaces:
Runtime error
Runtime error
Samuel Schmidt
commited on
Commit
·
fbfbcf2
1
Parent(s):
cfe24db
Update: Improve performance by using batches
Browse files- src/CLIP.py +12 -0
- src/app.py +2 -2
src/CLIP.py
CHANGED
@@ -13,3 +13,15 @@ class CLIPImageEncoder:
|
|
13 |
input = self.processor(images=image_pil, return_tensors="pt")
|
14 |
image_features = self.model.get_image_features(**input)
|
15 |
return image_features.cpu().detach().numpy()[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
input = self.processor(images=image_pil, return_tensors="pt")
|
14 |
image_features = self.model.get_image_features(**input)
|
15 |
return image_features.cpu().detach().numpy()[0]
|
16 |
+
|
17 |
+
def encode_images(self, batch):
|
18 |
+
images = batch["image"]
|
19 |
+
print(images)
|
20 |
+
input = self.processor(images=images, return_tensors="pt")
|
21 |
+
with torch.no_grad():
|
22 |
+
#image_features = self.model.get_image_features(**input)
|
23 |
+
image_features = self.model(**input).last_hidden_state[:,0].cpu()
|
24 |
+
print(image_features)
|
25 |
+
print("--------------------")
|
26 |
+
print(self.model.get_image_features(**input).cpu().detach().numpy()[0])
|
27 |
+
return {"clip_embeddings": image_features}
|
src/app.py
CHANGED
@@ -6,7 +6,7 @@ import numpy as np
|
|
6 |
from datasets import *
|
7 |
|
8 |
dataset = load_dataset("huggan/CelebA-faces")
|
9 |
-
candidate_subset = dataset["train"].select(range(
|
10 |
|
11 |
def emb_dataset(dataset):
|
12 |
# This function might need to be split up, to reduce start-up time of app
|
@@ -19,7 +19,7 @@ def emb_dataset(dataset):
|
|
19 |
|
20 |
## CLIP Embeddings
|
21 |
clip_model = CLIPImageEncoder()
|
22 |
-
dataset_with_embeddings = dataset_with_embeddings.map(
|
23 |
|
24 |
# Add index
|
25 |
dataset_with_embeddings.add_faiss_index(column='color_embeddings')
|
|
|
6 |
from datasets import *
|
7 |
|
8 |
dataset = load_dataset("huggan/CelebA-faces")
|
9 |
+
candidate_subset = dataset["train"].select(range(40)) # This is a small CBIR app! :D
|
10 |
|
11 |
def emb_dataset(dataset):
|
12 |
# This function might need to be split up, to reduce start-up time of app
|
|
|
19 |
|
20 |
## CLIP Embeddings
|
21 |
clip_model = CLIPImageEncoder()
|
22 |
+
dataset_with_embeddings = dataset_with_embeddings.map(clip_model.encode_images, batched=True, batch_size=8)
|
23 |
|
24 |
# Add index
|
25 |
dataset_with_embeddings.add_faiss_index(column='color_embeddings')
|