veichta's picture
Upload folder using huggingface_hub
205a7af verified
raw
history blame
3.41 kB
"""Generic losses and error functions for optimization or training deep networks."""
from typing import Callable, Tuple
import torch
def scaled_loss(
x: torch.Tensor, fn: Callable, a: float
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Apply a loss function to a tensor and pre- and post-scale it.
Args:
x: the data tensor, should already be squared: `x = y**2`.
fn: the loss function, with signature `fn(x) -> y`.
a: the scale parameter.
Returns:
The value of the loss, and its first and second derivatives.
"""
a2 = a**2
loss, loss_d1, loss_d2 = fn(x / a2)
return loss * a2, loss_d1, loss_d2 / a2
def squared_loss(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""A dummy squared loss."""
return x, torch.ones_like(x), torch.zeros_like(x)
def huber_loss(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""The classical robust Huber loss, with first and second derivatives."""
mask = x <= 1
sx = torch.sqrt(x + 1e-8) # avoid nan in backward pass
isx = torch.max(sx.new_tensor(torch.finfo(torch.float).eps), 1 / sx)
loss = torch.where(mask, x, 2 * sx - 1)
loss_d1 = torch.where(mask, torch.ones_like(x), isx)
loss_d2 = torch.where(mask, torch.zeros_like(x), -isx / (2 * x))
return loss, loss_d1, loss_d2
def barron_loss(
x: torch.Tensor, alpha: torch.Tensor, derivatives: bool = True, eps: float = 1e-7
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Parameterized & adaptive robust loss function.
Described in:
A General and Adaptive Robust Loss Function, Barron, CVPR 2019
alpha = 2 -> L2 loss
alpha = 1 -> Charbonnier loss (smooth L1)
alpha = 0 -> Cauchy loss
alpha = -2 -> Geman-McClure loss
alpha = -inf -> Welsch loss
Contrary to the original implementation, assume the the input is already
squared and scaled (basically scale=1). Computes the first derivative, but
not the second (TODO if needed).
"""
loss_two = x
loss_zero = 2 * torch.log1p(torch.clamp(0.5 * x, max=33e37))
# The loss when not in one of the above special cases.
# Clamp |2-alpha| to be >= machine epsilon so that it's safe to divide by.
beta_safe = torch.abs(alpha - 2.0).clamp(min=eps)
# Clamp |alpha| to be >= machine epsilon so that it's safe to divide by.
alpha_safe = torch.where(alpha >= 0, torch.ones_like(alpha), -torch.ones_like(alpha))
alpha_safe = alpha_safe * torch.abs(alpha).clamp(min=eps)
loss_otherwise = (
2 * (beta_safe / alpha_safe) * (torch.pow(x / beta_safe + 1.0, 0.5 * alpha) - 1.0)
)
# Select which of the cases of the loss to return.
loss = torch.where(alpha == 0, loss_zero, torch.where(alpha == 2, loss_two, loss_otherwise))
dummy = torch.zeros_like(x)
if derivatives:
loss_two_d1 = torch.ones_like(x)
loss_zero_d1 = 2 / (x + 2)
loss_otherwise_d1 = torch.pow(x / beta_safe + 1.0, 0.5 * alpha - 1.0)
loss_d1 = torch.where(
alpha == 0, loss_zero_d1, torch.where(alpha == 2, loss_two_d1, loss_otherwise_d1)
)
return loss, loss_d1, dummy
else:
return loss, dummy, dummy
def scaled_barron(a, c):
"""Return a scaled Barron loss function."""
return lambda x: scaled_loss(x, lambda y: barron_loss(y, y.new_tensor(a)), c)