import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from isegm.utils import misc class NormalizedFocalLossSigmoid(nn.Module): def __init__( self, axis=-1, alpha=0.25, gamma=2, max_mult=-1, eps=1e-12, from_sigmoid=False, detach_delimeter=True, batch_axis=0, weight=None, size_average=True, ignore_label=-1, ): super(NormalizedFocalLossSigmoid, self).__init__() self._axis = axis self._alpha = alpha self._gamma = gamma self._ignore_label = ignore_label self._weight = weight if weight is not None else 1.0 self._batch_axis = batch_axis self._from_logits = from_sigmoid self._eps = eps self._size_average = size_average self._detach_delimeter = detach_delimeter self._max_mult = max_mult self._k_sum = 0 self._m_max = 0 def forward(self, pred, label): one_hot = label > 0.5 sample_weight = label != self._ignore_label if not self._from_logits: pred = torch.sigmoid(pred) alpha = torch.where( one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight ) pt = torch.where( sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred) ) beta = (1 - pt) ** self._gamma sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True) beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True) mult = sw_sum / (beta_sum + self._eps) if self._detach_delimeter: mult = mult.detach() beta = beta * mult if self._max_mult > 0: beta = torch.clamp_max(beta, self._max_mult) with torch.no_grad(): ignore_area = ( torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim()))) .cpu() .numpy() ) sample_mult = ( torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy() ) if np.any(ignore_area == 0): self._k_sum = ( 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean() ) beta_pmax, _ = torch.flatten(beta, start_dim=1).max(dim=1) beta_pmax = beta_pmax.mean().item() self._m_max = 0.8 * self._m_max + 0.2 * beta_pmax loss = ( -alpha * beta * torch.log( torch.min( pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device) ) ) ) loss = self._weight * (loss * sample_weight) if self._size_average: bsum = torch.sum( sample_weight, dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis), ) loss = torch.sum( loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis) ) / (bsum + self._eps) else: loss = torch.sum( loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis) ) return loss def log_states(self, sw, name, global_step): sw.add_scalar(tag=name + "_k", value=self._k_sum, global_step=global_step) sw.add_scalar(tag=name + "_m", value=self._m_max, global_step=global_step) class FocalLoss(nn.Module): def __init__( self, axis=-1, alpha=0.25, gamma=2, from_logits=False, batch_axis=0, weight=None, num_class=None, eps=1e-9, size_average=True, scale=1.0, ignore_label=-1, ): super(FocalLoss, self).__init__() self._axis = axis self._alpha = alpha self._gamma = gamma self._ignore_label = ignore_label self._weight = weight if weight is not None else 1.0 self._batch_axis = batch_axis self._scale = scale self._num_class = num_class self._from_logits = from_logits self._eps = eps self._size_average = size_average def forward(self, pred, label, sample_weight=None): one_hot = label > 0.5 sample_weight = label != self._ignore_label if not self._from_logits: pred = torch.sigmoid(pred) alpha = torch.where( one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight ) pt = torch.where( sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred) ) beta = (1 - pt) ** self._gamma loss = ( -alpha * beta * torch.log( torch.min( pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device) ) ) ) loss = self._weight * (loss * sample_weight) if self._size_average: tsum = torch.sum( sample_weight, dim=misc.get_dims_with_exclusion(label.dim(), self._batch_axis), ) loss = torch.sum( loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis) ) / (tsum + self._eps) else: loss = torch.sum( loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis) ) return self._scale * loss class SoftIoU(nn.Module): def __init__(self, from_sigmoid=False, ignore_label=-1): super().__init__() self._from_sigmoid = from_sigmoid self._ignore_label = ignore_label def forward(self, pred, label): label = label.view(pred.size()) sample_weight = label != self._ignore_label if not self._from_sigmoid: pred = torch.sigmoid(pred) loss = 1.0 - torch.sum(pred * label * sample_weight, dim=(1, 2, 3)) / ( torch.sum(torch.max(pred, label) * sample_weight, dim=(1, 2, 3)) + 1e-8 ) return loss class SigmoidBinaryCrossEntropyLoss(nn.Module): def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, ignore_label=-1): super(SigmoidBinaryCrossEntropyLoss, self).__init__() self._from_sigmoid = from_sigmoid self._ignore_label = ignore_label self._weight = weight if weight is not None else 1.0 self._batch_axis = batch_axis def forward(self, pred, label): label = label.view(pred.size()) sample_weight = label != self._ignore_label label = torch.where(sample_weight, label, torch.zeros_like(label)) if not self._from_sigmoid: loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred)) else: eps = 1e-12 loss = -( torch.log(pred + eps) * label + torch.log(1.0 - pred + eps) * (1.0 - label) ) loss = self._weight * (loss * sample_weight) return torch.mean( loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis) )