File size: 1,221 Bytes
b55d767
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F


class PairwizeDiffLoss(nn.Module):
    def __init__(self, margin: float = 0.2, norm: str = "l1"):
        super().__init__()
        self.margin = margin
        self.norm = norm

    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        s = input.unsqueeze(1) - input.unsqueeze(0)
        t = target.unsqueeze(1) - target.unsqueeze(0)
        if self.norm not in ["l1", "l2_squared"]:
            raise ValueError(
                f'Unknown norm: {self.norm}. Must be one of ["l1", "l2_squared"]'
            )
        norm_fn = {
            "l1": torch.abs,
            "l2_squared": lambda x: x**2,
        }[self.norm]
        loss = F.relu(norm_fn(s - t) - self.margin)
        return loss.mean().div(2)


class CombinedLoss(nn.Module):
    def __init__(self, weighted_losses: list[tuple[nn.Module, float]]):
        super().__init__()
        self.weighted_losses = weighted_losses

    def forward(
        self, input: torch.Tensor, target: torch.Tensor
    ) -> list[tuple[float, torch.Tensor]]:
        return [(w, loss(input, target)) for loss, w in self.weighted_losses]