|
import torch |
|
import torch.nn.functional as F |
|
|
|
def bce_loss(pred, mask, reduction='none'): |
|
bce = F.binary_cross_entropy(pred, mask, reduction=reduction) |
|
return bce |
|
|
|
def weighted_bce_loss(pred, mask, reduction='none'): |
|
weight = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask) |
|
weight = weight.flatten() |
|
|
|
bce = weight * bce_loss(pred, mask, reduction='none').flatten() |
|
|
|
if reduction == 'mean': |
|
bce = bce.mean() |
|
|
|
return bce |
|
|
|
def iou_loss(pred, mask, reduction='none'): |
|
inter = pred * mask |
|
union = pred + mask |
|
iou = 1 - (inter + 1) / (union - inter + 1) |
|
|
|
if reduction == 'mean': |
|
iou = iou.mean() |
|
|
|
return iou |
|
|
|
def bce_loss_with_logits(pred, mask, reduction='none'): |
|
return bce_loss(torch.sigmoid(pred), mask, reduction=reduction) |
|
|
|
def weighted_bce_loss_with_logits(pred, mask, reduction='none'): |
|
return weighted_bce_loss(torch.sigmoid(pred), mask, reduction=reduction) |
|
|
|
def iou_loss_with_logits(pred, mask, reduction='none'): |
|
return iou_loss(torch.sigmoid(pred), mask, reduction=reduction) |