Spaces:
Runtime error
Runtime error
File size: 4,167 Bytes
cc0dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from mmpose.registry import MODELS
@MODELS.register_module()
class AssociativeEmbeddingLoss(nn.Module):
"""Associative Embedding loss.
Details can be found in
`Associative Embedding <https://arxiv.org/abs/1611.05424>`_
Note:
- batch size: B
- instance number: N
- keypoint number: K
- keypoint dimension: D
- embedding tag dimension: L
- heatmap size: [W, H]
Args:
loss_weight (float): Weight of the loss. Defaults to 1.0
push_loss_factor (float): A factor that controls the weight between
the push loss and the pull loss. Defaults to 0.5
"""
def __init__(self,
loss_weight: float = 1.0,
push_loss_factor: float = 0.5) -> None:
super().__init__()
self.loss_weight = loss_weight
self.push_loss_factor = push_loss_factor
def _ae_loss_per_image(self, tags: Tensor, keypoint_indices: Tensor):
"""Compute associative embedding loss for one image.
Args:
tags (Tensor): Tagging heatmaps in shape (K*L, H, W)
keypoint_indices (Tensor): Ground-truth keypint position indices
in shape (N, K, 2)
"""
K = keypoint_indices.shape[1]
C, H, W = tags.shape
L = C // K
tags = tags.view(L, K, H * W)
instance_tags = []
instance_kpt_tags = []
for keypoint_indices_n in keypoint_indices:
_kpt_tags = []
for k in range(K):
if keypoint_indices_n[k, 1]:
_kpt_tags.append(tags[:, k, keypoint_indices_n[k, 0]])
if _kpt_tags:
kpt_tags = torch.stack(_kpt_tags)
instance_kpt_tags.append(kpt_tags)
instance_tags.append(kpt_tags.mean(dim=0))
N = len(instance_kpt_tags) # number of instances with valid keypoints
if N == 0:
pull_loss = tags.new_zeros(size=(), requires_grad=True)
push_loss = tags.new_zeros(size=(), requires_grad=True)
else:
pull_loss = sum(
F.mse_loss(_kpt_tags, _tag.expand_as(_kpt_tags))
for (_kpt_tags, _tag) in zip(instance_kpt_tags, instance_tags))
if N == 1:
push_loss = tags.new_zeros(size=(), requires_grad=True)
else:
tag_mat = torch.stack(instance_tags) # (N, L)
diff = tag_mat[None] - tag_mat[:, None] # (N, N, L)
push_loss = torch.sum(torch.exp(-diff.pow(2)))
# normalization
eps = 1e-6
pull_loss = pull_loss / (N + eps)
push_loss = push_loss / ((N - 1) * N + eps)
return pull_loss, push_loss
def forward(self, tags: Tensor, keypoint_indices: Union[List[Tensor],
Tensor]):
"""Compute associative embedding loss on a batch of data.
Args:
tags (Tensor): Tagging heatmaps in shape (B, L*K, H, W)
keypoint_indices (Tensor|List[Tensor]): Ground-truth keypint
position indices represented by a Tensor in shape
(B, N, K, 2), or a list of B Tensors in shape (N_i, K, 2)
Each keypoint's index is represented as [i, v], where i is the
position index in the heatmap (:math:`i=y*w+x`) and v is the
visibility
Returns:
tuple:
- pull_loss (Tensor)
- push_loss (Tensor)
"""
assert tags.shape[0] == len(keypoint_indices)
pull_loss = 0.
push_loss = 0.
for i in range(tags.shape[0]):
_pull, _push = self._ae_loss_per_image(tags[i],
keypoint_indices[i])
pull_loss += _pull * self.loss_weight
push_loss += _push * self.loss_weight * self.push_loss_factor
return pull_loss, push_loss
|