Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import Tensor | |
from mmpose.registry import MODELS | |
class KeypointMSELoss(nn.Module): | |
"""MSE loss for heatmaps. | |
Args: | |
use_target_weight (bool): Option to use weighted MSE loss. | |
Different joint types may have different target weights. | |
Defaults to ``False`` | |
skip_empty_channel (bool): If ``True``, heatmap channels with no | |
non-zero value (which means no visible ground-truth keypoint | |
in the image) will not be used to calculate the loss. Defaults to | |
``False`` | |
loss_weight (float): Weight of the loss. Defaults to 1.0 | |
""" | |
def __init__(self, | |
use_target_weight: bool = False, | |
skip_empty_channel: bool = False, | |
loss_weight: float = 1.): | |
super().__init__() | |
self.use_target_weight = use_target_weight | |
self.skip_empty_channel = skip_empty_channel | |
self.loss_weight = loss_weight | |
def forward(self, | |
output: Tensor, | |
target: Tensor, | |
target_weights: Optional[Tensor] = None, | |
mask: Optional[Tensor] = None) -> Tensor: | |
"""Forward function of loss. | |
Note: | |
- batch_size: B | |
- num_keypoints: K | |
- heatmaps height: H | |
- heatmaps weight: W | |
Args: | |
output (Tensor): The output heatmaps with shape [B, K, H, W] | |
target (Tensor): The target heatmaps with shape [B, K, H, W] | |
target_weights (Tensor, optional): The target weights of differet | |
keypoints, with shape [B, K] (keypoint-wise) or | |
[B, K, H, W] (pixel-wise). | |
mask (Tensor, optional): The masks of valid heatmap pixels in | |
shape [B, K, H, W] or [B, 1, H, W]. If ``None``, no mask will | |
be applied. Defaults to ``None`` | |
Returns: | |
Tensor: The calculated loss. | |
""" | |
_mask = self._get_mask(target, target_weights, mask) | |
if _mask is None: | |
loss = F.mse_loss(output, target) | |
else: | |
_loss = F.mse_loss(output, target, reduction='none') | |
loss = (_loss * _mask).mean() | |
return loss * self.loss_weight | |
def _get_mask(self, target: Tensor, target_weights: Optional[Tensor], | |
mask: Optional[Tensor]) -> Optional[Tensor]: | |
"""Generate the heatmap mask w.r.t. the given mask, target weight and | |
`skip_empty_channel` setting. | |
Returns: | |
Tensor: The mask in shape (B, K, *) or ``None`` if no mask is | |
needed. | |
""" | |
# Given spatial mask | |
if mask is not None: | |
# check mask has matching type with target | |
assert (mask.ndim == target.ndim and all( | |
d_m == d_t or d_m == 1 | |
for d_m, d_t in zip(mask.shape, target.shape))), ( | |
f'mask and target have mismatched shapes {mask.shape} v.s.' | |
f'{target.shape}') | |
# Mask by target weights (keypoint-wise mask) | |
if target_weights is not None: | |
# check target weight has matching shape with target | |
assert (target_weights.ndim in (2, 4) and target_weights.shape | |
== target.shape[:target_weights.ndim]), ( | |
'target_weights and target have mismatched shapes ' | |
f'{target_weights.shape} v.s. {target.shape}') | |
ndim_pad = target.ndim - target_weights.ndim | |
_mask = target_weights.view(target_weights.shape + | |
(1, ) * ndim_pad) | |
if mask is None: | |
mask = _mask | |
else: | |
mask = mask * _mask | |
# Mask by ``skip_empty_channel`` | |
if self.skip_empty_channel: | |
_mask = (target != 0).flatten(2).any() | |
ndim_pad = target.ndim - _mask.ndim | |
_mask = _mask.view(_mask.shape + (1, ) * ndim_pad) | |
if mask is None: | |
mask = _mask | |
else: | |
mask = mask * _mask | |
return mask | |
class CombinedTargetMSELoss(nn.Module): | |
"""MSE loss for combined target. | |
CombinedTarget: The combination of classification target | |
(response map) and regression target (offset map). | |
Paper ref: Huang et al. The Devil is in the Details: Delving into | |
Unbiased Data Processing for Human Pose Estimation (CVPR 2020). | |
Args: | |
use_target_weight (bool): Option to use weighted MSE loss. | |
Different joint types may have different target weights. | |
Defaults to ``False`` | |
loss_weight (float): Weight of the loss. Defaults to 1.0 | |
""" | |
def __init__(self, | |
use_target_weight: bool = False, | |
loss_weight: float = 1.): | |
super().__init__() | |
self.criterion = nn.MSELoss(reduction='mean') | |
self.use_target_weight = use_target_weight | |
self.loss_weight = loss_weight | |
def forward(self, output: Tensor, target: Tensor, | |
target_weights: Tensor) -> Tensor: | |
"""Forward function of loss. | |
Note: | |
- batch_size: B | |
- num_channels: C | |
- heatmaps height: H | |
- heatmaps weight: W | |
- num_keypoints: K | |
Here, C = 3 * K | |
Args: | |
output (Tensor): The output feature maps with shape [B, C, H, W]. | |
target (Tensor): The target feature maps with shape [B, C, H, W]. | |
target_weights (Tensor): The target weights of differet keypoints, | |
with shape [B, K]. | |
Returns: | |
Tensor: The calculated loss. | |
""" | |
batch_size = output.size(0) | |
num_channels = output.size(1) | |
heatmaps_pred = output.reshape( | |
(batch_size, num_channels, -1)).split(1, 1) | |
heatmaps_gt = target.reshape( | |
(batch_size, num_channels, -1)).split(1, 1) | |
loss = 0. | |
num_joints = num_channels // 3 | |
for idx in range(num_joints): | |
heatmap_pred = heatmaps_pred[idx * 3].squeeze() | |
heatmap_gt = heatmaps_gt[idx * 3].squeeze() | |
offset_x_pred = heatmaps_pred[idx * 3 + 1].squeeze() | |
offset_x_gt = heatmaps_gt[idx * 3 + 1].squeeze() | |
offset_y_pred = heatmaps_pred[idx * 3 + 2].squeeze() | |
offset_y_gt = heatmaps_gt[idx * 3 + 2].squeeze() | |
if self.use_target_weight: | |
target_weight = target_weights[:, idx, None] | |
heatmap_pred = heatmap_pred * target_weight | |
heatmap_gt = heatmap_gt * target_weight | |
# classification loss | |
loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt) | |
# regression loss | |
loss += 0.5 * self.criterion(heatmap_gt * offset_x_pred, | |
heatmap_gt * offset_x_gt) | |
loss += 0.5 * self.criterion(heatmap_gt * offset_y_pred, | |
heatmap_gt * offset_y_gt) | |
return loss / num_joints * self.loss_weight | |
class KeypointOHKMMSELoss(nn.Module): | |
"""MSE loss with online hard keypoint mining. | |
Args: | |
use_target_weight (bool): Option to use weighted MSE loss. | |
Different joint types may have different target weights. | |
Defaults to ``False`` | |
topk (int): Only top k joint losses are kept. Defaults to 8 | |
loss_weight (float): Weight of the loss. Defaults to 1.0 | |
""" | |
def __init__(self, | |
use_target_weight: bool = False, | |
topk: int = 8, | |
loss_weight: float = 1.): | |
super().__init__() | |
assert topk > 0 | |
self.criterion = nn.MSELoss(reduction='none') | |
self.use_target_weight = use_target_weight | |
self.topk = topk | |
self.loss_weight = loss_weight | |
def _ohkm(self, losses: Tensor) -> Tensor: | |
"""Online hard keypoint mining. | |
Note: | |
- batch_size: B | |
- num_keypoints: K | |
Args: | |
loss (Tensor): The losses with shape [B, K] | |
Returns: | |
Tensor: The calculated loss. | |
""" | |
ohkm_loss = 0. | |
B = losses.shape[0] | |
for i in range(B): | |
sub_loss = losses[i] | |
_, topk_idx = torch.topk( | |
sub_loss, k=self.topk, dim=0, sorted=False) | |
tmp_loss = torch.gather(sub_loss, 0, topk_idx) | |
ohkm_loss += torch.sum(tmp_loss) / self.topk | |
ohkm_loss /= B | |
return ohkm_loss | |
def forward(self, output: Tensor, target: Tensor, | |
target_weights: Tensor) -> Tensor: | |
"""Forward function of loss. | |
Note: | |
- batch_size: B | |
- num_keypoints: K | |
- heatmaps height: H | |
- heatmaps weight: W | |
Args: | |
output (Tensor): The output heatmaps with shape [B, K, H, W]. | |
target (Tensor): The target heatmaps with shape [B, K, H, W]. | |
target_weights (Tensor): The target weights of differet keypoints, | |
with shape [B, K]. | |
Returns: | |
Tensor: The calculated loss. | |
""" | |
num_keypoints = output.size(1) | |
if num_keypoints < self.topk: | |
raise ValueError(f'topk ({self.topk}) should not be ' | |
f'larger than num_keypoints ({num_keypoints}).') | |
losses = [] | |
for idx in range(num_keypoints): | |
if self.use_target_weight: | |
target_weight = target_weights[:, idx, None, None] | |
losses.append( | |
self.criterion(output[:, idx] * target_weight, | |
target[:, idx] * target_weight)) | |
else: | |
losses.append(self.criterion(output[:, idx], target[:, idx])) | |
losses = [loss.mean(dim=(1, 2)).unsqueeze(dim=1) for loss in losses] | |
losses = torch.cat(losses, dim=1) | |
return self._ohkm(losses) * self.loss_weight | |
class AdaptiveWingLoss(nn.Module): | |
"""Adaptive wing loss. paper ref: 'Adaptive Wing Loss for Robust Face | |
Alignment via Heatmap Regression' Wang et al. ICCV'2019. | |
Args: | |
alpha (float), omega (float), epsilon (float), theta (float) | |
are hyper-parameters. | |
use_target_weight (bool): Option to use weighted MSE loss. | |
Different joint types may have different target weights. | |
loss_weight (float): Weight of the loss. Default: 1.0. | |
""" | |
def __init__(self, | |
alpha=2.1, | |
omega=14, | |
epsilon=1, | |
theta=0.5, | |
use_target_weight=False, | |
loss_weight=1.): | |
super().__init__() | |
self.alpha = float(alpha) | |
self.omega = float(omega) | |
self.epsilon = float(epsilon) | |
self.theta = float(theta) | |
self.use_target_weight = use_target_weight | |
self.loss_weight = loss_weight | |
def criterion(self, pred, target): | |
"""Criterion of wingloss. | |
Note: | |
batch_size: N | |
num_keypoints: K | |
Args: | |
pred (torch.Tensor[NxKxHxW]): Predicted heatmaps. | |
target (torch.Tensor[NxKxHxW]): Target heatmaps. | |
""" | |
H, W = pred.shape[2:4] | |
delta = (target - pred).abs() | |
A = self.omega * ( | |
1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - target)) | |
) * (self.alpha - target) * (torch.pow( | |
self.theta / self.epsilon, | |
self.alpha - target - 1)) * (1 / self.epsilon) | |
C = self.theta * A - self.omega * torch.log( | |
1 + torch.pow(self.theta / self.epsilon, self.alpha - target)) | |
losses = torch.where( | |
delta < self.theta, | |
self.omega * | |
torch.log(1 + | |
torch.pow(delta / self.epsilon, self.alpha - target)), | |
A * delta - C) | |
return torch.mean(losses) | |
def forward(self, | |
output: Tensor, | |
target: Tensor, | |
target_weights: Optional[Tensor] = None): | |
"""Forward function. | |
Note: | |
batch_size: N | |
num_keypoints: K | |
Args: | |
output (torch.Tensor[N, K, H, W]): Output heatmaps. | |
target (torch.Tensor[N, K, H, W]): Target heatmaps. | |
target_weight (torch.Tensor[N, K]): | |
Weights across different joint types. | |
""" | |
if self.use_target_weight: | |
assert (target_weights.ndim in (2, 4) and target_weights.shape | |
== target.shape[:target_weights.ndim]), ( | |
'target_weights and target have mismatched shapes ' | |
f'{target_weights.shape} v.s. {target.shape}') | |
ndim_pad = target.ndim - target_weights.ndim | |
target_weights = target_weights.view(target_weights.shape + | |
(1, ) * ndim_pad) | |
loss = self.criterion(output * target_weights, | |
target * target_weights) | |
else: | |
loss = self.criterion(output, target) | |
return loss * self.loss_weight | |
class FocalHeatmapLoss(KeypointMSELoss): | |
"""A class for calculating the modified focal loss for heatmap prediction. | |
This loss function is exactly the same as the one used in CornerNet. It | |
runs faster and costs a little bit more memory. | |
`CornerNet: Detecting Objects as Paired Keypoints | |
arXiv: <https://arxiv.org/abs/1808.01244>`_. | |
Arguments: | |
alpha (int): The alpha parameter in the focal loss equation. | |
beta (int): The beta parameter in the focal loss equation. | |
use_target_weight (bool): Option to use weighted MSE loss. | |
Different joint types may have different target weights. | |
Defaults to ``False`` | |
skip_empty_channel (bool): If ``True``, heatmap channels with no | |
non-zero value (which means no visible ground-truth keypoint | |
in the image) will not be used to calculate the loss. Defaults to | |
``False`` | |
loss_weight (float): Weight of the loss. Defaults to 1.0 | |
""" | |
def __init__(self, | |
alpha: int = 2, | |
beta: int = 4, | |
use_target_weight: bool = False, | |
skip_empty_channel: bool = False, | |
loss_weight: float = 1.0): | |
super(FocalHeatmapLoss, self).__init__(use_target_weight, | |
skip_empty_channel, loss_weight) | |
self.alpha = alpha | |
self.beta = beta | |
def forward(self, | |
output: Tensor, | |
target: Tensor, | |
target_weights: Optional[Tensor] = None, | |
mask: Optional[Tensor] = None) -> Tensor: | |
"""Calculate the modified focal loss for heatmap prediction. | |
Note: | |
- batch_size: B | |
- num_keypoints: K | |
- heatmaps height: H | |
- heatmaps weight: W | |
Args: | |
output (Tensor): The output heatmaps with shape [B, K, H, W] | |
target (Tensor): The target heatmaps with shape [B, K, H, W] | |
target_weights (Tensor, optional): The target weights of differet | |
keypoints, with shape [B, K] (keypoint-wise) or | |
[B, K, H, W] (pixel-wise). | |
mask (Tensor, optional): The masks of valid heatmap pixels in | |
shape [B, K, H, W] or [B, 1, H, W]. If ``None``, no mask will | |
be applied. Defaults to ``None`` | |
Returns: | |
Tensor: The calculated loss. | |
""" | |
_mask = self._get_mask(target, target_weights, mask) | |
pos_inds = target.eq(1).float() | |
neg_inds = target.lt(1).float() | |
if _mask is not None: | |
pos_inds = pos_inds * _mask | |
neg_inds = neg_inds * _mask | |
neg_weights = torch.pow(1 - target, self.beta) | |
pos_loss = torch.log(output) * torch.pow(1 - output, | |
self.alpha) * pos_inds | |
neg_loss = torch.log(1 - output) * torch.pow( | |
output, self.alpha) * neg_weights * neg_inds | |
num_pos = pos_inds.float().sum() | |
if num_pos == 0: | |
loss = -neg_loss.sum() | |
else: | |
loss = -(pos_loss.sum() + neg_loss.sum()) / num_pos | |
return loss * self.loss_weight | |