File size: 5,927 Bytes
7f49ac7 df20d82 7f49ac7 df20d82 7f49ac7 df20d82 7f49ac7 c7c9ff6 7f49ac7 df20d82 7f49ac7 df20d82 7f49ac7 df20d82 7f49ac7 df20d82 7f49ac7 df20d82 7f49ac7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import utils as vutils
from huggingface_hub import PyTorchModelHubMixin
import os
import random
import argparse
from tqdm import tqdm
from models import Generator
def load_params(model, new_param):
for p, new_p in zip(model.parameters(), new_param):
p.data.copy_(new_p)
def resize(img):
return F.interpolate(img, size=256)
def batch_generate(zs, netG, batch=8):
g_images = []
with torch.no_grad():
for i in range(len(zs)//batch):
g_images.append( netG(zs[i*batch:(i+1)*batch]).cpu() )
if len(zs)%batch>0:
g_images.append( netG(zs[-(len(zs)%batch):]).cpu() )
return torch.cat(g_images)
def batch_save(images, folder_name):
if not os.path.exists(folder_name):
os.mkdir(folder_name)
for i, image in enumerate(images):
vutils.save_image(image.add(1).mul(0.5), folder_name+'/%d.jpg'%i)
# To push the model to Huggingface model hub
class MyFastGanModel(nn.Module, PyTorchModelHubMixin):
def __init__(self, config: dict) -> None:
super().__init__()
self.model = Generator( ngf=config["ngf"], nz=config["noise_dim"], nc=config["nc"], im_size=config["im_size"])
def forward(self, x):
return self.model(x)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='generate images'
)
parser.add_argument('--ckpt', type=str, default="/work/vajira/DL/FastGAN-pytorch/train_results/test1_4ch/models/all_50000.pth")
parser.add_argument('--artifacts', type=str, default=".", help='path to artifacts.')
parser.add_argument('--cuda', type=int, default=0, help='index of gpu to use')
parser.add_argument('--start_iter', type=int, default=6)
parser.add_argument('--end_iter', type=int, default=10)
parser.add_argument('--dist', type=str, default='test_out')
parser.add_argument('--size', type=int, default=256)
parser.add_argument('--batch', default=1, type=int, help='batch size')
parser.add_argument('--n_sample', type=int, default=1)
parser.add_argument('--big', action='store_true')
parser.add_argument('--im_size', type=int, default=256)
parser.add_argument("--save_option", default="image_and_mask", help="Options to svae output, image_only, mask_only, image_and_mask", choices=["image_only","mask_only", "image_and_mask"])
parser.set_defaults(big=False)
args = parser.parse_args()
noise_dim = 256
device = torch.device('cuda:%d'%(args.cuda))
# adding the model to the model hub
config={"ngf":64, "noise_dim":noise_dim, "nc":4, "im_size":args.im_size}
net_ig = MyFastGanModel(config=config)
# exit
#exit()
#net_ig = model #Generator( ngf=64, nz=noise_dim, nc=4, im_size=args.im_size)#, big=args.big )
net_ig.to(device)
#for epoch in [10000*i for i in range(args.start_iter, args.end_iter+1)]:
ckpt = args.ckpt #f"{args.artifacts}/models/{epoch}.pth"
#checkpoint = torch.load(ckpt, map_location=lambda a,b: a)
checkpoint = torch.load(ckpt)
# Remove prefix `module`.
checkpoint['g'] = {k.replace('module.', ''): v for k, v in checkpoint['g'].items()}
net_ig.model.load_state_dict(checkpoint['g'])
#load_params(net_ig, checkpoint['g_ema'])
#net_ig.eval()
print("load checkpoint success")
net_ig.to(device)
# Save locally
net_ig.save_pretrained("pre_trained_checkpoint_4ch", config=config) # Save the model locally
print("Model saved locally. Pushing to Huggingface model hub...")
# Push to the Huggingface model hub
# push to the hub
net_ig.push_to_hub("deepsynthbody/deepfake_gi_fastGAN", config=config)
print("pushed to the Huggingface model hub. Done.")
exit()
del checkpoint
#dist = 'eval_%d'%(epoch)
#dist = os.path.join(args.dist, 'img')
dist = args.dist
os.makedirs(dist, exist_ok=True)
with torch.no_grad():
for i in tqdm(range(args.n_sample//args.batch)):
noise = torch.randn(args.batch, noise_dim).to(device)
g_imgs = net_ig(noise)[0]
g_imgs = F.interpolate(g_imgs, 512)
for j, g_img in enumerate( g_imgs ):
#print("img sahpe=", g_img.shape)
g_mask = g_img.add(1).mul(0.5)[-1, :, :].expand(3, -1, -1)
g_img = g_img.add(1).mul(0.5)[0:3, :, :]
# Clean generated data using clamping
g_mask = torch.clamp(g_mask, min=0, max=1)
g_img = torch.clamp(g_img, min=0, max=1)
#print(g_mask.type())
g_mask = (g_mask > 0.5) * 1.0
#print(g_mask.type())
#print("gmask_min:", g_mask.min())
#print("gmask_max:", g_mask.max())
#exit()
#print("img sahpe=", g_img.shape)
if args.save_option == "image_and_mask":
vutils.save_image(g_img,
os.path.join(dist, '%d_img.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
vutils.save_image(g_mask,
os.path.join(dist, '%d_mask.png'%(i*args.batch+j))) #, normalize=True, range=(0,1))
elif args.save_option == "image_only":
vutils.save_image(g_img,
os.path.join(dist, '%d_img.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
elif args.save_option == "mask_only":
vutils.save_image(g_mask,
os.path.join(dist, '%d_mask.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
else:
print("wrong choise to save option.")
|