File size: 5,051 Bytes
97a6728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import numpy as np
import torch
import tops
from dp2 import utils
from torch_fidelity.helpers import get_kwarg, vassert
from torch_fidelity.defaults import DEFAULTS as PPL_DEFAULTS
from torch_fidelity.utils import sample_random, batch_interp, create_sample_similarity
from torchvision.transforms.functional import resize


def slerp(a, b, t):
    a = a / a.norm(dim=-1, keepdim=True)
    b = b / b.norm(dim=-1, keepdim=True)
    d = (a * b).sum(dim=-1, keepdim=True)
    p = t * torch.acos(d)
    c = b - d * a
    c = c / c.norm(dim=-1, keepdim=True)
    d = a * torch.cos(p) + c * torch.sin(p)
    d = d / d.norm(dim=-1, keepdim=True)
    return d


@torch.no_grad()
def calculate_ppl(
        dataloader,
        generator,
        latent_space=None,
        data_len=None,
        upsample_size=None,
        **kwargs) -> dict:
    """
    Inspired by https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py
    """
    if latent_space is None:
        latent_space = generator.latent_space
    assert latent_space in ["Z", "W"], f"Not supported latent space: {latent_space}"
    assert len(upsample_size) == 2
    epsilon = PPL_DEFAULTS["ppl_epsilon"]
    interp = PPL_DEFAULTS['ppl_z_interp_mode']
    similarity_name = PPL_DEFAULTS['ppl_sample_similarity']
    sample_similarity_resize = PPL_DEFAULTS['ppl_sample_similarity_resize']
    sample_similarity_dtype = PPL_DEFAULTS['ppl_sample_similarity_dtype']
    discard_percentile_lower = PPL_DEFAULTS['ppl_discard_percentile_lower']
    discard_percentile_higher = PPL_DEFAULTS['ppl_discard_percentile_higher']

    vassert(type(epsilon) is float and epsilon > 0, 'Epsilon must be a small positive floating point number')
    vassert(discard_percentile_lower is None or 0 < discard_percentile_lower < 100, 'Invalid percentile')
    vassert(discard_percentile_higher is None or 0 < discard_percentile_higher < 100, 'Invalid percentile')
    if discard_percentile_lower is not None and discard_percentile_higher is not None:
        vassert(0 < discard_percentile_lower < discard_percentile_higher < 100, 'Invalid percentiles')

    sample_similarity = create_sample_similarity(
        similarity_name,
        sample_similarity_resize=sample_similarity_resize,
        sample_similarity_dtype=sample_similarity_dtype,
        cuda=False,
        **kwargs
    )
    sample_similarity = tops.to_cuda(sample_similarity)
    rng = np.random.RandomState(get_kwarg('rng_seed', kwargs))
    distances = []
    if data_len is None:
        data_len = len(dataloader) * dataloader.batch_size
    z0 = sample_random(rng, (data_len, generator.z_channels), "normal")
    z1 = sample_random(rng, (data_len, generator.z_channels), "normal")
    if latent_space == "Z":
        z1 = batch_interp(z0, z1, epsilon, interp)
    print("Computing PPL IN", latent_space)
    distances = torch.zeros(data_len, dtype=torch.float32, device=tops.get_device())
    print(distances.shape)
    end = 0
    n_samples = 0
    for it, batch in enumerate(utils.tqdm_(dataloader, desc="Perceptual Path Length")):
        start = end
        end = start + batch["img"].shape[0]
        n_samples += batch["img"].shape[0]
        batch_lat_e0 = tops.to_cuda(z0[start:end])
        batch_lat_e1 = tops.to_cuda(z1[start:end])
        if latent_space == "W":
            w0 = generator.get_w(batch_lat_e0, update_emas=False)
            w1 = generator.get_w(batch_lat_e1, update_emas=False)
            w1 = w0.lerp(w1, epsilon)  # PPL end
            rgb1 = generator(**batch, w=w0)["img"]
            rgb2 = generator(**batch, w=w1)["img"]
        else:
            rgb1 = generator(**batch, z=batch_lat_e0)["img"]
            rgb2 = generator(**batch, z=batch_lat_e1)["img"]
        if rgb1.shape[-2] < upsample_size[0] or rgb1.shape[-1] < upsample_size[1]:
            rgb1 = resize(rgb1, upsample_size, antialias=True)
            rgb2 = resize(rgb2, upsample_size, antialias=True)
        rgb1 = utils.denormalize_img(rgb1).mul(255).byte()
        rgb2 = utils.denormalize_img(rgb2).mul(255).byte()

        sim = sample_similarity(rgb1, rgb2)
        dist_lat_e01 = sim / (epsilon ** 2)
        distances[start:end] = dist_lat_e01.view(-1)
    distances = distances[:n_samples]
    distances = tops.all_gather_uneven(distances).cpu().numpy()
    if tops.rank() != 0:
        return {"ppl/mean": -1, "ppl/std": -1}
    if tops.rank() == 0:
        cond, lo, hi = None, None, None
        if discard_percentile_lower is not None:
            lo = np.percentile(distances, discard_percentile_lower, interpolation='lower')
            cond = lo <= distances
        if discard_percentile_higher is not None:
            hi = np.percentile(distances, discard_percentile_higher, interpolation='higher')
            cond = np.logical_and(cond, distances <= hi)
        if cond is not None:
            distances = np.extract(cond, distances)
        return {
            "ppl/mean": float(np.mean(distances)),
            "ppl/std": float(np.std(distances)),
        }
    else:
        return {"ppl/mean"}