YOLO / yolo /tools /loss_functions.py
henry000's picture
🧱 [Update] config files struct, make it clearly
b5fa3f1
raw
history blame
6.3 kB
from typing import Any, Dict, List, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from loguru import logger
from torch import Tensor, nn
from torch.nn import BCEWithLogitsLoss
from yolo.config.config import Config
from yolo.utils.bounding_box_utils import (
AnchorBoxConverter,
BoxMatcher,
calculate_iou,
generate_anchors,
)
from yolo.utils.module_utils import divide_into_chunks
class BCELoss(nn.Module):
def __init__(self) -> None:
super().__init__()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.bce = BCEWithLogitsLoss(pos_weight=torch.tensor([1.0], device=device), reduction="none")
def forward(self, predicts_cls: Tensor, targets_cls: Tensor, cls_norm: Tensor) -> Any:
return self.bce(predicts_cls, targets_cls).sum() / cls_norm
class BoxLoss(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self, predicts_bbox: Tensor, targets_bbox: Tensor, valid_masks: Tensor, box_norm: Tensor, cls_norm: Tensor
) -> Any:
valid_bbox = valid_masks[..., None].expand(-1, -1, 4)
picked_predict = predicts_bbox[valid_bbox].view(-1, 4)
picked_targets = targets_bbox[valid_bbox].view(-1, 4)
iou = calculate_iou(picked_predict, picked_targets, "ciou").diag()
loss_iou = 1.0 - iou
loss_iou = (loss_iou * box_norm).sum() / cls_norm
return loss_iou
class DFLoss(nn.Module):
def __init__(self, anchors: Tensor, scaler: Tensor, reg_max: int) -> None:
super().__init__()
self.anchors = anchors
self.scaler = scaler
self.reg_max = reg_max
def forward(
self, predicts_anc: Tensor, targets_bbox: Tensor, valid_masks: Tensor, box_norm: Tensor, cls_norm: Tensor
) -> Any:
valid_bbox = valid_masks[..., None].expand(-1, -1, 4)
bbox_lt, bbox_rb = targets_bbox.chunk(2, -1)
anchors_norm = (self.anchors / self.scaler[:, None])[None]
targets_dist = torch.cat(((anchors_norm - bbox_lt), (bbox_rb - anchors_norm)), -1).clamp(0, self.reg_max - 1.01)
picked_targets = targets_dist[valid_bbox].view(-1)
picked_predict = predicts_anc[valid_bbox].view(-1, self.reg_max)
label_left, label_right = picked_targets.floor(), picked_targets.floor() + 1
weight_left, weight_right = label_right - picked_targets, picked_targets - label_left
loss_left = F.cross_entropy(picked_predict, label_left.to(torch.long), reduction="none")
loss_right = F.cross_entropy(picked_predict, label_right.to(torch.long), reduction="none")
loss_dfl = loss_left * weight_left + loss_right * weight_right
loss_dfl = loss_dfl.view(-1, 4).mean(-1)
loss_dfl = (loss_dfl * box_norm).sum() / cls_norm
return loss_dfl
class YOLOLoss:
def __init__(self, cfg: Config) -> None:
self.reg_max = cfg.model.anchor.reg_max
self.class_num = cfg.class_num
self.image_size = list(cfg.image_size)
self.strides = cfg.model.anchor.strides
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float32, device=device)
self.scale_up = torch.tensor(self.image_size * 2, device=device)
self.anchors, self.scaler = generate_anchors(self.image_size, self.strides, device)
self.cls = BCELoss()
self.dfl = DFLoss(self.anchors, self.scaler, self.reg_max)
self.iou = BoxLoss()
self.matcher = BoxMatcher(cfg.task.loss.matcher, self.class_num, self.anchors)
self.box_converter = AnchorBoxConverter(cfg, device)
def separate_anchor(self, anchors):
"""
separate anchor and bbouding box
"""
anchors_cls, anchors_box = torch.split(anchors, (self.class_num, 4), dim=-1)
anchors_box = anchors_box / self.scaler[None, :, None]
return anchors_cls, anchors_box
def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
# Batch_Size x (Anchor + Class) x H x W
# TODO: check datatype, why targets has a little bit error with origin version
predicts, predicts_anc = self.box_converter(predicts)
# For each predicted targets, assign a best suitable ground truth box.
align_targets, valid_masks = self.matcher(targets, predicts)
targets_cls, targets_bbox = self.separate_anchor(align_targets)
predicts_cls, predicts_bbox = self.separate_anchor(predicts)
cls_norm = targets_cls.sum()
box_norm = targets_cls.sum(-1)[valid_masks]
## -- CLS -- ##
loss_cls = self.cls(predicts_cls, targets_cls, cls_norm)
## -- IOU -- ##
loss_iou = self.iou(predicts_bbox, targets_bbox, valid_masks, box_norm, cls_norm)
## -- DFL -- ##
loss_dfl = self.dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm)
return loss_iou, loss_dfl, loss_cls
class DualLoss:
def __init__(self, cfg: Config) -> None:
self.loss = YOLOLoss(cfg)
self.aux_rate = cfg.task.loss.aux
self.iou_rate = cfg.task.loss.objective["BoxLoss"]
self.dfl_rate = cfg.task.loss.objective["DFLoss"]
self.cls_rate = cfg.task.loss.objective["BCELoss"]
def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:
targets[:, :, 1:] = targets[:, :, 1:] * self.loss.scale_up
# TODO: Need Refactor this region, make it flexible!
predicts = divide_into_chunks(predicts[0], 2)
aux_iou, aux_dfl, aux_cls = self.loss(predicts[0], targets)
main_iou, main_dfl, main_cls = self.loss(predicts[1], targets)
loss_dict = {
"BoxLoss": self.iou_rate * (aux_iou * self.aux_rate + main_iou),
"DFLoss": self.dfl_rate * (aux_dfl * self.aux_rate + main_dfl),
"BCELoss": self.cls_rate * (aux_cls * self.aux_rate + main_cls),
}
loss_sum = sum(list(loss_dict.values())) / len(loss_dict)
return loss_sum, loss_dict
def get_loss_function(cfg: Config) -> YOLOLoss:
loss_function = DualLoss(cfg)
logger.info("✅ Success load loss function")
return loss_function