javierabad01 commited on
Commit
f997506
·
verified ·
1 Parent(s): 50b5a48

Update archs/model.py

Browse files
Files changed (1) hide show
  1. archs/model.py +14 -1
archs/model.py CHANGED
@@ -33,6 +33,8 @@ class AttentionBlock(nn.Module):
33
  class UNet(nn.Module):
34
  def __init__(self):
35
  super(UNet, self).__init__()
 
 
36
 
37
  self.encoder = nn.Sequential(
38
  nn.Conv2d(3, 32, kernel_size=3, padding=1),
@@ -68,6 +70,10 @@ class UNet(nn.Module):
68
  )
69
 
70
  def forward(self, x):
 
 
 
 
71
  skip_connections = []
72
 
73
  for layer in self.encoder:
@@ -93,5 +99,12 @@ class UNet(nn.Module):
93
  else:
94
  x = layer(x)
95
 
96
- return x
97
 
 
 
 
 
 
 
 
 
33
  class UNet(nn.Module):
34
  def __init__(self):
35
  super(UNet, self).__init__()
36
+
37
+ self.padder_size = 32
38
 
39
  self.encoder = nn.Sequential(
40
  nn.Conv2d(3, 32, kernel_size=3, padding=1),
 
70
  )
71
 
72
  def forward(self, x):
73
+
74
+ _, _, H, W = x.shape
75
+ x = self.check_image_size(x)
76
+
77
  skip_connections = []
78
 
79
  for layer in self.encoder:
 
99
  else:
100
  x = layer(x)
101
 
102
+ return x[:, :, :H, :W]
103
 
104
+
105
+ def check_image_size(self, x):
106
+ _, _, h, w = x.size()
107
+ mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
108
+ mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
109
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), value = 0)
110
+ return x