henry000 commited on
Commit
00c86de
·
1 Parent(s): b679c1b

✨ [Add] v9-c-segment model, inference WIP

Browse files
Files changed (2) hide show
  1. yolo/model/module.py +32 -0
  2. yolo/model/yolo.py +1 -1
yolo/model/module.py CHANGED
@@ -130,6 +130,38 @@ class MultiheadDetection(nn.Module):
130
  return [head(x) for x, head in zip(x_list, self.heads)]
131
 
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  class Anchor2Vec(nn.Module):
134
  def __init__(self, reg_max: int = 16) -> None:
135
  super().__init__()
 
130
  return [head(x) for x, head in zip(x_list, self.heads)]
131
 
132
 
133
+ class Segmentation(nn.Module):
134
+ def __init__(self, in_channels: Tuple[int], num_maskes: int):
135
+ super().__init__()
136
+ first_neck, in_channels = in_channels
137
+
138
+ mask_neck = max(first_neck // 4, num_maskes)
139
+ self.mask_conv = nn.Sequential(
140
+ Conv(in_channels, mask_neck, 3), Conv(mask_neck, mask_neck, 3), nn.Conv2d(mask_neck, num_maskes, 1)
141
+ )
142
+
143
+ def forward(self, x: Tensor) -> Tuple[Tensor]:
144
+ x = self.mask_conv(x)
145
+ return x
146
+
147
+
148
+ class MultiheadSegmentation(nn.Module):
149
+ """Mutlihead Segmentation module for Dual segment or Triple segment"""
150
+
151
+ def __init__(self, in_channels: List[int], num_classes: int, num_maskes: int, **head_kwargs):
152
+ super().__init__()
153
+ mask_channels, proto_channels = in_channels[:-1], in_channels[-1]
154
+
155
+ self.detect = MultiheadDetection(mask_channels, num_classes, **head_kwargs)
156
+ self.heads = nn.ModuleList(
157
+ [Segmentation((in_channels[0], in_channel), num_maskes) for in_channel in mask_channels]
158
+ )
159
+ self.heads.append(Conv(proto_channels, num_maskes, 1))
160
+
161
+ def forward(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
162
+ return [head(x) for x, head in zip(x_list, self.heads)]
163
+
164
+
165
  class Anchor2Vec(nn.Module):
166
  def __init__(self, reg_max: int = 16) -> None:
167
  super().__init__()
yolo/model/yolo.py CHANGED
@@ -45,7 +45,7 @@ class YOLO(nn.Module):
45
  # Find in channels
46
  if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv", "CBLinear"]):
47
  layer_args["in_channels"] = output_dim[source]
48
- if "Detection" in layer_type:
49
  layer_args["in_channels"] = [output_dim[idx] for idx in source]
50
  layer_args["num_classes"] = self.num_classes
51
  layer_args["reg_max"] = self.reg_max
 
45
  # Find in channels
46
  if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv", "CBLinear"]):
47
  layer_args["in_channels"] = output_dim[source]
48
+ if "Detection" in layer_type or "Segmentation" in layer_type:
49
  layer_args["in_channels"] = [output_dim[idx] for idx in source]
50
  layer_args["num_classes"] = self.num_classes
51
  layer_args["reg_max"] = self.reg_max