vlbthambawita
commited on
Commit
•
df20d82
1
Parent(s):
c7c9ff6
added files
Browse files- .gitignore +1 -0
- __pycache__/models.cpython-310.pyc +0 -0
- generate_4ch.py +36 -4
- generate_4ch_from_huggingface.py +160 -0
- test_out/0_img.png +0 -0
- test_out/0_mask.png +0 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
pre_trained_checkpoint_4ch
|
__pycache__/models.cpython-310.pyc
ADDED
Binary file (12.2 kB). View file
|
|
generate_4ch.py
CHANGED
@@ -5,6 +5,7 @@ import torch.nn.functional as F
|
|
5 |
from torchvision.datasets import ImageFolder
|
6 |
from torch.utils.data import DataLoader
|
7 |
from torchvision import utils as vutils
|
|
|
8 |
|
9 |
import os
|
10 |
import random
|
@@ -36,12 +37,22 @@ def batch_save(images, folder_name):
|
|
36 |
for i, image in enumerate(images):
|
37 |
vutils.save_image(image.add(1).mul(0.5), folder_name+'/%d.jpg'%i)
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
if __name__ == "__main__":
|
41 |
parser = argparse.ArgumentParser(
|
42 |
description='generate images'
|
43 |
)
|
44 |
-
parser.add_argument('--ckpt', type=str, default="
|
45 |
parser.add_argument('--artifacts', type=str, default=".", help='path to artifacts.')
|
46 |
parser.add_argument('--cuda', type=int, default=0, help='index of gpu to use')
|
47 |
parser.add_argument('--start_iter', type=int, default=6)
|
@@ -50,7 +61,7 @@ if __name__ == "__main__":
|
|
50 |
parser.add_argument('--dist', type=str, default='test_out')
|
51 |
parser.add_argument('--size', type=int, default=256)
|
52 |
parser.add_argument('--batch', default=1, type=int, help='batch size')
|
53 |
-
parser.add_argument('--n_sample', type=int, default=
|
54 |
parser.add_argument('--big', action='store_true')
|
55 |
parser.add_argument('--im_size', type=int, default=256)
|
56 |
parser.add_argument("--save_option", default="image_and_mask", help="Options to svae output, image_only, mask_only, image_and_mask", choices=["image_only","mask_only", "image_and_mask"])
|
@@ -59,8 +70,17 @@ if __name__ == "__main__":
|
|
59 |
|
60 |
noise_dim = 256
|
61 |
device = torch.device('cuda:%d'%(args.cuda))
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
-
|
|
|
|
|
|
|
|
|
64 |
net_ig.to(device)
|
65 |
|
66 |
#for epoch in [10000*i for i in range(args.start_iter, args.end_iter+1)]:
|
@@ -69,13 +89,25 @@ if __name__ == "__main__":
|
|
69 |
checkpoint = torch.load(ckpt)
|
70 |
# Remove prefix `module`.
|
71 |
checkpoint['g'] = {k.replace('module.', ''): v for k, v in checkpoint['g'].items()}
|
72 |
-
net_ig.load_state_dict(checkpoint['g'])
|
73 |
#load_params(net_ig, checkpoint['g_ema'])
|
74 |
|
75 |
#net_ig.eval()
|
76 |
print("load checkpoint success")
|
77 |
|
78 |
net_ig.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
del checkpoint
|
81 |
|
|
|
5 |
from torchvision.datasets import ImageFolder
|
6 |
from torch.utils.data import DataLoader
|
7 |
from torchvision import utils as vutils
|
8 |
+
from huggingface_hub import PyTorchModelHubMixin
|
9 |
|
10 |
import os
|
11 |
import random
|
|
|
37 |
for i, image in enumerate(images):
|
38 |
vutils.save_image(image.add(1).mul(0.5), folder_name+'/%d.jpg'%i)
|
39 |
|
40 |
+
# To push the model to Huggingface model hub
|
41 |
+
class MyFastGanModel(nn.Module, PyTorchModelHubMixin):
|
42 |
+
|
43 |
+
def __init__(self, config: dict) -> None:
|
44 |
+
super().__init__()
|
45 |
+
|
46 |
+
self.model = Generator( ngf=config["ngf"], nz=config["noise_dim"], nc=config["nc"], im_size=config["im_size"])
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
return self.model(x)
|
50 |
|
51 |
if __name__ == "__main__":
|
52 |
parser = argparse.ArgumentParser(
|
53 |
description='generate images'
|
54 |
)
|
55 |
+
parser.add_argument('--ckpt', type=str, default="/work/vajira/DL/FastGAN-pytorch/train_results/test1_4ch/models/all_50000.pth")
|
56 |
parser.add_argument('--artifacts', type=str, default=".", help='path to artifacts.')
|
57 |
parser.add_argument('--cuda', type=int, default=0, help='index of gpu to use')
|
58 |
parser.add_argument('--start_iter', type=int, default=6)
|
|
|
61 |
parser.add_argument('--dist', type=str, default='test_out')
|
62 |
parser.add_argument('--size', type=int, default=256)
|
63 |
parser.add_argument('--batch', default=1, type=int, help='batch size')
|
64 |
+
parser.add_argument('--n_sample', type=int, default=1)
|
65 |
parser.add_argument('--big', action='store_true')
|
66 |
parser.add_argument('--im_size', type=int, default=256)
|
67 |
parser.add_argument("--save_option", default="image_and_mask", help="Options to svae output, image_only, mask_only, image_and_mask", choices=["image_only","mask_only", "image_and_mask"])
|
|
|
70 |
|
71 |
noise_dim = 256
|
72 |
device = torch.device('cuda:%d'%(args.cuda))
|
73 |
+
|
74 |
+
# adding the model to the model hub
|
75 |
+
config={"ngf":64, "noise_dim":noise_dim, "nc":4, "im_size":args.im_size}
|
76 |
+
net_ig = MyFastGanModel(config=config)
|
77 |
+
|
78 |
|
79 |
+
|
80 |
+
# exit
|
81 |
+
#exit()
|
82 |
+
|
83 |
+
#net_ig = model #Generator( ngf=64, nz=noise_dim, nc=4, im_size=args.im_size)#, big=args.big )
|
84 |
net_ig.to(device)
|
85 |
|
86 |
#for epoch in [10000*i for i in range(args.start_iter, args.end_iter+1)]:
|
|
|
89 |
checkpoint = torch.load(ckpt)
|
90 |
# Remove prefix `module`.
|
91 |
checkpoint['g'] = {k.replace('module.', ''): v for k, v in checkpoint['g'].items()}
|
92 |
+
net_ig.model.load_state_dict(checkpoint['g'])
|
93 |
#load_params(net_ig, checkpoint['g_ema'])
|
94 |
|
95 |
#net_ig.eval()
|
96 |
print("load checkpoint success")
|
97 |
|
98 |
net_ig.to(device)
|
99 |
+
# Save locally
|
100 |
+
net_ig.save_pretrained("pre_trained_checkpoint_4ch", config=config) # Save the model locally
|
101 |
+
print("Model saved locally. Pushing to Huggingface model hub...")
|
102 |
+
|
103 |
+
# Push to the Huggingface model hub
|
104 |
+
# push to the hub
|
105 |
+
net_ig.push_to_hub("deepsynthbody/deepfake_gi_fastGAN", config=config)
|
106 |
+
|
107 |
+
|
108 |
+
print("pushed to the Huggingface model hub. Done.")
|
109 |
+
exit()
|
110 |
+
|
111 |
|
112 |
del checkpoint
|
113 |
|
generate_4ch_from_huggingface.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch import optim
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torchvision.datasets import ImageFolder
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from torchvision import utils as vutils
|
8 |
+
from huggingface_hub import PyTorchModelHubMixin
|
9 |
+
|
10 |
+
import os
|
11 |
+
import random
|
12 |
+
import argparse
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from models import Generator
|
16 |
+
|
17 |
+
|
18 |
+
def load_params(model, new_param):
|
19 |
+
for p, new_p in zip(model.parameters(), new_param):
|
20 |
+
p.data.copy_(new_p)
|
21 |
+
|
22 |
+
def resize(img):
|
23 |
+
return F.interpolate(img, size=256)
|
24 |
+
|
25 |
+
def batch_generate(zs, netG, batch=8):
|
26 |
+
g_images = []
|
27 |
+
with torch.no_grad():
|
28 |
+
for i in range(len(zs)//batch):
|
29 |
+
g_images.append( netG(zs[i*batch:(i+1)*batch]).cpu() )
|
30 |
+
if len(zs)%batch>0:
|
31 |
+
g_images.append( netG(zs[-(len(zs)%batch):]).cpu() )
|
32 |
+
return torch.cat(g_images)
|
33 |
+
|
34 |
+
def batch_save(images, folder_name):
|
35 |
+
if not os.path.exists(folder_name):
|
36 |
+
os.mkdir(folder_name)
|
37 |
+
for i, image in enumerate(images):
|
38 |
+
vutils.save_image(image.add(1).mul(0.5), folder_name+'/%d.jpg'%i)
|
39 |
+
|
40 |
+
# To push the model to Huggingface model hub
|
41 |
+
class MyFastGanModel(nn.Module, PyTorchModelHubMixin):
|
42 |
+
|
43 |
+
def __init__(self, config: dict) -> None:
|
44 |
+
super().__init__()
|
45 |
+
|
46 |
+
self.model = Generator( ngf=config["ngf"], nz=config["noise_dim"], nc=config["nc"], im_size=config["im_size"])
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
return self.model(x)
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
parser = argparse.ArgumentParser(
|
53 |
+
description='generate images'
|
54 |
+
)
|
55 |
+
parser.add_argument('--ckpt', type=str, default="/work/vajira/DL/FastGAN-pytorch/train_results/test1_4ch/models/all_50000.pth")
|
56 |
+
parser.add_argument('--artifacts', type=str, default=".", help='path to artifacts.')
|
57 |
+
parser.add_argument('--cuda', type=int, default=0, help='index of gpu to use')
|
58 |
+
parser.add_argument('--start_iter', type=int, default=6)
|
59 |
+
parser.add_argument('--end_iter', type=int, default=10)
|
60 |
+
|
61 |
+
parser.add_argument('--dist', type=str, default='test_out')
|
62 |
+
parser.add_argument('--size', type=int, default=256)
|
63 |
+
parser.add_argument('--batch', default=1, type=int, help='batch size')
|
64 |
+
parser.add_argument('--n_sample', type=int, default=1)
|
65 |
+
parser.add_argument('--big', action='store_true')
|
66 |
+
parser.add_argument('--im_size', type=int, default=256)
|
67 |
+
parser.add_argument("--save_option", default="image_and_mask", help="Options to svae output, image_only, mask_only, image_and_mask", choices=["image_only","mask_only", "image_and_mask"])
|
68 |
+
parser.set_defaults(big=False)
|
69 |
+
args = parser.parse_args()
|
70 |
+
|
71 |
+
noise_dim = 256
|
72 |
+
device = torch.device('cuda:%d'%(args.cuda))
|
73 |
+
|
74 |
+
# adding the model to the model hub
|
75 |
+
config={"ngf":64, "noise_dim":noise_dim, "nc":4, "im_size":args.im_size}
|
76 |
+
net_ig = MyFastGanModel(config=config)
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
# exit
|
81 |
+
#exit()
|
82 |
+
|
83 |
+
#net_ig = model #Generator( ngf=64, nz=noise_dim, nc=4, im_size=args.im_size)#, big=args.big )
|
84 |
+
#net_ig.to(device)
|
85 |
+
|
86 |
+
#for epoch in [10000*i for i in range(args.start_iter, args.end_iter+1)]:
|
87 |
+
#ckpt = args.ckpt #f"{args.artifacts}/models/{epoch}.pth"
|
88 |
+
#checkpoint = torch.load(ckpt, map_location=lambda a,b: a)
|
89 |
+
#checkpoint = torch.load(ckpt)
|
90 |
+
# Remove prefix `module`.
|
91 |
+
#checkpoint['g'] = {k.replace('module.', ''): v for k, v in checkpoint['g'].items()}
|
92 |
+
#net_ig.model.load_state_dict(checkpoint['g'])
|
93 |
+
#load_params(net_ig, checkpoint['g_ema'])
|
94 |
+
|
95 |
+
net_ig = MyFastGanModel.from_pretrained("deepsynthbody/deepfake_gi_fastGAN", config=config) # Load the model from the hub
|
96 |
+
|
97 |
+
#net_ig.eval()
|
98 |
+
print("load checkpoint success")
|
99 |
+
|
100 |
+
net_ig.to(device)
|
101 |
+
# Save locally
|
102 |
+
# net_ig.save_pretrained("pre_trained_checkpoint_4ch", config=config) # Save the model locally
|
103 |
+
# print("Model saved locally. Pushing to Huggingface model hub...")
|
104 |
+
|
105 |
+
# Push to the Huggingface model hub
|
106 |
+
# push to the hub
|
107 |
+
# net_ig.push_to_hub("deepsynthbody/deepfake_gi_fastGAN", config=config)
|
108 |
+
|
109 |
+
|
110 |
+
#print("pushed to the Huggingface model hub. Done.")
|
111 |
+
#exit()
|
112 |
+
|
113 |
+
|
114 |
+
#del checkpoint
|
115 |
+
|
116 |
+
#dist = 'eval_%d'%(epoch)
|
117 |
+
#dist = os.path.join(args.dist, 'img')
|
118 |
+
dist = args.dist
|
119 |
+
os.makedirs(dist, exist_ok=True)
|
120 |
+
|
121 |
+
with torch.no_grad():
|
122 |
+
for i in tqdm(range(args.n_sample//args.batch)):
|
123 |
+
noise = torch.randn(args.batch, noise_dim).to(device)
|
124 |
+
g_imgs = net_ig(noise)[0]
|
125 |
+
g_imgs = F.interpolate(g_imgs, 512)
|
126 |
+
|
127 |
+
|
128 |
+
for j, g_img in enumerate( g_imgs ):
|
129 |
+
#print("img sahpe=", g_img.shape)
|
130 |
+
g_mask = g_img.add(1).mul(0.5)[-1, :, :].expand(3, -1, -1)
|
131 |
+
g_img = g_img.add(1).mul(0.5)[0:3, :, :]
|
132 |
+
|
133 |
+
# Clean generated data using clamping
|
134 |
+
g_mask = torch.clamp(g_mask, min=0, max=1)
|
135 |
+
g_img = torch.clamp(g_img, min=0, max=1)
|
136 |
+
#print(g_mask.type())
|
137 |
+
g_mask = (g_mask > 0.5) * 1.0
|
138 |
+
#print(g_mask.type())
|
139 |
+
|
140 |
+
#print("gmask_min:", g_mask.min())
|
141 |
+
#print("gmask_max:", g_mask.max())
|
142 |
+
#exit()
|
143 |
+
|
144 |
+
#print("img sahpe=", g_img.shape)
|
145 |
+
|
146 |
+
if args.save_option == "image_and_mask":
|
147 |
+
vutils.save_image(g_img,
|
148 |
+
os.path.join(dist, '%d_img.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
|
149 |
+
vutils.save_image(g_mask,
|
150 |
+
os.path.join(dist, '%d_mask.png'%(i*args.batch+j))) #, normalize=True, range=(0,1))
|
151 |
+
|
152 |
+
elif args.save_option == "image_only":
|
153 |
+
vutils.save_image(g_img,
|
154 |
+
os.path.join(dist, '%d_img.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
|
155 |
+
|
156 |
+
elif args.save_option == "mask_only":
|
157 |
+
vutils.save_image(g_mask,
|
158 |
+
os.path.join(dist, '%d_mask.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
|
159 |
+
else:
|
160 |
+
print("wrong choise to save option.")
|
test_out/0_img.png
ADDED
test_out/0_mask.png
ADDED