File size: 2,644 Bytes
cc0dd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict

import torch.nn as nn

from mmpose.registry import MODELS
from mmpose.utils.typing import ConfigType


@MODELS.register_module()
class MultipleLossWrapper(nn.Module):
    """A wrapper to collect multiple loss functions together and return a list
    of losses in the same order.

    Args:
        losses (list): List of Loss Config
    """

    def __init__(self, losses: list):
        super().__init__()
        self.num_losses = len(losses)

        loss_modules = []
        for loss_cfg in losses:
            t_loss = MODELS.build(loss_cfg)
            loss_modules.append(t_loss)
        self.loss_modules = nn.ModuleList(loss_modules)

    def forward(self, input_list, target_list, keypoint_weights=None):
        """Forward function.

        Note:
            - batch_size: N
            - num_keypoints: K
            - dimension of keypoints: D (D=2 or D=3)

        Args:
            input_list (List[Tensor]): List of inputs.
            target_list (List[Tensor]): List of targets.
            keypoint_weights (Tensor[N, K, D]):
                Weights across different joint types.
        """
        assert isinstance(input_list, list), ''
        assert isinstance(target_list, list), ''
        assert len(input_list) == len(target_list), ''

        losses = []
        for i in range(self.num_losses):
            input_i = input_list[i]
            target_i = target_list[i]

            loss_i = self.loss_modules[i](input_i, target_i, keypoint_weights)
            losses.append(loss_i)

        return losses


@MODELS.register_module()
class CombinedLoss(nn.ModuleDict):
    """A wrapper to combine multiple loss functions. These loss functions can
    have different input type (e.g. heatmaps or regression values), and can
    only be involed individually and explixitly.

    Args:
        losses (Dict[str, ConfigType]): The names and configs of loss
            functions to be wrapped

    Example::
        >>> heatmap_loss_cfg = dict(type='KeypointMSELoss')
        >>> ae_loss_cfg = dict(type='AssociativeEmbeddingLoss')
        >>> loss_module = CombinedLoss(
        ...     losses=dict(
        ...         heatmap_loss=heatmap_loss_cfg,
        ...         ae_loss=ae_loss_cfg))
        >>> loss_hm = loss_module.heatmap_loss(pred_heatmap, gt_heatmap)
        >>> loss_ae = loss_module.ae_loss(pred_tags, keypoint_indices)
    """

    def __init__(self, losses: Dict[str, ConfigType]):
        super().__init__()
        for loss_name, loss_cfg in losses.items():
            self.add_module(loss_name, MODELS.build(loss_cfg))