import pytest | |
import torch | |
from mmdet.models.losses import (BoundedIoULoss, CIoULoss, DIoULoss, GIoULoss, | |
IoULoss) | |
def test_iou_type_loss_zeros_weight(loss_class): | |
pred = torch.rand((10, 4)) | |
target = torch.rand((10, 4)) | |
weight = torch.zeros(10) | |
loss = loss_class()(pred, target, weight) | |
assert loss == 0. | |