ghlee94's picture
Init
2a13495
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn, Tensor
__all__ = ["SoftBCEWithLogitsLoss"]
class SoftBCEWithLogitsLoss(nn.Module):
__constants__ = [
"weight",
"pos_weight",
"reduction",
"ignore_index",
"smooth_factor",
]
def __init__(
self,
weight: Optional[torch.Tensor] = None,
ignore_index: Optional[int] = -100,
reduction: str = "mean",
smooth_factor: Optional[float] = None,
pos_weight: Optional[torch.Tensor] = None,
):
"""Drop-in replacement for torch.nn.BCEWithLogitsLoss with few additions: ignore_index and label_smoothing
Args:
ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient.
smooth_factor: Factor to smooth target (e.g. if smooth_factor=0.1 then [1, 0, 1] -> [0.9, 0.1, 0.9])
Shape
- **y_pred** - torch.Tensor of shape NxCxHxW
- **y_true** - torch.Tensor of shape NxHxW or Nx1xHxW
Reference
https://github.com/BloodAxe/pytorch-toolbelt
"""
super().__init__()
self.ignore_index = ignore_index
self.reduction = reduction
self.smooth_factor = smooth_factor
self.register_buffer("weight", weight)
self.register_buffer("pos_weight", pos_weight)
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Args:
y_pred: torch.Tensor of shape (N, C, H, W)
y_true: torch.Tensor of shape (N, H, W) or (N, 1, H, W)
Returns:
loss: torch.Tensor
"""
if self.smooth_factor is not None:
soft_targets = (1 - y_true) * self.smooth_factor + y_true * (
1 - self.smooth_factor
)
else:
soft_targets = y_true
loss = F.binary_cross_entropy_with_logits(
y_pred,
soft_targets,
self.weight,
pos_weight=self.pos_weight,
reduction="none",
)
if self.ignore_index is not None:
not_ignored_mask = y_true != self.ignore_index
loss *= not_ignored_mask.type_as(loss)
if self.reduction == "mean":
loss = loss.mean()
if self.reduction == "sum":
loss = loss.sum()
return loss