HumanSD / mmpose /models /losses /classification_loss.py
liyy201912's picture
Upload folder using huggingface_hub
cc0dd3c
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmpose.registry import MODELS
@MODELS.register_module()
class BCELoss(nn.Module):
"""Binary Cross Entropy loss.
Args:
use_target_weight (bool): Option to use weighted loss.
Different joint types may have different target weights.
loss_weight (float): Weight of the loss. Default: 1.0.
"""
def __init__(self, use_target_weight=False, loss_weight=1.):
super().__init__()
self.criterion = F.binary_cross_entropy
self.use_target_weight = use_target_weight
self.loss_weight = loss_weight
def forward(self, output, target, target_weight=None):
"""Forward function.
Note:
- batch_size: N
- num_labels: K
Args:
output (torch.Tensor[N, K]): Output classification.
target (torch.Tensor[N, K]): Target classification.
target_weight (torch.Tensor[N, K] or torch.Tensor[N]):
Weights across different labels.
"""
if self.use_target_weight:
assert target_weight is not None
loss = self.criterion(output, target, reduction='none')
if target_weight.dim() == 1:
target_weight = target_weight[:, None]
loss = (loss * target_weight).mean()
else:
loss = self.criterion(output, target)
return loss * self.loss_weight
@MODELS.register_module()
class JSDiscretLoss(nn.Module):
"""Discrete JS Divergence loss for DSNT with Gaussian Heatmap.
Modified from `the official implementation
<https://github.com/anibali/dsntnn/blob/master/dsntnn/__init__.py>`_.
Args:
use_target_weight (bool): Option to use weighted loss.
Different joint types may have different target weights.
size_average (bool): Option to average the loss by the batch_size.
"""
def __init__(
self,
use_target_weight=True,
size_average: bool = True,
):
super(JSDiscretLoss, self).__init__()
self.use_target_weight = use_target_weight
self.size_average = size_average
self.kl_loss = nn.KLDivLoss(reduction='none')
def kl(self, p, q):
"""Kullback-Leibler Divergence."""
eps = 1e-24
kl_values = self.kl_loss((q + eps).log(), p)
return kl_values
def js(self, pred_hm, gt_hm):
"""Jensen-Shannon Divergence."""
m = 0.5 * (pred_hm + gt_hm)
js_values = 0.5 * (self.kl(pred_hm, m) + self.kl(gt_hm, m))
return js_values
def forward(self, pred_hm, gt_hm, target_weight=None):
"""Forward function.
Args:
pred_hm (torch.Tensor[N, K, H, W]): Predicted heatmaps.
gt_hm (torch.Tensor[N, K, H, W]): Target heatmaps.
target_weight (torch.Tensor[N, K] or torch.Tensor[N]):
Weights across different labels.
Returns:
torch.Tensor: Loss value.
"""
if self.use_target_weight:
assert target_weight is not None
assert pred_hm.ndim >= target_weight.ndim
for i in range(pred_hm.ndim - target_weight.ndim):
target_weight = target_weight.unsqueeze(-1)
loss = self.js(pred_hm * target_weight, gt_hm * target_weight)
else:
loss = self.js(pred_hm, gt_hm)
if self.size_average:
loss /= len(gt_hm)
return loss.sum()
@MODELS.register_module()
class KLDiscretLoss(nn.Module):
"""Discrete KL Divergence loss for SimCC with Gaussian Label Smoothing.
Modified from `the official implementation.
<https://github.com/leeyegy/SimCC>`_.
Args:
beta (float): Temperature factor of Softmax.
label_softmax (bool): Whether to use Softmax on labels.
use_target_weight (bool): Option to use weighted loss.
Different joint types may have different target weights.
"""
def __init__(self, beta=1.0, label_softmax=False, use_target_weight=True):
super(KLDiscretLoss, self).__init__()
self.beta = beta
self.label_softmax = label_softmax
self.use_target_weight = use_target_weight
self.log_softmax = nn.LogSoftmax(dim=1)
self.kl_loss = nn.KLDivLoss(reduction='none')
def criterion(self, dec_outs, labels):
"""Criterion function."""
log_pt = self.log_softmax(dec_outs * self.beta)
if self.label_softmax:
labels = F.softmax(labels * self.beta, dim=1)
loss = torch.mean(self.kl_loss(log_pt, labels), dim=1)
return loss
def forward(self, pred_simcc, gt_simcc, target_weight):
"""Forward function.
Args:
pred_simcc (Tuple[Tensor, Tensor]): Predicted SimCC vectors of
x-axis and y-axis.
gt_simcc (Tuple[Tensor, Tensor]): Target representations.
target_weight (torch.Tensor[N, K] or torch.Tensor[N]):
Weights across different labels.
"""
num_joints = pred_simcc[0].size(1)
loss = 0
if self.use_target_weight:
weight = target_weight.reshape(-1)
else:
weight = 1.
for pred, target in zip(pred_simcc, gt_simcc):
pred = pred.reshape(-1, pred.size(-1))
target = target.reshape(-1, target.size(-1))
loss += self.criterion(pred, target).mul(weight).sum()
return loss / num_joints
@MODELS.register_module()
class InfoNCELoss(nn.Module):
"""InfoNCE loss for training a discriminative representation space with a
contrastive manner.
`Representation Learning with Contrastive Predictive Coding
arXiv: <https://arxiv.org/abs/1611.05424>`_.
Args:
temperature (float, optional): The temperature to use in the softmax
function. Higher temperatures lead to softer probability
distributions. Defaults to 1.0.
loss_weight (float, optional): The weight to apply to the loss.
Defaults to 1.0.
"""
def __init__(self, temperature: float = 1.0, loss_weight=1.0) -> None:
super(InfoNCELoss, self).__init__()
assert temperature > 0, f'the argument `temperature` must be ' \
f'positive, but got {temperature}'
self.temp = temperature
self.loss_weight = loss_weight
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""Computes the InfoNCE loss.
Args:
features (Tensor): A tensor containing the feature
representations of different samples.
Returns:
Tensor: A tensor of shape (1,) containing the InfoNCE loss.
"""
n = features.size(0)
features_norm = F.normalize(features, dim=1)
logits = features_norm.mm(features_norm.t()) / self.temp
targets = torch.arange(n, dtype=torch.long, device=features.device)
loss = F.cross_entropy(logits, targets, reduction='sum')
return loss * self.loss_weight