Spaces:
Runtime error
Runtime error
File size: 2,858 Bytes
ad552d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import torch
from PIL import Image
from torchvision import transforms
from .lpips import LPIPS
# Normalize image tensors
def normalize_tensor(images, norm_type):
assert norm_type in ["imagenet", "naive"]
# Two possible normalization conventions
if norm_type == "imagenet":
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
normalize = transforms.Normalize(mean, std)
elif norm_type == "naive":
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
normalize = transforms.Normalize(mean, std)
else:
assert False
return torch.stack([normalize(image) for image in images])
def to_tensor(images, norm_type="naive"):
assert isinstance(images, list) and all(
[isinstance(image, Image.Image) for image in images]
)
images = torch.stack([transforms.ToTensor()(image) for image in images])
if norm_type is not None:
images = normalize_tensor(images, norm_type)
return images
def load_perceptual_models(metric_name, mode, device=torch.device("cuda")):
assert metric_name in ["lpips"]
if metric_name == "lpips":
assert mode in ["vgg", "alex"]
perceptual_model = LPIPS(net=mode).to(device)
else:
assert False
return perceptual_model
# Compute metric between two images
def compute_metric(image1, image2, perceptual_model, device=torch.device("cuda")):
assert isinstance(image1, Image.Image) and isinstance(image2, Image.Image)
image1_tensor = to_tensor([image1]).to(device)
image2_tensor = to_tensor([image2]).to(device)
return perceptual_model(image1_tensor, image2_tensor).cpu().item()
# Compute LPIPS distance between two images
def compute_lpips(image1, image2, mode="alex", device=torch.device("cuda")):
perceptual_model = load_perceptual_models("lpips", mode, device)
return compute_metric(image1, image2, perceptual_model, device)
# Compute metrics between pairs of images
def compute_perceptual_metric_repeated(
images1,
images2,
metric_name,
mode,
model,
device,
):
# Accept list of PIL images
assert isinstance(images1, list) and isinstance(images1[0], Image.Image)
assert isinstance(images2, list) and isinstance(images2[0], Image.Image)
assert len(images1) == len(images2)
if model is None:
model = load_perceptual_models(metric_name, mode).to(device)
return (
model(to_tensor(images1).to(device), to_tensor(images2).to(device))
.detach()
.cpu()
.numpy()
.flatten()
.tolist()
)
# Compute LPIPS distance between pairs of images
def compute_lpips_repeated(
images1,
images2,
mode="alex",
model=None,
device=torch.device("cuda"),
):
return compute_perceptual_metric_repeated(
images1, images2, "lpips", mode, model, device
)
|