deepfake_gi_fastGAN / generate_4ch_from_huggingface.py
vlbthambawita's picture
updated generate 4ch huggingface
d5a0b8d
raw
history blame
6.07 kB
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import utils as vutils
from huggingface_hub import PyTorchModelHubMixin
import os
import random
import argparse
from tqdm import tqdm
from models import Generator
def load_params(model, new_param):
for p, new_p in zip(model.parameters(), new_param):
p.data.copy_(new_p)
def resize(img):
return F.interpolate(img, size=256)
def batch_generate(zs, netG, batch=8):
g_images = []
with torch.no_grad():
for i in range(len(zs)//batch):
g_images.append( netG(zs[i*batch:(i+1)*batch]).cpu() )
if len(zs)%batch>0:
g_images.append( netG(zs[-(len(zs)%batch):]).cpu() )
return torch.cat(g_images)
def batch_save(images, folder_name):
if not os.path.exists(folder_name):
os.mkdir(folder_name)
for i, image in enumerate(images):
vutils.save_image(image.add(1).mul(0.5), folder_name+'/%d.jpg'%i)
# To push the model to Huggingface model hub
class MyFastGanModel(nn.Module, PyTorchModelHubMixin):
def __init__(self, config: dict) -> None:
super().__init__()
self.model = Generator( ngf=config["ngf"], nz=config["noise_dim"], nc=config["nc"], im_size=config["im_size"])
def forward(self, x):
return self.model(x)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='generate images'
)
#parser.add_argument('--ckpt', type=str, default="/work/vajira/DL/FastGAN-pytorch/train_results/test1_4ch/models/all_50000.pth")
parser.add_argument('--artifacts', type=str, default=".", help='path to artifacts.')
parser.add_argument('--cuda', type=int, default=0, help='index of gpu to use')
parser.add_argument('--start_iter', type=int, default=6)
parser.add_argument('--end_iter', type=int, default=10)
parser.add_argument('--dist', type=str, default='test_out')
parser.add_argument('--size', type=int, default=256)
parser.add_argument('--batch', default=1, type=int, help='batch size')
parser.add_argument('--n_sample', type=int, default=1)
parser.add_argument('--big', action='store_true')
parser.add_argument('--im_size', type=int, default=256)
parser.add_argument("--save_option", default="image_and_mask", help="Options to svae output, image_only, mask_only, image_and_mask", choices=["image_only","mask_only", "image_and_mask"])
parser.set_defaults(big=False)
args = parser.parse_args()
noise_dim = 256
device = torch.device('cuda:%d'%(args.cuda))
# adding the model to the model hub
config={"ngf":64, "noise_dim":noise_dim, "nc":4, "im_size":args.im_size}
net_ig = MyFastGanModel(config=config)
# exit
#exit()
#net_ig = model #Generator( ngf=64, nz=noise_dim, nc=4, im_size=args.im_size)#, big=args.big )
#net_ig.to(device)
#for epoch in [10000*i for i in range(args.start_iter, args.end_iter+1)]:
#ckpt = args.ckpt #f"{args.artifacts}/models/{epoch}.pth"
#checkpoint = torch.load(ckpt, map_location=lambda a,b: a)
#checkpoint = torch.load(ckpt)
# Remove prefix `module`.
#checkpoint['g'] = {k.replace('module.', ''): v for k, v in checkpoint['g'].items()}
#net_ig.model.load_state_dict(checkpoint['g'])
#load_params(net_ig, checkpoint['g_ema'])
net_ig = MyFastGanModel.from_pretrained("deepsynthbody/deepfake_gi_fastGAN", config=config) # Load the model from the hub
#net_ig.eval()
print("load checkpoint success")
net_ig.to(device)
# Save locally
# net_ig.save_pretrained("pre_trained_checkpoint_4ch", config=config) # Save the model locally
# print("Model saved locally. Pushing to Huggingface model hub...")
# Push to the Huggingface model hub
# push to the hub
# net_ig.push_to_hub("deepsynthbody/deepfake_gi_fastGAN", config=config)
#print("pushed to the Huggingface model hub. Done.")
#exit()
#del checkpoint
#dist = 'eval_%d'%(epoch)
#dist = os.path.join(args.dist, 'img')
dist = args.dist
os.makedirs(dist, exist_ok=True)
with torch.no_grad():
for i in tqdm(range(args.n_sample//args.batch)):
noise = torch.randn(args.batch, noise_dim).to(device)
g_imgs = net_ig(noise)[0]
g_imgs = F.interpolate(g_imgs, 512)
for j, g_img in enumerate( g_imgs ):
#print("img sahpe=", g_img.shape)
g_mask = g_img.add(1).mul(0.5)[-1, :, :].expand(3, -1, -1)
g_img = g_img.add(1).mul(0.5)[0:3, :, :]
# Clean generated data using clamping
g_mask = torch.clamp(g_mask, min=0, max=1)
g_img = torch.clamp(g_img, min=0, max=1)
#print(g_mask.type())
g_mask = (g_mask > 0.5) * 1.0
#print(g_mask.type())
#print("gmask_min:", g_mask.min())
#print("gmask_max:", g_mask.max())
#exit()
#print("img sahpe=", g_img.shape)
if args.save_option == "image_and_mask":
vutils.save_image(g_img,
os.path.join(dist, '%d_img.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
vutils.save_image(g_mask,
os.path.join(dist, '%d_mask.png'%(i*args.batch+j))) #, normalize=True, range=(0,1))
elif args.save_option == "image_only":
vutils.save_image(g_img,
os.path.join(dist, '%d_img.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
elif args.save_option == "mask_only":
vutils.save_image(g_mask,
os.path.join(dist, '%d_mask.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
else:
print("wrong choise to save option.")