File size: 1,337 Bytes
cc0dd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.model import BaseModule

from mmpretrain.registry import MODELS


@MODELS.register_module()
class CrossCorrelationLoss(BaseModule):
    """Cross correlation loss function.

    Compute the on-diagnal and off-diagnal loss.

    Args:
        lambd (float): The weight for the off-diag loss.
    """

    def __init__(self, lambd: float = 0.0051) -> None:
        super().__init__()
        self.lambd = lambd

    def forward(self, cross_correlation_matrix: torch.Tensor) -> torch.Tensor:
        """Forward function of cross correlation loss.

        Args:
            cross_correlation_matrix (torch.Tensor): The cross correlation
                matrix.

        Returns:
            torch.Tensor: cross correlation loss.
        """
        # loss
        on_diag = torch.diagonal(cross_correlation_matrix).add_(-1).pow_(
            2).sum()
        off_diag = self.off_diagonal(cross_correlation_matrix).pow_(2).sum()
        loss = on_diag + self.lambd * off_diag
        return loss

    def off_diagonal(self, x: torch.Tensor) -> torch.Tensor:
        """Rreturn a flattened view of the off-diagonal elements of a square
        matrix."""
        n, m = x.shape
        assert n == m
        return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()