File size: 1,105 Bytes
f7f604d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
import torch.nn.functional as F
def bce_loss(pred, mask, reduction='none'):
bce = F.binary_cross_entropy(pred, mask, reduction=reduction)
return bce
def weighted_bce_loss(pred, mask, reduction='none'):
weight = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
weight = weight.flatten()
bce = weight * bce_loss(pred, mask, reduction='none').flatten()
if reduction == 'mean':
bce = bce.mean()
return bce
def iou_loss(pred, mask, reduction='none'):
inter = pred * mask
union = pred + mask
iou = 1 - (inter + 1) / (union - inter + 1)
if reduction == 'mean':
iou = iou.mean()
return iou
def bce_loss_with_logits(pred, mask, reduction='none'):
return bce_loss(torch.sigmoid(pred), mask, reduction=reduction)
def weighted_bce_loss_with_logits(pred, mask, reduction='none'):
return weighted_bce_loss(torch.sigmoid(pred), mask, reduction=reduction)
def iou_loss_with_logits(pred, mask, reduction='none'):
return iou_loss(torch.sigmoid(pred), mask, reduction=reduction) |