π [Fix] params mismatch with origin v9
Browse files- yolo/model/module.py +11 -5
yolo/model/module.py
CHANGED
@@ -49,14 +49,16 @@ class Pool(nn.Module):
|
|
49 |
class Detection(nn.Module):
|
50 |
"""A single YOLO Detection head for detection models"""
|
51 |
|
52 |
-
def __init__(self, in_channels: int, num_classes: int, *, reg_max: int = 16, use_group: bool = True):
|
53 |
super().__init__()
|
54 |
|
55 |
groups = 4 if use_group else 1
|
56 |
anchor_channels = 4 * reg_max
|
|
|
|
|
57 |
# TODO: round up head[0] channels or each head?
|
58 |
-
anchor_neck = max(round_up(
|
59 |
-
class_neck = max(
|
60 |
|
61 |
self.anchor_conv = nn.Sequential(
|
62 |
Conv(in_channels, anchor_neck, 3),
|
@@ -78,8 +80,12 @@ class MultiheadDetection(nn.Module):
|
|
78 |
|
79 |
def __init__(self, in_channels: List[int], num_classes: int, **head_kwargs):
|
80 |
super().__init__()
|
|
|
81 |
self.heads = nn.ModuleList(
|
82 |
-
[
|
|
|
|
|
|
|
83 |
)
|
84 |
|
85 |
def forward(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
|
@@ -118,7 +124,7 @@ class RepNBottleneck(nn.Module):
|
|
118 |
*,
|
119 |
kernel_size: Tuple[int, int] = (3, 3),
|
120 |
residual: bool = True,
|
121 |
-
expand: float = 0
|
122 |
**kwargs
|
123 |
):
|
124 |
super().__init__()
|
|
|
49 |
class Detection(nn.Module):
|
50 |
"""A single YOLO Detection head for detection models"""
|
51 |
|
52 |
+
def __init__(self, in_channels: Tuple[int], num_classes: int, *, reg_max: int = 16, use_group: bool = True):
|
53 |
super().__init__()
|
54 |
|
55 |
groups = 4 if use_group else 1
|
56 |
anchor_channels = 4 * reg_max
|
57 |
+
|
58 |
+
first_neck, in_channels = in_channels
|
59 |
# TODO: round up head[0] channels or each head?
|
60 |
+
anchor_neck = max(round_up(first_neck // 4, groups), anchor_channels, 16)
|
61 |
+
class_neck = max(first_neck, min(num_classes * 2, 128))
|
62 |
|
63 |
self.anchor_conv = nn.Sequential(
|
64 |
Conv(in_channels, anchor_neck, 3),
|
|
|
80 |
|
81 |
def __init__(self, in_channels: List[int], num_classes: int, **head_kwargs):
|
82 |
super().__init__()
|
83 |
+
# TODO: Refactor these parts
|
84 |
self.heads = nn.ModuleList(
|
85 |
+
[
|
86 |
+
Detection((in_channels[3 * (idx // 3)], in_channel), num_classes, **head_kwargs)
|
87 |
+
for idx, in_channel in enumerate(in_channels)
|
88 |
+
]
|
89 |
)
|
90 |
|
91 |
def forward(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
|
|
|
124 |
*,
|
125 |
kernel_size: Tuple[int, int] = (3, 3),
|
126 |
residual: bool = True,
|
127 |
+
expand: float = 1.0,
|
128 |
**kwargs
|
129 |
):
|
130 |
super().__init__()
|