File size: 2,858 Bytes
ad552d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
from PIL import Image
from torchvision import transforms
from .lpips import LPIPS


# Normalize image tensors
def normalize_tensor(images, norm_type):
    assert norm_type in ["imagenet", "naive"]
    # Two possible normalization conventions
    if norm_type == "imagenet":
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        normalize = transforms.Normalize(mean, std)
    elif norm_type == "naive":
        mean = [0.5, 0.5, 0.5]
        std = [0.5, 0.5, 0.5]
        normalize = transforms.Normalize(mean, std)
    else:
        assert False
    return torch.stack([normalize(image) for image in images])


def to_tensor(images, norm_type="naive"):
    assert isinstance(images, list) and all(
        [isinstance(image, Image.Image) for image in images]
    )
    images = torch.stack([transforms.ToTensor()(image) for image in images])
    if norm_type is not None:
        images = normalize_tensor(images, norm_type)
    return images


def load_perceptual_models(metric_name, mode, device=torch.device("cuda")):
    assert metric_name in ["lpips"]
    if metric_name == "lpips":
        assert mode in ["vgg", "alex"]
        perceptual_model = LPIPS(net=mode).to(device)
    else:
        assert False
    return perceptual_model


# Compute metric between two images
def compute_metric(image1, image2, perceptual_model, device=torch.device("cuda")):
    assert isinstance(image1, Image.Image) and isinstance(image2, Image.Image)
    image1_tensor = to_tensor([image1]).to(device)
    image2_tensor = to_tensor([image2]).to(device)
    return perceptual_model(image1_tensor, image2_tensor).cpu().item()


# Compute LPIPS distance between two images
def compute_lpips(image1, image2, mode="alex", device=torch.device("cuda")):
    perceptual_model = load_perceptual_models("lpips", mode, device)
    return compute_metric(image1, image2, perceptual_model, device)


# Compute metrics between pairs of images
def compute_perceptual_metric_repeated(
    images1,
    images2,
    metric_name,
    mode,
    model,
    device,
):
    # Accept list of PIL images
    assert isinstance(images1, list) and isinstance(images1[0], Image.Image)
    assert isinstance(images2, list) and isinstance(images2[0], Image.Image)
    assert len(images1) == len(images2)
    if model is None:
        model = load_perceptual_models(metric_name, mode).to(device)
    return (
        model(to_tensor(images1).to(device), to_tensor(images2).to(device))
        .detach()
        .cpu()
        .numpy()
        .flatten()
        .tolist()
    )


# Compute LPIPS distance between pairs of images
def compute_lpips_repeated(
    images1,
    images2,
    mode="alex",
    model=None,
    device=torch.device("cuda"),
):
    return compute_perceptual_metric_repeated(
        images1, images2, "lpips", mode, model, device
    )