henry000 commited on
Commit
a3f8ecb
Β·
1 Parent(s): 0f9ffa2

πŸ› [Fix] params mismatch with origin v9

Browse files
Files changed (1) hide show
  1. yolo/model/module.py +11 -5
yolo/model/module.py CHANGED
@@ -49,14 +49,16 @@ class Pool(nn.Module):
49
  class Detection(nn.Module):
50
  """A single YOLO Detection head for detection models"""
51
 
52
- def __init__(self, in_channels: int, num_classes: int, *, reg_max: int = 16, use_group: bool = True):
53
  super().__init__()
54
 
55
  groups = 4 if use_group else 1
56
  anchor_channels = 4 * reg_max
 
 
57
  # TODO: round up head[0] channels or each head?
58
- anchor_neck = max(round_up(in_channels // 4, groups), anchor_channels, 16)
59
- class_neck = max(in_channels, min(num_classes * 2, 128))
60
 
61
  self.anchor_conv = nn.Sequential(
62
  Conv(in_channels, anchor_neck, 3),
@@ -78,8 +80,12 @@ class MultiheadDetection(nn.Module):
78
 
79
  def __init__(self, in_channels: List[int], num_classes: int, **head_kwargs):
80
  super().__init__()
 
81
  self.heads = nn.ModuleList(
82
- [Detection(head_in_channels, num_classes, **head_kwargs) for head_in_channels in in_channels]
 
 
 
83
  )
84
 
85
  def forward(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
@@ -118,7 +124,7 @@ class RepNBottleneck(nn.Module):
118
  *,
119
  kernel_size: Tuple[int, int] = (3, 3),
120
  residual: bool = True,
121
- expand: float = 0.5,
122
  **kwargs
123
  ):
124
  super().__init__()
 
49
  class Detection(nn.Module):
50
  """A single YOLO Detection head for detection models"""
51
 
52
+ def __init__(self, in_channels: Tuple[int], num_classes: int, *, reg_max: int = 16, use_group: bool = True):
53
  super().__init__()
54
 
55
  groups = 4 if use_group else 1
56
  anchor_channels = 4 * reg_max
57
+
58
+ first_neck, in_channels = in_channels
59
  # TODO: round up head[0] channels or each head?
60
+ anchor_neck = max(round_up(first_neck // 4, groups), anchor_channels, 16)
61
+ class_neck = max(first_neck, min(num_classes * 2, 128))
62
 
63
  self.anchor_conv = nn.Sequential(
64
  Conv(in_channels, anchor_neck, 3),
 
80
 
81
  def __init__(self, in_channels: List[int], num_classes: int, **head_kwargs):
82
  super().__init__()
83
+ # TODO: Refactor these parts
84
  self.heads = nn.ModuleList(
85
+ [
86
+ Detection((in_channels[3 * (idx // 3)], in_channel), num_classes, **head_kwargs)
87
+ for idx, in_channel in enumerate(in_channels)
88
+ ]
89
  )
90
 
91
  def forward(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
 
124
  *,
125
  kernel_size: Tuple[int, int] = (3, 3),
126
  residual: bool = True,
127
+ expand: float = 1.0,
128
  **kwargs
129
  ):
130
  super().__init__()