AmitIsraeli commited on
Commit
0d2ab62
·
1 Parent(s): 211a3a6

solve error shapes

Browse files
Files changed (1) hide show
  1. help_function.py +4 -0
help_function.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  import open_clip
3
  from torchvision import transforms
4
  from torchvision.transforms import ToPILImage
 
5
 
6
  class help_function:
7
  def __init__(self):
@@ -24,6 +25,9 @@ class help_function:
24
 
25
  def get_image_inversion(self, image):
26
  image = self.transform(image)
 
 
 
27
  w_inversion = self.encoder(image.reshape(1,3,224,224)).reshape(1,16,512)
28
  return w_inversion + self.mean_person
29
 
 
2
  import open_clip
3
  from torchvision import transforms
4
  from torchvision.transforms import ToPILImage
5
+ import torch.nn.functional as F
6
 
7
  class help_function:
8
  def __init__(self):
 
25
 
26
  def get_image_inversion(self, image):
27
  image = self.transform(image)
28
+ if not image.shape == torch.Size([3, 224, 224]):
29
+ image = image.reshape(1,3,image.shape[1],image.shape[2])
30
+ image = F.interpolate(image, [224,224], mode='bilinear', align_corners=True)
31
  w_inversion = self.encoder(image.reshape(1,3,224,224)).reshape(1,16,512)
32
  return w_inversion + self.mean_person
33