Spaces:
Runtime error
Runtime error
import torch | |
from PIL import Image | |
from torchvision import transforms | |
from .lpips import LPIPS | |
# Normalize image tensors | |
def normalize_tensor(images, norm_type): | |
assert norm_type in ["imagenet", "naive"] | |
# Two possible normalization conventions | |
if norm_type == "imagenet": | |
mean = [0.485, 0.456, 0.406] | |
std = [0.229, 0.224, 0.225] | |
normalize = transforms.Normalize(mean, std) | |
elif norm_type == "naive": | |
mean = [0.5, 0.5, 0.5] | |
std = [0.5, 0.5, 0.5] | |
normalize = transforms.Normalize(mean, std) | |
else: | |
assert False | |
return torch.stack([normalize(image) for image in images]) | |
def to_tensor(images, norm_type="naive"): | |
assert isinstance(images, list) and all( | |
[isinstance(image, Image.Image) for image in images] | |
) | |
images = torch.stack([transforms.ToTensor()(image) for image in images]) | |
if norm_type is not None: | |
images = normalize_tensor(images, norm_type) | |
return images | |
def load_perceptual_models(metric_name, mode, device=torch.device("cuda")): | |
assert metric_name in ["lpips"] | |
if metric_name == "lpips": | |
assert mode in ["vgg", "alex"] | |
perceptual_model = LPIPS(net=mode).to(device) | |
else: | |
assert False | |
return perceptual_model | |
# Compute metric between two images | |
def compute_metric(image1, image2, perceptual_model, device=torch.device("cuda")): | |
assert isinstance(image1, Image.Image) and isinstance(image2, Image.Image) | |
image1_tensor = to_tensor([image1]).to(device) | |
image2_tensor = to_tensor([image2]).to(device) | |
return perceptual_model(image1_tensor, image2_tensor).cpu().item() | |
# Compute LPIPS distance between two images | |
def compute_lpips(image1, image2, mode="alex", device=torch.device("cuda")): | |
perceptual_model = load_perceptual_models("lpips", mode, device) | |
return compute_metric(image1, image2, perceptual_model, device) | |
# Compute metrics between pairs of images | |
def compute_perceptual_metric_repeated( | |
images1, | |
images2, | |
metric_name, | |
mode, | |
model, | |
device, | |
): | |
# Accept list of PIL images | |
assert isinstance(images1, list) and isinstance(images1[0], Image.Image) | |
assert isinstance(images2, list) and isinstance(images2[0], Image.Image) | |
assert len(images1) == len(images2) | |
if model is None: | |
model = load_perceptual_models(metric_name, mode).to(device) | |
return ( | |
model(to_tensor(images1).to(device), to_tensor(images2).to(device)) | |
.detach() | |
.cpu() | |
.numpy() | |
.flatten() | |
.tolist() | |
) | |
# Compute LPIPS distance between pairs of images | |
def compute_lpips_repeated( | |
images1, | |
images2, | |
mode="alex", | |
model=None, | |
device=torch.device("cuda"), | |
): | |
return compute_perceptual_metric_repeated( | |
images1, images2, "lpips", mode, model, device | |
) | |