File size: 1,251 Bytes
f85e212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27


import lpips
import torch

class LPIPS(torch.nn.Module):
    """Learned Perceptual Image Patch Similarity (LPIPS)"""
    def __init__(self, linear_calibration=False, normalize=False):
        super().__init__()
        self.loss_fn = lpips.LPIPS(net='vgg', lpips=linear_calibration) # Note: only 'vgg' valid as loss  
        self.normalize = normalize # If true, normalize [0, 1] to [-1, 1]
        

    def forward(self, pred, target):
        # No need to do that because ScalingLayer was introduced in version 0.1 which does this indirectly  
        # if pred.shape[1] == 1: # convert 1-channel gray images to 3-channel RGB
        #     pred = torch.concat([pred, pred, pred], dim=1)
        # if target.shape[1] == 1: # convert 1-channel gray images to 3-channel RGB 
        #     target = torch.concat([target, target, target], dim=1)

        if pred.ndim == 5: # 3D Image: Just use 2D model and compute average over slices 
            depth = pred.shape[2] 
            losses = torch.stack([self.loss_fn(pred[:,:,d], target[:,:,d], normalize=self.normalize) for d in range(depth)], dim=2)
            return torch.mean(losses, dim=2, keepdim=True)
        else:
            return self.loss_fn(pred, target, normalize=self.normalize)