Spaces:
Runtime error
Runtime error
File size: 3,339 Bytes
ad552d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import os
import tempfile
import torch
from PIL import Image
from tqdm.auto import tqdm
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from PIL import Image
from .clean_fid import fid
def save_single_image_to_temp(i, image, temp_dir):
save_path = os.path.join(temp_dir, f"{i}.png")
image.save(save_path, "PNG")
def save_images_to_temp(images, num_workers, verbose=False):
assert isinstance(images, list) and isinstance(images[0], Image.Image)
temp_dir = tempfile.mkdtemp()
# Using ProcessPoolExecutor to save images in parallel
func = partial(save_single_image_to_temp, temp_dir=temp_dir)
with ProcessPoolExecutor(max_workers=num_workers) as executor:
tasks = executor.map(func, range(len(images)), images)
list(tasks) if not verbose else list(
tqdm(
tasks,
total=len(images),
desc="Saving images ",
)
)
return temp_dir
# Compute FID between two sets of images
def compute_fid(
images1,
images2,
mode="legacy",
device=None,
batch_size=64,
num_workers=None,
verbose=False,
):
# Support four types of FID scores
assert mode in ["legacy", "clean", "clip"]
if mode == "legacy":
mode = "legacy_pytorch"
model_name = "inception_v3"
elif mode == "clean":
mode = "clean"
model_name = "inception_v3"
elif mode == "clip":
mode = "clean"
model_name = "clip_vit_b_32"
else:
assert False
# Set up device and num_workers
if device is None:
device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
if num_workers is not None:
assert 1 <= num_workers <= os.cpu_count()
else:
num_workers = max(torch.cuda.device_count() * 4, 8)
# Check images, can be paths or lists of PIL images
if not isinstance(images1, list):
assert isinstance(images1, str) and os.path.exists(images1)
assert isinstance(images2, str) and os.path.exists(images2)
path1 = images1
path2 = images2
else:
assert isinstance(images1, list) and isinstance(images1[0], Image.Image)
assert isinstance(images2, list) and isinstance(images2[0], Image.Image)
# Save images to temp dir if needed
path1 = save_images_to_temp(images1, num_workers=num_workers, verbose=verbose)
path2 = save_images_to_temp(images2, num_workers=num_workers, verbose=verbose)
# Attempt to cache statistics for path1
if not fid.test_stats_exists(name=str(os.path.abspath(path1)).replace("/", "_"), mode=mode, model_name=model_name):
fid.make_custom_stats(
name=str(os.path.abspath(path1)).replace("/", "_"),
fdir=path1,
mode=mode,
model_name=model_name,
device=device,
num_workers=num_workers,
verbose=verbose,
)
fid_score = fid.compute_fid(
path2,
dataset_name=str(os.path.abspath(path1)).replace("/", "_"),
dataset_split="custom",
mode=mode,
model_name=model_name,
device=device,
batch_size=batch_size,
num_workers=num_workers,
verbose=verbose,
)
return fid_score
|