bendeguzszabo commited on
Commit
6d93dbe
·
1 Parent(s): 83942d1

Update src/CLIP.py

Browse files
Files changed (1) hide show
  1. src/CLIP.py +30 -1
src/CLIP.py CHANGED
@@ -3,18 +3,47 @@ import torch
3
 
4
 
5
  class CLIPImageEncoder:
 
 
 
 
 
 
 
 
 
 
 
6
  def __init__(self, device="cpu"):
7
  self.device = device
8
  self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
9
  self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
10
-
11
  def encode_image(self, image_pil):
 
 
 
 
 
 
 
 
 
12
  with torch.no_grad():
13
  input = self.processor(images=image_pil, return_tensors="pt")
14
  image_features = self.model.get_image_features(**input)
15
  return image_features.cpu().detach().numpy()[0]
16
 
17
  def encode_images(self, batch):
 
 
 
 
 
 
 
 
 
18
  images = batch["image"]
19
  input = self.processor(images=images, return_tensors="pt")
20
  with torch.no_grad():
 
3
 
4
 
5
  class CLIPImageEncoder:
6
+ """
7
+ A class for encoding images using the CLIP model.
8
+
9
+ Args:
10
+ device (str): The device to run the model on (default: "cpu").
11
+
12
+ Attributes:
13
+ device (str): The device to run the model on.
14
+ model (CLIPModel): The CLIP model used for image encoding.
15
+ processor (AutoProcessor): The tokenizer and input processor for the CLIP model.
16
+ """
17
  def __init__(self, device="cpu"):
18
  self.device = device
19
  self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
20
  self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
21
+
22
  def encode_image(self, image_pil):
23
+ """
24
+ Encodes a single image using the CLIP model.
25
+
26
+ Args:
27
+ image_pil: A PIL Image object representing the image to encode.
28
+
29
+ Returns:
30
+ numpy.ndarray: The CLIP embedding for the image.
31
+ """
32
  with torch.no_grad():
33
  input = self.processor(images=image_pil, return_tensors="pt")
34
  image_features = self.model.get_image_features(**input)
35
  return image_features.cpu().detach().numpy()[0]
36
 
37
  def encode_images(self, batch):
38
+ """
39
+ Encodes a batch of images using the CLIP model.
40
+
41
+ Args:
42
+ batch (Dict[str, Any]): A dictionary containing the batch of images to encode.
43
+
44
+ Returns:
45
+ Dict[str, Any]: A dictionary containing the CLIP embeddings for the batch of images.
46
+ """
47
  images = batch["image"]
48
  input = self.processor(images=images, return_tensors="pt")
49
  with torch.no_grad():