lucytuan commited on
Commit
d1aff91
·
1 Parent(s): 9a3d99f

🔨 [Add] RepNCSPELAN and base modules in module.py

Browse files
Files changed (1) hide show
  1. yolo/model/module.py +111 -1
yolo/model/module.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Optional, Tuple
2
 
3
  import torch
4
  from torch import Tensor, nn
@@ -121,6 +121,116 @@ class RepConv(nn.Module):
121
  return self.act(self.conv1(x) + self.conv2(x))
122
 
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  # ResNet
125
  class Res(nn.Module):
126
  # ResNet bottleneck
 
1
+ from typing import Any, Dict, Optional, Tuple
2
 
3
  import torch
4
  from torch import Tensor, nn
 
121
  return self.act(self.conv1(x) + self.conv2(x))
122
 
123
 
124
+ class RepNBottleneck(nn.Module):
125
+ """A bottleneck block with optional residual connections."""
126
+
127
+ def __init__(
128
+ self,
129
+ in_channels: int,
130
+ out_channels: int,
131
+ *,
132
+ kernel_size: Tuple[int, int] = (3, 3),
133
+ residual: bool = True,
134
+ expand: float = 1.0,
135
+ **kwargs
136
+ ):
137
+ super().__init__()
138
+ neck_channels = int(out_channels * expand)
139
+ self.conv1 = RepConv(in_channels, neck_channels, kernel_size[0], **kwargs)
140
+ self.conv2 = Conv(neck_channels, out_channels, kernel_size[1], **kwargs)
141
+ self.residual = residual
142
+
143
+ if residual and (in_channels != out_channels):
144
+ self.residual = False
145
+ logging.warning("Residual is turned off since in_channels is not equal to out_channels.")
146
+
147
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
148
+ y = self.conv2(self.conv1(x))
149
+ return x + y if self.residual else y
150
+
151
+
152
+ class RepNCSP(nn.Module):
153
+ """RepNCSP block with convolutions, split, and bottleneck processing."""
154
+
155
+ def __init__(
156
+ self,
157
+ in_channels: int,
158
+ out_channels: int,
159
+ kernel_size: int = 1,
160
+ *,
161
+ csp_expand: float = 0.5,
162
+ repeat_num: int = 1,
163
+ bottleneck_args: Optional[Dict[str, Any]] = None,
164
+ **kwargs
165
+ ):
166
+ super().__init__()
167
+
168
+ if bottleneck_args is None:
169
+ bottleneck_args = {"kernel_size": (3, 3), "residual": True, "expand": 0.5}
170
+
171
+ neck_channels = int(out_channels * csp_expand)
172
+ self.conv1 = Conv(in_channels, neck_channels, kernel_size, **kwargs)
173
+ self.conv2 = Conv(in_channels, neck_channels, kernel_size, **kwargs)
174
+ self.conv3 = Conv(2 * neck_channels, out_channels, kernel_size, **kwargs)
175
+
176
+ self.bottleneck_block = nn.Sequential(
177
+ *[RepNBottleneck(neck_channels, neck_channels, **bottleneck_args) for _ in range(repeat_num)]
178
+ )
179
+
180
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
181
+ input_features = self.conv1(x)
182
+ split_features = self.conv2(x)
183
+ bottleneck_output = self.bottleneck_block(input_features)
184
+ return self.conv3(torch.cat((bottleneck_output, split_features), dim=1))
185
+
186
+
187
+ class RepNCSPELAN(nn.Module):
188
+ """RepNCSPELAN block combining RepNCSP blocks with ELAN structure."""
189
+
190
+ def __init__(
191
+ self,
192
+ *,
193
+ in_channels: int,
194
+ out_channels: int,
195
+ partition_channels: int,
196
+ process_channels: int,
197
+ expand: float,
198
+ repncsp_args: Optional[Dict[str, Any]] = None,
199
+ bottleneck_args: Optional[Dict[str, Any]] = None,
200
+ **kwargs
201
+ ):
202
+ super().__init__()
203
+
204
+ if repncsp_args is None:
205
+ repncsp_args = {}
206
+
207
+ self.conv1 = Conv(in_channels, partition_channels, 1, **kwargs)
208
+ self.conv2 = nn.Sequential(
209
+ RepNCSP(
210
+ partition_channels // 2,
211
+ process_channels,
212
+ csp_expand=expand,
213
+ bottleneck_args=bottleneck_args,
214
+ **repncsp_args
215
+ ),
216
+ Conv(process_channels, process_channels, 3, padding=1, **kwargs),
217
+ )
218
+ self.conv3 = nn.Sequential(
219
+ RepNCSP(
220
+ process_channels, process_channels, csp_expand=expand, bottleneck_args=bottleneck_args, **repncsp_args
221
+ ),
222
+ Conv(process_channels, process_channels, 3, padding=1, **kwargs),
223
+ )
224
+ self.conv4 = Conv(partition_channels + 2 * process_channels, out_channels, 1, **kwargs)
225
+
226
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
227
+ partition1, partition2 = self.conv1(x).chunk(2, 1)
228
+ csp_output1 = self.conv2(partition2)
229
+ csp_output2 = self.conv3(csp_output1)
230
+ concat = torch.cat([partition1, partition2, csp_output1, csp_output2], dim=1)
231
+ return self.conv4(concat)
232
+
233
+
234
  # ResNet
235
  class Res(nn.Module):
236
  # ResNet bottleneck