File size: 3,339 Bytes
ad552d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import tempfile
import torch
from PIL import Image
from tqdm.auto import tqdm
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from PIL import Image
from .clean_fid import fid


def save_single_image_to_temp(i, image, temp_dir):
    save_path = os.path.join(temp_dir, f"{i}.png")
    image.save(save_path, "PNG")


def save_images_to_temp(images, num_workers, verbose=False):
    assert isinstance(images, list) and isinstance(images[0], Image.Image)
    temp_dir = tempfile.mkdtemp()

    # Using ProcessPoolExecutor to save images in parallel
    func = partial(save_single_image_to_temp, temp_dir=temp_dir)
    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        tasks = executor.map(func, range(len(images)), images)
        list(tasks) if not verbose else list(
            tqdm(
                tasks,
                total=len(images),
                desc="Saving images ",
            )
        )
    return temp_dir


# Compute FID between two sets of images
def compute_fid(
    images1,
    images2,
    mode="legacy",
    device=None,
    batch_size=64,
    num_workers=None,
    verbose=False,
):
    # Support four types of FID scores
    assert mode in ["legacy", "clean", "clip"]
    if mode == "legacy":
        mode = "legacy_pytorch"
        model_name = "inception_v3"
    elif mode == "clean":
        mode = "clean"
        model_name = "inception_v3"
    elif mode == "clip":
        mode = "clean"
        model_name = "clip_vit_b_32"
    else:
        assert False

    # Set up device and num_workers
    if device is None:
        device = (
            torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        )
    if num_workers is not None:
        assert 1 <= num_workers <= os.cpu_count()
    else:
        num_workers = max(torch.cuda.device_count() * 4, 8)

    # Check images, can be paths or lists of PIL images
    if not isinstance(images1, list):
        assert isinstance(images1, str) and os.path.exists(images1)
        assert isinstance(images2, str) and os.path.exists(images2)
        path1 = images1
        path2 = images2
    else:
        assert isinstance(images1, list) and isinstance(images1[0], Image.Image)
        assert isinstance(images2, list) and isinstance(images2[0], Image.Image)
        # Save images to temp dir if needed
        path1 = save_images_to_temp(images1, num_workers=num_workers, verbose=verbose)
        path2 = save_images_to_temp(images2, num_workers=num_workers, verbose=verbose)
        
    # Attempt to cache statistics for path1
    if not fid.test_stats_exists(name=str(os.path.abspath(path1)).replace("/", "_"), mode=mode, model_name=model_name):
        fid.make_custom_stats(
            name=str(os.path.abspath(path1)).replace("/", "_"),
            fdir=path1,
            mode=mode,
            model_name=model_name,
            device=device,
            num_workers=num_workers,
            verbose=verbose,
        )
    fid_score = fid.compute_fid(
        path2,
        dataset_name=str(os.path.abspath(path1)).replace("/", "_"),
        dataset_split="custom",
        mode=mode,
        model_name=model_name,
        device=device,
        batch_size=batch_size,
        num_workers=num_workers,
        verbose=verbose,
    )
    return fid_score