Spaces:
Runtime error
Runtime error
File size: 7,146 Bytes
cc0dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmpose.registry import MODELS
@MODELS.register_module()
class BCELoss(nn.Module):
"""Binary Cross Entropy loss.
Args:
use_target_weight (bool): Option to use weighted loss.
Different joint types may have different target weights.
loss_weight (float): Weight of the loss. Default: 1.0.
"""
def __init__(self, use_target_weight=False, loss_weight=1.):
super().__init__()
self.criterion = F.binary_cross_entropy
self.use_target_weight = use_target_weight
self.loss_weight = loss_weight
def forward(self, output, target, target_weight=None):
"""Forward function.
Note:
- batch_size: N
- num_labels: K
Args:
output (torch.Tensor[N, K]): Output classification.
target (torch.Tensor[N, K]): Target classification.
target_weight (torch.Tensor[N, K] or torch.Tensor[N]):
Weights across different labels.
"""
if self.use_target_weight:
assert target_weight is not None
loss = self.criterion(output, target, reduction='none')
if target_weight.dim() == 1:
target_weight = target_weight[:, None]
loss = (loss * target_weight).mean()
else:
loss = self.criterion(output, target)
return loss * self.loss_weight
@MODELS.register_module()
class JSDiscretLoss(nn.Module):
"""Discrete JS Divergence loss for DSNT with Gaussian Heatmap.
Modified from `the official implementation
<https://github.com/anibali/dsntnn/blob/master/dsntnn/__init__.py>`_.
Args:
use_target_weight (bool): Option to use weighted loss.
Different joint types may have different target weights.
size_average (bool): Option to average the loss by the batch_size.
"""
def __init__(
self,
use_target_weight=True,
size_average: bool = True,
):
super(JSDiscretLoss, self).__init__()
self.use_target_weight = use_target_weight
self.size_average = size_average
self.kl_loss = nn.KLDivLoss(reduction='none')
def kl(self, p, q):
"""Kullback-Leibler Divergence."""
eps = 1e-24
kl_values = self.kl_loss((q + eps).log(), p)
return kl_values
def js(self, pred_hm, gt_hm):
"""Jensen-Shannon Divergence."""
m = 0.5 * (pred_hm + gt_hm)
js_values = 0.5 * (self.kl(pred_hm, m) + self.kl(gt_hm, m))
return js_values
def forward(self, pred_hm, gt_hm, target_weight=None):
"""Forward function.
Args:
pred_hm (torch.Tensor[N, K, H, W]): Predicted heatmaps.
gt_hm (torch.Tensor[N, K, H, W]): Target heatmaps.
target_weight (torch.Tensor[N, K] or torch.Tensor[N]):
Weights across different labels.
Returns:
torch.Tensor: Loss value.
"""
if self.use_target_weight:
assert target_weight is not None
assert pred_hm.ndim >= target_weight.ndim
for i in range(pred_hm.ndim - target_weight.ndim):
target_weight = target_weight.unsqueeze(-1)
loss = self.js(pred_hm * target_weight, gt_hm * target_weight)
else:
loss = self.js(pred_hm, gt_hm)
if self.size_average:
loss /= len(gt_hm)
return loss.sum()
@MODELS.register_module()
class KLDiscretLoss(nn.Module):
"""Discrete KL Divergence loss for SimCC with Gaussian Label Smoothing.
Modified from `the official implementation.
<https://github.com/leeyegy/SimCC>`_.
Args:
beta (float): Temperature factor of Softmax.
label_softmax (bool): Whether to use Softmax on labels.
use_target_weight (bool): Option to use weighted loss.
Different joint types may have different target weights.
"""
def __init__(self, beta=1.0, label_softmax=False, use_target_weight=True):
super(KLDiscretLoss, self).__init__()
self.beta = beta
self.label_softmax = label_softmax
self.use_target_weight = use_target_weight
self.log_softmax = nn.LogSoftmax(dim=1)
self.kl_loss = nn.KLDivLoss(reduction='none')
def criterion(self, dec_outs, labels):
"""Criterion function."""
log_pt = self.log_softmax(dec_outs * self.beta)
if self.label_softmax:
labels = F.softmax(labels * self.beta, dim=1)
loss = torch.mean(self.kl_loss(log_pt, labels), dim=1)
return loss
def forward(self, pred_simcc, gt_simcc, target_weight):
"""Forward function.
Args:
pred_simcc (Tuple[Tensor, Tensor]): Predicted SimCC vectors of
x-axis and y-axis.
gt_simcc (Tuple[Tensor, Tensor]): Target representations.
target_weight (torch.Tensor[N, K] or torch.Tensor[N]):
Weights across different labels.
"""
num_joints = pred_simcc[0].size(1)
loss = 0
if self.use_target_weight:
weight = target_weight.reshape(-1)
else:
weight = 1.
for pred, target in zip(pred_simcc, gt_simcc):
pred = pred.reshape(-1, pred.size(-1))
target = target.reshape(-1, target.size(-1))
loss += self.criterion(pred, target).mul(weight).sum()
return loss / num_joints
@MODELS.register_module()
class InfoNCELoss(nn.Module):
"""InfoNCE loss for training a discriminative representation space with a
contrastive manner.
`Representation Learning with Contrastive Predictive Coding
arXiv: <https://arxiv.org/abs/1611.05424>`_.
Args:
temperature (float, optional): The temperature to use in the softmax
function. Higher temperatures lead to softer probability
distributions. Defaults to 1.0.
loss_weight (float, optional): The weight to apply to the loss.
Defaults to 1.0.
"""
def __init__(self, temperature: float = 1.0, loss_weight=1.0) -> None:
super(InfoNCELoss, self).__init__()
assert temperature > 0, f'the argument `temperature` must be ' \
f'positive, but got {temperature}'
self.temp = temperature
self.loss_weight = loss_weight
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""Computes the InfoNCE loss.
Args:
features (Tensor): A tensor containing the feature
representations of different samples.
Returns:
Tensor: A tensor of shape (1,) containing the InfoNCE loss.
"""
n = features.size(0)
features_norm = F.normalize(features, dim=1)
logits = features_norm.mm(features_norm.t()) / self.temp
targets = torch.arange(n, dtype=torch.long, device=features.device)
loss = F.cross_entropy(logits, targets, reduction='sum')
return loss * self.loss_weight
|