Spaces:
Runtime error
Runtime error
import lpips | |
import torch | |
class LPIPS(torch.nn.Module): | |
"""Learned Perceptual Image Patch Similarity (LPIPS)""" | |
def __init__(self, linear_calibration=False, normalize=False): | |
super().__init__() | |
self.loss_fn = lpips.LPIPS(net='vgg', lpips=linear_calibration) # Note: only 'vgg' valid as loss | |
self.normalize = normalize # If true, normalize [0, 1] to [-1, 1] | |
def forward(self, pred, target): | |
# No need to do that because ScalingLayer was introduced in version 0.1 which does this indirectly | |
# if pred.shape[1] == 1: # convert 1-channel gray images to 3-channel RGB | |
# pred = torch.concat([pred, pred, pred], dim=1) | |
# if target.shape[1] == 1: # convert 1-channel gray images to 3-channel RGB | |
# target = torch.concat([target, target, target], dim=1) | |
if pred.ndim == 5: # 3D Image: Just use 2D model and compute average over slices | |
depth = pred.shape[2] | |
losses = torch.stack([self.loss_fn(pred[:,:,d], target[:,:,d], normalize=self.normalize) for d in range(depth)], dim=2) | |
return torch.mean(losses, dim=2, keepdim=True) | |
else: | |
return self.loss_fn(pred, target, normalize=self.normalize) | |