YOLO / yolo /utils /bounding_box_utils.py
henry000's picture
⚡️ [Update] anchor generator, load stride if given
b86ec3e
raw
history blame
15.4 kB
import math
from typing import List, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from loguru import logger
from torch import Tensor
from torchvision.ops import batched_nms
from yolo.config.config import MatcherConfig, ModelConfig, NMSConfig
from yolo.model.yolo import YOLO
def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
metrics = metrics.lower()
EPS = 1e-9
dtype = bbox1.dtype
bbox1 = bbox1.to(torch.float32)
bbox2 = bbox2.to(torch.float32)
# Expand dimensions if necessary
if bbox1.ndim == 2 and bbox2.ndim == 2:
bbox1 = bbox1.unsqueeze(1) # (Ax4) -> (Ax1x4)
bbox2 = bbox2.unsqueeze(0) # (Bx4) -> (1xBx4)
elif bbox1.ndim == 3 and bbox2.ndim == 3:
bbox1 = bbox1.unsqueeze(2) # (BZxAx4) -> (BZxAx1x4)
bbox2 = bbox2.unsqueeze(1) # (BZxBx4) -> (BZx1xBx4)
# Calculate intersection coordinates
xmin_inter = torch.max(bbox1[..., 0], bbox2[..., 0])
ymin_inter = torch.max(bbox1[..., 1], bbox2[..., 1])
xmax_inter = torch.min(bbox1[..., 2], bbox2[..., 2])
ymax_inter = torch.min(bbox1[..., 3], bbox2[..., 3])
# Calculate intersection area
intersection_area = torch.clamp(xmax_inter - xmin_inter, min=0) * torch.clamp(ymax_inter - ymin_inter, min=0)
# Calculate area of each bbox
area_bbox1 = (bbox1[..., 2] - bbox1[..., 0]) * (bbox1[..., 3] - bbox1[..., 1])
area_bbox2 = (bbox2[..., 2] - bbox2[..., 0]) * (bbox2[..., 3] - bbox2[..., 1])
# Calculate union area
union_area = area_bbox1 + area_bbox2 - intersection_area
# Calculate IoU
iou = intersection_area / (union_area + EPS)
if metrics == "iou":
return iou
# Calculate centroid distance
cx1 = (bbox1[..., 2] + bbox1[..., 0]) / 2
cy1 = (bbox1[..., 3] + bbox1[..., 1]) / 2
cx2 = (bbox2[..., 2] + bbox2[..., 0]) / 2
cy2 = (bbox2[..., 3] + bbox2[..., 1]) / 2
cent_dis = (cx1 - cx2) ** 2 + (cy1 - cy2) ** 2
# Calculate diagonal length of the smallest enclosing box
c_x = torch.max(bbox1[..., 2], bbox2[..., 2]) - torch.min(bbox1[..., 0], bbox2[..., 0])
c_y = torch.max(bbox1[..., 3], bbox2[..., 3]) - torch.min(bbox1[..., 1], bbox2[..., 1])
diag_dis = c_x**2 + c_y**2 + EPS
diou = iou - (cent_dis / diag_dis)
if metrics == "diou":
return diou
# Compute aspect ratio penalty term
arctan = torch.atan((bbox1[..., 2] - bbox1[..., 0]) / (bbox1[..., 3] - bbox1[..., 1] + EPS)) - torch.atan(
(bbox2[..., 2] - bbox2[..., 0]) / (bbox2[..., 3] - bbox2[..., 1] + EPS)
)
v = (4 / (math.pi**2)) * (arctan**2)
alpha = v / (v - iou + 1 + EPS)
# Compute CIoU
ciou = diou - alpha * v
return ciou.to(dtype)
def transform_bbox(bbox: Tensor, indicator="xywh -> xyxy"):
data_type = bbox.dtype
in_type, out_type = indicator.replace(" ", "").split("->")
if in_type not in ["xyxy", "xywh", "xycwh"] or out_type not in ["xyxy", "xywh", "xycwh"]:
raise ValueError("Invalid input or output format")
if in_type == "xywh":
x_min = bbox[..., 0]
y_min = bbox[..., 1]
x_max = bbox[..., 0] + bbox[..., 2]
y_max = bbox[..., 1] + bbox[..., 3]
elif in_type == "xyxy":
x_min = bbox[..., 0]
y_min = bbox[..., 1]
x_max = bbox[..., 2]
y_max = bbox[..., 3]
elif in_type == "xycwh":
x_min = bbox[..., 0] - bbox[..., 2] / 2
y_min = bbox[..., 1] - bbox[..., 3] / 2
x_max = bbox[..., 0] + bbox[..., 2] / 2
y_max = bbox[..., 1] + bbox[..., 3] / 2
if out_type == "xywh":
bbox = torch.stack([x_min, y_min, x_max - x_min, y_max - y_min], dim=-1)
elif out_type == "xyxy":
bbox = torch.stack([x_min, y_min, x_max, y_max], dim=-1)
elif out_type == "xycwh":
bbox = torch.stack([(x_min + x_max) / 2, (y_min + y_max) / 2, x_max - x_min, y_max - y_min], dim=-1)
return bbox.to(dtype=data_type)
def generate_anchors(image_size: List[int], strides: List[int]):
"""
Find the anchor maps for each w, h.
Args:
image_size List: the image size of augmented image size
strides List[8, 16, 32, ...]: the stride size for each predicted layer
Returns:
all_anchors [HW x 2]:
all_scalers [HW]: The index of the best targets for each anchors
"""
W, H = image_size
anchors = []
scaler = []
for stride in strides:
anchor_num = W // stride * H // stride
scaler.append(torch.full((anchor_num,), stride))
shift = stride // 2
h = torch.arange(0, H, stride) + shift
w = torch.arange(0, W, stride) + shift
anchor_h, anchor_w = torch.meshgrid(h, w, indexing="ij")
anchor = torch.stack([anchor_w.flatten(), anchor_h.flatten()], dim=-1)
anchors.append(anchor)
all_anchors = torch.cat(anchors, dim=0)
all_scalers = torch.cat(scaler, dim=0)
return all_anchors, all_scalers
class BoxMatcher:
def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None:
self.class_num = class_num
self.anchors = anchors
for attr_name in cfg:
setattr(self, attr_name, cfg[attr_name])
def get_valid_matrix(self, target_bbox: Tensor):
"""
Get a boolean mask that indicates whether each target bounding box overlaps with each anchor.
Args:
target_bbox [batch x targets x 4]: The bounding box of each targets.
Returns:
[batch x targets x anchors]: A boolean tensor indicates if target bounding box overlaps with anchors.
"""
Xmin, Ymin, Xmax, Ymax = target_bbox[:, :, None].unbind(3)
anchors = self.anchors[None, None] # add a axis at first, second dimension
anchors_x, anchors_y = anchors.unbind(dim=3)
target_in_x = (Xmin < anchors_x) & (anchors_x < Xmax)
target_in_y = (Ymin < anchors_y) & (anchors_y < Ymax)
target_on_anchor = target_in_x & target_in_y
return target_on_anchor
def get_cls_matrix(self, predict_cls: Tensor, target_cls: Tensor) -> Tensor:
"""
Get the (predicted class' probabilities) corresponding to the target classes across all anchors
Args:
predict_cls [batch x class x anchors]: The predicted probabilities for each class across each anchor.
target_cls [batch x targets]: The class index for each target.
Returns:
[batch x targets x anchors]: The probabilities from `pred_cls` corresponding to the class indices specified in `target_cls`.
"""
# TODO: Turn 8400 to HW
target_cls = target_cls.expand(-1, -1, 8400)
predict_cls = predict_cls.transpose(1, 2)
cls_probabilities = torch.gather(predict_cls, 1, target_cls)
return cls_probabilities
def get_iou_matrix(self, predict_bbox, target_bbox) -> Tensor:
"""
Get the IoU between each target bounding box and each predicted bounding box.
Args:
predict_bbox [batch x predicts x 4]: Bounding box with [x1, y1, x2, y2].
target_bbox [batch x targets x 4]: Bounding box with [x1, y1, x2, y2].
Returns:
[batch x targets x predicts]: The IoU scores between each target and predicted.
"""
return calculate_iou(target_bbox, predict_bbox, self.iou).clamp(0, 1)
def filter_topk(self, target_matrix: Tensor, topk: int = 10) -> Tuple[Tensor, Tensor]:
"""
Filter the top-k suitability of targets for each anchor.
Args:
target_matrix [batch x targets x anchors]: The suitability for each targets-anchors
topk (int, optional): Number of top scores to retain per anchor.
Returns:
topk_targets [batch x targets x anchors]: Only leave the topk targets for each anchor
topk_masks [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions.
"""
values, indices = target_matrix.topk(topk, dim=-1)
topk_targets = torch.zeros_like(target_matrix, device=target_matrix.device)
topk_targets.scatter_(dim=-1, index=indices, src=values)
topk_masks = topk_targets > 0
return topk_targets, topk_masks
def filter_duplicates(self, target_matrix: Tensor):
"""
Filter the maximum suitability target index of each anchor.
Args:
target_matrix [batch x targets x anchors]: The suitability for each targets-anchors
Returns:
unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
"""
unique_indices = target_matrix.argmax(dim=1)
return unique_indices[..., None]
def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
"""
1. For each anchor prediction, find the highest suitability targets
2. Select the targets
2. Noramlize the class probilities of targets
"""
predict_cls, predict_bbox = predict
target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
target_cls = target_cls.long().clamp(0)
# get valid matrix (each gt appear in which anchor grid)
grid_mask = self.get_valid_matrix(target_bbox)
# get iou matrix (iou with each gt bbox and each predict anchor)
iou_mat = self.get_iou_matrix(predict_bbox, target_bbox)
# get cls matrix (cls prob with each gt class and each predict class)
cls_mat = self.get_cls_matrix(predict_cls.sigmoid(), target_cls)
target_matrix = grid_mask * (iou_mat ** self.factor["iou"]) * (cls_mat ** self.factor["cls"])
# choose topk
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
# delete one anchor pred assign to mutliple gts
unique_indices = self.filter_duplicates(topk_targets)
# TODO: do we need grid_mask? Filter the valid groud truth
valid_mask = (grid_mask.sum(dim=-2) * topk_mask.sum(dim=-2)).bool()
align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
align_cls = F.one_hot(align_cls, self.class_num)
# normalize class ditribution
max_target = target_matrix.amax(dim=-1, keepdim=True)
max_iou = iou_mat.amax(dim=-1, keepdim=True)
normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
align_cls = align_cls * normalize_term * valid_mask[:, :, None]
return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
class Vec2Box:
def __init__(self, model: YOLO, image_size, device):
self.device = device
if getattr(model, "strides"):
logger.info(f"🈶 Found stride of model {model.strides}")
self.strides = model.strides
else:
logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
self.strides = self.create_auto_anchor(model, image_size)
# TODO: this is a exception of onnx, remove it when onnx device if fixed
if not isinstance(model, YOLO):
device = torch.device("cpu")
anchor_grid, scaler = generate_anchors(image_size, self.strides)
self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device)
def create_auto_anchor(self, model: YOLO, image_size):
dummy_input = torch.zeros(1, 3, *image_size).to(self.device)
dummy_output = model(dummy_input)
strides = []
for predict_head in dummy_output["Main"]:
_, _, *anchor_num = predict_head[2].shape
strides.append(image_size[1] // anchor_num[1])
return strides
def update(self, image_size):
anchor_grid, scaler = generate_anchors(image_size, self.strides)
self.anchor_grid, self.scaler = anchor_grid.to(self.device), scaler.to(self.device)
def __call__(self, predicts):
preds_cls, preds_anc, preds_box = [], [], []
for layer_output in predicts:
pred_cls, pred_anc, pred_box = layer_output
preds_cls.append(rearrange(pred_cls, "B C h w -> B (h w) C"))
preds_anc.append(rearrange(pred_anc, "B A R h w -> B (h w) R A"))
preds_box.append(rearrange(pred_box, "B X h w -> B (h w) X"))
preds_cls = torch.concat(preds_cls, dim=1)
preds_anc = torch.concat(preds_anc, dim=1)
preds_box = torch.concat(preds_box, dim=1)
pred_LTRB = preds_box * self.scaler.view(1, -1, 1)
lt, rb = pred_LTRB.chunk(2, dim=-1)
preds_box = torch.cat([self.anchor_grid - lt, self.anchor_grid + rb], dim=-1)
return preds_cls, preds_anc, preds_box
def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig):
# TODO change function to class or set 80 to class_num instead of a number
cls_dist = cls_dist.sigmoid()
# filter class by confidence
cls_val, cls_idx = cls_dist.max(dim=-1, keepdim=True)
valid_mask = cls_val > nms_cfg.min_confidence
valid_cls = cls_idx[valid_mask].float()
valid_con = cls_val[valid_mask].float()
valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
batch_idx, *_ = torch.where(valid_mask)
nms_idx = batched_nms(valid_box, valid_cls, batch_idx, nms_cfg.min_iou)
predicts_nms = []
for idx in range(cls_dist.size(0)):
instance_idx = nms_idx[idx == batch_idx[nms_idx]]
predict_nms = torch.cat(
[valid_cls[instance_idx][:, None], valid_box[instance_idx], valid_con[instance_idx][:, None]], dim=-1
)
predicts_nms.append(predict_nms)
return predicts_nms
def calculate_map(predictions, ground_truths, iou_thresholds):
# TODO: Refactor this block
device = predictions.device
n_preds = predictions.size(0)
n_gts = (ground_truths[:, 0] != -1).sum()
ground_truths = ground_truths[:n_gts]
aps = []
ious = calculate_iou(predictions[:, 1:-1], ground_truths[:, 1:]) # [n_preds, n_gts]
for threshold in iou_thresholds:
tp = torch.zeros(n_preds, device=device)
fp = torch.zeros(n_preds, device=device)
max_iou, max_indices = torch.max(ious, dim=1)
above_threshold = max_iou >= threshold
matched_classes = predictions[:, 0] == ground_truths[max_indices, 0]
tp[above_threshold & matched_classes] = 1
fp[above_threshold & ~matched_classes] = 1
fp[max_iou < threshold] = 1
_, indices = torch.sort(predictions[:, 1], descending=True)
tp = tp[indices]
fp = fp[indices]
tp_cumsum = torch.cumsum(tp, dim=0)
fp_cumsum = torch.cumsum(fp, dim=0)
precision = tp_cumsum / (tp_cumsum + fp_cumsum + 1e-6)
recall = tp_cumsum / (n_gts + 1e-6)
recall_thresholds = torch.arange(0, 1, 0.1)
precision_at_recall = torch.zeros_like(recall_thresholds)
for i, r in enumerate(recall_thresholds):
precision_at_recall[i] = precision[recall >= r].max().item() if torch.any(recall >= r) else 0
ap = precision_at_recall.mean()
aps.append(ap)
mean_ap = torch.mean(torch.stack(aps))
return mean_ap, aps