echen01 commited on
Commit
5c10e4d
1 Parent(s): 58aeb21

update PTI

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. PTI/{color_transfer_loss.py → criteria/color_transfer_loss.py} +0 -0
  2. PTI/models/StyleCLIP/__init__.py +0 -0
  3. PTI/models/StyleCLIP/criteria/__init__.py +0 -0
  4. PTI/models/StyleCLIP/criteria/clip_loss.py +0 -17
  5. PTI/models/StyleCLIP/criteria/id_loss.py +0 -39
  6. PTI/models/StyleCLIP/global_directions/GUI.py +0 -103
  7. PTI/models/StyleCLIP/global_directions/GenerateImg.py +0 -50
  8. PTI/models/StyleCLIP/global_directions/GetCode.py +0 -232
  9. PTI/models/StyleCLIP/global_directions/GetGUIData.py +0 -67
  10. PTI/models/StyleCLIP/global_directions/Inference.py +0 -106
  11. PTI/models/StyleCLIP/global_directions/MapTS.py +0 -394
  12. PTI/models/StyleCLIP/global_directions/PlayInteractively.py +0 -197
  13. PTI/models/StyleCLIP/global_directions/SingleChannel.py +0 -109
  14. PTI/models/StyleCLIP/global_directions/__init__.py +0 -0
  15. PTI/models/StyleCLIP/global_directions/data/ffhq/w_plus.npy +0 -3
  16. PTI/models/StyleCLIP/global_directions/dnnlib/__init__.py +0 -9
  17. PTI/models/StyleCLIP/global_directions/dnnlib/tflib/__init__.py +0 -20
  18. PTI/models/StyleCLIP/global_directions/dnnlib/tflib/autosummary.py +0 -193
  19. PTI/models/StyleCLIP/global_directions/dnnlib/tflib/custom_ops.py +0 -181
  20. PTI/models/StyleCLIP/global_directions/dnnlib/tflib/network.py +0 -781
  21. PTI/models/StyleCLIP/global_directions/dnnlib/tflib/ops/__init__.py +0 -9
  22. PTI/models/StyleCLIP/global_directions/dnnlib/tflib/ops/fused_bias_act.cu +0 -220
  23. PTI/models/StyleCLIP/global_directions/dnnlib/tflib/ops/fused_bias_act.py +0 -211
  24. PTI/models/StyleCLIP/global_directions/dnnlib/tflib/ops/upfirdn_2d.cu +0 -359
  25. PTI/models/StyleCLIP/global_directions/dnnlib/tflib/ops/upfirdn_2d.py +0 -418
  26. PTI/models/StyleCLIP/global_directions/dnnlib/tflib/optimizer.py +0 -372
  27. PTI/models/StyleCLIP/global_directions/dnnlib/tflib/tfutil.py +0 -262
  28. PTI/models/StyleCLIP/global_directions/dnnlib/util.py +0 -472
  29. PTI/models/StyleCLIP/global_directions/manipulate.py +0 -278
  30. PTI/models/StyleCLIP/global_directions/utils/__init__.py +0 -0
  31. PTI/models/StyleCLIP/global_directions/utils/editor.py +0 -507
  32. PTI/models/StyleCLIP/global_directions/utils/train_boundary.py +0 -158
  33. PTI/models/StyleCLIP/global_directions/utils/visualizer.py +0 -605
  34. PTI/models/StyleCLIP/mapper/__init__.py +0 -0
  35. PTI/models/StyleCLIP/mapper/datasets/__init__.py +0 -0
  36. PTI/models/StyleCLIP/mapper/datasets/latents_dataset.py +0 -15
  37. PTI/models/StyleCLIP/mapper/latent_mappers.py +0 -81
  38. PTI/models/StyleCLIP/mapper/options/__init__.py +0 -0
  39. PTI/models/StyleCLIP/mapper/options/test_options.py +0 -42
  40. PTI/models/StyleCLIP/mapper/options/train_options.py +0 -49
  41. PTI/models/StyleCLIP/mapper/scripts/inference.py +0 -80
  42. PTI/models/StyleCLIP/mapper/scripts/train.py +0 -32
  43. PTI/models/StyleCLIP/mapper/styleclip_mapper.py +0 -76
  44. PTI/models/StyleCLIP/mapper/training/__init__.py +0 -0
  45. PTI/models/StyleCLIP/mapper/training/coach.py +0 -242
  46. PTI/models/StyleCLIP/mapper/training/ranger.py +0 -164
  47. PTI/models/StyleCLIP/mapper/training/train_utils.py +0 -13
  48. PTI/models/StyleCLIP/models/__init__.py +0 -0
  49. PTI/models/StyleCLIP/models/facial_recognition/__init__.py +0 -0
  50. PTI/models/StyleCLIP/models/facial_recognition/helpers.py +0 -119
PTI/{color_transfer_loss.py → criteria/color_transfer_loss.py} RENAMED
File without changes
PTI/models/StyleCLIP/__init__.py DELETED
File without changes
PTI/models/StyleCLIP/criteria/__init__.py DELETED
File without changes
PTI/models/StyleCLIP/criteria/clip_loss.py DELETED
@@ -1,17 +0,0 @@
1
-
2
- import torch
3
- import clip
4
-
5
-
6
- class CLIPLoss(torch.nn.Module):
7
-
8
- def __init__(self, opts):
9
- super(CLIPLoss, self).__init__()
10
- self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
11
- self.upsample = torch.nn.Upsample(scale_factor=7)
12
- self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32)
13
-
14
- def forward(self, image, text):
15
- image = self.avg_pool(self.upsample(image))
16
- similarity = 1 - self.model(image, text)[0] / 100
17
- return similarity
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/criteria/id_loss.py DELETED
@@ -1,39 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
- from models.facial_recognition.model_irse import Backbone
5
-
6
-
7
- class IDLoss(nn.Module):
8
- def __init__(self, opts):
9
- super(IDLoss, self).__init__()
10
- print('Loading ResNet ArcFace')
11
- self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
12
- self.facenet.load_state_dict(torch.load(opts.ir_se50_weights))
13
- self.pool = torch.nn.AdaptiveAvgPool2d((256, 256))
14
- self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
15
- self.facenet.eval()
16
- self.opts = opts
17
-
18
- def extract_feats(self, x):
19
- if x.shape[2] != 256:
20
- x = self.pool(x)
21
- x = x[:, :, 35:223, 32:220] # Crop interesting region
22
- x = self.face_pool(x)
23
- x_feats = self.facenet(x)
24
- return x_feats
25
-
26
- def forward(self, y_hat, y):
27
- n_samples = y.shape[0]
28
- y_feats = self.extract_feats(y) # Otherwise use the feature from there
29
- y_hat_feats = self.extract_feats(y_hat)
30
- y_feats = y_feats.detach()
31
- loss = 0
32
- sim_improvement = 0
33
- count = 0
34
- for i in range(n_samples):
35
- diff_target = y_hat_feats[i].dot(y_feats[i])
36
- loss += 1 - diff_target
37
- count += 1
38
-
39
- return loss / count, sim_improvement / count
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/GUI.py DELETED
@@ -1,103 +0,0 @@
1
-
2
-
3
- from tkinter import Tk,Frame ,Label,Button,messagebox,Canvas,Text,Scale
4
- from tkinter import HORIZONTAL
5
-
6
- class View():
7
- def __init__(self,master):
8
-
9
- self.width=600
10
- self.height=600
11
-
12
-
13
- self.root=master
14
- self.root.geometry("600x600")
15
-
16
- self.left_frame=Frame(self.root,width=600)
17
- self.left_frame.pack_propagate(0)
18
- self.left_frame.pack(fill='both', side='left', expand='True')
19
-
20
- self.retrieval_frame=Frame(self.root,bg='snow3')
21
- self.retrieval_frame.pack_propagate(0)
22
- self.retrieval_frame.pack(fill='both', side='right', expand='True')
23
-
24
- self.bg_frame=Frame(self.left_frame,bg='snow3',height=600,width=600)
25
- self.bg_frame.pack_propagate(0)
26
- self.bg_frame.pack(fill='both', side='top', expand='True')
27
-
28
- self.command_frame=Frame(self.left_frame,bg='snow3')
29
- self.command_frame.pack_propagate(0)
30
- self.command_frame.pack(fill='both', side='bottom', expand='True')
31
- # self.command_frame.grid(row=1, column=0,padx=0, pady=0)
32
-
33
- self.bg=Canvas(self.bg_frame,width=self.width,height=self.height, bg='gray')
34
- self.bg.place(relx=0.5, rely=0.5, anchor='center')
35
-
36
- self.mani=Canvas(self.retrieval_frame,width=1024,height=1024, bg='gray')
37
- self.mani.grid(row=0, column=0,padx=0, pady=42)
38
-
39
- self.SetCommand()
40
-
41
-
42
-
43
-
44
- def run(self):
45
- self.root.mainloop()
46
-
47
- def helloCallBack(self):
48
- category=self.set_category.get()
49
- messagebox.showinfo( "Hello Python",category)
50
-
51
- def SetCommand(self):
52
-
53
- tmp = Label(self.command_frame, text="neutral", width=10 ,bg='snow3')
54
- tmp.grid(row=1, column=0,padx=10, pady=10)
55
-
56
- tmp = Label(self.command_frame, text="a photo of a", width=10 ,bg='snow3')
57
- tmp.grid(row=1, column=1,padx=10, pady=10)
58
-
59
- self.neutral = Text ( self.command_frame, height=2, width=30)
60
- self.neutral.grid(row=1, column=2,padx=10, pady=10)
61
-
62
-
63
- tmp = Label(self.command_frame, text="target", width=10 ,bg='snow3')
64
- tmp.grid(row=2, column=0,padx=10, pady=10)
65
-
66
- tmp = Label(self.command_frame, text="a photo of a", width=10 ,bg='snow3')
67
- tmp.grid(row=2, column=1,padx=10, pady=10)
68
-
69
- self.target = Text ( self.command_frame, height=2, width=30)
70
- self.target.grid(row=2, column=2,padx=10, pady=10)
71
-
72
- tmp = Label(self.command_frame, text="strength", width=10 ,bg='snow3')
73
- tmp.grid(row=3, column=0,padx=10, pady=10)
74
-
75
- self.alpha = Scale(self.command_frame, from_=-15, to=25, orient=HORIZONTAL,bg='snow3', length=250,resolution=0.01)
76
- self.alpha.grid(row=3, column=2,padx=10, pady=10)
77
-
78
-
79
- tmp = Label(self.command_frame, text="disentangle", width=10 ,bg='snow3')
80
- tmp.grid(row=4, column=0,padx=10, pady=10)
81
-
82
- self.beta = Scale(self.command_frame, from_=0.08, to=0.4, orient=HORIZONTAL,bg='snow3', length=250,resolution=0.001)
83
- self.beta.grid(row=4, column=2,padx=10, pady=10)
84
-
85
- self.reset = Button(self.command_frame, text='Reset')
86
- self.reset.grid(row=5, column=1,padx=10, pady=10)
87
-
88
-
89
- self.set_init = Button(self.command_frame, text='Accept')
90
- self.set_init.grid(row=5, column=2,padx=10, pady=10)
91
-
92
- #%%
93
- if __name__ == "__main__":
94
- master=Tk()
95
- self=View(master)
96
- self.run()
97
-
98
-
99
-
100
-
101
-
102
-
103
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/GenerateImg.py DELETED
@@ -1,50 +0,0 @@
1
-
2
- import os
3
- import numpy as np
4
- import argparse
5
- from manipulate import Manipulator
6
-
7
- from PIL import Image
8
- #%%
9
-
10
- if __name__ == "__main__":
11
- parser = argparse.ArgumentParser(description='Process some integers.')
12
-
13
- parser.add_argument('--dataset_name',type=str,default='ffhq',
14
- help='name of dataset, for example, ffhq')
15
-
16
- args = parser.parse_args()
17
- dataset_name=args.dataset_name
18
-
19
- if not os.path.isdir('./data/'+dataset_name):
20
- os.system('mkdir ./data/'+dataset_name)
21
- #%%
22
- M=Manipulator(dataset_name=dataset_name)
23
- np.set_printoptions(suppress=True)
24
- print(M.dataset_name)
25
- #%%
26
-
27
- M.img_index=0
28
- M.num_images=50
29
- M.alpha=[0]
30
- M.step=1
31
- lindex,bname=0,0
32
-
33
- M.manipulate_layers=[lindex]
34
- codes,out=M.EditOneC(bname)
35
- #%%
36
-
37
- for i in range(len(out)):
38
- img=out[i,0]
39
- img=Image.fromarray(img)
40
- img.save('./data/'+dataset_name+'/'+str(i)+'.jpg')
41
- #%%
42
- w=np.load('./npy/'+dataset_name+'/W.npy')
43
-
44
- tmp=w[:M.num_images]
45
- tmp=tmp[:,None,:]
46
- tmp=np.tile(tmp,(1,M.Gs.components.synthesis.input_shape[1],1))
47
-
48
- np.save('./data/'+dataset_name+'/w_plus.npy',tmp)
49
-
50
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/GetCode.py DELETED
@@ -1,232 +0,0 @@
1
-
2
-
3
-
4
- import os
5
- import pickle
6
- import numpy as np
7
- from dnnlib import tflib
8
- import tensorflow as tf
9
-
10
- import argparse
11
-
12
- def LoadModel(dataset_name):
13
- # Initialize TensorFlow.
14
- tflib.init_tf()
15
- model_path='./model/'
16
- model_name=dataset_name+'.pkl'
17
-
18
- tmp=os.path.join(model_path,model_name)
19
- with open(tmp, 'rb') as f:
20
- _, _, Gs = pickle.load(f)
21
- return Gs
22
-
23
- def lerp(a,b,t):
24
- return a + (b - a) * t
25
-
26
- #stylegan-ada
27
- def SelectName(layer_name,suffix):
28
- if suffix==None:
29
- tmp1='add:0' in layer_name
30
- tmp2='shape=(?,' in layer_name
31
- tmp4='G_synthesis_1' in layer_name
32
- tmp= tmp1 and tmp2 and tmp4
33
- else:
34
- tmp1=('/Conv0_up'+suffix) in layer_name
35
- tmp2=('/Conv1'+suffix) in layer_name
36
- tmp3=('4x4/Conv'+suffix) in layer_name
37
- tmp4='G_synthesis_1' in layer_name
38
- tmp5=('/ToRGB'+suffix) in layer_name
39
- tmp= (tmp1 or tmp2 or tmp3 or tmp5) and tmp4
40
- return tmp
41
-
42
-
43
- def GetSNames(suffix):
44
- #get style tensor name
45
- with tf.Session() as sess:
46
- op = sess.graph.get_operations()
47
- layers=[m.values() for m in op]
48
-
49
-
50
- select_layers=[]
51
- for layer in layers:
52
- layer_name=str(layer)
53
- if SelectName(layer_name,suffix):
54
- select_layers.append(layer[0])
55
- return select_layers
56
-
57
- def SelectName2(layer_name):
58
- tmp1='mod_bias' in layer_name
59
- tmp2='mod_weight' in layer_name
60
- tmp3='ToRGB' in layer_name
61
-
62
- tmp= (tmp1 or tmp2) and (not tmp3)
63
- return tmp
64
-
65
- def GetKName(Gs):
66
-
67
- layers=[var for name, var in Gs.components.synthesis.vars.items()]
68
-
69
- select_layers=[]
70
- for layer in layers:
71
- layer_name=str(layer)
72
- if SelectName2(layer_name):
73
- select_layers.append(layer)
74
- return select_layers
75
-
76
- def GetCode(Gs,random_state,num_img,num_once,dataset_name):
77
- rnd = np.random.RandomState(random_state) #5
78
-
79
- truncation_psi=0.7
80
- truncation_cutoff=8
81
-
82
- dlatent_avg=Gs.get_var('dlatent_avg')
83
-
84
- dlatents=np.zeros((num_img,512),dtype='float32')
85
- for i in range(int(num_img/num_once)):
86
- src_latents = rnd.randn(num_once, Gs.input_shape[1])
87
- src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component]
88
-
89
- # Apply truncation trick.
90
- if truncation_psi is not None and truncation_cutoff is not None:
91
- layer_idx = np.arange(src_dlatents.shape[1])[np.newaxis, :, np.newaxis]
92
- ones = np.ones(layer_idx.shape, dtype=np.float32)
93
- coefs = np.where(layer_idx < truncation_cutoff, truncation_psi * ones, ones)
94
- src_dlatents_np=lerp(dlatent_avg, src_dlatents, coefs)
95
- src_dlatents=src_dlatents_np[:,0,:].astype('float32')
96
- dlatents[(i*num_once):((i+1)*num_once),:]=src_dlatents
97
- print('get all z and w')
98
-
99
- tmp='./npy/'+dataset_name+'/W'
100
- np.save(tmp,dlatents)
101
-
102
-
103
- def GetImg(Gs,num_img,num_once,dataset_name,save_name='images'):
104
- print('Generate Image')
105
- tmp='./npy/'+dataset_name+'/W.npy'
106
- dlatents=np.load(tmp)
107
- fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
108
-
109
- all_images=[]
110
- for i in range(int(num_img/num_once)):
111
- print(i)
112
- images=[]
113
- for k in range(num_once):
114
- tmp=dlatents[i*num_once+k]
115
- tmp=tmp[None,None,:]
116
- tmp=np.tile(tmp,(1,Gs.components.synthesis.input_shape[1],1))
117
- image2= Gs.components.synthesis.run(tmp, randomize_noise=False, output_transform=fmt)
118
- images.append(image2)
119
-
120
- images=np.concatenate(images)
121
-
122
- all_images.append(images)
123
-
124
- all_images=np.concatenate(all_images)
125
-
126
- tmp='./npy/'+dataset_name+'/'+save_name
127
- np.save(tmp,all_images)
128
-
129
- def GetS(dataset_name,num_img):
130
- print('Generate S')
131
- tmp='./npy/'+dataset_name+'/W.npy'
132
- dlatents=np.load(tmp)[:num_img]
133
-
134
- with tf.Session() as sess:
135
- init = tf.global_variables_initializer()
136
- sess.run(init)
137
-
138
- Gs=LoadModel(dataset_name)
139
- Gs.print_layers() #for ada
140
- select_layers1=GetSNames(suffix=None) #None,'/mul_1:0','/mod_weight/read:0','/MatMul:0'
141
- dlatents=dlatents[:,None,:]
142
- dlatents=np.tile(dlatents,(1,Gs.components.synthesis.input_shape[1],1))
143
-
144
- all_s = sess.run(
145
- select_layers1,
146
- feed_dict={'G_synthesis_1/dlatents_in:0': dlatents})
147
-
148
- layer_names=[layer.name for layer in select_layers1]
149
- save_tmp=[layer_names,all_s]
150
- return save_tmp
151
-
152
-
153
-
154
-
155
- def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False):
156
- """Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
157
- Can be used as an output transformation for Network.run().
158
- """
159
- if nchw_to_nhwc:
160
- images = np.transpose(images, [0, 2, 3, 1])
161
-
162
- scale = 255 / (drange[1] - drange[0])
163
- images = images * scale + (0.5 - drange[0] * scale)
164
-
165
- np.clip(images, 0, 255, out=images)
166
- images=images.astype('uint8')
167
- return images
168
-
169
-
170
- def GetCodeMS(dlatents):
171
- m=[]
172
- std=[]
173
- for i in range(len(dlatents)):
174
- tmp= dlatents[i]
175
- tmp_mean=tmp.mean(axis=0)
176
- tmp_std=tmp.std(axis=0)
177
- m.append(tmp_mean)
178
- std.append(tmp_std)
179
- return m,std
180
-
181
-
182
-
183
- #%%
184
- if __name__ == "__main__":
185
-
186
-
187
- parser = argparse.ArgumentParser(description='Process some integers.')
188
-
189
- parser.add_argument('--dataset_name',type=str,default='ffhq',
190
- help='name of dataset, for example, ffhq')
191
- parser.add_argument('--code_type',choices=['w','s','s_mean_std'],default='w')
192
-
193
- args = parser.parse_args()
194
- random_state=5
195
- num_img=100_000
196
- num_once=1_000
197
- dataset_name=args.dataset_name
198
-
199
- if not os.path.isfile('./model/'+dataset_name+'.pkl'):
200
- url='https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/'
201
- name='stylegan2-'+dataset_name+'-config-f.pkl'
202
- os.system('wget ' +url+name + ' -P ./model/')
203
- os.system('mv ./model/'+name+' ./model/'+dataset_name+'.pkl')
204
-
205
- if not os.path.isdir('./npy/'+dataset_name):
206
- os.system('mkdir ./npy/'+dataset_name)
207
-
208
- if args.code_type=='w':
209
- Gs=LoadModel(dataset_name=dataset_name)
210
- GetCode(Gs,random_state,num_img,num_once,dataset_name)
211
- # GetImg(Gs,num_img=num_img,num_once=num_once,dataset_name=dataset_name,save_name='images_100K') #no need
212
- elif args.code_type=='s':
213
- save_name='S'
214
- save_tmp=GetS(dataset_name,num_img=2_000)
215
- tmp='./npy/'+dataset_name+'/'+save_name
216
- with open(tmp, "wb") as fp:
217
- pickle.dump(save_tmp, fp)
218
-
219
- elif args.code_type=='s_mean_std':
220
- save_tmp=GetS(dataset_name,num_img=num_img)
221
- dlatents=save_tmp[1]
222
- m,std=GetCodeMS(dlatents)
223
- save_tmp=[m,std]
224
- save_name='S_mean_std'
225
- tmp='./npy/'+dataset_name+'/'+save_name
226
- with open(tmp, "wb") as fp:
227
- pickle.dump(save_tmp, fp)
228
-
229
-
230
-
231
-
232
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/GetGUIData.py DELETED
@@ -1,67 +0,0 @@
1
-
2
- import os
3
- import numpy as np
4
- import argparse
5
- from manipulate import Manipulator
6
- import torch
7
- from PIL import Image
8
- #%%
9
-
10
- if __name__ == "__main__":
11
- parser = argparse.ArgumentParser(description='Process some integers.')
12
-
13
- parser.add_argument('--dataset_name',type=str,default='ffhq',
14
- help='name of dataset, for example, ffhq')
15
-
16
- parser.add_argument('--real', action='store_true')
17
-
18
- args = parser.parse_args()
19
- dataset_name=args.dataset_name
20
-
21
- if not os.path.isdir('./data/'+dataset_name):
22
- os.system('mkdir ./data/'+dataset_name)
23
- #%%
24
- M=Manipulator(dataset_name=dataset_name)
25
- np.set_printoptions(suppress=True)
26
- print(M.dataset_name)
27
- #%%
28
- #remove all .jpg
29
- names=os.listdir('./data/'+dataset_name+'/')
30
- for name in names:
31
- if '.jpg' in name:
32
- os.system('rm ./data/'+dataset_name+'/'+name)
33
-
34
-
35
- #%%
36
- if args.real:
37
- latents=torch.load('./data/'+dataset_name+'/latents.pt')
38
- w_plus=latents.cpu().detach().numpy()
39
- else:
40
- w=np.load('./npy/'+dataset_name+'/W.npy')
41
- tmp=w[:50] #only use 50 images
42
- tmp=tmp[:,None,:]
43
- w_plus=np.tile(tmp,(1,M.Gs.components.synthesis.input_shape[1],1))
44
- np.save('./data/'+dataset_name+'/w_plus.npy',w_plus)
45
-
46
- #%%
47
- tmp=M.W2S(w_plus)
48
- M.dlatents=tmp
49
-
50
- M.img_index=0
51
- M.num_images=len(w_plus)
52
- M.alpha=[0]
53
- M.step=1
54
- lindex,bname=0,0
55
-
56
- M.manipulate_layers=[lindex]
57
- codes,out=M.EditOneC(bname)
58
- #%%
59
-
60
- for i in range(len(out)):
61
- img=out[i,0]
62
- img=Image.fromarray(img)
63
- img.save('./data/'+dataset_name+'/'+str(i)+'.jpg')
64
- #%%
65
-
66
-
67
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/Inference.py DELETED
@@ -1,106 +0,0 @@
1
-
2
-
3
- from manipulate import Manipulator
4
- import tensorflow as tf
5
- import numpy as np
6
- import torch
7
- import clip
8
- from MapTS import GetBoundary,GetDt
9
-
10
- class StyleCLIP():
11
-
12
- def __init__(self,dataset_name='ffhq'):
13
- print('load clip')
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
15
- self.model, preprocess = clip.load("ViT-B/32", device=device)
16
- self.LoadData(dataset_name)
17
-
18
- def LoadData(self, dataset_name):
19
- tf.keras.backend.clear_session()
20
- M=Manipulator(dataset_name=dataset_name)
21
- np.set_printoptions(suppress=True)
22
- fs3=np.load('./npy/'+dataset_name+'/fs3.npy')
23
-
24
- self.M=M
25
- self.fs3=fs3
26
-
27
- w_plus=np.load('./data/'+dataset_name+'/w_plus.npy')
28
- self.M.dlatents=M.W2S(w_plus)
29
-
30
- if dataset_name=='ffhq':
31
- self.c_threshold=20
32
- else:
33
- self.c_threshold=100
34
- self.SetInitP()
35
-
36
- def SetInitP(self):
37
- self.M.alpha=[3]
38
- self.M.num_images=1
39
-
40
- self.target=''
41
- self.neutral=''
42
- self.GetDt2()
43
- img_index=0
44
- self.M.dlatent_tmp=[tmp[img_index:(img_index+1)] for tmp in self.M.dlatents]
45
-
46
-
47
- def GetDt2(self):
48
- classnames=[self.target,self.neutral]
49
- dt=GetDt(classnames,self.model)
50
-
51
- self.dt=dt
52
- num_cs=[]
53
- betas=np.arange(0.1,0.3,0.01)
54
- for i in range(len(betas)):
55
- boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=betas[i])
56
- print(betas[i])
57
- num_cs.append(num_c)
58
-
59
- num_cs=np.array(num_cs)
60
- select=num_cs>self.c_threshold
61
-
62
- if sum(select)==0:
63
- self.beta=0.1
64
- else:
65
- self.beta=betas[select][-1]
66
-
67
-
68
- def GetCode(self):
69
- boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=self.beta)
70
- codes=self.M.MSCode(self.M.dlatent_tmp,boundary_tmp2)
71
- return codes
72
-
73
- def GetImg(self):
74
-
75
- codes=self.GetCode()
76
- out=self.M.GenerateImg(codes)
77
- img=out[0,0]
78
- return img
79
-
80
-
81
-
82
-
83
- #%%
84
- if __name__ == "__main__":
85
- style_clip=StyleCLIP()
86
- self=style_clip
87
-
88
-
89
-
90
-
91
-
92
-
93
-
94
-
95
-
96
-
97
-
98
-
99
-
100
-
101
-
102
-
103
-
104
-
105
-
106
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/MapTS.py DELETED
@@ -1,394 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- Created on Thu Feb 4 17:36:31 2021
5
-
6
- @author: wuzongze
7
- """
8
-
9
- import os
10
- #os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
11
- #os.environ["CUDA_VISIBLE_DEVICES"] = "1" #(or "1" or "2")
12
-
13
- import sys
14
-
15
- #sys.path=['', '/usr/local/tensorflow/avx-avx2-gpu/1.14.0/python3.7/site-packages', '/usr/local/matlab/2018b/lib/python3.7/site-packages', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python37.zip', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/lib-dynload', '/usr/lib/python3.7', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/site-packages', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/site-packages/copkmeans-1.5-py3.7.egg', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/site-packages/spherecluster-0.1.7-py3.7.egg', '/usr/lib/python3/dist-packages', '/usr/local/lib/python3.7/dist-packages', '/usr/lib/python3/dist-packages/IPython/extensions']
16
-
17
- import tensorflow as tf
18
-
19
- import numpy as np
20
- import torch
21
- import clip
22
- from PIL import Image
23
- import pickle
24
- import copy
25
- import matplotlib.pyplot as plt
26
-
27
- def GetAlign(out,dt,model,preprocess):
28
- imgs=out
29
- imgs1=imgs.reshape([-1]+list(imgs.shape[2:]))
30
-
31
- tmp=[]
32
- for i in range(len(imgs1)):
33
-
34
- img=Image.fromarray(imgs1[i])
35
- image = preprocess(img).unsqueeze(0).to(device)
36
- tmp.append(image)
37
-
38
- image=torch.cat(tmp)
39
-
40
- with torch.no_grad():
41
- image_features = model.encode_image(image)
42
- image_features = image_features / image_features.norm(dim=-1, keepdim=True)
43
-
44
- image_features1=image_features.cpu().numpy()
45
-
46
- image_features1=image_features1.reshape(list(imgs.shape[:2])+[512])
47
-
48
- fd=image_features1[:,1:,:]-image_features1[:,:-1,:]
49
-
50
- fd1=fd.reshape([-1,512])
51
- fd2=fd1/np.linalg.norm(fd1,axis=1)[:,None]
52
-
53
- tmp=np.dot(fd2,dt)
54
- m=tmp.mean()
55
- acc=np.sum(tmp>0)/len(tmp)
56
- print(m,acc)
57
- return m,acc
58
-
59
-
60
- def SplitS(ds_p,M,if_std):
61
- all_ds=[]
62
- start=0
63
- for i in M.mindexs:
64
- tmp=M.dlatents[i].shape[1]
65
- end=start+tmp
66
- tmp=ds_p[start:end]
67
- # tmp=tmp*M.code_std[i]
68
-
69
- all_ds.append(tmp)
70
- start=end
71
-
72
- all_ds2=[]
73
- tmp_index=0
74
- for i in range(len(M.s_names)):
75
- if (not 'RGB' in M.s_names[i]) and (not len(all_ds[tmp_index])==0):
76
-
77
- # tmp=np.abs(all_ds[tmp_index]/M.code_std[i])
78
- # print(i,tmp.mean())
79
- # tmp=np.dot(M.latent_codes[i],all_ds[tmp_index])
80
- # print(tmp)
81
- if if_std:
82
- tmp=all_ds[tmp_index]*M.code_std[i]
83
- else:
84
- tmp=all_ds[tmp_index]
85
-
86
- all_ds2.append(tmp)
87
- tmp_index+=1
88
- else:
89
- tmp=np.zeros(len(M.dlatents[i][0]))
90
- all_ds2.append(tmp)
91
- return all_ds2
92
-
93
-
94
- imagenet_templates = [
95
- 'a bad photo of a {}.',
96
- # 'a photo of many {}.',
97
- 'a sculpture of a {}.',
98
- 'a photo of the hard to see {}.',
99
- 'a low resolution photo of the {}.',
100
- 'a rendering of a {}.',
101
- 'graffiti of a {}.',
102
- 'a bad photo of the {}.',
103
- 'a cropped photo of the {}.',
104
- 'a tattoo of a {}.',
105
- 'the embroidered {}.',
106
- 'a photo of a hard to see {}.',
107
- 'a bright photo of a {}.',
108
- 'a photo of a clean {}.',
109
- 'a photo of a dirty {}.',
110
- 'a dark photo of the {}.',
111
- 'a drawing of a {}.',
112
- 'a photo of my {}.',
113
- 'the plastic {}.',
114
- 'a photo of the cool {}.',
115
- 'a close-up photo of a {}.',
116
- 'a black and white photo of the {}.',
117
- 'a painting of the {}.',
118
- 'a painting of a {}.',
119
- 'a pixelated photo of the {}.',
120
- 'a sculpture of the {}.',
121
- 'a bright photo of the {}.',
122
- 'a cropped photo of a {}.',
123
- 'a plastic {}.',
124
- 'a photo of the dirty {}.',
125
- 'a jpeg corrupted photo of a {}.',
126
- 'a blurry photo of the {}.',
127
- 'a photo of the {}.',
128
- 'a good photo of the {}.',
129
- 'a rendering of the {}.',
130
- 'a {} in a video game.',
131
- 'a photo of one {}.',
132
- 'a doodle of a {}.',
133
- 'a close-up photo of the {}.',
134
- 'a photo of a {}.',
135
- 'the origami {}.',
136
- 'the {} in a video game.',
137
- 'a sketch of a {}.',
138
- 'a doodle of the {}.',
139
- 'a origami {}.',
140
- 'a low resolution photo of a {}.',
141
- 'the toy {}.',
142
- 'a rendition of the {}.',
143
- 'a photo of the clean {}.',
144
- 'a photo of a large {}.',
145
- 'a rendition of a {}.',
146
- 'a photo of a nice {}.',
147
- 'a photo of a weird {}.',
148
- 'a blurry photo of a {}.',
149
- 'a cartoon {}.',
150
- 'art of a {}.',
151
- 'a sketch of the {}.',
152
- 'a embroidered {}.',
153
- 'a pixelated photo of a {}.',
154
- 'itap of the {}.',
155
- 'a jpeg corrupted photo of the {}.',
156
- 'a good photo of a {}.',
157
- 'a plushie {}.',
158
- 'a photo of the nice {}.',
159
- 'a photo of the small {}.',
160
- 'a photo of the weird {}.',
161
- 'the cartoon {}.',
162
- 'art of the {}.',
163
- 'a drawing of the {}.',
164
- 'a photo of the large {}.',
165
- 'a black and white photo of a {}.',
166
- 'the plushie {}.',
167
- 'a dark photo of a {}.',
168
- 'itap of a {}.',
169
- 'graffiti of the {}.',
170
- 'a toy {}.',
171
- 'itap of my {}.',
172
- 'a photo of a cool {}.',
173
- 'a photo of a small {}.',
174
- 'a tattoo of the {}.',
175
- ]
176
-
177
-
178
- def zeroshot_classifier(classnames, templates,model):
179
- with torch.no_grad():
180
- zeroshot_weights = []
181
- for classname in classnames:
182
- texts = [template.format(classname) for template in templates] #format with class
183
- texts = clip.tokenize(texts).cuda() #tokenize
184
- class_embeddings = model.encode_text(texts) #embed with text encoder
185
- class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
186
- class_embedding = class_embeddings.mean(dim=0)
187
- class_embedding /= class_embedding.norm()
188
- zeroshot_weights.append(class_embedding)
189
- zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
190
- return zeroshot_weights
191
-
192
-
193
- def GetDt(classnames,model):
194
- text_features=zeroshot_classifier(classnames, imagenet_templates,model).t()
195
-
196
- dt=text_features[0]-text_features[1]
197
- dt=dt.cpu().numpy()
198
-
199
- # t_m1=t_m/np.linalg.norm(t_m)
200
- # dt=text_features.cpu().numpy()[0]-t_m1
201
- print(np.linalg.norm(dt))
202
- dt=dt/np.linalg.norm(dt)
203
- return dt
204
-
205
-
206
- def GetBoundary(fs3,dt,M,threshold):
207
- tmp=np.dot(fs3,dt)
208
-
209
- ds_imp=copy.copy(tmp)
210
- select=np.abs(tmp)<threshold
211
- num_c=np.sum(~select)
212
-
213
-
214
- ds_imp[select]=0
215
- tmp=np.abs(ds_imp).max()
216
- ds_imp/=tmp
217
-
218
- boundary_tmp2=SplitS(ds_imp,M,if_std=True)
219
- print('num of channels being manipulated:',num_c)
220
- return boundary_tmp2,num_c
221
-
222
- def GetFs(file_path):
223
- fs=np.load(file_path+'single_channel.npy')
224
- tmp=np.linalg.norm(fs,axis=-1)
225
- fs1=fs/tmp[:,:,:,None]
226
- fs2=fs1[:,:,1,:]-fs1[:,:,0,:] # 5*sigma - (-5)* sigma
227
- fs3=fs2/np.linalg.norm(fs2,axis=-1)[:,:,None]
228
- fs3=fs3.mean(axis=1)
229
- fs3=fs3/np.linalg.norm(fs3,axis=-1)[:,None]
230
- return fs3
231
- #%%
232
-
233
- if __name__ == "__main__":
234
- device = "cuda" if torch.cuda.is_available() else "cpu"
235
- model, preprocess = clip.load("ViT-B/32", device=device)
236
- #%%
237
- sys.path.append('/cs/labs/danix/wuzongze/Gan_Manipulation/play')
238
- from example_try import Manipulator4
239
-
240
- M=Manipulator4(dataset_name='ffhq',code_type='S')
241
- np.set_printoptions(suppress=True)
242
-
243
- #%%
244
-
245
-
246
- file_path='/cs/labs/danix/wuzongze/Tansformer_Manipulation/CLIP/results/'+M.dataset_name+'/'
247
- fs3=GetFs(file_path)
248
-
249
-
250
-
251
- #%%
252
- '''
253
- text_features=zeroshot_classifier2(classnames, imagenet_templates) #.t()
254
-
255
- tmp=np.linalg.norm(text_features,axis=2)
256
- text_features/=tmp[:,:,None]
257
- dt=text_features[0]-text_features[1]
258
-
259
- tmp=np.linalg.norm(dt,axis=1)
260
- dt/=tmp[:,None]
261
- dt=dt.mean(axis=0)
262
- '''
263
-
264
- #%%
265
- '''
266
- all_tmp=[]
267
- tmp=torch.load('/cs/labs/danix/wuzongze/downloads/harris_latent.pt')
268
- tmp=tmp.cpu().detach().numpy() #[:,:14,:]
269
- all_tmp.append(tmp)
270
-
271
- tmp=torch.load('/cs/labs/danix/wuzongze/downloads/ariana_latent.pt')
272
- tmp=tmp.cpu().detach().numpy() #[:,:14,:]
273
- all_tmp.append(tmp)
274
-
275
- tmp=torch.load('/cs/labs/danix/wuzongze/downloads/federer.pt')
276
- tmp=tmp.cpu().detach().numpy() #[:,:14,:]
277
- all_tmp.append(tmp)
278
-
279
- all_tmp=np.array(all_tmp)[:,0]
280
-
281
- dlatent_tmp=M.W2S(all_tmp)
282
- '''
283
- '''
284
- tmp=torch.load('/cs/labs/danix/wuzongze/downloads/all_cars.pt')
285
- tmp=tmp.cpu().detach().numpy()[:300]
286
- dlatent_tmp=M.W2S(tmp)
287
- '''
288
- '''
289
- tmp=torch.load('/cs/labs/danix/wuzongze/downloads/faces.pt')
290
- tmp=tmp.cpu().detach().numpy()[:100]
291
- dlatent_tmp=M.W2S(tmp)
292
- '''
293
- #%%
294
- # M.viz_size=1024
295
- M.img_index=0
296
- M.num_images=30
297
- dlatent_tmp=[tmp[M.img_index:(M.img_index+M.num_images)] for tmp in M.dlatents]
298
- #%%
299
-
300
- classnames=['face','face with glasses']
301
-
302
- # classnames=['car','classic car']
303
- # classnames=['dog','happy dog']
304
- # classnames=['bedroom','modern bedroom']
305
-
306
- # classnames=['church','church without watermark']
307
- # classnames=['natural scene','natural scene without grass']
308
- dt=GetDt(classnames,model)
309
- # tmp=np.dot(fs3,dt)
310
- #
311
- # ds_imp=copy.copy(tmp)
312
- # select=np.abs(tmp)<0.1
313
- # num_c=np.sum(~select)
314
- #
315
- #
316
- # ds_imp[select]=0
317
- # tmp=np.abs(ds_imp).max()
318
- # ds_imp/=tmp
319
- #
320
- # boundary_tmp2=SplitS(ds_imp,M,if_std=True)
321
- # print('num of channels being manipulated:',num_c)
322
-
323
- boundary_tmp2=GetBoundary(fs3,dt,M,threshold=0.13)
324
-
325
- #%%
326
- M.start_distance=-20
327
- M.end_distance=20
328
- M.step=7
329
- # M.num_images=100
330
- codes=M.MSCode(dlatent_tmp,boundary_tmp2)
331
- out=M.GenerateImg(codes)
332
- M.Vis2(str('tmp'),'filter2',out)
333
-
334
- # full=GetAlign(out,dt,model,preprocess)
335
-
336
-
337
- #%%
338
- boundary_tmp3=copy.copy(boundary_tmp2) #primary
339
- boundary_tmp4=copy.copy(boundary_tmp2) #condition
340
- #%%
341
- boundary_tmp2=copy.copy(boundary_tmp3)
342
- for i in range(len(boundary_tmp3)):
343
- select=boundary_tmp4[i]==0
344
- boundary_tmp2[i][~select]=0
345
-
346
-
347
-
348
-
349
-
350
-
351
-
352
- #%%1
353
-
354
-
355
-
356
-
357
-
358
-
359
-
360
-
361
-
362
-
363
-
364
-
365
-
366
-
367
-
368
-
369
-
370
-
371
-
372
-
373
-
374
-
375
-
376
-
377
-
378
-
379
-
380
-
381
-
382
-
383
-
384
-
385
-
386
-
387
-
388
-
389
-
390
-
391
-
392
-
393
-
394
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/PlayInteractively.py DELETED
@@ -1,197 +0,0 @@
1
-
2
-
3
-
4
- from tkinter import Tk
5
- from PIL import Image, ImageTk
6
- from tkinter.filedialog import askopenfilename
7
- from GUI import View
8
- from Inference import StyleCLIP
9
- import argparse
10
- #%%
11
-
12
-
13
- class PlayInteractively(): #Controller
14
- '''
15
- followed Model View Controller Design Pattern
16
-
17
- controller, model, view
18
- '''
19
- def __init__(self,dataset_name='ffhq'):
20
-
21
- self.root = Tk()
22
- self.view=View(self.root)
23
- self.img_ratio=2
24
- self.style_clip=StyleCLIP(dataset_name)
25
-
26
- self.view.neutral.bind("<Return>", self.text_n)
27
- self.view.target.bind("<Return>", self.text_t)
28
- self.view.alpha.bind('<ButtonRelease-1>', self.ChangeAlpha)
29
- self.view.beta.bind('<ButtonRelease-1>', self.ChangeBeta)
30
- self.view.set_init.bind('<ButtonPress-1>', self.SetInit)
31
- self.view.reset.bind('<ButtonPress-1>', self.Reset)
32
- self.view.bg.bind('<Double-1>', self.open_img)
33
-
34
-
35
- self.drawn = None
36
-
37
- self.view.target.delete(1.0, "end")
38
- self.view.target.insert("end", self.style_clip.target)
39
- #
40
- self.view.neutral.delete(1.0, "end")
41
- self.view.neutral.insert("end", self.style_clip.neutral)
42
-
43
-
44
- def Reset(self,event):
45
- self.style_clip.GetDt2()
46
- self.style_clip.M.alpha=[0]
47
-
48
- self.view.beta.set(self.style_clip.beta)
49
- self.view.alpha.set(0)
50
-
51
- img=self.style_clip.GetImg()
52
- img=Image.fromarray(img)
53
- img = ImageTk.PhotoImage(img)
54
- self.addImage_m(img)
55
-
56
-
57
- def SetInit(self,event):
58
- codes=self.style_clip.GetCode()
59
- self.style_clip.M.dlatent_tmp=[tmp[:,0] for tmp in codes]
60
- print('set init')
61
-
62
- def ChangeAlpha(self,event):
63
- tmp=self.view.alpha.get()
64
- self.style_clip.M.alpha=[float(tmp)]
65
-
66
- img=self.style_clip.GetImg()
67
- print('manipulate one')
68
- img=Image.fromarray(img)
69
- img = ImageTk.PhotoImage(img)
70
- self.addImage_m(img)
71
-
72
- def ChangeBeta(self,event):
73
- tmp=self.view.beta.get()
74
- self.style_clip.beta=float(tmp)
75
-
76
- img=self.style_clip.GetImg()
77
- print('manipulate one')
78
- img=Image.fromarray(img)
79
- img = ImageTk.PhotoImage(img)
80
- self.addImage_m(img)
81
-
82
- def ChangeDataset(self,event):
83
-
84
- dataset_name=self.view.set_category.get()
85
-
86
- self.style_clip.LoadData(dataset_name)
87
-
88
- self.view.target.delete(1.0, "end")
89
- self.view.target.insert("end", self.style_clip.target)
90
-
91
- self.view.neutral.delete(1.0, "end")
92
- self.view.neutral.insert("end", self.style_clip.neutral)
93
-
94
- def text_t(self,event):
95
- tmp=self.view.target.get("1.0",'end')
96
- tmp=tmp.replace('\n','')
97
-
98
- self.view.target.delete(1.0, "end")
99
- self.view.target.insert("end", tmp)
100
-
101
- print('target',tmp,'###')
102
- self.style_clip.target=tmp
103
- self.style_clip.GetDt2()
104
- self.view.beta.set(self.style_clip.beta)
105
- self.view.alpha.set(3)
106
- self.style_clip.M.alpha=[3]
107
-
108
- img=self.style_clip.GetImg()
109
- print('manipulate one')
110
- img=Image.fromarray(img)
111
- img = ImageTk.PhotoImage(img)
112
- self.addImage_m(img)
113
-
114
-
115
- def text_n(self,event):
116
- tmp=self.view.neutral.get("1.0",'end')
117
- tmp=tmp.replace('\n','')
118
-
119
- self.view.neutral.delete(1.0, "end")
120
- self.view.neutral.insert("end", tmp)
121
-
122
- print('neutral',tmp,'###')
123
- self.style_clip.neutral=tmp
124
- self.view.target.delete(1.0, "end")
125
- self.view.target.insert("end", tmp)
126
-
127
-
128
- def run(self):
129
- self.root.mainloop()
130
-
131
- def addImage(self,img):
132
- self.view.bg.create_image(self.view.width/2, self.view.height/2, image=img, anchor='center')
133
- self.image=img #save a copy of image. if not the image will disappear
134
-
135
- def addImage_m(self,img):
136
- self.view.mani.create_image(512, 512, image=img, anchor='center')
137
- self.image2=img
138
-
139
-
140
- def openfn(self):
141
- filename = askopenfilename(title='open',initialdir='./data/'+self.style_clip.M.dataset_name+'/',filetypes=[("all image format", ".jpg"),("all image format", ".png")])
142
- return filename
143
-
144
- def open_img(self,event):
145
- x = self.openfn()
146
- print(x)
147
-
148
-
149
- img = Image.open(x)
150
- img2 = img.resize(( 512,512), Image.ANTIALIAS)
151
- img2 = ImageTk.PhotoImage(img2)
152
- self.addImage(img2)
153
-
154
- img = ImageTk.PhotoImage(img)
155
- self.addImage_m(img)
156
-
157
- img_index=x.split('/')[-1].split('.')[0]
158
- img_index=int(img_index)
159
- print(img_index)
160
- self.style_clip.M.img_index=img_index
161
- self.style_clip.M.dlatent_tmp=[tmp[img_index:(img_index+1)] for tmp in self.style_clip.M.dlatents]
162
-
163
-
164
- self.style_clip.GetDt2()
165
- self.view.beta.set(self.style_clip.beta)
166
- self.view.alpha.set(3)
167
-
168
- #%%
169
- if __name__ == "__main__":
170
- parser = argparse.ArgumentParser(description='Process some integers.')
171
-
172
- parser.add_argument('--dataset_name',type=str,default='ffhq',
173
- help='name of dataset, for example, ffhq')
174
-
175
- args = parser.parse_args()
176
- dataset_name=args.dataset_name
177
-
178
- self=PlayInteractively(dataset_name)
179
- self.run()
180
-
181
-
182
-
183
-
184
-
185
-
186
-
187
-
188
-
189
-
190
-
191
-
192
-
193
-
194
-
195
-
196
-
197
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/SingleChannel.py DELETED
@@ -1,109 +0,0 @@
1
-
2
-
3
-
4
- import numpy as np
5
- import torch
6
- import clip
7
- from PIL import Image
8
- import copy
9
- from manipulate import Manipulator
10
- import argparse
11
-
12
- def GetImgF(out,model,preprocess):
13
- imgs=out
14
- imgs1=imgs.reshape([-1]+list(imgs.shape[2:]))
15
-
16
- tmp=[]
17
- for i in range(len(imgs1)):
18
-
19
- img=Image.fromarray(imgs1[i])
20
- image = preprocess(img).unsqueeze(0).to(device)
21
- tmp.append(image)
22
-
23
- image=torch.cat(tmp)
24
- with torch.no_grad():
25
- image_features = model.encode_image(image)
26
-
27
- image_features1=image_features.cpu().numpy()
28
- image_features1=image_features1.reshape(list(imgs.shape[:2])+[512])
29
-
30
- return image_features1
31
-
32
- def GetFs(fs):
33
- tmp=np.linalg.norm(fs,axis=-1)
34
- fs1=fs/tmp[:,:,:,None]
35
- fs2=fs1[:,:,1,:]-fs1[:,:,0,:] # 5*sigma - (-5)* sigma
36
- fs3=fs2/np.linalg.norm(fs2,axis=-1)[:,:,None]
37
- fs3=fs3.mean(axis=1)
38
- fs3=fs3/np.linalg.norm(fs3,axis=-1)[:,None]
39
- return fs3
40
-
41
- #%%
42
- if __name__ == "__main__":
43
- parser = argparse.ArgumentParser(description='Process some integers.')
44
-
45
- parser.add_argument('--dataset_name',type=str,default='cat',
46
- help='name of dataset, for example, ffhq')
47
- args = parser.parse_args()
48
- dataset_name=args.dataset_name
49
-
50
- #%%
51
- device = "cuda" if torch.cuda.is_available() else "cpu"
52
- model, preprocess = clip.load("ViT-B/32", device=device)
53
- #%%
54
- M=Manipulator(dataset_name=dataset_name)
55
- np.set_printoptions(suppress=True)
56
- print(M.dataset_name)
57
- #%%
58
- img_sindex=0
59
- num_images=100
60
- dlatents_o=[]
61
- tmp=img_sindex*num_images
62
- for i in range(len(M.dlatents)):
63
- tmp1=M.dlatents[i][tmp:(tmp+num_images)]
64
- dlatents_o.append(tmp1)
65
- #%%
66
-
67
- all_f=[]
68
- M.alpha=[-5,5] #ffhq 5
69
- M.step=2
70
- M.num_images=num_images
71
- select=np.array(M.mindexs)<=16 #below or equal to 128 resolution
72
- mindexs2=np.array(M.mindexs)[select]
73
- for lindex in mindexs2: #ignore ToRGB layers
74
- print(lindex)
75
- num_c=M.dlatents[lindex].shape[1]
76
- for cindex in range(num_c):
77
-
78
- M.dlatents=copy.copy(dlatents_o)
79
- M.dlatents[lindex][:,cindex]=M.code_mean[lindex][cindex]
80
-
81
- M.manipulate_layers=[lindex]
82
- codes,out=M.EditOneC(cindex)
83
- image_features1=GetImgF(out,model,preprocess)
84
- all_f.append(image_features1)
85
-
86
- all_f=np.array(all_f)
87
-
88
- fs3=GetFs(all_f)
89
-
90
- #%%
91
- file_path='./npy/'+M.dataset_name+'/'
92
- np.save(file_path+'fs3',fs3)
93
-
94
-
95
-
96
-
97
-
98
-
99
-
100
-
101
-
102
-
103
-
104
-
105
-
106
-
107
-
108
-
109
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/__init__.py DELETED
File without changes
PTI/models/StyleCLIP/global_directions/data/ffhq/w_plus.npy DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:394f0f166305654f49cd1b0cd3d4f2b7a51e740a449a1ebfa1c69f79d01399fa
3
- size 2506880
 
 
 
 
PTI/models/StyleCLIP/global_directions/dnnlib/__init__.py DELETED
@@ -1,9 +0,0 @@
1
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- from .util import EasyDict, make_cache_dir_path
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/dnnlib/tflib/__init__.py DELETED
@@ -1,20 +0,0 @@
1
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- from . import autosummary
10
- from . import network
11
- from . import optimizer
12
- from . import tfutil
13
- from . import custom_ops
14
-
15
- from .tfutil import *
16
- from .network import Network
17
-
18
- from .optimizer import Optimizer
19
-
20
- from .custom_ops import get_plugin
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/dnnlib/tflib/autosummary.py DELETED
@@ -1,193 +0,0 @@
1
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Helper for adding automatically tracked values to Tensorboard.
10
-
11
- Autosummary creates an identity op that internally keeps track of the input
12
- values and automatically shows up in TensorBoard. The reported value
13
- represents an average over input components. The average is accumulated
14
- constantly over time and flushed when save_summaries() is called.
15
-
16
- Notes:
17
- - The output tensor must be used as an input for something else in the
18
- graph. Otherwise, the autosummary op will not get executed, and the average
19
- value will not get accumulated.
20
- - It is perfectly fine to include autosummaries with the same name in
21
- several places throughout the graph, even if they are executed concurrently.
22
- - It is ok to also pass in a python scalar or numpy array. In this case, it
23
- is added to the average immediately.
24
- """
25
-
26
- from collections import OrderedDict
27
- import numpy as np
28
- import tensorflow as tf
29
- from tensorboard import summary as summary_lib
30
- from tensorboard.plugins.custom_scalar import layout_pb2
31
-
32
- from . import tfutil
33
- from .tfutil import TfExpression
34
- from .tfutil import TfExpressionEx
35
-
36
- # Enable "Custom scalars" tab in TensorBoard for advanced formatting.
37
- # Disabled by default to reduce tfevents file size.
38
- enable_custom_scalars = False
39
-
40
- _dtype = tf.float64
41
- _vars = OrderedDict() # name => [var, ...]
42
- _immediate = OrderedDict() # name => update_op, update_value
43
- _finalized = False
44
- _merge_op = None
45
-
46
-
47
- def _create_var(name: str, value_expr: TfExpression) -> TfExpression:
48
- """Internal helper for creating autosummary accumulators."""
49
- assert not _finalized
50
- name_id = name.replace("/", "_")
51
- v = tf.cast(value_expr, _dtype)
52
-
53
- if v.shape.is_fully_defined():
54
- size = np.prod(v.shape.as_list())
55
- size_expr = tf.constant(size, dtype=_dtype)
56
- else:
57
- size = None
58
- size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype))
59
-
60
- if size == 1:
61
- if v.shape.ndims != 0:
62
- v = tf.reshape(v, [])
63
- v = [size_expr, v, tf.square(v)]
64
- else:
65
- v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))]
66
- v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype))
67
-
68
- with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None):
69
- var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)]
70
- update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v))
71
-
72
- if name in _vars:
73
- _vars[name].append(var)
74
- else:
75
- _vars[name] = [var]
76
- return update_op
77
-
78
-
79
- def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None, condition: TfExpressionEx = True) -> TfExpressionEx:
80
- """Create a new autosummary.
81
-
82
- Args:
83
- name: Name to use in TensorBoard
84
- value: TensorFlow expression or python value to track
85
- passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.
86
-
87
- Example use of the passthru mechanism:
88
-
89
- n = autosummary('l2loss', loss, passthru=n)
90
-
91
- This is a shorthand for the following code:
92
-
93
- with tf.control_dependencies([autosummary('l2loss', loss)]):
94
- n = tf.identity(n)
95
- """
96
- tfutil.assert_tf_initialized()
97
- name_id = name.replace("/", "_")
98
-
99
- if tfutil.is_tf_expression(value):
100
- with tf.name_scope("summary_" + name_id), tf.device(value.device):
101
- condition = tf.convert_to_tensor(condition, name='condition')
102
- update_op = tf.cond(condition, lambda: tf.group(_create_var(name, value)), tf.no_op)
103
- with tf.control_dependencies([update_op]):
104
- return tf.identity(value if passthru is None else passthru)
105
-
106
- else: # python scalar or numpy array
107
- assert not tfutil.is_tf_expression(passthru)
108
- assert not tfutil.is_tf_expression(condition)
109
- if condition:
110
- if name not in _immediate:
111
- with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None):
112
- update_value = tf.placeholder(_dtype)
113
- update_op = _create_var(name, update_value)
114
- _immediate[name] = update_op, update_value
115
- update_op, update_value = _immediate[name]
116
- tfutil.run(update_op, {update_value: value})
117
- return value if passthru is None else passthru
118
-
119
-
120
- def finalize_autosummaries() -> None:
121
- """Create the necessary ops to include autosummaries in TensorBoard report.
122
- Note: This should be done only once per graph.
123
- """
124
- global _finalized
125
- tfutil.assert_tf_initialized()
126
-
127
- if _finalized:
128
- return None
129
-
130
- _finalized = True
131
- tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list])
132
-
133
- # Create summary ops.
134
- with tf.device(None), tf.control_dependencies(None):
135
- for name, vars_list in _vars.items():
136
- name_id = name.replace("/", "_")
137
- with tfutil.absolute_name_scope("Autosummary/" + name_id):
138
- moments = tf.add_n(vars_list)
139
- moments /= moments[0]
140
- with tf.control_dependencies([moments]): # read before resetting
141
- reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list]
142
- with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting
143
- mean = moments[1]
144
- std = tf.sqrt(moments[2] - tf.square(moments[1]))
145
- tf.summary.scalar(name, mean)
146
- if enable_custom_scalars:
147
- tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std)
148
- tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std)
149
-
150
- # Setup layout for custom scalars.
151
- layout = None
152
- if enable_custom_scalars:
153
- cat_dict = OrderedDict()
154
- for series_name in sorted(_vars.keys()):
155
- p = series_name.split("/")
156
- cat = p[0] if len(p) >= 2 else ""
157
- chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1]
158
- if cat not in cat_dict:
159
- cat_dict[cat] = OrderedDict()
160
- if chart not in cat_dict[cat]:
161
- cat_dict[cat][chart] = []
162
- cat_dict[cat][chart].append(series_name)
163
- categories = []
164
- for cat_name, chart_dict in cat_dict.items():
165
- charts = []
166
- for chart_name, series_names in chart_dict.items():
167
- series = []
168
- for series_name in series_names:
169
- series.append(layout_pb2.MarginChartContent.Series(
170
- value=series_name,
171
- lower="xCustomScalars/" + series_name + "/margin_lo",
172
- upper="xCustomScalars/" + series_name + "/margin_hi"))
173
- margin = layout_pb2.MarginChartContent(series=series)
174
- charts.append(layout_pb2.Chart(title=chart_name, margin=margin))
175
- categories.append(layout_pb2.Category(title=cat_name, chart=charts))
176
- layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories))
177
- return layout
178
-
179
- def save_summaries(file_writer, global_step=None):
180
- """Call FileWriter.add_summary() with all summaries in the default graph,
181
- automatically finalizing and merging them on the first call.
182
- """
183
- global _merge_op
184
- tfutil.assert_tf_initialized()
185
-
186
- if _merge_op is None:
187
- layout = finalize_autosummaries()
188
- if layout is not None:
189
- file_writer.add_summary(layout)
190
- with tf.device(None), tf.control_dependencies(None):
191
- _merge_op = tf.summary.merge_all()
192
-
193
- file_writer.add_summary(_merge_op.eval(), global_step)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/dnnlib/tflib/custom_ops.py DELETED
@@ -1,181 +0,0 @@
1
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """TensorFlow custom ops builder.
10
- """
11
-
12
- import glob
13
- import os
14
- import re
15
- import uuid
16
- import hashlib
17
- import tempfile
18
- import shutil
19
- import tensorflow as tf
20
- from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module
21
-
22
- from .. import util
23
-
24
- #----------------------------------------------------------------------------
25
- # Global configs.
26
-
27
- cuda_cache_path = None
28
- cuda_cache_version_tag = 'v1'
29
- do_not_hash_included_headers = True # Speed up compilation by assuming that headers included by the CUDA code never change.
30
- verbose = True # Print status messages to stdout.
31
-
32
- #----------------------------------------------------------------------------
33
- # Internal helper funcs.
34
-
35
- def _find_compiler_bindir():
36
- hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True)
37
- if hostx64_paths != []:
38
- return hostx64_paths[0]
39
- hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True)
40
- if hostx64_paths != []:
41
- return hostx64_paths[0]
42
- hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True)
43
- if hostx64_paths != []:
44
- return hostx64_paths[0]
45
- vc_bin_dir = 'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin'
46
- if os.path.isdir(vc_bin_dir):
47
- return vc_bin_dir
48
- return None
49
-
50
- def _get_compute_cap(device):
51
- caps_str = device.physical_device_desc
52
- m = re.search('compute capability: (\\d+).(\\d+)', caps_str)
53
- major = m.group(1)
54
- minor = m.group(2)
55
- return (major, minor)
56
-
57
- def _get_cuda_gpu_arch_string():
58
- gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU']
59
- if len(gpus) == 0:
60
- raise RuntimeError('No GPU devices found')
61
- (major, minor) = _get_compute_cap(gpus[0])
62
- return 'sm_%s%s' % (major, minor)
63
-
64
- def _run_cmd(cmd):
65
- with os.popen(cmd) as pipe:
66
- output = pipe.read()
67
- status = pipe.close()
68
- if status is not None:
69
- raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output))
70
-
71
- def _prepare_nvcc_cli(opts):
72
- cmd = 'nvcc ' + opts.strip()
73
- cmd += ' --disable-warnings'
74
- cmd += ' --include-path "%s"' % tf.sysconfig.get_include()
75
- cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src')
76
- cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'com_google_absl')
77
- cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'eigen_archive')
78
-
79
- compiler_bindir = _find_compiler_bindir()
80
- if compiler_bindir is None:
81
- # Require that _find_compiler_bindir succeeds on Windows. Allow
82
- # nvcc to use whatever is the default on Linux.
83
- if os.name == 'nt':
84
- raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__)
85
- else:
86
- cmd += ' --compiler-bindir "%s"' % compiler_bindir
87
- cmd += ' 2>&1'
88
- return cmd
89
-
90
- #----------------------------------------------------------------------------
91
- # Main entry point.
92
-
93
- _plugin_cache = dict()
94
-
95
- def get_plugin(cuda_file, extra_nvcc_options=[]):
96
- cuda_file_base = os.path.basename(cuda_file)
97
- cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base)
98
-
99
- # Already in cache?
100
- if cuda_file in _plugin_cache:
101
- return _plugin_cache[cuda_file]
102
-
103
- # Setup plugin.
104
- if verbose:
105
- print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True)
106
- try:
107
- # Hash CUDA source.
108
- md5 = hashlib.md5()
109
- with open(cuda_file, 'rb') as f:
110
- md5.update(f.read())
111
- md5.update(b'\n')
112
-
113
- # Hash headers included by the CUDA code by running it through the preprocessor.
114
- if not do_not_hash_included_headers:
115
- if verbose:
116
- print('Preprocessing... ', end='', flush=True)
117
- with tempfile.TemporaryDirectory() as tmp_dir:
118
- tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext)
119
- _run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)))
120
- with open(tmp_file, 'rb') as f:
121
- bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros
122
- good_file_str = ('"' + cuda_file_base + '"').encode('utf-8')
123
- for ln in f:
124
- if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas
125
- ln = ln.replace(bad_file_str, good_file_str)
126
- md5.update(ln)
127
- md5.update(b'\n')
128
-
129
- # Select compiler configs.
130
- compile_opts = ''
131
- if os.name == 'nt':
132
- compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.lib')
133
- elif os.name == 'posix':
134
- compile_opts += f' --compiler-options \'-fPIC\''
135
- compile_opts += f' --compiler-options \'{" ".join(tf.sysconfig.get_compile_flags())}\''
136
- compile_opts += f' --linker-options \'{" ".join(tf.sysconfig.get_link_flags())}\''
137
- else:
138
- assert False # not Windows or Linux, w00t?
139
- compile_opts += f' --gpu-architecture={_get_cuda_gpu_arch_string()}'
140
- compile_opts += ' --use_fast_math'
141
- for opt in extra_nvcc_options:
142
- compile_opts += ' ' + opt
143
- nvcc_cmd = _prepare_nvcc_cli(compile_opts)
144
-
145
- # Hash build configuration.
146
- md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n')
147
- md5.update(('tf.VERSION: ' + tf.VERSION).encode('utf-8') + b'\n')
148
- md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n')
149
-
150
- # Compile if not already compiled.
151
- cache_dir = util.make_cache_dir_path('tflib-cudacache') if cuda_cache_path is None else cuda_cache_path
152
- bin_file_ext = '.dll' if os.name == 'nt' else '.so'
153
- bin_file = os.path.join(cache_dir, cuda_file_name + '_' + md5.hexdigest() + bin_file_ext)
154
- if not os.path.isfile(bin_file):
155
- if verbose:
156
- print('Compiling... ', end='', flush=True)
157
- with tempfile.TemporaryDirectory() as tmp_dir:
158
- tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext)
159
- _run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))
160
- os.makedirs(cache_dir, exist_ok=True)
161
- intermediate_file = os.path.join(cache_dir, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext)
162
- shutil.copyfile(tmp_file, intermediate_file)
163
- os.rename(intermediate_file, bin_file) # atomic
164
-
165
- # Load.
166
- if verbose:
167
- print('Loading... ', end='', flush=True)
168
- plugin = tf.load_op_library(bin_file)
169
-
170
- # Add to cache.
171
- _plugin_cache[cuda_file] = plugin
172
- if verbose:
173
- print('Done.', flush=True)
174
- return plugin
175
-
176
- except:
177
- if verbose:
178
- print('Failed!', flush=True)
179
- raise
180
-
181
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/dnnlib/tflib/network.py DELETED
@@ -1,781 +0,0 @@
1
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Helper for managing networks."""
10
-
11
- import types
12
- import inspect
13
- import re
14
- import uuid
15
- import sys
16
- import copy
17
- import numpy as np
18
- import tensorflow as tf
19
-
20
- from collections import OrderedDict
21
- from typing import Any, List, Tuple, Union, Callable
22
-
23
- from . import tfutil
24
- from .. import util
25
-
26
- from .tfutil import TfExpression, TfExpressionEx
27
-
28
- # pylint: disable=protected-access
29
- # pylint: disable=attribute-defined-outside-init
30
- # pylint: disable=too-many-public-methods
31
-
32
- _import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import.
33
- _import_module_src = dict() # Source code for temporary modules created during pickle import.
34
-
35
-
36
- def import_handler(handler_func):
37
- """Function decorator for declaring custom import handlers."""
38
- _import_handlers.append(handler_func)
39
- return handler_func
40
-
41
-
42
- class Network:
43
- """Generic network abstraction.
44
-
45
- Acts as a convenience wrapper for a parameterized network construction
46
- function, providing several utility methods and convenient access to
47
- the inputs/outputs/weights.
48
-
49
- Network objects can be safely pickled and unpickled for long-term
50
- archival purposes. The pickling works reliably as long as the underlying
51
- network construction function is defined in a standalone Python module
52
- that has no side effects or application-specific imports.
53
-
54
- Args:
55
- name: Network name. Used to select TensorFlow name and variable scopes. Defaults to build func name if None.
56
- func_name: Fully qualified name of the underlying network construction function, or a top-level function object.
57
- static_kwargs: Keyword arguments to be passed in to the network construction function.
58
- """
59
-
60
- def __init__(self, name: str = None, func_name: Any = None, **static_kwargs):
61
- # Locate the user-specified build function.
62
- assert isinstance(func_name, str) or util.is_top_level_function(func_name)
63
- if util.is_top_level_function(func_name):
64
- func_name = util.get_top_level_function_name(func_name)
65
- module, func_name = util.get_module_from_obj_name(func_name)
66
- func = util.get_obj_from_module(module, func_name)
67
-
68
- # Dig up source code for the module containing the build function.
69
- module_src = _import_module_src.get(module, None)
70
- if module_src is None:
71
- module_src = inspect.getsource(module)
72
-
73
- # Initialize fields.
74
- self._init_fields(name=(name or func_name), static_kwargs=static_kwargs, build_func=func, build_func_name=func_name, build_module_src=module_src)
75
-
76
- def _init_fields(self, name: str, static_kwargs: dict, build_func: Callable, build_func_name: str, build_module_src: str) -> None:
77
- tfutil.assert_tf_initialized()
78
- assert isinstance(name, str)
79
- assert len(name) >= 1
80
- assert re.fullmatch(r"[A-Za-z0-9_.\\-]*", name)
81
- assert isinstance(static_kwargs, dict)
82
- assert util.is_pickleable(static_kwargs)
83
- assert callable(build_func)
84
- assert isinstance(build_func_name, str)
85
- assert isinstance(build_module_src, str)
86
-
87
- # Choose TensorFlow name scope.
88
- with tf.name_scope(None):
89
- scope = tf.get_default_graph().unique_name(name, mark_as_used=True)
90
-
91
- # Query current TensorFlow device.
92
- with tfutil.absolute_name_scope(scope), tf.control_dependencies(None):
93
- device = tf.no_op(name="_QueryDevice").device
94
-
95
- # Immutable state.
96
- self._name = name
97
- self._scope = scope
98
- self._device = device
99
- self._static_kwargs = util.EasyDict(copy.deepcopy(static_kwargs))
100
- self._build_func = build_func
101
- self._build_func_name = build_func_name
102
- self._build_module_src = build_module_src
103
-
104
- # State before _init_graph().
105
- self._var_inits = dict() # var_name => initial_value, set to None by _init_graph()
106
- self._all_inits_known = False # Do we know for sure that _var_inits covers all the variables?
107
- self._components = None # subnet_name => Network, None if the components are not known yet
108
-
109
- # Initialized by _init_graph().
110
- self._input_templates = None
111
- self._output_templates = None
112
- self._own_vars = None
113
-
114
- # Cached values initialized the respective methods.
115
- self._input_shapes = None
116
- self._output_shapes = None
117
- self._input_names = None
118
- self._output_names = None
119
- self._vars = None
120
- self._trainables = None
121
- self._var_global_to_local = None
122
- self._run_cache = dict()
123
-
124
- def _init_graph(self) -> None:
125
- assert self._var_inits is not None
126
- assert self._input_templates is None
127
- assert self._output_templates is None
128
- assert self._own_vars is None
129
-
130
- # Initialize components.
131
- if self._components is None:
132
- self._components = util.EasyDict()
133
-
134
- # Choose build func kwargs.
135
- build_kwargs = dict(self.static_kwargs)
136
- build_kwargs["is_template_graph"] = True
137
- build_kwargs["components"] = self._components
138
-
139
- # Override scope and device, and ignore surrounding control dependencies.
140
- with tfutil.absolute_variable_scope(self.scope, reuse=False), tfutil.absolute_name_scope(self.scope), tf.device(self.device), tf.control_dependencies(None):
141
- assert tf.get_variable_scope().name == self.scope
142
- assert tf.get_default_graph().get_name_scope() == self.scope
143
-
144
- # Create input templates.
145
- self._input_templates = []
146
- for param in inspect.signature(self._build_func).parameters.values():
147
- if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
148
- self._input_templates.append(tf.placeholder(tf.float32, name=param.name))
149
-
150
- # Call build func.
151
- out_expr = self._build_func(*self._input_templates, **build_kwargs)
152
-
153
- # Collect output templates and variables.
154
- assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
155
- self._output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
156
- self._own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/"))
157
-
158
- # Check for errors.
159
- if len(self._input_templates) == 0:
160
- raise ValueError("Network build func did not list any inputs.")
161
- if len(self._output_templates) == 0:
162
- raise ValueError("Network build func did not return any outputs.")
163
- if any(not tfutil.is_tf_expression(t) for t in self._output_templates):
164
- raise ValueError("Network outputs must be TensorFlow expressions.")
165
- if any(t.shape.ndims is None for t in self._input_templates):
166
- raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.")
167
- if any(t.shape.ndims is None for t in self._output_templates):
168
- raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.")
169
- if any(not isinstance(comp, Network) for comp in self._components.values()):
170
- raise ValueError("Components of a Network must be Networks themselves.")
171
- if len(self._components) != len(set(comp.name for comp in self._components.values())):
172
- raise ValueError("Components of a Network must have unique names.")
173
-
174
- # Initialize variables.
175
- if len(self._var_inits):
176
- tfutil.set_vars({self._get_vars()[name]: value for name, value in self._var_inits.items() if name in self._get_vars()})
177
- remaining_inits = [var.initializer for name, var in self._own_vars.items() if name not in self._var_inits]
178
- if self._all_inits_known:
179
- assert len(remaining_inits) == 0
180
- else:
181
- tfutil.run(remaining_inits)
182
- self._var_inits = None
183
-
184
- @property
185
- def name(self):
186
- """User-specified name string."""
187
- return self._name
188
-
189
- @property
190
- def scope(self):
191
- """Unique TensorFlow scope containing template graph and variables, derived from the user-specified name."""
192
- return self._scope
193
-
194
- @property
195
- def device(self):
196
- """Name of the TensorFlow device that the weights of this network reside on. Determined by the current device at construction time."""
197
- return self._device
198
-
199
- @property
200
- def static_kwargs(self):
201
- """EasyDict of arguments passed to the user-supplied build func."""
202
- return copy.deepcopy(self._static_kwargs)
203
-
204
- @property
205
- def components(self):
206
- """EasyDict of sub-networks created by the build func."""
207
- return copy.copy(self._get_components())
208
-
209
- def _get_components(self):
210
- if self._components is None:
211
- self._init_graph()
212
- assert self._components is not None
213
- return self._components
214
-
215
- @property
216
- def input_shapes(self):
217
- """List of input tensor shapes, including minibatch dimension."""
218
- if self._input_shapes is None:
219
- self._input_shapes = [t.shape.as_list() for t in self.input_templates]
220
- return copy.deepcopy(self._input_shapes)
221
-
222
- @property
223
- def output_shapes(self):
224
- """List of output tensor shapes, including minibatch dimension."""
225
- if self._output_shapes is None:
226
- self._output_shapes = [t.shape.as_list() for t in self.output_templates]
227
- return copy.deepcopy(self._output_shapes)
228
-
229
- @property
230
- def input_shape(self):
231
- """Short-hand for input_shapes[0]."""
232
- return self.input_shapes[0]
233
-
234
- @property
235
- def output_shape(self):
236
- """Short-hand for output_shapes[0]."""
237
- return self.output_shapes[0]
238
-
239
- @property
240
- def num_inputs(self):
241
- """Number of input tensors."""
242
- return len(self.input_shapes)
243
-
244
- @property
245
- def num_outputs(self):
246
- """Number of output tensors."""
247
- return len(self.output_shapes)
248
-
249
- @property
250
- def input_names(self):
251
- """Name string for each input."""
252
- if self._input_names is None:
253
- self._input_names = [t.name.split("/")[-1].split(":")[0] for t in self.input_templates]
254
- return copy.copy(self._input_names)
255
-
256
- @property
257
- def output_names(self):
258
- """Name string for each output."""
259
- if self._output_names is None:
260
- self._output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates]
261
- return copy.copy(self._output_names)
262
-
263
- @property
264
- def input_templates(self):
265
- """Input placeholders in the template graph."""
266
- if self._input_templates is None:
267
- self._init_graph()
268
- assert self._input_templates is not None
269
- return copy.copy(self._input_templates)
270
-
271
- @property
272
- def output_templates(self):
273
- """Output tensors in the template graph."""
274
- if self._output_templates is None:
275
- self._init_graph()
276
- assert self._output_templates is not None
277
- return copy.copy(self._output_templates)
278
-
279
- @property
280
- def own_vars(self):
281
- """Variables defined by this network (local_name => var), excluding sub-networks."""
282
- return copy.copy(self._get_own_vars())
283
-
284
- def _get_own_vars(self):
285
- if self._own_vars is None:
286
- self._init_graph()
287
- assert self._own_vars is not None
288
- return self._own_vars
289
-
290
- @property
291
- def vars(self):
292
- """All variables (local_name => var)."""
293
- return copy.copy(self._get_vars())
294
-
295
- def _get_vars(self):
296
- if self._vars is None:
297
- self._vars = OrderedDict(self._get_own_vars())
298
- for comp in self._get_components().values():
299
- self._vars.update((comp.name + "/" + name, var) for name, var in comp._get_vars().items())
300
- return self._vars
301
-
302
- @property
303
- def trainables(self):
304
- """All trainable variables (local_name => var)."""
305
- return copy.copy(self._get_trainables())
306
-
307
- def _get_trainables(self):
308
- if self._trainables is None:
309
- self._trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable)
310
- return self._trainables
311
-
312
- @property
313
- def var_global_to_local(self):
314
- """Mapping from variable global names to local names."""
315
- return copy.copy(self._get_var_global_to_local())
316
-
317
- def _get_var_global_to_local(self):
318
- if self._var_global_to_local is None:
319
- self._var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items())
320
- return self._var_global_to_local
321
-
322
- def reset_own_vars(self) -> None:
323
- """Re-initialize all variables of this network, excluding sub-networks."""
324
- if self._var_inits is None or self._components is None:
325
- tfutil.run([var.initializer for var in self._get_own_vars().values()])
326
- else:
327
- self._var_inits.clear()
328
- self._all_inits_known = False
329
-
330
- def reset_vars(self) -> None:
331
- """Re-initialize all variables of this network, including sub-networks."""
332
- if self._var_inits is None:
333
- tfutil.run([var.initializer for var in self._get_vars().values()])
334
- else:
335
- self._var_inits.clear()
336
- self._all_inits_known = False
337
- if self._components is not None:
338
- for comp in self._components.values():
339
- comp.reset_vars()
340
-
341
- def reset_trainables(self) -> None:
342
- """Re-initialize all trainable variables of this network, including sub-networks."""
343
- tfutil.run([var.initializer for var in self._get_trainables().values()])
344
-
345
- def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]:
346
- """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s).
347
- The graph is placed on the current TensorFlow device."""
348
- assert len(in_expr) == self.num_inputs
349
- assert not all(expr is None for expr in in_expr)
350
- self._get_vars() # ensure that all variables have been created
351
-
352
- # Choose build func kwargs.
353
- build_kwargs = dict(self.static_kwargs)
354
- build_kwargs.update(dynamic_kwargs)
355
- build_kwargs["is_template_graph"] = False
356
- build_kwargs["components"] = self._components
357
-
358
- # Build TensorFlow graph to evaluate the network.
359
- with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name):
360
- assert tf.get_variable_scope().name == self.scope
361
- valid_inputs = [expr for expr in in_expr if expr is not None]
362
- final_inputs = []
363
- for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes):
364
- if expr is not None:
365
- expr = tf.identity(expr, name=name)
366
- else:
367
- expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name)
368
- final_inputs.append(expr)
369
- out_expr = self._build_func(*final_inputs, **build_kwargs)
370
-
371
- # Propagate input shapes back to the user-specified expressions.
372
- for expr, final in zip(in_expr, final_inputs):
373
- if isinstance(expr, tf.Tensor):
374
- expr.set_shape(final.shape)
375
-
376
- # Express outputs in the desired format.
377
- assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
378
- if return_as_list:
379
- out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
380
- return out_expr
381
-
382
- def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str:
383
- """Get the local name of a given variable, without any surrounding name scopes."""
384
- assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str)
385
- global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name
386
- return self._get_var_global_to_local()[global_name]
387
-
388
- def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression:
389
- """Find variable by local or global name."""
390
- assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str)
391
- return self._get_vars()[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name
392
-
393
- def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray:
394
- """Get the value of a given variable as NumPy array.
395
- Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible."""
396
- return self.find_var(var_or_local_name).eval()
397
-
398
- def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None:
399
- """Set the value of a given variable based on the given NumPy array.
400
- Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible."""
401
- tfutil.set_vars({self.find_var(var_or_local_name): new_value})
402
-
403
- def __getstate__(self) -> dict:
404
- """Pickle export."""
405
- state = dict()
406
- state["version"] = 5
407
- state["name"] = self.name
408
- state["static_kwargs"] = dict(self.static_kwargs)
409
- state["components"] = dict(self.components)
410
- state["build_module_src"] = self._build_module_src
411
- state["build_func_name"] = self._build_func_name
412
- state["variables"] = list(zip(self._get_own_vars().keys(), tfutil.run(list(self._get_own_vars().values()))))
413
- state["input_shapes"] = self.input_shapes
414
- state["output_shapes"] = self.output_shapes
415
- state["input_names"] = self.input_names
416
- state["output_names"] = self.output_names
417
- return state
418
-
419
- def __setstate__(self, state: dict) -> None:
420
- """Pickle import."""
421
-
422
- # Execute custom import handlers.
423
- for handler in _import_handlers:
424
- state = handler(state)
425
-
426
- # Get basic fields.
427
- assert state["version"] in [2, 3, 4, 5]
428
- name = state["name"]
429
- static_kwargs = state["static_kwargs"]
430
- build_module_src = state["build_module_src"]
431
- build_func_name = state["build_func_name"]
432
-
433
- # Create temporary module from the imported source code.
434
- module_name = "_tflib_network_import_" + uuid.uuid4().hex
435
- module = types.ModuleType(module_name)
436
- sys.modules[module_name] = module
437
- _import_module_src[module] = build_module_src
438
- exec(build_module_src, module.__dict__) # pylint: disable=exec-used
439
- build_func = util.get_obj_from_module(module, build_func_name)
440
-
441
- # Initialize fields.
442
- self._init_fields(name=name, static_kwargs=static_kwargs, build_func=build_func, build_func_name=build_func_name, build_module_src=build_module_src)
443
- self._var_inits.update(copy.deepcopy(state["variables"]))
444
- self._all_inits_known = True
445
- self._components = util.EasyDict(state.get("components", {}))
446
- self._input_shapes = copy.deepcopy(state.get("input_shapes", None))
447
- self._output_shapes = copy.deepcopy(state.get("output_shapes", None))
448
- self._input_names = copy.deepcopy(state.get("input_names", None))
449
- self._output_names = copy.deepcopy(state.get("output_names", None))
450
-
451
- def clone(self, name: str = None, **new_static_kwargs) -> "Network":
452
- """Create a clone of this network with its own copy of the variables."""
453
- static_kwargs = dict(self.static_kwargs)
454
- static_kwargs.update(new_static_kwargs)
455
- net = object.__new__(Network)
456
- net._init_fields(name=(name or self.name), static_kwargs=static_kwargs, build_func=self._build_func, build_func_name=self._build_func_name, build_module_src=self._build_module_src)
457
- net.copy_vars_from(self)
458
- return net
459
-
460
- def copy_own_vars_from(self, src_net: "Network") -> None:
461
- """Copy the values of all variables from the given network, excluding sub-networks."""
462
-
463
- # Source has unknown variables or unknown components => init now.
464
- if (src_net._var_inits is not None and not src_net._all_inits_known) or src_net._components is None:
465
- src_net._get_vars()
466
-
467
- # Both networks are inited => copy directly.
468
- if src_net._var_inits is None and self._var_inits is None:
469
- names = [name for name in self._get_own_vars().keys() if name in src_net._get_own_vars()]
470
- tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names}))
471
- return
472
-
473
- # Read from source.
474
- if src_net._var_inits is None:
475
- value_dict = tfutil.run(src_net._get_own_vars())
476
- else:
477
- value_dict = src_net._var_inits
478
-
479
- # Write to destination.
480
- if self._var_inits is None:
481
- tfutil.set_vars({self._get_vars()[name]: value for name, value in value_dict.items() if name in self._get_vars()})
482
- else:
483
- self._var_inits.update(value_dict)
484
-
485
- def copy_vars_from(self, src_net: "Network") -> None:
486
- """Copy the values of all variables from the given network, including sub-networks."""
487
-
488
- # Source has unknown variables or unknown components => init now.
489
- if (src_net._var_inits is not None and not src_net._all_inits_known) or src_net._components is None:
490
- src_net._get_vars()
491
-
492
- # Source is inited, but destination components have not been created yet => set as initial values.
493
- if src_net._var_inits is None and self._components is None:
494
- self._var_inits.update(tfutil.run(src_net._get_vars()))
495
- return
496
-
497
- # Destination has unknown components => init now.
498
- if self._components is None:
499
- self._get_vars()
500
-
501
- # Both networks are inited => copy directly.
502
- if src_net._var_inits is None and self._var_inits is None:
503
- names = [name for name in self._get_vars().keys() if name in src_net._get_vars()]
504
- tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names}))
505
- return
506
-
507
- # Copy recursively, component by component.
508
- self.copy_own_vars_from(src_net)
509
- for name, src_comp in src_net._components.items():
510
- if name in self._components:
511
- self._components[name].copy_vars_from(src_comp)
512
-
513
- def copy_trainables_from(self, src_net: "Network") -> None:
514
- """Copy the values of all trainable variables from the given network, including sub-networks."""
515
- names = [name for name in self._get_trainables().keys() if name in src_net._get_trainables()]
516
- tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names}))
517
-
518
- def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network":
519
- """Create new network with the given parameters, and copy all variables from this network."""
520
- if new_name is None:
521
- new_name = self.name
522
- static_kwargs = dict(self.static_kwargs)
523
- static_kwargs.update(new_static_kwargs)
524
- net = Network(name=new_name, func_name=new_func_name, **static_kwargs)
525
- net.copy_vars_from(self)
526
- return net
527
-
528
- def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation:
529
- """Construct a TensorFlow op that updates the variables of this network
530
- to be slightly closer to those of the given network."""
531
- with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"):
532
- ops = []
533
- for name, var in self._get_vars().items():
534
- if name in src_net._get_vars():
535
- cur_beta = beta if var.trainable else beta_nontrainable
536
- new_value = tfutil.lerp(src_net._get_vars()[name], var, cur_beta)
537
- ops.append(var.assign(new_value))
538
- return tf.group(*ops)
539
-
540
- def run(self,
541
- *in_arrays: Tuple[Union[np.ndarray, None], ...],
542
- input_transform: dict = None,
543
- output_transform: dict = None,
544
- return_as_list: bool = False,
545
- print_progress: bool = False,
546
- minibatch_size: int = None,
547
- num_gpus: int = 1,
548
- assume_frozen: bool = False,
549
- **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:
550
- """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).
551
-
552
- Args:
553
- input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network.
554
- The dict must contain a 'func' field that points to a top-level function. The function is called with the input
555
- TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
556
- output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network.
557
- The dict must contain a 'func' field that points to a top-level function. The function is called with the output
558
- TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
559
- return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
560
- print_progress: Print progress to the console? Useful for very large input arrays.
561
- minibatch_size: Maximum minibatch size to use, None = disable batching.
562
- num_gpus: Number of GPUs to use.
563
- assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
564
- dynamic_kwargs: Additional keyword arguments to be passed into the network build function.
565
- """
566
- assert len(in_arrays) == self.num_inputs
567
- assert not all(arr is None for arr in in_arrays)
568
- assert input_transform is None or util.is_top_level_function(input_transform["func"])
569
- assert output_transform is None or util.is_top_level_function(output_transform["func"])
570
- output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs)
571
- num_items = in_arrays[0].shape[0]
572
- if minibatch_size is None:
573
- minibatch_size = num_items
574
-
575
- # Construct unique hash key from all arguments that affect the TensorFlow graph.
576
- key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs)
577
- def unwind_key(obj):
578
- if isinstance(obj, dict):
579
- return [(key, unwind_key(value)) for key, value in sorted(obj.items())]
580
- if callable(obj):
581
- return util.get_top_level_function_name(obj)
582
- return obj
583
- key = repr(unwind_key(key))
584
-
585
- # Build graph.
586
- if key not in self._run_cache:
587
- with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
588
- with tf.device("/cpu:0"):
589
- in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
590
- in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
591
-
592
- out_split = []
593
- for gpu in range(num_gpus):
594
- with tf.device(self.device if num_gpus == 1 else "/gpu:%d" % gpu):
595
- net_gpu = self.clone() if assume_frozen else self
596
- in_gpu = in_split[gpu]
597
-
598
- if input_transform is not None:
599
- in_kwargs = dict(input_transform)
600
- in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs)
601
- in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu)
602
-
603
- assert len(in_gpu) == self.num_inputs
604
- out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs)
605
-
606
- if output_transform is not None:
607
- out_kwargs = dict(output_transform)
608
- out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs)
609
- out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu)
610
-
611
- assert len(out_gpu) == self.num_outputs
612
- out_split.append(out_gpu)
613
-
614
- with tf.device("/cpu:0"):
615
- out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)]
616
- self._run_cache[key] = in_expr, out_expr
617
-
618
- # Run minibatches.
619
- in_expr, out_expr = self._run_cache[key]
620
- out_arrays = [np.empty([num_items] + expr.shape.as_list()[1:], expr.dtype.name) for expr in out_expr]
621
-
622
- for mb_begin in range(0, num_items, minibatch_size):
623
- if print_progress:
624
- print("\r%d / %d" % (mb_begin, num_items), end="")
625
-
626
- mb_end = min(mb_begin + minibatch_size, num_items)
627
- mb_num = mb_end - mb_begin
628
- mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]
629
- mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))
630
-
631
- for dst, src in zip(out_arrays, mb_out):
632
- dst[mb_begin: mb_end] = src
633
-
634
- # Done.
635
- if print_progress:
636
- print("\r%d / %d" % (num_items, num_items))
637
-
638
- if not return_as_list:
639
- out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays)
640
- return out_arrays
641
-
642
- def list_ops(self) -> List[TfExpression]:
643
- _ = self.output_templates # ensure that the template graph has been created
644
- include_prefix = self.scope + "/"
645
- exclude_prefix = include_prefix + "_"
646
- ops = tf.get_default_graph().get_operations()
647
- ops = [op for op in ops if op.name.startswith(include_prefix)]
648
- ops = [op for op in ops if not op.name.startswith(exclude_prefix)]
649
- return ops
650
-
651
- def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]:
652
- """Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to
653
- individual layers of the network. Mainly intended to be used for reporting."""
654
- layers = []
655
-
656
- def recurse(scope, parent_ops, parent_vars, level):
657
- if len(parent_ops) == 0 and len(parent_vars) == 0:
658
- return
659
-
660
- # Ignore specific patterns.
661
- if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]):
662
- return
663
-
664
- # Filter ops and vars by scope.
665
- global_prefix = scope + "/"
666
- local_prefix = global_prefix[len(self.scope) + 1:]
667
- cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]]
668
- cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]]
669
- if not cur_ops and not cur_vars:
670
- return
671
-
672
- # Filter out all ops related to variables.
673
- for var in [op for op in cur_ops if op.type.startswith("Variable")]:
674
- var_prefix = var.name + "/"
675
- cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)]
676
-
677
- # Scope does not contain ops as immediate children => recurse deeper.
678
- contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type not in ["Identity", "Cast", "Transpose"] for op in cur_ops)
679
- if (level == 0 or not contains_direct_ops) and (len(cur_ops) != 0 or len(cur_vars) != 0):
680
- visited = set()
681
- for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]:
682
- token = rel_name.split("/")[0]
683
- if token not in visited:
684
- recurse(global_prefix + token, cur_ops, cur_vars, level + 1)
685
- visited.add(token)
686
- return
687
-
688
- # Report layer.
689
- layer_name = scope[len(self.scope) + 1:]
690
- layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1]
691
- layer_trainables = [var for _name, var in cur_vars if var.trainable]
692
- layers.append((layer_name, layer_output, layer_trainables))
693
-
694
- recurse(self.scope, self.list_ops(), list(self._get_vars().items()), 0)
695
- return layers
696
-
697
- def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None:
698
- """Print a summary table of the network structure."""
699
- rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]]
700
- rows += [["---"] * 4]
701
- total_params = 0
702
-
703
- for layer_name, layer_output, layer_trainables in self.list_layers():
704
- num_params = sum(int(np.prod(var.shape.as_list())) for var in layer_trainables)
705
- weights = [var for var in layer_trainables if var.name.endswith("/weight:0")]
706
- weights.sort(key=lambda x: len(x.name))
707
- if len(weights) == 0 and len(layer_trainables) == 1:
708
- weights = layer_trainables
709
- total_params += num_params
710
-
711
- if not hide_layers_with_no_params or num_params != 0:
712
- num_params_str = str(num_params) if num_params > 0 else "-"
713
- output_shape_str = str(layer_output.shape)
714
- weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-"
715
- rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]]
716
-
717
- rows += [["---"] * 4]
718
- rows += [["Total", str(total_params), "", ""]]
719
-
720
- widths = [max(len(cell) for cell in column) for column in zip(*rows)]
721
- print()
722
- for row in rows:
723
- print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths)))
724
- print()
725
-
726
- def setup_weight_histograms(self, title: str = None) -> None:
727
- """Construct summary ops to include histograms of all trainable parameters in TensorBoard."""
728
- if title is None:
729
- title = self.name
730
-
731
- with tf.name_scope(None), tf.device(None), tf.control_dependencies(None):
732
- for local_name, var in self._get_trainables().items():
733
- if "/" in local_name:
734
- p = local_name.split("/")
735
- name = title + "_" + p[-1] + "/" + "_".join(p[:-1])
736
- else:
737
- name = title + "_toplevel/" + local_name
738
-
739
- tf.summary.histogram(name, var)
740
-
741
- #----------------------------------------------------------------------------
742
- # Backwards-compatible emulation of legacy output transformation in Network.run().
743
-
744
- _print_legacy_warning = True
745
-
746
- def _handle_legacy_output_transforms(output_transform, dynamic_kwargs):
747
- global _print_legacy_warning
748
- legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"]
749
- if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs):
750
- return output_transform, dynamic_kwargs
751
-
752
- if _print_legacy_warning:
753
- _print_legacy_warning = False
754
- print()
755
- print("WARNING: Old-style output transformations in Network.run() are deprecated.")
756
- print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'")
757
- print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.")
758
- print()
759
- assert output_transform is None
760
-
761
- new_kwargs = dict(dynamic_kwargs)
762
- new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs}
763
- new_transform["func"] = _legacy_output_transform_func
764
- return new_transform, new_kwargs
765
-
766
- def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None):
767
- if out_mul != 1.0:
768
- expr = [x * out_mul for x in expr]
769
-
770
- if out_add != 0.0:
771
- expr = [x + out_add for x in expr]
772
-
773
- if out_shrink > 1:
774
- ksize = [1, 1, out_shrink, out_shrink]
775
- expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr]
776
-
777
- if out_dtype is not None:
778
- if tf.as_dtype(out_dtype).is_integer:
779
- expr = [tf.round(x) for x in expr]
780
- expr = [tf.saturate_cast(x, out_dtype) for x in expr]
781
- return expr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/dnnlib/tflib/ops/__init__.py DELETED
@@ -1,9 +0,0 @@
1
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- # empty
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/dnnlib/tflib/ops/fused_bias_act.cu DELETED
@@ -1,220 +0,0 @@
1
- // Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
- //
3
- // NVIDIA CORPORATION and its licensors retain all intellectual property
4
- // and proprietary rights in and to this software, related documentation
5
- // and any modifications thereto. Any use, reproduction, disclosure or
6
- // distribution of this software and related documentation without an express
7
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- #define EIGEN_USE_GPU
10
- #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__
11
- #include "tensorflow/core/framework/op.h"
12
- #include "tensorflow/core/framework/op_kernel.h"
13
- #include "tensorflow/core/framework/shape_inference.h"
14
- #include <stdio.h>
15
-
16
- using namespace tensorflow;
17
- using namespace tensorflow::shape_inference;
18
-
19
- #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false)
20
-
21
- //------------------------------------------------------------------------
22
- // CUDA kernel.
23
-
24
- template <class T>
25
- struct FusedBiasActKernelParams
26
- {
27
- const T* x; // [sizeX]
28
- const T* b; // [sizeB] or NULL
29
- const T* xref; // [sizeX] or NULL
30
- const T* yref; // [sizeX] or NULL
31
- T* y; // [sizeX]
32
-
33
- int grad;
34
- int axis;
35
- int act;
36
- float alpha;
37
- float gain;
38
- float clamp;
39
-
40
- int sizeX;
41
- int sizeB;
42
- int stepB;
43
- int loopX;
44
- };
45
-
46
- template <class T>
47
- static __global__ void FusedBiasActKernel(const FusedBiasActKernelParams<T> p)
48
- {
49
- const float expRange = 80.0f;
50
- const float halfExpRange = 40.0f;
51
- const float seluScale = 1.0507009873554804934193349852946f;
52
- const float seluAlpha = 1.6732632423543772848170429916717f;
53
-
54
- // Loop over elements.
55
- int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
56
- for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
57
- {
58
- // Load and apply bias.
59
- float x = (float)p.x[xi];
60
- if (p.b)
61
- x += (float)p.b[(xi / p.stepB) % p.sizeB];
62
- float xref = (p.xref) ? (float)p.xref[xi] : 0.0f;
63
- float yref = (p.yref) ? (float)p.yref[xi] : 0.0f;
64
- float yy = (p.gain != 0.0f) ? yref / p.gain : 0.0f;
65
-
66
- // Evaluate activation func.
67
- float y;
68
- switch (p.act * 10 + p.grad)
69
- {
70
- // linear
71
- default:
72
- case 10: y = x; break;
73
- case 11: y = x; break;
74
- case 12: y = 0.0f; break;
75
-
76
- // relu
77
- case 20: y = (x > 0.0f) ? x : 0.0f; break;
78
- case 21: y = (yy > 0.0f) ? x : 0.0f; break;
79
- case 22: y = 0.0f; break;
80
-
81
- // lrelu
82
- case 30: y = (x > 0.0f) ? x : x * p.alpha; break;
83
- case 31: y = (yy > 0.0f) ? x : x * p.alpha; break;
84
- case 32: y = 0.0f; break;
85
-
86
- // tanh
87
- case 40: { float c = expf(x); float d = 1.0f / c; y = (x < -expRange) ? -1.0f : (x > expRange) ? 1.0f : (c - d) / (c + d); } break;
88
- case 41: y = x * (1.0f - yy * yy); break;
89
- case 42: y = x * (1.0f - yy * yy) * (-2.0f * yy); break;
90
-
91
- // sigmoid
92
- case 50: y = (x < -expRange) ? 0.0f : 1.0f / (expf(-x) + 1.0f); break;
93
- case 51: y = x * yy * (1.0f - yy); break;
94
- case 52: y = x * yy * (1.0f - yy) * (1.0f - 2.0f * yy); break;
95
-
96
- // elu
97
- case 60: y = (x >= 0.0f) ? x : expf(x) - 1.0f; break;
98
- case 61: y = (yy >= 0.0f) ? x : x * (yy + 1.0f); break;
99
- case 62: y = (yy >= 0.0f) ? 0.0f : x * (yy + 1.0f); break;
100
-
101
- // selu
102
- case 70: y = (x >= 0.0f) ? seluScale * x : (seluScale * seluAlpha) * (expf(x) - 1.0f); break;
103
- case 71: y = (yy >= 0.0f) ? x * seluScale : x * (yy + seluScale * seluAlpha); break;
104
- case 72: y = (yy >= 0.0f) ? 0.0f : x * (yy + seluScale * seluAlpha); break;
105
-
106
- // softplus
107
- case 80: y = (x > expRange) ? x : logf(expf(x) + 1.0f); break;
108
- case 81: y = x * (1.0f - expf(-yy)); break;
109
- case 82: { float c = expf(-yy); y = x * c * (1.0f - c); } break;
110
-
111
- // swish
112
- case 90: y = (x < -expRange) ? 0.0f : x / (expf(-x) + 1.0f); break;
113
- case 91:
114
- case 92:
115
- {
116
- float c = expf(xref);
117
- float d = c + 1.0f;
118
- if (p.grad == 1)
119
- y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
120
- else
121
- y = (xref > halfExpRange) ? 0.0f : x * c * (xref * (2.0f - d) + 2.0f * d) / (d * d * d);
122
- yref = (xref < -expRange) ? 0.0f : xref / (expf(-xref) + 1.0f) * p.gain;
123
- }
124
- break;
125
- }
126
-
127
- // Apply gain.
128
- y *= p.gain;
129
-
130
- // Clamp.
131
- if (p.clamp >= 0.0f)
132
- {
133
- if (p.grad == 0)
134
- y = (fabsf(y) < p.clamp) ? y : (y >= 0.0f) ? p.clamp : -p.clamp;
135
- else
136
- y = (fabsf(yref) < p.clamp) ? y : 0.0f;
137
- }
138
-
139
- // Store.
140
- p.y[xi] = (T)y;
141
- }
142
- }
143
-
144
- //------------------------------------------------------------------------
145
- // TensorFlow op.
146
-
147
- template <class T>
148
- struct FusedBiasActOp : public OpKernel
149
- {
150
- FusedBiasActKernelParams<T> m_attribs;
151
-
152
- FusedBiasActOp(OpKernelConstruction* ctx) : OpKernel(ctx)
153
- {
154
- memset(&m_attribs, 0, sizeof(m_attribs));
155
- OP_REQUIRES_OK(ctx, ctx->GetAttr("grad", &m_attribs.grad));
156
- OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &m_attribs.axis));
157
- OP_REQUIRES_OK(ctx, ctx->GetAttr("act", &m_attribs.act));
158
- OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &m_attribs.alpha));
159
- OP_REQUIRES_OK(ctx, ctx->GetAttr("gain", &m_attribs.gain));
160
- OP_REQUIRES_OK(ctx, ctx->GetAttr("clamp", &m_attribs.clamp));
161
- OP_REQUIRES(ctx, m_attribs.grad >= 0, errors::InvalidArgument("grad must be non-negative"));
162
- OP_REQUIRES(ctx, m_attribs.axis >= 0, errors::InvalidArgument("axis must be non-negative"));
163
- OP_REQUIRES(ctx, m_attribs.act >= 0, errors::InvalidArgument("act must be non-negative"));
164
- }
165
-
166
- void Compute(OpKernelContext* ctx)
167
- {
168
- FusedBiasActKernelParams<T> p = m_attribs;
169
- cudaStream_t stream = ctx->eigen_device<Eigen::GpuDevice>().stream();
170
-
171
- const Tensor& x = ctx->input(0); // [...]
172
- const Tensor& b = ctx->input(1); // [sizeB] or [0]
173
- const Tensor& xref = ctx->input(2); // x.shape or [0]
174
- const Tensor& yref = ctx->input(3); // x.shape or [0]
175
- p.x = x.flat<T>().data();
176
- p.b = (b.NumElements()) ? b.flat<T>().data() : NULL;
177
- p.xref = (xref.NumElements()) ? xref.flat<T>().data() : NULL;
178
- p.yref = (yref.NumElements()) ? yref.flat<T>().data() : NULL;
179
- OP_REQUIRES(ctx, b.NumElements() == 0 || m_attribs.axis < x.dims(), errors::InvalidArgument("axis out of bounds"));
180
- OP_REQUIRES(ctx, b.dims() == 1, errors::InvalidArgument("b must have rank 1"));
181
- OP_REQUIRES(ctx, b.NumElements() == 0 || b.NumElements() == x.dim_size(m_attribs.axis), errors::InvalidArgument("b has wrong number of elements"));
182
- OP_REQUIRES(ctx, xref.NumElements() == 0 || xref.NumElements() == x.NumElements(), errors::InvalidArgument("xref has wrong number of elements"));
183
- OP_REQUIRES(ctx, yref.NumElements() == 0 || yref.NumElements() == x.NumElements(), errors::InvalidArgument("yref has wrong number of elements"));
184
- OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("x is too large"));
185
-
186
- p.sizeX = (int)x.NumElements();
187
- p.sizeB = (int)b.NumElements();
188
- p.stepB = 1;
189
- for (int i = m_attribs.axis + 1; i < x.dims(); i++)
190
- p.stepB *= (int)x.dim_size(i);
191
-
192
- Tensor* y = NULL; // x.shape
193
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y));
194
- p.y = y->flat<T>().data();
195
-
196
- p.loopX = 4;
197
- int blockSize = 4 * 32;
198
- int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
199
- void* args[] = {&p};
200
- OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)FusedBiasActKernel<T>, gridSize, blockSize, args, 0, stream));
201
- }
202
- };
203
-
204
- REGISTER_OP("FusedBiasAct")
205
- .Input ("x: T")
206
- .Input ("b: T")
207
- .Input ("xref: T")
208
- .Input ("yref: T")
209
- .Output ("y: T")
210
- .Attr ("T: {float, half}")
211
- .Attr ("grad: int = 0")
212
- .Attr ("axis: int = 1")
213
- .Attr ("act: int = 0")
214
- .Attr ("alpha: float = 0.0")
215
- .Attr ("gain: float = 1.0")
216
- .Attr ("clamp: float = -1.0");
217
- REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint<float>("T"), FusedBiasActOp<float>);
218
- REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), FusedBiasActOp<Eigen::half>);
219
-
220
- //------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/dnnlib/tflib/ops/fused_bias_act.py DELETED
@@ -1,211 +0,0 @@
1
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Custom TensorFlow ops for efficient bias and activation."""
10
-
11
- import os
12
- import numpy as np
13
- import tensorflow as tf
14
- from .. import custom_ops
15
- from ...util import EasyDict
16
-
17
- def _get_plugin():
18
- return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu')
19
-
20
- #----------------------------------------------------------------------------
21
-
22
- activation_funcs = {
23
- 'linear': EasyDict(func=lambda x, **_: x, def_alpha=None, def_gain=1.0, cuda_idx=1, ref='y', zero_2nd_grad=True),
24
- 'relu': EasyDict(func=lambda x, **_: tf.nn.relu(x), def_alpha=None, def_gain=np.sqrt(2), cuda_idx=2, ref='y', zero_2nd_grad=True),
25
- 'lrelu': EasyDict(func=lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', zero_2nd_grad=True),
26
- 'tanh': EasyDict(func=lambda x, **_: tf.nn.tanh(x), def_alpha=None, def_gain=1.0, cuda_idx=4, ref='y', zero_2nd_grad=False),
27
- 'sigmoid': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x), def_alpha=None, def_gain=1.0, cuda_idx=5, ref='y', zero_2nd_grad=False),
28
- 'elu': EasyDict(func=lambda x, **_: tf.nn.elu(x), def_alpha=None, def_gain=1.0, cuda_idx=6, ref='y', zero_2nd_grad=False),
29
- 'selu': EasyDict(func=lambda x, **_: tf.nn.selu(x), def_alpha=None, def_gain=1.0, cuda_idx=7, ref='y', zero_2nd_grad=False),
30
- 'softplus': EasyDict(func=lambda x, **_: tf.nn.softplus(x), def_alpha=None, def_gain=1.0, cuda_idx=8, ref='y', zero_2nd_grad=False),
31
- 'swish': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x) * x, def_alpha=None, def_gain=np.sqrt(2), cuda_idx=9, ref='x', zero_2nd_grad=False),
32
- }
33
-
34
- #----------------------------------------------------------------------------
35
-
36
- def fused_bias_act(x, b=None, axis=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
37
- r"""Fused bias and activation function.
38
-
39
- Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
40
- and scales the result by `gain`. Each of the steps is optional. In most cases,
41
- the fused op is considerably more efficient than performing the same calculation
42
- using standard TensorFlow ops. It supports first and second order gradients,
43
- but not third order gradients.
44
-
45
- Args:
46
- x: Input activation tensor. Can have any shape, but if `b` is defined, the
47
- dimension corresponding to `axis`, as well as the rank, must be known.
48
- b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
49
- as `x`. The shape must be known, and it must match the dimension of `x`
50
- corresponding to `axis`.
51
- axis: The dimension in `x` corresponding to the elements of `b`.
52
- The value of `axis` is ignored if `b` is not specified.
53
- act: Name of the activation function to evaluate, or `"linear"` to disable.
54
- Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
55
- See `activation_funcs` for a full list. `None` is not allowed.
56
- alpha: Shape parameter for the activation function, or `None` to use the default.
57
- gain: Scaling factor for the output tensor, or `None` to use default.
58
- See `activation_funcs` for the default scaling of each activation function.
59
- If unsure, consider specifying `1.0`.
60
- clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
61
- the clamping (default).
62
- impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
63
-
64
- Returns:
65
- Tensor of the same shape and datatype as `x`.
66
- """
67
-
68
- impl_dict = {
69
- 'ref': _fused_bias_act_ref,
70
- 'cuda': _fused_bias_act_cuda,
71
- }
72
- return impl_dict[impl](x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain, clamp=clamp)
73
-
74
- #----------------------------------------------------------------------------
75
-
76
- def _fused_bias_act_ref(x, b, axis, act, alpha, gain, clamp):
77
- """Slow reference implementation of `fused_bias_act()` using standard TensorFlow ops."""
78
-
79
- # Validate arguments.
80
- x = tf.convert_to_tensor(x)
81
- b = tf.convert_to_tensor(b) if b is not None else tf.constant([], dtype=x.dtype)
82
- act_spec = activation_funcs[act]
83
- assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis])
84
- assert b.shape[0] == 0 or 0 <= axis < x.shape.rank
85
- if alpha is None:
86
- alpha = act_spec.def_alpha
87
- if gain is None:
88
- gain = act_spec.def_gain
89
-
90
- # Add bias.
91
- if b.shape[0] != 0:
92
- x += tf.reshape(b, [-1 if i == axis else 1 for i in range(x.shape.rank)])
93
-
94
- # Evaluate activation function.
95
- x = act_spec.func(x, alpha=alpha)
96
-
97
- # Scale by gain.
98
- if gain != 1:
99
- x *= gain
100
-
101
- # Clamp.
102
- if clamp is not None:
103
- clamp = np.asarray(clamp, dtype=x.dtype.name)
104
- assert clamp.shape == () and clamp >= 0
105
- x = tf.clip_by_value(x, -clamp, clamp)
106
- return x
107
-
108
- #----------------------------------------------------------------------------
109
-
110
- def _fused_bias_act_cuda(x, b, axis, act, alpha, gain, clamp):
111
- """Fast CUDA implementation of `fused_bias_act()` using custom ops."""
112
-
113
- # Validate arguments.
114
- x = tf.convert_to_tensor(x)
115
- empty_tensor = tf.constant([], dtype=x.dtype)
116
- b = tf.convert_to_tensor(b) if b is not None else empty_tensor
117
- act_spec = activation_funcs[act]
118
- assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis])
119
- assert b.shape[0] == 0 or 0 <= axis < x.shape.rank
120
- if alpha is None:
121
- alpha = act_spec.def_alpha
122
- if gain is None:
123
- gain = act_spec.def_gain
124
-
125
- # Special cases.
126
- if act == 'linear' and b is None and gain == 1.0:
127
- return x
128
- if act_spec.cuda_idx is None:
129
- return _fused_bias_act_ref(x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain, clamp=clamp)
130
-
131
- # CUDA op.
132
- cuda_op = _get_plugin().fused_bias_act
133
- cuda_kwargs = dict(axis=int(axis), act=int(act_spec.cuda_idx), gain=float(gain))
134
- if alpha is not None:
135
- cuda_kwargs['alpha'] = float(alpha)
136
- if clamp is not None:
137
- clamp = np.asarray(clamp, dtype=x.dtype.name)
138
- assert clamp.shape == () and clamp >= 0
139
- cuda_kwargs['clamp'] = float(clamp.astype(np.float32))
140
- def ref(tensor, name):
141
- return tensor if act_spec.ref == name else empty_tensor
142
-
143
- # Forward pass: y = func(x, b).
144
- def func_y(x, b):
145
- y = cuda_op(x=x, b=b, xref=empty_tensor, yref=empty_tensor, grad=0, **cuda_kwargs)
146
- y.set_shape(x.shape)
147
- return y
148
-
149
- # Backward pass: dx, db = grad(dy, x, y)
150
- def grad_dx(dy, x, y):
151
- dx = cuda_op(x=dy, b=empty_tensor, xref=ref(x,'x'), yref=ref(y,'y'), grad=1, **cuda_kwargs)
152
- dx.set_shape(x.shape)
153
- return dx
154
- def grad_db(dx):
155
- if b.shape[0] == 0:
156
- return empty_tensor
157
- db = dx
158
- if axis < x.shape.rank - 1:
159
- db = tf.reduce_sum(db, list(range(axis + 1, x.shape.rank)))
160
- if axis > 0:
161
- db = tf.reduce_sum(db, list(range(axis)))
162
- db.set_shape(b.shape)
163
- return db
164
-
165
- # Second order gradients: d_dy, d_x = grad2(d_dx, d_db, x, y)
166
- def grad2_d_dy(d_dx, d_db, x, y):
167
- d_dy = cuda_op(x=d_dx, b=d_db, xref=ref(x,'x'), yref=ref(y,'y'), grad=1, **cuda_kwargs)
168
- d_dy.set_shape(x.shape)
169
- return d_dy
170
- def grad2_d_x(d_dx, d_db, x, y):
171
- d_x = cuda_op(x=d_dx, b=d_db, xref=ref(x,'x'), yref=ref(y,'y'), grad=2, **cuda_kwargs)
172
- d_x.set_shape(x.shape)
173
- return d_x
174
-
175
- # Fast version for piecewise-linear activation funcs.
176
- @tf.custom_gradient
177
- def func_zero_2nd_grad(x, b):
178
- y = func_y(x, b)
179
- @tf.custom_gradient
180
- def grad(dy):
181
- dx = grad_dx(dy, x, y)
182
- db = grad_db(dx)
183
- def grad2(d_dx, d_db):
184
- d_dy = grad2_d_dy(d_dx, d_db, x, y)
185
- return d_dy
186
- return (dx, db), grad2
187
- return y, grad
188
-
189
- # Slow version for general activation funcs.
190
- @tf.custom_gradient
191
- def func_nonzero_2nd_grad(x, b):
192
- y = func_y(x, b)
193
- def grad_wrap(dy):
194
- @tf.custom_gradient
195
- def grad_impl(dy, x):
196
- dx = grad_dx(dy, x, y)
197
- db = grad_db(dx)
198
- def grad2(d_dx, d_db):
199
- d_dy = grad2_d_dy(d_dx, d_db, x, y)
200
- d_x = grad2_d_x(d_dx, d_db, x, y)
201
- return d_dy, d_x
202
- return (dx, db), grad2
203
- return grad_impl(dy, x)
204
- return y, grad_wrap
205
-
206
- # Which version to use?
207
- if act_spec.zero_2nd_grad:
208
- return func_zero_2nd_grad(x, b)
209
- return func_nonzero_2nd_grad(x, b)
210
-
211
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/dnnlib/tflib/ops/upfirdn_2d.cu DELETED
@@ -1,359 +0,0 @@
1
- // Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
- //
3
- // NVIDIA CORPORATION and its licensors retain all intellectual property
4
- // and proprietary rights in and to this software, related documentation
5
- // and any modifications thereto. Any use, reproduction, disclosure or
6
- // distribution of this software and related documentation without an express
7
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- #define EIGEN_USE_GPU
10
- #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__
11
- #include "tensorflow/core/framework/op.h"
12
- #include "tensorflow/core/framework/op_kernel.h"
13
- #include "tensorflow/core/framework/shape_inference.h"
14
- #include <stdio.h>
15
-
16
- using namespace tensorflow;
17
- using namespace tensorflow::shape_inference;
18
-
19
- //------------------------------------------------------------------------
20
- // Helpers.
21
-
22
- #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false)
23
-
24
- static __host__ __device__ __forceinline__ int floorDiv(int a, int b)
25
- {
26
- int t = 1 - a / b;
27
- return (a + t * b) / b - t;
28
- }
29
-
30
- //------------------------------------------------------------------------
31
- // CUDA kernel params.
32
-
33
- template <class T>
34
- struct UpFirDn2DKernelParams
35
- {
36
- const T* x; // [majorDim, inH, inW, minorDim]
37
- const T* k; // [kernelH, kernelW]
38
- T* y; // [majorDim, outH, outW, minorDim]
39
-
40
- int upx;
41
- int upy;
42
- int downx;
43
- int downy;
44
- int padx0;
45
- int padx1;
46
- int pady0;
47
- int pady1;
48
-
49
- int majorDim;
50
- int inH;
51
- int inW;
52
- int minorDim;
53
- int kernelH;
54
- int kernelW;
55
- int outH;
56
- int outW;
57
- int loopMajor;
58
- int loopX;
59
- };
60
-
61
- //------------------------------------------------------------------------
62
- // General CUDA implementation for large filter kernels.
63
-
64
- template <class T>
65
- static __global__ void UpFirDn2DKernel_large(const UpFirDn2DKernelParams<T> p)
66
- {
67
- // Calculate thread index.
68
- int minorIdx = blockIdx.x * blockDim.x + threadIdx.x;
69
- int outY = minorIdx / p.minorDim;
70
- minorIdx -= outY * p.minorDim;
71
- int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
72
- int majorIdxBase = blockIdx.z * p.loopMajor;
73
- if (outXBase >= p.outW || outY >= p.outH || majorIdxBase >= p.majorDim)
74
- return;
75
-
76
- // Setup Y receptive field.
77
- int midY = outY * p.downy + p.upy - 1 - p.pady0;
78
- int inY = min(max(floorDiv(midY, p.upy), 0), p.inH);
79
- int h = min(max(floorDiv(midY + p.kernelH, p.upy), 0), p.inH) - inY;
80
- int kernelY = midY + p.kernelH - (inY + 1) * p.upy;
81
-
82
- // Loop over majorDim and outX.
83
- for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor && majorIdx < p.majorDim; loopMajor++, majorIdx++)
84
- for (int loopX = 0, outX = outXBase; loopX < p.loopX && outX < p.outW; loopX++, outX += blockDim.y)
85
- {
86
- // Setup X receptive field.
87
- int midX = outX * p.downx + p.upx - 1 - p.padx0;
88
- int inX = min(max(floorDiv(midX, p.upx), 0), p.inW);
89
- int w = min(max(floorDiv(midX + p.kernelW, p.upx), 0), p.inW) - inX;
90
- int kernelX = midX + p.kernelW - (inX + 1) * p.upx;
91
-
92
- // Initialize pointers.
93
- const T* xp = &p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx];
94
- const T* kp = &p.k[kernelY * p.kernelW + kernelX];
95
- int xpx = p.minorDim;
96
- int kpx = -p.upx;
97
- int xpy = p.inW * p.minorDim;
98
- int kpy = -p.upy * p.kernelW;
99
-
100
- // Inner loop.
101
- float v = 0.0f;
102
- for (int y = 0; y < h; y++)
103
- {
104
- for (int x = 0; x < w; x++)
105
- {
106
- v += (float)(*xp) * (float)(*kp);
107
- xp += xpx;
108
- kp += kpx;
109
- }
110
- xp += xpy - w * xpx;
111
- kp += kpy - w * kpx;
112
- }
113
-
114
- // Store result.
115
- p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v;
116
- }
117
- }
118
-
119
- //------------------------------------------------------------------------
120
- // Specialized CUDA implementation for small filter kernels.
121
-
122
- template <class T, int upx, int upy, int downx, int downy, int kernelW, int kernelH, int tileOutW, int tileOutH>
123
- static __global__ void UpFirDn2DKernel_small(const UpFirDn2DKernelParams<T> p)
124
- {
125
- //assert(kernelW % upx == 0);
126
- //assert(kernelH % upy == 0);
127
- const int tileInW = ((tileOutW - 1) * downx + kernelW - 1) / upx + 1;
128
- const int tileInH = ((tileOutH - 1) * downy + kernelH - 1) / upy + 1;
129
- __shared__ volatile float sk[kernelH][kernelW];
130
- __shared__ volatile float sx[tileInH][tileInW];
131
-
132
- // Calculate tile index.
133
- int minorIdx = blockIdx.x;
134
- int tileOutY = minorIdx / p.minorDim;
135
- minorIdx -= tileOutY * p.minorDim;
136
- tileOutY *= tileOutH;
137
- int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
138
- int majorIdxBase = blockIdx.z * p.loopMajor;
139
- if (tileOutXBase >= p.outW | tileOutY >= p.outH | majorIdxBase >= p.majorDim)
140
- return;
141
-
142
- // Load filter kernel (flipped).
143
- for (int tapIdx = threadIdx.x; tapIdx < kernelH * kernelW; tapIdx += blockDim.x)
144
- {
145
- int ky = tapIdx / kernelW;
146
- int kx = tapIdx - ky * kernelW;
147
- float v = 0.0f;
148
- if (kx < p.kernelW & ky < p.kernelH)
149
- v = (float)p.k[(p.kernelH - 1 - ky) * p.kernelW + (p.kernelW - 1 - kx)];
150
- sk[ky][kx] = v;
151
- }
152
-
153
- // Loop over majorDim and outX.
154
- for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor & majorIdx < p.majorDim; loopMajor++, majorIdx++)
155
- for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outW; loopX++, tileOutX += tileOutW)
156
- {
157
- // Load input pixels.
158
- int tileMidX = tileOutX * downx + upx - 1 - p.padx0;
159
- int tileMidY = tileOutY * downy + upy - 1 - p.pady0;
160
- int tileInX = floorDiv(tileMidX, upx);
161
- int tileInY = floorDiv(tileMidY, upy);
162
- __syncthreads();
163
- for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW; inIdx += blockDim.x)
164
- {
165
- int relInY = inIdx / tileInW;
166
- int relInX = inIdx - relInY * tileInW;
167
- int inX = relInX + tileInX;
168
- int inY = relInY + tileInY;
169
- float v = 0.0f;
170
- if (inX >= 0 & inY >= 0 & inX < p.inW & inY < p.inH)
171
- v = (float)p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx];
172
- sx[relInY][relInX] = v;
173
- }
174
-
175
- // Loop over output pixels.
176
- __syncthreads();
177
- for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW; outIdx += blockDim.x)
178
- {
179
- int relOutY = outIdx / tileOutW;
180
- int relOutX = outIdx - relOutY * tileOutW;
181
- int outX = relOutX + tileOutX;
182
- int outY = relOutY + tileOutY;
183
-
184
- // Setup receptive field.
185
- int midX = tileMidX + relOutX * downx;
186
- int midY = tileMidY + relOutY * downy;
187
- int inX = floorDiv(midX, upx);
188
- int inY = floorDiv(midY, upy);
189
- int relInX = inX - tileInX;
190
- int relInY = inY - tileInY;
191
- int kernelX = (inX + 1) * upx - midX - 1; // flipped
192
- int kernelY = (inY + 1) * upy - midY - 1; // flipped
193
-
194
- // Inner loop.
195
- float v = 0.0f;
196
- #pragma unroll
197
- for (int y = 0; y < kernelH / upy; y++)
198
- #pragma unroll
199
- for (int x = 0; x < kernelW / upx; x++)
200
- v += sx[relInY + y][relInX + x] * sk[kernelY + y * upy][kernelX + x * upx];
201
-
202
- // Store result.
203
- if (outX < p.outW & outY < p.outH)
204
- p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v;
205
- }
206
- }
207
- }
208
-
209
- //------------------------------------------------------------------------
210
- // TensorFlow op.
211
-
212
- template <class T>
213
- struct UpFirDn2DOp : public OpKernel
214
- {
215
- UpFirDn2DKernelParams<T> m_attribs;
216
-
217
- UpFirDn2DOp(OpKernelConstruction* ctx) : OpKernel(ctx)
218
- {
219
- memset(&m_attribs, 0, sizeof(m_attribs));
220
- OP_REQUIRES_OK(ctx, ctx->GetAttr("upx", &m_attribs.upx));
221
- OP_REQUIRES_OK(ctx, ctx->GetAttr("upy", &m_attribs.upy));
222
- OP_REQUIRES_OK(ctx, ctx->GetAttr("downx", &m_attribs.downx));
223
- OP_REQUIRES_OK(ctx, ctx->GetAttr("downy", &m_attribs.downy));
224
- OP_REQUIRES_OK(ctx, ctx->GetAttr("padx0", &m_attribs.padx0));
225
- OP_REQUIRES_OK(ctx, ctx->GetAttr("padx1", &m_attribs.padx1));
226
- OP_REQUIRES_OK(ctx, ctx->GetAttr("pady0", &m_attribs.pady0));
227
- OP_REQUIRES_OK(ctx, ctx->GetAttr("pady1", &m_attribs.pady1));
228
- OP_REQUIRES(ctx, m_attribs.upx >= 1 && m_attribs.upy >= 1, errors::InvalidArgument("upx and upy must be at least 1x1"));
229
- OP_REQUIRES(ctx, m_attribs.downx >= 1 && m_attribs.downy >= 1, errors::InvalidArgument("downx and downy must be at least 1x1"));
230
- }
231
-
232
- void Compute(OpKernelContext* ctx)
233
- {
234
- UpFirDn2DKernelParams<T> p = m_attribs;
235
- cudaStream_t stream = ctx->eigen_device<Eigen::GpuDevice>().stream();
236
-
237
- const Tensor& x = ctx->input(0); // [majorDim, inH, inW, minorDim]
238
- const Tensor& k = ctx->input(1); // [kernelH, kernelW]
239
- p.x = x.flat<T>().data();
240
- p.k = k.flat<T>().data();
241
- OP_REQUIRES(ctx, x.dims() == 4, errors::InvalidArgument("input must have rank 4"));
242
- OP_REQUIRES(ctx, k.dims() == 2, errors::InvalidArgument("kernel must have rank 2"));
243
- OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("input too large"));
244
- OP_REQUIRES(ctx, k.NumElements() <= kint32max, errors::InvalidArgument("kernel too large"));
245
-
246
- p.majorDim = (int)x.dim_size(0);
247
- p.inH = (int)x.dim_size(1);
248
- p.inW = (int)x.dim_size(2);
249
- p.minorDim = (int)x.dim_size(3);
250
- p.kernelH = (int)k.dim_size(0);
251
- p.kernelW = (int)k.dim_size(1);
252
- OP_REQUIRES(ctx, p.kernelW >= 1 && p.kernelH >= 1, errors::InvalidArgument("kernel must be at least 1x1"));
253
-
254
- p.outW = (p.inW * p.upx + p.padx0 + p.padx1 - p.kernelW + p.downx) / p.downx;
255
- p.outH = (p.inH * p.upy + p.pady0 + p.pady1 - p.kernelH + p.downy) / p.downy;
256
- OP_REQUIRES(ctx, p.outW >= 1 && p.outH >= 1, errors::InvalidArgument("output must be at least 1x1"));
257
-
258
- Tensor* y = NULL; // [majorDim, outH, outW, minorDim]
259
- TensorShape ys;
260
- ys.AddDim(p.majorDim);
261
- ys.AddDim(p.outH);
262
- ys.AddDim(p.outW);
263
- ys.AddDim(p.minorDim);
264
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, ys, &y));
265
- p.y = y->flat<T>().data();
266
- OP_REQUIRES(ctx, y->NumElements() <= kint32max, errors::InvalidArgument("output too large"));
267
-
268
- // Choose CUDA kernel to use.
269
- void* cudaKernel = (void*)UpFirDn2DKernel_large<T>;
270
- int tileOutW = -1;
271
- int tileOutH = -1;
272
-
273
- if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 7 && p.kernelH <= 7 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 7,7, 64,16>; tileOutW = 64; tileOutH = 16; }
274
- if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 6,6, 64,16>; tileOutW = 64; tileOutH = 16; }
275
- if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 5 && p.kernelH <= 5 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 5,5, 64,16>; tileOutW = 64; tileOutH = 16; }
276
- if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 4,4, 64,16>; tileOutW = 64; tileOutH = 16; }
277
- if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 3 && p.kernelH <= 3 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 3,3, 64,16>; tileOutW = 64; tileOutH = 16; }
278
- if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 24,1, 128,8>; tileOutW = 128; tileOutH = 8; }
279
- if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 20,1, 128,8>; tileOutW = 128; tileOutH = 8; }
280
- if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 16,1, 128,8>; tileOutW = 128; tileOutH = 8; }
281
- if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 12,1, 128,8>; tileOutW = 128; tileOutH = 8; }
282
- if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 8,1, 128,8>; tileOutW = 128; tileOutH = 8; }
283
- if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 1,24, 32,32>; tileOutW = 32; tileOutH = 32; }
284
- if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 1,20, 32,32>; tileOutW = 32; tileOutH = 32; }
285
- if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 1,16, 32,32>; tileOutW = 32; tileOutH = 32; }
286
- if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 1,12, 32,32>; tileOutW = 32; tileOutH = 32; }
287
- if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 1,8, 32,32>; tileOutW = 32; tileOutH = 32; }
288
-
289
- if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 8,8, 64,16>; tileOutW = 64; tileOutH = 16; }
290
- if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 6,6, 64,16>; tileOutW = 64; tileOutH = 16; }
291
- if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 4,4, 64,16>; tileOutW = 64; tileOutH = 16; }
292
- if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 2 && p.kernelH <= 2 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 2,2, 64,16>; tileOutW = 64; tileOutH = 16; }
293
- if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,1, 1,1, 24,1, 128,8>; tileOutW = 128; tileOutH = 8; }
294
- if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,1, 1,1, 20,1, 128,8>; tileOutW = 128; tileOutH = 8; }
295
- if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,1, 1,1, 16,1, 128,8>; tileOutW = 128; tileOutH = 8; }
296
- if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,1, 1,1, 12,1, 128,8>; tileOutW = 128; tileOutH = 8; }
297
- if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,1, 1,1, 8,1, 128,8>; tileOutW = 128; tileOutH = 8; }
298
- if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,2, 1,1, 1,24, 32,32>; tileOutW = 32; tileOutH = 32; }
299
- if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,2, 1,1, 1,20, 32,32>; tileOutW = 32; tileOutH = 32; }
300
- if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,2, 1,1, 1,16, 32,32>; tileOutW = 32; tileOutH = 32; }
301
- if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,2, 1,1, 1,12, 32,32>; tileOutW = 32; tileOutH = 32; }
302
- if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,2, 1,1, 1,8, 32,32>; tileOutW = 32; tileOutH = 32; }
303
-
304
- if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 8 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 8,8, 32,8 >; tileOutW = 32; tileOutH = 8; }
305
- if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 6 && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 6,6, 32,8 >; tileOutW = 32; tileOutH = 8; }
306
- if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 4 && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 4,4, 32,8 >; tileOutW = 32; tileOutH = 8; }
307
- if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 2 && p.kernelH <= 2 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 2,2, 32,8 >; tileOutW = 32; tileOutH = 8; }
308
- if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,1, 24,1, 64,8 >; tileOutW = 64; tileOutH = 8; }
309
- if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,1, 20,1, 64,8 >; tileOutW = 64; tileOutH = 8; }
310
- if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,1, 16,1, 64,8 >; tileOutW = 64; tileOutH = 8; }
311
- if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,1, 12,1, 64,8 >; tileOutW = 64; tileOutH = 8; }
312
- if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,1, 8,1, 64,8 >; tileOutW = 64; tileOutH = 8; }
313
- if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,2, 1,24, 32,16>; tileOutW = 32; tileOutH = 16; }
314
- if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,2, 1,20, 32,16>; tileOutW = 32; tileOutH = 16; }
315
- if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,2, 1,16, 32,16>; tileOutW = 32; tileOutH = 16; }
316
- if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,2, 1,12, 32,16>; tileOutW = 32; tileOutH = 16; }
317
- if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,2, 1,8, 32,16>; tileOutW = 32; tileOutH = 16; }
318
-
319
- // Choose launch params.
320
- dim3 blockSize;
321
- dim3 gridSize;
322
- if (tileOutW > 0 && tileOutH > 0) // small
323
- {
324
- p.loopMajor = (p.majorDim - 1) / 16384 + 1;
325
- p.loopX = 1;
326
- blockSize = dim3(32 * 8, 1, 1);
327
- gridSize = dim3(((p.outH - 1) / tileOutH + 1) * p.minorDim, (p.outW - 1) / (p.loopX * tileOutW) + 1, (p.majorDim - 1) / p.loopMajor + 1);
328
- }
329
- else // large
330
- {
331
- p.loopMajor = (p.majorDim - 1) / 16384 + 1;
332
- p.loopX = 4;
333
- blockSize = dim3(4, 32, 1);
334
- gridSize = dim3((p.outH * p.minorDim - 1) / blockSize.x + 1, (p.outW - 1) / (p.loopX * blockSize.y) + 1, (p.majorDim - 1) / p.loopMajor + 1);
335
- }
336
-
337
- // Launch CUDA kernel.
338
- void* args[] = {&p};
339
- OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(cudaKernel, gridSize, blockSize, args, 0, stream));
340
- }
341
- };
342
-
343
- REGISTER_OP("UpFirDn2D")
344
- .Input ("x: T")
345
- .Input ("k: T")
346
- .Output ("y: T")
347
- .Attr ("T: {float, half}")
348
- .Attr ("upx: int = 1")
349
- .Attr ("upy: int = 1")
350
- .Attr ("downx: int = 1")
351
- .Attr ("downy: int = 1")
352
- .Attr ("padx0: int = 0")
353
- .Attr ("padx1: int = 0")
354
- .Attr ("pady0: int = 0")
355
- .Attr ("pady1: int = 0");
356
- REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint<float>("T"), UpFirDn2DOp<float>);
357
- REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), UpFirDn2DOp<Eigen::half>);
358
-
359
- //------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/dnnlib/tflib/ops/upfirdn_2d.py DELETED
@@ -1,418 +0,0 @@
1
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Custom TensorFlow ops for efficient resampling of 2D images."""
10
-
11
- import os
12
- import numpy as np
13
- import tensorflow as tf
14
- from .. import custom_ops
15
-
16
- def _get_plugin():
17
- return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu')
18
-
19
- #----------------------------------------------------------------------------
20
-
21
- def upfirdn_2d(x, k, upx=1, upy=1, downx=1, downy=1, padx0=0, padx1=0, pady0=0, pady1=0, impl='cuda'):
22
- r"""Pad, upsample, FIR filter, and downsample a batch of 2D images.
23
-
24
- Accepts a batch of 2D images of the shape `[majorDim, inH, inW, minorDim]`
25
- and performs the following operations for each image, batched across
26
- `majorDim` and `minorDim`:
27
-
28
- 1. Upsample the image by inserting the zeros after each pixel (`upx`, `upy`).
29
-
30
- 2. Pad the image with zeros by the specified number of pixels on each side
31
- (`padx0`, `padx1`, `pady0`, `pady1`). Specifying a negative value
32
- corresponds to cropping the image.
33
-
34
- 3. Convolve the image with the specified 2D FIR filter (`k`), shrinking the
35
- image so that the footprint of all output pixels lies within the input image.
36
-
37
- 4. Downsample the image by throwing away pixels (`downx`, `downy`).
38
-
39
- This sequence of operations bears close resemblance to scipy.signal.upfirdn().
40
- The fused op is considerably more efficient than performing the same calculation
41
- using standard TensorFlow ops. It supports gradients of arbitrary order.
42
-
43
- Args:
44
- x: Input tensor of the shape `[majorDim, inH, inW, minorDim]`.
45
- k: 2D FIR filter of the shape `[firH, firW]`.
46
- upx: Integer upsampling factor along the X-axis (default: 1).
47
- upy: Integer upsampling factor along the Y-axis (default: 1).
48
- downx: Integer downsampling factor along the X-axis (default: 1).
49
- downy: Integer downsampling factor along the Y-axis (default: 1).
50
- padx0: Number of pixels to pad on the left side (default: 0).
51
- padx1: Number of pixels to pad on the right side (default: 0).
52
- pady0: Number of pixels to pad on the top side (default: 0).
53
- pady1: Number of pixels to pad on the bottom side (default: 0).
54
- impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
55
-
56
- Returns:
57
- Tensor of the shape `[majorDim, outH, outW, minorDim]`, and same datatype as `x`.
58
- """
59
-
60
- impl_dict = {
61
- 'ref': _upfirdn_2d_ref,
62
- 'cuda': _upfirdn_2d_cuda,
63
- }
64
- return impl_dict[impl](x=x, k=k, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1)
65
-
66
- #----------------------------------------------------------------------------
67
-
68
- def _upfirdn_2d_ref(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
69
- """Slow reference implementation of `upfirdn_2d()` using standard TensorFlow ops."""
70
-
71
- x = tf.convert_to_tensor(x)
72
- k = np.asarray(k, dtype=np.float32)
73
- assert x.shape.rank == 4
74
- inH = x.shape[1].value
75
- inW = x.shape[2].value
76
- minorDim = _shape(x, 3)
77
- kernelH, kernelW = k.shape
78
- assert inW >= 1 and inH >= 1
79
- assert kernelW >= 1 and kernelH >= 1
80
- assert isinstance(upx, int) and isinstance(upy, int)
81
- assert isinstance(downx, int) and isinstance(downy, int)
82
- assert isinstance(padx0, int) and isinstance(padx1, int)
83
- assert isinstance(pady0, int) and isinstance(pady1, int)
84
-
85
- # Upsample (insert zeros).
86
- x = tf.reshape(x, [-1, inH, 1, inW, 1, minorDim])
87
- x = tf.pad(x, [[0, 0], [0, 0], [0, upy - 1], [0, 0], [0, upx - 1], [0, 0]])
88
- x = tf.reshape(x, [-1, inH * upy, inW * upx, minorDim])
89
-
90
- # Pad (crop if negative).
91
- x = tf.pad(x, [[0, 0], [max(pady0, 0), max(pady1, 0)], [max(padx0, 0), max(padx1, 0)], [0, 0]])
92
- x = x[:, max(-pady0, 0) : x.shape[1].value - max(-pady1, 0), max(-padx0, 0) : x.shape[2].value - max(-padx1, 0), :]
93
-
94
- # Convolve with filter.
95
- x = tf.transpose(x, [0, 3, 1, 2])
96
- x = tf.reshape(x, [-1, 1, inH * upy + pady0 + pady1, inW * upx + padx0 + padx1])
97
- w = tf.constant(k[::-1, ::-1, np.newaxis, np.newaxis], dtype=x.dtype)
98
- x = tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='VALID', data_format='NCHW')
99
- x = tf.reshape(x, [-1, minorDim, inH * upy + pady0 + pady1 - kernelH + 1, inW * upx + padx0 + padx1 - kernelW + 1])
100
- x = tf.transpose(x, [0, 2, 3, 1])
101
-
102
- # Downsample (throw away pixels).
103
- return x[:, ::downy, ::downx, :]
104
-
105
- #----------------------------------------------------------------------------
106
-
107
- def _upfirdn_2d_cuda(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
108
- """Fast CUDA implementation of `upfirdn_2d()` using custom ops."""
109
-
110
- x = tf.convert_to_tensor(x)
111
- k = np.asarray(k, dtype=np.float32)
112
- majorDim, inH, inW, minorDim = x.shape.as_list()
113
- kernelH, kernelW = k.shape
114
- assert inW >= 1 and inH >= 1
115
- assert kernelW >= 1 and kernelH >= 1
116
- assert isinstance(upx, int) and isinstance(upy, int)
117
- assert isinstance(downx, int) and isinstance(downy, int)
118
- assert isinstance(padx0, int) and isinstance(padx1, int)
119
- assert isinstance(pady0, int) and isinstance(pady1, int)
120
-
121
- outW = (inW * upx + padx0 + padx1 - kernelW) // downx + 1
122
- outH = (inH * upy + pady0 + pady1 - kernelH) // downy + 1
123
- assert outW >= 1 and outH >= 1
124
-
125
- cuda_op = _get_plugin().up_fir_dn2d
126
- kc = tf.constant(k, dtype=x.dtype)
127
- gkc = tf.constant(k[::-1, ::-1], dtype=x.dtype)
128
- gpadx0 = kernelW - padx0 - 1
129
- gpady0 = kernelH - pady0 - 1
130
- gpadx1 = inW * upx - outW * downx + padx0 - upx + 1
131
- gpady1 = inH * upy - outH * downy + pady0 - upy + 1
132
-
133
- @tf.custom_gradient
134
- def func(x):
135
- y = cuda_op(x=x, k=kc, upx=int(upx), upy=int(upy), downx=int(downx), downy=int(downy), padx0=int(padx0), padx1=int(padx1), pady0=int(pady0), pady1=int(pady1))
136
- y.set_shape([majorDim, outH, outW, minorDim])
137
- @tf.custom_gradient
138
- def grad(dy):
139
- dx = cuda_op(x=dy, k=gkc, upx=int(downx), upy=int(downy), downx=int(upx), downy=int(upy), padx0=int(gpadx0), padx1=int(gpadx1), pady0=int(gpady0), pady1=int(gpady1))
140
- dx.set_shape([majorDim, inH, inW, minorDim])
141
- return dx, func
142
- return y, grad
143
- return func(x)
144
-
145
- #----------------------------------------------------------------------------
146
-
147
- def filter_2d(x, k, gain=1, padding=0, data_format='NCHW', impl='cuda'):
148
- r"""Filter a batch of 2D images with the given FIR filter.
149
-
150
- Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
151
- and filters each image with the given filter. The filter is normalized so that
152
- if the input pixels are constant, they will be scaled by the specified `gain`.
153
- Pixels outside the image are assumed to be zero.
154
-
155
- Args:
156
- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
157
- k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
158
- gain: Scaling factor for signal magnitude (default: 1.0).
159
- padding: Number of pixels to pad or crop the output on each side (default: 0).
160
- data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
161
- impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
162
-
163
- Returns:
164
- Tensor of the same shape and datatype as `x`.
165
- """
166
-
167
- assert isinstance(padding, int)
168
- k = _FilterKernel(k=k, gain=gain)
169
- assert k.w == k.h
170
- pad0 = k.w // 2 + padding
171
- pad1 = (k.w - 1) // 2 + padding
172
- return _simple_upfirdn_2d(x, k, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl)
173
-
174
- #----------------------------------------------------------------------------
175
-
176
- def upsample_2d(x, k=None, factor=2, gain=1, padding=0, data_format='NCHW', impl='cuda'):
177
- r"""Upsample a batch of 2D images with the given filter.
178
-
179
- Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
180
- and upsamples each image with the given filter. The filter is normalized so that
181
- if the input pixels are constant, they will be scaled by the specified `gain`.
182
- Pixels outside the image are assumed to be zero, and the filter is padded with
183
- zeros so that its shape is a multiple of the upsampling factor.
184
-
185
- Args:
186
- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
187
- k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
188
- The default is `[1] * factor`, which corresponds to nearest-neighbor
189
- upsampling.
190
- factor: Integer upsampling factor (default: 2).
191
- gain: Scaling factor for signal magnitude (default: 1.0).
192
- padding: Number of pixels to pad or crop the output on each side (default: 0).
193
- data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
194
- impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
195
-
196
- Returns:
197
- Tensor of the shape `[N, C, H * factor, W * factor]` or
198
- `[N, H * factor, W * factor, C]`, and same datatype as `x`.
199
- """
200
-
201
- assert isinstance(factor, int) and factor >= 1
202
- assert isinstance(padding, int)
203
- k = _FilterKernel(k if k is not None else [1] * factor, gain * (factor ** 2))
204
- assert k.w == k.h
205
- pad0 = (k.w + factor - 1) // 2 + padding
206
- pad1 = (k.w - factor) // 2 + padding
207
- return _simple_upfirdn_2d(x, k, up=factor, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl)
208
-
209
- #----------------------------------------------------------------------------
210
-
211
- def downsample_2d(x, k=None, factor=2, gain=1, padding=0, data_format='NCHW', impl='cuda'):
212
- r"""Downsample a batch of 2D images with the given filter.
213
-
214
- Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
215
- and downsamples each image with the given filter. The filter is normalized so that
216
- if the input pixels are constant, they will be scaled by the specified `gain`.
217
- Pixels outside the image are assumed to be zero, and the filter is padded with
218
- zeros so that its shape is a multiple of the downsampling factor.
219
-
220
- Args:
221
- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
222
- k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
223
- The default is `[1] * factor`, which corresponds to average pooling.
224
- factor: Integer downsampling factor (default: 2).
225
- gain: Scaling factor for signal magnitude (default: 1.0).
226
- padding: Number of pixels to pad or crop the output on each side (default: 0).
227
- data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
228
- impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
229
-
230
- Returns:
231
- Tensor of the shape `[N, C, H // factor, W // factor]` or
232
- `[N, H // factor, W // factor, C]`, and same datatype as `x`.
233
- """
234
-
235
- assert isinstance(factor, int) and factor >= 1
236
- assert isinstance(padding, int)
237
- k = _FilterKernel(k if k is not None else [1] * factor, gain)
238
- assert k.w == k.h
239
- pad0 = (k.w - factor + 1) // 2 + padding * factor
240
- pad1 = (k.w - factor) // 2 + padding * factor
241
- return _simple_upfirdn_2d(x, k, down=factor, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl)
242
-
243
- #----------------------------------------------------------------------------
244
-
245
- def upsample_conv_2d(x, w, k=None, factor=2, gain=1, padding=0, data_format='NCHW', impl='cuda'):
246
- r"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
247
-
248
- Padding is performed only once at the beginning, not between the operations.
249
- The fused op is considerably more efficient than performing the same calculation
250
- using standard TensorFlow ops. It supports gradients of arbitrary order.
251
-
252
- Args:
253
- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
254
- w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`.
255
- Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
256
- k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
257
- The default is `[1] * factor`, which corresponds to nearest-neighbor
258
- upsampling.
259
- factor: Integer upsampling factor (default: 2).
260
- gain: Scaling factor for signal magnitude (default: 1.0).
261
- padding: Number of pixels to pad or crop the output on each side (default: 0).
262
- data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
263
- impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
264
-
265
- Returns:
266
- Tensor of the shape `[N, C, H * factor, W * factor]` or
267
- `[N, H * factor, W * factor, C]`, and same datatype as `x`.
268
- """
269
-
270
- assert isinstance(factor, int) and factor >= 1
271
- assert isinstance(padding, int)
272
-
273
- # Check weight shape.
274
- w = tf.convert_to_tensor(w)
275
- ch, cw, _inC, _outC = w.shape.as_list()
276
- inC = _shape(w, 2)
277
- outC = _shape(w, 3)
278
- assert cw == ch
279
-
280
- # Fast path for 1x1 convolution.
281
- if cw == 1 and ch == 1:
282
- x = tf.nn.conv2d(x, w, data_format=data_format, strides=[1,1,1,1], padding='VALID')
283
- x = upsample_2d(x, k, factor=factor, gain=gain, padding=padding, data_format=data_format, impl=impl)
284
- return x
285
-
286
- # Setup filter kernel.
287
- k = _FilterKernel(k if k is not None else [1] * factor, gain * (factor ** 2))
288
- assert k.w == k.h
289
-
290
- # Determine data dimensions.
291
- if data_format == 'NCHW':
292
- stride = [1, 1, factor, factor]
293
- output_shape = [_shape(x, 0), outC, (_shape(x, 2) - 1) * factor + ch, (_shape(x, 3) - 1) * factor + cw]
294
- num_groups = _shape(x, 1) // inC
295
- else:
296
- stride = [1, factor, factor, 1]
297
- output_shape = [_shape(x, 0), (_shape(x, 1) - 1) * factor + ch, (_shape(x, 2) - 1) * factor + cw, outC]
298
- num_groups = _shape(x, 3) // inC
299
-
300
- # Transpose weights.
301
- w = tf.reshape(w, [ch, cw, inC, num_groups, -1])
302
- w = tf.transpose(w[::-1, ::-1], [0, 1, 4, 3, 2])
303
- w = tf.reshape(w, [ch, cw, -1, num_groups * inC])
304
-
305
- # Execute.
306
- x = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=stride, padding='VALID', data_format=data_format)
307
- pad0 = (k.w + factor - cw) // 2 + padding
308
- pad1 = (k.w - factor - cw + 3) // 2 + padding
309
- return _simple_upfirdn_2d(x, k, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl)
310
-
311
- #----------------------------------------------------------------------------
312
-
313
- def conv_downsample_2d(x, w, k=None, factor=2, gain=1, padding=0, data_format='NCHW', impl='cuda'):
314
- r"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
315
-
316
- Padding is performed only once at the beginning, not between the operations.
317
- The fused op is considerably more efficient than performing the same calculation
318
- using standard TensorFlow ops. It supports gradients of arbitrary order.
319
-
320
- Args:
321
- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
322
- w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`.
323
- Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
324
- k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
325
- The default is `[1] * factor`, which corresponds to average pooling.
326
- factor: Integer downsampling factor (default: 2).
327
- gain: Scaling factor for signal magnitude (default: 1.0).
328
- padding: Number of pixels to pad or crop the output on each side (default: 0).
329
- data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
330
- impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
331
-
332
- Returns:
333
- Tensor of the shape `[N, C, H // factor, W // factor]` or
334
- `[N, H // factor, W // factor, C]`, and same datatype as `x`.
335
- """
336
-
337
- assert isinstance(factor, int) and factor >= 1
338
- assert isinstance(padding, int)
339
-
340
- # Check weight shape.
341
- w = tf.convert_to_tensor(w)
342
- ch, cw, _inC, _outC = w.shape.as_list()
343
- assert cw == ch
344
-
345
- # Fast path for 1x1 convolution.
346
- if cw == 1 and ch == 1:
347
- x = downsample_2d(x, k, factor=factor, gain=gain, padding=padding, data_format=data_format, impl=impl)
348
- x = tf.nn.conv2d(x, w, data_format=data_format, strides=[1,1,1,1], padding='VALID')
349
- return x
350
-
351
- # Setup filter kernel.
352
- k = _FilterKernel(k if k is not None else [1] * factor, gain)
353
- assert k.w == k.h
354
-
355
- # Determine stride.
356
- if data_format == 'NCHW':
357
- s = [1, 1, factor, factor]
358
- else:
359
- s = [1, factor, factor, 1]
360
-
361
- # Execute.
362
- pad0 = (k.w - factor + cw) // 2 + padding * factor
363
- pad1 = (k.w - factor + cw - 1) // 2 + padding * factor
364
- x = _simple_upfirdn_2d(x, k, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl)
365
- return tf.nn.conv2d(x, w, strides=s, padding='VALID', data_format=data_format)
366
-
367
- #----------------------------------------------------------------------------
368
- # Internal helpers.
369
-
370
- class _FilterKernel:
371
- def __init__(self, k, gain=1):
372
- k = np.asarray(k, dtype=np.float32)
373
- k /= np.sum(k)
374
-
375
- # Separable.
376
- if k.ndim == 1 and k.size >= 8:
377
- self.w = k.size
378
- self.h = k.size
379
- self.kx = k[np.newaxis, :]
380
- self.ky = k[:, np.newaxis] * gain
381
- self.kxy = None
382
-
383
- # Non-separable.
384
- else:
385
- if k.ndim == 1:
386
- k = np.outer(k, k)
387
- assert k.ndim == 2
388
- self.w = k.shape[1]
389
- self.h = k.shape[0]
390
- self.kx = None
391
- self.ky = None
392
- self.kxy = k * gain
393
-
394
- def _simple_upfirdn_2d(x, k, up=1, down=1, pad0=0, pad1=0, data_format='NCHW', impl='cuda'):
395
- assert isinstance(k, _FilterKernel)
396
- assert data_format in ['NCHW', 'NHWC']
397
- assert x.shape.rank == 4
398
- y = x
399
- if data_format == 'NCHW':
400
- y = tf.reshape(y, [-1, _shape(y, 2), _shape(y, 3), 1])
401
- if k.kx is not None:
402
- y = upfirdn_2d(y, k.kx, upx=up, downx=down, padx0=pad0, padx1=pad1, impl=impl)
403
- if k.ky is not None:
404
- y = upfirdn_2d(y, k.ky, upy=up, downy=down, pady0=pad0, pady1=pad1, impl=impl)
405
- if k.kxy is not None:
406
- y = upfirdn_2d(y, k.kxy, upx=up, upy=up, downx=down, downy=down, padx0=pad0, padx1=pad1, pady0=pad0, pady1=pad1, impl=impl)
407
- if data_format == 'NCHW':
408
- y = tf.reshape(y, [-1, _shape(x, 1), _shape(y, 1), _shape(y, 2)])
409
- return y
410
-
411
- def _shape(tf_expr, dim_idx):
412
- if tf_expr.shape.rank is not None:
413
- dim = tf_expr.shape[dim_idx].value
414
- if dim is not None:
415
- return dim
416
- return tf.shape(tf_expr)[dim_idx]
417
-
418
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/dnnlib/tflib/optimizer.py DELETED
@@ -1,372 +0,0 @@
1
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Helper wrapper for a Tensorflow optimizer."""
10
-
11
- import platform
12
- import numpy as np
13
- import tensorflow as tf
14
-
15
- from collections import OrderedDict
16
- from typing import List, Union
17
-
18
- from . import autosummary
19
- from . import tfutil
20
- from .. import util
21
-
22
- from .tfutil import TfExpression, TfExpressionEx
23
-
24
- _collective_ops_warning_printed = False
25
- _collective_ops_group_key = 831766147
26
- _collective_ops_instance_key = 436340067
27
-
28
- class Optimizer:
29
- """A Wrapper for tf.train.Optimizer.
30
-
31
- Automatically takes care of:
32
- - Gradient averaging for multi-GPU training.
33
- - Gradient accumulation for arbitrarily large minibatches.
34
- - Dynamic loss scaling and typecasts for FP16 training.
35
- - Ignoring corrupted gradients that contain NaNs/Infs.
36
- - Reporting statistics.
37
- - Well-chosen default settings.
38
- """
39
-
40
- def __init__(self,
41
- name: str = "Train", # Name string that will appear in TensorFlow graph.
42
- tf_optimizer: str = "tf.train.AdamOptimizer", # Underlying optimizer class.
43
- learning_rate: TfExpressionEx = 0.001, # Learning rate. Can vary over time.
44
- minibatch_multiplier: TfExpressionEx = None, # Treat N consecutive minibatches as one by accumulating gradients.
45
- share: "Optimizer" = None, # Share internal state with a previously created optimizer?
46
- use_loss_scaling: bool = False, # Enable dynamic loss scaling for robust mixed-precision training?
47
- loss_scaling_init: float = 64.0, # Log2 of initial loss scaling factor.
48
- loss_scaling_inc: float = 0.0005, # Log2 of per-minibatch loss scaling increment when there is no overflow.
49
- loss_scaling_dec: float = 1.0, # Log2 of per-minibatch loss scaling decrement when there is an overflow.
50
- report_mem_usage: bool = False, # Report fine-grained memory usage statistics in TensorBoard?
51
- **kwargs):
52
-
53
- # Public fields.
54
- self.name = name
55
- self.learning_rate = learning_rate
56
- self.minibatch_multiplier = minibatch_multiplier
57
- self.id = self.name.replace("/", ".")
58
- self.scope = tf.get_default_graph().unique_name(self.id)
59
- self.optimizer_class = util.get_obj_by_name(tf_optimizer)
60
- self.optimizer_kwargs = dict(kwargs)
61
- self.use_loss_scaling = use_loss_scaling
62
- self.loss_scaling_init = loss_scaling_init
63
- self.loss_scaling_inc = loss_scaling_inc
64
- self.loss_scaling_dec = loss_scaling_dec
65
-
66
- # Private fields.
67
- self._updates_applied = False
68
- self._devices = OrderedDict() # device_name => EasyDict()
69
- self._shared_optimizers = OrderedDict() # device_name => optimizer_class
70
- self._gradient_shapes = None # [shape, ...]
71
- self._report_mem_usage = report_mem_usage
72
-
73
- # Validate arguments.
74
- assert callable(self.optimizer_class)
75
-
76
- # Share internal state if requested.
77
- if share is not None:
78
- assert isinstance(share, Optimizer)
79
- assert self.optimizer_class is share.optimizer_class
80
- assert self.learning_rate is share.learning_rate
81
- assert self.optimizer_kwargs == share.optimizer_kwargs
82
- self._shared_optimizers = share._shared_optimizers # pylint: disable=protected-access
83
-
84
- def _get_device(self, device_name: str):
85
- """Get internal state for the given TensorFlow device."""
86
- tfutil.assert_tf_initialized()
87
- if device_name in self._devices:
88
- return self._devices[device_name]
89
-
90
- # Initialize fields.
91
- device = util.EasyDict()
92
- device.name = device_name
93
- device.optimizer = None # Underlying optimizer: optimizer_class
94
- device.loss_scaling_var = None # Log2 of loss scaling: tf.Variable
95
- device.grad_raw = OrderedDict() # Raw gradients: var => [grad, ...]
96
- device.grad_clean = OrderedDict() # Clean gradients: var => grad
97
- device.grad_acc_vars = OrderedDict() # Accumulation sums: var => tf.Variable
98
- device.grad_acc_count = None # Accumulation counter: tf.Variable
99
- device.grad_acc = OrderedDict() # Accumulated gradients: var => grad
100
-
101
- # Setup TensorFlow objects.
102
- with tfutil.absolute_name_scope(self.scope + "/Devices"), tf.device(device_name), tf.control_dependencies(None):
103
- if device_name not in self._shared_optimizers:
104
- optimizer_name = self.scope.replace("/", "_") + "_opt%d" % len(self._shared_optimizers)
105
- self._shared_optimizers[device_name] = self.optimizer_class(name=optimizer_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
106
- device.optimizer = self._shared_optimizers[device_name]
107
- if self.use_loss_scaling:
108
- device.loss_scaling_var = tf.Variable(np.float32(self.loss_scaling_init), trainable=False, name="loss_scaling_var")
109
-
110
- # Register device.
111
- self._devices[device_name] = device
112
- return device
113
-
114
- def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
115
- """Register the gradients of the given loss function with respect to the given variables.
116
- Intended to be called once per GPU."""
117
- tfutil.assert_tf_initialized()
118
- assert not self._updates_applied
119
- device = self._get_device(loss.device)
120
-
121
- # Validate trainables.
122
- if isinstance(trainable_vars, dict):
123
- trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars
124
- assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
125
- assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
126
- assert all(var.device == device.name for var in trainable_vars)
127
-
128
- # Validate shapes.
129
- if self._gradient_shapes is None:
130
- self._gradient_shapes = [var.shape.as_list() for var in trainable_vars]
131
- assert len(trainable_vars) == len(self._gradient_shapes)
132
- assert all(var.shape.as_list() == var_shape for var, var_shape in zip(trainable_vars, self._gradient_shapes))
133
-
134
- # Report memory usage if requested.
135
- deps = [loss]
136
- if self._report_mem_usage:
137
- self._report_mem_usage = False
138
- try:
139
- with tf.name_scope(self.id + '_mem'), tf.device(device.name), tf.control_dependencies([loss]):
140
- deps.append(autosummary.autosummary(self.id + "/mem_usage_gb", tf.contrib.memory_stats.BytesInUse() / 2**30))
141
- except tf.errors.NotFoundError:
142
- pass
143
-
144
- # Compute gradients.
145
- with tf.name_scope(self.id + "_grad"), tf.device(device.name), tf.control_dependencies(deps):
146
- loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
147
- gate = tf.train.Optimizer.GATE_NONE # disable gating to reduce memory usage
148
- grad_list = device.optimizer.compute_gradients(loss=loss, var_list=trainable_vars, gate_gradients=gate)
149
-
150
- # Register gradients.
151
- for grad, var in grad_list:
152
- if var not in device.grad_raw:
153
- device.grad_raw[var] = []
154
- device.grad_raw[var].append(grad)
155
-
156
- def apply_updates(self, allow_no_op: bool = False) -> tf.Operation:
157
- """Construct training op to update the registered variables based on their gradients."""
158
- tfutil.assert_tf_initialized()
159
- assert not self._updates_applied
160
- self._updates_applied = True
161
- all_ops = []
162
-
163
- # Check for no-op.
164
- if allow_no_op and len(self._devices) == 0:
165
- with tfutil.absolute_name_scope(self.scope):
166
- return tf.no_op(name='TrainingOp')
167
-
168
- # Clean up gradients.
169
- for device_idx, device in enumerate(self._devices.values()):
170
- with tfutil.absolute_name_scope(self.scope + "/Clean%d" % device_idx), tf.device(device.name):
171
- for var, grad in device.grad_raw.items():
172
-
173
- # Filter out disconnected gradients and convert to float32.
174
- grad = [g for g in grad if g is not None]
175
- grad = [tf.cast(g, tf.float32) for g in grad]
176
-
177
- # Sum within the device.
178
- if len(grad) == 0:
179
- grad = tf.zeros(var.shape) # No gradients => zero.
180
- elif len(grad) == 1:
181
- grad = grad[0] # Single gradient => use as is.
182
- else:
183
- grad = tf.add_n(grad) # Multiple gradients => sum.
184
-
185
- # Scale as needed.
186
- scale = 1.0 / len(device.grad_raw[var]) / len(self._devices)
187
- scale = tf.constant(scale, dtype=tf.float32, name="scale")
188
- if self.minibatch_multiplier is not None:
189
- scale /= tf.cast(self.minibatch_multiplier, tf.float32)
190
- scale = self.undo_loss_scaling(scale)
191
- device.grad_clean[var] = grad * scale
192
-
193
- # Sum gradients across devices.
194
- if len(self._devices) > 1:
195
- with tfutil.absolute_name_scope(self.scope + "/Broadcast"), tf.device(None):
196
- if platform.system() == "Windows": # Windows => NCCL ops are not available.
197
- self._broadcast_fallback()
198
- elif tf.VERSION.startswith("1.15."): # TF 1.15 => NCCL ops are broken: https://github.com/tensorflow/tensorflow/issues/41539
199
- self._broadcast_fallback()
200
- else: # Otherwise => NCCL ops are safe to use.
201
- self._broadcast_nccl()
202
-
203
- # Apply updates separately on each device.
204
- for device_idx, device in enumerate(self._devices.values()):
205
- with tfutil.absolute_name_scope(self.scope + "/Apply%d" % device_idx), tf.device(device.name):
206
- # pylint: disable=cell-var-from-loop
207
-
208
- # Accumulate gradients over time.
209
- if self.minibatch_multiplier is None:
210
- acc_ok = tf.constant(True, name='acc_ok')
211
- device.grad_acc = OrderedDict(device.grad_clean)
212
- else:
213
- # Create variables.
214
- with tf.control_dependencies(None):
215
- for var in device.grad_clean.keys():
216
- device.grad_acc_vars[var] = tf.Variable(tf.zeros(var.shape), trainable=False, name="grad_acc_var")
217
- device.grad_acc_count = tf.Variable(tf.zeros([]), trainable=False, name="grad_acc_count")
218
-
219
- # Track counter.
220
- count_cur = device.grad_acc_count + 1.0
221
- count_inc_op = lambda: tf.assign(device.grad_acc_count, count_cur)
222
- count_reset_op = lambda: tf.assign(device.grad_acc_count, tf.zeros([]))
223
- acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier, tf.float32))
224
- all_ops.append(tf.cond(acc_ok, count_reset_op, count_inc_op))
225
-
226
- # Track gradients.
227
- for var, grad in device.grad_clean.items():
228
- acc_var = device.grad_acc_vars[var]
229
- acc_cur = acc_var + grad
230
- device.grad_acc[var] = acc_cur
231
- with tf.control_dependencies([acc_cur]):
232
- acc_inc_op = lambda: tf.assign(acc_var, acc_cur)
233
- acc_reset_op = lambda: tf.assign(acc_var, tf.zeros(var.shape))
234
- all_ops.append(tf.cond(acc_ok, acc_reset_op, acc_inc_op))
235
-
236
- # No overflow => apply gradients.
237
- all_ok = tf.reduce_all(tf.stack([acc_ok] + [tf.reduce_all(tf.is_finite(g)) for g in device.grad_acc.values()]))
238
- apply_op = lambda: device.optimizer.apply_gradients([(tf.cast(grad, var.dtype), var) for var, grad in device.grad_acc.items()])
239
- all_ops.append(tf.cond(all_ok, apply_op, tf.no_op))
240
-
241
- # Adjust loss scaling.
242
- if self.use_loss_scaling:
243
- ls_inc_op = lambda: tf.assign_add(device.loss_scaling_var, self.loss_scaling_inc)
244
- ls_dec_op = lambda: tf.assign_sub(device.loss_scaling_var, self.loss_scaling_dec)
245
- ls_update_op = lambda: tf.group(tf.cond(all_ok, ls_inc_op, ls_dec_op))
246
- all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op))
247
-
248
- # Last device => report statistics.
249
- if device_idx == len(self._devices) - 1:
250
- all_ops.append(autosummary.autosummary(self.id + "/learning_rate", tf.convert_to_tensor(self.learning_rate)))
251
- all_ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(all_ok, 0, 1), condition=acc_ok))
252
- if self.use_loss_scaling:
253
- all_ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", device.loss_scaling_var))
254
-
255
- # Initialize variables.
256
- self.reset_optimizer_state()
257
- if self.use_loss_scaling:
258
- tfutil.init_uninitialized_vars([device.loss_scaling_var for device in self._devices.values()])
259
- if self.minibatch_multiplier is not None:
260
- tfutil.run([var.initializer for device in self._devices.values() for var in list(device.grad_acc_vars.values()) + [device.grad_acc_count]])
261
-
262
- # Group everything into a single op.
263
- with tfutil.absolute_name_scope(self.scope):
264
- return tf.group(*all_ops, name="TrainingOp")
265
-
266
- def reset_optimizer_state(self) -> None:
267
- """Reset internal state of the underlying optimizer."""
268
- tfutil.assert_tf_initialized()
269
- tfutil.run([var.initializer for device in self._devices.values() for var in device.optimizer.variables()])
270
-
271
- def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
272
- """Get or create variable representing log2 of the current dynamic loss scaling factor."""
273
- return self._get_device(device).loss_scaling_var
274
-
275
- def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
276
- """Apply dynamic loss scaling for the given expression."""
277
- assert tfutil.is_tf_expression(value)
278
- if not self.use_loss_scaling:
279
- return value
280
- return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
281
-
282
- def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
283
- """Undo the effect of dynamic loss scaling for the given expression."""
284
- assert tfutil.is_tf_expression(value)
285
- if not self.use_loss_scaling:
286
- return value
287
- return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
288
-
289
- def _broadcast_nccl(self):
290
- """Sum gradients across devices using NCCL ops (fast path)."""
291
- from tensorflow.python.ops import nccl_ops # pylint: disable=no-name-in-module
292
- for all_vars in zip(*[device.grad_clean.keys() for device in self._devices.values()]):
293
- if any(x.shape.num_elements() > 0 for x in all_vars):
294
- all_grads = [device.grad_clean[var] for device, var in zip(self._devices.values(), all_vars)]
295
- all_grads = nccl_ops.all_sum(all_grads)
296
- for device, var, grad in zip(self._devices.values(), all_vars, all_grads):
297
- device.grad_clean[var] = grad
298
-
299
- def _broadcast_fallback(self):
300
- """Sum gradients across devices using TensorFlow collective ops (slow fallback path)."""
301
- from tensorflow.python.ops import collective_ops # pylint: disable=no-name-in-module
302
- global _collective_ops_warning_printed, _collective_ops_group_key, _collective_ops_instance_key
303
- if all(x.shape.num_elements() == 0 for device in self._devices.values() for x in device.grad_clean.values()):
304
- return
305
- if not _collective_ops_warning_printed:
306
- print("------------------------------------------------------------------------")
307
- print("WARNING: Using slow fallback implementation for inter-GPU communication.")
308
- print("Please use TensorFlow 1.14 on Linux for optimal training performance.")
309
- print("------------------------------------------------------------------------")
310
- _collective_ops_warning_printed = True
311
- for device in self._devices.values():
312
- with tf.device(device.name):
313
- combo = [tf.reshape(x, [x.shape.num_elements()]) for x in device.grad_clean.values()]
314
- combo = tf.concat(combo, axis=0)
315
- combo = collective_ops.all_reduce(combo, merge_op='Add', final_op='Id',
316
- group_size=len(self._devices), group_key=_collective_ops_group_key,
317
- instance_key=_collective_ops_instance_key)
318
- cur_ofs = 0
319
- for var, grad_old in device.grad_clean.items():
320
- grad_new = tf.reshape(combo[cur_ofs : cur_ofs + grad_old.shape.num_elements()], grad_old.shape)
321
- cur_ofs += grad_old.shape.num_elements()
322
- device.grad_clean[var] = grad_new
323
- _collective_ops_instance_key += 1
324
-
325
-
326
- class SimpleAdam:
327
- """Simplified version of tf.train.AdamOptimizer that behaves identically when used with dnnlib.tflib.Optimizer."""
328
-
329
- def __init__(self, name="Adam", learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
330
- self.name = name
331
- self.learning_rate = learning_rate
332
- self.beta1 = beta1
333
- self.beta2 = beta2
334
- self.epsilon = epsilon
335
- self.all_state_vars = []
336
-
337
- def variables(self):
338
- return self.all_state_vars
339
-
340
- def compute_gradients(self, loss, var_list, gate_gradients=tf.train.Optimizer.GATE_NONE):
341
- assert gate_gradients == tf.train.Optimizer.GATE_NONE
342
- return list(zip(tf.gradients(loss, var_list), var_list))
343
-
344
- def apply_gradients(self, grads_and_vars):
345
- with tf.name_scope(self.name):
346
- state_vars = []
347
- update_ops = []
348
-
349
- # Adjust learning rate to deal with startup bias.
350
- with tf.control_dependencies(None):
351
- b1pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
352
- b2pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
353
- state_vars += [b1pow_var, b2pow_var]
354
- b1pow_new = b1pow_var * self.beta1
355
- b2pow_new = b2pow_var * self.beta2
356
- update_ops += [tf.assign(b1pow_var, b1pow_new), tf.assign(b2pow_var, b2pow_new)]
357
- lr_new = self.learning_rate * tf.sqrt(1 - b2pow_new) / (1 - b1pow_new)
358
-
359
- # Construct ops to update each variable.
360
- for grad, var in grads_and_vars:
361
- with tf.control_dependencies(None):
362
- m_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
363
- v_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
364
- state_vars += [m_var, v_var]
365
- m_new = self.beta1 * m_var + (1 - self.beta1) * grad
366
- v_new = self.beta2 * v_var + (1 - self.beta2) * tf.square(grad)
367
- var_delta = lr_new * m_new / (tf.sqrt(v_new) + self.epsilon)
368
- update_ops += [tf.assign(m_var, m_new), tf.assign(v_var, v_new), tf.assign_sub(var, var_delta)]
369
-
370
- # Group everything together.
371
- self.all_state_vars += state_vars
372
- return tf.group(*update_ops)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/dnnlib/tflib/tfutil.py DELETED
@@ -1,262 +0,0 @@
1
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Miscellaneous helper utils for Tensorflow."""
10
-
11
- import os
12
- import numpy as np
13
- import tensorflow as tf
14
-
15
- # Silence deprecation warnings from TensorFlow 1.13 onwards
16
- import logging
17
- logging.getLogger('tensorflow').setLevel(logging.ERROR)
18
- import tensorflow.contrib # requires TensorFlow 1.x!
19
- tf.contrib = tensorflow.contrib
20
-
21
- from typing import Any, Iterable, List, Union
22
-
23
- TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation]
24
- """A type that represents a valid Tensorflow expression."""
25
-
26
- TfExpressionEx = Union[TfExpression, int, float, np.ndarray]
27
- """A type that can be converted to a valid Tensorflow expression."""
28
-
29
-
30
- def run(*args, **kwargs) -> Any:
31
- """Run the specified ops in the default session."""
32
- assert_tf_initialized()
33
- return tf.get_default_session().run(*args, **kwargs)
34
-
35
-
36
- def is_tf_expression(x: Any) -> bool:
37
- """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation."""
38
- return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation))
39
-
40
-
41
- def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]:
42
- """Convert a Tensorflow shape to a list of ints. Retained for backwards compatibility -- use TensorShape.as_list() in new code."""
43
- return [dim.value for dim in shape]
44
-
45
-
46
- def flatten(x: TfExpressionEx) -> TfExpression:
47
- """Shortcut function for flattening a tensor."""
48
- with tf.name_scope("Flatten"):
49
- return tf.reshape(x, [-1])
50
-
51
-
52
- def log2(x: TfExpressionEx) -> TfExpression:
53
- """Logarithm in base 2."""
54
- with tf.name_scope("Log2"):
55
- return tf.log(x) * np.float32(1.0 / np.log(2.0))
56
-
57
-
58
- def exp2(x: TfExpressionEx) -> TfExpression:
59
- """Exponent in base 2."""
60
- with tf.name_scope("Exp2"):
61
- return tf.exp(x * np.float32(np.log(2.0)))
62
-
63
-
64
- def erfinv(y: TfExpressionEx) -> TfExpression:
65
- """Inverse of the error function."""
66
- # pylint: disable=no-name-in-module
67
- from tensorflow.python.ops.distributions import special_math
68
- return special_math.erfinv(y)
69
-
70
-
71
- def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx:
72
- """Linear interpolation."""
73
- with tf.name_scope("Lerp"):
74
- return a + (b - a) * t
75
-
76
-
77
- def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression:
78
- """Linear interpolation with clip."""
79
- with tf.name_scope("LerpClip"):
80
- return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
81
-
82
-
83
- def absolute_name_scope(scope: str) -> tf.name_scope:
84
- """Forcefully enter the specified name scope, ignoring any surrounding scopes."""
85
- return tf.name_scope(scope + "/")
86
-
87
-
88
- def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope:
89
- """Forcefully enter the specified variable scope, ignoring any surrounding scopes."""
90
- return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False)
91
-
92
-
93
- def _sanitize_tf_config(config_dict: dict = None) -> dict:
94
- # Defaults.
95
- cfg = dict()
96
- cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is.
97
- cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
98
- cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
99
- cfg["env.HDF5_USE_FILE_LOCKING"] = "FALSE" # Disable HDF5 file locking to avoid concurrency issues with network shares.
100
- cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
101
- cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.
102
-
103
- # Remove defaults for environment variables that are already set.
104
- for key in list(cfg):
105
- fields = key.split(".")
106
- if fields[0] == "env":
107
- assert len(fields) == 2
108
- if fields[1] in os.environ:
109
- del cfg[key]
110
-
111
- # User overrides.
112
- if config_dict is not None:
113
- cfg.update(config_dict)
114
- return cfg
115
-
116
-
117
- def init_tf(config_dict: dict = None) -> None:
118
- """Initialize TensorFlow session using good default settings."""
119
- # Skip if already initialized.
120
- if tf.get_default_session() is not None:
121
- return
122
-
123
- # Setup config dict and random seeds.
124
- cfg = _sanitize_tf_config(config_dict)
125
- np_random_seed = cfg["rnd.np_random_seed"]
126
- if np_random_seed is not None:
127
- np.random.seed(np_random_seed)
128
- tf_random_seed = cfg["rnd.tf_random_seed"]
129
- if tf_random_seed == "auto":
130
- tf_random_seed = np.random.randint(1 << 31)
131
- if tf_random_seed is not None:
132
- tf.set_random_seed(tf_random_seed)
133
-
134
- # Setup environment variables.
135
- for key, value in cfg.items():
136
- fields = key.split(".")
137
- if fields[0] == "env":
138
- assert len(fields) == 2
139
- os.environ[fields[1]] = str(value)
140
-
141
- # Create default TensorFlow session.
142
- create_session(cfg, force_as_default=True)
143
-
144
-
145
- def assert_tf_initialized():
146
- """Check that TensorFlow session has been initialized."""
147
- if tf.get_default_session() is None:
148
- raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().")
149
-
150
-
151
- def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session:
152
- """Create tf.Session based on config dict."""
153
- # Setup TensorFlow config proto.
154
- cfg = _sanitize_tf_config(config_dict)
155
- config_proto = tf.ConfigProto()
156
- for key, value in cfg.items():
157
- fields = key.split(".")
158
- if fields[0] not in ["rnd", "env"]:
159
- obj = config_proto
160
- for field in fields[:-1]:
161
- obj = getattr(obj, field)
162
- setattr(obj, fields[-1], value)
163
-
164
- # Create session.
165
- session = tf.Session(config=config_proto)
166
- if force_as_default:
167
- # pylint: disable=protected-access
168
- session._default_session = session.as_default()
169
- session._default_session.enforce_nesting = False
170
- session._default_session.__enter__()
171
- return session
172
-
173
-
174
- def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:
175
- """Initialize all tf.Variables that have not already been initialized.
176
-
177
- Equivalent to the following, but more efficient and does not bloat the tf graph:
178
- tf.variables_initializer(tf.report_uninitialized_variables()).run()
179
- """
180
- assert_tf_initialized()
181
- if target_vars is None:
182
- target_vars = tf.global_variables()
183
-
184
- test_vars = []
185
- test_ops = []
186
-
187
- with tf.control_dependencies(None): # ignore surrounding control_dependencies
188
- for var in target_vars:
189
- assert is_tf_expression(var)
190
-
191
- try:
192
- tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0"))
193
- except KeyError:
194
- # Op does not exist => variable may be uninitialized.
195
- test_vars.append(var)
196
-
197
- with absolute_name_scope(var.name.split(":")[0]):
198
- test_ops.append(tf.is_variable_initialized(var))
199
-
200
- init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
201
- run([var.initializer for var in init_vars])
202
-
203
-
204
- def set_vars(var_to_value_dict: dict) -> None:
205
- """Set the values of given tf.Variables.
206
-
207
- Equivalent to the following, but more efficient and does not bloat the tf graph:
208
- tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
209
- """
210
- assert_tf_initialized()
211
- ops = []
212
- feed_dict = {}
213
-
214
- for var, value in var_to_value_dict.items():
215
- assert is_tf_expression(var)
216
-
217
- try:
218
- setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op
219
- except KeyError:
220
- with absolute_name_scope(var.name.split(":")[0]):
221
- with tf.control_dependencies(None): # ignore surrounding control_dependencies
222
- setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter
223
-
224
- ops.append(setter)
225
- feed_dict[setter.op.inputs[1]] = value
226
-
227
- run(ops, feed_dict)
228
-
229
-
230
- def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs):
231
- """Create tf.Variable with large initial value without bloating the tf graph."""
232
- assert_tf_initialized()
233
- assert isinstance(initial_value, np.ndarray)
234
- zeros = tf.zeros(initial_value.shape, initial_value.dtype)
235
- var = tf.Variable(zeros, *args, **kwargs)
236
- set_vars({var: initial_value})
237
- return var
238
-
239
-
240
- def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
241
- """Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
242
- Can be used as an input transformation for Network.run().
243
- """
244
- images = tf.cast(images, tf.float32)
245
- if nhwc_to_nchw:
246
- images = tf.transpose(images, [0, 3, 1, 2])
247
- return images * ((drange[1] - drange[0]) / 255) + drange[0]
248
-
249
-
250
- def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1):
251
- """Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
252
- Can be used as an output transformation for Network.run().
253
- """
254
- images = tf.cast(images, tf.float32)
255
- if shrink > 1:
256
- ksize = [1, 1, shrink, shrink]
257
- images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW")
258
- if nchw_to_nhwc:
259
- images = tf.transpose(images, [0, 2, 3, 1])
260
- scale = 255 / (drange[1] - drange[0])
261
- images = images * scale + (0.5 - drange[0] * scale)
262
- return tf.saturate_cast(images, tf.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/dnnlib/util.py DELETED
@@ -1,472 +0,0 @@
1
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Miscellaneous utility classes and functions."""
10
-
11
- import ctypes
12
- import fnmatch
13
- import importlib
14
- import inspect
15
- import numpy as np
16
- import os
17
- import shutil
18
- import sys
19
- import types
20
- import io
21
- import pickle
22
- import re
23
- import requests
24
- import html
25
- import hashlib
26
- import glob
27
- import tempfile
28
- import urllib
29
- import urllib.request
30
- import uuid
31
-
32
- from distutils.util import strtobool
33
- from typing import Any, List, Tuple, Union
34
-
35
-
36
- # Util classes
37
- # ------------------------------------------------------------------------------------------
38
-
39
-
40
- class EasyDict(dict):
41
- """Convenience class that behaves like a dict but allows access with the attribute syntax."""
42
-
43
- def __getattr__(self, name: str) -> Any:
44
- try:
45
- return self[name]
46
- except KeyError:
47
- raise AttributeError(name)
48
-
49
- def __setattr__(self, name: str, value: Any) -> None:
50
- self[name] = value
51
-
52
- def __delattr__(self, name: str) -> None:
53
- del self[name]
54
-
55
-
56
- class Logger(object):
57
- """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
58
-
59
- def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
60
- self.file = None
61
-
62
- if file_name is not None:
63
- self.file = open(file_name, file_mode)
64
-
65
- self.should_flush = should_flush
66
- self.stdout = sys.stdout
67
- self.stderr = sys.stderr
68
-
69
- sys.stdout = self
70
- sys.stderr = self
71
-
72
- def __enter__(self) -> "Logger":
73
- return self
74
-
75
- def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
76
- self.close()
77
-
78
- def write(self, text: str) -> None:
79
- """Write text to stdout (and a file) and optionally flush."""
80
- if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
81
- return
82
-
83
- if self.file is not None:
84
- self.file.write(text)
85
-
86
- self.stdout.write(text)
87
-
88
- if self.should_flush:
89
- self.flush()
90
-
91
- def flush(self) -> None:
92
- """Flush written text to both stdout and a file, if open."""
93
- if self.file is not None:
94
- self.file.flush()
95
-
96
- self.stdout.flush()
97
-
98
- def close(self) -> None:
99
- """Flush, close possible files, and remove stdout/stderr mirroring."""
100
- self.flush()
101
-
102
- # if using multiple loggers, prevent closing in wrong order
103
- if sys.stdout is self:
104
- sys.stdout = self.stdout
105
- if sys.stderr is self:
106
- sys.stderr = self.stderr
107
-
108
- if self.file is not None:
109
- self.file.close()
110
-
111
-
112
- # Cache directories
113
- # ------------------------------------------------------------------------------------------
114
-
115
- _dnnlib_cache_dir = None
116
-
117
- def set_cache_dir(path: str) -> None:
118
- global _dnnlib_cache_dir
119
- _dnnlib_cache_dir = path
120
-
121
- def make_cache_dir_path(*paths: str) -> str:
122
- if _dnnlib_cache_dir is not None:
123
- return os.path.join(_dnnlib_cache_dir, *paths)
124
- if 'DNNLIB_CACHE_DIR' in os.environ:
125
- return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
126
- if 'HOME' in os.environ:
127
- return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
128
- if 'USERPROFILE' in os.environ:
129
- return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
130
- return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
131
-
132
- # Small util functions
133
- # ------------------------------------------------------------------------------------------
134
-
135
-
136
- def format_time(seconds: Union[int, float]) -> str:
137
- """Convert the seconds to human readable string with days, hours, minutes and seconds."""
138
- s = int(np.rint(seconds))
139
-
140
- if s < 60:
141
- return "{0}s".format(s)
142
- elif s < 60 * 60:
143
- return "{0}m {1:02}s".format(s // 60, s % 60)
144
- elif s < 24 * 60 * 60:
145
- return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
146
- else:
147
- return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
148
-
149
-
150
- def ask_yes_no(question: str) -> bool:
151
- """Ask the user the question until the user inputs a valid answer."""
152
- while True:
153
- try:
154
- print("{0} [y/n]".format(question))
155
- return strtobool(input().lower())
156
- except ValueError:
157
- pass
158
-
159
-
160
- def tuple_product(t: Tuple) -> Any:
161
- """Calculate the product of the tuple elements."""
162
- result = 1
163
-
164
- for v in t:
165
- result *= v
166
-
167
- return result
168
-
169
-
170
- _str_to_ctype = {
171
- "uint8": ctypes.c_ubyte,
172
- "uint16": ctypes.c_uint16,
173
- "uint32": ctypes.c_uint32,
174
- "uint64": ctypes.c_uint64,
175
- "int8": ctypes.c_byte,
176
- "int16": ctypes.c_int16,
177
- "int32": ctypes.c_int32,
178
- "int64": ctypes.c_int64,
179
- "float32": ctypes.c_float,
180
- "float64": ctypes.c_double
181
- }
182
-
183
-
184
- def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
185
- """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
186
- type_str = None
187
-
188
- if isinstance(type_obj, str):
189
- type_str = type_obj
190
- elif hasattr(type_obj, "__name__"):
191
- type_str = type_obj.__name__
192
- elif hasattr(type_obj, "name"):
193
- type_str = type_obj.name
194
- else:
195
- raise RuntimeError("Cannot infer type name from input")
196
-
197
- assert type_str in _str_to_ctype.keys()
198
-
199
- my_dtype = np.dtype(type_str)
200
- my_ctype = _str_to_ctype[type_str]
201
-
202
- assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
203
-
204
- return my_dtype, my_ctype
205
-
206
-
207
- def is_pickleable(obj: Any) -> bool:
208
- try:
209
- with io.BytesIO() as stream:
210
- pickle.dump(obj, stream)
211
- return True
212
- except:
213
- return False
214
-
215
-
216
- # Functionality to import modules/objects by name, and call functions by name
217
- # ------------------------------------------------------------------------------------------
218
-
219
- def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
220
- """Searches for the underlying module behind the name to some python object.
221
- Returns the module and the object name (original name with module part removed)."""
222
-
223
- # allow convenience shorthands, substitute them by full names
224
- obj_name = re.sub("^np.", "numpy.", obj_name)
225
- obj_name = re.sub("^tf.", "tensorflow.", obj_name)
226
-
227
- # list alternatives for (module_name, local_obj_name)
228
- parts = obj_name.split(".")
229
- name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
230
-
231
- # try each alternative in turn
232
- for module_name, local_obj_name in name_pairs:
233
- try:
234
- module = importlib.import_module(module_name) # may raise ImportError
235
- get_obj_from_module(module, local_obj_name) # may raise AttributeError
236
- return module, local_obj_name
237
- except:
238
- pass
239
-
240
- # maybe some of the modules themselves contain errors?
241
- for module_name, _local_obj_name in name_pairs:
242
- try:
243
- importlib.import_module(module_name) # may raise ImportError
244
- except ImportError:
245
- if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
246
- raise
247
-
248
- # maybe the requested attribute is missing?
249
- for module_name, local_obj_name in name_pairs:
250
- try:
251
- module = importlib.import_module(module_name) # may raise ImportError
252
- get_obj_from_module(module, local_obj_name) # may raise AttributeError
253
- except ImportError:
254
- pass
255
-
256
- # we are out of luck, but we have no idea why
257
- raise ImportError(obj_name)
258
-
259
-
260
- def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
261
- """Traverses the object name and returns the last (rightmost) python object."""
262
- if obj_name == '':
263
- return module
264
- obj = module
265
- for part in obj_name.split("."):
266
- obj = getattr(obj, part)
267
- return obj
268
-
269
-
270
- def get_obj_by_name(name: str) -> Any:
271
- """Finds the python object with the given name."""
272
- module, obj_name = get_module_from_obj_name(name)
273
- return get_obj_from_module(module, obj_name)
274
-
275
-
276
- def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
277
- """Finds the python object with the given name and calls it as a function."""
278
- assert func_name is not None
279
- func_obj = get_obj_by_name(func_name)
280
- assert callable(func_obj)
281
- return func_obj(*args, **kwargs)
282
-
283
-
284
- def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
285
- """Finds the python class with the given name and constructs it with the given arguments."""
286
- return call_func_by_name(*args, func_name=class_name, **kwargs)
287
-
288
-
289
- def get_module_dir_by_obj_name(obj_name: str) -> str:
290
- """Get the directory path of the module containing the given object name."""
291
- module, _ = get_module_from_obj_name(obj_name)
292
- return os.path.dirname(inspect.getfile(module))
293
-
294
-
295
- def is_top_level_function(obj: Any) -> bool:
296
- """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
297
- return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
298
-
299
-
300
- def get_top_level_function_name(obj: Any) -> str:
301
- """Return the fully-qualified name of a top-level function."""
302
- assert is_top_level_function(obj)
303
- module = obj.__module__
304
- if module == '__main__':
305
- module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
306
- return module + "." + obj.__name__
307
-
308
-
309
- # File system helpers
310
- # ------------------------------------------------------------------------------------------
311
-
312
- def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
313
- """List all files recursively in a given directory while ignoring given file and directory names.
314
- Returns list of tuples containing both absolute and relative paths."""
315
- assert os.path.isdir(dir_path)
316
- base_name = os.path.basename(os.path.normpath(dir_path))
317
-
318
- if ignores is None:
319
- ignores = []
320
-
321
- result = []
322
-
323
- for root, dirs, files in os.walk(dir_path, topdown=True):
324
- for ignore_ in ignores:
325
- dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
326
-
327
- # dirs need to be edited in-place
328
- for d in dirs_to_remove:
329
- dirs.remove(d)
330
-
331
- files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
332
-
333
- absolute_paths = [os.path.join(root, f) for f in files]
334
- relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
335
-
336
- if add_base_to_relative:
337
- relative_paths = [os.path.join(base_name, p) for p in relative_paths]
338
-
339
- assert len(absolute_paths) == len(relative_paths)
340
- result += zip(absolute_paths, relative_paths)
341
-
342
- return result
343
-
344
-
345
- def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
346
- """Takes in a list of tuples of (src, dst) paths and copies files.
347
- Will create all necessary directories."""
348
- for file in files:
349
- target_dir_name = os.path.dirname(file[1])
350
-
351
- # will create all intermediate-level directories
352
- if not os.path.exists(target_dir_name):
353
- os.makedirs(target_dir_name)
354
-
355
- shutil.copyfile(file[0], file[1])
356
-
357
-
358
- # URL helpers
359
- # ------------------------------------------------------------------------------------------
360
-
361
- def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
362
- """Determine whether the given object is a valid URL string."""
363
- if not isinstance(obj, str) or not "://" in obj:
364
- return False
365
- if allow_file_urls and obj.startswith('file://'):
366
- return True
367
- try:
368
- res = requests.compat.urlparse(obj)
369
- if not res.scheme or not res.netloc or not "." in res.netloc:
370
- return False
371
- res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
372
- if not res.scheme or not res.netloc or not "." in res.netloc:
373
- return False
374
- except:
375
- return False
376
- return True
377
-
378
-
379
- def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
380
- """Download the given URL and return a binary-mode file object to access the data."""
381
- assert num_attempts >= 1
382
- assert not (return_filename and (not cache))
383
-
384
- # Doesn't look like an URL scheme so interpret it as a local filename.
385
- if not re.match('^[a-z]+://', url):
386
- return url if return_filename else open(url, "rb")
387
-
388
- # Handle file URLs. This code handles unusual file:// patterns that
389
- # arise on Windows:
390
- #
391
- # file:///c:/foo.txt
392
- #
393
- # which would translate to a local '/c:/foo.txt' filename that's
394
- # invalid. Drop the forward slash for such pathnames.
395
- #
396
- # If you touch this code path, you should test it on both Linux and
397
- # Windows.
398
- #
399
- # Some internet resources suggest using urllib.request.url2pathname() but
400
- # but that converts forward slashes to backslashes and this causes
401
- # its own set of problems.
402
- if url.startswith('file://'):
403
- filename = urllib.parse.urlparse(url).path
404
- if re.match(r'^/[a-zA-Z]:', filename):
405
- filename = filename[1:]
406
- return filename if return_filename else open(filename, "rb")
407
-
408
- assert is_url(url)
409
-
410
- # Lookup from cache.
411
- if cache_dir is None:
412
- cache_dir = make_cache_dir_path('downloads')
413
-
414
- url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
415
- if cache:
416
- cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
417
- if len(cache_files) == 1:
418
- filename = cache_files[0]
419
- return filename if return_filename else open(filename, "rb")
420
-
421
- # Download.
422
- url_name = None
423
- url_data = None
424
- with requests.Session() as session:
425
- if verbose:
426
- print("Downloading %s ..." % url, end="", flush=True)
427
- for attempts_left in reversed(range(num_attempts)):
428
- try:
429
- with session.get(url) as res:
430
- res.raise_for_status()
431
- if len(res.content) == 0:
432
- raise IOError("No data received")
433
-
434
- if len(res.content) < 8192:
435
- content_str = res.content.decode("utf-8")
436
- if "download_warning" in res.headers.get("Set-Cookie", ""):
437
- links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
438
- if len(links) == 1:
439
- url = requests.compat.urljoin(url, links[0])
440
- raise IOError("Google Drive virus checker nag")
441
- if "Google Drive - Quota exceeded" in content_str:
442
- raise IOError("Google Drive download quota exceeded -- please try again later")
443
-
444
- match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
445
- url_name = match[1] if match else url
446
- url_data = res.content
447
- if verbose:
448
- print(" done")
449
- break
450
- except:
451
- if not attempts_left:
452
- if verbose:
453
- print(" failed")
454
- raise
455
- if verbose:
456
- print(".", end="", flush=True)
457
-
458
- # Save to cache.
459
- if cache:
460
- safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
461
- cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
462
- temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
463
- os.makedirs(cache_dir, exist_ok=True)
464
- with open(temp_file, "wb") as f:
465
- f.write(url_data)
466
- os.replace(temp_file, cache_file) # atomic
467
- if return_filename:
468
- return cache_file
469
-
470
- # Return data as file object.
471
- assert not return_filename
472
- return io.BytesIO(url_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/manipulate.py DELETED
@@ -1,278 +0,0 @@
1
-
2
-
3
- import os
4
- import os.path
5
- import pickle
6
- import numpy as np
7
- import tensorflow as tf
8
- from dnnlib import tflib
9
- from global_directions.utils.visualizer import HtmlPageVisualizer
10
-
11
-
12
- def Vis(bname,suffix,out,rownames=None,colnames=None):
13
- num_images=out.shape[0]
14
- step=out.shape[1]
15
-
16
- if colnames is None:
17
- colnames=[f'Step {i:02d}' for i in range(1, step + 1)]
18
- if rownames is None:
19
- rownames=[str(i) for i in range(num_images)]
20
-
21
-
22
- visualizer = HtmlPageVisualizer(
23
- num_rows=num_images, num_cols=step + 1, viz_size=256)
24
- visualizer.set_headers(
25
- ['Name'] +colnames)
26
-
27
- for i in range(num_images):
28
- visualizer.set_cell(i, 0, text=rownames[i])
29
-
30
- for i in range(num_images):
31
- for k in range(step):
32
- image=out[i,k,:,:,:]
33
- visualizer.set_cell(i, 1+k, image=image)
34
-
35
- # Save results.
36
- visualizer.save(f'./html/'+bname+'_'+suffix+'.html')
37
-
38
-
39
-
40
-
41
- def LoadData(img_path):
42
- tmp=img_path+'S'
43
- with open(tmp, "rb") as fp: #Pickling
44
- s_names,all_s=pickle.load( fp)
45
- dlatents=all_s
46
-
47
- pindexs=[]
48
- mindexs=[]
49
- for i in range(len(s_names)):
50
- name=s_names[i]
51
- if not('ToRGB' in name):
52
- mindexs.append(i)
53
- else:
54
- pindexs.append(i)
55
-
56
- tmp=img_path+'S_mean_std'
57
- with open(tmp, "rb") as fp: #Pickling
58
- m,std=pickle.load( fp)
59
-
60
- return dlatents,s_names,mindexs,pindexs,m,std
61
-
62
-
63
- def LoadModel(model_path,model_name):
64
- # Initialize TensorFlow.
65
- tflib.init_tf()
66
- tmp=os.path.join(model_path,model_name)
67
- with open(tmp, 'rb') as f:
68
- _, _, Gs = pickle.load(f)
69
- Gs.print_layers()
70
- return Gs
71
-
72
- def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False):
73
- """Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
74
- Can be used as an output transformation for Network.run().
75
- """
76
- if nchw_to_nhwc:
77
- images = np.transpose(images, [0, 2, 3, 1])
78
-
79
- scale = 255 / (drange[1] - drange[0])
80
- images = images * scale + (0.5 - drange[0] * scale)
81
-
82
- np.clip(images, 0, 255, out=images)
83
- images=images.astype('uint8')
84
- return images
85
-
86
-
87
- def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
88
- """Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
89
- Can be used as an input transformation for Network.run().
90
- """
91
- if nhwc_to_nchw:
92
- images=np.rollaxis(images, 3, 1)
93
- return images/ 255 *(drange[1] - drange[0])+ drange[0]
94
-
95
-
96
- class Manipulator():
97
- def __init__(self,dataset_name='ffhq'):
98
- self.file_path='./'
99
- self.img_path=self.file_path+'npy/'+dataset_name+'/'
100
- self.model_path=self.file_path+'model/'
101
- self.dataset_name=dataset_name
102
- self.model_name=dataset_name+'.pkl'
103
-
104
- self.alpha=[0] #manipulation strength
105
- self.num_images=10
106
- self.img_index=0 #which image to start
107
- self.viz_size=256
108
- self.manipulate_layers=None #which layer to manipulate, list
109
-
110
- self.dlatents,self.s_names,self.mindexs,self.pindexs,self.code_mean,self.code_std=LoadData(self.img_path)
111
-
112
- self.sess=tf.InteractiveSession()
113
- init = tf.global_variables_initializer()
114
- self.sess.run(init)
115
- self.Gs=LoadModel(self.model_path,self.model_name)
116
- self.num_layers=len(self.dlatents)
117
-
118
- self.Vis=Vis
119
- self.noise_constant={}
120
-
121
- for i in range(len(self.s_names)):
122
- tmp1=self.s_names[i].split('/')
123
- if not 'ToRGB' in tmp1:
124
- tmp1[-1]='random_normal:0'
125
- size=int(tmp1[1].split('x')[0])
126
- tmp1='/'.join(tmp1)
127
- tmp=(1,1,size,size)
128
- self.noise_constant[tmp1]=np.random.random(tmp)
129
-
130
- tmp=self.Gs.components.synthesis.input_shape[1]
131
- d={}
132
- d['G_synthesis_1/dlatents_in:0']=np.zeros([1,tmp,512])
133
- names=list(self.noise_constant.keys())
134
- tmp=tflib.run(names,d)
135
- for i in range(len(names)):
136
- self.noise_constant[names[i]]=tmp[i]
137
-
138
- self.fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
139
- self.img_size=self.Gs.output_shape[-1]
140
-
141
- def GenerateImg(self,codes):
142
-
143
-
144
- num_images,step=codes[0].shape[:2]
145
-
146
-
147
- out=np.zeros((num_images,step,self.img_size,self.img_size,3),dtype='uint8')
148
- for i in range(num_images):
149
- for k in range(step):
150
- d={}
151
- for m in range(len(self.s_names)):
152
- d[self.s_names[m]]=codes[m][i,k][None,:] #need to change
153
- d['G_synthesis_1/4x4/Const/Shape:0']=np.array([1,18, 512], dtype=np.int32)
154
- d.update(self.noise_constant)
155
- img=tflib.run('G_synthesis_1/images_out:0', d)
156
- image=convert_images_to_uint8(img, nchw_to_nhwc=True)
157
- out[i,k,:,:,:]=image[0]
158
- return out
159
-
160
-
161
-
162
- def MSCode(self,dlatent_tmp,boundary_tmp):
163
-
164
- step=len(self.alpha)
165
- dlatent_tmp1=[tmp.reshape((self.num_images,-1)) for tmp in dlatent_tmp]
166
- dlatent_tmp2=[np.tile(tmp[:,None],(1,step,1)) for tmp in dlatent_tmp1] # (10, 7, 512)
167
-
168
- l=np.array(self.alpha)
169
- l=l.reshape(
170
- [step if axis == 1 else 1 for axis in range(dlatent_tmp2[0].ndim)])
171
-
172
- if type(self.manipulate_layers)==int:
173
- tmp=[self.manipulate_layers]
174
- elif type(self.manipulate_layers)==list:
175
- tmp=self.manipulate_layers
176
- elif self.manipulate_layers is None:
177
- tmp=np.arange(len(boundary_tmp))
178
- else:
179
- raise ValueError('manipulate_layers is wrong')
180
-
181
- for i in tmp:
182
- dlatent_tmp2[i]+=l*boundary_tmp[i]
183
-
184
- codes=[]
185
- for i in range(len(dlatent_tmp2)):
186
- tmp=list(dlatent_tmp[i].shape)
187
- tmp.insert(1,step)
188
- codes.append(dlatent_tmp2[i].reshape(tmp))
189
- return codes
190
-
191
-
192
- def EditOne(self,bname,dlatent_tmp=None):
193
- if dlatent_tmp==None:
194
- dlatent_tmp=[tmp[self.img_index:(self.img_index+self.num_images)] for tmp in self.dlatents]
195
-
196
- boundary_tmp=[]
197
- for i in range(len(self.boundary)):
198
- tmp=self.boundary[i]
199
- if len(tmp)<=bname:
200
- boundary_tmp.append([])
201
- else:
202
- boundary_tmp.append(tmp[bname])
203
-
204
- codes=self.MSCode(dlatent_tmp,boundary_tmp)
205
-
206
- out=self.GenerateImg(codes)
207
- return codes,out
208
-
209
- def EditOneC(self,cindex,dlatent_tmp=None):
210
- if dlatent_tmp==None:
211
- dlatent_tmp=[tmp[self.img_index:(self.img_index+self.num_images)] for tmp in self.dlatents]
212
-
213
- boundary_tmp=[[] for i in range(len(self.dlatents))]
214
-
215
- #'only manipulate 1 layer and one channel'
216
- assert len(self.manipulate_layers)==1
217
-
218
- ml=self.manipulate_layers[0]
219
- tmp=dlatent_tmp[ml].shape[1] #ada
220
- tmp1=np.zeros(tmp)
221
- tmp1[cindex]=self.code_std[ml][cindex] #1
222
- boundary_tmp[ml]=tmp1
223
-
224
- codes=self.MSCode(dlatent_tmp,boundary_tmp)
225
- out=self.GenerateImg(codes)
226
- return codes,out
227
-
228
-
229
- def W2S(self,dlatent_tmp):
230
-
231
- all_s = self.sess.run(
232
- self.s_names,
233
- feed_dict={'G_synthesis_1/dlatents_in:0': dlatent_tmp})
234
- return all_s
235
-
236
-
237
-
238
-
239
-
240
-
241
-
242
-
243
- #%%
244
- if __name__ == "__main__":
245
-
246
-
247
- M=Manipulator(dataset_name='ffhq')
248
-
249
-
250
- #%%
251
- M.alpha=[-5,0,5]
252
- M.num_images=20
253
- lindex,cindex=6,501
254
-
255
- M.manipulate_layers=[lindex]
256
- codes,out=M.EditOneC(cindex) #dlatent_tmp
257
- tmp=str(M.manipulate_layers)+'_'+str(cindex)
258
- M.Vis(tmp,'c',out)
259
-
260
-
261
-
262
-
263
-
264
-
265
-
266
-
267
-
268
-
269
-
270
-
271
-
272
-
273
-
274
-
275
-
276
-
277
-
278
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/utils/__init__.py DELETED
File without changes
PTI/models/StyleCLIP/global_directions/utils/editor.py DELETED
@@ -1,507 +0,0 @@
1
- # python 3.7
2
- """Utility functions for image editing from latent space."""
3
-
4
- import os.path
5
- import numpy as np
6
-
7
- __all__ = [
8
- 'parse_indices', 'interpolate', 'mix_style',
9
- 'get_layerwise_manipulation_strength', 'manipulate', 'parse_boundary_list'
10
- ]
11
-
12
-
13
- def parse_indices(obj, min_val=None, max_val=None):
14
- """Parses indices.
15
-
16
- If the input is a list or tuple, this function has no effect.
17
-
18
- The input can also be a string, which is either a comma separated list of
19
- numbers 'a, b, c', or a dash separated range 'a - c'. Space in the string will
20
- be ignored.
21
-
22
- Args:
23
- obj: The input object to parse indices from.
24
- min_val: If not `None`, this function will check that all indices are equal
25
- to or larger than this value. (default: None)
26
- max_val: If not `None`, this function will check that all indices are equal
27
- to or smaller than this field. (default: None)
28
-
29
- Returns:
30
- A list of integers.
31
-
32
- Raises:
33
- If the input is invalid, i.e., neither a list or tuple, nor a string.
34
- """
35
- if obj is None or obj == '':
36
- indices = []
37
- elif isinstance(obj, int):
38
- indices = [obj]
39
- elif isinstance(obj, (list, tuple, np.ndarray)):
40
- indices = list(obj)
41
- elif isinstance(obj, str):
42
- indices = []
43
- splits = obj.replace(' ', '').split(',')
44
- for split in splits:
45
- numbers = list(map(int, split.split('-')))
46
- if len(numbers) == 1:
47
- indices.append(numbers[0])
48
- elif len(numbers) == 2:
49
- indices.extend(list(range(numbers[0], numbers[1] + 1)))
50
- else:
51
- raise ValueError(f'Invalid type of input: {type(obj)}!')
52
-
53
- assert isinstance(indices, list)
54
- indices = sorted(list(set(indices)))
55
- for idx in indices:
56
- assert isinstance(idx, int)
57
- if min_val is not None:
58
- assert idx >= min_val, f'{idx} is smaller than min val `{min_val}`!'
59
- if max_val is not None:
60
- assert idx <= max_val, f'{idx} is larger than max val `{max_val}`!'
61
-
62
- return indices
63
-
64
-
65
- def interpolate(src_codes, dst_codes, step=5):
66
- """Interpolates two sets of latent codes linearly.
67
-
68
- Args:
69
- src_codes: Source codes, with shape [num, *code_shape].
70
- dst_codes: Target codes, with shape [num, *code_shape].
71
- step: Number of interplolation steps, with source and target included. For
72
- example, if `step = 5`, three more samples will be inserted. (default: 5)
73
-
74
- Returns:
75
- Interpolated codes, with shape [num, step, *code_shape].
76
-
77
- Raises:
78
- ValueError: If the input two sets of latent codes are with different shapes.
79
- """
80
- if not (src_codes.ndim >= 2 and src_codes.shape == dst_codes.shape):
81
- raise ValueError(f'Shapes of source codes and target codes should both be '
82
- f'[num, *code_shape], but {src_codes.shape} and '
83
- f'{dst_codes.shape} are received!')
84
- num = src_codes.shape[0]
85
- code_shape = src_codes.shape[1:]
86
-
87
- a = src_codes[:, np.newaxis]
88
- b = dst_codes[:, np.newaxis]
89
- l = np.linspace(0.0, 1.0, step).reshape(
90
- [step if axis == 1 else 1 for axis in range(a.ndim)])
91
- results = a + l * (b - a)
92
- assert results.shape == (num, step, *code_shape)
93
-
94
- return results
95
-
96
-
97
- def mix_style(style_codes,
98
- content_codes,
99
- num_layers=1,
100
- mix_layers=None,
101
- is_style_layerwise=True,
102
- is_content_layerwise=True):
103
- """Mixes styles from style codes to those of content codes.
104
-
105
- Each style code or content code consists of `num_layers` codes, each of which
106
- is typically fed into a particular layer of the generator. This function mixes
107
- styles by partially replacing the codes of `content_codes` from some certain
108
- layers with those of `style_codes`.
109
-
110
- For example, if both style code and content code are with shape [10, 512],
111
- meaning to have 10 layers and each employs a 512-dimensional latent code. And
112
- the 1st, 2nd, and 3rd layers are the target layers to perform style mixing.
113
- Then the top half of the content code (with shape [3, 512]) will be replaced
114
- by the top half of the style code (also with shape [3, 512]).
115
-
116
- NOTE: This function also supports taking single-layer latent codes as inputs,
117
- i.e., setting `is_style_layerwise` or `is_content_layerwise` as False. In this
118
- case, the corresponding code will be first repeated for `num_layers` before
119
- performing style mixing.
120
-
121
- Args:
122
- style_codes: Style codes, with shape [num_styles, *code_shape] or
123
- [num_styles, num_layers, *code_shape].
124
- content_codes: Content codes, with shape [num_contents, *code_shape] or
125
- [num_contents, num_layers, *code_shape].
126
- num_layers: Total number of layers in the generative model. (default: 1)
127
- mix_layers: Indices of the layers to perform style mixing. `None` means to
128
- replace all layers, in which case the content code will be completely
129
- replaced by style code. (default: None)
130
- is_style_layerwise: Indicating whether the input `style_codes` are
131
- layer-wise codes. (default: True)
132
- is_content_layerwise: Indicating whether the input `content_codes` are
133
- layer-wise codes. (default: True)
134
- num_layers
135
-
136
- Returns:
137
- Codes after style mixing, with shape [num_styles, num_contents, num_layers,
138
- *code_shape].
139
-
140
- Raises:
141
- ValueError: If input `content_codes` or `style_codes` is with invalid shape.
142
- """
143
- if not is_style_layerwise:
144
- style_codes = style_codes[:, np.newaxis]
145
- style_codes = np.tile(
146
- style_codes,
147
- [num_layers if axis == 1 else 1 for axis in range(style_codes.ndim)])
148
- if not is_content_layerwise:
149
- content_codes = content_codes[:, np.newaxis]
150
- content_codes = np.tile(
151
- content_codes,
152
- [num_layers if axis == 1 else 1 for axis in range(content_codes.ndim)])
153
-
154
- if not (style_codes.ndim >= 3 and style_codes.shape[1] == num_layers and
155
- style_codes.shape[1:] == content_codes.shape[1:]):
156
- raise ValueError(f'Shapes of style codes and content codes should be '
157
- f'[num_styles, num_layers, *code_shape] and '
158
- f'[num_contents, num_layers, *code_shape] respectively, '
159
- f'but {style_codes.shape} and {content_codes.shape} are '
160
- f'received!')
161
-
162
- layer_indices = parse_indices(mix_layers, min_val=0, max_val=num_layers - 1)
163
- if not layer_indices:
164
- layer_indices = list(range(num_layers))
165
-
166
- num_styles = style_codes.shape[0]
167
- num_contents = content_codes.shape[0]
168
- code_shape = content_codes.shape[2:]
169
-
170
- s = style_codes[:, np.newaxis]
171
- s = np.tile(s, [num_contents if axis == 1 else 1 for axis in range(s.ndim)])
172
- c = content_codes[np.newaxis]
173
- c = np.tile(c, [num_styles if axis == 0 else 1 for axis in range(c.ndim)])
174
-
175
- from_style = np.zeros(s.shape, dtype=bool)
176
- from_style[:, :, layer_indices] = True
177
- results = np.where(from_style, s, c)
178
- assert results.shape == (num_styles, num_contents, num_layers, *code_shape)
179
-
180
- return results
181
-
182
-
183
- def get_layerwise_manipulation_strength(num_layers,
184
- truncation_psi,
185
- truncation_layers):
186
- """Gets layer-wise strength for manipulation.
187
-
188
- Recall the truncation trick played on layer [0, truncation_layers):
189
-
190
- w = truncation_psi * w + (1 - truncation_psi) * w_avg
191
-
192
- So, when using the same boundary to manipulate different layers, layer
193
- [0, truncation_layers) and layer [truncation_layers, num_layers) should use
194
- different strength to eliminate the effect from the truncation trick. More
195
- concretely, the strength for layer [0, truncation_layers) is set as
196
- `truncation_psi`, while that for other layers are set as 1.
197
- """
198
- strength = [1.0 for _ in range(num_layers)]
199
- if truncation_layers > 0:
200
- for layer_idx in range(0, truncation_layers):
201
- strength[layer_idx] = truncation_psi
202
- return strength
203
-
204
-
205
- def manipulate(latent_codes,
206
- boundary,
207
- start_distance=-5.0,
208
- end_distance=5.0,
209
- step=21,
210
- layerwise_manipulation=False,
211
- num_layers=1,
212
- manipulate_layers=None,
213
- is_code_layerwise=False,
214
- is_boundary_layerwise=False,
215
- layerwise_manipulation_strength=1.0):
216
- """Manipulates the given latent codes with respect to a particular boundary.
217
-
218
- Basically, this function takes a set of latent codes and a boundary as inputs,
219
- and outputs a collection of manipulated latent codes.
220
-
221
- For example, let `step` to be 10, `latent_codes` to be with shape [num,
222
- *code_shape], and `boundary` to be with shape [1, *code_shape] and unit norm.
223
- Then the output will be with shape [num, 10, *code_shape]. For each 10-element
224
- manipulated codes, the first code is `start_distance` away from the original
225
- code (i.e., the input) along the `boundary` direction, while the last code is
226
- `end_distance` away. Remaining codes are linearly interpolated. Here,
227
- `distance` is sign sensitive.
228
-
229
- NOTE: This function also supports layer-wise manipulation, in which case the
230
- generator should be able to take layer-wise latent codes as inputs. For
231
- example, if the generator has 18 convolutional layers in total, and each of
232
- which takes an independent latent code as input. It is possible, sometimes
233
- with even better performance, to only partially manipulate these latent codes
234
- corresponding to some certain layers yet keeping others untouched.
235
-
236
- NOTE: Boundary is assumed to be normalized to unit norm already.
237
-
238
- Args:
239
- latent_codes: The input latent codes for manipulation, with shape
240
- [num, *code_shape] or [num, num_layers, *code_shape].
241
- boundary: The semantic boundary as reference, with shape [1, *code_shape] or
242
- [1, num_layers, *code_shape].
243
- start_distance: Start point for manipulation. (default: -5.0)
244
- end_distance: End point for manipulation. (default: 5.0)
245
- step: Number of manipulation steps. (default: 21)
246
- layerwise_manipulation: Whether to perform layer-wise manipulation.
247
- (default: False)
248
- num_layers: Number of layers. Only active when `layerwise_manipulation` is
249
- set as `True`. Should be a positive integer. (default: 1)
250
- manipulate_layers: Indices of the layers to perform manipulation. `None`
251
- means to manipulate latent codes from all layers. (default: None)
252
- is_code_layerwise: Whether the input latent codes are layer-wise. If set as
253
- `False`, the function will first repeat the input codes for `num_layers`
254
- times before perform manipulation. (default: False)
255
- is_boundary_layerwise: Whether the input boundary is layer-wise. If set as
256
- `False`, the function will first repeat boundary for `num_layers` times
257
- before perform manipulation. (default: False)
258
- layerwise_manipulation_strength: Manipulation strength for each layer. Only
259
- active when `layerwise_manipulation` is set as `True`. This field can be
260
- used to resolve the strength discrepancy across layers when truncation
261
- trick is on. See function `get_layerwise_manipulation_strength()` for
262
- details. A tuple, list, or `numpy.ndarray` is expected. If set as a single
263
- number, this strength will be used for all layers. (default: 1.0)
264
-
265
- Returns:
266
- Manipulated codes, with shape [num, step, *code_shape] if
267
- `layerwise_manipulation` is set as `False`, or shape [num, step,
268
- num_layers, *code_shape] if `layerwise_manipulation` is set as `True`.
269
-
270
- Raises:
271
- ValueError: If the input latent codes, boundary, or strength are with
272
- invalid shape.
273
- """
274
- if not (boundary.ndim >= 2 and boundary.shape[0] == 1):
275
- raise ValueError(f'Boundary should be with shape [1, *code_shape] or '
276
- f'[1, num_layers, *code_shape], but '
277
- f'{boundary.shape} is received!')
278
-
279
- if not layerwise_manipulation:
280
- assert not is_code_layerwise
281
- assert not is_boundary_layerwise
282
- num_layers = 1
283
- manipulate_layers = None
284
- layerwise_manipulation_strength = 1.0
285
-
286
- # Preprocessing for layer-wise manipulation.
287
- # Parse indices of manipulation layers.
288
- layer_indices = parse_indices(
289
- manipulate_layers, min_val=0, max_val=num_layers - 1)
290
- if not layer_indices:
291
- layer_indices = list(range(num_layers))
292
- # Make latent codes layer-wise if needed.
293
- assert num_layers > 0
294
- if not is_code_layerwise:
295
- x = latent_codes[:, np.newaxis]
296
- x = np.tile(x, [num_layers if axis == 1 else 1 for axis in range(x.ndim)])
297
- else:
298
- x = latent_codes
299
- if x.shape[1] != num_layers:
300
- raise ValueError(f'Latent codes should be with shape [num, num_layers, '
301
- f'*code_shape], where `num_layers` equals to '
302
- f'{num_layers}, but {x.shape} is received!')
303
- # Make boundary layer-wise if needed.
304
- if not is_boundary_layerwise:
305
- b = boundary
306
- b = np.tile(b, [num_layers if axis == 0 else 1 for axis in range(b.ndim)])
307
- else:
308
- b = boundary[0]
309
- if b.shape[0] != num_layers:
310
- raise ValueError(f'Boundary should be with shape [num_layers, '
311
- f'*code_shape], where `num_layers` equals to '
312
- f'{num_layers}, but {b.shape} is received!')
313
- # Get layer-wise manipulation strength.
314
- if isinstance(layerwise_manipulation_strength, (int, float)):
315
- s = [float(layerwise_manipulation_strength) for _ in range(num_layers)]
316
- elif isinstance(layerwise_manipulation_strength, (list, tuple)):
317
- s = layerwise_manipulation_strength
318
- if len(s) != num_layers:
319
- raise ValueError(f'Shape of layer-wise manipulation strength `{len(s)}` '
320
- f'mismatches number of layers `{num_layers}`!')
321
- elif isinstance(layerwise_manipulation_strength, np.ndarray):
322
- s = layerwise_manipulation_strength
323
- if s.size != num_layers:
324
- raise ValueError(f'Shape of layer-wise manipulation strength `{s.size}` '
325
- f'mismatches number of layers `{num_layers}`!')
326
- else:
327
- raise ValueError(f'Unsupported type of `layerwise_manipulation_strength`!')
328
- s = np.array(s).reshape(
329
- [num_layers if axis == 0 else 1 for axis in range(b.ndim)])
330
- b = b * s
331
-
332
- if x.shape[1:] != b.shape:
333
- raise ValueError(f'Latent code shape {x.shape} and boundary shape '
334
- f'{b.shape} mismatch!')
335
- num = x.shape[0]
336
- code_shape = x.shape[2:]
337
-
338
- x = x[:, np.newaxis]
339
- b = b[np.newaxis, np.newaxis, :]
340
- l = np.linspace(start_distance, end_distance, step).reshape(
341
- [step if axis == 1 else 1 for axis in range(x.ndim)])
342
- results = np.tile(x, [step if axis == 1 else 1 for axis in range(x.ndim)])
343
- is_manipulatable = np.zeros(results.shape, dtype=bool)
344
- is_manipulatable[:, :, layer_indices] = True
345
- results = np.where(is_manipulatable, x + l * b, results)
346
- assert results.shape == (num, step, num_layers, *code_shape)
347
-
348
- return results if layerwise_manipulation else results[:, :, 0]
349
-
350
-
351
- def manipulate2(latent_codes,
352
- proj,
353
- mindex,
354
- start_distance=-5.0,
355
- end_distance=5.0,
356
- step=21,
357
- layerwise_manipulation=False,
358
- num_layers=1,
359
- manipulate_layers=None,
360
- is_code_layerwise=False,
361
- layerwise_manipulation_strength=1.0):
362
-
363
-
364
- if not layerwise_manipulation:
365
- assert not is_code_layerwise
366
- # assert not is_boundary_layerwise
367
- num_layers = 1
368
- manipulate_layers = None
369
- layerwise_manipulation_strength = 1.0
370
-
371
- # Preprocessing for layer-wise manipulation.
372
- # Parse indices of manipulation layers.
373
- layer_indices = parse_indices(
374
- manipulate_layers, min_val=0, max_val=num_layers - 1)
375
- if not layer_indices:
376
- layer_indices = list(range(num_layers))
377
- # Make latent codes layer-wise if needed.
378
- assert num_layers > 0
379
- if not is_code_layerwise:
380
- x = latent_codes[:, np.newaxis]
381
- x = np.tile(x, [num_layers if axis == 1 else 1 for axis in range(x.ndim)])
382
- else:
383
- x = latent_codes
384
- if x.shape[1] != num_layers:
385
- raise ValueError(f'Latent codes should be with shape [num, num_layers, '
386
- f'*code_shape], where `num_layers` equals to '
387
- f'{num_layers}, but {x.shape} is received!')
388
- # Make boundary layer-wise if needed.
389
- # if not is_boundary_layerwise:
390
- # b = boundary
391
- # b = np.tile(b, [num_layers if axis == 0 else 1 for axis in range(b.ndim)])
392
- # else:
393
- # b = boundary[0]
394
- # if b.shape[0] != num_layers:
395
- # raise ValueError(f'Boundary should be with shape [num_layers, '
396
- # f'*code_shape], where `num_layers` equals to '
397
- # f'{num_layers}, but {b.shape} is received!')
398
- # Get layer-wise manipulation strength.
399
- if isinstance(layerwise_manipulation_strength, (int, float)):
400
- s = [float(layerwise_manipulation_strength) for _ in range(num_layers)]
401
- elif isinstance(layerwise_manipulation_strength, (list, tuple)):
402
- s = layerwise_manipulation_strength
403
- if len(s) != num_layers:
404
- raise ValueError(f'Shape of layer-wise manipulation strength `{len(s)}` '
405
- f'mismatches number of layers `{num_layers}`!')
406
- elif isinstance(layerwise_manipulation_strength, np.ndarray):
407
- s = layerwise_manipulation_strength
408
- if s.size != num_layers:
409
- raise ValueError(f'Shape of layer-wise manipulation strength `{s.size}` '
410
- f'mismatches number of layers `{num_layers}`!')
411
- else:
412
- raise ValueError(f'Unsupported type of `layerwise_manipulation_strength`!')
413
- # s = np.array(s).reshape(
414
- # [num_layers if axis == 0 else 1 for axis in range(b.ndim)])
415
- # b = b * s
416
-
417
- # if x.shape[1:] != b.shape:
418
- # raise ValueError(f'Latent code shape {x.shape} and boundary shape '
419
- # f'{b.shape} mismatch!')
420
- num = x.shape[0]
421
- code_shape = x.shape[2:]
422
-
423
- x = x[:, np.newaxis]
424
- # b = b[np.newaxis, np.newaxis, :]
425
- # l = np.linspace(start_distance, end_distance, step).reshape(
426
- # [step if axis == 1 else 1 for axis in range(x.ndim)])
427
- results = np.tile(x, [step if axis == 1 else 1 for axis in range(x.ndim)])
428
- is_manipulatable = np.zeros(results.shape, dtype=bool)
429
- is_manipulatable[:, :, layer_indices] = True
430
-
431
- tmp=MPC(proj,x,mindex,start_distance,end_distance,step)
432
- tmp = tmp[:, :,np.newaxis]
433
- tmp1 = np.tile(tmp, [num_layers if axis == 2 else 1 for axis in range(tmp.ndim)])
434
-
435
-
436
- results = np.where(is_manipulatable, tmp1, results)
437
- # print(results.shape)
438
- assert results.shape == (num, step, num_layers, *code_shape)
439
- return results if layerwise_manipulation else results[:, :, 0]
440
-
441
- def MPC(proj,x,mindex,start_distance,end_distance,step):
442
- # x shape (batch_size,1,num_layers,feature)
443
- # print(x.shape)
444
- x1=proj.transform(x[:,0,0,:]) #/np.sqrt(proj.explained_variance_) # (batch_size,num_pc)
445
-
446
- x1 = x1[:, np.newaxis]
447
- x1 = np.tile(x1, [step if axis == 1 else 1 for axis in range(x1.ndim)])
448
-
449
-
450
- l = np.linspace(start_distance, end_distance, step)[None,:]
451
- x1[:,:,mindex]+=l
452
-
453
- tmp=x1.reshape((-1,x1.shape[-1])) #*np.sqrt(proj.explained_variance_)
454
- # print('xxx')
455
- x2=proj.inverse_transform(tmp)
456
- x2=x2.reshape((x1.shape[0],x1.shape[1],-1))
457
-
458
- # x1 = x1[:, np.newaxis]
459
- # x1 = np.tile(x1, [step if axis == 1 else 1 for axis in range(x1.ndim)])
460
-
461
- return x2
462
-
463
-
464
-
465
-
466
- def parse_boundary_list(boundary_list_path):
467
- """Parses boundary list.
468
-
469
- Sometimes, a text file containing a list of boundaries will significantly
470
- simplify image manipulation with a large amount of boundaries. This function
471
- is used to parse boundary information from such list file.
472
-
473
- Basically, each item in the list should be with format
474
- `($NAME, $SPACE_TYPE): $PATH`. `DISABLE` at the beginning of the line can
475
- disable a particular boundary.
476
-
477
- Sample:
478
-
479
- (age, z): $AGE_BOUNDARY_PATH
480
- (gender, w): $GENDER_BOUNDARY_PATH
481
- DISABLE(pose, wp): $POSE_BOUNDARY_PATH
482
-
483
- Args:
484
- boundary_list_path: Path to the boundary list.
485
-
486
- Returns:
487
- A dictionary, whose key is a two-element tuple (boundary_name, space_type)
488
- and value is the corresponding boundary path.
489
-
490
- Raise:
491
- ValueError: If the given boundary list does not exist.
492
- """
493
- if not os.path.isfile(boundary_list_path):
494
- raise ValueError(f'Boundary list `boundary_list_path` does not exist!')
495
-
496
- boundaries = {}
497
- with open(boundary_list_path, 'r') as f:
498
- for line in f:
499
- if line[:len('DISABLE')] == 'DISABLE':
500
- continue
501
- boundary_info, boundary_path = line.strip().split(':')
502
- boundary_name, space_type = boundary_info.strip()[1:-1].split(',')
503
- boundary_name = boundary_name.strip()
504
- space_type = space_type.strip().lower()
505
- boundary_path = boundary_path.strip()
506
- boundaries[(boundary_name, space_type)] = boundary_path
507
- return boundaries
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/utils/train_boundary.py DELETED
@@ -1,158 +0,0 @@
1
-
2
- import numpy as np
3
- from sklearn import svm
4
-
5
-
6
-
7
-
8
-
9
- def train_boundary(latent_codes,
10
- scores,
11
- chosen_num_or_ratio=0.02,
12
- split_ratio=0.7,
13
- invalid_value=None,
14
- logger=None,
15
- logger_name='train_boundary'):
16
- """Trains boundary in latent space with offline predicted attribute scores.
17
-
18
- Given a collection of latent codes and the attribute scores predicted from the
19
- corresponding images, this function will train a linear SVM by treating it as
20
- a bi-classification problem. Basically, the samples with highest attribute
21
- scores are treated as positive samples, while those with lowest scores as
22
- negative. For now, the latent code can ONLY be with 1 dimension.
23
-
24
- NOTE: The returned boundary is with shape (1, latent_space_dim), and also
25
- normalized with unit norm.
26
-
27
- Args:
28
- latent_codes: Input latent codes as training data.
29
- scores: Input attribute scores used to generate training labels.
30
- chosen_num_or_ratio: How many samples will be chosen as positive (negative)
31
- samples. If this field lies in range (0, 0.5], `chosen_num_or_ratio *
32
- latent_codes_num` will be used. Otherwise, `min(chosen_num_or_ratio,
33
- 0.5 * latent_codes_num)` will be used. (default: 0.02)
34
- split_ratio: Ratio to split training and validation sets. (default: 0.7)
35
- invalid_value: This field is used to filter out data. (default: None)
36
- logger: Logger for recording log messages. If set as `None`, a default
37
- logger, which prints messages from all levels to screen, will be created.
38
- (default: None)
39
-
40
- Returns:
41
- A decision boundary with type `numpy.ndarray`.
42
-
43
- Raises:
44
- ValueError: If the input `latent_codes` or `scores` are with invalid format.
45
- """
46
- # if not logger:
47
- # logger = setup_logger(work_dir='', logger_name=logger_name)
48
-
49
- if (not isinstance(latent_codes, np.ndarray) or
50
- not len(latent_codes.shape) == 2):
51
- raise ValueError(f'Input `latent_codes` should be with type'
52
- f'`numpy.ndarray`, and shape [num_samples, '
53
- f'latent_space_dim]!')
54
- num_samples = latent_codes.shape[0]
55
- latent_space_dim = latent_codes.shape[1]
56
- if (not isinstance(scores, np.ndarray) or not len(scores.shape) == 2 or
57
- not scores.shape[0] == num_samples or not scores.shape[1] == 1):
58
- raise ValueError(f'Input `scores` should be with type `numpy.ndarray`, and '
59
- f'shape [num_samples, 1], where `num_samples` should be '
60
- f'exactly same as that of input `latent_codes`!')
61
- if chosen_num_or_ratio <= 0:
62
- raise ValueError(f'Input `chosen_num_or_ratio` should be positive, '
63
- f'but {chosen_num_or_ratio} received!')
64
-
65
- # logger.info(f'Filtering training data.')
66
- print('Filtering training data.')
67
- if invalid_value is not None:
68
- latent_codes = latent_codes[scores[:, 0] != invalid_value]
69
- scores = scores[scores[:, 0] != invalid_value]
70
-
71
- # logger.info(f'Sorting scores to get positive and negative samples.')
72
- print('Sorting scores to get positive and negative samples.')
73
-
74
- sorted_idx = np.argsort(scores, axis=0)[::-1, 0]
75
- latent_codes = latent_codes[sorted_idx]
76
- scores = scores[sorted_idx]
77
- num_samples = latent_codes.shape[0]
78
- if 0 < chosen_num_or_ratio <= 1:
79
- chosen_num = int(num_samples * chosen_num_or_ratio)
80
- else:
81
- chosen_num = int(chosen_num_or_ratio)
82
- chosen_num = min(chosen_num, num_samples // 2)
83
-
84
- # logger.info(f'Spliting training and validation sets:')
85
- print('Filtering training data.')
86
-
87
- train_num = int(chosen_num * split_ratio)
88
- val_num = chosen_num - train_num
89
- # Positive samples.
90
- positive_idx = np.arange(chosen_num)
91
- np.random.shuffle(positive_idx)
92
- positive_train = latent_codes[:chosen_num][positive_idx[:train_num]]
93
- positive_val = latent_codes[:chosen_num][positive_idx[train_num:]]
94
- # Negative samples.
95
- negative_idx = np.arange(chosen_num)
96
- np.random.shuffle(negative_idx)
97
- negative_train = latent_codes[-chosen_num:][negative_idx[:train_num]]
98
- negative_val = latent_codes[-chosen_num:][negative_idx[train_num:]]
99
- # Training set.
100
- train_data = np.concatenate([positive_train, negative_train], axis=0)
101
- train_label = np.concatenate([np.ones(train_num, dtype=np.int),
102
- np.zeros(train_num, dtype=np.int)], axis=0)
103
- # logger.info(f' Training: {train_num} positive, {train_num} negative.')
104
- print(f' Training: {train_num} positive, {train_num} negative.')
105
- # Validation set.
106
- val_data = np.concatenate([positive_val, negative_val], axis=0)
107
- val_label = np.concatenate([np.ones(val_num, dtype=np.int),
108
- np.zeros(val_num, dtype=np.int)], axis=0)
109
- # logger.info(f' Validation: {val_num} positive, {val_num} negative.')
110
- print(f' Validation: {val_num} positive, {val_num} negative.')
111
-
112
- # Remaining set.
113
- remaining_num = num_samples - chosen_num * 2
114
- remaining_data = latent_codes[chosen_num:-chosen_num]
115
- remaining_scores = scores[chosen_num:-chosen_num]
116
- decision_value = (scores[0] + scores[-1]) / 2
117
- remaining_label = np.ones(remaining_num, dtype=np.int)
118
- remaining_label[remaining_scores.ravel() < decision_value] = 0
119
- remaining_positive_num = np.sum(remaining_label == 1)
120
- remaining_negative_num = np.sum(remaining_label == 0)
121
- # logger.info(f' Remaining: {remaining_positive_num} positive, '
122
- # f'{remaining_negative_num} negative.')
123
- print(f' Remaining: {remaining_positive_num} positive, '
124
- f'{remaining_negative_num} negative.')
125
- # logger.info(f'Training boundary.')
126
- print(f'Training boundary.')
127
-
128
- clf = svm.SVC(kernel='linear')
129
- classifier = clf.fit(train_data, train_label)
130
- # logger.info(f'Finish training.')
131
- print(f'Finish training.')
132
-
133
-
134
- if val_num:
135
- val_prediction = classifier.predict(val_data)
136
- correct_num = np.sum(val_label == val_prediction)
137
- # logger.info(f'Accuracy for validation set: '
138
- # f'{correct_num} / {val_num * 2} = '
139
- # f'{correct_num / (val_num * 2):.6f}')
140
- print(f'Accuracy for validation set: '
141
- f'{correct_num} / {val_num * 2} = '
142
- f'{correct_num / (val_num * 2):.6f}')
143
- vacc=correct_num/len(val_label)
144
- '''
145
- if remaining_num:
146
- remaining_prediction = classifier.predict(remaining_data)
147
- correct_num = np.sum(remaining_label == remaining_prediction)
148
- logger.info(f'Accuracy for remaining set: '
149
- f'{correct_num} / {remaining_num} = '
150
- f'{correct_num / remaining_num:.6f}')
151
- '''
152
- a = classifier.coef_.reshape(1, latent_space_dim).astype(np.float32)
153
- return a / np.linalg.norm(a),vacc
154
-
155
-
156
-
157
-
158
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/global_directions/utils/visualizer.py DELETED
@@ -1,605 +0,0 @@
1
- # python 3.7
2
- """Utility functions for visualizing results on html page."""
3
-
4
- import base64
5
- import os.path
6
- import cv2
7
- import numpy as np
8
-
9
- __all__ = [
10
- 'get_grid_shape', 'get_blank_image', 'load_image', 'save_image',
11
- 'resize_image', 'add_text_to_image', 'fuse_images', 'HtmlPageVisualizer',
12
- 'VideoReader', 'VideoWriter', 'adjust_pixel_range'
13
- ]
14
-
15
-
16
- def adjust_pixel_range(images, min_val=-1.0, max_val=1.0, channel_order='NCHW'):
17
- """Adjusts the pixel range of the input images.
18
-
19
- This function assumes the input array (image batch) is with shape [batch_size,
20
- channel, height, width] if `channel_order = NCHW`, or with shape [batch_size,
21
- height, width] if `channel_order = NHWC`. The returned images are with shape
22
- [batch_size, height, width, channel] and pixel range [0, 255].
23
-
24
- NOTE: The channel order of output images will remain the same as the input.
25
-
26
- Args:
27
- images: Input images to adjust pixel range.
28
- min_val: Min value of the input images. (default: -1.0)
29
- max_val: Max value of the input images. (default: 1.0)
30
- channel_order: Channel order of the input array. (default: NCHW)
31
-
32
- Returns:
33
- The postprocessed images with dtype `numpy.uint8` and range [0, 255].
34
-
35
- Raises:
36
- ValueError: If the input `images` are not with type `numpy.ndarray` or the
37
- shape is invalid according to `channel_order`.
38
- """
39
- if not isinstance(images, np.ndarray):
40
- raise ValueError(f'Images should be with type `numpy.ndarray`!')
41
-
42
- channel_order = channel_order.upper()
43
- if channel_order not in ['NCHW', 'NHWC']:
44
- raise ValueError(f'Invalid channel order `{channel_order}`!')
45
-
46
- if images.ndim != 4:
47
- raise ValueError(f'Input images are expected to be with shape `NCHW` or '
48
- f'`NHWC`, but `{images.shape}` is received!')
49
- if channel_order == 'NCHW' and images.shape[1] not in [1, 3]:
50
- raise ValueError(f'Input images should have 1 or 3 channels under `NCHW` '
51
- f'channel order!')
52
- if channel_order == 'NHWC' and images.shape[3] not in [1, 3]:
53
- raise ValueError(f'Input images should have 1 or 3 channels under `NHWC` '
54
- f'channel order!')
55
-
56
- images = images.astype(np.float32)
57
- images = (images - min_val) * 255 / (max_val - min_val)
58
- images = np.clip(images + 0.5, 0, 255).astype(np.uint8)
59
- if channel_order == 'NCHW':
60
- images = images.transpose(0, 2, 3, 1)
61
-
62
- return images
63
-
64
-
65
- def get_grid_shape(size, row=0, col=0, is_portrait=False):
66
- """Gets the shape of a grid based on the size.
67
-
68
- This function makes greatest effort on making the output grid square if
69
- neither `row` nor `col` is set. If `is_portrait` is set as `False`, the height
70
- will always be equal to or smaller than the width. For example, if input
71
- `size = 16`, output shape will be `(4, 4)`; if input `size = 15`, output shape
72
- will be (3, 5). Otherwise, the height will always be equal to or larger than
73
- the width.
74
-
75
- Args:
76
- size: Size (height * width) of the target grid.
77
- is_portrait: Whether to return a portrait size of a landscape size.
78
- (default: False)
79
-
80
- Returns:
81
- A two-element tuple, representing height and width respectively.
82
- """
83
- assert isinstance(size, int)
84
- assert isinstance(row, int)
85
- assert isinstance(col, int)
86
- if size == 0:
87
- return (0, 0)
88
-
89
- if row > 0 and col > 0 and row * col != size:
90
- row = 0
91
- col = 0
92
-
93
- if row > 0 and size % row == 0:
94
- return (row, size // row)
95
- if col > 0 and size % col == 0:
96
- return (size // col, col)
97
-
98
- row = int(np.sqrt(size))
99
- while row > 0:
100
- if size % row == 0:
101
- col = size // row
102
- break
103
- row = row - 1
104
-
105
- return (col, row) if is_portrait else (row, col)
106
-
107
-
108
- def get_blank_image(height, width, channels=3, is_black=True):
109
- """Gets a blank image, either white of black.
110
-
111
- NOTE: This function will always return an image with `RGB` channel order for
112
- color image and pixel range [0, 255].
113
-
114
- Args:
115
- height: Height of the returned image.
116
- width: Width of the returned image.
117
- channels: Number of channels. (default: 3)
118
- is_black: Whether to return a black image or white image. (default: True)
119
- """
120
- shape = (height, width, channels)
121
- if is_black:
122
- return np.zeros(shape, dtype=np.uint8)
123
- return np.ones(shape, dtype=np.uint8) * 255
124
-
125
-
126
- def load_image(path):
127
- """Loads an image from disk.
128
-
129
- NOTE: This function will always return an image with `RGB` channel order for
130
- color image and pixel range [0, 255].
131
-
132
- Args:
133
- path: Path to load the image from.
134
-
135
- Returns:
136
- An image with dtype `np.ndarray` or `None` if input `path` does not exist.
137
- """
138
- if not os.path.isfile(path):
139
- return None
140
-
141
- image = cv2.imread(path)
142
- return image[:, :, ::-1]
143
-
144
-
145
- def save_image(path, image):
146
- """Saves an image to disk.
147
-
148
- NOTE: The input image (if colorful) is assumed to be with `RGB` channel order
149
- and pixel range [0, 255].
150
-
151
- Args:
152
- path: Path to save the image to.
153
- image: Image to save.
154
- """
155
- if image is None:
156
- return
157
-
158
- assert len(image.shape) == 3 and image.shape[2] in [1, 3]
159
- cv2.imwrite(path, image[:, :, ::-1])
160
-
161
-
162
- def resize_image(image, *args, **kwargs):
163
- """Resizes image.
164
-
165
- This is a wrap of `cv2.resize()`.
166
-
167
- NOTE: THe channel order of the input image will not be changed.
168
-
169
- Args:
170
- image: Image to resize.
171
- """
172
- if image is None:
173
- return None
174
-
175
- assert image.ndim == 3 and image.shape[2] in [1, 3]
176
- image = cv2.resize(image, *args, **kwargs)
177
- if image.ndim == 2:
178
- return image[:, :, np.newaxis]
179
- return image
180
-
181
-
182
- def add_text_to_image(image,
183
- text='',
184
- position=None,
185
- font=cv2.FONT_HERSHEY_TRIPLEX,
186
- font_size=1.0,
187
- line_type=cv2.LINE_8,
188
- line_width=1,
189
- color=(255, 255, 255)):
190
- """Overlays text on given image.
191
-
192
- NOTE: The input image is assumed to be with `RGB` channel order.
193
-
194
- Args:
195
- image: The image to overlay text on.
196
- text: Text content to overlay on the image. (default: '')
197
- position: Target position (bottom-left corner) to add text. If not set,
198
- center of the image will be used by default. (default: None)
199
- font: Font of the text added. (default: cv2.FONT_HERSHEY_TRIPLEX)
200
- font_size: Font size of the text added. (default: 1.0)
201
- line_type: Line type used to depict the text. (default: cv2.LINE_8)
202
- line_width: Line width used to depict the text. (default: 1)
203
- color: Color of the text added in `RGB` channel order. (default:
204
- (255, 255, 255))
205
-
206
- Returns:
207
- An image with target text overlayed on.
208
- """
209
- if image is None or not text:
210
- return image
211
-
212
- cv2.putText(img=image,
213
- text=text,
214
- org=position,
215
- fontFace=font,
216
- fontScale=font_size,
217
- color=color,
218
- thickness=line_width,
219
- lineType=line_type,
220
- bottomLeftOrigin=False)
221
-
222
- return image
223
-
224
-
225
- def fuse_images(images,
226
- image_size=None,
227
- row=0,
228
- col=0,
229
- is_row_major=True,
230
- is_portrait=False,
231
- row_spacing=0,
232
- col_spacing=0,
233
- border_left=0,
234
- border_right=0,
235
- border_top=0,
236
- border_bottom=0,
237
- black_background=True):
238
- """Fuses a collection of images into an entire image.
239
-
240
- Args:
241
- images: A collection of images to fuse. Should be with shape [num, height,
242
- width, channels].
243
- image_size: Int or two-element tuple. This field is used to resize the image
244
- before fusing. `None` disables resizing. (default: None)
245
- row: Number of rows used for image fusion. If not set, this field will be
246
- automatically assigned based on `col` and total number of images.
247
- (default: None)
248
- col: Number of columns used for image fusion. If not set, this field will be
249
- automatically assigned based on `row` and total number of images.
250
- (default: None)
251
- is_row_major: Whether the input images should be arranged row-major or
252
- column-major. (default: True)
253
- is_portrait: Only active when both `row` and `col` should be assigned
254
- automatically. (default: False)
255
- row_spacing: Space between rows. (default: 0)
256
- col_spacing: Space between columns. (default: 0)
257
- border_left: Width of left border. (default: 0)
258
- border_right: Width of right border. (default: 0)
259
- border_top: Width of top border. (default: 0)
260
- border_bottom: Width of bottom border. (default: 0)
261
-
262
- Returns:
263
- The fused image.
264
-
265
- Raises:
266
- ValueError: If the input `images` is not with shape [num, height, width,
267
- width].
268
- """
269
- if images is None:
270
- return images
271
-
272
- if not images.ndim == 4:
273
- raise ValueError(f'Input `images` should be with shape [num, height, '
274
- f'width, channels], but {images.shape} is received!')
275
-
276
- num, image_height, image_width, channels = images.shape
277
- if image_size is not None:
278
- if isinstance(image_size, int):
279
- image_size = (image_size, image_size)
280
- assert isinstance(image_size, (list, tuple)) and len(image_size) == 2
281
- width, height = image_size
282
- else:
283
- height, width = image_height, image_width
284
- row, col = get_grid_shape(num, row=row, col=col, is_portrait=is_portrait)
285
- fused_height = (
286
- height * row + row_spacing * (row - 1) + border_top + border_bottom)
287
- fused_width = (
288
- width * col + col_spacing * (col - 1) + border_left + border_right)
289
- fused_image = get_blank_image(
290
- fused_height, fused_width, channels=channels, is_black=black_background)
291
- images = images.reshape(row, col, image_height, image_width, channels)
292
- if not is_row_major:
293
- images = images.transpose(1, 0, 2, 3, 4)
294
-
295
- for i in range(row):
296
- y = border_top + i * (height + row_spacing)
297
- for j in range(col):
298
- x = border_left + j * (width + col_spacing)
299
- if image_size is not None:
300
- image = cv2.resize(images[i, j], image_size)
301
- else:
302
- image = images[i, j]
303
- fused_image[y:y + height, x:x + width] = image
304
-
305
- return fused_image
306
-
307
-
308
- def get_sortable_html_header(column_name_list, sort_by_ascending=False):
309
- """Gets header for sortable html page.
310
-
311
- Basically, the html page contains a sortable table, where user can sort the
312
- rows by a particular column by clicking the column head.
313
-
314
- Example:
315
-
316
- column_name_list = [name_1, name_2, name_3]
317
- header = get_sortable_html_header(column_name_list)
318
- footer = get_sortable_html_footer()
319
- sortable_table = ...
320
- html_page = header + sortable_table + footer
321
-
322
- Args:
323
- column_name_list: List of column header names.
324
- sort_by_ascending: Default sorting order. If set as `True`, the html page
325
- will be sorted by ascending order when the header is clicked for the first
326
- time.
327
-
328
- Returns:
329
- A string, which represents for the header for a sortable html page.
330
- """
331
- header = '\n'.join([
332
- '<script type="text/javascript">',
333
- 'var column_idx;',
334
- 'var sort_by_ascending = ' + str(sort_by_ascending).lower() + ';',
335
- '',
336
- 'function sorting(tbody, column_idx){',
337
- ' this.column_idx = column_idx;',
338
- ' Array.from(tbody.rows)',
339
- ' .sort(compareCells)',
340
- ' .forEach(function(row) { tbody.appendChild(row); })',
341
- ' sort_by_ascending = !sort_by_ascending;',
342
- '}',
343
- '',
344
- 'function compareCells(row_a, row_b) {',
345
- ' var val_a = row_a.cells[column_idx].innerText;',
346
- ' var val_b = row_b.cells[column_idx].innerText;',
347
- ' var flag = sort_by_ascending ? 1 : -1;',
348
- ' return flag * (val_a > val_b ? 1 : -1);',
349
- '}',
350
- '</script>',
351
- '',
352
- '<html>',
353
- '',
354
- '<head>',
355
- '<style>',
356
- ' table {',
357
- ' border-spacing: 0;',
358
- ' border: 1px solid black;',
359
- ' }',
360
- ' th {',
361
- ' cursor: pointer;',
362
- ' }',
363
- ' th, td {',
364
- ' text-align: left;',
365
- ' vertical-align: middle;',
366
- ' border-collapse: collapse;',
367
- ' border: 0.5px solid black;',
368
- ' padding: 8px;',
369
- ' }',
370
- ' tr:nth-child(even) {',
371
- ' background-color: #d2d2d2;',
372
- ' }',
373
- '</style>',
374
- '</head>',
375
- '',
376
- '<body>',
377
- '',
378
- '<table>',
379
- '<thead>',
380
- '<tr>',
381
- ''])
382
- for idx, column_name in enumerate(column_name_list):
383
- header += f' <th onclick="sorting(tbody, {idx})">{column_name}</th>\n'
384
- header += '</tr>\n'
385
- header += '</thead>\n'
386
- header += '<tbody id="tbody">\n'
387
-
388
- return header
389
-
390
-
391
- def get_sortable_html_footer():
392
- """Gets footer for sortable html page.
393
-
394
- Check function `get_sortable_html_header()` for more details.
395
- """
396
- return '</tbody>\n</table>\n\n</body>\n</html>\n'
397
-
398
-
399
- def encode_image_to_html_str(image, image_size=None):
400
- """Encodes an image to html language.
401
-
402
- Args:
403
- image: The input image to encode. Should be with `RGB` channel order.
404
- image_size: Int or two-element tuple. This field is used to resize the image
405
- before encoding. `None` disables resizing. (default: None)
406
-
407
- Returns:
408
- A string which represents the encoded image.
409
- """
410
- if image is None:
411
- return ''
412
-
413
- assert len(image.shape) == 3 and image.shape[2] in [1, 3]
414
-
415
- # Change channel order to `BGR`, which is opencv-friendly.
416
- image = image[:, :, ::-1]
417
-
418
- # Resize the image if needed.
419
- if image_size is not None:
420
- if isinstance(image_size, int):
421
- image_size = (image_size, image_size)
422
- assert isinstance(image_size, (list, tuple)) and len(image_size) == 2
423
- image = cv2.resize(image, image_size)
424
-
425
- # Encode the image to html-format string.
426
- encoded_image = cv2.imencode(".jpg", image)[1].tostring()
427
- encoded_image_base64 = base64.b64encode(encoded_image).decode('utf-8')
428
- html_str = f'<img src="data:image/jpeg;base64, {encoded_image_base64}"/>'
429
-
430
- return html_str
431
-
432
-
433
- class HtmlPageVisualizer(object):
434
- """Defines the html page visualizer.
435
-
436
- This class can be used to visualize image results as html page. Basically, it
437
- is based on an html-format sorted table with helper functions
438
- `get_sortable_html_header()`, `get_sortable_html_footer()`, and
439
- `encode_image_to_html_str()`. To simplify the usage, specifying the following
440
- fields is enough to create a visualization page:
441
-
442
- (1) num_rows: Number of rows of the table (header-row exclusive).
443
- (2) num_cols: Number of columns of the table.
444
- (3) header contents (optional): Title of each column.
445
-
446
- NOTE: `grid_size` can be used to assign `num_rows` and `num_cols`
447
- automatically.
448
-
449
- Example:
450
-
451
- html = HtmlPageVisualizer(num_rows, num_cols)
452
- html.set_headers([...])
453
- for i in range(num_rows):
454
- for j in range(num_cols):
455
- html.set_cell(i, j, text=..., image=...)
456
- html.save('visualize.html')
457
- """
458
-
459
- def __init__(self,
460
- num_rows=0,
461
- num_cols=0,
462
- grid_size=0,
463
- is_portrait=False,
464
- viz_size=None):
465
- if grid_size > 0:
466
- num_rows, num_cols = get_grid_shape(
467
- grid_size, row=num_rows, col=num_cols, is_portrait=is_portrait)
468
- assert num_rows > 0 and num_cols > 0
469
-
470
- self.num_rows = num_rows
471
- self.num_cols = num_cols
472
- self.viz_size = viz_size
473
- self.headers = ['' for _ in range(self.num_cols)]
474
- self.cells = [[{
475
- 'text': '',
476
- 'image': '',
477
- } for _ in range(self.num_cols)] for _ in range(self.num_rows)]
478
-
479
- def set_header(self, column_idx, content):
480
- """Sets the content of a particular header by column index."""
481
- self.headers[column_idx] = content
482
-
483
- def set_headers(self, contents):
484
- """Sets the contents of all headers."""
485
- if isinstance(contents, str):
486
- contents = [contents]
487
- assert isinstance(contents, (list, tuple))
488
- assert len(contents) == self.num_cols
489
- for column_idx, content in enumerate(contents):
490
- self.set_header(column_idx, content)
491
-
492
- def set_cell(self, row_idx, column_idx, text='', image=None):
493
- """Sets the content of a particular cell.
494
-
495
- Basically, a cell contains some text as well as an image. Both text and
496
- image can be empty.
497
-
498
- Args:
499
- row_idx: Row index of the cell to edit.
500
- column_idx: Column index of the cell to edit.
501
- text: Text to add into the target cell.
502
- image: Image to show in the target cell. Should be with `RGB` channel
503
- order.
504
- """
505
- self.cells[row_idx][column_idx]['text'] = text
506
- self.cells[row_idx][column_idx]['image'] = encode_image_to_html_str(
507
- image, self.viz_size)
508
-
509
- def save(self, save_path):
510
- """Saves the html page."""
511
- html = ''
512
- for i in range(self.num_rows):
513
- html += f'<tr>\n'
514
- for j in range(self.num_cols):
515
- text = self.cells[i][j]['text']
516
- image = self.cells[i][j]['image']
517
- if text:
518
- html += f' <td>{text}<br><br>{image}</td>\n'
519
- else:
520
- html += f' <td>{image}</td>\n'
521
- html += f'</tr>\n'
522
-
523
- header = get_sortable_html_header(self.headers)
524
- footer = get_sortable_html_footer()
525
-
526
- with open(save_path, 'w') as f:
527
- f.write(header + html + footer)
528
-
529
-
530
- class VideoReader(object):
531
- """Defines the video reader.
532
-
533
- This class can be used to read frames from a given video.
534
- """
535
-
536
- def __init__(self, path):
537
- """Initializes the video reader by loading the video from disk."""
538
- if not os.path.isfile(path):
539
- raise ValueError(f'Video `{path}` does not exist!')
540
-
541
- self.path = path
542
- self.video = cv2.VideoCapture(path)
543
- assert self.video.isOpened()
544
- self.position = 0
545
-
546
- self.length = int(self.video.get(cv2.CAP_PROP_FRAME_COUNT))
547
- self.frame_height = int(self.video.get(cv2.CAP_PROP_FRAME_HEIGHT))
548
- self.frame_width = int(self.video.get(cv2.CAP_PROP_FRAME_WIDTH))
549
- self.fps = self.video.get(cv2.CAP_PROP_FPS)
550
-
551
- def __del__(self):
552
- """Releases the opened video."""
553
- self.video.release()
554
-
555
- def read(self, position=None):
556
- """Reads a certain frame.
557
-
558
- NOTE: The returned frame is assumed to be with `RGB` channel order.
559
-
560
- Args:
561
- position: Optional. If set, the reader will read frames from the exact
562
- position. Otherwise, the reader will read next frames. (default: None)
563
- """
564
- if position is not None and position < self.length:
565
- self.video.set(cv2.CAP_PROP_POS_FRAMES, position)
566
- self.position = position
567
-
568
- success, frame = self.video.read()
569
- self.position = self.position + 1
570
-
571
- return frame[:, :, ::-1] if success else None
572
-
573
-
574
- class VideoWriter(object):
575
- """Defines the video writer.
576
-
577
- This class can be used to create a video.
578
-
579
- NOTE: `.avi` and `DIVX` is the most recommended codec format since it does not
580
- rely on other dependencies.
581
- """
582
-
583
- def __init__(self, path, frame_height, frame_width, fps=24, codec='DIVX'):
584
- """Creates the video writer."""
585
- self.path = path
586
- self.frame_height = frame_height
587
- self.frame_width = frame_width
588
- self.fps = fps
589
- self.codec = codec
590
-
591
- self.video = cv2.VideoWriter(filename=path,
592
- fourcc=cv2.VideoWriter_fourcc(*codec),
593
- fps=fps,
594
- frameSize=(frame_width, frame_height))
595
-
596
- def __del__(self):
597
- """Releases the opened video."""
598
- self.video.release()
599
-
600
- def write(self, frame):
601
- """Writes a target frame.
602
-
603
- NOTE: The input frame is assumed to be with `RGB` channel order.
604
- """
605
- self.video.write(frame[:, :, ::-1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/mapper/__init__.py DELETED
File without changes
PTI/models/StyleCLIP/mapper/datasets/__init__.py DELETED
File without changes
PTI/models/StyleCLIP/mapper/datasets/latents_dataset.py DELETED
@@ -1,15 +0,0 @@
1
- from torch.utils.data import Dataset
2
-
3
-
4
- class LatentsDataset(Dataset):
5
-
6
- def __init__(self, latents, opts):
7
- self.latents = latents
8
- self.opts = opts
9
-
10
- def __len__(self):
11
- return self.latents.shape[0]
12
-
13
- def __getitem__(self, index):
14
-
15
- return self.latents[index]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/mapper/latent_mappers.py DELETED
@@ -1,81 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn import Module
4
-
5
- from models.StyleCLIP.models.stylegan2.model import EqualLinear, PixelNorm
6
-
7
-
8
- class Mapper(Module):
9
-
10
- def __init__(self, opts):
11
- super(Mapper, self).__init__()
12
-
13
- self.opts = opts
14
- layers = [PixelNorm()]
15
-
16
- for i in range(4):
17
- layers.append(
18
- EqualLinear(
19
- 512, 512, lr_mul=0.01, activation='fused_lrelu'
20
- )
21
- )
22
-
23
- self.mapping = nn.Sequential(*layers)
24
-
25
-
26
- def forward(self, x):
27
- x = self.mapping(x)
28
- return x
29
-
30
-
31
- class SingleMapper(Module):
32
-
33
- def __init__(self, opts):
34
- super(SingleMapper, self).__init__()
35
-
36
- self.opts = opts
37
-
38
- self.mapping = Mapper(opts)
39
-
40
- def forward(self, x):
41
- out = self.mapping(x)
42
- return out
43
-
44
-
45
- class LevelsMapper(Module):
46
-
47
- def __init__(self, opts):
48
- super(LevelsMapper, self).__init__()
49
-
50
- self.opts = opts
51
-
52
- if not opts.no_coarse_mapper:
53
- self.course_mapping = Mapper(opts)
54
- if not opts.no_medium_mapper:
55
- self.medium_mapping = Mapper(opts)
56
- if not opts.no_fine_mapper:
57
- self.fine_mapping = Mapper(opts)
58
-
59
- def forward(self, x):
60
- x_coarse = x[:, :4, :]
61
- x_medium = x[:, 4:8, :]
62
- x_fine = x[:, 8:, :]
63
-
64
- if not self.opts.no_coarse_mapper:
65
- x_coarse = self.course_mapping(x_coarse)
66
- else:
67
- x_coarse = torch.zeros_like(x_coarse)
68
- if not self.opts.no_medium_mapper:
69
- x_medium = self.medium_mapping(x_medium)
70
- else:
71
- x_medium = torch.zeros_like(x_medium)
72
- if not self.opts.no_fine_mapper:
73
- x_fine = self.fine_mapping(x_fine)
74
- else:
75
- x_fine = torch.zeros_like(x_fine)
76
-
77
-
78
- out = torch.cat([x_coarse, x_medium, x_fine], dim=1)
79
-
80
- return out
81
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/mapper/options/__init__.py DELETED
File without changes
PTI/models/StyleCLIP/mapper/options/test_options.py DELETED
@@ -1,42 +0,0 @@
1
- from argparse import ArgumentParser
2
-
3
-
4
- class TestOptions:
5
-
6
- def __init__(self):
7
- self.parser = ArgumentParser()
8
- self.initialize()
9
-
10
- def initialize(self):
11
- # arguments for inference script
12
- self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory')
13
- self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to model checkpoint')
14
- self.parser.add_argument('--couple_outputs', action='store_true',
15
- help='Whether to also save inputs + outputs side-by-side')
16
-
17
- self.parser.add_argument('--mapper_type', default='LevelsMapper', type=str, help='Which mapper to use')
18
- self.parser.add_argument('--no_coarse_mapper', default=False, action="store_true")
19
- self.parser.add_argument('--no_medium_mapper', default=False, action="store_true")
20
- self.parser.add_argument('--no_fine_mapper', default=False, action="store_true")
21
- self.parser.add_argument('--stylegan_size', default=1024, type=int)
22
-
23
- self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference')
24
- self.parser.add_argument('--latents_test_path', default=None, type=str, help="The latents for the validation")
25
- self.parser.add_argument('--test_workers', default=2, type=int,
26
- help='Number of test/inference dataloader workers')
27
-
28
- self.parser.add_argument('--n_images', type=int, default=None,
29
- help='Number of images to output. If None, run on all data')
30
-
31
- self.parser.add_argument('--run_id', type=str, default='PKNWUQRQRKXQ',
32
- help='The generator id to use')
33
-
34
- self.parser.add_argument('--image_name', type=str, default='',
35
- help='image to run on')
36
-
37
- self.parser.add_argument('--edit_name', type=str, default='',
38
- help='edit type')
39
-
40
- def parse(self):
41
- opts = self.parser.parse_args()
42
- return opts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/mapper/options/train_options.py DELETED
@@ -1,49 +0,0 @@
1
- from argparse import ArgumentParser
2
-
3
-
4
- class TrainOptions:
5
-
6
- def __init__(self):
7
- self.parser = ArgumentParser()
8
- self.initialize()
9
-
10
- def initialize(self):
11
- self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory')
12
- self.parser.add_argument('--mapper_type', default='LevelsMapper', type=str, help='Which mapper to use')
13
- self.parser.add_argument('--no_coarse_mapper', default=False, action="store_true")
14
- self.parser.add_argument('--no_medium_mapper', default=False, action="store_true")
15
- self.parser.add_argument('--no_fine_mapper', default=False, action="store_true")
16
- self.parser.add_argument('--latents_train_path', default="train_faces.pt", type=str, help="The latents for the training")
17
- self.parser.add_argument('--latents_test_path', default="test_faces.pt", type=str, help="The latents for the validation")
18
- self.parser.add_argument('--train_dataset_size', default=5000, type=int, help="Will be used only if no latents are given")
19
- self.parser.add_argument('--test_dataset_size', default=1000, type=int, help="Will be used only if no latents are given")
20
-
21
- self.parser.add_argument('--batch_size', default=2, type=int, help='Batch size for training')
22
- self.parser.add_argument('--test_batch_size', default=1, type=int, help='Batch size for testing and inference')
23
- self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers')
24
- self.parser.add_argument('--test_workers', default=2, type=int, help='Number of test/inference dataloader workers')
25
-
26
- self.parser.add_argument('--learning_rate', default=0.5, type=float, help='Optimizer learning rate')
27
- self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use')
28
-
29
- self.parser.add_argument('--id_lambda', default=0.1, type=float, help='ID loss multiplier factor')
30
- self.parser.add_argument('--clip_lambda', default=1.0, type=float, help='CLIP loss multiplier factor')
31
- self.parser.add_argument('--latent_l2_lambda', default=0.8, type=float, help='Latent L2 loss multiplier factor')
32
-
33
- self.parser.add_argument('--stylegan_weights', default='../pretrained_models/stylegan2-ffhq-config-f.pt', type=str, help='Path to StyleGAN model weights')
34
- self.parser.add_argument('--stylegan_size', default=1024, type=int)
35
- self.parser.add_argument('--ir_se50_weights', default='../pretrained_models/model_ir_se50.pth', type=str, help="Path to facial recognition network used in ID loss")
36
- self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to StyleCLIPModel model checkpoint')
37
-
38
- self.parser.add_argument('--max_steps', default=50000, type=int, help='Maximum number of training steps')
39
- self.parser.add_argument('--image_interval', default=100, type=int, help='Interval for logging train images during training')
40
- self.parser.add_argument('--board_interval', default=50, type=int, help='Interval for logging metrics to tensorboard')
41
- self.parser.add_argument('--val_interval', default=2000, type=int, help='Validation interval')
42
- self.parser.add_argument('--save_interval', default=2000, type=int, help='Model checkpoint interval')
43
-
44
- self.parser.add_argument('--description', required=True, type=str, help='Driving text prompt')
45
-
46
-
47
- def parse(self):
48
- opts = self.parser.parse_args()
49
- return opts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/mapper/scripts/inference.py DELETED
@@ -1,80 +0,0 @@
1
- import os
2
- import pickle
3
- from argparse import Namespace
4
- import torchvision
5
- import torch
6
- import sys
7
- import time
8
-
9
- from configs import paths_config, global_config
10
- from models.StyleCLIP.mapper.styleclip_mapper import StyleCLIPMapper
11
- from utils.models_utils import load_tuned_G, load_old_G
12
-
13
- sys.path.append(".")
14
- sys.path.append("..")
15
-
16
-
17
- def run(test_opts, model_id, image_name, use_multi_id_G):
18
- out_path_results = os.path.join(test_opts.exp_dir, test_opts.data_dir_name)
19
- os.makedirs(out_path_results, exist_ok=True)
20
- out_path_results = os.path.join(out_path_results, test_opts.image_name)
21
- os.makedirs(out_path_results, exist_ok=True)
22
-
23
- # update test configs with configs used during training
24
- ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
25
- opts = ckpt['opts']
26
- opts.update(vars(test_opts))
27
- opts = Namespace(**opts)
28
-
29
- net = StyleCLIPMapper(opts, test_opts.run_id)
30
- net.eval()
31
- net.to(global_config.device)
32
-
33
- generator_type = paths_config.multi_id_model_type if use_multi_id_G else image_name
34
-
35
- new_G = load_tuned_G(model_id, generator_type)
36
- old_G = load_old_G()
37
-
38
- run_styleclip(net, new_G, opts, paths_config.pti_results_keyword, out_path_results, test_opts)
39
- run_styleclip(net, old_G, opts, paths_config.e4e_results_keyword, out_path_results, test_opts)
40
-
41
-
42
- def run_styleclip(net, G, opts, method, out_path_results, test_opts):
43
- net.set_G(G)
44
-
45
- out_path_results = os.path.join(out_path_results, method)
46
- os.makedirs(out_path_results, exist_ok=True)
47
-
48
- latent = torch.load(opts.latents_test_path)
49
-
50
- global_i = 0
51
- global_time = []
52
- with torch.no_grad():
53
- input_cuda = latent.cuda().float()
54
- tic = time.time()
55
- result_batch = run_on_batch(input_cuda, net, test_opts.couple_outputs)
56
- toc = time.time()
57
- global_time.append(toc - tic)
58
-
59
- for i in range(opts.test_batch_size):
60
- im_path = f'{test_opts.image_name}_{test_opts.edit_name}'
61
- if test_opts.couple_outputs:
62
- couple_output = torch.cat([result_batch[2][i].unsqueeze(0), result_batch[0][i].unsqueeze(0)])
63
- torchvision.utils.save_image(couple_output, os.path.join(out_path_results, f"{im_path}.jpg"),
64
- normalize=True, range=(-1, 1))
65
- else:
66
- torchvision.utils.save_image(result_batch[0][i], os.path.join(out_path_results, f"{im_path}.jpg"),
67
- normalize=True, range=(-1, 1))
68
- torch.save(result_batch[1][i].detach().cpu(), os.path.join(out_path_results, f"latent_{im_path}.pt"))
69
-
70
-
71
- def run_on_batch(inputs, net, couple_outputs=False):
72
- w = inputs
73
- with torch.no_grad():
74
- w_hat = w + 0.06 * net.mapper(w)
75
- x_hat = net.decoder.synthesis(w_hat, noise_mode='const', force_fp32=True)
76
- result_batch = (x_hat, w_hat)
77
- if couple_outputs:
78
- x = net.decoder.synthesis(w, noise_mode='const', force_fp32=True)
79
- result_batch = (x_hat, w_hat, x)
80
- return result_batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/mapper/scripts/train.py DELETED
@@ -1,32 +0,0 @@
1
- """
2
- This file runs the main training/val loop
3
- """
4
- import os
5
- import json
6
- import sys
7
- import pprint
8
-
9
- sys.path.append(".")
10
- sys.path.append("..")
11
-
12
- from mapper.options.train_options import TrainOptions
13
- from mapper.training.coach import Coach
14
-
15
-
16
- def main(opts):
17
- if os.path.exists(opts.exp_dir):
18
- raise Exception('Oops... {} already exists'.format(opts.exp_dir))
19
- os.makedirs(opts.exp_dir, exist_ok=True)
20
-
21
- opts_dict = vars(opts)
22
- pprint.pprint(opts_dict)
23
- with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f:
24
- json.dump(opts_dict, f, indent=4, sort_keys=True)
25
-
26
- coach = Coach(opts)
27
- coach.train()
28
-
29
-
30
- if __name__ == '__main__':
31
- opts = TrainOptions().parse()
32
- main(opts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/mapper/styleclip_mapper.py DELETED
@@ -1,76 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from models.StyleCLIP.mapper import latent_mappers
4
- from models.StyleCLIP.models.stylegan2.model import Generator
5
-
6
-
7
- def get_keys(d, name):
8
- if 'state_dict' in d:
9
- d = d['state_dict']
10
- d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
11
- return d_filt
12
-
13
-
14
- class StyleCLIPMapper(nn.Module):
15
-
16
- def __init__(self, opts, run_id):
17
- super(StyleCLIPMapper, self).__init__()
18
- self.opts = opts
19
- # Define architecture
20
- self.mapper = self.set_mapper()
21
- self.run_id = run_id
22
-
23
- self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
24
- # Load weights if needed
25
- self.load_weights()
26
-
27
- def set_mapper(self):
28
- if self.opts.mapper_type == 'SingleMapper':
29
- mapper = latent_mappers.SingleMapper(self.opts)
30
- elif self.opts.mapper_type == 'LevelsMapper':
31
- mapper = latent_mappers.LevelsMapper(self.opts)
32
- else:
33
- raise Exception('{} is not a valid mapper'.format(self.opts.mapper_type))
34
- return mapper
35
-
36
- def load_weights(self):
37
- if self.opts.checkpoint_path is not None:
38
- print('Loading from checkpoint: {}'.format(self.opts.checkpoint_path))
39
- ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
40
- self.mapper.load_state_dict(get_keys(ckpt, 'mapper'), strict=True)
41
-
42
- def set_G(self, new_G):
43
- self.decoder = new_G
44
-
45
- def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
46
- inject_latent=None, return_latents=False, alpha=None):
47
- if input_code:
48
- codes = x
49
- else:
50
- codes = self.mapper(x)
51
-
52
- if latent_mask is not None:
53
- for i in latent_mask:
54
- if inject_latent is not None:
55
- if alpha is not None:
56
- codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
57
- else:
58
- codes[:, i] = inject_latent[:, i]
59
- else:
60
- codes[:, i] = 0
61
-
62
- input_is_latent = not input_code
63
- images = self.decoder.synthesis(codes, noise_mode='const')
64
- result_latent = None
65
- # images, result_latent = self.decoder([codes],
66
- # input_is_latent=input_is_latent,
67
- # randomize_noise=randomize_noise,
68
- # return_latents=return_latents)
69
-
70
- if resize:
71
- images = self.face_pool(images)
72
-
73
- if return_latents:
74
- return images, result_latent
75
- else:
76
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/mapper/training/__init__.py DELETED
File without changes
PTI/models/StyleCLIP/mapper/training/coach.py DELETED
@@ -1,242 +0,0 @@
1
- import os
2
-
3
- import clip
4
- import torch
5
- import torchvision
6
- from torch import nn
7
- from torch.utils.data import DataLoader
8
- from torch.utils.tensorboard import SummaryWriter
9
-
10
- import criteria.clip_loss as clip_loss
11
- from criteria import id_loss
12
- from mapper.datasets.latents_dataset import LatentsDataset
13
- from mapper.styleclip_mapper import StyleCLIPMapper
14
- from mapper.training.ranger import Ranger
15
- from mapper.training import train_utils
16
-
17
-
18
- class Coach:
19
- def __init__(self, opts):
20
- self.opts = opts
21
-
22
- self.global_step = 0
23
-
24
- self.device = 'cuda:0'
25
- self.opts.device = self.device
26
-
27
- # Initialize network
28
- self.net = StyleCLIPMapper(self.opts).to(self.device)
29
-
30
- # Initialize loss
31
- if self.opts.id_lambda > 0:
32
- self.id_loss = id_loss.IDLoss(self.opts).to(self.device).eval()
33
- if self.opts.clip_lambda > 0:
34
- self.clip_loss = clip_loss.CLIPLoss(opts)
35
- if self.opts.latent_l2_lambda > 0:
36
- self.latent_l2_loss = nn.MSELoss().to(self.device).eval()
37
-
38
- # Initialize optimizer
39
- self.optimizer = self.configure_optimizers()
40
-
41
- # Initialize dataset
42
- self.train_dataset, self.test_dataset = self.configure_datasets()
43
- self.train_dataloader = DataLoader(self.train_dataset,
44
- batch_size=self.opts.batch_size,
45
- shuffle=True,
46
- num_workers=int(self.opts.workers),
47
- drop_last=True)
48
- self.test_dataloader = DataLoader(self.test_dataset,
49
- batch_size=self.opts.test_batch_size,
50
- shuffle=False,
51
- num_workers=int(self.opts.test_workers),
52
- drop_last=True)
53
-
54
- self.text_inputs = torch.cat([clip.tokenize(self.opts.description)]).cuda()
55
-
56
- # Initialize logger
57
- log_dir = os.path.join(opts.exp_dir, 'logs')
58
- os.makedirs(log_dir, exist_ok=True)
59
- self.log_dir = log_dir
60
- self.logger = SummaryWriter(log_dir=log_dir)
61
-
62
- # Initialize checkpoint dir
63
- self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints')
64
- os.makedirs(self.checkpoint_dir, exist_ok=True)
65
- self.best_val_loss = None
66
- if self.opts.save_interval is None:
67
- self.opts.save_interval = self.opts.max_steps
68
-
69
- def train(self):
70
- self.net.train()
71
- while self.global_step < self.opts.max_steps:
72
- for batch_idx, batch in enumerate(self.train_dataloader):
73
- self.optimizer.zero_grad()
74
- w = batch
75
- w = w.to(self.device)
76
- with torch.no_grad():
77
- x, _ = self.net.decoder([w], input_is_latent=True, randomize_noise=False, truncation=1)
78
- w_hat = w + 0.1 * self.net.mapper(w)
79
- x_hat, w_hat = self.net.decoder([w_hat], input_is_latent=True, return_latents=True, randomize_noise=False, truncation=1)
80
- loss, loss_dict = self.calc_loss(w, x, w_hat, x_hat)
81
- loss.backward()
82
- self.optimizer.step()
83
-
84
- # Logging related
85
- if self.global_step % self.opts.image_interval == 0 or (
86
- self.global_step < 1000 and self.global_step % 1000 == 0):
87
- self.parse_and_log_images(x, x_hat, title='images_train')
88
- if self.global_step % self.opts.board_interval == 0:
89
- self.print_metrics(loss_dict, prefix='train')
90
- self.log_metrics(loss_dict, prefix='train')
91
-
92
- # Validation related
93
- val_loss_dict = None
94
- if self.global_step % self.opts.val_interval == 0 or self.global_step == self.opts.max_steps:
95
- val_loss_dict = self.validate()
96
- if val_loss_dict and (self.best_val_loss is None or val_loss_dict['loss'] < self.best_val_loss):
97
- self.best_val_loss = val_loss_dict['loss']
98
- self.checkpoint_me(val_loss_dict, is_best=True)
99
-
100
- if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps:
101
- if val_loss_dict is not None:
102
- self.checkpoint_me(val_loss_dict, is_best=False)
103
- else:
104
- self.checkpoint_me(loss_dict, is_best=False)
105
-
106
- if self.global_step == self.opts.max_steps:
107
- print('OMG, finished training!')
108
- break
109
-
110
- self.global_step += 1
111
-
112
- def validate(self):
113
- self.net.eval()
114
- agg_loss_dict = []
115
- for batch_idx, batch in enumerate(self.test_dataloader):
116
- if batch_idx > 200:
117
- break
118
-
119
- w = batch
120
-
121
- with torch.no_grad():
122
- w = w.to(self.device).float()
123
- x, _ = self.net.decoder([w], input_is_latent=True, randomize_noise=True, truncation=1)
124
- w_hat = w + 0.1 * self.net.mapper(w)
125
- x_hat, _ = self.net.decoder([w_hat], input_is_latent=True, randomize_noise=True, truncation=1)
126
- loss, cur_loss_dict = self.calc_loss(w, x, w_hat, x_hat)
127
- agg_loss_dict.append(cur_loss_dict)
128
-
129
- # Logging related
130
- self.parse_and_log_images(x, x_hat, title='images_val', index=batch_idx)
131
-
132
- # For first step just do sanity test on small amount of data
133
- if self.global_step == 0 and batch_idx >= 4:
134
- self.net.train()
135
- return None # Do not log, inaccurate in first batch
136
-
137
- loss_dict = train_utils.aggregate_loss_dict(agg_loss_dict)
138
- self.log_metrics(loss_dict, prefix='test')
139
- self.print_metrics(loss_dict, prefix='test')
140
-
141
- self.net.train()
142
- return loss_dict
143
-
144
- def checkpoint_me(self, loss_dict, is_best):
145
- save_name = 'best_model.pt' if is_best else 'iteration_{}.pt'.format(self.global_step)
146
- save_dict = self.__get_save_dict()
147
- checkpoint_path = os.path.join(self.checkpoint_dir, save_name)
148
- torch.save(save_dict, checkpoint_path)
149
- with open(os.path.join(self.checkpoint_dir, 'timestamp.txt'), 'a') as f:
150
- if is_best:
151
- f.write('**Best**: Step - {}, Loss - {:.3f} \n{}\n'.format(self.global_step, self.best_val_loss, loss_dict))
152
- else:
153
- f.write('Step - {}, \n{}\n'.format(self.global_step, loss_dict))
154
-
155
- def configure_optimizers(self):
156
- params = list(self.net.mapper.parameters())
157
- if self.opts.optim_name == 'adam':
158
- optimizer = torch.optim.Adam(params, lr=self.opts.learning_rate)
159
- else:
160
- optimizer = Ranger(params, lr=self.opts.learning_rate)
161
- return optimizer
162
-
163
- def configure_datasets(self):
164
- if self.opts.latents_train_path:
165
- train_latents = torch.load(self.opts.latents_train_path)
166
- else:
167
- train_latents_z = torch.randn(self.opts.train_dataset_size, 512).cuda()
168
- train_latents = []
169
- for b in range(self.opts.train_dataset_size // self.opts.batch_size):
170
- with torch.no_grad():
171
- _, train_latents_b = self.net.decoder([train_latents_z[b: b + self.opts.batch_size]],
172
- truncation=0.7, truncation_latent=self.net.latent_avg, return_latents=True)
173
- train_latents.append(train_latents_b)
174
- train_latents = torch.cat(train_latents)
175
-
176
- if self.opts.latents_test_path:
177
- test_latents = torch.load(self.opts.latents_test_path)
178
- else:
179
- test_latents_z = torch.randn(self.opts.train_dataset_size, 512).cuda()
180
- test_latents = []
181
- for b in range(self.opts.test_dataset_size // self.opts.test_batch_size):
182
- with torch.no_grad():
183
- _, test_latents_b = self.net.decoder([test_latents_z[b: b + self.opts.test_batch_size]],
184
- truncation=0.7, truncation_latent=self.net.latent_avg, return_latents=True)
185
- test_latents.append(test_latents_b)
186
- test_latents = torch.cat(test_latents)
187
-
188
- train_dataset_celeba = LatentsDataset(latents=train_latents.cpu(),
189
- opts=self.opts)
190
- test_dataset_celeba = LatentsDataset(latents=test_latents.cpu(),
191
- opts=self.opts)
192
- train_dataset = train_dataset_celeba
193
- test_dataset = test_dataset_celeba
194
- print("Number of training samples: {}".format(len(train_dataset)))
195
- print("Number of test samples: {}".format(len(test_dataset)))
196
- return train_dataset, test_dataset
197
-
198
- def calc_loss(self, w, x, w_hat, x_hat):
199
- loss_dict = {}
200
- loss = 0.0
201
- if self.opts.id_lambda > 0:
202
- loss_id, sim_improvement = self.id_loss(x_hat, x)
203
- loss_dict['loss_id'] = float(loss_id)
204
- loss_dict['id_improve'] = float(sim_improvement)
205
- loss = loss_id * self.opts.id_lambda
206
- if self.opts.clip_lambda > 0:
207
- loss_clip = self.clip_loss(x_hat, self.text_inputs).mean()
208
- loss_dict['loss_clip'] = float(loss_clip)
209
- loss += loss_clip * self.opts.clip_lambda
210
- if self.opts.latent_l2_lambda > 0:
211
- loss_l2_latent = self.latent_l2_loss(w_hat, w)
212
- loss_dict['loss_l2_latent'] = float(loss_l2_latent)
213
- loss += loss_l2_latent * self.opts.latent_l2_lambda
214
- loss_dict['loss'] = float(loss)
215
- return loss, loss_dict
216
-
217
- def log_metrics(self, metrics_dict, prefix):
218
- for key, value in metrics_dict.items():
219
- #pass
220
- print(f"step: {self.global_step} \t metric: {prefix}/{key} \t value: {value}")
221
- self.logger.add_scalar('{}/{}'.format(prefix, key), value, self.global_step)
222
-
223
- def print_metrics(self, metrics_dict, prefix):
224
- print('Metrics for {}, step {}'.format(prefix, self.global_step))
225
- for key, value in metrics_dict.items():
226
- print('\t{} = '.format(key), value)
227
-
228
- def parse_and_log_images(self, x, x_hat, title, index=None):
229
- if index is None:
230
- path = os.path.join(self.log_dir, title, f'{str(self.global_step).zfill(5)}.jpg')
231
- else:
232
- path = os.path.join(self.log_dir, title, f'{str(self.global_step).zfill(5)}_{str(index).zfill(5)}.jpg')
233
- os.makedirs(os.path.dirname(path), exist_ok=True)
234
- torchvision.utils.save_image(torch.cat([x.detach().cpu(), x_hat.detach().cpu()]), path,
235
- normalize=True, scale_each=True, range=(-1, 1), nrow=self.opts.batch_size)
236
-
237
- def __get_save_dict(self):
238
- save_dict = {
239
- 'state_dict': self.net.state_dict(),
240
- 'opts': vars(self.opts)
241
- }
242
- return save_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/mapper/training/ranger.py DELETED
@@ -1,164 +0,0 @@
1
- # Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer.
2
-
3
- # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
4
- # and/or
5
- # https://github.com/lessw2020/Best-Deep-Learning-Optimizers
6
-
7
- # Ranger has now been used to capture 12 records on the FastAI leaderboard.
8
-
9
- # This version = 20.4.11
10
-
11
- # Credits:
12
- # Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization
13
- # RAdam --> https://github.com/LiyuanLucasLiu/RAdam
14
- # Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code.
15
- # Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610
16
-
17
- # summary of changes:
18
- # 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init.
19
- # full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights),
20
- # supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues.
21
- # changes 8/31/19 - fix references to *self*.N_sma_threshold;
22
- # changed eps to 1e-5 as better default than 1e-8.
23
-
24
- import math
25
- import torch
26
- from torch.optim.optimizer import Optimizer
27
-
28
-
29
- class Ranger(Optimizer):
30
-
31
- def __init__(self, params, lr=1e-3, # lr
32
- alpha=0.5, k=6, N_sma_threshhold=5, # Ranger configs
33
- betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam configs
34
- use_gc=True, gc_conv_only=False
35
- # Gradient centralization on or off, applied to conv layers only or conv + fc layers
36
- ):
37
-
38
- # parameter checks
39
- if not 0.0 <= alpha <= 1.0:
40
- raise ValueError(f'Invalid slow update rate: {alpha}')
41
- if not 1 <= k:
42
- raise ValueError(f'Invalid lookahead steps: {k}')
43
- if not lr > 0:
44
- raise ValueError(f'Invalid Learning Rate: {lr}')
45
- if not eps > 0:
46
- raise ValueError(f'Invalid eps: {eps}')
47
-
48
- # parameter comments:
49
- # beta1 (momentum) of .95 seems to work better than .90...
50
- # N_sma_threshold of 5 seems better in testing than 4.
51
- # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
52
-
53
- # prep defaults and init torch.optim base
54
- defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold,
55
- eps=eps, weight_decay=weight_decay)
56
- super().__init__(params, defaults)
57
-
58
- # adjustable threshold
59
- self.N_sma_threshhold = N_sma_threshhold
60
-
61
- # look ahead params
62
-
63
- self.alpha = alpha
64
- self.k = k
65
-
66
- # radam buffer for state
67
- self.radam_buffer = [[None, None, None] for ind in range(10)]
68
-
69
- # gc on or off
70
- self.use_gc = use_gc
71
-
72
- # level of gradient centralization
73
- self.gc_gradient_threshold = 3 if gc_conv_only else 1
74
-
75
- def __setstate__(self, state):
76
- super(Ranger, self).__setstate__(state)
77
-
78
- def step(self, closure=None):
79
- loss = None
80
-
81
- # Evaluate averages and grad, update param tensors
82
- for group in self.param_groups:
83
-
84
- for p in group['params']:
85
- if p.grad is None:
86
- continue
87
- grad = p.grad.data.float()
88
-
89
- if grad.is_sparse:
90
- raise RuntimeError('Ranger optimizer does not support sparse gradients')
91
-
92
- p_data_fp32 = p.data.float()
93
-
94
- state = self.state[p] # get state dict for this param
95
-
96
- if len(state) == 0: # if first time to run...init dictionary with our desired entries
97
- # if self.first_run_check==0:
98
- # self.first_run_check=1
99
- # print("Initializing slow buffer...should not see this at load from saved model!")
100
- state['step'] = 0
101
- state['exp_avg'] = torch.zeros_like(p_data_fp32)
102
- state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
103
-
104
- # look ahead weight storage now in state dict
105
- state['slow_buffer'] = torch.empty_like(p.data)
106
- state['slow_buffer'].copy_(p.data)
107
-
108
- else:
109
- state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
110
- state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
111
-
112
- # begin computations
113
- exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
114
- beta1, beta2 = group['betas']
115
-
116
- # GC operation for Conv layers and FC layers
117
- if grad.dim() > self.gc_gradient_threshold:
118
- grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True))
119
-
120
- state['step'] += 1
121
-
122
- # compute variance mov avg
123
- exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
124
- # compute mean moving avg
125
- exp_avg.mul_(beta1).add_(1 - beta1, grad)
126
-
127
- buffered = self.radam_buffer[int(state['step'] % 10)]
128
-
129
- if state['step'] == buffered[0]:
130
- N_sma, step_size = buffered[1], buffered[2]
131
- else:
132
- buffered[0] = state['step']
133
- beta2_t = beta2 ** state['step']
134
- N_sma_max = 2 / (1 - beta2) - 1
135
- N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
136
- buffered[1] = N_sma
137
- if N_sma > self.N_sma_threshhold:
138
- step_size = math.sqrt(
139
- (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
140
- N_sma_max - 2)) / (1 - beta1 ** state['step'])
141
- else:
142
- step_size = 1.0 / (1 - beta1 ** state['step'])
143
- buffered[2] = step_size
144
-
145
- if group['weight_decay'] != 0:
146
- p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
147
-
148
- # apply lr
149
- if N_sma > self.N_sma_threshhold:
150
- denom = exp_avg_sq.sqrt().add_(group['eps'])
151
- p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
152
- else:
153
- p_data_fp32.add_(-step_size * group['lr'], exp_avg)
154
-
155
- p.data.copy_(p_data_fp32)
156
-
157
- # integrated look ahead...
158
- # we do it at the param level instead of group level
159
- if state['step'] % group['k'] == 0:
160
- slow_p = state['slow_buffer'] # get access to slow param tensor
161
- slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha
162
- p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor
163
-
164
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/mapper/training/train_utils.py DELETED
@@ -1,13 +0,0 @@
1
-
2
- def aggregate_loss_dict(agg_loss_dict):
3
- mean_vals = {}
4
- for output in agg_loss_dict:
5
- for key in output:
6
- mean_vals[key] = mean_vals.setdefault(key, []) + [output[key]]
7
- for key in mean_vals:
8
- if len(mean_vals[key]) > 0:
9
- mean_vals[key] = sum(mean_vals[key]) / len(mean_vals[key])
10
- else:
11
- print('{} has no value'.format(key))
12
- mean_vals[key] = 0
13
- return mean_vals
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/models/StyleCLIP/models/__init__.py DELETED
File without changes
PTI/models/StyleCLIP/models/facial_recognition/__init__.py DELETED
File without changes
PTI/models/StyleCLIP/models/facial_recognition/helpers.py DELETED
@@ -1,119 +0,0 @@
1
- from collections import namedtuple
2
- import torch
3
- from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
4
-
5
- """
6
- ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
7
- """
8
-
9
-
10
- class Flatten(Module):
11
- def forward(self, input):
12
- return input.view(input.size(0), -1)
13
-
14
-
15
- def l2_norm(input, axis=1):
16
- norm = torch.norm(input, 2, axis, True)
17
- output = torch.div(input, norm)
18
- return output
19
-
20
-
21
- class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
22
- """ A named tuple describing a ResNet block. """
23
-
24
-
25
- def get_block(in_channel, depth, num_units, stride=2):
26
- return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
27
-
28
-
29
- def get_blocks(num_layers):
30
- if num_layers == 50:
31
- blocks = [
32
- get_block(in_channel=64, depth=64, num_units=3),
33
- get_block(in_channel=64, depth=128, num_units=4),
34
- get_block(in_channel=128, depth=256, num_units=14),
35
- get_block(in_channel=256, depth=512, num_units=3)
36
- ]
37
- elif num_layers == 100:
38
- blocks = [
39
- get_block(in_channel=64, depth=64, num_units=3),
40
- get_block(in_channel=64, depth=128, num_units=13),
41
- get_block(in_channel=128, depth=256, num_units=30),
42
- get_block(in_channel=256, depth=512, num_units=3)
43
- ]
44
- elif num_layers == 152:
45
- blocks = [
46
- get_block(in_channel=64, depth=64, num_units=3),
47
- get_block(in_channel=64, depth=128, num_units=8),
48
- get_block(in_channel=128, depth=256, num_units=36),
49
- get_block(in_channel=256, depth=512, num_units=3)
50
- ]
51
- else:
52
- raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
53
- return blocks
54
-
55
-
56
- class SEModule(Module):
57
- def __init__(self, channels, reduction):
58
- super(SEModule, self).__init__()
59
- self.avg_pool = AdaptiveAvgPool2d(1)
60
- self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
61
- self.relu = ReLU(inplace=True)
62
- self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
63
- self.sigmoid = Sigmoid()
64
-
65
- def forward(self, x):
66
- module_input = x
67
- x = self.avg_pool(x)
68
- x = self.fc1(x)
69
- x = self.relu(x)
70
- x = self.fc2(x)
71
- x = self.sigmoid(x)
72
- return module_input * x
73
-
74
-
75
- class bottleneck_IR(Module):
76
- def __init__(self, in_channel, depth, stride):
77
- super(bottleneck_IR, self).__init__()
78
- if in_channel == depth:
79
- self.shortcut_layer = MaxPool2d(1, stride)
80
- else:
81
- self.shortcut_layer = Sequential(
82
- Conv2d(in_channel, depth, (1, 1), stride, bias=False),
83
- BatchNorm2d(depth)
84
- )
85
- self.res_layer = Sequential(
86
- BatchNorm2d(in_channel),
87
- Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
88
- Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
89
- )
90
-
91
- def forward(self, x):
92
- shortcut = self.shortcut_layer(x)
93
- res = self.res_layer(x)
94
- return res + shortcut
95
-
96
-
97
- class bottleneck_IR_SE(Module):
98
- def __init__(self, in_channel, depth, stride):
99
- super(bottleneck_IR_SE, self).__init__()
100
- if in_channel == depth:
101
- self.shortcut_layer = MaxPool2d(1, stride)
102
- else:
103
- self.shortcut_layer = Sequential(
104
- Conv2d(in_channel, depth, (1, 1), stride, bias=False),
105
- BatchNorm2d(depth)
106
- )
107
- self.res_layer = Sequential(
108
- BatchNorm2d(in_channel),
109
- Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
110
- PReLU(depth),
111
- Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
112
- BatchNorm2d(depth),
113
- SEModule(depth, 16)
114
- )
115
-
116
- def forward(self, x):
117
- shortcut = self.shortcut_layer(x)
118
- res = self.res_layer(x)
119
- return res + shortcut