Erasing-Invisible-Demo / kit /metrics /distributional.py
mcding
published version
ad552d8
import os
import tempfile
import torch
from PIL import Image
from tqdm.auto import tqdm
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from PIL import Image
from .clean_fid import fid
def save_single_image_to_temp(i, image, temp_dir):
save_path = os.path.join(temp_dir, f"{i}.png")
image.save(save_path, "PNG")
def save_images_to_temp(images, num_workers, verbose=False):
assert isinstance(images, list) and isinstance(images[0], Image.Image)
temp_dir = tempfile.mkdtemp()
# Using ProcessPoolExecutor to save images in parallel
func = partial(save_single_image_to_temp, temp_dir=temp_dir)
with ProcessPoolExecutor(max_workers=num_workers) as executor:
tasks = executor.map(func, range(len(images)), images)
list(tasks) if not verbose else list(
tqdm(
tasks,
total=len(images),
desc="Saving images ",
)
)
return temp_dir
# Compute FID between two sets of images
def compute_fid(
images1,
images2,
mode="legacy",
device=None,
batch_size=64,
num_workers=None,
verbose=False,
):
# Support four types of FID scores
assert mode in ["legacy", "clean", "clip"]
if mode == "legacy":
mode = "legacy_pytorch"
model_name = "inception_v3"
elif mode == "clean":
mode = "clean"
model_name = "inception_v3"
elif mode == "clip":
mode = "clean"
model_name = "clip_vit_b_32"
else:
assert False
# Set up device and num_workers
if device is None:
device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
if num_workers is not None:
assert 1 <= num_workers <= os.cpu_count()
else:
num_workers = max(torch.cuda.device_count() * 4, 8)
# Check images, can be paths or lists of PIL images
if not isinstance(images1, list):
assert isinstance(images1, str) and os.path.exists(images1)
assert isinstance(images2, str) and os.path.exists(images2)
path1 = images1
path2 = images2
else:
assert isinstance(images1, list) and isinstance(images1[0], Image.Image)
assert isinstance(images2, list) and isinstance(images2[0], Image.Image)
# Save images to temp dir if needed
path1 = save_images_to_temp(images1, num_workers=num_workers, verbose=verbose)
path2 = save_images_to_temp(images2, num_workers=num_workers, verbose=verbose)
# Attempt to cache statistics for path1
if not fid.test_stats_exists(name=str(os.path.abspath(path1)).replace("/", "_"), mode=mode, model_name=model_name):
fid.make_custom_stats(
name=str(os.path.abspath(path1)).replace("/", "_"),
fdir=path1,
mode=mode,
model_name=model_name,
device=device,
num_workers=num_workers,
verbose=verbose,
)
fid_score = fid.compute_fid(
path2,
dataset_name=str(os.path.abspath(path1)).replace("/", "_"),
dataset_split="custom",
mode=mode,
model_name=model_name,
device=device,
batch_size=batch_size,
num_workers=num_workers,
verbose=verbose,
)
return fid_score