henry000 commited on
Commit
38f4931
Β·
1 Parent(s): cbbfcfe

πŸ› [Fix] the batch size of loss function

Browse files
Files changed (1) hide show
  1. yolo/utils/loss.py +1 -2
yolo/utils/loss.py CHANGED
@@ -4,7 +4,6 @@ from typing import Any, List, Tuple
4
  import torch
5
  import torch.nn.functional as F
6
  from einops import rearrange
7
- from hydra import main
8
  from loguru import logger
9
  from torch import Tensor, nn
10
  from torch.nn import BCEWithLogitsLoss
@@ -144,7 +143,7 @@ class YOLOLoss:
144
  # Batch_Size x (Anchor + Class) x H x W
145
  # TODO: check datatype, why targets has a little bit error with origin version
146
  predicts, predicts_anc = self.parse_predicts(predicts[0])
147
- targets = self.parse_targets(targets)
148
 
149
  align_targets, valid_masks = self.matcher(targets, predicts)
150
  # calculate loss between with instance and predict
 
4
  import torch
5
  import torch.nn.functional as F
6
  from einops import rearrange
 
7
  from loguru import logger
8
  from torch import Tensor, nn
9
  from torch.nn import BCEWithLogitsLoss
 
143
  # Batch_Size x (Anchor + Class) x H x W
144
  # TODO: check datatype, why targets has a little bit error with origin version
145
  predicts, predicts_anc = self.parse_predicts(predicts[0])
146
+ targets = self.parse_targets(targets, batch_size=predicts.size(0))
147
 
148
  align_targets, valid_masks = self.matcher(targets, predicts)
149
  # calculate loss between with instance and predict