Spaces:
Runtime error
Runtime error
import os | |
import tempfile | |
import torch | |
from PIL import Image | |
from tqdm.auto import tqdm | |
from concurrent.futures import ProcessPoolExecutor | |
from functools import partial | |
from PIL import Image | |
from .clean_fid import fid | |
def save_single_image_to_temp(i, image, temp_dir): | |
save_path = os.path.join(temp_dir, f"{i}.png") | |
image.save(save_path, "PNG") | |
def save_images_to_temp(images, num_workers, verbose=False): | |
assert isinstance(images, list) and isinstance(images[0], Image.Image) | |
temp_dir = tempfile.mkdtemp() | |
# Using ProcessPoolExecutor to save images in parallel | |
func = partial(save_single_image_to_temp, temp_dir=temp_dir) | |
with ProcessPoolExecutor(max_workers=num_workers) as executor: | |
tasks = executor.map(func, range(len(images)), images) | |
list(tasks) if not verbose else list( | |
tqdm( | |
tasks, | |
total=len(images), | |
desc="Saving images ", | |
) | |
) | |
return temp_dir | |
# Compute FID between two sets of images | |
def compute_fid( | |
images1, | |
images2, | |
mode="legacy", | |
device=None, | |
batch_size=64, | |
num_workers=None, | |
verbose=False, | |
): | |
# Support four types of FID scores | |
assert mode in ["legacy", "clean", "clip"] | |
if mode == "legacy": | |
mode = "legacy_pytorch" | |
model_name = "inception_v3" | |
elif mode == "clean": | |
mode = "clean" | |
model_name = "inception_v3" | |
elif mode == "clip": | |
mode = "clean" | |
model_name = "clip_vit_b_32" | |
else: | |
assert False | |
# Set up device and num_workers | |
if device is None: | |
device = ( | |
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
) | |
if num_workers is not None: | |
assert 1 <= num_workers <= os.cpu_count() | |
else: | |
num_workers = max(torch.cuda.device_count() * 4, 8) | |
# Check images, can be paths or lists of PIL images | |
if not isinstance(images1, list): | |
assert isinstance(images1, str) and os.path.exists(images1) | |
assert isinstance(images2, str) and os.path.exists(images2) | |
path1 = images1 | |
path2 = images2 | |
else: | |
assert isinstance(images1, list) and isinstance(images1[0], Image.Image) | |
assert isinstance(images2, list) and isinstance(images2[0], Image.Image) | |
# Save images to temp dir if needed | |
path1 = save_images_to_temp(images1, num_workers=num_workers, verbose=verbose) | |
path2 = save_images_to_temp(images2, num_workers=num_workers, verbose=verbose) | |
# Attempt to cache statistics for path1 | |
if not fid.test_stats_exists(name=str(os.path.abspath(path1)).replace("/", "_"), mode=mode, model_name=model_name): | |
fid.make_custom_stats( | |
name=str(os.path.abspath(path1)).replace("/", "_"), | |
fdir=path1, | |
mode=mode, | |
model_name=model_name, | |
device=device, | |
num_workers=num_workers, | |
verbose=verbose, | |
) | |
fid_score = fid.compute_fid( | |
path2, | |
dataset_name=str(os.path.abspath(path1)).replace("/", "_"), | |
dataset_split="custom", | |
mode=mode, | |
model_name=model_name, | |
device=device, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
verbose=verbose, | |
) | |
return fid_score | |