|
"""Generic losses and error functions for optimization or training deep networks.""" |
|
|
|
from typing import Callable, Tuple |
|
|
|
import torch |
|
|
|
|
|
def scaled_loss( |
|
x: torch.Tensor, fn: Callable, a: float |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
"""Apply a loss function to a tensor and pre- and post-scale it. |
|
|
|
Args: |
|
x: the data tensor, should already be squared: `x = y**2`. |
|
fn: the loss function, with signature `fn(x) -> y`. |
|
a: the scale parameter. |
|
|
|
Returns: |
|
The value of the loss, and its first and second derivatives. |
|
""" |
|
a2 = a**2 |
|
loss, loss_d1, loss_d2 = fn(x / a2) |
|
return loss * a2, loss_d1, loss_d2 / a2 |
|
|
|
|
|
def squared_loss(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
"""A dummy squared loss.""" |
|
return x, torch.ones_like(x), torch.zeros_like(x) |
|
|
|
|
|
def huber_loss(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
"""The classical robust Huber loss, with first and second derivatives.""" |
|
mask = x <= 1 |
|
sx = torch.sqrt(x + 1e-8) |
|
isx = torch.max(sx.new_tensor(torch.finfo(torch.float).eps), 1 / sx) |
|
loss = torch.where(mask, x, 2 * sx - 1) |
|
loss_d1 = torch.where(mask, torch.ones_like(x), isx) |
|
loss_d2 = torch.where(mask, torch.zeros_like(x), -isx / (2 * x)) |
|
return loss, loss_d1, loss_d2 |
|
|
|
|
|
def barron_loss( |
|
x: torch.Tensor, alpha: torch.Tensor, derivatives: bool = True, eps: float = 1e-7 |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
"""Parameterized & adaptive robust loss function. |
|
|
|
Described in: |
|
A General and Adaptive Robust Loss Function, Barron, CVPR 2019 |
|
|
|
alpha = 2 -> L2 loss |
|
alpha = 1 -> Charbonnier loss (smooth L1) |
|
alpha = 0 -> Cauchy loss |
|
alpha = -2 -> Geman-McClure loss |
|
alpha = -inf -> Welsch loss |
|
|
|
Contrary to the original implementation, assume the the input is already |
|
squared and scaled (basically scale=1). Computes the first derivative, but |
|
not the second (TODO if needed). |
|
""" |
|
loss_two = x |
|
loss_zero = 2 * torch.log1p(torch.clamp(0.5 * x, max=33e37)) |
|
|
|
|
|
|
|
beta_safe = torch.abs(alpha - 2.0).clamp(min=eps) |
|
|
|
alpha_safe = torch.where(alpha >= 0, torch.ones_like(alpha), -torch.ones_like(alpha)) |
|
alpha_safe = alpha_safe * torch.abs(alpha).clamp(min=eps) |
|
|
|
loss_otherwise = ( |
|
2 * (beta_safe / alpha_safe) * (torch.pow(x / beta_safe + 1.0, 0.5 * alpha) - 1.0) |
|
) |
|
|
|
|
|
loss = torch.where(alpha == 0, loss_zero, torch.where(alpha == 2, loss_two, loss_otherwise)) |
|
dummy = torch.zeros_like(x) |
|
|
|
if derivatives: |
|
loss_two_d1 = torch.ones_like(x) |
|
loss_zero_d1 = 2 / (x + 2) |
|
loss_otherwise_d1 = torch.pow(x / beta_safe + 1.0, 0.5 * alpha - 1.0) |
|
loss_d1 = torch.where( |
|
alpha == 0, loss_zero_d1, torch.where(alpha == 2, loss_two_d1, loss_otherwise_d1) |
|
) |
|
|
|
return loss, loss_d1, dummy |
|
else: |
|
return loss, dummy, dummy |
|
|
|
|
|
def scaled_barron(a, c): |
|
"""Return a scaled Barron loss function.""" |
|
return lambda x: scaled_loss(x, lambda y: barron_loss(y, y.new_tensor(a)), c) |
|
|