|
import torch |
|
from torch import nn |
|
from torch import optim |
|
import torch.nn.functional as F |
|
from torchvision.datasets import ImageFolder |
|
from torch.utils.data import DataLoader |
|
from torchvision import utils as vutils |
|
|
|
import os |
|
import random |
|
import argparse |
|
from tqdm import tqdm |
|
|
|
from models import Generator |
|
|
|
|
|
def load_params(model, new_param): |
|
for p, new_p in zip(model.parameters(), new_param): |
|
p.data.copy_(new_p) |
|
|
|
def resize(img): |
|
return F.interpolate(img, size=256) |
|
|
|
def batch_generate(zs, netG, batch=8): |
|
g_images = [] |
|
with torch.no_grad(): |
|
for i in range(len(zs)//batch): |
|
g_images.append( netG(zs[i*batch:(i+1)*batch]).cpu() ) |
|
if len(zs)%batch>0: |
|
g_images.append( netG(zs[-(len(zs)%batch):]).cpu() ) |
|
return torch.cat(g_images) |
|
|
|
def batch_save(images, folder_name): |
|
if not os.path.exists(folder_name): |
|
os.mkdir(folder_name) |
|
for i, image in enumerate(images): |
|
vutils.save_image(image.add(1).mul(0.5), folder_name+'/%d.jpg'%i) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser( |
|
description='generate images' |
|
) |
|
parser.add_argument('--ckpt', type=str) |
|
parser.add_argument('--artifacts', type=str, default=".", help='path to artifacts.') |
|
parser.add_argument('--cuda', type=int, default=0, help='index of gpu to use') |
|
parser.add_argument('--start_iter', type=int, default=6) |
|
parser.add_argument('--end_iter', type=int, default=10) |
|
|
|
parser.add_argument('--dist', type=str, default='.') |
|
parser.add_argument('--size', type=int, default=256) |
|
parser.add_argument('--batch', default=16, type=int, help='batch size') |
|
parser.add_argument('--n_sample', type=int, default=2000) |
|
parser.add_argument('--big', action='store_true') |
|
parser.add_argument('--im_size', type=int, default=1024) |
|
parser.set_defaults(big=False) |
|
args = parser.parse_args() |
|
|
|
noise_dim = 256 |
|
device = torch.device('cuda:%d'%(args.cuda)) |
|
|
|
net_ig = Generator( ngf=64, nz=noise_dim, nc=3, im_size=args.im_size) |
|
net_ig.to(device) |
|
|
|
for epoch in [10000*i for i in range(args.start_iter, args.end_iter+1)]: |
|
ckpt = f"{args.artifacts}/models/{epoch}.pth" |
|
checkpoint = torch.load(ckpt, map_location=lambda a,b: a) |
|
|
|
checkpoint['g'] = {k.replace('module.', ''): v for k, v in checkpoint['g'].items()} |
|
net_ig.load_state_dict(checkpoint['g']) |
|
|
|
|
|
|
|
print('load checkpoint success, epoch %d'%epoch) |
|
|
|
net_ig.to(device) |
|
|
|
del checkpoint |
|
|
|
dist = 'eval_%d'%(epoch) |
|
dist = os.path.join(dist, 'img') |
|
os.makedirs(dist, exist_ok=True) |
|
|
|
with torch.no_grad(): |
|
for i in tqdm(range(args.n_sample//args.batch)): |
|
noise = torch.randn(args.batch, noise_dim).to(device) |
|
g_imgs = net_ig(noise)[0] |
|
g_imgs = F.interpolate(g_imgs, 512) |
|
for j, g_img in enumerate( g_imgs ): |
|
vutils.save_image(g_img.add(1).mul(0.5), |
|
os.path.join(dist, '%d.png'%(i*args.batch+j))) |
|
|