File size: 1,332 Bytes
6723494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import numpy as np
import numpy.typing as npt

from torch.nn import functional as F
from torch import Tensor


def calculate_dice_loss(inputs: Tensor, targets: Tensor, num_masks: int = 1) -> Tensor:
    inputs = inputs.sigmoid()
    inputs = inputs.flatten(1)
    numerator = 2 * (inputs * targets).sum(-1)
    denominator = inputs.sum(-1) + targets.sum(-1)
    loss = 1 - (numerator + 1) / (denominator + 1)
    return loss.sum() / num_masks


def calculate_sigmoid_focal_loss(
    inputs: Tensor,
    targets: Tensor,
    num_masks: int = 1,
    alpha: float = 0.25,
    gamma: float = 2,
) -> Tensor:
    prob = inputs.sigmoid()
    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    p_t = prob * targets + (1 - prob) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    return loss.mean(1).sum() / num_masks


def calculate_iou(mask1: npt.NDArray, mask2: npt.NDArray) -> float:
    mask1 = mask1.sum(axis=2)
    mask2 = mask2.sum(axis=2)

    mask1 = np.where(mask1 == 128, 1, 0)
    mask2 = np.where(mask2 == 128, 1, 0)
    intersection = np.sum(np.logical_and(mask1, mask2))
    union = np.sum(np.logical_or(mask1, mask2))
    iou = intersection / union
    return iou