π [Fix] the batch size of loss function
Browse files- yolo/utils/loss.py +1 -2
yolo/utils/loss.py
CHANGED
@@ -4,7 +4,6 @@ from typing import Any, List, Tuple
|
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
6 |
from einops import rearrange
|
7 |
-
from hydra import main
|
8 |
from loguru import logger
|
9 |
from torch import Tensor, nn
|
10 |
from torch.nn import BCEWithLogitsLoss
|
@@ -144,7 +143,7 @@ class YOLOLoss:
|
|
144 |
# Batch_Size x (Anchor + Class) x H x W
|
145 |
# TODO: check datatype, why targets has a little bit error with origin version
|
146 |
predicts, predicts_anc = self.parse_predicts(predicts[0])
|
147 |
-
targets = self.parse_targets(targets)
|
148 |
|
149 |
align_targets, valid_masks = self.matcher(targets, predicts)
|
150 |
# calculate loss between with instance and predict
|
|
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
6 |
from einops import rearrange
|
|
|
7 |
from loguru import logger
|
8 |
from torch import Tensor, nn
|
9 |
from torch.nn import BCEWithLogitsLoss
|
|
|
143 |
# Batch_Size x (Anchor + Class) x H x W
|
144 |
# TODO: check datatype, why targets has a little bit error with origin version
|
145 |
predicts, predicts_anc = self.parse_predicts(predicts[0])
|
146 |
+
targets = self.parse_targets(targets, batch_size=predicts.size(0))
|
147 |
|
148 |
align_targets, valid_masks = self.matcher(targets, predicts)
|
149 |
# calculate loss between with instance and predict
|