lilkm commited on
Commit
c3601db
·
verified ·
1 Parent(s): c9417a6

Upload ResNet10

Browse files
Files changed (3) hide show
  1. config.json +7 -1
  2. model.safetensors +3 -0
  3. modeling_resnet.py +214 -0
config.json CHANGED
@@ -1,6 +1,11 @@
1
  {
 
 
 
 
2
  "auto_map": {
3
- "AutoConfig": "configuration_resnet.ResNet10Config"
 
4
  },
5
  "depths": [
6
  1,
@@ -18,5 +23,6 @@
18
  ],
19
  "model_type": "resnet10",
20
  "num_channels": 3,
 
21
  "transformers_version": "4.48.1"
22
  }
 
1
  {
2
+ "_name_or_path": "lilkm/resnet10",
3
+ "architectures": [
4
+ "ResNet10"
5
+ ],
6
  "auto_map": {
7
+ "AutoConfig": "lilkm/resnet10--configuration_resnet.ResNet10Config",
8
+ "AutoModel": "modeling_resnet.ResNet10"
9
  },
10
  "depths": [
11
  1,
 
23
  ],
24
  "model_type": "resnet10",
25
  "num_channels": 3,
26
+ "torch_dtype": "float32",
27
  "transformers_version": "4.48.1"
28
  }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10f7d125770aa256bd45ec9e4f586ca1157e29380fa1306d14a025664ae173d0
3
+ size 19626736
modeling_resnet.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -----------------------------------------------------------------------------
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # -----------------------------------------------------------------------------
15
+
16
+ from typing import Optional
17
+
18
+ import torch.nn as nn
19
+ from torch import Tensor
20
+ from transformers import PreTrainedModel
21
+ from transformers.activations import ACT2FN
22
+ from transformers.modeling_outputs import BaseModelOutputWithNoAttention
23
+
24
+ from .configuration_resnet import ResNet10Config
25
+
26
+ import math
27
+
28
+
29
+ class JaxStyleMaxPool(nn.Module):
30
+ def forward(self, x):
31
+ x = nn.functional.pad(x, (0, 1, 0, 1), value=-float('inf')) # Pad right/bottom by 1 to match JAX's maxpooling padding="SAME"
32
+ return nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)
33
+
34
+
35
+ class JaxStyleConv2d(nn.Module):
36
+ """Mimics JAX's Conv2D with padding='SAME' for exact parity."""
37
+
38
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=False):
39
+ super().__init__()
40
+
41
+ # Ensure kernel_size and stride are tuples
42
+ self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
43
+ self.stride = stride if isinstance(stride, tuple) else (stride, stride)
44
+
45
+ self.conv = nn.Conv2d(
46
+ in_channels, out_channels,
47
+ kernel_size=self.kernel_size,
48
+ stride=self.stride,
49
+ padding=0, # No padding
50
+ bias=bias
51
+ )
52
+
53
+ def _compute_padding(self, input_height, input_width):
54
+ """Calculate asymmetric padding to match JAX's 'SAME' behavior."""
55
+
56
+ # Compute padding needed for height and width
57
+ pad_h = max(0, (math.ceil(input_height / self.stride[0]) - 1) * self.stride[0] + self.kernel_size[0] - input_height)
58
+ pad_w = max(0, (math.ceil(input_width / self.stride[1]) - 1) * self.stride[1] + self.kernel_size[1] - input_width)
59
+
60
+ # Asymmetric padding (JAX-style: more padding on the bottom/right if needed)
61
+ pad_top = pad_h // 2
62
+ pad_bottom = pad_h - pad_top
63
+ pad_left = pad_w // 2
64
+ pad_right = pad_w - pad_left
65
+
66
+ return (pad_left, pad_right, pad_top, pad_bottom)
67
+
68
+ def forward(self, x):
69
+ """Apply asymmetric padding before convolution."""
70
+ _, _, h, w = x.shape
71
+
72
+ # Compute asymmetric padding
73
+ pad_left, pad_right, pad_top, pad_bottom = self._compute_padding(h, w)
74
+ x = nn.functional.pad(x, (pad_left, pad_right, pad_top, pad_bottom))
75
+
76
+ return self.conv(x)
77
+
78
+
79
+ class BasicBlock(nn.Module):
80
+ def __init__(self, in_channels, out_channels, activation, stride=1, norm_groups=4):
81
+ super().__init__()
82
+
83
+ self.conv1 = JaxStyleConv2d(
84
+ in_channels,
85
+ out_channels,
86
+ kernel_size=3,
87
+ stride=stride,
88
+ bias=False,
89
+ )
90
+ self.norm1 = nn.GroupNorm(num_groups=norm_groups, num_channels=out_channels)
91
+ self.act1 = ACT2FN[activation]
92
+ self.act2 = ACT2FN[activation]
93
+ self.conv2 = JaxStyleConv2d(out_channels, out_channels, kernel_size=3, stride=1, bias=False)
94
+ self.norm2 = nn.GroupNorm(num_groups=norm_groups, num_channels=out_channels)
95
+
96
+ self.shortcut = None
97
+ if in_channels != out_channels:
98
+ self.shortcut = nn.Sequential(
99
+ JaxStyleConv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
100
+ nn.GroupNorm(num_groups=norm_groups, num_channels=out_channels),
101
+ )
102
+
103
+ def forward(self, x):
104
+ identity = x
105
+
106
+ out = self.conv1(x)
107
+ out = self.norm1(out)
108
+ out = self.act1(out)
109
+
110
+ out = self.conv2(out)
111
+ out = self.norm2(out)
112
+
113
+ if self.shortcut is not None:
114
+ identity = self.shortcut(identity)
115
+
116
+ out += identity
117
+ return self.act2(out)
118
+
119
+
120
+ class Encoder(nn.Module):
121
+ def __init__(self, config: ResNet10Config):
122
+ super().__init__()
123
+ self.config = config
124
+ self.stages = nn.ModuleList([])
125
+
126
+ for i, size in enumerate(self.config.hidden_sizes):
127
+ if i == 0:
128
+ self.stages.append(
129
+ BasicBlock(
130
+ self.config.embedding_size,
131
+ size,
132
+ activation=self.config.hidden_act,
133
+ )
134
+ )
135
+ else:
136
+ self.stages.append(
137
+ BasicBlock(
138
+ self.config.hidden_sizes[i - 1],
139
+ size,
140
+ activation=self.config.hidden_act,
141
+ stride=2,
142
+ )
143
+ )
144
+
145
+ def forward(self, hidden_state: Tensor, output_hidden_states: bool = False) -> BaseModelOutputWithNoAttention:
146
+ hidden_states = () if output_hidden_states else None
147
+
148
+ for stage in self.stages:
149
+ if output_hidden_states:
150
+ hidden_states = hidden_states + (hidden_state,)
151
+
152
+ hidden_state = stage(hidden_state)
153
+
154
+ if output_hidden_states:
155
+ hidden_states = hidden_states + (hidden_state,)
156
+
157
+ return BaseModelOutputWithNoAttention(
158
+ last_hidden_state=hidden_state,
159
+ hidden_states=hidden_states,
160
+ )
161
+
162
+ class JaxStyleMaxPool(nn.Module):
163
+ def forward(self, x):
164
+ x = nn.functional.pad(x, (0, 1, 0, 1), value=-float('inf')) # Pad right/bottom by 1
165
+ return nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)
166
+
167
+ class ResNet10(PreTrainedModel):
168
+ config_class = ResNet10Config
169
+
170
+ def __init__(self, config):
171
+ super().__init__(config)
172
+
173
+ self.embedder = nn.Sequential(
174
+ nn.Conv2d(
175
+ self.config.num_channels,
176
+ self.config.embedding_size,
177
+ kernel_size=7,
178
+ stride=2,
179
+ padding=3,
180
+ bias=False,
181
+ ),
182
+ # The original code has a small trick -
183
+ # https://github.com/rail-berkeley/hil-serl/blob/main/serl_launcher/serl_launcher/vision/resnet_v1.py#L119
184
+ # class MyGroupNorm(nn.GroupNorm):
185
+ # def __call__(self, x):
186
+ # if x.ndim == 3:
187
+ # x = x[jnp.newaxis]
188
+ # x = super().__call__(x)
189
+ # return x[0]
190
+ # else:
191
+ # return super().__call__(x)
192
+ nn.GroupNorm(num_groups=4, eps=1e-5, num_channels=self.config.embedding_size),
193
+ ACT2FN[self.config.hidden_act],
194
+ JaxStyleMaxPool(),
195
+ )
196
+
197
+ self.encoder = Encoder(self.config)
198
+
199
+ def forward(self, x: Tensor, output_hidden_states: Optional[bool] = None) -> BaseModelOutputWithNoAttention:
200
+ output_hidden_states = (
201
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
202
+ )
203
+ embedding_output = self.embedder(x)
204
+ encoder_outputs = self.encoder(embedding_output, output_hidden_states=output_hidden_states)
205
+
206
+ return BaseModelOutputWithNoAttention(
207
+ last_hidden_state=encoder_outputs.last_hidden_state,
208
+ hidden_states=encoder_outputs.hidden_states,
209
+ )
210
+
211
+ def print_model_hash(self):
212
+ print("Model parameters hashes:")
213
+ for name, param in self.named_parameters():
214
+ print(name, param.sum())