curt-park's picture
Refactor code
1615d09
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from isegm.utils import misc
class NormalizedFocalLossSigmoid(nn.Module):
def __init__(
self,
axis=-1,
alpha=0.25,
gamma=2,
max_mult=-1,
eps=1e-12,
from_sigmoid=False,
detach_delimeter=True,
batch_axis=0,
weight=None,
size_average=True,
ignore_label=-1,
):
super(NormalizedFocalLossSigmoid, self).__init__()
self._axis = axis
self._alpha = alpha
self._gamma = gamma
self._ignore_label = ignore_label
self._weight = weight if weight is not None else 1.0
self._batch_axis = batch_axis
self._from_logits = from_sigmoid
self._eps = eps
self._size_average = size_average
self._detach_delimeter = detach_delimeter
self._max_mult = max_mult
self._k_sum = 0
self._m_max = 0
def forward(self, pred, label):
one_hot = label > 0.5
sample_weight = label != self._ignore_label
if not self._from_logits:
pred = torch.sigmoid(pred)
alpha = torch.where(
one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight
)
pt = torch.where(
sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred)
)
beta = (1 - pt) ** self._gamma
sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True)
beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True)
mult = sw_sum / (beta_sum + self._eps)
if self._detach_delimeter:
mult = mult.detach()
beta = beta * mult
if self._max_mult > 0:
beta = torch.clamp_max(beta, self._max_mult)
with torch.no_grad():
ignore_area = (
torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim())))
.cpu()
.numpy()
)
sample_mult = (
torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy()
)
if np.any(ignore_area == 0):
self._k_sum = (
0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean()
)
beta_pmax, _ = torch.flatten(beta, start_dim=1).max(dim=1)
beta_pmax = beta_pmax.mean().item()
self._m_max = 0.8 * self._m_max + 0.2 * beta_pmax
loss = (
-alpha
* beta
* torch.log(
torch.min(
pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)
)
)
)
loss = self._weight * (loss * sample_weight)
if self._size_average:
bsum = torch.sum(
sample_weight,
dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis),
)
loss = torch.sum(
loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)
) / (bsum + self._eps)
else:
loss = torch.sum(
loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)
)
return loss
def log_states(self, sw, name, global_step):
sw.add_scalar(tag=name + "_k", value=self._k_sum, global_step=global_step)
sw.add_scalar(tag=name + "_m", value=self._m_max, global_step=global_step)
class FocalLoss(nn.Module):
def __init__(
self,
axis=-1,
alpha=0.25,
gamma=2,
from_logits=False,
batch_axis=0,
weight=None,
num_class=None,
eps=1e-9,
size_average=True,
scale=1.0,
ignore_label=-1,
):
super(FocalLoss, self).__init__()
self._axis = axis
self._alpha = alpha
self._gamma = gamma
self._ignore_label = ignore_label
self._weight = weight if weight is not None else 1.0
self._batch_axis = batch_axis
self._scale = scale
self._num_class = num_class
self._from_logits = from_logits
self._eps = eps
self._size_average = size_average
def forward(self, pred, label, sample_weight=None):
one_hot = label > 0.5
sample_weight = label != self._ignore_label
if not self._from_logits:
pred = torch.sigmoid(pred)
alpha = torch.where(
one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight
)
pt = torch.where(
sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred)
)
beta = (1 - pt) ** self._gamma
loss = (
-alpha
* beta
* torch.log(
torch.min(
pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)
)
)
)
loss = self._weight * (loss * sample_weight)
if self._size_average:
tsum = torch.sum(
sample_weight,
dim=misc.get_dims_with_exclusion(label.dim(), self._batch_axis),
)
loss = torch.sum(
loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)
) / (tsum + self._eps)
else:
loss = torch.sum(
loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)
)
return self._scale * loss
class SoftIoU(nn.Module):
def __init__(self, from_sigmoid=False, ignore_label=-1):
super().__init__()
self._from_sigmoid = from_sigmoid
self._ignore_label = ignore_label
def forward(self, pred, label):
label = label.view(pred.size())
sample_weight = label != self._ignore_label
if not self._from_sigmoid:
pred = torch.sigmoid(pred)
loss = 1.0 - torch.sum(pred * label * sample_weight, dim=(1, 2, 3)) / (
torch.sum(torch.max(pred, label) * sample_weight, dim=(1, 2, 3)) + 1e-8
)
return loss
class SigmoidBinaryCrossEntropyLoss(nn.Module):
def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, ignore_label=-1):
super(SigmoidBinaryCrossEntropyLoss, self).__init__()
self._from_sigmoid = from_sigmoid
self._ignore_label = ignore_label
self._weight = weight if weight is not None else 1.0
self._batch_axis = batch_axis
def forward(self, pred, label):
label = label.view(pred.size())
sample_weight = label != self._ignore_label
label = torch.where(sample_weight, label, torch.zeros_like(label))
if not self._from_sigmoid:
loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred))
else:
eps = 1e-12
loss = -(
torch.log(pred + eps) * label
+ torch.log(1.0 - pred + eps) * (1.0 - label)
)
loss = self._weight * (loss * sample_weight)
return torch.mean(
loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)
)