|
""" |
|
Modified from guided-diffusion/scripts/image_sample.py |
|
""" |
|
|
|
import argparse |
|
import os |
|
import torch |
|
from PIL import ImageFile |
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
from dataset import TMDistilDireDataset |
|
from torch.utils.data import DataLoader |
|
import cv2 |
|
|
|
import torch.nn.functional as F |
|
import torchvision.transforms as transforms |
|
from tqdm.auto import tqdm |
|
|
|
import numpy as np |
|
import torch as th |
|
import torchvision |
|
import os.path as osp |
|
|
|
|
|
from .guided_diffusion.script_util import ( |
|
NUM_CLASSES, |
|
model_and_diffusion_defaults, |
|
create_model_and_diffusion, |
|
add_dict_to_argparser, |
|
dict_parse, |
|
args_to_dict, |
|
|
|
) |
|
|
|
|
|
|
|
def reshape_image(imgs: torch.Tensor, image_size: int) -> torch.Tensor: |
|
if len(imgs.shape) == 3: |
|
imgs = imgs.unsqueeze(0) |
|
if imgs.shape[2] != imgs.shape[3]: |
|
crop_func = transforms.CenterCrop(image_size) |
|
imgs = crop_func(imgs) |
|
if imgs.shape[2] != image_size: |
|
imgs = F.interpolate(imgs, size=(image_size, image_size), mode="bicubic", antialias=True) |
|
return imgs |
|
|
|
|
|
|
|
def create_argparser(): |
|
defaults = dict( |
|
clip_denoised=True, |
|
num_samples=-1, |
|
use_ddim=True, |
|
real_step=0, |
|
continue_reverse=False, |
|
has_subfolder=False, |
|
) |
|
defaults.update(model_and_diffusion_defaults()) |
|
sanic_dict = dict( |
|
attention_resolutions='32,16,8', |
|
class_cond=False, |
|
diffusion_steps=1000, |
|
image_size=256, |
|
learn_sigma=True, |
|
model_path="./models/256x256-adm.pt", |
|
noise_schedule='linear', |
|
num_channels=256, |
|
num_head_channels=64, |
|
num_res_blocks=2, |
|
resblock_updown=True, |
|
use_fp16=True, |
|
use_scale_shift_norm=True, |
|
data_root="", |
|
compute_dire=False, |
|
compute_eps=False, |
|
save_root="", |
|
batch_size=32, |
|
) |
|
defaults.update(sanic_dict) |
|
parser = argparse.ArgumentParser() |
|
add_dict_to_argparser(parser, defaults) |
|
args = parser.parse_args() |
|
|
|
return args_to_dict(args, list(defaults.keys())) |
|
|
|
def create_dicts_for_static_init(): |
|
defaults = dict( |
|
clip_denoised=True, |
|
num_samples=-1, |
|
use_ddim=True, |
|
real_step=0, |
|
continue_reverse=False, |
|
has_subfolder=False, |
|
) |
|
defaults.update(model_and_diffusion_defaults()) |
|
sanic_dict = dict( |
|
attention_resolutions='32,16,8', |
|
class_cond=False, |
|
diffusion_steps=1000, |
|
image_size=256, |
|
learn_sigma=True, |
|
model_path="./models/256x256-adm.pt", |
|
noise_schedule='linear', |
|
num_channels=256, |
|
num_head_channels=64, |
|
num_res_blocks=2, |
|
resblock_updown=True, |
|
use_fp16=True, |
|
use_scale_shift_norm=True, |
|
data_root="", |
|
compute_dire=False, |
|
compute_eps=False, |
|
save_root="", |
|
batch_size=32, |
|
) |
|
defaults.update(sanic_dict) |
|
|
|
return defaults |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def dire(img_batch:torch.Tensor, model, diffusion, args, save_img=False, save_path=None): |
|
print("computing recons & DIRE ...") |
|
imgs = img_batch.cuda() |
|
|
|
batch_size = imgs.shape[0] |
|
model_kwargs = {} |
|
|
|
reverse_fn = diffusion.ddim_reverse_sample_loop |
|
assert (imgs.shape[2] == imgs.shape[3]) and (imgs.shape[3] == args['image_size']), f"Image size mismatch: {imgs.shape[2]} != {args['image_size']}" |
|
latent = reverse_fn( |
|
model, |
|
(batch_size, 3, args['image_size'], args['image_size']), |
|
noise=imgs, |
|
clip_denoised=args['clip_denoised'], |
|
model_kwargs=model_kwargs, |
|
real_step=args['real_step'], |
|
) |
|
sample_fn = diffusion.p_sample_loop if not args['use_ddim'] else diffusion.ddim_sample_loop |
|
recons = sample_fn( |
|
model, |
|
(batch_size, 3, args['image_size'], args['image_size']), |
|
noise=latent, |
|
clip_denoised=args['clip_denoised'], |
|
model_kwargs=model_kwargs, |
|
real_step=args['real_step'] |
|
) |
|
|
|
dire = th.abs(imgs - recons) |
|
dire = (dire*255./2).clamp(0, 255).to(th.uint8) |
|
dire = dire.contiguous() / 255. |
|
dire = (dire).clamp(0, 1).to(th.float32) |
|
|
|
|
|
imgs = (imgs+1)*0.5 |
|
recons = (recons+1)*0.5 |
|
|
|
if save_img: |
|
|
|
for i in range(len(img_batch)): |
|
dire_img = dire[i].detach().cpu().numpy().transpose(1, 2, 0) |
|
dire_img = cv2.cvtColor(dire_img, cv2.COLOR_RGB2BGR) |
|
|
|
dire_path = os.path.join(save_path, f"dire_{i}.png") |
|
cv2.imwrite(dire_path, dire_img*255) |
|
|
|
return dire, imgs, recons |
|
|
|
@torch.no_grad() |
|
def dire_get_first_step_noise(img_batch:torch.Tensor, model, diffusion, args, device): |
|
imgs = img_batch.to(device) |
|
batch_size = imgs.shape[0] |
|
model_kwargs = {} |
|
|
|
reverse_fn = diffusion.ddim_reverse_sample_only_eps |
|
assert (imgs.shape[2] == imgs.shape[3]) and (imgs.shape[3] == args['image_size']), f"Image size mismatch: {imgs.shape[2]} != {args['image_size']}" |
|
|
|
eps = reverse_fn( |
|
model, |
|
|
|
x=imgs, |
|
t=torch.zeros(imgs.shape[0],).long().to(device), |
|
clip_denoised=args['clip_denoised'], |
|
model_kwargs=model_kwargs, |
|
eta=0.0 |
|
|
|
) |
|
|
|
return eps |
|
|
|
|
|
if __name__ == "__main__": |
|
from torch.utils.data.distributed import DistributedSampler |
|
import torch.distributed as dist |
|
import os |
|
|
|
dist.init_process_group(backend='nccl', init_method='env://') |
|
|
|
local_rank = int(os.environ['LOCAL_RANK']) |
|
torch.cuda.set_device(local_rank) |
|
|
|
|
|
device = torch.device("cuda") |
|
|
|
adm_args = create_argparser() |
|
adm_args['timestep_respacing'] = 'ddim20' |
|
adm_model, diffusion = create_model_and_diffusion(**dict_parse(adm_args, model_and_diffusion_defaults().keys())) |
|
print(f"checkpoint: {adm_args['model_path']}") |
|
print(f"model channel: {adm_args['num_channels']}, {adm_model.model_channels}") |
|
adm_model.load_state_dict(torch.load(adm_args['model_path'], map_location="cpu")) |
|
adm_model.to(device) |
|
|
|
adm_model.convert_to_fp16() |
|
adm_model.eval() |
|
|
|
dataset = TMDistilDireDataset(adm_args['data_root'], prepared_dire=False) |
|
sampler = DistributedSampler(dataset, shuffle=False) |
|
os.makedirs(osp.join(adm_args['save_root'], 'images', 'fakes'), exist_ok=True) |
|
os.makedirs(osp.join(adm_args['save_root'], 'images', 'reals'), exist_ok=True) |
|
os.makedirs(osp.join(adm_args['save_root'], 'dire', 'fakes'), exist_ok=True) |
|
os.makedirs(osp.join(adm_args['save_root'], 'dire', 'reals'), exist_ok=True) |
|
os.makedirs(osp.join(adm_args['save_root'], 'eps', 'fakes'), exist_ok=True) |
|
os.makedirs(osp.join(adm_args['save_root'], 'eps', 'reals'), exist_ok=True) |
|
print(f"Dataset length: {len(dataset)}") |
|
dataloader = DataLoader(dataset, batch_size=adm_args['batch_size'], num_workers=4, drop_last=False, pin_memory=True, sampler=sampler) |
|
|
|
for (img_batch, dire_batch, eps_batch, isfake_batch), (img_pathes, dire_pathes, eps_pathes) in tqdm(dataloader): |
|
haveall=True |
|
for i in range(len(img_batch)): |
|
basename = osp.basename(img_pathes[i]) |
|
isfake = isfake_batch[i] |
|
|
|
img_path = osp.join(adm_args['save_root'], 'images', 'fakes', basename) if isfake else osp.join(adm_args['save_root'], 'images', 'reals', basename) |
|
dire_path = img_path.replace('/images/', '/dire/') |
|
eps_path = img_path.replace('/images/', '/eps/').split('.')[0] + '.pt' |
|
if (not osp.exists(img_pathes[i])) or (not osp.exists(dire_path)) or (not osp.exists(eps_path)): |
|
haveall=False |
|
break |
|
if haveall: |
|
continue |
|
with torch.no_grad(): |
|
eps = None |
|
img = (img_batch.detach().cpu()+1)*0.5 |
|
if adm_args['compute_eps']: |
|
eps = dire_get_first_step_noise(img_batch, adm_model, diffusion, adm_args, device) |
|
eps = eps.detach().cpu() |
|
|
|
if adm_args['compute_dire']: |
|
dire_img, img, recons = dire(img_batch, adm_model, diffusion, adm_args) |
|
dire_img = dire_img.detach().cpu() |
|
img = img.detach().cpu() |
|
|
|
for i in range(len(img_batch)): |
|
basename = osp.basename(img_pathes[i]) |
|
isfake = isfake_batch[i] |
|
|
|
img_path = osp.join(adm_args['save_root'], 'images', 'fakes', basename) if isfake else osp.join(adm_args['save_root'], 'images', 'reals', basename) |
|
dire_path = img_path.replace('/images/', '/dire/') |
|
eps_path = img_path.replace('/images/', '/eps/').split('.')[0] + '.pt' |
|
|
|
if not osp.exists(img_path): |
|
torchvision.utils.save_image(img[i], img_path) |
|
if not osp.exists(dire_path) and adm_args['compute_dire']: |
|
torchvision.utils.save_image(dire_img[i], dire_path) |
|
if not osp.exists(eps_path) and eps is not None: |
|
torch.save(eps[i], eps_path) |