File size: 3,226 Bytes
14ce5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""This file contains the definition of the perceptual loss."""

import torch

from torchvision import models
from torchvision.models.feature_extraction import create_feature_extractor


class PerceptualLoss(torch.nn.Module):
    def __init__(
        self,
        model_name: str = "resnet50",
        compute_perceptual_loss_on_logits: bool = True,
    ):
        """Initialize the perceptual loss.

        Args:
            model_name -> str: The name of the model to use.
            compute_perceptual_loss_on_logits -> bool: Whether to compute the perceptual loss on the logits
                or the features.
        """
        super().__init__()
        if model_name == "resnet50":
            model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
            return_nodes = {"layer4": "features", "fc": "logits"}
        elif model_name == "convnext_s":
            model = models.convnext_small(
                weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1
            )
            return_nodes = {"features": "features", "classifier": "logits"}

        if compute_perceptual_loss_on_logits:
            self.model = model
        else:
            self.model = create_feature_extractor(model, return_nodes=return_nodes)

        self.compute_perceptual_loss_on_logits = compute_perceptual_loss_on_logits

        self.register_buffer(
            "mean", torch.Tensor([0.485, 0.456, 0.406])[None, :, None, None]
        )
        self.register_buffer(
            "std", torch.Tensor([0.229, 0.224, 0.225])[None, :, None, None]
        )

        for param in self.parameters():
            param.requires_grad = False

    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """Compute the perceptual loss.

        Args:
            input -> torch.Tensor: The input tensor.
            target -> torch.Tensor: The target tensor.

        Returns:
            loss -> torch.Tensor: The perceptual loss.
        """
        input = torch.nn.functional.interpolate(
            input, size=224, mode="bilinear", antialias=True, align_corners=False
        )
        target = torch.nn.functional.interpolate(
            target, size=224, mode="bilinear", antialias=True, align_corners=False
        )

        input = (input - self.mean) / self.std
        target = (target - self.mean) / self.std

        features_input = self.model(input)
        features_target = self.model(target)

        if self.compute_perceptual_loss_on_logits:
            loss = torch.nn.functional.mse_loss(
                features_input, features_target, reduction="mean"
            )
        else:
            loss = torch.nn.functional.mse_loss(
                features_input["features"],
                features_target["features"],
                reduction="mean",
            )
            loss += torch.nn.functional.mse_loss(
                features_input["logits"], features_target["logits"], reduction="mean"
            )
        return loss


if __name__ == "__main__":
    model = PerceptualLoss()
    input = torch.randn(2, 3, 256, 256).clamp_(0, 1)
    target = torch.randn(2, 3, 256, 256).clamp_(0, 1)
    loss = model(input, target)
    print(loss)