Spaces:
Runtime error
Runtime error
File size: 1,251 Bytes
f85e212 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import lpips
import torch
class LPIPS(torch.nn.Module):
"""Learned Perceptual Image Patch Similarity (LPIPS)"""
def __init__(self, linear_calibration=False, normalize=False):
super().__init__()
self.loss_fn = lpips.LPIPS(net='vgg', lpips=linear_calibration) # Note: only 'vgg' valid as loss
self.normalize = normalize # If true, normalize [0, 1] to [-1, 1]
def forward(self, pred, target):
# No need to do that because ScalingLayer was introduced in version 0.1 which does this indirectly
# if pred.shape[1] == 1: # convert 1-channel gray images to 3-channel RGB
# pred = torch.concat([pred, pred, pred], dim=1)
# if target.shape[1] == 1: # convert 1-channel gray images to 3-channel RGB
# target = torch.concat([target, target, target], dim=1)
if pred.ndim == 5: # 3D Image: Just use 2D model and compute average over slices
depth = pred.shape[2]
losses = torch.stack([self.loss_fn(pred[:,:,d], target[:,:,d], normalize=self.normalize) for d in range(depth)], dim=2)
return torch.mean(losses, dim=2, keepdim=True)
else:
return self.loss_fn(pred, target, normalize=self.normalize)
|