# Copyright (c) OpenMMLab. All rights reserved. import math from typing import Optional, Tuple, Union import torch import torch.nn as nn from mmdet.models.losses.utils import weight_reduce_loss from mmdet.structures.bbox import HorizontalBoxes from mmyolo.registry import MODELS def bbox_overlaps(pred: torch.Tensor, target: torch.Tensor, iou_mode: str = 'ciou', bbox_format: str = 'xywh', siou_theta: float = 4.0, eps: float = 1e-7) -> torch.Tensor: r"""Calculate overlap between two set of bboxes. `Implementation of paper `Enhancing Geometric Factors into Model Learning and Inference for Object Detection and Instance Segmentation `_. In the CIoU implementation of YOLOv5 and MMDetection, there is a slight difference in the way the alpha parameter is computed. mmdet version: alpha = (ious > 0.5).float() * v / (1 - ious + v) YOLOv5 version: alpha = v / (v - ious + (1 + eps) Args: pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2) or (x, y, w, h),shape (n, 4). target (Tensor): Corresponding gt bboxes, shape (n, 4). iou_mode (str): Options are ('iou', 'ciou', 'giou', 'siou'). Defaults to "ciou". bbox_format (str): Options are "xywh" and "xyxy". Defaults to "xywh". siou_theta (float): siou_theta for SIoU when calculate shape cost. Defaults to 4.0. eps (float): Eps to avoid log(0). Returns: Tensor: shape (n, ). """ assert iou_mode in ('iou', 'ciou', 'giou', 'siou') assert bbox_format in ('xyxy', 'xywh') if bbox_format == 'xywh': pred = HorizontalBoxes.cxcywh_to_xyxy(pred) target = HorizontalBoxes.cxcywh_to_xyxy(target) bbox1_x1, bbox1_y1 = pred[..., 0], pred[..., 1] bbox1_x2, bbox1_y2 = pred[..., 2], pred[..., 3] bbox2_x1, bbox2_y1 = target[..., 0], target[..., 1] bbox2_x2, bbox2_y2 = target[..., 2], target[..., 3] # Overlap overlap = (torch.min(bbox1_x2, bbox2_x2) - torch.max(bbox1_x1, bbox2_x1)).clamp(0) * \ (torch.min(bbox1_y2, bbox2_y2) - torch.max(bbox1_y1, bbox2_y1)).clamp(0) # Union w1, h1 = bbox1_x2 - bbox1_x1, bbox1_y2 - bbox1_y1 w2, h2 = bbox2_x2 - bbox2_x1, bbox2_y2 - bbox2_y1 union = (w1 * h1) + (w2 * h2) - overlap + eps h1 = bbox1_y2 - bbox1_y1 + eps h2 = bbox2_y2 - bbox2_y1 + eps # IoU ious = overlap / union # enclose area enclose_x1y1 = torch.min(pred[..., :2], target[..., :2]) enclose_x2y2 = torch.max(pred[..., 2:], target[..., 2:]) enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0) enclose_w = enclose_wh[..., 0] # cw enclose_h = enclose_wh[..., 1] # ch if iou_mode == 'ciou': # CIoU = IoU - ( (ρ^2(b_pred,b_gt) / c^2) + (alpha x v) ) # calculate enclose area (c^2) enclose_area = enclose_w**2 + enclose_h**2 + eps # calculate ρ^2(b_pred,b_gt): # euclidean distance between b_pred(bbox2) and b_gt(bbox1) # center point, because bbox format is xyxy -> left-top xy and # right-bottom xy, so need to / 4 to get center point. rho2_left_item = ((bbox2_x1 + bbox2_x2) - (bbox1_x1 + bbox1_x2))**2 / 4 rho2_right_item = ((bbox2_y1 + bbox2_y2) - (bbox1_y1 + bbox1_y2))**2 / 4 rho2 = rho2_left_item + rho2_right_item # rho^2 (ρ^2) # Width and height ratio (v) wh_ratio = (4 / (math.pi**2)) * torch.pow( torch.atan(w2 / h2) - torch.atan(w1 / h1), 2) with torch.no_grad(): alpha = wh_ratio / (wh_ratio - ious + (1 + eps)) # CIoU ious = ious - ((rho2 / enclose_area) + (alpha * wh_ratio)) elif iou_mode == 'giou': # GIoU = IoU - ( (A_c - union) / A_c ) convex_area = enclose_w * enclose_h + eps # convex area (A_c) ious = ious - (convex_area - union) / convex_area elif iou_mode == 'siou': # SIoU: https://arxiv.org/pdf/2205.12740.pdf # SIoU = IoU - ( (Distance Cost + Shape Cost) / 2 ) # calculate sigma (σ): # euclidean distance between bbox2(pred) and bbox1(gt) center point, # sigma_cw = b_cx_gt - b_cx sigma_cw = (bbox2_x1 + bbox2_x2) / 2 - (bbox1_x1 + bbox1_x2) / 2 + eps # sigma_ch = b_cy_gt - b_cy sigma_ch = (bbox2_y1 + bbox2_y2) / 2 - (bbox1_y1 + bbox1_y2) / 2 + eps # sigma = √( (sigma_cw ** 2) - (sigma_ch ** 2) ) sigma = torch.pow(sigma_cw**2 + sigma_ch**2, 0.5) # choose minimize alpha, sin(alpha) sin_alpha = torch.abs(sigma_ch) / sigma sin_beta = torch.abs(sigma_cw) / sigma sin_alpha = torch.where(sin_alpha <= math.sin(math.pi / 4), sin_alpha, sin_beta) # Angle cost = 1 - 2 * ( sin^2 ( arcsin(x) - (pi / 4) ) ) angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2) # Distance cost = Σ_(t=x,y) (1 - e ^ (- γ ρ_t)) rho_x = (sigma_cw / enclose_w)**2 # ρ_x rho_y = (sigma_ch / enclose_h)**2 # ρ_y gamma = 2 - angle_cost # γ distance_cost = (1 - torch.exp(-1 * gamma * rho_x)) + ( 1 - torch.exp(-1 * gamma * rho_y)) # Shape cost = Ω = Σ_(t=w,h) ( ( 1 - ( e ^ (-ω_t) ) ) ^ θ ) omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2) # ω_w omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2) # ω_h shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), siou_theta) + torch.pow( 1 - torch.exp(-1 * omiga_h), siou_theta) ious = ious - ((distance_cost + shape_cost) * 0.5) return ious.clamp(min=-1.0, max=1.0) @MODELS.register_module() class IoULoss(nn.Module): """IoULoss. Computing the IoU loss between a set of predicted bboxes and target bboxes. Args: iou_mode (str): Options are "ciou". Defaults to "ciou". bbox_format (str): Options are "xywh" and "xyxy". Defaults to "xywh". eps (float): Eps to avoid log(0). reduction (str): Options are "none", "mean" and "sum". loss_weight (float): Weight of loss. return_iou (bool): If True, return loss and iou. """ def __init__(self, iou_mode: str = 'ciou', bbox_format: str = 'xywh', eps: float = 1e-7, reduction: str = 'mean', loss_weight: float = 1.0, return_iou: bool = True): super().__init__() assert bbox_format in ('xywh', 'xyxy') assert iou_mode in ('ciou', 'siou', 'giou') self.iou_mode = iou_mode self.bbox_format = bbox_format self.eps = eps self.reduction = reduction self.loss_weight = loss_weight self.return_iou = return_iou def forward( self, pred: torch.Tensor, target: torch.Tensor, weight: Optional[torch.Tensor] = None, avg_factor: Optional[float] = None, reduction_override: Optional[Union[str, bool]] = None ) -> Tuple[Union[torch.Tensor, torch.Tensor], torch.Tensor]: """Forward function. Args: pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2) or (x, y, w, h),shape (n, 4). target (Tensor): Corresponding gt bboxes, shape (n, 4). weight (Tensor, optional): Element-wise weights. avg_factor (float, optional): Average factor when computing the mean of losses. reduction_override (str, bool, optional): Same as built-in losses of PyTorch. Defaults to None. Returns: loss or tuple(loss, iou): """ if weight is not None and not torch.any(weight > 0): if pred.dim() == weight.dim() + 1: weight = weight.unsqueeze(1) return (pred * weight).sum() # 0 assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) if weight is not None and weight.dim() > 1: weight = weight.mean(-1) iou = bbox_overlaps( pred, target, iou_mode=self.iou_mode, bbox_format=self.bbox_format, eps=self.eps) loss = self.loss_weight * weight_reduce_loss(1.0 - iou, weight, reduction, avg_factor) if self.return_iou: return loss, iou else: return loss