|
import os |
|
import shutil |
|
|
|
import torch |
|
import torchvision |
|
from pytorch_fid import fid_score |
|
from torch import distributed |
|
from torch.utils.data import DataLoader |
|
from torch.utils.data.distributed import DistributedSampler |
|
from tqdm.autonotebook import tqdm, trange |
|
|
|
from .DiffAE_support_renderer import * |
|
from .DiffAE_support_config import * |
|
from .DiffAE_diffusion_diffusion import SpacedDiffusionBeatGans as Sampler |
|
import lpips |
|
from ssim import compute_ssim as ssim |
|
|
|
|
|
def make_subset_loader(conf: TrainConfig, |
|
dataset, |
|
batch_size: int, |
|
shuffle: bool, |
|
parallel: bool, |
|
drop_last=True): |
|
dataset = SubsetDataset(dataset, size=conf.eval_num_images) |
|
if parallel and distributed.is_initialized(): |
|
sampler = DistributedSampler(dataset, shuffle=shuffle) |
|
else: |
|
sampler = None |
|
return DataLoader( |
|
dataset, |
|
batch_size=batch_size, |
|
sampler=sampler, |
|
|
|
shuffle=False if sampler else shuffle, |
|
num_workers=conf.num_workers, |
|
pin_memory=True, |
|
drop_last=drop_last, |
|
multiprocessing_context=get_context('fork'), |
|
) |
|
|
|
|
|
def evaluate_lpips( |
|
sampler: Sampler, |
|
model: Model, |
|
conf: TrainConfig, |
|
device, |
|
val_data, |
|
latent_sampler: Sampler = None, |
|
use_inverted_noise: bool = False, |
|
): |
|
""" |
|
compare the generated images from autoencoder on validation dataset |
|
|
|
Args: |
|
use_inversed_noise: the noise is also inverted from DDIM |
|
""" |
|
lpips_fn = lpips.LPIPS(net='alex').to(device) |
|
val_loader = make_subset_loader(conf, |
|
dataset=val_data, |
|
batch_size=conf.batch_size_eval, |
|
shuffle=False, |
|
parallel=True) |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
scores = { |
|
'lpips': [], |
|
'mse': [], |
|
'ssim': [], |
|
'psnr': [], |
|
} |
|
for batch in tqdm(val_loader, desc='lpips'): |
|
imgs = batch['img'].to(device) |
|
|
|
if use_inverted_noise: |
|
|
|
|
|
model_kwargs = {} |
|
if conf.model_type.has_autoenc(): |
|
with torch.no_grad(): |
|
model_kwargs = model.encode(imgs) |
|
x_T = sampler.ddim_reverse_sample_loop( |
|
model=model, |
|
x=imgs, |
|
clip_denoised=True, |
|
model_kwargs=model_kwargs) |
|
x_T = x_T['sample'] |
|
else: |
|
x_T = torch.randn((len(imgs), 3, conf.img_size, conf.img_size), |
|
device=device) |
|
|
|
if conf.model_type == ModelType.ddpm: |
|
|
|
assert use_inverted_noise |
|
pred_imgs = render_uncondition( |
|
conf=conf, |
|
model=model, |
|
x_T=x_T, |
|
sampler=sampler, |
|
latent_sampler=latent_sampler, |
|
) |
|
else: |
|
pred_imgs = render_condition(conf=conf, |
|
model=model, |
|
x_T=x_T, |
|
x_start=imgs, |
|
cond=None, |
|
sampler=sampler) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scores['lpips'].append(lpips_fn.forward(imgs, pred_imgs).view(-1)) |
|
|
|
|
|
norm_imgs = (imgs + 1) / 2 |
|
norm_pred_imgs = (pred_imgs + 1) / 2 |
|
|
|
scores['ssim'].append( |
|
ssim(norm_imgs, norm_pred_imgs, size_average=False)) |
|
|
|
scores['mse'].append( |
|
(norm_imgs - norm_pred_imgs).pow(2).mean(dim=[1, 2, 3])) |
|
|
|
scores['psnr'].append(psnr(norm_imgs, norm_pred_imgs)) |
|
|
|
for key in scores.keys(): |
|
scores[key] = torch.cat(scores[key]).float() |
|
model.train() |
|
|
|
barrier() |
|
|
|
|
|
outs = { |
|
key: [ |
|
torch.zeros(len(scores[key]), device=device) |
|
for i in range(get_world_size()) |
|
] |
|
for key in scores.keys() |
|
} |
|
for key in scores.keys(): |
|
all_gather(outs[key], scores[key]) |
|
|
|
|
|
for key in scores.keys(): |
|
scores[key] = torch.cat(outs[key]).mean().item() |
|
|
|
|
|
return scores |
|
|
|
|
|
def psnr(img1, img2): |
|
""" |
|
Args: |
|
img1: (n, c, h, w) |
|
""" |
|
v_max = 1. |
|
|
|
mse = torch.mean((img1 - img2)**2, dim=[1, 2, 3]) |
|
return 20 * torch.log10(v_max / torch.sqrt(mse)) |
|
|
|
|
|
def evaluate_fid( |
|
sampler: Sampler, |
|
model: Model, |
|
conf: TrainConfig, |
|
device, |
|
train_data, |
|
val_data, |
|
latent_sampler: Sampler = None, |
|
conds_mean=None, |
|
conds_std=None, |
|
remove_cache: bool = True, |
|
clip_latent_noise: bool = False, |
|
): |
|
assert conf.fid_cache is not None |
|
if get_rank() == 0: |
|
|
|
|
|
val_loader = make_subset_loader(conf, |
|
dataset=val_data, |
|
batch_size=conf.batch_size_eval, |
|
shuffle=False, |
|
parallel=False) |
|
|
|
|
|
cache_dir = f'{conf.fid_cache}_{conf.eval_num_images}' |
|
if (os.path.exists(cache_dir) |
|
and len(os.listdir(cache_dir)) < conf.eval_num_images): |
|
shutil.rmtree(cache_dir) |
|
|
|
if not os.path.exists(cache_dir): |
|
|
|
|
|
loader_to_path(val_loader, cache_dir, denormalize=True) |
|
|
|
|
|
if os.path.exists(conf.generate_dir): |
|
shutil.rmtree(conf.generate_dir) |
|
os.makedirs(conf.generate_dir) |
|
|
|
barrier() |
|
|
|
world_size = get_world_size() |
|
rank = get_rank() |
|
batch_size = chunk_size(conf.batch_size_eval, rank, world_size) |
|
|
|
def filename(idx): |
|
return world_size * idx + rank |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
if conf.model_type.can_sample(): |
|
eval_num_images = chunk_size(conf.eval_num_images, rank, |
|
world_size) |
|
desc = "generating images" |
|
for i in trange(0, eval_num_images, batch_size, desc=desc): |
|
batch_size = min(batch_size, eval_num_images - i) |
|
x_T = torch.randn( |
|
(batch_size, 3, conf.img_size, conf.img_size), |
|
device=device) |
|
batch_images = render_uncondition( |
|
conf=conf, |
|
model=model, |
|
x_T=x_T, |
|
sampler=sampler, |
|
latent_sampler=latent_sampler, |
|
conds_mean=conds_mean, |
|
conds_std=conds_std).cpu() |
|
|
|
batch_images = (batch_images + 1) / 2 |
|
|
|
for j in range(len(batch_images)): |
|
img_name = filename(i + j) |
|
torchvision.utils.save_image( |
|
batch_images[j], |
|
os.path.join(conf.generate_dir, f'{img_name}.png')) |
|
elif conf.model_type == ModelType.autoencoder: |
|
if conf.train_mode.is_latent_diffusion(): |
|
|
|
model: BeatGANsAutoencModel |
|
eval_num_images = chunk_size(conf.eval_num_images, rank, |
|
world_size) |
|
desc = "generating images" |
|
for i in trange(0, eval_num_images, batch_size, desc=desc): |
|
batch_size = min(batch_size, eval_num_images - i) |
|
x_T = torch.randn( |
|
(batch_size, 3, conf.img_size, conf.img_size), |
|
device=device) |
|
batch_images = render_uncondition( |
|
conf=conf, |
|
model=model, |
|
x_T=x_T, |
|
sampler=sampler, |
|
latent_sampler=latent_sampler, |
|
conds_mean=conds_mean, |
|
conds_std=conds_std, |
|
clip_latent_noise=clip_latent_noise, |
|
).cpu() |
|
batch_images = (batch_images + 1) / 2 |
|
|
|
for j in range(len(batch_images)): |
|
img_name = filename(i + j) |
|
torchvision.utils.save_image( |
|
batch_images[j], |
|
os.path.join(conf.generate_dir, f'{img_name}.png')) |
|
else: |
|
|
|
|
|
|
|
train_loader = make_subset_loader(conf, |
|
dataset=train_data, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
parallel=True) |
|
|
|
i = 0 |
|
for batch in tqdm(train_loader, desc='generating images'): |
|
imgs = batch['img'].to(device) |
|
x_T = torch.randn( |
|
(len(imgs), 3, conf.img_size, conf.img_size), |
|
device=device) |
|
batch_images = render_condition( |
|
conf=conf, |
|
model=model, |
|
x_T=x_T, |
|
x_start=imgs, |
|
cond=None, |
|
sampler=sampler, |
|
latent_sampler=latent_sampler).cpu() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_images = (batch_images + 1) / 2 |
|
|
|
for j in range(len(batch_images)): |
|
img_name = filename(i + j) |
|
torchvision.utils.save_image( |
|
batch_images[j], |
|
os.path.join(conf.generate_dir, f'{img_name}.png')) |
|
i += len(imgs) |
|
else: |
|
raise NotImplementedError() |
|
model.train() |
|
|
|
barrier() |
|
|
|
if get_rank() == 0: |
|
fid = fid_score.calculate_fid_given_paths( |
|
[cache_dir, conf.generate_dir], |
|
batch_size, |
|
device=device, |
|
dims=2048) |
|
|
|
|
|
if remove_cache and os.path.exists(conf.generate_dir): |
|
shutil.rmtree(conf.generate_dir) |
|
|
|
barrier() |
|
|
|
if get_rank() == 0: |
|
|
|
fid = torch.tensor(float(fid), device=device) |
|
broadcast(fid, 0) |
|
else: |
|
fid = torch.tensor(0., device=device) |
|
broadcast(fid, 0) |
|
fid = fid.item() |
|
print(f'fid ({get_rank()}):', fid) |
|
|
|
return fid |
|
|
|
|
|
def loader_to_path(loader: DataLoader, path: str, denormalize: bool): |
|
|
|
|
|
if not os.path.exists(path): |
|
os.makedirs(path) |
|
|
|
|
|
i = 0 |
|
for batch in tqdm(loader, desc='copy images'): |
|
imgs = batch['img'] |
|
if denormalize: |
|
imgs = (imgs + 1) / 2 |
|
for j in range(len(imgs)): |
|
torchvision.utils.save_image(imgs[j], |
|
os.path.join(path, f'{i+j}.png')) |
|
i += len(imgs) |