vlbthambawita's picture
First
7f49ac7
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import utils as vutils
import os
import random
import argparse
from tqdm import tqdm
from models import Generator
def load_params(model, new_param):
for p, new_p in zip(model.parameters(), new_param):
p.data.copy_(new_p)
def resize(img):
return F.interpolate(img, size=256)
def batch_generate(zs, netG, batch=8):
g_images = []
with torch.no_grad():
for i in range(len(zs)//batch):
g_images.append( netG(zs[i*batch:(i+1)*batch]).cpu() )
if len(zs)%batch>0:
g_images.append( netG(zs[-(len(zs)%batch):]).cpu() )
return torch.cat(g_images)
def batch_save(images, folder_name):
if not os.path.exists(folder_name):
os.mkdir(folder_name)
for i, image in enumerate(images):
vutils.save_image(image.add(1).mul(0.5), folder_name+'/%d.jpg'%i)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='generate images'
)
parser.add_argument('--ckpt', type=str)
parser.add_argument('--artifacts', type=str, default=".", help='path to artifacts.')
parser.add_argument('--cuda', type=int, default=0, help='index of gpu to use')
parser.add_argument('--start_iter', type=int, default=6)
parser.add_argument('--end_iter', type=int, default=10)
parser.add_argument('--dist', type=str, default='.')
parser.add_argument('--size', type=int, default=256)
parser.add_argument('--batch', default=16, type=int, help='batch size')
parser.add_argument('--n_sample', type=int, default=2000)
parser.add_argument('--big', action='store_true')
parser.add_argument('--im_size', type=int, default=1024)
parser.set_defaults(big=False)
args = parser.parse_args()
noise_dim = 256
device = torch.device('cuda:%d'%(args.cuda))
net_ig = Generator( ngf=64, nz=noise_dim, nc=3, im_size=args.im_size)#, big=args.big )
net_ig.to(device)
for epoch in [10000*i for i in range(args.start_iter, args.end_iter+1)]:
ckpt = f"{args.artifacts}/models/{epoch}.pth"
checkpoint = torch.load(ckpt, map_location=lambda a,b: a)
# Remove prefix `module`.
checkpoint['g'] = {k.replace('module.', ''): v for k, v in checkpoint['g'].items()}
net_ig.load_state_dict(checkpoint['g'])
#load_params(net_ig, checkpoint['g_ema'])
#net_ig.eval()
print('load checkpoint success, epoch %d'%epoch)
net_ig.to(device)
del checkpoint
dist = 'eval_%d'%(epoch)
dist = os.path.join(dist, 'img')
os.makedirs(dist, exist_ok=True)
with torch.no_grad():
for i in tqdm(range(args.n_sample//args.batch)):
noise = torch.randn(args.batch, noise_dim).to(device)
g_imgs = net_ig(noise)[0]
g_imgs = F.interpolate(g_imgs, 512)
for j, g_img in enumerate( g_imgs ):
vutils.save_image(g_img.add(1).mul(0.5),
os.path.join(dist, '%d.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))