File size: 5,792 Bytes
373af33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
from .utils import weighted_loss
def gmof(x, sigma):
"""Geman-McClure error function.
Args:
x (torch.Tensor): The input tensor.
sigma (float): The sigma value used in the calculation.
Returns:
torch.Tensor: The computed Geman-McClure error.
"""
x_squared = x**2
sigma_squared = sigma**2
return (sigma_squared * x_squared) / (sigma_squared + x_squared)
@weighted_loss
def mse_loss(pred, target):
"""Wrapper for Mean Squared Error (MSE) loss.
Args:
pred (torch.Tensor): Predicted values.
target (torch.Tensor): Ground truth values.
Returns:
torch.Tensor: MSE loss.
"""
return F.mse_loss(pred, target, reduction='none')
@weighted_loss
def smooth_l1_loss(pred, target):
"""Wrapper for Smooth L1 loss.
Args:
pred (torch.Tensor): Predicted values.
target (torch.Tensor): Ground truth values.
Returns:
torch.Tensor: Smooth L1 loss.
"""
return F.smooth_l1_loss(pred, target, reduction='none')
@weighted_loss
def l1_loss(pred, target):
"""Wrapper for L1 loss.
Args:
pred (torch.Tensor): Predicted values.
target (torch.Tensor): Ground truth values.
Returns:
torch.Tensor: L1 loss.
"""
return F.l1_loss(pred, target, reduction='none')
@weighted_loss
def mse_loss_with_gmof(pred, target, sigma):
"""Extended MSE Loss with Geman-McClure function applied.
Args:
pred (torch.Tensor): Predicted values.
target (torch.Tensor): Ground truth values.
sigma (float): The sigma value for the Geman-McClure function.
Returns:
torch.Tensor: The loss value.
"""
loss = F.mse_loss(pred, target, reduction='none')
loss = gmof(loss, sigma)
return loss
@LOSSES.register_module()
class MSELoss(nn.Module):
"""Mean Squared Error (MSE) Loss.
Args:
reduction (str, optional): The method to reduce the loss to a scalar.
Options are 'none', 'mean', and 'sum'. Defaults to 'mean'.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
"""
def __init__(self, reduction='mean', loss_weight=1.0):
super().__init__()
assert reduction in (None, 'none', 'mean', 'sum')
self.reduction = 'none' if reduction is None else reduction
self.loss_weight = loss_weight
def forward(self, pred, target, weight=None, avg_factor=None, reduction_override=None):
"""Forward function to compute loss.
Args:
pred (torch.Tensor): Predictions.
target (torch.Tensor): Ground truth.
weight (torch.Tensor, optional): Optional weight per sample.
avg_factor (int, optional): Factor for averaging the loss.
reduction_override (str, optional): Option to override reduction method.
Returns:
torch.Tensor: Calculated loss.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = reduction_override if reduction_override else self.reduction
loss = self.loss_weight * mse_loss(pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss
@LOSSES.register_module()
class KinematicLoss(nn.Module):
"""Kinematic Loss for hierarchical motion prediction.
Args:
reduction (str, optional): Reduction method ('none', 'mean', or 'sum').
loss_type (str, optional): The type of loss to use ('mse', 'smooth_l1', 'l1').
loss_weight (list[float], optional): List of weights for each stage of the hierarchy.
"""
def __init__(self, reduction='mean', loss_type='mse', loss_weight=[1.0]):
super().__init__()
assert reduction in (None, 'none', 'mean', 'sum')
self.reduction = 'none' if reduction is None else reduction
self.loss_weight = loss_weight
self.num_stages = len(loss_weight)
# Select loss function based on loss_type
if loss_type == 'mse':
self.loss_func = mse_loss
elif loss_type == 'smooth_l1':
self.loss_func = smooth_l1_loss
elif loss_type == 'l1':
self.loss_func = l1_loss
else:
raise ValueError(f"Unknown loss type: {loss_type}")
def forward(self, pred, target, weight=None, avg_factor=None, reduction_override=None):
"""Forward function for hierarchical kinematic loss.
Args:
pred (torch.Tensor): The prediction tensor.
target (torch.Tensor): The target tensor.
weight (torch.Tensor, optional): Weights for each prediction. Defaults to None.
avg_factor (int, optional): Factor to average the loss. Defaults to None.
reduction_override (str, optional): Override reduction method. Defaults to None.
Returns:
torch.Tensor: The calculated hierarchical loss.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = reduction_override if reduction_override else self.reduction
total_loss = 0
pred_t = pred.clone()
target_t = target.clone()
# Apply loss function across stages
for i in range(self.num_stages):
stage_loss = self.loss_weight[i] * self.loss_func(
pred_t, target_t, weight, reduction=reduction, avg_factor=avg_factor)
total_loss += stage_loss
# Compute differences between consecutive frames
pred_t = torch.cat((pred_t[:, :1, :], pred_t[:, 1:] - pred_t[:, :-1]), dim=1)
target_t = torch.cat((target_t[:, :1, :], target_t[:, 1:] - target_t[:, :-1]), dim=1)
return total_loss
|