File size: 1,468 Bytes
2a13495 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
from typing import Optional
from torch import nn, Tensor
import torch
import torch.nn.functional as F
from ._functional import label_smoothed_nll_loss
__all__ = ["SoftCrossEntropyLoss"]
class SoftCrossEntropyLoss(nn.Module):
__constants__ = ["reduction", "ignore_index", "smooth_factor"]
def __init__(
self,
reduction: str = "mean",
smooth_factor: Optional[float] = None,
ignore_index: Optional[int] = -100,
dim: int = 1,
):
"""Drop-in replacement for torch.nn.CrossEntropyLoss with label_smoothing
Args:
smooth_factor: Factor to smooth target (e.g. if smooth_factor=0.1 then [1, 0, 0] -> [0.9, 0.05, 0.05])
Shape
- **y_pred** - torch.Tensor of shape (N, C, H, W)
- **y_true** - torch.Tensor of shape (N, H, W)
Reference
https://github.com/BloodAxe/pytorch-toolbelt
"""
super().__init__()
self.smooth_factor = smooth_factor
self.ignore_index = ignore_index
self.reduction = reduction
self.dim = dim
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
log_prob = F.log_softmax(y_pred, dim=self.dim)
return label_smoothed_nll_loss(
log_prob,
y_true,
epsilon=self.smooth_factor,
ignore_index=self.ignore_index,
reduction=self.reduction,
dim=self.dim,
)
|