File size: 6,302 Bytes
f2370d7
cbbfcfe
 
 
 
 
 
 
 
 
dcceddd
 
 
 
 
 
 
cbbfcfe
 
 
 
 
dc55a8e
 
cbbfcfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5fa3f1
 
cbbfcfe
710e371
cbbfcfe
033231b
cbbfcfe
 
dcceddd
cbbfcfe
 
 
 
 
b5fa3f1
dcceddd
cbbfcfe
 
 
 
 
 
 
 
 
 
 
 
253e9b1
cbbfcfe
253e9b1
cbbfcfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9338ee
 
 
 
 
 
b5fa3f1
c9338ee
b5fa3f1
 
 
c9338ee
f2370d7
c9338ee
 
 
dcceddd
c9338ee
 
 
f2370d7
 
 
 
 
 
 
584d5bd
 
 
c9338ee
584d5bd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from typing import Any, Dict, List, Tuple

import torch
import torch.nn.functional as F
from einops import rearrange
from loguru import logger
from torch import Tensor, nn
from torch.nn import BCEWithLogitsLoss

from yolo.config.config import Config
from yolo.utils.bounding_box_utils import (
    AnchorBoxConverter,
    BoxMatcher,
    calculate_iou,
    generate_anchors,
)
from yolo.utils.module_utils import divide_into_chunks


class BCELoss(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.bce = BCEWithLogitsLoss(pos_weight=torch.tensor([1.0], device=device), reduction="none")

    def forward(self, predicts_cls: Tensor, targets_cls: Tensor, cls_norm: Tensor) -> Any:
        return self.bce(predicts_cls, targets_cls).sum() / cls_norm


class BoxLoss(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(
        self, predicts_bbox: Tensor, targets_bbox: Tensor, valid_masks: Tensor, box_norm: Tensor, cls_norm: Tensor
    ) -> Any:
        valid_bbox = valid_masks[..., None].expand(-1, -1, 4)
        picked_predict = predicts_bbox[valid_bbox].view(-1, 4)
        picked_targets = targets_bbox[valid_bbox].view(-1, 4)

        iou = calculate_iou(picked_predict, picked_targets, "ciou").diag()
        loss_iou = 1.0 - iou
        loss_iou = (loss_iou * box_norm).sum() / cls_norm
        return loss_iou


class DFLoss(nn.Module):
    def __init__(self, anchors: Tensor, scaler: Tensor, reg_max: int) -> None:
        super().__init__()
        self.anchors = anchors
        self.scaler = scaler
        self.reg_max = reg_max

    def forward(
        self, predicts_anc: Tensor, targets_bbox: Tensor, valid_masks: Tensor, box_norm: Tensor, cls_norm: Tensor
    ) -> Any:
        valid_bbox = valid_masks[..., None].expand(-1, -1, 4)
        bbox_lt, bbox_rb = targets_bbox.chunk(2, -1)
        anchors_norm = (self.anchors / self.scaler[:, None])[None]
        targets_dist = torch.cat(((anchors_norm - bbox_lt), (bbox_rb - anchors_norm)), -1).clamp(0, self.reg_max - 1.01)
        picked_targets = targets_dist[valid_bbox].view(-1)
        picked_predict = predicts_anc[valid_bbox].view(-1, self.reg_max)

        label_left, label_right = picked_targets.floor(), picked_targets.floor() + 1
        weight_left, weight_right = label_right - picked_targets, picked_targets - label_left

        loss_left = F.cross_entropy(picked_predict, label_left.to(torch.long), reduction="none")
        loss_right = F.cross_entropy(picked_predict, label_right.to(torch.long), reduction="none")
        loss_dfl = loss_left * weight_left + loss_right * weight_right
        loss_dfl = loss_dfl.view(-1, 4).mean(-1)
        loss_dfl = (loss_dfl * box_norm).sum() / cls_norm
        return loss_dfl


class YOLOLoss:
    def __init__(self, cfg: Config) -> None:
        self.reg_max = cfg.model.anchor.reg_max
        self.class_num = cfg.class_num
        self.image_size = list(cfg.image_size)
        self.strides = cfg.model.anchor.strides
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float32, device=device)
        self.scale_up = torch.tensor(self.image_size * 2, device=device)

        self.anchors, self.scaler = generate_anchors(self.image_size, self.strides, device)

        self.cls = BCELoss()
        self.dfl = DFLoss(self.anchors, self.scaler, self.reg_max)
        self.iou = BoxLoss()

        self.matcher = BoxMatcher(cfg.task.loss.matcher, self.class_num, self.anchors)
        self.box_converter = AnchorBoxConverter(cfg, device)

    def separate_anchor(self, anchors):
        """
        separate anchor and bbouding box
        """
        anchors_cls, anchors_box = torch.split(anchors, (self.class_num, 4), dim=-1)
        anchors_box = anchors_box / self.scaler[None, :, None]
        return anchors_cls, anchors_box

    def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        # Batch_Size x (Anchor + Class) x H x W
        # TODO: check datatype, why targets has a little bit error with origin version
        predicts, predicts_anc = self.box_converter(predicts)

        # For each predicted targets, assign a best suitable ground truth box.
        align_targets, valid_masks = self.matcher(targets, predicts)

        targets_cls, targets_bbox = self.separate_anchor(align_targets)
        predicts_cls, predicts_bbox = self.separate_anchor(predicts)

        cls_norm = targets_cls.sum()
        box_norm = targets_cls.sum(-1)[valid_masks]

        ## -- CLS -- ##
        loss_cls = self.cls(predicts_cls, targets_cls, cls_norm)
        ## -- IOU -- ##
        loss_iou = self.iou(predicts_bbox, targets_bbox, valid_masks, box_norm, cls_norm)
        ## -- DFL -- ##
        loss_dfl = self.dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm)

        return loss_iou, loss_dfl, loss_cls


class DualLoss:
    def __init__(self, cfg: Config) -> None:
        self.loss = YOLOLoss(cfg)
        self.aux_rate = cfg.task.loss.aux

        self.iou_rate = cfg.task.loss.objective["BoxLoss"]
        self.dfl_rate = cfg.task.loss.objective["DFLoss"]
        self.cls_rate = cfg.task.loss.objective["BCELoss"]

    def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:
        targets[:, :, 1:] = targets[:, :, 1:] * self.loss.scale_up

        # TODO: Need Refactor this region, make it flexible!
        predicts = divide_into_chunks(predicts[0], 2)
        aux_iou, aux_dfl, aux_cls = self.loss(predicts[0], targets)
        main_iou, main_dfl, main_cls = self.loss(predicts[1], targets)

        loss_dict = {
            "BoxLoss": self.iou_rate * (aux_iou * self.aux_rate + main_iou),
            "DFLoss": self.dfl_rate * (aux_dfl * self.aux_rate + main_dfl),
            "BCELoss": self.cls_rate * (aux_cls * self.aux_rate + main_cls),
        }
        loss_sum = sum(list(loss_dict.values())) / len(loss_dict)
        return loss_sum, loss_dict


def get_loss_function(cfg: Config) -> YOLOLoss:
    loss_function = DualLoss(cfg)
    logger.info("✅ Success load loss function")
    return loss_function