glenn-jocher commited on
Commit
a84cd02
·
unverified ·
1 Parent(s): 7dafd1c

CIoU protected divides (#8546)

Browse files

Protected divides in IOU function to resolve https://github.com/ultralytics/yolov5/issues/8539

Files changed (1) hide show
  1. utils/metrics.py +3 -3
utils/metrics.py CHANGED
@@ -225,8 +225,8 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
225
  else: # x1, y1, x2, y2 = box1
226
  b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, 1)
227
  b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, 1)
228
- w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
229
- w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
230
 
231
  # Intersection area
232
  inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
@@ -244,7 +244,7 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
244
  c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
245
  rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
246
  if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
247
- v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
248
  with torch.no_grad():
249
  alpha = v / (v - iou + (1 + eps))
250
  return iou - (rho2 / c2 + v * alpha) # CIoU
 
225
  else: # x1, y1, x2, y2 = box1
226
  b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, 1)
227
  b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, 1)
228
+ w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1
229
+ w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1
230
 
231
  # Intersection area
232
  inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
 
244
  c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
245
  rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
246
  if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
247
+ v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / (h2 + eps)) - torch.atan(w1 / (h1 + eps)), 2)
248
  with torch.no_grad():
249
  alpha = v / (v - iou + (1 + eps))
250
  return iou - (rho2 / c2 + v * alpha) # CIoU