✨ [Add] Detection Head and Multiple Head class
Browse files- yolo/model/module.py +43 -2
yolo/model/module.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
-
from typing import Optional, Tuple
|
2 |
|
3 |
import torch
|
4 |
from torch import Tensor, nn
|
5 |
from torch.nn.common_types import _size_2_t
|
6 |
|
7 |
-
from yolo.tools.module_helper import auto_pad, get_activation
|
8 |
|
9 |
|
10 |
class Conv(nn.Module):
|
@@ -99,6 +99,47 @@ class SPPELAN(nn.Module):
|
|
99 |
#### -- ####
|
100 |
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
# RepVGG
|
103 |
class RepConv(nn.Module):
|
104 |
# https://github.com/DingXiaoH/RepVGG
|
|
|
1 |
+
from typing import List, Optional, Tuple
|
2 |
|
3 |
import torch
|
4 |
from torch import Tensor, nn
|
5 |
from torch.nn.common_types import _size_2_t
|
6 |
|
7 |
+
from yolo.tools.module_helper import auto_pad, get_activation, round_up
|
8 |
|
9 |
|
10 |
class Conv(nn.Module):
|
|
|
99 |
#### -- ####
|
100 |
|
101 |
|
102 |
+
class Detection(nn.Module):
|
103 |
+
"""A single YOLO Detection head for detection models"""
|
104 |
+
|
105 |
+
def __init__(self, in_channels: int, num_classes: int, *, reg_max: int = 16, use_group: bool = True):
|
106 |
+
super().__init__()
|
107 |
+
|
108 |
+
groups = 4 if use_group else 1
|
109 |
+
anchor_channels = 4 * reg_max
|
110 |
+
# TODO: round up head[0] channels or each head?
|
111 |
+
anchor_neck = max(round_up(in_channels // 4, groups), anchor_channels, 16)
|
112 |
+
class_neck = max(in_channels, min(num_classes * 2, 128))
|
113 |
+
|
114 |
+
self.anchor_conv = nn.Sequential(
|
115 |
+
Conv(in_channels, anchor_neck, 3),
|
116 |
+
Conv(anchor_neck, anchor_neck, 3, groups=groups),
|
117 |
+
nn.Conv2d(anchor_neck, anchor_channels, 1, groups=groups),
|
118 |
+
)
|
119 |
+
self.class_conv = nn.Sequential(
|
120 |
+
Conv(in_channels, class_neck, 3), Conv(class_neck, class_neck, 3), nn.Conv2d(class_neck, num_classes, 1)
|
121 |
+
)
|
122 |
+
|
123 |
+
def forward(self, x: List[Tensor]) -> List[Tensor]:
|
124 |
+
anchor_x = self.anchor_conv(x)
|
125 |
+
class_x = self.class_conv(x)
|
126 |
+
return torch.cat([anchor_x, class_x], dim=1)
|
127 |
+
|
128 |
+
|
129 |
+
class MultiheadDetection(nn.Module):
|
130 |
+
"""Mutlihead Detection module for Dual detect or Triple detect"""
|
131 |
+
|
132 |
+
def __init__(self, in_channels: List[int], num_classes: int, **head_kwargs):
|
133 |
+
super().__init__()
|
134 |
+
self.heads = nn.ModuleList(
|
135 |
+
[Detection(head_in_channels, num_classes, **head_kwargs) for head_in_channels in in_channels]
|
136 |
+
)
|
137 |
+
|
138 |
+
def forward(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
|
139 |
+
return [head(x) for x, head in zip(x_list, self.heads)]
|
140 |
+
|
141 |
+
|
142 |
+
#### -- ####
|
143 |
# RepVGG
|
144 |
class RepConv(nn.Module):
|
145 |
# https://github.com/DingXiaoH/RepVGG
|