henry000 commited on
Commit
c601a4c
·
1 Parent(s): 995ae20

🔨 [Add] weight initailize for Detection

Browse files
Files changed (2) hide show
  1. yolo/model/module.py +4 -1
  2. yolo/tools/trainer.py +1 -1
yolo/model/module.py CHANGED
@@ -25,7 +25,7 @@ class Conv(nn.Module):
25
  super().__init__()
26
  kwargs.setdefault("padding", auto_pad(kernel_size, **kwargs))
27
  self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs)
28
- self.bn = nn.BatchNorm2d(out_channels)
29
  self.act = get_activation(activation)
30
 
31
  def forward(self, x: Tensor) -> Tensor:
@@ -69,6 +69,9 @@ class Detection(nn.Module):
69
  Conv(in_channels, class_neck, 3), Conv(class_neck, class_neck, 3), nn.Conv2d(class_neck, num_classes, 1)
70
  )
71
 
 
 
 
72
  def forward(self, x: List[Tensor]) -> List[Tensor]:
73
  anchor_x = self.anchor_conv(x)
74
  class_x = self.class_conv(x)
 
25
  super().__init__()
26
  kwargs.setdefault("padding", auto_pad(kernel_size, **kwargs))
27
  self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs)
28
+ self.bn = nn.BatchNorm2d(out_channels, eps=1e-3, momentum=3e-2)
29
  self.act = get_activation(activation)
30
 
31
  def forward(self, x: Tensor) -> Tensor:
 
69
  Conv(in_channels, class_neck, 3), Conv(class_neck, class_neck, 3), nn.Conv2d(class_neck, num_classes, 1)
70
  )
71
 
72
+ self.anchor_conv[-1].bias.data.fill_(1.0)
73
+ self.class_conv[-1].bias.data.fill_(-10)
74
+
75
  def forward(self, x: List[Tensor]) -> List[Tensor]:
76
  anchor_x = self.anchor_conv(x)
77
  class_x = self.class_conv(x)
yolo/tools/trainer.py CHANGED
@@ -79,7 +79,7 @@ class Trainer:
79
  self.progress.start_train(num_epochs)
80
  for epoch in range(num_epochs):
81
 
82
- epoch_loss = self.train_one_epoch(dataloader, self.progress)
83
  self.progress.one_epoch()
84
 
85
  logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
 
79
  self.progress.start_train(num_epochs)
80
  for epoch in range(num_epochs):
81
 
82
+ epoch_loss = self.train_one_epoch(dataloader)
83
  self.progress.one_epoch()
84
 
85
  logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")