🔨 [Add] CBFuse, last v9 module is uploaded
Browse files- yolo/model/module.py +15 -0
yolo/model/module.py
CHANGED
@@ -269,6 +269,21 @@ class UpSample(nn.Module):
|
|
269 |
return self.UpSample(x)
|
270 |
|
271 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
############# Waiting For Refactor #############
|
273 |
# ResNet
|
274 |
class Res(nn.Module):
|
|
|
269 |
return self.UpSample(x)
|
270 |
|
271 |
|
272 |
+
class CBFuse(nn.Module):
|
273 |
+
def __init__(self, index: List[int], mode: str = "nearest"):
|
274 |
+
super().__init__()
|
275 |
+
self.idx = index
|
276 |
+
self.mode = mode
|
277 |
+
|
278 |
+
def forward(self, x_list: List[torch.Tensor]) -> List[Tensor]:
|
279 |
+
target = x_list[-1]
|
280 |
+
target_size = target.shape[2:] # Batch, Channel, H, W
|
281 |
+
|
282 |
+
res = [F.interpolate(x[pick_id], size=target_size, mode=self.mode) for pick_id, x in zip(self.idx, x_list)]
|
283 |
+
out = torch.stack(res + [target]).sum(dim=0)
|
284 |
+
return out
|
285 |
+
|
286 |
+
|
287 |
############# Waiting For Refactor #############
|
288 |
# ResNet
|
289 |
class Res(nn.Module):
|