🔨 [Add] weight initailize for Detection
Browse files- yolo/model/module.py +4 -1
- yolo/tools/trainer.py +1 -1
yolo/model/module.py
CHANGED
@@ -25,7 +25,7 @@ class Conv(nn.Module):
|
|
25 |
super().__init__()
|
26 |
kwargs.setdefault("padding", auto_pad(kernel_size, **kwargs))
|
27 |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs)
|
28 |
-
self.bn = nn.BatchNorm2d(out_channels)
|
29 |
self.act = get_activation(activation)
|
30 |
|
31 |
def forward(self, x: Tensor) -> Tensor:
|
@@ -69,6 +69,9 @@ class Detection(nn.Module):
|
|
69 |
Conv(in_channels, class_neck, 3), Conv(class_neck, class_neck, 3), nn.Conv2d(class_neck, num_classes, 1)
|
70 |
)
|
71 |
|
|
|
|
|
|
|
72 |
def forward(self, x: List[Tensor]) -> List[Tensor]:
|
73 |
anchor_x = self.anchor_conv(x)
|
74 |
class_x = self.class_conv(x)
|
|
|
25 |
super().__init__()
|
26 |
kwargs.setdefault("padding", auto_pad(kernel_size, **kwargs))
|
27 |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs)
|
28 |
+
self.bn = nn.BatchNorm2d(out_channels, eps=1e-3, momentum=3e-2)
|
29 |
self.act = get_activation(activation)
|
30 |
|
31 |
def forward(self, x: Tensor) -> Tensor:
|
|
|
69 |
Conv(in_channels, class_neck, 3), Conv(class_neck, class_neck, 3), nn.Conv2d(class_neck, num_classes, 1)
|
70 |
)
|
71 |
|
72 |
+
self.anchor_conv[-1].bias.data.fill_(1.0)
|
73 |
+
self.class_conv[-1].bias.data.fill_(-10)
|
74 |
+
|
75 |
def forward(self, x: List[Tensor]) -> List[Tensor]:
|
76 |
anchor_x = self.anchor_conv(x)
|
77 |
class_x = self.class_conv(x)
|
yolo/tools/trainer.py
CHANGED
@@ -79,7 +79,7 @@ class Trainer:
|
|
79 |
self.progress.start_train(num_epochs)
|
80 |
for epoch in range(num_epochs):
|
81 |
|
82 |
-
epoch_loss = self.train_one_epoch(dataloader
|
83 |
self.progress.one_epoch()
|
84 |
|
85 |
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
|
|
|
79 |
self.progress.start_train(num_epochs)
|
80 |
for epoch in range(num_epochs):
|
81 |
|
82 |
+
epoch_loss = self.train_one_epoch(dataloader)
|
83 |
self.progress.one_epoch()
|
84 |
|
85 |
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
|