lilkm commited on
Commit
6670841
·
verified ·
1 Parent(s): 147f912

Upload ResNet10

Browse files
Files changed (2) hide show
  1. config.json +7 -1
  2. modeling_resnet.py +38 -6
config.json CHANGED
@@ -1,6 +1,11 @@
1
  {
 
 
 
 
2
  "auto_map": {
3
- "AutoConfig": "configuration_resnet.ResNet10Config"
 
4
  },
5
  "depths": [
6
  1,
@@ -18,5 +23,6 @@
18
  ],
19
  "model_type": "resnet10",
20
  "num_channels": 3,
 
21
  "transformers_version": "4.48.1"
22
  }
 
1
  {
2
+ "_name_or_path": "lilkm/resnet10",
3
+ "architectures": [
4
+ "ResNet10"
5
+ ],
6
  "auto_map": {
7
+ "AutoConfig": "lilkm/resnet10--configuration_resnet.ResNet10Config",
8
+ "AutoModel": "modeling_resnet.ResNet10"
9
  },
10
  "depths": [
11
  1,
 
23
  ],
24
  "model_type": "resnet10",
25
  "num_channels": 3,
26
+ "torch_dtype": "float32",
27
  "transformers_version": "4.48.1"
28
  }
modeling_resnet.py CHANGED
@@ -27,8 +27,44 @@ import math
27
 
28
 
29
  class JaxStyleMaxPool(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def forward(self, x):
31
- x = nn.functional.pad(x, (0, 1, 0, 1), value=-float('inf')) # Pad right/bottom by 1 to match JAX's maxpooling padding="SAME"
 
 
 
 
 
 
32
  return nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)
33
 
34
 
@@ -159,10 +195,6 @@ class Encoder(nn.Module):
159
  hidden_states=hidden_states,
160
  )
161
 
162
- class JaxStyleMaxPool(nn.Module):
163
- def forward(self, x):
164
- x = nn.functional.pad(x, (0, 1, 0, 1), value=-float('inf')) # Pad right/bottom by 1
165
- return nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)
166
 
167
  class ResNet10(PreTrainedModel):
168
  config_class = ResNet10Config
@@ -191,7 +223,7 @@ class ResNet10(PreTrainedModel):
191
  # return super().__call__(x)
192
  nn.GroupNorm(num_groups=4, eps=1e-5, num_channels=self.config.embedding_size),
193
  ACT2FN[self.config.hidden_act],
194
- JaxStyleMaxPool(),
195
  )
196
 
197
  self.encoder = Encoder(self.config)
 
27
 
28
 
29
  class JaxStyleMaxPool(nn.Module):
30
+ """Mimics JAX's MaxPool with padding='SAME' for exact parity."""
31
+
32
+ def __init__(self, kernel_size, stride=2):
33
+ super().__init__()
34
+
35
+ # Ensure kernel_size and stride are tuples
36
+ self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
37
+ self.stride = stride if isinstance(stride, tuple) else (stride, stride)
38
+
39
+ self.maxpool = nn.MaxPool2d(
40
+ kernel_size=self.kernel_size,
41
+ stride=self.stride,
42
+ padding=0, # No padding
43
+ )
44
+
45
+ def _compute_padding(self, input_height, input_width):
46
+ """Calculate asymmetric padding to match JAX's 'SAME' behavior."""
47
+
48
+ # Compute padding needed for height and width
49
+ pad_h = max(0, (math.ceil(input_height / self.stride[0]) - 1) * self.stride[0] + self.kernel_size[0] - input_height)
50
+ pad_w = max(0, (math.ceil(input_width / self.stride[1]) - 1) * self.stride[1] + self.kernel_size[1] - input_width)
51
+
52
+ # Asymmetric padding (JAX-style: more padding on the bottom/right if needed)
53
+ pad_top = pad_h // 2
54
+ pad_bottom = pad_h - pad_top
55
+ pad_left = pad_w // 2
56
+ pad_right = pad_w - pad_left
57
+
58
+ return (pad_left, pad_right, pad_top, pad_bottom)
59
+
60
  def forward(self, x):
61
+ """Apply asymmetric padding before convolution."""
62
+ _, _, h, w = x.shape
63
+
64
+ # Compute asymmetric padding
65
+ pad_left, pad_right, pad_top, pad_bottom = self._compute_padding(h, w)
66
+ x = nn.functional.pad(x, (pad_left, pad_right, pad_top, pad_bottom), value=-float('inf')) # Pad right/bottom by 1 to match JAX's maxpooling padding="SAME"
67
+
68
  return nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)
69
 
70
 
 
195
  hidden_states=hidden_states,
196
  )
197
 
 
 
 
 
198
 
199
  class ResNet10(PreTrainedModel):
200
  config_class = ResNet10Config
 
223
  # return super().__call__(x)
224
  nn.GroupNorm(num_groups=4, eps=1e-5, num_channels=self.config.embedding_size),
225
  ACT2FN[self.config.hidden_act],
226
+ JaxStyleMaxPool(kernel_size=3, stride=2),
227
  )
228
 
229
  self.encoder = Encoder(self.config)