🔨 [Add] RepNCSPELAN and base modules in module.py
Browse files- yolo/model/module.py +111 -1
yolo/model/module.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from typing import Optional, Tuple
|
2 |
|
3 |
import torch
|
4 |
from torch import Tensor, nn
|
@@ -121,6 +121,116 @@ class RepConv(nn.Module):
|
|
121 |
return self.act(self.conv1(x) + self.conv2(x))
|
122 |
|
123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
# ResNet
|
125 |
class Res(nn.Module):
|
126 |
# ResNet bottleneck
|
|
|
1 |
+
from typing import Any, Dict, Optional, Tuple
|
2 |
|
3 |
import torch
|
4 |
from torch import Tensor, nn
|
|
|
121 |
return self.act(self.conv1(x) + self.conv2(x))
|
122 |
|
123 |
|
124 |
+
class RepNBottleneck(nn.Module):
|
125 |
+
"""A bottleneck block with optional residual connections."""
|
126 |
+
|
127 |
+
def __init__(
|
128 |
+
self,
|
129 |
+
in_channels: int,
|
130 |
+
out_channels: int,
|
131 |
+
*,
|
132 |
+
kernel_size: Tuple[int, int] = (3, 3),
|
133 |
+
residual: bool = True,
|
134 |
+
expand: float = 1.0,
|
135 |
+
**kwargs
|
136 |
+
):
|
137 |
+
super().__init__()
|
138 |
+
neck_channels = int(out_channels * expand)
|
139 |
+
self.conv1 = RepConv(in_channels, neck_channels, kernel_size[0], **kwargs)
|
140 |
+
self.conv2 = Conv(neck_channels, out_channels, kernel_size[1], **kwargs)
|
141 |
+
self.residual = residual
|
142 |
+
|
143 |
+
if residual and (in_channels != out_channels):
|
144 |
+
self.residual = False
|
145 |
+
logging.warning("Residual is turned off since in_channels is not equal to out_channels.")
|
146 |
+
|
147 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
148 |
+
y = self.conv2(self.conv1(x))
|
149 |
+
return x + y if self.residual else y
|
150 |
+
|
151 |
+
|
152 |
+
class RepNCSP(nn.Module):
|
153 |
+
"""RepNCSP block with convolutions, split, and bottleneck processing."""
|
154 |
+
|
155 |
+
def __init__(
|
156 |
+
self,
|
157 |
+
in_channels: int,
|
158 |
+
out_channels: int,
|
159 |
+
kernel_size: int = 1,
|
160 |
+
*,
|
161 |
+
csp_expand: float = 0.5,
|
162 |
+
repeat_num: int = 1,
|
163 |
+
bottleneck_args: Optional[Dict[str, Any]] = None,
|
164 |
+
**kwargs
|
165 |
+
):
|
166 |
+
super().__init__()
|
167 |
+
|
168 |
+
if bottleneck_args is None:
|
169 |
+
bottleneck_args = {"kernel_size": (3, 3), "residual": True, "expand": 0.5}
|
170 |
+
|
171 |
+
neck_channels = int(out_channels * csp_expand)
|
172 |
+
self.conv1 = Conv(in_channels, neck_channels, kernel_size, **kwargs)
|
173 |
+
self.conv2 = Conv(in_channels, neck_channels, kernel_size, **kwargs)
|
174 |
+
self.conv3 = Conv(2 * neck_channels, out_channels, kernel_size, **kwargs)
|
175 |
+
|
176 |
+
self.bottleneck_block = nn.Sequential(
|
177 |
+
*[RepNBottleneck(neck_channels, neck_channels, **bottleneck_args) for _ in range(repeat_num)]
|
178 |
+
)
|
179 |
+
|
180 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
181 |
+
input_features = self.conv1(x)
|
182 |
+
split_features = self.conv2(x)
|
183 |
+
bottleneck_output = self.bottleneck_block(input_features)
|
184 |
+
return self.conv3(torch.cat((bottleneck_output, split_features), dim=1))
|
185 |
+
|
186 |
+
|
187 |
+
class RepNCSPELAN(nn.Module):
|
188 |
+
"""RepNCSPELAN block combining RepNCSP blocks with ELAN structure."""
|
189 |
+
|
190 |
+
def __init__(
|
191 |
+
self,
|
192 |
+
*,
|
193 |
+
in_channels: int,
|
194 |
+
out_channels: int,
|
195 |
+
partition_channels: int,
|
196 |
+
process_channels: int,
|
197 |
+
expand: float,
|
198 |
+
repncsp_args: Optional[Dict[str, Any]] = None,
|
199 |
+
bottleneck_args: Optional[Dict[str, Any]] = None,
|
200 |
+
**kwargs
|
201 |
+
):
|
202 |
+
super().__init__()
|
203 |
+
|
204 |
+
if repncsp_args is None:
|
205 |
+
repncsp_args = {}
|
206 |
+
|
207 |
+
self.conv1 = Conv(in_channels, partition_channels, 1, **kwargs)
|
208 |
+
self.conv2 = nn.Sequential(
|
209 |
+
RepNCSP(
|
210 |
+
partition_channels // 2,
|
211 |
+
process_channels,
|
212 |
+
csp_expand=expand,
|
213 |
+
bottleneck_args=bottleneck_args,
|
214 |
+
**repncsp_args
|
215 |
+
),
|
216 |
+
Conv(process_channels, process_channels, 3, padding=1, **kwargs),
|
217 |
+
)
|
218 |
+
self.conv3 = nn.Sequential(
|
219 |
+
RepNCSP(
|
220 |
+
process_channels, process_channels, csp_expand=expand, bottleneck_args=bottleneck_args, **repncsp_args
|
221 |
+
),
|
222 |
+
Conv(process_channels, process_channels, 3, padding=1, **kwargs),
|
223 |
+
)
|
224 |
+
self.conv4 = Conv(partition_channels + 2 * process_channels, out_channels, 1, **kwargs)
|
225 |
+
|
226 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
227 |
+
partition1, partition2 = self.conv1(x).chunk(2, 1)
|
228 |
+
csp_output1 = self.conv2(partition2)
|
229 |
+
csp_output2 = self.conv3(csp_output1)
|
230 |
+
concat = torch.cat([partition1, partition2, csp_output1, csp_output2], dim=1)
|
231 |
+
return self.conv4(concat)
|
232 |
+
|
233 |
+
|
234 |
# ResNet
|
235 |
class Res(nn.Module):
|
236 |
# ResNet bottleneck
|