Spaces:
Build error
Build error
import numpy as np | |
import numpy.typing as npt | |
from torch.nn import functional as F | |
from torch import Tensor | |
def calculate_dice_loss(inputs: Tensor, targets: Tensor, num_masks: int = 1) -> Tensor: | |
inputs = inputs.sigmoid() | |
inputs = inputs.flatten(1) | |
numerator = 2 * (inputs * targets).sum(-1) | |
denominator = inputs.sum(-1) + targets.sum(-1) | |
loss = 1 - (numerator + 1) / (denominator + 1) | |
return loss.sum() / num_masks | |
def calculate_sigmoid_focal_loss( | |
inputs: Tensor, | |
targets: Tensor, | |
num_masks: int = 1, | |
alpha: float = 0.25, | |
gamma: float = 2, | |
) -> Tensor: | |
prob = inputs.sigmoid() | |
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") | |
p_t = prob * targets + (1 - prob) * (1 - targets) | |
loss = ce_loss * ((1 - p_t) ** gamma) | |
if alpha >= 0: | |
alpha_t = alpha * targets + (1 - alpha) * (1 - targets) | |
loss = alpha_t * loss | |
return loss.mean(1).sum() / num_masks | |
def calculate_iou(mask1: npt.NDArray, mask2: npt.NDArray) -> float: | |
mask1 = mask1.sum(axis=2) | |
mask2 = mask2.sum(axis=2) | |
mask1 = np.where(mask1 == 128, 1, 0) | |
mask2 = np.where(mask2 == 128, 1, 0) | |
intersection = np.sum(np.logical_and(mask1, mask2)) | |
union = np.sum(np.logical_or(mask1, mask2)) | |
iou = intersection / union | |
return iou | |