max
reinit
b6dd358
raw
history blame
19.2 kB
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import os
import time
import hashlib
import pickle
import copy
import uuid
import numpy as np
import torch
import dnnlib
import math
import cv2
#----------------------------------------------------------------------------
class MetricOptions:
def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
assert 0 <= rank < num_gpus
self.G = G
self.G_kwargs = dnnlib.EasyDict(G_kwargs)
self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
self.num_gpus = num_gpus
self.rank = rank
self.device = device if device is not None else torch.device('cuda', rank)
self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
self.cache = cache
#----------------------------------------------------------------------------
_feature_detector_cache = dict()
def get_feature_detector_name(url):
return os.path.splitext(url.split('/')[-1])[0]
def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
assert 0 <= rank < num_gpus
key = (url, device)
if key not in _feature_detector_cache:
is_leader = (rank == 0)
if not is_leader and num_gpus > 1:
torch.distributed.barrier() # leader goes first
with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
_feature_detector_cache[key] = torch.jit.load(f).eval().to(device)
if is_leader and num_gpus > 1:
torch.distributed.barrier() # others follow
return _feature_detector_cache[key]
#----------------------------------------------------------------------------
class FeatureStats:
def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
self.capture_all = capture_all
self.capture_mean_cov = capture_mean_cov
self.max_items = max_items
self.num_items = 0
self.num_features = None
self.all_features = None
self.raw_mean = None
self.raw_cov = None
def set_num_features(self, num_features):
if self.num_features is not None:
assert num_features == self.num_features
else:
self.num_features = num_features
self.all_features = []
self.raw_mean = np.zeros([num_features], dtype=np.float64)
self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
def is_full(self):
return (self.max_items is not None) and (self.num_items >= self.max_items)
def append(self, x):
x = np.asarray(x, dtype=np.float32)
assert x.ndim == 2
if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
if self.num_items >= self.max_items:
return
x = x[:self.max_items - self.num_items]
self.set_num_features(x.shape[1])
self.num_items += x.shape[0]
if self.capture_all:
self.all_features.append(x)
if self.capture_mean_cov:
x64 = x.astype(np.float64)
self.raw_mean += x64.sum(axis=0)
self.raw_cov += x64.T @ x64
def append_torch(self, x, num_gpus=1, rank=0):
assert isinstance(x, torch.Tensor) and x.ndim == 2
assert 0 <= rank < num_gpus
if num_gpus > 1:
ys = []
for src in range(num_gpus):
y = x.clone()
torch.distributed.broadcast(y, src=src)
ys.append(y)
x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
self.append(x.cpu().numpy())
def get_all(self):
assert self.capture_all
return np.concatenate(self.all_features, axis=0)
def get_all_torch(self):
return torch.from_numpy(self.get_all())
def get_mean_cov(self):
assert self.capture_mean_cov
mean = self.raw_mean / self.num_items
cov = self.raw_cov / self.num_items
cov = cov - np.outer(mean, mean)
return mean, cov
def save(self, pkl_file):
with open(pkl_file, 'wb') as f:
pickle.dump(self.__dict__, f)
@staticmethod
def load(pkl_file):
with open(pkl_file, 'rb') as f:
s = dnnlib.EasyDict(pickle.load(f))
obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
obj.__dict__.update(s)
return obj
#----------------------------------------------------------------------------
class ProgressMonitor:
def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
self.tag = tag
self.num_items = num_items
self.verbose = verbose
self.flush_interval = flush_interval
self.progress_fn = progress_fn
self.pfn_lo = pfn_lo
self.pfn_hi = pfn_hi
self.pfn_total = pfn_total
self.start_time = time.time()
self.batch_time = self.start_time
self.batch_items = 0
if self.progress_fn is not None:
self.progress_fn(self.pfn_lo, self.pfn_total)
def update(self, cur_items):
assert (self.num_items is None) or (cur_items <= self.num_items)
if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
return
cur_time = time.time()
total_time = cur_time - self.start_time
time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
if (self.verbose) and (self.tag is not None):
print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
self.batch_time = cur_time
self.batch_items = cur_items
if (self.progress_fn is not None) and (self.num_items is not None):
self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
return ProgressMonitor(
tag = tag,
num_items = num_items,
flush_interval = flush_interval,
verbose = self.verbose,
progress_fn = self.progress_fn,
pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
pfn_total = self.pfn_total,
)
#----------------------------------------------------------------------------
def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
if data_loader_kwargs is None:
data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
# Try to lookup from cache.
cache_file = None
if opts.cache:
# Choose cache file name.
args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
# Check if the file exists (all processes must agree).
flag = os.path.isfile(cache_file) if opts.rank == 0 else False
if opts.num_gpus > 1:
flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
torch.distributed.broadcast(tensor=flag, src=0)
flag = (float(flag.cpu()) != 0)
# Load.
if flag:
return FeatureStats.load(cache_file)
# Initialize.
num_items = len(dataset)
if max_items is not None:
num_items = min(num_items, max_items)
stats = FeatureStats(max_items=num_items, **stats_kwargs)
progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
# Main loop.
item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
# for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
# adaptation to inpainting
for images, masks, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size,
**data_loader_kwargs):
# --------------------------------
if images.shape[1] == 1:
images = images.repeat([1, 3, 1, 1])
features = detector(images.to(opts.device), **detector_kwargs)
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
progress.update(stats.num_items)
# Save to cache.
if cache_file is not None and opts.rank == 0:
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
temp_file = cache_file + '.' + uuid.uuid4().hex
stats.save(temp_file)
os.replace(temp_file, cache_file) # atomic
return stats
#----------------------------------------------------------------------------
def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, jit=False, data_loader_kwargs=None, **stats_kwargs):
if data_loader_kwargs is None:
data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
if batch_gen is None:
batch_gen = min(batch_size, 4)
assert batch_size % batch_gen == 0
# Setup generator and load labels.
G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
# Image generation func.
def run_generator(img_in, mask_in, z, c):
img = G(img_in, mask_in, z, c, **opts.G_kwargs)
# img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
img = ((img + 1.0) * 127.5).clamp(0, 255).round().to(torch.uint8)
return img
# # JIT.
# if jit:
# z = torch.zeros([batch_gen, G.z_dim], device=opts.device)
# c = torch.zeros([batch_gen, G.c_dim], device=opts.device)
# run_generator = torch.jit.trace(run_generator, [z, c], check_trace=False)
# Initialize.
stats = FeatureStats(**stats_kwargs)
assert stats.max_items is not None
progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
# Main loop.
item_subset = [(i * opts.num_gpus + opts.rank) % stats.max_items for i in range((stats.max_items - 1) // opts.num_gpus + 1)]
for imgs_batch, masks_batch, labels_batch in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset,
batch_size=batch_size,
**data_loader_kwargs):
images = []
imgs_gen = (imgs_batch.to(opts.device).to(torch.float32) / 127.5 - 1).split(batch_gen)
masks_gen = masks_batch.to(opts.device).to(torch.float32).split(batch_gen)
for img_in, mask_in in zip(imgs_gen, masks_gen):
z = torch.randn([img_in.shape[0], G.z_dim], device=opts.device)
c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(img_in.shape[0])]
c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
images.append(run_generator(img_in, mask_in, z, c))
images = torch.cat(images)
if images.shape[1] == 1:
images = images.repeat([1, 3, 1, 1])
features = detector(images, **detector_kwargs)
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
progress.update(stats.num_items)
return stats
#----------------------------------------------------------------------------
def compute_image_stats_for_generator(opts, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, jit=False, data_loader_kwargs=None, **stats_kwargs):
if data_loader_kwargs is None:
data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
if batch_gen is None:
batch_gen = min(batch_size, 4)
assert batch_size % batch_gen == 0
# Setup generator and load labels.
G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
# Image generation func.
def run_generator(img_in, mask_in, z, c):
img = G(img_in, mask_in, z, c, **opts.G_kwargs)
# img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
img = ((img + 1.0) * 127.5).clamp(0, 255).round().to(torch.uint8)
return img
# Initialize.
stats = FeatureStats(**stats_kwargs)
assert stats.max_items is not None
progress = opts.progress.sub(tag='generator images', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
# Main loop.
item_subset = [(i * opts.num_gpus + opts.rank) % stats.max_items for i in range((stats.max_items - 1) // opts.num_gpus + 1)]
for imgs_batch, masks_batch, labels_batch in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset,
batch_size=batch_size,
**data_loader_kwargs):
images = []
imgs_gen = (imgs_batch.to(opts.device).to(torch.float32) / 127.5 - 1).split(batch_gen)
masks_gen = masks_batch.to(opts.device).to(torch.float32).split(batch_gen)
for img_in, mask_in in zip(imgs_gen, masks_gen):
z = torch.randn([img_in.shape[0], G.z_dim], device=opts.device)
c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(img_in.shape[0])]
c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
images.append(run_generator(img_in, mask_in, z, c))
images = torch.cat(images)
if images.shape[1] == 1:
images = images.repeat([1, 3, 1, 1])
assert imgs_batch.shape == images.shape
metrics = []
for i in range(imgs_batch.shape[0]):
img_real = np.transpose(imgs_batch[i].cpu().numpy(), [1, 2, 0])
img_gen = np.transpose(images[i].cpu().numpy(), [1, 2, 0])
psnr = calculate_psnr(img_gen, img_real)
ssim = calculate_ssim(img_gen, img_real)
l1 = calculate_l1(img_gen, img_real)
metrics.append([psnr, ssim, l1])
metrics = torch.from_numpy(np.array(metrics)).to(torch.float32).to(opts.device)
stats.append_torch(metrics, num_gpus=opts.num_gpus, rank=opts.rank)
progress.update(stats.num_items)
return stats
def calculate_psnr(img1, img2):
# img1 and img2 have range [0, 255]
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
mse = np.mean((img1 - img2) ** 2)
if mse == 0:
return float('inf')
return 20 * math.log10(255.0 / math.sqrt(mse))
def calculate_ssim(img1, img2):
C1 = (0.01 * 255) ** 2
C2 = (0.03 * 255) ** 2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1 ** 2
mu2_sq = mu2 ** 2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
def calculate_l1(img1, img2):
img1 = img1.astype(np.float64) / 255.0
img2 = img2.astype(np.float64) / 255.0
l1 = np.mean(np.abs(img1 - img2))
return l1
# def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, jit=False, **stats_kwargs):
# if batch_gen is None:
# batch_gen = min(batch_size, 4)
# assert batch_size % batch_gen == 0
#
# # Setup generator and load labels.
# G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
# dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
#
# # Image generation func.
# def run_generator(z, c):
# img = G(z=z, c=c, **opts.G_kwargs)
# img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
# return img
#
# # JIT.
# if jit:
# z = torch.zeros([batch_gen, G.z_dim], device=opts.device)
# c = torch.zeros([batch_gen, G.c_dim], device=opts.device)
# run_generator = torch.jit.trace(run_generator, [z, c], check_trace=False)
#
# # Initialize.
# stats = FeatureStats(**stats_kwargs)
# assert stats.max_items is not None
# progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
# detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
#
# # Main loop.
# while not stats.is_full():
# images = []
# for _i in range(batch_size // batch_gen):
# z = torch.randn([batch_gen, G.z_dim], device=opts.device)
# c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)]
# c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
# images.append(run_generator(z, c))
# images = torch.cat(images)
# if images.shape[1] == 1:
# images = images.repeat([1, 3, 1, 1])
# features = detector(images, **detector_kwargs)
# stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
# progress.update(stats.num_items)
# return stats
#
# #----------------------------------------------------------------------------