Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmpose.registry import MODELS | |
class BCELoss(nn.Module): | |
"""Binary Cross Entropy loss. | |
Args: | |
use_target_weight (bool): Option to use weighted loss. | |
Different joint types may have different target weights. | |
loss_weight (float): Weight of the loss. Default: 1.0. | |
""" | |
def __init__(self, use_target_weight=False, loss_weight=1.): | |
super().__init__() | |
self.criterion = F.binary_cross_entropy | |
self.use_target_weight = use_target_weight | |
self.loss_weight = loss_weight | |
def forward(self, output, target, target_weight=None): | |
"""Forward function. | |
Note: | |
- batch_size: N | |
- num_labels: K | |
Args: | |
output (torch.Tensor[N, K]): Output classification. | |
target (torch.Tensor[N, K]): Target classification. | |
target_weight (torch.Tensor[N, K] or torch.Tensor[N]): | |
Weights across different labels. | |
""" | |
if self.use_target_weight: | |
assert target_weight is not None | |
loss = self.criterion(output, target, reduction='none') | |
if target_weight.dim() == 1: | |
target_weight = target_weight[:, None] | |
loss = (loss * target_weight).mean() | |
else: | |
loss = self.criterion(output, target) | |
return loss * self.loss_weight | |
class JSDiscretLoss(nn.Module): | |
"""Discrete JS Divergence loss for DSNT with Gaussian Heatmap. | |
Modified from `the official implementation | |
<https://github.com/anibali/dsntnn/blob/master/dsntnn/__init__.py>`_. | |
Args: | |
use_target_weight (bool): Option to use weighted loss. | |
Different joint types may have different target weights. | |
size_average (bool): Option to average the loss by the batch_size. | |
""" | |
def __init__( | |
self, | |
use_target_weight=True, | |
size_average: bool = True, | |
): | |
super(JSDiscretLoss, self).__init__() | |
self.use_target_weight = use_target_weight | |
self.size_average = size_average | |
self.kl_loss = nn.KLDivLoss(reduction='none') | |
def kl(self, p, q): | |
"""Kullback-Leibler Divergence.""" | |
eps = 1e-24 | |
kl_values = self.kl_loss((q + eps).log(), p) | |
return kl_values | |
def js(self, pred_hm, gt_hm): | |
"""Jensen-Shannon Divergence.""" | |
m = 0.5 * (pred_hm + gt_hm) | |
js_values = 0.5 * (self.kl(pred_hm, m) + self.kl(gt_hm, m)) | |
return js_values | |
def forward(self, pred_hm, gt_hm, target_weight=None): | |
"""Forward function. | |
Args: | |
pred_hm (torch.Tensor[N, K, H, W]): Predicted heatmaps. | |
gt_hm (torch.Tensor[N, K, H, W]): Target heatmaps. | |
target_weight (torch.Tensor[N, K] or torch.Tensor[N]): | |
Weights across different labels. | |
Returns: | |
torch.Tensor: Loss value. | |
""" | |
if self.use_target_weight: | |
assert target_weight is not None | |
assert pred_hm.ndim >= target_weight.ndim | |
for i in range(pred_hm.ndim - target_weight.ndim): | |
target_weight = target_weight.unsqueeze(-1) | |
loss = self.js(pred_hm * target_weight, gt_hm * target_weight) | |
else: | |
loss = self.js(pred_hm, gt_hm) | |
if self.size_average: | |
loss /= len(gt_hm) | |
return loss.sum() | |
class KLDiscretLoss(nn.Module): | |
"""Discrete KL Divergence loss for SimCC with Gaussian Label Smoothing. | |
Modified from `the official implementation. | |
<https://github.com/leeyegy/SimCC>`_. | |
Args: | |
beta (float): Temperature factor of Softmax. | |
label_softmax (bool): Whether to use Softmax on labels. | |
use_target_weight (bool): Option to use weighted loss. | |
Different joint types may have different target weights. | |
""" | |
def __init__(self, beta=1.0, label_softmax=False, use_target_weight=True): | |
super(KLDiscretLoss, self).__init__() | |
self.beta = beta | |
self.label_softmax = label_softmax | |
self.use_target_weight = use_target_weight | |
self.log_softmax = nn.LogSoftmax(dim=1) | |
self.kl_loss = nn.KLDivLoss(reduction='none') | |
def criterion(self, dec_outs, labels): | |
"""Criterion function.""" | |
log_pt = self.log_softmax(dec_outs * self.beta) | |
if self.label_softmax: | |
labels = F.softmax(labels * self.beta, dim=1) | |
loss = torch.mean(self.kl_loss(log_pt, labels), dim=1) | |
return loss | |
def forward(self, pred_simcc, gt_simcc, target_weight): | |
"""Forward function. | |
Args: | |
pred_simcc (Tuple[Tensor, Tensor]): Predicted SimCC vectors of | |
x-axis and y-axis. | |
gt_simcc (Tuple[Tensor, Tensor]): Target representations. | |
target_weight (torch.Tensor[N, K] or torch.Tensor[N]): | |
Weights across different labels. | |
""" | |
num_joints = pred_simcc[0].size(1) | |
loss = 0 | |
if self.use_target_weight: | |
weight = target_weight.reshape(-1) | |
else: | |
weight = 1. | |
for pred, target in zip(pred_simcc, gt_simcc): | |
pred = pred.reshape(-1, pred.size(-1)) | |
target = target.reshape(-1, target.size(-1)) | |
loss += self.criterion(pred, target).mul(weight).sum() | |
return loss / num_joints | |
class InfoNCELoss(nn.Module): | |
"""InfoNCE loss for training a discriminative representation space with a | |
contrastive manner. | |
`Representation Learning with Contrastive Predictive Coding | |
arXiv: <https://arxiv.org/abs/1611.05424>`_. | |
Args: | |
temperature (float, optional): The temperature to use in the softmax | |
function. Higher temperatures lead to softer probability | |
distributions. Defaults to 1.0. | |
loss_weight (float, optional): The weight to apply to the loss. | |
Defaults to 1.0. | |
""" | |
def __init__(self, temperature: float = 1.0, loss_weight=1.0) -> None: | |
super(InfoNCELoss, self).__init__() | |
assert temperature > 0, f'the argument `temperature` must be ' \ | |
f'positive, but got {temperature}' | |
self.temp = temperature | |
self.loss_weight = loss_weight | |
def forward(self, features: torch.Tensor) -> torch.Tensor: | |
"""Computes the InfoNCE loss. | |
Args: | |
features (Tensor): A tensor containing the feature | |
representations of different samples. | |
Returns: | |
Tensor: A tensor of shape (1,) containing the InfoNCE loss. | |
""" | |
n = features.size(0) | |
features_norm = F.normalize(features, dim=1) | |
logits = features_norm.mm(features_norm.t()) / self.temp | |
targets = torch.arange(n, dtype=torch.long, device=features.device) | |
loss = F.cross_entropy(logits, targets, reduction='sum') | |
return loss * self.loss_weight | |