Upload ResNet10
Browse files- config.json +7 -1
- modeling_resnet.py +38 -6
config.json
CHANGED
@@ -1,6 +1,11 @@
|
|
1 |
{
|
|
|
|
|
|
|
|
|
2 |
"auto_map": {
|
3 |
-
"AutoConfig": "configuration_resnet.ResNet10Config"
|
|
|
4 |
},
|
5 |
"depths": [
|
6 |
1,
|
@@ -18,5 +23,6 @@
|
|
18 |
],
|
19 |
"model_type": "resnet10",
|
20 |
"num_channels": 3,
|
|
|
21 |
"transformers_version": "4.48.1"
|
22 |
}
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "lilkm/resnet10",
|
3 |
+
"architectures": [
|
4 |
+
"ResNet10"
|
5 |
+
],
|
6 |
"auto_map": {
|
7 |
+
"AutoConfig": "lilkm/resnet10--configuration_resnet.ResNet10Config",
|
8 |
+
"AutoModel": "modeling_resnet.ResNet10"
|
9 |
},
|
10 |
"depths": [
|
11 |
1,
|
|
|
23 |
],
|
24 |
"model_type": "resnet10",
|
25 |
"num_channels": 3,
|
26 |
+
"torch_dtype": "float32",
|
27 |
"transformers_version": "4.48.1"
|
28 |
}
|
modeling_resnet.py
CHANGED
@@ -27,8 +27,44 @@ import math
|
|
27 |
|
28 |
|
29 |
class JaxStyleMaxPool(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
def forward(self, x):
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
return nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)
|
33 |
|
34 |
|
@@ -159,10 +195,6 @@ class Encoder(nn.Module):
|
|
159 |
hidden_states=hidden_states,
|
160 |
)
|
161 |
|
162 |
-
class JaxStyleMaxPool(nn.Module):
|
163 |
-
def forward(self, x):
|
164 |
-
x = nn.functional.pad(x, (0, 1, 0, 1), value=-float('inf')) # Pad right/bottom by 1
|
165 |
-
return nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)
|
166 |
|
167 |
class ResNet10(PreTrainedModel):
|
168 |
config_class = ResNet10Config
|
@@ -191,7 +223,7 @@ class ResNet10(PreTrainedModel):
|
|
191 |
# return super().__call__(x)
|
192 |
nn.GroupNorm(num_groups=4, eps=1e-5, num_channels=self.config.embedding_size),
|
193 |
ACT2FN[self.config.hidden_act],
|
194 |
-
JaxStyleMaxPool(),
|
195 |
)
|
196 |
|
197 |
self.encoder = Encoder(self.config)
|
|
|
27 |
|
28 |
|
29 |
class JaxStyleMaxPool(nn.Module):
|
30 |
+
"""Mimics JAX's MaxPool with padding='SAME' for exact parity."""
|
31 |
+
|
32 |
+
def __init__(self, kernel_size, stride=2):
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
# Ensure kernel_size and stride are tuples
|
36 |
+
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
|
37 |
+
self.stride = stride if isinstance(stride, tuple) else (stride, stride)
|
38 |
+
|
39 |
+
self.maxpool = nn.MaxPool2d(
|
40 |
+
kernel_size=self.kernel_size,
|
41 |
+
stride=self.stride,
|
42 |
+
padding=0, # No padding
|
43 |
+
)
|
44 |
+
|
45 |
+
def _compute_padding(self, input_height, input_width):
|
46 |
+
"""Calculate asymmetric padding to match JAX's 'SAME' behavior."""
|
47 |
+
|
48 |
+
# Compute padding needed for height and width
|
49 |
+
pad_h = max(0, (math.ceil(input_height / self.stride[0]) - 1) * self.stride[0] + self.kernel_size[0] - input_height)
|
50 |
+
pad_w = max(0, (math.ceil(input_width / self.stride[1]) - 1) * self.stride[1] + self.kernel_size[1] - input_width)
|
51 |
+
|
52 |
+
# Asymmetric padding (JAX-style: more padding on the bottom/right if needed)
|
53 |
+
pad_top = pad_h // 2
|
54 |
+
pad_bottom = pad_h - pad_top
|
55 |
+
pad_left = pad_w // 2
|
56 |
+
pad_right = pad_w - pad_left
|
57 |
+
|
58 |
+
return (pad_left, pad_right, pad_top, pad_bottom)
|
59 |
+
|
60 |
def forward(self, x):
|
61 |
+
"""Apply asymmetric padding before convolution."""
|
62 |
+
_, _, h, w = x.shape
|
63 |
+
|
64 |
+
# Compute asymmetric padding
|
65 |
+
pad_left, pad_right, pad_top, pad_bottom = self._compute_padding(h, w)
|
66 |
+
x = nn.functional.pad(x, (pad_left, pad_right, pad_top, pad_bottom), value=-float('inf')) # Pad right/bottom by 1 to match JAX's maxpooling padding="SAME"
|
67 |
+
|
68 |
return nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)
|
69 |
|
70 |
|
|
|
195 |
hidden_states=hidden_states,
|
196 |
)
|
197 |
|
|
|
|
|
|
|
|
|
198 |
|
199 |
class ResNet10(PreTrainedModel):
|
200 |
config_class = ResNet10Config
|
|
|
223 |
# return super().__call__(x)
|
224 |
nn.GroupNorm(num_groups=4, eps=1e-5, num_channels=self.config.embedding_size),
|
225 |
ACT2FN[self.config.hidden_act],
|
226 |
+
JaxStyleMaxPool(kernel_size=3, stride=2),
|
227 |
)
|
228 |
|
229 |
self.encoder = Encoder(self.config)
|