File size: 6,066 Bytes
df20d82 d5a0b8d df20d82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import utils as vutils
from huggingface_hub import PyTorchModelHubMixin
import os
import random
import argparse
from tqdm import tqdm
from models import Generator
def load_params(model, new_param):
for p, new_p in zip(model.parameters(), new_param):
p.data.copy_(new_p)
def resize(img):
return F.interpolate(img, size=256)
def batch_generate(zs, netG, batch=8):
g_images = []
with torch.no_grad():
for i in range(len(zs)//batch):
g_images.append( netG(zs[i*batch:(i+1)*batch]).cpu() )
if len(zs)%batch>0:
g_images.append( netG(zs[-(len(zs)%batch):]).cpu() )
return torch.cat(g_images)
def batch_save(images, folder_name):
if not os.path.exists(folder_name):
os.mkdir(folder_name)
for i, image in enumerate(images):
vutils.save_image(image.add(1).mul(0.5), folder_name+'/%d.jpg'%i)
# To push the model to Huggingface model hub
class MyFastGanModel(nn.Module, PyTorchModelHubMixin):
def __init__(self, config: dict) -> None:
super().__init__()
self.model = Generator( ngf=config["ngf"], nz=config["noise_dim"], nc=config["nc"], im_size=config["im_size"])
def forward(self, x):
return self.model(x)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='generate images'
)
#parser.add_argument('--ckpt', type=str, default="/work/vajira/DL/FastGAN-pytorch/train_results/test1_4ch/models/all_50000.pth")
parser.add_argument('--artifacts', type=str, default=".", help='path to artifacts.')
parser.add_argument('--cuda', type=int, default=0, help='index of gpu to use')
parser.add_argument('--start_iter', type=int, default=6)
parser.add_argument('--end_iter', type=int, default=10)
parser.add_argument('--dist', type=str, default='test_out')
parser.add_argument('--size', type=int, default=256)
parser.add_argument('--batch', default=1, type=int, help='batch size')
parser.add_argument('--n_sample', type=int, default=1)
parser.add_argument('--big', action='store_true')
parser.add_argument('--im_size', type=int, default=256)
parser.add_argument("--save_option", default="image_and_mask", help="Options to svae output, image_only, mask_only, image_and_mask", choices=["image_only","mask_only", "image_and_mask"])
parser.set_defaults(big=False)
args = parser.parse_args()
noise_dim = 256
device = torch.device('cuda:%d'%(args.cuda))
# adding the model to the model hub
config={"ngf":64, "noise_dim":noise_dim, "nc":4, "im_size":args.im_size}
net_ig = MyFastGanModel(config=config)
# exit
#exit()
#net_ig = model #Generator( ngf=64, nz=noise_dim, nc=4, im_size=args.im_size)#, big=args.big )
#net_ig.to(device)
#for epoch in [10000*i for i in range(args.start_iter, args.end_iter+1)]:
#ckpt = args.ckpt #f"{args.artifacts}/models/{epoch}.pth"
#checkpoint = torch.load(ckpt, map_location=lambda a,b: a)
#checkpoint = torch.load(ckpt)
# Remove prefix `module`.
#checkpoint['g'] = {k.replace('module.', ''): v for k, v in checkpoint['g'].items()}
#net_ig.model.load_state_dict(checkpoint['g'])
#load_params(net_ig, checkpoint['g_ema'])
net_ig = MyFastGanModel.from_pretrained("deepsynthbody/deepfake_gi_fastGAN", config=config) # Load the model from the hub
#net_ig.eval()
print("load checkpoint success")
net_ig.to(device)
# Save locally
# net_ig.save_pretrained("pre_trained_checkpoint_4ch", config=config) # Save the model locally
# print("Model saved locally. Pushing to Huggingface model hub...")
# Push to the Huggingface model hub
# push to the hub
# net_ig.push_to_hub("deepsynthbody/deepfake_gi_fastGAN", config=config)
#print("pushed to the Huggingface model hub. Done.")
#exit()
#del checkpoint
#dist = 'eval_%d'%(epoch)
#dist = os.path.join(args.dist, 'img')
dist = args.dist
os.makedirs(dist, exist_ok=True)
with torch.no_grad():
for i in tqdm(range(args.n_sample//args.batch)):
noise = torch.randn(args.batch, noise_dim).to(device)
g_imgs = net_ig(noise)[0]
g_imgs = F.interpolate(g_imgs, 512)
for j, g_img in enumerate( g_imgs ):
#print("img sahpe=", g_img.shape)
g_mask = g_img.add(1).mul(0.5)[-1, :, :].expand(3, -1, -1)
g_img = g_img.add(1).mul(0.5)[0:3, :, :]
# Clean generated data using clamping
g_mask = torch.clamp(g_mask, min=0, max=1)
g_img = torch.clamp(g_img, min=0, max=1)
#print(g_mask.type())
g_mask = (g_mask > 0.5) * 1.0
#print(g_mask.type())
#print("gmask_min:", g_mask.min())
#print("gmask_max:", g_mask.max())
#exit()
#print("img sahpe=", g_img.shape)
if args.save_option == "image_and_mask":
vutils.save_image(g_img,
os.path.join(dist, '%d_img.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
vutils.save_image(g_mask,
os.path.join(dist, '%d_mask.png'%(i*args.batch+j))) #, normalize=True, range=(0,1))
elif args.save_option == "image_only":
vutils.save_image(g_img,
os.path.join(dist, '%d_img.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
elif args.save_option == "mask_only":
vutils.save_image(g_mask,
os.path.join(dist, '%d_mask.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
else:
print("wrong choise to save option.")
|