import torch import torch.nn.functional as F def bce_loss(pred, mask, reduction='none'): bce = F.binary_cross_entropy(pred, mask, reduction=reduction) return bce def weighted_bce_loss(pred, mask, reduction='none'): weight = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask) weight = weight.flatten() bce = weight * bce_loss(pred, mask, reduction='none').flatten() if reduction == 'mean': bce = bce.mean() return bce def iou_loss(pred, mask, reduction='none'): inter = pred * mask union = pred + mask iou = 1 - (inter + 1) / (union - inter + 1) if reduction == 'mean': iou = iou.mean() return iou def bce_loss_with_logits(pred, mask, reduction='none'): return bce_loss(torch.sigmoid(pred), mask, reduction=reduction) def weighted_bce_loss_with_logits(pred, mask, reduction='none'): return weighted_bce_loss(torch.sigmoid(pred), mask, reduction=reduction) def iou_loss_with_logits(pred, mask, reduction='none'): return iou_loss(torch.sigmoid(pred), mask, reduction=reduction)