ragavsachdeva commited on
Commit
b8b475d
·
verified ·
1 Parent(s): 75798cc

Update modelling_magiv2.py

Browse files
Files changed (1) hide show
  1. modelling_magiv2.py +4 -2
modelling_magiv2.py CHANGED
@@ -29,13 +29,15 @@ class Magiv2Model(PreTrainedModel):
29
  def move_to_device(self, input):
30
  return move_to_device(input, self.device)
31
 
32
- def forward(self, images, move_to_device_fn=None, mask_ratio=0.0, batch_size=256):
33
  move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
34
  if len(images) == 0:
35
  return move_to_device_fn(torch.zeros(len(images), self.config.crop_embedding_model_config.hidden_size))
36
 
37
  assert all(isinstance(image, PIL.Image.Image) for image in images), "please provide a list of PIL images"
38
- images = [np.array(image.convert("L").convert("RGB")) for image in images]
 
 
39
  images = self.processor(images, return_tensors="pt").pixel_values
40
  images = move_to_device_fn(images)
41
 
 
29
  def move_to_device(self, input):
30
  return move_to_device(input, self.device)
31
 
32
+ def forward(self, images, move_to_device_fn=None, mask_ratio=0.0, batch_size=256, convert_to_grayscale=True):
33
  move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
34
  if len(images) == 0:
35
  return move_to_device_fn(torch.zeros(len(images), self.config.crop_embedding_model_config.hidden_size))
36
 
37
  assert all(isinstance(image, PIL.Image.Image) for image in images), "please provide a list of PIL images"
38
+ if convert_to_grayscale:
39
+ images = [x.convert("L") for x in images]
40
+ images = [np.array(image.convert("RGB")) for image in images]
41
  images = self.processor(images, return_tensors="pt").pixel_values
42
  images = move_to_device_fn(images)
43