hamzapehlivan commited on
Commit
6709fc9
·
1 Parent(s): bcd0ad0

Intial Commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +27 -0
  2. README.md +1 -0
  3. app.py +72 -0
  4. datasets/__init__.py +0 -0
  5. datasets/demo_dataset.py +11 -0
  6. datasets/inference_dataset.py +49 -0
  7. datasets/process_image.py +166 -0
  8. editings/__init__.py +0 -0
  9. editings/base_encoder_runner.py +861 -0
  10. editings/editor.py +23 -0
  11. editings/ganspace.py +49 -0
  12. editings/ganspace_pca/ganspace_configs.csv +42 -0
  13. editings/interfacegan.py +27 -0
  14. editings/styleclip.py +49 -0
  15. editings/styleclip_directions/__init__.py +0 -0
  16. editings/styleclip_directions/styleclip_directions/global_directions/ffhq/S_mean_std +0 -0
  17. editings/styleclip_directions/styleclip_directions/global_directions/templates.txt +79 -0
  18. editings/styleclip_directions/styleclip_mapper_network.py +121 -0
  19. editings/styleclip_directions/styleclip_mapping_configs.csv +14 -0
  20. inference.py +71 -0
  21. models/dnnlib/__init__.py +9 -0
  22. models/dnnlib/util.py +29 -0
  23. models/e4e.py +348 -0
  24. models/stylegan2.py +965 -0
  25. models/styleres.py +75 -0
  26. models/torch_utils/__init__.py +9 -0
  27. models/torch_utils/custom_ops.py +157 -0
  28. models/torch_utils/misc.py +265 -0
  29. models/torch_utils/ops/__init__.py +9 -0
  30. models/torch_utils/ops/bias_act.cpp +99 -0
  31. models/torch_utils/ops/bias_act.cu +173 -0
  32. models/torch_utils/ops/bias_act.h +38 -0
  33. models/torch_utils/ops/bias_act.py +209 -0
  34. models/torch_utils/ops/conv2d_gradfix.py +198 -0
  35. models/torch_utils/ops/conv2d_resample.py +143 -0
  36. models/torch_utils/ops/filtered_lrelu.cpp +300 -0
  37. models/torch_utils/ops/filtered_lrelu.cu +1284 -0
  38. models/torch_utils/ops/filtered_lrelu.h +90 -0
  39. models/torch_utils/ops/filtered_lrelu.py +274 -0
  40. models/torch_utils/ops/filtered_lrelu_ns.cu +27 -0
  41. models/torch_utils/ops/filtered_lrelu_rd.cu +27 -0
  42. models/torch_utils/ops/filtered_lrelu_wr.cu +27 -0
  43. models/torch_utils/ops/fma.py +60 -0
  44. models/torch_utils/ops/grid_sample_gradfix.py +77 -0
  45. models/torch_utils/ops/upfirdn2d.cpp +107 -0
  46. models/torch_utils/ops/upfirdn2d.cu +384 -0
  47. models/torch_utils/ops/upfirdn2d.h +59 -0
  48. models/torch_utils/ops/upfirdn2d.py +389 -0
  49. models/torch_utils/persistence.py +251 -0
  50. models/torch_utils/training_stats.py +268 -0
.gitignore ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ /.vscode/
4
+ /.idea/
5
+ *.sw[pon]
6
+
7
+ /results/
8
+ *.jpg
9
+ *.png
10
+ *.jpeg
11
+ *.gif
12
+ *.avi
13
+ *.mp4
14
+
15
+ *.npy
16
+ *.json
17
+ *.log
18
+ *.html
19
+ *.tar
20
+ *.zip
21
+ events.*
22
+
23
+ *.pth
24
+ *.pt
25
+ *.pkl
26
+ *.h5
27
+ *.dat
README.md CHANGED
@@ -5,6 +5,7 @@ colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 3.24.1
 
8
  app_file: app.py
9
  pinned: false
10
  ---
 
5
  colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 3.24.1
8
+ python_version: 3.7
9
  app_file: app.py
10
  pinned: false
11
  ---
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from inference import initialize_styleres
3
+ from utils import AppUtils
4
+ from datasets.process_image import ImageProcessor
5
+ from argparse import ArgumentParser
6
+
7
+ EXAMPLES = [
8
+ ["test_sample/1.jpg", "InterfaceGAN", "Smile", "2.0", False],
9
+ ]
10
+
11
+ parser = ArgumentParser()
12
+ parser.add_argument('--device', type=str, default='cpu', help='Which device to use')
13
+ args = parser.parse_args()
14
+
15
+ utils = AppUtils()
16
+ methods = utils.get_methods()
17
+ styleres = initialize_styleres('checkpoints/styleres_ffhq.pth', args.device)
18
+ image_processor = ImageProcessor('checkpoints/shape_predictor_68_face_landmarks.dat')
19
+
20
+ def process_image(image, method, edit, factor, is_align_checked):
21
+ cfg = utils.args_to_cfg(method, edit, factor)
22
+ if is_align_checked:
23
+ image = image_processor.align_face(image)
24
+ image = image_processor.preprocess_image(image, is_batch=False)
25
+ image = styleres.edit_images(image, cfg)
26
+ image = image_processor.postprocess_image(image.detach().cpu().numpy(), is_batch=False)
27
+ return image
28
+
29
+ def update_edit_dropdown(method):
30
+ choices = utils.get_edits(method)
31
+ return gr.Dropdown.update(choices=choices, value=choices[0])
32
+
33
+ def update_slider(method):
34
+ minimum, maximum, step= utils.get_range(method)
35
+ return gr.Slider.update(minimum=minimum, maximum=maximum, value=0, step=step, label=f"Strength [{minimum}, {maximum}]")
36
+
37
+
38
+ with gr.Blocks() as demo:
39
+ gr.Markdown(
40
+ """
41
+ # StyleRes: Transforming the Residuals for Real Image Editing with StyleGAN (CVPR2023)
42
+ """)
43
+ with gr.Row():
44
+ image_input = gr.Image(type="pil", shape=(256,256), label='Input Image', value="test_sample/116.jpg")
45
+ image_output = gr.Image(type="pil", shape=(256,256), label='Output Image')
46
+
47
+ with gr.Row():
48
+ with gr.Column(scale=0.25, min_width=50):
49
+ methods_drowdown = gr.Dropdown(methods, label="Choose Method", value=methods[0])
50
+ with gr.Column(scale=0.25, min_width=50):
51
+ edits_dropdown = gr.Dropdown(utils.get_edits(methods[0]), label="Choose Edit", value=utils.get_edits(methods[0])[0])
52
+
53
+
54
+ with gr.Row():
55
+ with gr.Column(scale=0.1, min_width=50):
56
+ is_align_checked = gr.Checkbox(label="Crop + Align")
57
+ with gr.Column(scale=0.4, min_width=50):
58
+ factor_slider = gr.Slider(-5, 5, value=0, label="Strength [-5, 5]")
59
+
60
+ gr.Examples(
61
+ examples=EXAMPLES,
62
+ inputs=[image_input, methods_drowdown, edits_dropdown, factor_slider, is_align_checked],
63
+ outputs=image_output,
64
+ fn=process_image,
65
+ cache_examples=True,
66
+ )
67
+ methods_drowdown.change(update_edit_dropdown, inputs=methods_drowdown, outputs=edits_dropdown )
68
+ methods_drowdown.change(update_slider, inputs=methods_drowdown, outputs=factor_slider)
69
+ factor_slider.release(process_image, inputs=[image_input, methods_drowdown, edits_dropdown, factor_slider, is_align_checked], outputs=image_output)
70
+
71
+ demo.launch(debug=True)
72
+
datasets/__init__.py ADDED
File without changes
datasets/demo_dataset.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ def preprocess_demo_image(image):
5
+ image = image.resize( (256, 256))
6
+ image = np.asarray(image).transpose(2, 0, 1).astype(np.float32) # C,H,W -> H,W,C
7
+ image = torch.FloatTensor(image.copy())
8
+ image = (image - 127.5) / 127.5 # Normalize
9
+ return image.unsqueeze(0)
10
+
11
+
datasets/inference_dataset.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the class of dataset."""
3
+
4
+ import os
5
+ from PIL import Image
6
+ from .process_image import ImageProcessor
7
+ from torch.utils.data import Dataset
8
+
9
+ class InferenceDataset(Dataset):
10
+
11
+ def __init__(self,
12
+ root_dir,
13
+ resolution=256,
14
+ aligner_path=None
15
+ ):
16
+ """Initializes the dataset.
17
+
18
+ Args:
19
+ root_dir: Root directory containing the dataset.
20
+ resolution: The resolution of the returned image.
21
+ transform: The transform function for pre-processing.
22
+ (default: `datasets.transforms.normalize_image()`)
23
+ """
24
+
25
+ self.root_dir = root_dir
26
+ self.resolution = resolution
27
+ self.image_paths = sorted(os.listdir(self.root_dir))
28
+ self.num_samples = len(self.image_paths)
29
+ self.processor = ImageProcessor(aligner_path)
30
+
31
+ def __len__(self):
32
+ return self.num_samples
33
+
34
+ def __getitem__(self, idx):
35
+ data = dict()
36
+
37
+ image_path = self.image_paths[idx]
38
+ image = Image.open(os.path.join(self.root_dir, image_path))
39
+ image = self.processor.align_face(image)
40
+ image = self.processor.preprocess_image(image)
41
+ # image = image.resize( (self.resolution, self.resolution))
42
+ # image = np.asarray(image).transpose(2, 0, 1).astype(np.float32) # C,H,W -> H,W,C
43
+ # image = torch.FloatTensor(image.copy())
44
+ # image = (image - 127.5) / 127.5 # Normalize
45
+
46
+ data.update({'image': image})
47
+ data.update({'name': image_path})
48
+ return data
49
+
datasets/process_image.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+ import dlib
6
+ import numpy as np
7
+ import PIL
8
+ import PIL.Image
9
+ import scipy
10
+ import scipy.ndimage
11
+
12
+ class ImageProcessor():
13
+ def __init__(self, predictor_path=None) -> None:
14
+ self.predictor = None
15
+ if predictor_path:
16
+ self.predictor = dlib.shape_predictor(predictor_path)
17
+
18
+ @staticmethod
19
+ def preprocess_image(image, is_batch=True):
20
+ image = image.resize( (256, 256))
21
+ image = np.asarray(image).transpose(2, 0, 1).astype(np.float32) # C,H,W -> H,W,C
22
+ image = torch.FloatTensor(image.copy())
23
+ image = (image - 127.5) / 127.5 # Normalize
24
+ if not is_batch:
25
+ image = image.unsqueeze(0)
26
+ return image
27
+
28
+ """
29
+ Input: A numpy image with shape NxCxHxW.
30
+ Output: Output image with NxHxWxC with values between 0-255
31
+ """
32
+ @staticmethod
33
+ def postprocess_image(image, min_val=-1.0, max_val=1.0, is_batch=True):
34
+ image = image.astype(np.float64)
35
+ image = (image - min_val) * 255 / (max_val - min_val)
36
+ image = np.clip(image + 0.5, 0, 255).astype(np.uint8)
37
+ image = image.transpose(0, 2, 3, 1)
38
+ if not is_batch:
39
+ image = Image.fromarray(image[0]).resize((256,256))
40
+ return image
41
+
42
+ """
43
+ brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)
44
+ author: lzhbrian (https://lzhbrian.me)
45
+ date: 2020.1.5
46
+ note: code is heavily borrowed from
47
+ https://github.com/NVlabs/ffhq-dataset
48
+ http://dlib.net/face_landmark_detection.py.html
49
+ requirements:
50
+ apt install cmake
51
+ conda install Pillow numpy scipy
52
+ pip install dlib
53
+ # download face landmark model from:
54
+ # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
55
+ """
56
+
57
+ def get_landmark(self, image):
58
+ """get landmark with dlib
59
+ :return: np.array shape=(68, 2)
60
+ """
61
+ detector = dlib.get_frontal_face_detector()
62
+
63
+ # img = dlib.load_rgb_image(filepath)
64
+ img = np.asarray(image)
65
+ dets = detector(img, 1)
66
+
67
+ for k, d in enumerate(dets):
68
+ shape = self.predictor(img, d)
69
+
70
+ t = list(shape.parts())
71
+ a = []
72
+ for tt in t:
73
+ a.append([tt.x, tt.y])
74
+ lm = np.array(a)
75
+ return lm
76
+
77
+ def align_face(self, img):
78
+ """
79
+ :param image: PIL image
80
+ :return: PIL Image
81
+ """
82
+ if self.predictor is None:
83
+ return img
84
+
85
+ lm = self.get_landmark(img)
86
+
87
+ lm_chin = lm[0: 17] # left-right
88
+ lm_eyebrow_left = lm[17: 22] # left-right
89
+ lm_eyebrow_right = lm[22: 27] # left-right
90
+ lm_nose = lm[27: 31] # top-down
91
+ lm_nostrils = lm[31: 36] # top-down
92
+ lm_eye_left = lm[36: 42] # left-clockwise
93
+ lm_eye_right = lm[42: 48] # left-clockwise
94
+ lm_mouth_outer = lm[48: 60] # left-clockwise
95
+ lm_mouth_inner = lm[60: 68] # left-clockwise
96
+
97
+ # Calculate auxiliary vectors.
98
+ eye_left = np.mean(lm_eye_left, axis=0)
99
+ eye_right = np.mean(lm_eye_right, axis=0)
100
+ eye_avg = (eye_left + eye_right) * 0.5
101
+ eye_to_eye = eye_right - eye_left
102
+ mouth_left = lm_mouth_outer[0]
103
+ mouth_right = lm_mouth_outer[6]
104
+ mouth_avg = (mouth_left + mouth_right) * 0.5
105
+ eye_to_mouth = mouth_avg - eye_avg
106
+
107
+ # Choose oriented crop rectangle.
108
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
109
+ x /= np.hypot(*x)
110
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
111
+ y = np.flipud(x) * [-1, 1]
112
+ c = eye_avg + eye_to_mouth * 0.1
113
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
114
+ qsize = np.hypot(*x) * 2
115
+
116
+ # read image
117
+ # img = PIL.Image.open(filepath)
118
+
119
+ output_size = 512
120
+ transform_size = 1024
121
+ enable_padding = True
122
+
123
+ # Shrink.
124
+ shrink = int(np.floor(qsize / output_size * 0.5))
125
+ if shrink > 1:
126
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
127
+ img = img.resize(rsize, PIL.Image.ANTIALIAS)
128
+ quad /= shrink
129
+ qsize /= shrink
130
+
131
+ # Crop.
132
+ border = max(int(np.rint(qsize * 0.1)), 3)
133
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
134
+ int(np.ceil(max(quad[:, 1]))))
135
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
136
+ min(crop[3] + border, img.size[1]))
137
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
138
+ img = img.crop(crop)
139
+ quad -= crop[0:2]
140
+
141
+ # Pad.
142
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
143
+ int(np.ceil(max(quad[:, 1]))))
144
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
145
+ max(pad[3] - img.size[1] + border, 0))
146
+ if enable_padding and max(pad) > border - 4:
147
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
148
+ img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
149
+ h, w, _ = img.shape
150
+ y, x, _ = np.ogrid[:h, :w, :1]
151
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
152
+ 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
153
+ blur = qsize * 0.02
154
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
155
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
156
+ img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
157
+ quad += pad[:2]
158
+
159
+ # Transform.
160
+ img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
161
+ if output_size < transform_size:
162
+ img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
163
+
164
+ # Save aligned image.
165
+ img.save('aligned.png')
166
+ return img
editings/__init__.py ADDED
File without changes
editings/base_encoder_runner.py ADDED
@@ -0,0 +1,861 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the base class for Encoder (GAN Inversion) runner."""
3
+
4
+ import os
5
+ import shutil
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ import torchvision.transforms as T
10
+
11
+ from utils.visualizer import HtmlPageVisualizer
12
+ from utils.visualizer import get_grid_shape
13
+ from utils.visualizer import postprocess_image
14
+ from utils.visualizer import save_image
15
+ from utils.visualizer import load_image
16
+ from utils.visualizer import postprocess_tensor
17
+
18
+ from metrics.inception import build_inception_model
19
+ from metrics.fid import extract_feature
20
+ from metrics.fid import compute_fid
21
+ from metrics.MSSIM import MSSSIM
22
+ from metrics.LPIPS import LPIPS
23
+ import numpy as np
24
+ from .base_runner import BaseRunner
25
+ from datasets import BaseDataset
26
+ from torch.utils.data import DataLoader
27
+ from PIL import Image
28
+ from runners.controllers.summary_writer import log_image
29
+ import torchvision
30
+ from editings.latent_editor import LatentEditor
31
+ from editings.styleclip.edit_hfgi import styleclip_edit, load_stylegan_generator,load_direction_calculator
32
+ from editings.GradCtrl.manipulate import main as gradctrl
33
+ import torch.nn.functional as F
34
+ import time
35
+
36
+
37
+ __all__ = ['BaseEncoderRunner']
38
+
39
+
40
+ class BaseEncoderRunner(BaseRunner):
41
+ """Defines the base class for Encoder runner."""
42
+
43
+ def __init__(self, config, logger):
44
+ super().__init__(config, logger)
45
+ self.inception_model = None
46
+
47
+ def build_models(self):
48
+ super().build_models()
49
+ assert 'encoder' in self.models
50
+ assert 'generator_smooth' in self.models
51
+ assert 'discriminator' in self.models
52
+
53
+ self.resolution = self.models['generator_smooth'].resolution
54
+ self.G_kwargs_train = self.config.modules['generator_smooth'].get(
55
+ 'kwargs_train', dict())
56
+ self.G_kwargs_val = self.config.modules['generator_smooth'].get(
57
+ 'kwargs_val', dict())
58
+ self.D_kwargs_train = self.config.modules['discriminator'].get(
59
+ 'kwargs_train', dict())
60
+ self.D_kwargs_val = self.config.modules['discriminator'].get(
61
+ 'kwargs_val', dict())
62
+ if self.config.use_disc2:
63
+ self.D2_kwargs_train = self.config.modules['discriminator2'].get(
64
+ 'kwargs_train', dict())
65
+ self.D2_kwargs_val = self.config.modules['discriminator2'].get(
66
+ 'kwargs_val', dict())
67
+ if self.config.mapping_method != 'pretrained':
68
+ self.M_kwargs_train = self.config.modules['mapping'].get(
69
+ 'kwargs_train', dict())
70
+ self.M_kwargs_val = self.config.modules['mapping'].get(
71
+ 'kwargs_val', dict())
72
+ if self.config.create_mixing_network:
73
+ self.MIX_kwargs_train = self.config.modules['mixer'].get(
74
+ 'kwargs_train', dict())
75
+ self.MIX_kwargs_val = self.config.modules['mixer'].get(
76
+ 'kwargs_val', dict())
77
+
78
+
79
+
80
+ def train_step(self, data, **train_kwargs):
81
+ raise NotImplementedError('Should be implemented in derived class.')
82
+
83
+ def mse(self, mse_num):
84
+
85
+ if mse_num == 0:
86
+ return -1
87
+ self.set_mode('val')
88
+
89
+ if self.val_loader is None:
90
+ self.build_dataset('val')
91
+ if mse_num == "auto":
92
+ mse_num = len(self.val_loader.dataset)
93
+
94
+ indices = list(range(self.rank, mse_num, self.world_size))
95
+ self.logger.init_pbar()
96
+ task1 = self.logger.add_pbar_task('MSE-LPIPS-SSIM', total=mse_num)
97
+
98
+ lpips = LPIPS()
99
+ ssim = MSSSIM(size_average=False)
100
+ n_evals = 3
101
+ gather_list = [torch.zeros( (self.val_batch_size, n_evals), device=torch.cuda.current_device()) for i in range(self.world_size)]
102
+ all_errors = np.zeros( (mse_num, n_evals), dtype=np.float64)
103
+ shared_tensor = torch.zeros((self.val_batch_size, n_evals), device=torch.cuda.current_device())
104
+ gather_idx = 0
105
+ for batch_idx in range(0, len(indices), self.val_batch_size):
106
+ sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
107
+ batch_size = len(sub_indices)
108
+ data = next(self.val_loader)
109
+ for key in data:
110
+ if key != 'name':
111
+ data[key] = data[key][:batch_size].cuda(
112
+ torch.cuda.current_device(), non_blocking=True)
113
+
114
+ with torch.no_grad():
115
+ real_images = data['image']
116
+ return_dict = self.forward_pass(data, return_vals='fakes, wp_mixed')
117
+ fakes = return_dict['fakes']
118
+ shared_tensor[:, 0] = torch.mean((fakes - real_images)**2, dim=(1,2,3)) #MSE Error
119
+ shared_tensor[:, 1]= lpips(real_images, fakes)
120
+ shared_tensor[:, 2]= ssim(real_images, fakes) #ssim (real_images[0].unsqueeze(0), fakes[0].unsqueeze(0) )
121
+
122
+ dist.all_gather(gather_list, shared_tensor)
123
+ if self.rank == 0:
124
+ for t in gather_list:
125
+ all_errors[gather_idx:gather_idx+batch_size, 0] = t[:,0].cpu().numpy()
126
+ all_errors[gather_idx:gather_idx+batch_size, 1] = t[:,1].cpu().numpy()
127
+ all_errors[gather_idx:gather_idx+batch_size, 2] = t[:,2].cpu().numpy()
128
+ gather_idx = gather_idx+batch_size
129
+ self.logger.update_pbar(task1, batch_size * self.world_size)
130
+ self.logger.close_pbar()
131
+ mean_lst, std_lst = np.mean(all_errors, axis=0), np.std(all_errors, axis=0)
132
+ mse_mean, lpips_mean, ssim_mean = mean_lst[0].item(), mean_lst[1].item(), mean_lst[2].item()
133
+ mse_std, lpips_std, ssim_std = std_lst[0].item(), std_lst[1].item(), std_lst[2].item()
134
+ return_vals = {'mse': (mse_mean,mse_std), 'lpips': (lpips_mean, lpips_std), 'ssim':(ssim_mean, ssim_std)}
135
+ return return_vals
136
+
137
+ def fid_attribute(self,
138
+ fid_num,
139
+ z=None,
140
+ ignore_cache=False,
141
+ align_tf=True,
142
+ attribute='smile', factor=1, direction=None):
143
+ """Computes the FID metric."""
144
+ self.set_mode('val')
145
+ direction = torch.load(f'editings/interfacegan_directions/{attribute}.pt').cuda()
146
+ if factor < 0:
147
+ self.config.data['smile']['root_dir'] = f'/media/hdd2/adundar/hamza/genforce/data/temp/smile_with_original'
148
+ elif factor > 0:
149
+ self.config.data['smile']['root_dir'] = f"/media/hdd2/adundar/hamza/genforce/data/temp/smile_without_original"
150
+
151
+ fake_loader = self.build_dataset(f"smile")
152
+
153
+ #fid_num = min(fid_num, len(self.val_loader.dataset))
154
+ fid_num = len(fake_loader.dataset)
155
+ if self.inception_model is None:
156
+ if align_tf:
157
+ self.logger.info(f'Building inception model '
158
+ f'(aligned with TensorFlow) ...')
159
+ else:
160
+ self.logger.info(f'Building inception model '
161
+ f'(using torchvision) ...')
162
+ self.inception_model = build_inception_model(align_tf).cuda()
163
+ self.logger.info(f'Finish building inception model.')
164
+
165
+ if z is not None:
166
+ assert isinstance(z, np.ndarray)
167
+ assert z.ndim == 2 and z.shape[1] == self.z_space_dim
168
+ fid_num = min(fid_num, z.shape[0])
169
+ z = torch.from_numpy(z).type(torch.FloatTensor)
170
+ if not fid_num:
171
+ return -1
172
+
173
+ indices = list(range(self.rank, fid_num, self.world_size))
174
+
175
+ self.logger.init_pbar()
176
+ # Extract features from fake images.
177
+ fake_feature_list = []
178
+ task1 = self.logger.add_pbar_task(f'FID-{attribute}_fake', total=fid_num)
179
+ for batch_idx in range(0, len(indices), self.val_batch_size):
180
+ sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
181
+ batch_size = len(sub_indices)
182
+ data = next(fake_loader)
183
+ for key in data:
184
+ if key != 'name':
185
+ data[key] = data[key][:batch_size].cuda(
186
+ torch.cuda.current_device(), non_blocking=True)
187
+ with torch.no_grad():
188
+ real_images = data['image']
189
+ #valids = data['valid']
190
+ return_dict = self.forward_pass(data, return_vals='all', only_enc = True)
191
+ wp = return_dict['wp_mixed']
192
+ eouts = return_dict['eouts']
193
+ edit_wp = wp + factor * direction
194
+ edited_images, _ = self.runG(edit_wp, "synthesis", highres_outs=eouts)
195
+ fake_feature_list.append(
196
+ extract_feature(self.inception_model, edited_images))
197
+ self.logger.update_pbar(task1, batch_size * self.world_size)
198
+
199
+ np.save(f'{self.work_dir}/fake_fid_features_{self.rank}.npy',
200
+ np.concatenate(fake_feature_list, axis=0))
201
+ self.logger.close_pbar()
202
+
203
+
204
+ #Extract features from real images if needed.
205
+ cached_fid_file = f'{self.work_dir}/real_{attribute}_{factor}_fid.npy'
206
+ do_real_test = (not os.path.exists(cached_fid_file) or ignore_cache)
207
+ if do_real_test:
208
+ real_feature_list = []
209
+ self.logger.init_pbar()
210
+
211
+ if factor < 0:
212
+ self.config.data['smile']['root_dir'] = f"/media/hdd2/adundar/hamza/genforce/data/temp/smile_without_original"
213
+ elif factor > 0:
214
+ self.config.data['smile']['root_dir'] = f"/media/hdd2/adundar/hamza/genforce/data/temp/smile_with_original"
215
+ real_loader = self.build_dataset(f"smile")
216
+
217
+ fid_num = len(real_loader.dataset)
218
+ indices = list(range(self.rank, fid_num, self.world_size))
219
+ task2 = self.logger.add_pbar_task(f"{attribute}_real", total=fid_num)
220
+ for batch_idx in range(0, len(indices), self.val_batch_size):
221
+ sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
222
+ batch_size = len(sub_indices)
223
+ data = next(real_loader)
224
+ for key in data:
225
+ if key != 'name':
226
+ data[key] = data[key][:batch_size].cuda(
227
+ torch.cuda.current_device(), non_blocking=True)
228
+ with torch.no_grad():
229
+ real_images = data['image']
230
+ real_feature_list.append(
231
+ extract_feature(self.inception_model, real_images))
232
+ self.logger.update_pbar(task2, batch_size * self.world_size)
233
+ np.save(f'{self.work_dir}/real_fid_features_{self.rank}.npy',
234
+ np.concatenate(real_feature_list, axis=0))
235
+
236
+ dist.barrier()
237
+ if self.rank != 0:
238
+ return -1
239
+ self.logger.close_pbar()
240
+
241
+ # Collect fake features.
242
+ fake_feature_list.clear()
243
+ for rank in range(self.world_size):
244
+ fake_feature_list.append(
245
+ np.load(f'{self.work_dir}/fake_fid_features_{rank}.npy'))
246
+ os.remove(f'{self.work_dir}/fake_fid_features_{rank}.npy')
247
+ fake_features = np.concatenate(fake_feature_list, axis=0)
248
+ # assert fake_features.ndim == 2 and fake_features.shape[0] == fid_num
249
+ feature_dim = fake_features.shape[1]
250
+ feature_num = fake_features.shape[0]
251
+ pad = feature_num % self.world_size #feature_dim.shape[0]
252
+ if pad:
253
+ pad = self.world_size - pad
254
+ fake_features = np.pad(fake_features, ((0, pad), (0, 0)))
255
+ fake_features = fake_features.reshape(self.world_size, -1, feature_dim)
256
+ fake_features = fake_features.transpose(1, 0, 2)
257
+ fake_features = fake_features.reshape(-1, feature_dim)[:feature_num]
258
+
259
+ # Collect (or load) real features.
260
+ if do_real_test:
261
+ real_feature_list.clear()
262
+ for rank in range(self.world_size):
263
+ real_feature_list.append(
264
+ np.load(f'{self.work_dir}/real_fid_features_{rank}.npy'))
265
+ os.remove(f'{self.work_dir}/real_fid_features_{rank}.npy')
266
+ real_features = np.concatenate(real_feature_list, axis=0)
267
+ # assert real_features.shape == (fid_num, feature_dim)
268
+ feature_dim = real_features.shape[1]
269
+ feature_num = real_features.shape[0]
270
+ pad = feature_num % self.world_size
271
+ if pad:
272
+ pad = self.world_size - pad
273
+ real_features = np.pad(real_features, ((0, pad), (0, 0)))
274
+ real_features = real_features.reshape(
275
+ self.world_size, -1, feature_dim)
276
+ real_features = real_features.transpose(1, 0, 2)
277
+ real_features = real_features.reshape(-1, feature_dim)[:feature_num]
278
+ np.save(cached_fid_file, real_features)
279
+ else:
280
+ real_features = np.load(cached_fid_file)
281
+ # assert real_features.shape == (fid_num, feature_dim)
282
+
283
+ fid_value = compute_fid(fake_features, real_features)
284
+ return fid_value
285
+
286
+ def fid(self,
287
+ fid_num,
288
+ z=None,
289
+ ignore_cache=False,
290
+ align_tf=True):
291
+ """Computes the FID metric."""
292
+ self.set_mode('val')
293
+
294
+ if self.val_loader is None:
295
+ self.build_dataset('val')
296
+ fid_num = min(fid_num, len(self.val_loader.dataset))
297
+
298
+ if self.inception_model is None:
299
+ if align_tf:
300
+ self.logger.info(f'Building inception model '
301
+ f'(aligned with TensorFlow) ...')
302
+ else:
303
+ self.logger.info(f'Building inception model '
304
+ f'(using torchvision) ...')
305
+ self.inception_model = build_inception_model(align_tf).cuda()
306
+ self.logger.info(f'Finish building inception model.')
307
+
308
+ if z is not None:
309
+ assert isinstance(z, np.ndarray)
310
+ assert z.ndim == 2 and z.shape[1] == self.z_space_dim
311
+ fid_num = min(fid_num, z.shape[0])
312
+ z = torch.from_numpy(z).type(torch.FloatTensor)
313
+ if not fid_num:
314
+ return -1
315
+
316
+ indices = list(range(self.rank, fid_num, self.world_size))
317
+
318
+ self.logger.init_pbar()
319
+ #generator = self.run_with_optim if run_with_optim else self.run_without_optim
320
+ # Extract features from fake images.
321
+ fake_feature_list = []
322
+ real_feature_list = []
323
+ task1 = self.logger.add_pbar_task('FID', total=fid_num)
324
+ for batch_idx in range(0, len(indices), self.val_batch_size):
325
+ sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
326
+ batch_size = len(sub_indices)
327
+ data = next(self.val_loader)
328
+ for key in data:
329
+ if key != 'name':
330
+ data[key] = data[key][:batch_size].cuda(
331
+ torch.cuda.current_device(), non_blocking=True)
332
+
333
+ # z_rand = torch.randn((batch_size,self.config.z_count,self.config.latent_dim)
334
+ # , device=torch.cuda.current_device())
335
+ # data['z_rand'] = z_rand
336
+ with torch.no_grad():
337
+ real_images = data['image']
338
+ #valids = data['valid']
339
+ return_dict = self.forward_pass(data, return_vals='fakes, wp_mixed')
340
+ fakes = return_dict['fakes']
341
+ if self.config.test_time_optims != 0:
342
+ wp_mixed = return_dict['wp_mixed']
343
+ fakes = self.optimize(data, wp_mixed)
344
+ with torch.no_grad():
345
+ #final_out = real_images * valids + fakes * (1.0-valids) #Final output is the mixed one.
346
+ fake_feature_list.append(
347
+ extract_feature(self.inception_model, fakes))
348
+ # Extract features from real images if needed.
349
+ cached_fid_file = f'{self.work_dir}/real_fid{fid_num}.npy'
350
+ do_real_test = (not os.path.exists(cached_fid_file) or ignore_cache)
351
+ if do_real_test:
352
+ with torch.no_grad():
353
+ real_feature_list.append(
354
+ extract_feature(self.inception_model, real_images))
355
+ self.logger.update_pbar(task1, batch_size * self.world_size)
356
+
357
+ np.save(f'{self.work_dir}/fake_fid_features_{self.rank}.npy',
358
+ np.concatenate(fake_feature_list, axis=0))
359
+ if (do_real_test):
360
+ np.save(f'{self.work_dir}/real_fid_features_{self.rank}.npy',
361
+ np.concatenate(real_feature_list, axis=0))
362
+
363
+ # Extract features from real images if needed.
364
+ # cached_fid_file = f'{self.work_dir}/real_fid{fid_num}.npy'
365
+ # do_real_test = (not os.path.exists(cached_fid_file) or ignore_cache)
366
+ # if do_real_test:
367
+ # real_feature_list = []
368
+ # task2 = self.logger.add_pbar_task("Real", total=fid_num)
369
+ # for batch_idx in range(0, len(indices), self.val_batch_size):
370
+ # sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
371
+ # batch_size = len(sub_indices)
372
+ # data = next(self.val_loader)
373
+ # for key in data:
374
+ # data[key] = data[key][:batch_size].cuda(
375
+ # torch.cuda.current_device(), non_blocking=True)
376
+ # with torch.no_grad():
377
+ # real_images = data['image']
378
+ # real_feature_list.append(
379
+ # extract_feature(self.inception_model, real_images))
380
+ # self.logger.update_pbar(task2, batch_size * self.world_size)
381
+ # np.save(f'{self.work_dir}/real_fid_features_{self.rank}.npy',
382
+ # np.concatenate(real_feature_list, axis=0))
383
+
384
+ dist.barrier()
385
+ if self.rank != 0:
386
+ return -1
387
+ self.logger.close_pbar()
388
+
389
+ # Collect fake features.
390
+ fake_feature_list.clear()
391
+ for rank in range(self.world_size):
392
+ fake_feature_list.append(
393
+ np.load(f'{self.work_dir}/fake_fid_features_{rank}.npy'))
394
+ os.remove(f'{self.work_dir}/fake_fid_features_{rank}.npy')
395
+ fake_features = np.concatenate(fake_feature_list, axis=0)
396
+ assert fake_features.ndim == 2 and fake_features.shape[0] == fid_num
397
+ feature_dim = fake_features.shape[1]
398
+ pad = fid_num % self.world_size
399
+ if pad:
400
+ pad = self.world_size - pad
401
+ fake_features = np.pad(fake_features, ((0, pad), (0, 0)))
402
+ fake_features = fake_features.reshape(self.world_size, -1, feature_dim)
403
+ fake_features = fake_features.transpose(1, 0, 2)
404
+ fake_features = fake_features.reshape(-1, feature_dim)[:fid_num]
405
+
406
+ # Collect (or load) real features.
407
+ if do_real_test:
408
+ real_feature_list.clear()
409
+ for rank in range(self.world_size):
410
+ real_feature_list.append(
411
+ np.load(f'{self.work_dir}/real_fid_features_{rank}.npy'))
412
+ os.remove(f'{self.work_dir}/real_fid_features_{rank}.npy')
413
+ real_features = np.concatenate(real_feature_list, axis=0)
414
+ assert real_features.shape == (fid_num, feature_dim)
415
+ real_features = np.pad(real_features, ((0, pad), (0, 0)))
416
+ real_features = real_features.reshape(
417
+ self.world_size, -1, feature_dim)
418
+ real_features = real_features.transpose(1, 0, 2)
419
+ real_features = real_features.reshape(-1, feature_dim)[:fid_num]
420
+ np.save(cached_fid_file, real_features)
421
+ else:
422
+ real_features = np.load(cached_fid_file)
423
+ assert real_features.shape == (fid_num, feature_dim)
424
+
425
+ fid_value = compute_fid(fake_features, real_features)
426
+ return fid_value
427
+
428
+ def val(self, **val_kwargs):
429
+ self.synthesize(**val_kwargs)
430
+
431
+ def synthesize(self,
432
+ num,
433
+ html_name=None,
434
+ save_raw_synthesis=False):
435
+ """Synthesizes images.
436
+
437
+ Args:
438
+ num: Number of images to synthesize.
439
+ z: Latent codes used for generation. If not specified, this function
440
+ will sample latent codes randomly. (default: None)
441
+ html_name: Name of the output html page for visualization. If not
442
+ specified, no visualization page will be saved. (default: None)
443
+ save_raw_synthesis: Whether to save raw synthesis on the disk.
444
+ (default: False)
445
+ """
446
+
447
+ dist.barrier()
448
+ if self.rank != 0:
449
+ return
450
+
451
+ if not html_name and not save_raw_synthesis:
452
+ return
453
+
454
+ self.set_mode('val')
455
+
456
+ if self.val_loader is None:
457
+ self.build_dataset('val')
458
+
459
+ # temp_dir = os.path.join(self.work_dir, 'synthesize_results')
460
+ # os.makedirs(temp_dir, exist_ok=True)
461
+
462
+ if not num:
463
+ return
464
+ # if num % self.val_batch_size != 0:
465
+ # num = (num //self.val_batch_size +1)*self.val_batch_size
466
+ # TODO: Use same z during the entire training process.
467
+
468
+ self.logger.init_pbar()
469
+ task = self.logger.add_pbar_task('Synthesis', total=num)
470
+ for i in range(num):
471
+ data = next(self.val_loader)
472
+ for key in data:
473
+ if key != 'name':
474
+ data[key] = data[key].cuda(
475
+ torch.cuda.current_device(), non_blocking=True)
476
+
477
+ with torch.no_grad():
478
+ real_images = data['image']
479
+ return_dict = self.forward_pass(data, return_vals='all')
480
+ fakes = return_dict['fakes']
481
+ wp_mixed = return_dict['wp_mixed']
482
+ eouts = return_dict['eouts']
483
+
484
+ log_list_gpu = {"real": real_images, "fake": fakes}
485
+
486
+ # Add editings to log_list
487
+ editings = ['age', 'pose', 'smile']
488
+ for edit in editings:
489
+ direction = torch.load(f'editings/interfacegan_directions/{edit}.pt').cuda()
490
+ factors = [+3, -3]
491
+ for factor in factors:
492
+ name = f"{edit}_{factor}"
493
+ edit_wp = wp_mixed + factor * direction
494
+ edited_images, _ = self.runG(edit_wp, "synthesis", highres_outs=eouts)
495
+ # if edit == 'smile' and factor == -3:
496
+ # res = gouts['gates'].shape[-1]
497
+ # log_list_gpu[f'smile_-3_gate'] = ( torch.mean((gouts_edits['gates']) , dim=1, keepdim=True), 0)
498
+ #edited_images = F.adaptive_avg_pool2d(edited_images, 256)
499
+ log_list_gpu[name] = edited_images
500
+ #log_list_gpu[f'{name}_gate'] = ( torch.mean((temp['gates']) , dim=1, keepdim=True), 0)
501
+
502
+ #Add gate to log_list
503
+ # res = gouts['gates'].shape[-1]
504
+ # log_list_gpu[f'gate{res}x{res}'] = ( torch.mean((gouts['gates']) , dim=1, keepdim=True), 0)
505
+
506
+ #Log images
507
+ for log_name, log_val in log_list_gpu.items():
508
+ log_im = log_val[0] if type(log_val) is tuple else log_val
509
+ min_val = log_val[1] if type(log_val) is tuple else -1
510
+ cpu_img = postprocess_tensor(log_im.detach().cpu(), min_val=min_val)
511
+ grid = torchvision.utils.make_grid(cpu_img, nrow=5)
512
+ log_image( name = f"image/{log_name}", grid=grid, iter=self.iter)
513
+ self.logger.update_pbar(task, 1)
514
+ self.logger.close_pbar()
515
+
516
+ def save_edited_images(self, opts):
517
+ dist.barrier()
518
+ if self.rank != 0:
519
+ return
520
+ self.set_mode('val')
521
+
522
+ if opts.method == 'inversion':
523
+ pass
524
+ elif opts.method == 'interfacegan':
525
+ direction = torch.load(f'editings/interfacegan_directions/{opts.edit}.pt').cuda()
526
+ elif opts.method == 'ganspace':
527
+ ganspace_pca = torch.load('editings/ganspace_pca/ffhq_pca.pt')
528
+ direction = {
529
+ 'eye_openness': (54, 7, 8, 20),
530
+ 'smile': (46, 4, 5, -20),
531
+ 'beard': (58, 7, 9, -20),
532
+ 'white_hair': (57, 7, 10, -24),
533
+ 'lipstick': (34, 10, 11, 20),
534
+ 'overexposed': (27, 8, 18, 15),
535
+ 'screaming': (35, 3, 7, -10),
536
+ 'head_angle_up': (11, 1, 4, 10),
537
+ }
538
+ editor = LatentEditor()
539
+ elif opts.method == 'styleclip':
540
+ #model_path = '/media/hdd2/adundar/hamza/hyperstyle/pretrained_models/stylegan2-ffhq-config-f.pt'
541
+ # calculator_args = {
542
+ # 'delta_i_c': 'editings/styleclip/global_directions/ffhq/fs3.npy',
543
+ # 's_statistics': 'editings/styleclip/global_directions/ffhq/S_mean_std',
544
+ # 'text_prompt_templates': 'editings/styleclip/global_directions/templates.txt'
545
+ # }
546
+ stylegan_model = load_stylegan_generator(opts.model_path)
547
+ global_direction_calculator = load_direction_calculator(opts.calculator_args)
548
+ #Eyeglasses 5, bangs 2, bobcut 5
549
+ # edit_args = {'alpha_min': 2, 'alpha_max': 2, 'num_alphas':1, 'beta_min':0.11, 'beta_max':0.11, 'num_betas': 1,
550
+ # 'neutral_text':'face', 'target_text': 'face with bangs'}
551
+
552
+
553
+ self.config.data['val']['root_dir'] = opts.dataset
554
+
555
+ dataset = BaseDataset(**self.config.data['val'])
556
+ val_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
557
+
558
+ temp_dir = os.path.join(self.work_dir, opts.output)
559
+ os.makedirs(temp_dir, exist_ok=True)
560
+
561
+
562
+ self.logger.init_pbar()
563
+ task = self.logger.add_pbar_task('Synthesis', total=len(val_loader))
564
+ global_i = 0
565
+ all_latents = {}
566
+ for idx, data in enumerate(val_loader):
567
+ for key in data:
568
+ if key != 'name':
569
+ data[key] = data[key].cuda(
570
+ torch.cuda.current_device(), non_blocking=True)
571
+ with torch.no_grad():
572
+ return_dict = self.forward_pass(data, return_vals='all', only_enc=True)
573
+ wp_mixed = return_dict['wp_mixed']
574
+ eouts = return_dict['eouts']
575
+ #fakes = return_dict['fakes']
576
+ factors = np.linspace(0, 3, 100)
577
+ global_i = 0
578
+
579
+ # for factor in factors:
580
+ if opts.method == 'interfacegan':
581
+ wp_mixed = wp_mixed + opts.factor * direction
582
+ if opts.method == 'ganspace':
583
+ #interpolate_dir = direction[opts.edit][0:3] + (factor,)
584
+ wp_mixed = editor.apply_ganspace_(wp_mixed, ganspace_pca, [direction[opts.edit]])
585
+ #wp_edit = editor.apply_ganspace_(wp_mixed, ganspace_pca, [interpolate_dir])
586
+
587
+
588
+ # z = torch.randn((1,self.config.latent_dim), device=torch.cuda.current_device())
589
+ # z = self.runM(z)
590
+ # diff = z - wp_mixed
591
+ # edit = (diff * 3.5) / 10
592
+ # wp_mixed = wp_mixed + edit
593
+ edited_images, gouts_edits = self.runG(wp_mixed, "synthesis", highres_outs=eouts, resize=False)
594
+ if opts.method == 'styleclip':
595
+ # opts.edit_args['alpha_min'] = factor
596
+ edited_images = styleclip_edit(wp_mixed, gouts_edits['additions'], stylegan_model, global_direction_calculator, opts.edit_args)
597
+ edited_images = T.Resize((256,256))(edited_images)
598
+ edited_images = postprocess_image(edited_images.detach().cpu().numpy())
599
+ for j in range(edited_images.shape[0]):
600
+ # dir_name = data['name'][j][:-4]
601
+ # os.makedirs(os.path.join(temp_dir, dir_name), exist_ok=True)
602
+ # save_name = f'{global_i:03d}_' + data['name'][j]
603
+ save_name = data['name'][j]
604
+ pil_img = Image.fromarray(edited_images[j]) #.resize((256,256))
605
+ #pil_img.save(os.path.join(temp_dir, dir_name, save_name ))
606
+ pil_img.save(os.path.join(temp_dir, save_name ))
607
+ global_i += 1
608
+ # if global_i >= 1000:
609
+ # break
610
+ # if global_i % 100 == 0:
611
+ # print(f"{global_i}/1000")
612
+
613
+ self.logger.update_pbar(task, 1)
614
+ self.logger.close_pbar()
615
+
616
+ # def forward_pass(self, data, return_vals='all', only_enc=False):
617
+ # encoder_type = self.config.encoder_type
618
+ # forward_func = getattr(self, f'{encoder_type}_forward')
619
+ # return_dict = forward_func(data,only_enc)
620
+ # return return_dict
621
+ # # if return_vals == 'all':
622
+ # # return return_dict
623
+ # # requested = return_vals.split(',')
624
+ # # modified_dict = {}
625
+ # # for request in requested:
626
+ # # stripped_request = request.strip()
627
+ # # modified_dict[stripped_request] = return_dict[stripped_request]
628
+ # # return modified_dict
629
+
630
+ # def base_forward(self, data):
631
+ # reals = data['image']
632
+ # valids = data['valid']
633
+ # z_rand = data['z_rand']
634
+
635
+ # wp_rand = self.runM(z_rand)
636
+ # wp_enc, blender = self.runE(reals, valids)
637
+ # wp_mixed = self.mix(wp_enc, wp_rand, blender)
638
+ # fakes = self.runG(wp_mixed, 'synthesis')
639
+ # return_dict = {'fakes': fakes, 'wp_enc': wp_enc, 'blender': blender, 'wp_mixed':wp_mixed}
640
+ # return return_dict
641
+
642
+ # def pSp_forward(self, data, only_enc):
643
+ # return self.e4e_forward(data, only_enc)
644
+
645
+ # def train_forward(self, data, iscycle=False):
646
+ # reals = data['image']
647
+ # direction = data['direction']
648
+ # edit_name = data['edit_name']
649
+ # factor = data['factor']
650
+ # E = self.models['encoder']
651
+ # with torch.no_grad():
652
+ # wp, eouts = E(reals)
653
+ # #wp = wp + self.meanw.repeat(reals.shape[0], 1, 1)
654
+ # edit = torch.zeros_like(wp)
655
+ # for i in range (edit.shape[0]):
656
+ # if edit_name[i] is None:
657
+ # edit[i] = 0
658
+ # elif edit_name[i] == 'randw':
659
+ # diff = direction[i] - wp[i]
660
+ # # one_hot = [1] * 8 + [0] * 10
661
+ # # one_hot = torch.tensor(one_hot, device=diff.device).unsqueeze(1)
662
+ # # diff = diff * one_hot
663
+ # #norm = torch.linalg.norm(diff, dim=1, keepdim=True)
664
+ # edit[i] = (diff * factor[i]) / 10
665
+ # elif edit_name[i] == 'interface':
666
+ # edit[i] = (factor[i] * direction[i])
667
+
668
+ # # # Debug
669
+ # # with torch.no_grad():
670
+ # # fakes,_ =self.runG(wp, 'synthesis', highres_outs=None)
671
+ # # fakes = postprocess_image(fakes.detach().cpu().numpy())
672
+ # # for i in range(fakes.shape[0]):
673
+ # # pil_img = Image.fromarray(fakes[i]).resize((256,256))
674
+ # # pil_img.save(f'{self.iter}_orig.png')
675
+
676
+ # # fakes,_ =self.runG(wp+edit, 'synthesis', highres_outs=None)
677
+ # # fakes = postprocess_image(fakes.detach().cpu().numpy())
678
+ # # for i in range(fakes.shape[0]):
679
+ # # pil_img = Image.fromarray(fakes[i]).resize((256,256))
680
+ # # pil_img.save(f'{self.iter}_edit.png')
681
+
682
+ # # fakes,_ =self.runG(direction.unsqueeze(1).repeat(1,18,1), 'synthesis', highres_outs=None)
683
+ # # fakes = postprocess_image(fakes.detach().cpu().numpy())
684
+ # # for i in range(fakes.shape[0]):
685
+ # # pil_img = Image.fromarray(fakes[i]).resize((256,256))
686
+ # # pil_img.save(f'{self.iter}_rand.png')
687
+ # with torch.no_grad():
688
+ # eouts['inversion'] = self.runG(wp, 'synthesis', highres_outs=None, return_f=True)
689
+ # wp = wp + edit
690
+ # fakes, gouts = self.runG(wp, 'synthesis', highres_outs=eouts)
691
+ # #fakes = F.adaptive_avg_pool2d(fakes, (256,256))
692
+ # fakes_cycle = None
693
+ # if iscycle:
694
+ # # wp_cycle = wp_cycle + self.meanw.repeat(reals.shape[0], 1, 1)
695
+ # with torch.no_grad():
696
+ # wp_cycle, eout_cycle = E(fakes)
697
+ # eout_cycle['inversion'] = self.runG(wp_cycle, 'synthesis', highres_outs=None, return_f=True)
698
+
699
+ # #wp_cycle = wp_cycle - edit
700
+ # wp_cycle = wp_cycle - edit
701
+ # #wp_cycle = wp_cycle - (data['factor'] * data['direction']).unsqueeze(1)
702
+ # fakes_cycle, _ = self.runG(wp_cycle, 'synthesis', highres_outs=eout_cycle)
703
+ # #fakes_cycle = F.adaptive_avg_pool2d(fakes, (256,256))
704
+ # #cycle = F.mse_loss(fakes_cycle, reals, reduction='mean')
705
+ # return_dict = {'fakes': fakes, 'wp_mixed':wp, 'gouts':gouts, 'eouts': eouts, 'cycle': fakes_cycle}
706
+ # return return_dict
707
+
708
+ # def e4e_forward(self, data, only_enc=False):
709
+ # #return self.base_forward(data)
710
+ # reals = data['image']
711
+ # #valids = data['valid']
712
+ # E = self.models['encoder']
713
+ # wp_mixed, eouts = E(reals)
714
+ # #wp_mixed = wp_mixed + self.meanw.repeat(reals.shape[0], 1, 1)
715
+ # eouts['inversion'] = self.runG(wp_mixed, 'synthesis', highres_outs=None, return_f=True)
716
+ # if only_enc:
717
+ # return_dict = {'wp_mixed':wp_mixed,'eouts': eouts}
718
+ # return return_dict
719
+ # fakes, gouts = self.runG(wp_mixed, 'synthesis', highres_outs=eouts)
720
+ # #fakes = self.runG(wp_mixed, 'synthesis', highres_outs=None)
721
+ # #fakes = F.adaptive_avg_pool2d(fakes, (256,256))
722
+ # return_dict = {'fakes': fakes, 'wp_mixed':wp_mixed, 'gouts':gouts, 'eouts': eouts}
723
+ # return return_dict
724
+
725
+
726
+ # def hyperstyle_forward(self, data):
727
+ # return_dict = self.base_forward(data)
728
+
729
+ # E = self.models['encoder']
730
+ # reals = data['image']
731
+ # valids = data['valid']
732
+
733
+ # #HyperNetwork
734
+ # weight_deltas = E(reals, valids, mode='hyper', gouts=return_dict['fakes'])
735
+ # fakes = self.runG(return_dict['wp_mixed'], 'synthesis', weight_deltas=weight_deltas)
736
+ # return_dict['fakes'] = fakes
737
+
738
+ # return return_dict
739
+
740
+ def interface_generate(self, num, edit, factor):
741
+ direction = torch.load(f'editings/interfacegan_directions/{edit}.pt').cuda()
742
+ indices = list(range(self.rank, num, self.world_size))
743
+ gt_path = os.path.join(self.work_dir, f'interfacegan_gt')
744
+ smile_add_path = os.path.join(self.work_dir, f'interfacegan_{edit}_{factor}')
745
+ smile_rm_path = os.path.join(self.work_dir, f'interfacegan_{edit}_-{factor}')
746
+ if self.rank == 0:
747
+ os.makedirs(gt_path, exist_ok=True)
748
+ os.makedirs(smile_add_path, exist_ok=True)
749
+ os.makedirs(smile_rm_path, exist_ok=True)
750
+ dist.barrier()
751
+
752
+ self.logger.init_pbar()
753
+ task = self.logger.add_pbar_task('Interfacegan', total=num)
754
+
755
+
756
+ for batch_idx in range(0, len(indices), self.val_batch_size):
757
+ sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
758
+ batch_size = len(sub_indices)
759
+
760
+ z = torch.randn((batch_size,512), device=torch.cuda.current_device())
761
+ w_r = self.runM(z, repeat_w=True)
762
+ gt_imgs,_ = self.runG(w_r, resize=False)
763
+ gt_imgs = postprocess_image(gt_imgs.detach().cpu().numpy())
764
+ for i in range(gt_imgs.shape[0]):
765
+ save_name = str(sub_indices[i]) + ".png"
766
+ pil_img = Image.fromarray(gt_imgs[i]).resize((256,256))
767
+ pil_img.save(os.path.join(gt_path, save_name ))
768
+
769
+ smile_added, _ = self.runG(w_r + factor*direction, resize=False)
770
+ smile_added = postprocess_image(smile_added.detach().cpu().numpy())
771
+ for i in range(gt_imgs.shape[0]):
772
+ save_name = str(sub_indices[i]) + ".png"
773
+ pil_img = Image.fromarray(smile_added[i]).resize((256,256))
774
+ pil_img.save(os.path.join(smile_add_path, save_name ))
775
+
776
+ smile_removed, _= self.runG(w_r - factor*direction, resize=False)
777
+ smile_removed = postprocess_image(smile_removed.detach().cpu().numpy())
778
+ for i in range(gt_imgs.shape[0]):
779
+ save_name = str(sub_indices[i]) + ".png"
780
+ pil_img = Image.fromarray(smile_removed[i]).resize((256,256))
781
+ pil_img.save(os.path.join(smile_rm_path, save_name ))
782
+
783
+ self.logger.update_pbar(task, batch_size * self.world_size)
784
+ self.logger.close_pbar()
785
+
786
+ def grad_edit(self, edit, factor, dataset=None):
787
+ dist.barrier()
788
+ if self.rank != 0:
789
+ return
790
+ self.set_mode('val')
791
+ edit_name = edit
792
+ edit = 'val'
793
+ self.config.data[edit]['root_dir'] = dataset
794
+
795
+ dataset = BaseDataset(**self.config.data[edit])
796
+ val_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
797
+
798
+ temp_dir = os.path.join(self.work_dir, f'fakes_{edit_name}_{factor}')
799
+ os.makedirs(temp_dir, exist_ok=True)
800
+
801
+ self.logger.init_pbar()
802
+ task = self.logger.add_pbar_task('Synthesis', total=len(val_loader))
803
+ global_i = 0
804
+ args = {'model': 'ffhq', 'model_dir': '/media/hdd2/adundar/hamza/genforce/editings/GradCtrl/model_ffhq',
805
+ 'attribute': edit_name, 'exclude': 'default', 'top_channels': 'default', 'layerwise': 'default' }
806
+ for idx, data in enumerate(val_loader):
807
+ for key in data:
808
+ if key != 'name':
809
+ data[key] = data[key].cuda(
810
+ torch.cuda.current_device(), non_blocking=True)
811
+ with torch.no_grad():
812
+ return_dict = self.forward_pass(data, return_vals='all', only_enc=True)
813
+ wp_mixed = return_dict['wp_mixed']
814
+ eouts = return_dict['eouts']
815
+
816
+ #fakes = return_dict['fakes']
817
+ edit_wp = gradctrl(args, wp_mixed, factor)
818
+ edited_images, gouts_edits = self.runG(edit_wp, "synthesis", highres_outs=eouts, resize=False)
819
+ #edited_images, gouts_edits = self.runG(wp_mixed, "synthesis", highres_outs=eouts, resize=False)
820
+ edited_images = postprocess_image(edited_images.detach().cpu().numpy())
821
+ for j in range(edited_images.shape[0]):
822
+ save_name = data['name'][j]
823
+ pil_img = Image.fromarray(edited_images[j]).resize((256,256))
824
+ pil_img.save(os.path.join(temp_dir, save_name ))
825
+ global_i += 1
826
+
827
+ self.logger.update_pbar(task, 1)
828
+ self.logger.close_pbar()
829
+
830
+ def measure_time(self, edit, factor, dataset=None, save_latents=False):
831
+ dist.barrier()
832
+ if self.rank != 0:
833
+ return
834
+ self.set_mode('val')
835
+ edit_name = edit
836
+ if dataset is not None:
837
+ edit = 'val'
838
+ self.config.data[edit]['root_dir'] = dataset
839
+
840
+ dataset = BaseDataset(**self.config.data[edit])
841
+ val_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
842
+
843
+ global_i = 0
844
+ time_list = []
845
+ for idx, data in enumerate(val_loader):
846
+ for key in data:
847
+ if key != 'name':
848
+ data[key] = data[key].cuda(
849
+ torch.cuda.current_device(), non_blocking=True)
850
+ with torch.no_grad():
851
+ start = time.time()
852
+ return_dict = self.forward_pass(data, return_vals='all', only_enc=True)
853
+ wp_mixed = return_dict['wp_mixed']
854
+ eouts = return_dict['eouts']
855
+ edited_images, gouts_edits = self.runG(wp_mixed, "synthesis", highres_outs=eouts, resize=False)
856
+ end = time.time()
857
+ time_list.append(end-start)
858
+ print(np.mean(time_list))
859
+ print(np.mean(time_list[1:]))
860
+
861
+
editings/editor.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .interfacegan import InterFaceGAN
2
+ from .ganspace import GanSpace
3
+ from .styleclip import StyleClip
4
+ from options import Settings
5
+
6
+ """
7
+ Entry class for all the edits.
8
+ """
9
+ class Editor():
10
+ def __init__(self) -> None:
11
+ self.interfacegan_editor = InterFaceGAN()
12
+ self.ganspace_editor = GanSpace()
13
+ self.styleclip_editor = StyleClip()
14
+
15
+ def edit(self, latent, cfg):
16
+ # Finds the corresponding function using method name
17
+ if cfg.method == 'inversion':
18
+ return latent
19
+
20
+ editor = getattr(self, f'{cfg.method}_editor')
21
+ return editor.edit(latent, cfg)
22
+
23
+
editings/ganspace.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import csv
3
+ from options import Settings
4
+ import os
5
+
6
+ class GanSpace():
7
+ def __init__(self) -> None:
8
+
9
+ self.gan_space_configs = {}
10
+
11
+ with open(os.path.join(Settings.ganspace_directions, 'ganspace_configs.csv'), "r") as f:
12
+ reader = csv.reader(f, delimiter="\t")
13
+ for row in reader:
14
+ key = row.pop(0)
15
+ self.gan_space_configs[key] = list(map(int, row))
16
+
17
+ def edit(self, latent, cfg):
18
+ with torch.no_grad():
19
+ self.load_ganspace_pca()
20
+ gan_space_config = self.gan_space_configs[cfg.edit]
21
+ gan_space_config[-1] = cfg.strength
22
+ return self.edit_ganspace(latent, gan_space_config)
23
+
24
+ def load_ganspace_pca(self):
25
+ try: # Check if loaded
26
+ getattr(self, f"pca")
27
+ except:
28
+ pca = torch.load(os.path.join(Settings.ganspace_directions, 'ffhq_pca.pt'))
29
+ setattr(self, f"pca", pca)
30
+
31
+
32
+ def edit_ganspace(self, latents, config):
33
+ edit_latents = []
34
+ pca_idx, start, end, strength = config
35
+ for latent in latents:
36
+ delta = self.get_delta( latent, pca_idx, strength)
37
+ delta_padded = torch.zeros(latent.shape).to(Settings.device)
38
+ delta_padded[start:end] += delta.repeat(end - start, 1)
39
+ edit_latents.append(latent + delta_padded)
40
+ return torch.stack(edit_latents)
41
+
42
+ def get_delta(self, latent, idx, strength):
43
+ # pca: ganspace checkpoint. latent: (16, 512) w+
44
+ w_centered = latent - self.pca['mean'].to(Settings.device)
45
+ lat_comp = self.pca['comp'].to(Settings.device)
46
+ lat_std = self.pca['std'].to(Settings.device)
47
+ w_coord = torch.sum(w_centered[0].reshape(-1)*lat_comp[idx].reshape(-1)) / lat_std[idx]
48
+ delta = (strength - w_coord)*lat_comp[idx]*lat_std[idx]
49
+ return delta
editings/ganspace_pca/ganspace_configs.csv ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ frizzy_hair 31 2 6 20
2
+ background_blur 49 6 9 20
3
+ bald 21 2 5 20
4
+ big_smile 19 4 5 20
5
+ caricature_smile 26 3 8 13
6
+ scary_eyes 33 6 8 20
7
+ curly_hair 47 3 6 20
8
+ dark_bg_shiny_hair 13 8 9 20
9
+ dark_hair_and_light_pos 14 8 9 20
10
+ dark_hair 16 8 9 20
11
+ disgusted 43 6 8 -30
12
+ displeased 36 4 7 20
13
+ eye_openness 54 7 8 20
14
+ eye_wrinkles 28 6 8 20
15
+ eyebrow_thickness 37 8 9 20
16
+ face_roundness 37 0 5 20
17
+ fearful_eyes 54 4 10 20
18
+ hairline 21 4 5 -20
19
+ happy_frizzy_hair 30 0 8 20
20
+ happy_elderly_lady 27 4 7 20
21
+ head_angle_up 11 1 4 20
22
+ huge_grin 28 4 6 20
23
+ in_awe 23 3 6 -15
24
+ wide_smile 23 3 6 20
25
+ large_jaw 22 3 6 20
26
+ light_lr 15 8 9 10
27
+ lipstick_and_age 34 6 11 20
28
+ lipstick 34 10 11 20
29
+ mascara_vs_beard 41 6 9 20
30
+ nose_length 51 4 5 -20
31
+ elderly_woman 34 6 7 20
32
+ overexposed 27 8 18 15
33
+ screaming 35 3 7 -15
34
+ short_face 32 2 6 -20
35
+ show_front_teeth 59 4 5 40
36
+ smile 46 4 5 -20
37
+ straight_bowl_cut 20 4 5 -20
38
+ sunlight_in_face 10 8 9 10
39
+ trimmed_beard 58 7 9 20
40
+ white_hair 57 7 10 -24
41
+ wrinkles 20 6 7 -18
42
+ boyishness 8 2 5 20
editings/interfacegan.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from options import Settings
3
+ import os
4
+
5
+ class InterFaceGAN():
6
+ def __init__(self) -> None:
7
+ pass
8
+
9
+ def edit(self, latent, cfg):
10
+ with torch.no_grad():
11
+ return latent + cfg.strength * self.get_direction(cfg.edit)
12
+
13
+ def get_direction(self, editname):
14
+ try:
15
+ direction = getattr(self, f"{editname}_direction")
16
+ except:
17
+ direction = self.load_direction(editname)
18
+ if Settings.device != 'cpu':
19
+ direction = direction.to(Settings.device)
20
+ setattr(self, f"{editname}_direction", direction.clone())
21
+ return direction
22
+
23
+ def load_direction(self, editname):
24
+ direction = torch.load(os.path.join( Settings.interfacegan_directions, f'{editname}.pt'))
25
+ if Settings.device != 'cpu':
26
+ direction = direction.cuda()
27
+ return direction
editings/styleclip.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from editings.styleclip_directions.styleclip_mapper_network import LevelsMapper
3
+ import torch
4
+ import csv
5
+ from options import Settings
6
+ import os
7
+
8
+ class Options():
9
+ def __init__(self, no_coarse_mapper, no_medium_mapper, no_fine_mapper) -> None:
10
+ self.no_coarse_mapper = no_coarse_mapper
11
+ self.no_medium_mapper = no_medium_mapper
12
+ self.no_fine_mapper = no_fine_mapper
13
+
14
+ class StyleClip():
15
+ def __init__(self) -> None:
16
+ self.styleclip_mapping_configs = {}
17
+
18
+ with open(os.path.join(Settings.styleclip_settings, 'styleclip_mapping_configs.csv'), "r") as f:
19
+ reader = csv.reader(f)
20
+ for row in reader:
21
+ key = row.pop(0)
22
+ self.styleclip_mapping_configs[key] = list(map(lambda x: True if x == "True" else False, row))
23
+
24
+ def edit(self, latent, cfg):
25
+ with torch.no_grad():
26
+ if cfg.type == 'mapper':
27
+ mapper = self.build_mapper(cfg.edit)
28
+ return latent + cfg.strength * mapper(latent)
29
+ if cfg.type == 'global':
30
+
31
+ return latent + 10 * torch.load(os.path.join(Settings.styleclip_global_directions, 'makeup.pt'))
32
+
33
+ # def load_global_direction(self, editname):
34
+ # pass
35
+
36
+ def build_mapper(self, editname):
37
+ try: # Check if loaded
38
+ mapper = getattr(self, f"{editname}_mapper")
39
+ except:
40
+ opts = Options(*self.styleclip_mapping_configs[editname])
41
+ mapper = LevelsMapper(opts)
42
+ ckpt = torch.load(os.path.join(Settings.styleclip_mapper_directions, f'{editname}.pt'))
43
+ mapper.load_state_dict(ckpt, strict=True)
44
+ mapper.to(device=Settings.device)
45
+ for param in mapper.parameters():
46
+ param.requires_grad = False
47
+ mapper.eval()
48
+ setattr(self, f"{editname}_mapper", mapper)
49
+ return mapper
editings/styleclip_directions/__init__.py ADDED
File without changes
editings/styleclip_directions/styleclip_directions/global_directions/ffhq/S_mean_std ADDED
Binary file (75.1 kB). View file
 
editings/styleclip_directions/styleclip_directions/global_directions/templates.txt ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ a bad photo of a {}.
2
+ a sculpture of a {}.
3
+ a photo of the hard to see {}.
4
+ a low resolution photo of the {}.
5
+ a rendering of a {}.
6
+ graffiti of a {}.
7
+ a bad photo of the {}.
8
+ a cropped photo of the {}.
9
+ a tattoo of a {}.
10
+ the embroidered {}.
11
+ a photo of a hard to see {}.
12
+ a bright photo of a {}.
13
+ a photo of a clean {}.
14
+ a photo of a dirty {}.
15
+ a dark photo of the {}.
16
+ a drawing of a {}.
17
+ a photo of my {}.
18
+ the plastic {}.
19
+ a photo of the cool {}.
20
+ a close-up photo of a {}.
21
+ a black and white photo of the {}.
22
+ a painting of the {}.
23
+ a painting of a {}.
24
+ a pixelated photo of the {}.
25
+ a sculpture of the {}.
26
+ a bright photo of the {}.
27
+ a cropped photo of a {}.
28
+ a plastic {}.
29
+ a photo of the dirty {}.
30
+ a jpeg corrupted photo of a {}.
31
+ a blurry photo of the {}.
32
+ a photo of the {}.
33
+ a good photo of the {}.
34
+ a rendering of the {}.
35
+ a {} in a video game.
36
+ a photo of one {}.
37
+ a doodle of a {}.
38
+ a close-up photo of the {}.
39
+ a photo of a {}.
40
+ the origami {}.
41
+ the {} in a video game.
42
+ a sketch of a {}.
43
+ a doodle of the {}.
44
+ a origami {}.
45
+ a low resolution photo of a {}.
46
+ the toy {}.
47
+ a rendition of the {}.
48
+ a photo of the clean {}.
49
+ a photo of a large {}.
50
+ a rendition of a {}.
51
+ a photo of a nice {}.
52
+ a photo of a weird {}.
53
+ a blurry photo of a {}.
54
+ a cartoon {}.
55
+ art of a {}.
56
+ a sketch of the {}.
57
+ a embroidered {}.
58
+ a pixelated photo of a {}.
59
+ itap of the {}.
60
+ a jpeg corrupted photo of the {}.
61
+ a good photo of a {}.
62
+ a plushie {}.
63
+ a photo of the nice {}.
64
+ a photo of the small {}.
65
+ a photo of the weird {}.
66
+ the cartoon {}.
67
+ art of the {}.
68
+ a drawing of the {}.
69
+ a photo of the large {}.
70
+ a black and white photo of a {}.
71
+ the plushie {}.
72
+ a dark photo of a {}.
73
+ itap of a {}.
74
+ graffiti of the {}.
75
+ a toy {}.
76
+ itap of my {}.
77
+ a photo of a cool {}.
78
+ a photo of a small {}.
79
+ a tattoo of the {}.
editings/styleclip_directions/styleclip_mapper_network.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import Module
4
+ from torch.nn import functional as F
5
+ import math
6
+
7
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
8
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
9
+ input = input #.cuda()
10
+ if input.ndim == 3:
11
+ return (
12
+ F.leaky_relu(
13
+ input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope
14
+ )
15
+ * scale
16
+ )
17
+ else:
18
+ return (
19
+ F.leaky_relu(
20
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
21
+ )
22
+ * scale
23
+ )
24
+
25
+
26
+ class PixelNorm(nn.Module):
27
+ def __init__(self):
28
+ super().__init__()
29
+
30
+ def forward(self, input):
31
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
32
+
33
+ class EqualLinear(nn.Module):
34
+ def __init__(
35
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
36
+ ):
37
+ super().__init__()
38
+
39
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
40
+
41
+ if bias:
42
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
43
+
44
+ else:
45
+ self.bias = None
46
+
47
+ self.activation = activation
48
+
49
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
50
+ self.lr_mul = lr_mul
51
+
52
+ def forward(self, input):
53
+ if self.activation:
54
+ out = F.linear(input, self.weight * self.scale)
55
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
56
+
57
+ else:
58
+ out = F.linear(
59
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
60
+ )
61
+
62
+ return out
63
+
64
+ class Mapper(Module):
65
+
66
+ def __init__(self, latent_dim=512):
67
+ super(Mapper, self).__init__()
68
+
69
+ layers = [PixelNorm()]
70
+
71
+ for i in range(4):
72
+ layers.append(
73
+ EqualLinear(
74
+ latent_dim, latent_dim, lr_mul=0.01, activation='fused_lrelu'
75
+ )
76
+ )
77
+
78
+ self.mapping = nn.Sequential(*layers)
79
+
80
+
81
+ def forward(self, x):
82
+ x = self.mapping(x)
83
+ return x
84
+
85
+
86
+ class LevelsMapper(Module):
87
+
88
+ def __init__(self, opts):
89
+ super(LevelsMapper, self).__init__()
90
+
91
+ self.opts = opts
92
+
93
+ if not opts.no_coarse_mapper:
94
+ self.course_mapping = Mapper()
95
+ if not opts.no_medium_mapper:
96
+ self.medium_mapping = Mapper()
97
+ if not opts.no_fine_mapper:
98
+ self.fine_mapping = Mapper()
99
+
100
+ def forward(self, x):
101
+ x_coarse = x[:, :4, :]
102
+ x_medium = x[:, 4:8, :]
103
+ x_fine = x[:, 8:, :]
104
+
105
+ if not self.opts.no_coarse_mapper:
106
+ x_coarse = self.course_mapping(x_coarse)
107
+ else:
108
+ x_coarse = torch.zeros_like(x_coarse)
109
+ if not self.opts.no_medium_mapper:
110
+ x_medium = self.medium_mapping(x_medium)
111
+ else:
112
+ x_medium = torch.zeros_like(x_medium)
113
+ if not self.opts.no_fine_mapper:
114
+ x_fine = self.fine_mapping(x_fine)
115
+ else:
116
+ x_fine = torch.zeros_like(x_fine)
117
+
118
+
119
+ out = torch.cat([x_coarse, x_medium, x_fine], dim=1)
120
+
121
+ return out
editings/styleclip_directions/styleclip_mapping_configs.csv ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ afro,False,False,True
2
+ angry,False,False,True
3
+ beyonce,False,False,False
4
+ bobcut,False,False,True
5
+ bowlcut,False,False,True
6
+ curly_hair,False,False,True
7
+ hilary_clinton,False,False,False
8
+ depp,False,False,False
9
+ mohawk,False,False,True
10
+ purple_hair,False,False,False
11
+ surprised,False,False,True
12
+ taylor_swift,False,False,False
13
+ trump,False,False,False
14
+ zuckerberg,False,False,False
inference.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from PIL import Image
4
+ import torch
5
+ from torch.utils.data import DataLoader
6
+
7
+ from datasets.inference_dataset import InferenceDataset
8
+ from datasets.process_image import ImageProcessor
9
+ from models.styleres import StyleRes
10
+ from options.inference_options import InferenceOptions
11
+ from options import Settings
12
+ from utils import parse_config
13
+ from tqdm import tqdm
14
+
15
+ def initialize_styleres(checkpoint_path, device):
16
+ Settings.device = device
17
+ model = StyleRes()
18
+ model.load_ckpt(checkpoint_path)
19
+ model.send_to_device()
20
+ model.eval()
21
+ for param in model.parameters():
22
+ param.requires_grad = False
23
+ return model
24
+
25
+ def run():
26
+ args = InferenceOptions().parse()
27
+ edit_configs = parse_config(args.edit_configs)
28
+ if torch.cuda.is_available():
29
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30
+
31
+ dataset = InferenceDataset(args.datadir, aligner_path=args.aligner_path)
32
+ print(f"Dataset is created. Number of images is {len(dataset)}")
33
+ dataloader = DataLoader(dataset, batch_size = args.test_batch_size,
34
+ shuffle=False,
35
+ num_workers=int(args.test_workers),
36
+ drop_last=False)
37
+
38
+ if args.n_images == None:
39
+ args.n_images = len(dataset)
40
+
41
+ # Create output directories
42
+ output_dir = args.outdir
43
+ os.makedirs(output_dir, exist_ok=True)
44
+ for edit_config in edit_configs:
45
+ cfg_vals = edit_config.values()
46
+ edit_config.outdir = '_'.join( str(i) for i in cfg_vals)
47
+ os.makedirs( os.path.join(output_dir, edit_config.outdir), exist_ok=True)
48
+
49
+ resize_amount = (1024, 1024)
50
+ if args.resize_outputs:
51
+ resize_amount = (256,256)
52
+
53
+ # Setup model
54
+ model = initialize_styleres(args.checkpoint_path, device)
55
+
56
+ n_images = 0
57
+ for data in tqdm(dataloader):
58
+ if n_images >= args.n_images:
59
+ break
60
+ n_images = n_images + data['image'].shape[0]
61
+ for edit_config in edit_configs:
62
+ images = model.edit_images( data['image'], edit_config)
63
+ images = ImageProcessor.postprocess_image(images.detach().cpu().numpy())
64
+ for j in range( images.shape[0]):
65
+ save_name = data['name'][j]
66
+ pil_img = Image.fromarray(images[j]).resize(resize_amount)
67
+ pil_img.save(os.path.join(output_dir, edit_config.outdir, save_name))
68
+
69
+
70
+ if __name__ == '__main__':
71
+ run()
models/dnnlib/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from .util import EasyDict
models/dnnlib/util.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Miscellaneous utility classes and functions."""
10
+
11
+ # Util classes
12
+ # ------------------------------------------------------------------------------------------
13
+
14
+ class EasyDict(dict):
15
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
16
+
17
+ def __getattr__(self, name: str):
18
+ try:
19
+ return self[name]
20
+ except KeyError:
21
+ raise AttributeError(name)
22
+
23
+ def __setattr__(self, name, value):
24
+ self[name] = value
25
+
26
+ def __delattr__(self, name):
27
+ del self[name]
28
+
29
+
models/e4e.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from collections import namedtuple
7
+
8
+ def _upsample_add(x, y):
9
+ _, _, H, W = y.size()
10
+ return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
11
+
12
+
13
+
14
+ class EqualLinear(nn.Module):
15
+ def __init__(
16
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
17
+ ):
18
+ super().__init__()
19
+
20
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
21
+
22
+ if bias:
23
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
24
+
25
+ else:
26
+ self.bias = None
27
+
28
+ self.activation = activation
29
+
30
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
31
+ self.lr_mul = lr_mul
32
+
33
+ def forward(self, input):
34
+ # if self.activation:
35
+ # out = F.linear(input, self.weight * self.scale)
36
+ # out = fused_leaky_relu(out, self.bias * self.lr_mul)
37
+
38
+ # else:
39
+ out = F.linear(
40
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
41
+ )
42
+
43
+ return out
44
+
45
+ def __repr__(self):
46
+ return (
47
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
48
+ )
49
+
50
+ class GradualStyleBlock(nn.Module):
51
+ def __init__(self, in_c, out_c, spatial):
52
+ super(GradualStyleBlock, self).__init__()
53
+ self.out_c = out_c
54
+ self.spatial = spatial
55
+ num_pools = int(np.log2(spatial))
56
+ modules = []
57
+ modules += [nn.Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
58
+ nn.LeakyReLU()]
59
+ for i in range(num_pools - 1):
60
+ modules += [
61
+ nn.Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
62
+ nn.LeakyReLU()
63
+ ]
64
+ self.convs = nn.Sequential(*modules)
65
+ self.linear = EqualLinear(out_c, out_c, lr_mul=1)
66
+
67
+ def forward(self, x):
68
+ x = self.convs(x)
69
+ x = x.view(-1, self.out_c)
70
+ x = self.linear(x)
71
+ return x
72
+
73
+ class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
74
+ """ A named tuple describing a ResNet block. """
75
+
76
+ class bottleneck_IR(nn.Module):
77
+ def __init__(self, in_channel, depth, stride):
78
+ super(bottleneck_IR, self).__init__()
79
+ if in_channel == depth:
80
+ self.shortcut_layer = nn.MaxPool2d(1, stride)
81
+ else:
82
+ self.shortcut_layer = nn.Sequential(
83
+ nn.Conv2d(in_channel, depth, (1, 1), stride, bias=False),
84
+ nn.BatchNorm2d(depth)
85
+ )
86
+ self.res_layer = nn.Sequential(
87
+ nn.BatchNorm2d(in_channel),
88
+ nn.Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), nn.PReLU(depth),
89
+ nn.Conv2d(depth, depth, (3, 3), stride, 1, bias=False), nn.BatchNorm2d(depth)
90
+ )
91
+
92
+ def forward(self, x):
93
+ shortcut = self.shortcut_layer(x)
94
+ res = self.res_layer(x)
95
+ return res + shortcut
96
+
97
+ class SEModule(nn.Module):
98
+ def __init__(self, channels, reduction):
99
+ super(SEModule, self).__init__()
100
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
101
+ self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
102
+ self.relu = nn.ReLU(inplace=True)
103
+ self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
104
+ self.sigmoid = nn.Sigmoid()
105
+
106
+ def forward(self, x):
107
+ module_input = x
108
+ x = self.avg_pool(x)
109
+ x = self.fc1(x)
110
+ x = self.relu(x)
111
+ x = self.fc2(x)
112
+ x = self.sigmoid(x)
113
+ return module_input * x
114
+
115
+ class bottleneck_IR_SE(nn.Module):
116
+ def __init__(self, in_channel, depth, stride):
117
+ super(bottleneck_IR_SE, self).__init__()
118
+ if in_channel == depth:
119
+ self.shortcut_layer = nn.MaxPool2d(1, stride)
120
+ else:
121
+ self.shortcut_layer = nn.Sequential(
122
+ nn.Conv2d(in_channel, depth, (1, 1), stride, bias=False),
123
+ nn.BatchNorm2d(depth)
124
+ )
125
+ self.res_layer = nn.Sequential(
126
+ nn.BatchNorm2d(in_channel),
127
+ nn.Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
128
+ nn.PReLU(depth),
129
+ nn.Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
130
+ nn.BatchNorm2d(depth),
131
+ SEModule(depth, 16)
132
+ )
133
+
134
+ def forward(self, x):
135
+ shortcut = self.shortcut_layer(x)
136
+ res = self.res_layer(x)
137
+ return res + shortcut
138
+
139
+
140
+ def get_block(in_channel, depth, num_units, stride=2):
141
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
142
+
143
+ def get_blocks(num_layers):
144
+ if num_layers == 50:
145
+ blocks = [
146
+ get_block(in_channel=64, depth=64, num_units=3),
147
+ get_block(in_channel=64, depth=128, num_units=4),
148
+ get_block(in_channel=128, depth=256, num_units=14),
149
+ get_block(in_channel=256, depth=512, num_units=3)
150
+ ]
151
+ elif num_layers == 100:
152
+ blocks = [
153
+ get_block(in_channel=64, depth=64, num_units=3),
154
+ get_block(in_channel=64, depth=128, num_units=13),
155
+ get_block(in_channel=128, depth=256, num_units=30),
156
+ get_block(in_channel=256, depth=512, num_units=3)
157
+ ]
158
+ elif num_layers == 152:
159
+ blocks = [
160
+ get_block(in_channel=64, depth=64, num_units=3),
161
+ get_block(in_channel=64, depth=128, num_units=8),
162
+ get_block(in_channel=128, depth=256, num_units=36),
163
+ get_block(in_channel=256, depth=512, num_units=3)
164
+ ]
165
+ else:
166
+ raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
167
+ return blocks
168
+
169
+ class Encoder4Editing(nn.Module):
170
+ def __init__(self, num_layers, mode='ir', stylegan_size=1024, out_res=64):
171
+ super(Encoder4Editing, self).__init__()
172
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
173
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
174
+ blocks = get_blocks(num_layers)
175
+ if mode == 'ir':
176
+ unit_module = bottleneck_IR
177
+ elif mode == 'ir_se':
178
+ unit_module = bottleneck_IR_SE
179
+ self.out_res = out_res
180
+ self.input_layer = nn.Sequential(nn.Conv2d(3, 64, (3, 3), 1, 1, bias=False),
181
+ nn.BatchNorm2d(64),
182
+ nn.PReLU(64))
183
+ modules = []
184
+ for block in blocks:
185
+ for bottleneck in block:
186
+ modules.append(unit_module(bottleneck.in_channel,
187
+ bottleneck.depth,
188
+ bottleneck.stride))
189
+ self.body = nn.Sequential(*modules)
190
+
191
+ self.styles = nn.ModuleList()
192
+ log_size = int(math.log(stylegan_size, 2))
193
+ self.style_count = 2 * log_size - 2
194
+ self.coarse_ind = 3
195
+ self.middle_ind = 7
196
+
197
+ for i in range(self.style_count):
198
+ if i < self.coarse_ind:
199
+ style = GradualStyleBlock(512, 512, 16)
200
+ elif i < self.middle_ind:
201
+ style = GradualStyleBlock(512, 512, 32)
202
+ else:
203
+ style = GradualStyleBlock(512, 512, 64)
204
+ self.styles.append(style)
205
+
206
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
207
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
208
+
209
+ def forward(self, x):
210
+ x = self.input_layer(x)
211
+
212
+ modulelist = list(self.body._modules.values())
213
+ for i, l in enumerate(modulelist):
214
+ x = l(x)
215
+ if i == 2:
216
+ c0 = x
217
+ if i == 6:
218
+ c1 = x
219
+ elif i == 20:
220
+ c2 = x
221
+ elif i == 23:
222
+ c3 = x
223
+
224
+ # Infer main W and duplicate it
225
+ w0 = self.styles[0](c3)
226
+ w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2)
227
+
228
+ features = c3
229
+ for i in range(1, self.style_count): # Infer additional deltas
230
+ if i == self.coarse_ind:
231
+ p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features
232
+ features = p2
233
+ elif i == self.middle_ind:
234
+ p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features
235
+ features = p1
236
+ delta_i = self.styles[i](features)
237
+ w[:, i] += delta_i
238
+
239
+ c = { 128: c0,
240
+ 64: c1,
241
+ 32: c2,
242
+ 16: c3
243
+ }.get(self.out_res)
244
+ return w, c
245
+
246
+ class EqualConv2d(nn.Module):
247
+ def __init__(
248
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
249
+ ):
250
+ super().__init__()
251
+
252
+ self.weight = nn.Parameter(
253
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
254
+ )
255
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
256
+
257
+ self.stride = stride
258
+ self.padding = padding
259
+
260
+ if bias:
261
+ self.bias = nn.Parameter(torch.zeros(out_channel))
262
+
263
+ else:
264
+ self.bias = None
265
+
266
+ def forward(self, input):
267
+ out = F.conv2d(
268
+ input,
269
+ self.weight * self.scale,
270
+ bias=self.bias,
271
+ stride=self.stride,
272
+ padding=self.padding,
273
+ )
274
+
275
+ return out
276
+
277
+ def __repr__(self):
278
+ return (
279
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
280
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
281
+ )
282
+
283
+ class ScaledLeakyReLU(nn.Module):
284
+ def __init__(self, negative_slope=0.2):
285
+ super().__init__()
286
+
287
+ self.negative_slope = negative_slope
288
+
289
+ def forward(self, input):
290
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
291
+
292
+ return out * math.sqrt(2)
293
+
294
+ class HighResFeat(nn.Module):
295
+ def __init__(self, in_channels, out_channels):
296
+ super(HighResFeat, self).__init__()
297
+
298
+ self.shared = EqualConv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=True)
299
+
300
+ self.conv1 = EqualConv2d(out_channels, 1, kernel_size=3, padding=1, bias=True)
301
+ self.conv2 = EqualConv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=True)
302
+ self.activation = ScaledLeakyReLU(0.2)
303
+
304
+ self.sigmoid = nn.Sigmoid()
305
+
306
+ self.skip = None
307
+ if in_channels != out_channels:
308
+ self.skip = EqualConv2d(in_channels, out_channels, kernel_size=1, padding=0, bias=False)
309
+
310
+ def forward(self, x):
311
+
312
+ shared_feats = self.shared(x)
313
+ shared_feats = self.activation(shared_feats)
314
+
315
+ gate = self.conv1(shared_feats)
316
+ gate = self.sigmoid(gate)
317
+
318
+ addition = self.conv2(shared_feats)
319
+ addition = self.activation(addition)
320
+
321
+ if self.skip is not None:
322
+ x = self.skip(x)
323
+ return gate, addition+x
324
+
325
+ class E4E_Inversion(nn.Module):
326
+ def __init__(self, resolution, num_layers = 50, mode='ir_se', out_res=64):
327
+ super(E4E_Inversion, self).__init__()
328
+ self.out_res = out_res
329
+ resolution = 1024
330
+ self.basic_encoder = Encoder4Editing(num_layers, mode, resolution, self.out_res)
331
+ self.latent_avg = None
332
+ # ckpt = torch.load(e4e_path, map_location='cpu')
333
+ # self.latent_avg = ckpt['latent_avg'].cuda()
334
+ # ckpt = {k[k.find(".")+1:]: v for k, v in ckpt['state_dict'].items() if "decoder" not in k}
335
+ # self.basic_encoder.load_state_dict(ckpt, strict=True)
336
+
337
+ def freeze_basic_encoder(self):
338
+ self.basic_encoder.eval() #Basic Encoder always in eval mode.
339
+ #No backprop to basic Encoder
340
+ for param in self.basic_encoder.parameters():
341
+ param.requires_grad = False
342
+
343
+ def forward(self, reals):
344
+ self.freeze_basic_encoder()
345
+ w, c = self.basic_encoder(reals)
346
+ w = w + self.latent_avg
347
+ highres_outs = {f"{self.out_res}x{self.out_res}": c} #{"gates": gates, "additions": additions}
348
+ return w, highres_outs
models/stylegan2.py ADDED
@@ -0,0 +1,965 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Network architectures from the paper
10
+ "Analyzing and Improving the Image Quality of StyleGAN".
11
+ Matches the original implementation of configs E-F by Karras et al. at
12
+ https://github.com/NVlabs/stylegan2/blob/master/training/networks_stylegan2.py"""
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from .torch_utils import misc
19
+ from .torch_utils.ops import conv2d_resample
20
+ from .torch_utils.ops import upfirdn2d
21
+ from .torch_utils.ops import bias_act
22
+ from .torch_utils.ops import fma
23
+
24
+ #----------------------------------------------------------------------------
25
+
26
+
27
+ def normalize_2nd_moment(x, dim=1, eps=1e-8):
28
+ return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
29
+
30
+ #----------------------------------------------------------------------------
31
+
32
+
33
+ def modulated_conv2d(
34
+ x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
35
+ weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
36
+ styles, # Modulation coefficients of shape [batch_size, in_channels].
37
+ noise = None, # Optional noise tensor to add to the output activations.
38
+ up = 1, # Integer upsampling factor.
39
+ down = 1, # Integer downsampling factor.
40
+ padding = 0, # Padding with respect to the upsampled image.
41
+ resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
42
+ demodulate = True, # Apply weight demodulation?
43
+ flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
44
+ fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation?
45
+ weigth_deltas = None
46
+ ):
47
+ batch_size = x.shape[0]
48
+ out_channels, in_channels, kh, kw = weight.shape
49
+ misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk]
50
+ misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
51
+ misc.assert_shape(styles, [batch_size, in_channels]) # [NI]
52
+
53
+ # Pre-normalize inputs to avoid FP16 overflow.
54
+ if x.dtype == torch.float16 and demodulate:
55
+ weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk
56
+ styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
57
+
58
+ # Calculate per-sample weights and demodulation coefficients.
59
+ w = None
60
+ dcoefs = None
61
+ if demodulate or fused_modconv:
62
+ w = weight.unsqueeze(0) # [NOIkk]
63
+ #HyperStyle Addition for the Generator
64
+ if weigth_deltas is None:
65
+ w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
66
+ else:
67
+ w = w * (1 + weigth_deltas) * styles.reshape(batch_size, 1, -1, 1, 1)
68
+ if demodulate:
69
+ dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
70
+ if demodulate and fused_modconv:
71
+ w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
72
+
73
+ # Execute by scaling the activations before and after the convolution.
74
+ if not fused_modconv:
75
+ x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
76
+ x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight)
77
+ if demodulate and noise is not None:
78
+ x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
79
+ elif demodulate:
80
+ x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
81
+ elif noise is not None:
82
+ x = x.add_(noise.to(x.dtype))
83
+ return x
84
+
85
+ # Execute as one fused op using grouped convolution.
86
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
87
+ batch_size = int(batch_size)
88
+ misc.assert_shape(x, [batch_size, in_channels, None, None])
89
+ x = x.reshape(1, -1, *x.shape[2:])
90
+ w = w.reshape(-1, in_channels, kh, kw)
91
+ x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight)
92
+ x = x.reshape(batch_size, -1, *x.shape[2:])
93
+ if noise is not None:
94
+ x = x.add_(noise)
95
+ return x
96
+
97
+ #----------------------------------------------------------------------------
98
+
99
+
100
+ class FullyConnectedLayer(torch.nn.Module):
101
+ def __init__(self,
102
+ in_features, # Number of input features.
103
+ out_features, # Number of output features.
104
+ bias = True, # Apply additive bias before the activation function?
105
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
106
+ lr_multiplier = 1, # Learning rate multiplier.
107
+ bias_init = 0, # Initial value for the additive bias.
108
+ ):
109
+ super().__init__()
110
+ self.in_features = in_features
111
+ self.out_features = out_features
112
+ self.activation = activation
113
+ self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
114
+ self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
115
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
116
+ self.bias_gain = lr_multiplier
117
+
118
+ def forward(self, x):
119
+ w = self.weight.to(x.dtype) * self.weight_gain
120
+ b = self.bias
121
+ if b is not None:
122
+ b = b.to(x.dtype)
123
+ if self.bias_gain != 1:
124
+ b = b * self.bias_gain
125
+
126
+ if self.activation == 'linear' and b is not None:
127
+ x = torch.addmm(b.unsqueeze(0), x, w.t())
128
+ else:
129
+ x = x.matmul(w.t())
130
+ x = bias_act.bias_act(x, b, act=self.activation)
131
+ return x
132
+
133
+ def extra_repr(self):
134
+ return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}'
135
+
136
+ #----------------------------------------------------------------------------
137
+
138
+
139
+ class Conv2dLayer(torch.nn.Module):
140
+ def __init__(self,
141
+ in_channels, # Number of input channels.
142
+ out_channels, # Number of output channels.
143
+ kernel_size, # Width and height of the convolution kernel.
144
+ bias = True, # Apply additive bias before the activation function?
145
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
146
+ up = 1, # Integer upsampling factor.
147
+ down = 1, # Integer downsampling factor.
148
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
149
+ conv_clamp = None, # Clamp the output to +-X, None = disable clamping.
150
+ channels_last = False, # Expect the input to have memory_format=channels_last?
151
+ trainable = True, # Update the weights of this layer during training?
152
+ ):
153
+ super().__init__()
154
+ self.in_channels = in_channels
155
+ self.out_channels = out_channels
156
+ self.activation = activation
157
+ self.up = up
158
+ self.down = down
159
+ self.conv_clamp = conv_clamp
160
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
161
+ self.padding = kernel_size // 2
162
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
163
+ self.act_gain = bias_act.activation_funcs[activation].def_gain
164
+
165
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
166
+ weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)
167
+ bias = torch.zeros([out_channels]) if bias else None
168
+ if trainable:
169
+ self.weight = torch.nn.Parameter(weight)
170
+ self.bias = torch.nn.Parameter(bias) if bias is not None else None
171
+ else:
172
+ self.register_buffer('weight', weight)
173
+ if bias is not None:
174
+ self.register_buffer('bias', bias)
175
+ else:
176
+ self.bias = None
177
+
178
+ def forward(self, x, gain=1):
179
+ w = self.weight * self.weight_gain
180
+ b = self.bias.to(x.dtype) if self.bias is not None else None
181
+ flip_weight = (self.up == 1) # slightly faster
182
+ x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight)
183
+
184
+ act_gain = self.act_gain * gain
185
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
186
+ x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp)
187
+ return x
188
+
189
+ def extra_repr(self):
190
+ return ' '.join([
191
+ f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, activation={self.activation:s},',
192
+ f'up={self.up}, down={self.down}'])
193
+
194
+ #----------------------------------------------------------------------------
195
+
196
+
197
+ class MappingNetwork(torch.nn.Module):
198
+ def __init__(self,
199
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
200
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
201
+ w_dim, # Intermediate latent (W) dimensionality.
202
+ num_ws, # Number of intermediate latents to output, None = do not broadcast.
203
+ num_layers = 8, # Number of mapping layers.
204
+ embed_features = None, # Label embedding dimensionality, None = same as w_dim.
205
+ layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim.
206
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
207
+ lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
208
+ w_avg_beta = 0.998, # Decay for tracking the moving average of W during training, None = do not track.
209
+ ):
210
+ super().__init__()
211
+ self.z_dim = z_dim
212
+ self.c_dim = c_dim
213
+ self.w_dim = w_dim
214
+ self.num_ws = num_ws
215
+ self.num_layers = num_layers
216
+ self.w_avg_beta = w_avg_beta
217
+
218
+ if embed_features is None:
219
+ embed_features = w_dim
220
+ if c_dim == 0:
221
+ embed_features = 0
222
+ if layer_features is None:
223
+ layer_features = w_dim
224
+ features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
225
+
226
+ if c_dim > 0:
227
+ self.embed = FullyConnectedLayer(c_dim, embed_features)
228
+ for idx in range(num_layers):
229
+ in_features = features_list[idx]
230
+ out_features = features_list[idx + 1]
231
+ layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
232
+ setattr(self, f'fc{idx}', layer)
233
+
234
+ if num_ws is not None and w_avg_beta is not None:
235
+ self.register_buffer('w_avg', torch.zeros([w_dim]))
236
+
237
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, repeat_w = False):
238
+ # Embed, normalize, and concat inputs.
239
+ x = None
240
+ with torch.autograd.profiler.record_function('input'):
241
+ if self.z_dim > 0:
242
+ misc.assert_shape(z, [None, self.z_dim])
243
+ x = normalize_2nd_moment(z.to(torch.float32))
244
+ if self.c_dim > 0:
245
+ misc.assert_shape(c, [None, self.c_dim])
246
+ y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
247
+ x = torch.cat([x, y], dim=1) if x is not None else y
248
+
249
+ # Main layers.
250
+ for idx in range(self.num_layers):
251
+ layer = getattr(self, f'fc{idx}')
252
+ x = layer(x)
253
+
254
+ # Update moving average of W.
255
+ if update_emas and self.w_avg_beta is not None:
256
+ with torch.autograd.profiler.record_function('update_w_avg'):
257
+ self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
258
+
259
+ # Broadcast.
260
+ #if self.num_ws is not None:
261
+ if repeat_w:
262
+ with torch.autograd.profiler.record_function('broadcast'):
263
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
264
+
265
+ # Apply truncation.
266
+ if truncation_psi != 1:
267
+ with torch.autograd.profiler.record_function('truncate'):
268
+ assert self.w_avg_beta is not None
269
+ if self.num_ws is None or truncation_cutoff is None:
270
+ x = self.w_avg.lerp(x, truncation_psi)
271
+ else:
272
+ x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
273
+ return x
274
+
275
+ def extra_repr(self):
276
+ return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}'
277
+
278
+ #----------------------------------------------------------------------------
279
+
280
+
281
+ class SynthesisLayer(torch.nn.Module):
282
+ def __init__(self,
283
+ in_channels, # Number of input channels.
284
+ out_channels, # Number of output channels.
285
+ w_dim, # Intermediate latent (W) dimensionality.
286
+ resolution, # Resolution of this layer.
287
+ kernel_size = 3, # Convolution kernel size.
288
+ up = 1, # Integer upsampling factor.
289
+ use_noise = True, # Enable noise input?
290
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
291
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
292
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
293
+ channels_last = False, # Use channels_last format for the weights?
294
+ ):
295
+ super().__init__()
296
+ self.in_channels = in_channels
297
+ self.out_channels = out_channels
298
+ self.w_dim = w_dim
299
+ self.resolution = resolution
300
+ self.up = up
301
+ self.use_noise = use_noise
302
+ self.activation = activation
303
+ self.conv_clamp = conv_clamp
304
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
305
+ self.padding = kernel_size // 2
306
+ self.act_gain = bias_act.activation_funcs[activation].def_gain
307
+
308
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
309
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
310
+ self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
311
+ if use_noise:
312
+ self.register_buffer('noise_const', torch.randn([resolution, resolution]))
313
+ self.noise_strength = torch.nn.Parameter(torch.zeros([]))
314
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
315
+
316
+ def forward(self, x, w, noise_mode='random', n = None, weight_deltas = None,fused_modconv=True, gain=1):
317
+ assert noise_mode in ['random', 'const', 'none']
318
+ in_resolution = self.resolution // self.up
319
+ misc.assert_shape(x, [None, self.in_channels, in_resolution, in_resolution])
320
+ styles = self.affine(w)
321
+
322
+ noise = None
323
+ if self.use_noise and noise_mode == 'random':
324
+ noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
325
+ if self.use_noise and noise_mode == 'const':
326
+ if n is not None:
327
+ noise = n * self.noise_strength
328
+ else:
329
+ noise = self.noise_const * self.noise_strength
330
+
331
+ flip_weight = (self.up == 1) # slightly faster
332
+ x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up,
333
+ padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv, weigth_deltas=weight_deltas)
334
+
335
+ act_gain = self.act_gain * gain
336
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
337
+ x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp)
338
+ return x
339
+
340
+ def extra_repr(self):
341
+ return ' '.join([
342
+ f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d},',
343
+ f'resolution={self.resolution:d}, up={self.up}, activation={self.activation:s}'])
344
+
345
+ #----------------------------------------------------------------------------
346
+
347
+
348
+ class ToRGBLayer(torch.nn.Module):
349
+ def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False):
350
+ super().__init__()
351
+ self.in_channels = in_channels
352
+ self.out_channels = out_channels
353
+ self.w_dim = w_dim
354
+ self.conv_clamp = conv_clamp
355
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
356
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
357
+ self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
358
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
359
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
360
+
361
+ def forward(self, x, w, fused_modconv=True):
362
+ styles = self.affine(w) * self.weight_gain
363
+ x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv)
364
+ x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
365
+ return x
366
+
367
+ def extra_repr(self):
368
+ return f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d}'
369
+
370
+ #----------------------------------------------------------------------------
371
+ class ResLayers(nn.Module):
372
+ def __init__(self, in_channels, out_channels, stride=1):
373
+ super().__init__()
374
+
375
+ if (in_channels == out_channels) and stride==1:
376
+ self.shortcut_layer = nn.Identity()
377
+ else:
378
+ self.shortcut_layer = nn.Sequential(
379
+ nn.Conv2d(in_channels, out_channels, (1, 1), stride, bias=False))
380
+
381
+ self.res_layer = nn.Sequential(
382
+ nn.Conv2d(in_channels, out_channels, (3, 3), (1, 1), 1, bias=True), nn.LeakyReLU(0.2),
383
+ nn.Conv2d(out_channels, out_channels, (3, 3), stride, 1, bias=True) )
384
+
385
+ def forward(self, x):
386
+ shortcut = self.shortcut_layer(x)
387
+ res = self.res_layer(x)
388
+ return res + shortcut
389
+
390
+ class FeatureEdit(nn.Module):
391
+ def __init__(self, in_channels, out_channels):
392
+ super().__init__()
393
+ self.convs = nn.ModuleList()
394
+ iter_num = in_channels // out_channels
395
+ i_c = in_channels
396
+ for i in range(iter_num-1):
397
+ out_c = i_c - out_channels
398
+ self.convs.append( ResLayers(i_c,out_c,1) )
399
+ i_c = out_c
400
+ def forward(self, diff):
401
+ for block in self.convs:
402
+ diff = block(diff)
403
+ return diff
404
+
405
+ class FeatureAlignment(nn.Module):
406
+ def __init__(self, in_channels, out_channels):
407
+ super().__init__()
408
+ t_channel = 512
409
+ self.first_layer = nn.Conv2d(in_channels, t_channel, kernel_size=1, padding=0, bias=True)
410
+
411
+ self.conv1 = nn.Sequential(*[ResLayers(t_channel,t_channel,1)])
412
+ self.conv2 = nn.Sequential(*[ResLayers(t_channel,t_channel,2), ResLayers(t_channel,t_channel,1)])
413
+ self.conv3 = nn.Sequential(*[ResLayers(t_channel,t_channel,2), ResLayers(t_channel,t_channel,1)])
414
+
415
+ self.dconv1 = nn.Sequential(*[ResLayers(t_channel,t_channel,1), ResLayers(t_channel,t_channel,1)])
416
+ self.dconv2 = nn.Sequential(*[ResLayers(t_channel,t_channel,1), ResLayers(t_channel,t_channel,1)])
417
+ self.dconv3 = nn.Sequential(*[ResLayers(t_channel,t_channel,1), ResLayers(t_channel,t_channel,1)])
418
+
419
+ self.out_layer = nn.Conv2d(t_channel, out_channels, kernel_size=1, padding=0, bias=True)
420
+
421
+ def forward(self, encoder_feats, generator_feats):
422
+
423
+ x = torch.cat((encoder_feats,generator_feats), dim=1)
424
+ x = self.first_layer(x)
425
+
426
+ f1 = self.conv1(x)
427
+ f2 = self.conv2(f1)
428
+ f3 = self.conv3(f2)
429
+ shape = f3.shape[-1]
430
+ df1 = F.interpolate(f3, size=(shape*2,shape*2) , mode='bilinear', align_corners=True)
431
+ df2 = self.dconv1(df1 + f2)
432
+ df2 = F.interpolate(df2, size=(shape*4,shape*4) , mode='bilinear', align_corners=True)
433
+ df3 = self.dconv2(df2 + f1)
434
+
435
+ aligned_feats = self.out_layer(df3)
436
+
437
+ return aligned_feats
438
+
439
+ class FeatureExtraction(nn.Module):
440
+ def __init__(self,in_channels, out_channels ):
441
+ super().__init__()
442
+ t_channel = 512
443
+ self.first_layer = nn.Conv2d(in_channels, t_channel, kernel_size=1, padding=0, bias=True)
444
+ self.convs = nn.Sequential(*[ResLayers(t_channel,t_channel,1), ResLayers(t_channel,out_channels,1), ResLayers(out_channels,out_channels,1) ])
445
+ #self.out_layer = nn.Conv2d(t_channel, out_channels, kernel_size=1, padding=0, bias=False)
446
+
447
+ def forward(self, aligned_feats):
448
+ #x = aligned_feats - generator_feats
449
+ y = self.first_layer(aligned_feats)
450
+ y = self.convs(y)
451
+ #deltaF = self.out_layer(x)
452
+ return y
453
+
454
+ class GateNetwork(nn.Module):
455
+ def __init__(self, in_channels, out_channels):
456
+ super().__init__()
457
+ t_channel = 256
458
+ self.down1 = nn.Conv2d(in_channels, t_channel, kernel_size=3, padding=1, bias=True)
459
+ self.down2 = nn.Conv2d(in_channels, t_channel, kernel_size=3, padding=1, bias=True)
460
+ self.sigmoid = nn.Sigmoid()
461
+ self.convs = nn.Sequential(*[ResLayers(in_channels,in_channels,1), ResLayers(in_channels,out_channels,1), ResLayers(out_channels,out_channels,1) ])
462
+ self.convs2 = nn.Sequential(*[ResLayers(in_channels,in_channels,1), ResLayers(in_channels,out_channels,1), ResLayers(out_channels,1,1) ])
463
+
464
+
465
+ def forward(self, generator_feats, y):
466
+ generator_feats = self.down1(generator_feats)
467
+ y = self.down2(y)
468
+ x = torch.cat((generator_feats, y), dim=1)
469
+ deltaF = self.convs(x)
470
+ gate = self.convs2(x)
471
+ gate = self.sigmoid(gate)
472
+ return deltaF, gate
473
+
474
+ g_e_concat_shape={64: 640, 32:768}
475
+ e_shape = {64: 128, 32:256}
476
+
477
+ class SynthesisBlock(torch.nn.Module):
478
+ def __init__(self,
479
+ in_channels, # Number of input channels, 0 = first block.
480
+ out_channels, # Number of output channels.
481
+ w_dim, # Intermediate latent (W) dimensionality.
482
+ resolution, # Resolution of this block.
483
+ img_channels, # Number of output color channels.
484
+ is_last, # Is this the last block?
485
+ architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'.
486
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
487
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
488
+ use_fp16 = False, # Use FP16 for this block?
489
+ fp16_channels_last = False, # Use channels-last memory format with FP16?
490
+ fused_modconv_default = True, # Default value of fused_modconv. 'inference_only' = True for inference, False for training.
491
+ embed_res = 64, # Which resolution we embed the images
492
+ **layer_kwargs, # Arguments for SynthesisLayer.
493
+ ):
494
+ assert architecture in ['orig', 'skip', 'resnet']
495
+ super().__init__()
496
+ self.in_channels = in_channels
497
+ self.w_dim = w_dim
498
+ self.resolution = resolution
499
+ self.img_channels = img_channels
500
+ self.is_last = is_last
501
+ self.architecture = architecture
502
+ self.use_fp16 = use_fp16
503
+ self.channels_last = (use_fp16 and fp16_channels_last)
504
+ self.fused_modconv_default = fused_modconv_default
505
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
506
+ self.num_conv = 0
507
+ self.num_torgb = 0
508
+
509
+ if in_channels == 0:
510
+ self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))
511
+
512
+ if in_channels != 0:
513
+ self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2,
514
+ resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
515
+ self.num_conv += 1
516
+
517
+ self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution,
518
+ conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
519
+ self.num_conv += 1
520
+
521
+ if is_last or architecture == 'skip':
522
+ self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim,
523
+ conv_clamp=conv_clamp, channels_last=self.channels_last)
524
+ self.num_torgb += 1
525
+
526
+ if in_channels != 0 and architecture == 'resnet':
527
+ self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2,
528
+ resample_filter=resample_filter, channels_last=self.channels_last)
529
+ if resolution == embed_res:
530
+ in_c = g_e_concat_shape.get(embed_res)
531
+ #self.modify_feature_edit = FeatureEdit(in_channels=512, out_channels=e_shape.get(embed_res))
532
+ self.modify_feature_alignment = FeatureAlignment(in_channels=in_c, out_channels=512)
533
+ self.modify_feature_extraction = FeatureExtraction(in_channels=512, out_channels=512)
534
+ self.modify_feature_gates = GateNetwork(in_channels=512, out_channels=512)
535
+ self.embed_res = embed_res
536
+
537
+ def forward(self, x, img, ws, conditions=None, noise=None, weight_deltas = None, highres_outs=None, return_f = False,
538
+ force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs):
539
+ _ = update_emas # unused
540
+ misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
541
+ w_iter = iter(ws.unbind(dim=1))
542
+ if ws.device.type != 'cuda':
543
+ force_fp32 = True
544
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
545
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
546
+ if fused_modconv is None:
547
+ fused_modconv = self.fused_modconv_default
548
+ if fused_modconv == 'inference_only':
549
+ fused_modconv = (not self.training)
550
+
551
+ # Input.
552
+ if self.in_channels == 0:
553
+ x = self.const.to(dtype=dtype, memory_format=memory_format)
554
+ x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
555
+ else:
556
+ misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
557
+ x = x.to(dtype=dtype, memory_format=memory_format)
558
+ gouts = {}
559
+ # Main layers.
560
+ if self.in_channels == 0:
561
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, n=noise[0], weight_deltas=weight_deltas[0], **layer_kwargs)
562
+ elif self.architecture == 'resnet':
563
+ y = self.skip(x, gain=np.sqrt(0.5))
564
+ x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
565
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
566
+ x = y.add_(x)
567
+ else:
568
+ x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, n=noise[0], weight_deltas=weight_deltas[0], **layer_kwargs)
569
+ #HFGI Generator Modification
570
+ if x.shape[-1] == 64 and conditions is not None:
571
+ x = x*(1+conditions[0]) + conditions[1]
572
+ if x.shape[-1] == self.embed_res and return_f:
573
+ return x, None, None
574
+ #HighResFeat Generator Modification
575
+ if x.shape[-1] == self.embed_res and highres_outs is not None:
576
+ #feature_edit = self.modify_feature_edit(x - highres_outs['inversion'])
577
+ #high_res = highres_outs[f'{self.embed_res}x{self.embed_res}'] + feature_edit
578
+ aligned_feats = self.modify_feature_alignment(highres_outs[f'{self.embed_res}x{self.embed_res}'], highres_outs['inversion'])
579
+ aligned_feats = self.modify_feature_extraction(aligned_feats)
580
+ #x = self.modify_feature_gates(x, deltaF)
581
+ deltaF, gate = self.modify_feature_gates(x, aligned_feats)
582
+ x = (x * (1-gate) ) + ( (x + deltaF) * gate )
583
+ gouts['gates'] = gate
584
+ gouts['additions'] = deltaF
585
+ gouts['aligned_feats'] = aligned_feats
586
+
587
+ #gouts['aligned_loss'] =F.mse_loss(aligned_feats, x, reduction='mean')
588
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv,n=noise[1], weight_deltas=weight_deltas[1], **layer_kwargs)
589
+
590
+ # ToRGB.
591
+ if img is not None:
592
+ misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
593
+ img = upfirdn2d.upsample2d(img, self.resample_filter)
594
+ if self.is_last or self.architecture == 'skip':
595
+ y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
596
+ y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
597
+ img = img.add_(y) if img is not None else y
598
+
599
+ assert x.dtype == dtype
600
+ assert img is None or img.dtype == torch.float32
601
+ return x, img, gouts
602
+
603
+ def extra_repr(self):
604
+ return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
605
+
606
+ #----------------------------------------------------------------------------
607
+
608
+
609
+ class SynthesisNetwork(torch.nn.Module):
610
+ def __init__(self,
611
+ w_dim, # Intermediate latent (W) dimensionality.
612
+ img_resolution, # Output image resolution.
613
+ img_channels, # Number of color channels.
614
+ channel_base = 32768, # Overall multiplier for the number of channels.
615
+ channel_max = 512, # Maximum number of channels in any layer.
616
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
617
+ **block_kwargs, # Arguments for SynthesisBlock.
618
+ ):
619
+ assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
620
+ super().__init__()
621
+ self.w_dim = w_dim
622
+ self.img_resolution = img_resolution
623
+ self.img_resolution_log2 = int(np.log2(img_resolution))
624
+ self.img_channels = img_channels
625
+ self.num_fp16_res = num_fp16_res
626
+ self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)]
627
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
628
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
629
+
630
+ self.num_ws = 0
631
+ for res in self.block_resolutions:
632
+ in_channels = channels_dict[res // 2] if res > 4 else 0
633
+ out_channels = channels_dict[res]
634
+ use_fp16 = (res >= fp16_resolution)
635
+ is_last = (res == self.img_resolution)
636
+ block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res,
637
+ img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs)
638
+ self.num_ws += block.num_conv
639
+ if is_last:
640
+ self.num_ws += block.num_torgb
641
+ setattr(self, f'b{res}', block)
642
+
643
+ def forward(self, ws, conditions=None, noise=None, weight_deltas=None, highres_outs=None, return_f = False, **block_kwargs):
644
+ block_ws = []
645
+ with torch.autograd.profiler.record_function('split_ws'):
646
+ misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
647
+ ws = ws.to(torch.float32)
648
+ w_idx = 0
649
+ for res in self.block_resolutions:
650
+ block = getattr(self, f'b{res}')
651
+ block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
652
+ w_idx += block.num_conv
653
+
654
+ x = img = None
655
+ conv_idx = 0
656
+ gouts = {}
657
+ for res, cur_ws in zip(self.block_resolutions, block_ws):
658
+ block = getattr(self, f'b{res}')
659
+ if noise is not None:
660
+ noise_input = noise[conv_idx: conv_idx + block.num_conv]
661
+ else:
662
+ noise_input = [None] * block.num_conv
663
+ if weight_deltas is not None:
664
+ delta_input = weight_deltas[conv_idx: conv_idx + block.num_conv]
665
+ else:
666
+ delta_input = [None] * block.num_conv
667
+ x, img, gouts_per_res = block(x, img, cur_ws, conditions, noise_input, delta_input, highres_outs, return_f, **block_kwargs)
668
+ if return_f and img is None:
669
+ return x, None
670
+ if gouts_per_res:
671
+ gouts.update(gouts_per_res)
672
+
673
+ conv_idx += block.num_conv
674
+ return img, gouts
675
+
676
+ def extra_repr(self):
677
+ return ' '.join([
678
+ f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},',
679
+ f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},',
680
+ f'num_fp16_res={self.num_fp16_res:d}'])
681
+
682
+ #----------------------------------------------------------------------------
683
+
684
+
685
+ class Generator(torch.nn.Module):
686
+ def __init__(self,
687
+ z_dim, # Input latent (Z) dimensionality.
688
+ c_dim, # Conditioning label (C) dimensionality.
689
+ w_dim, # Intermediate latent (W) dimensionality.
690
+ resolution, # Output resolution.
691
+ img_channels, # Number of output color channels.
692
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
693
+ **synthesis_kwargs, # Arguments for SynthesisNetwork.
694
+ ):
695
+ super().__init__()
696
+ self.z_dim = z_dim
697
+ self.c_dim = c_dim
698
+ self.w_dim = w_dim
699
+ self.resolution = resolution
700
+ self.img_channels = img_channels
701
+ self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=resolution, img_channels=img_channels, **synthesis_kwargs)
702
+ self.num_ws = self.synthesis.num_ws
703
+ self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
704
+
705
+ # self.freeze_non_trainable_layers()
706
+
707
+ def forward(self, lat, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, mode='synthesis', return_f = False, **synthesis_kwargs):
708
+ # self.freeze_non_trainable_layers()
709
+ if mode == 'mapping':
710
+ ws = self.mapping(lat, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
711
+ return ws
712
+ if mode == 'synthesis':
713
+ img = self.synthesis(lat, highres_outs = c, return_f=return_f, update_emas=False, **synthesis_kwargs)
714
+ return img
715
+
716
+ # def freeze_non_trainable_layers(self):
717
+ # for param in self.mapping.parameters():
718
+ # param.requires_grad = False
719
+ # for name, param in self.synthesis.named_parameters():
720
+ # if 'modify' not in name:
721
+ # param.requires_grad = False
722
+
723
+ #----------------------------------------------------------------------------
724
+
725
+
726
+ class DiscriminatorBlock(torch.nn.Module):
727
+ def __init__(self,
728
+ in_channels, # Number of input channels, 0 = first block.
729
+ tmp_channels, # Number of intermediate channels.
730
+ out_channels, # Number of output channels.
731
+ resolution, # Resolution of this block.
732
+ img_channels, # Number of input color channels.
733
+ first_layer_idx, # Index of the first layer.
734
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
735
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
736
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
737
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
738
+ use_fp16 = False, # Use FP16 for this block?
739
+ fp16_channels_last = False, # Use channels-last memory format with FP16?
740
+ freeze_layers = 0, # Freeze-D: Number of layers to freeze.
741
+ ):
742
+ assert in_channels in [0, tmp_channels]
743
+ assert architecture in ['orig', 'skip', 'resnet']
744
+ super().__init__()
745
+ self.in_channels = in_channels
746
+ self.resolution = resolution
747
+ self.img_channels = img_channels
748
+ self.first_layer_idx = first_layer_idx
749
+ self.architecture = architecture
750
+ self.use_fp16 = use_fp16
751
+ self.channels_last = (use_fp16 and fp16_channels_last)
752
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
753
+
754
+ self.num_layers = 0
755
+ def trainable_gen():
756
+ while True:
757
+ layer_idx = self.first_layer_idx + self.num_layers
758
+ trainable = (layer_idx >= freeze_layers)
759
+ self.num_layers += 1
760
+ yield trainable
761
+ trainable_iter = trainable_gen()
762
+
763
+ if in_channels == 0 or architecture == 'skip':
764
+ self.fromrgb = Conv2dLayer(img_channels, tmp_channels, kernel_size=1, activation=activation,
765
+ trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
766
+
767
+ self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation,
768
+ trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
769
+
770
+ self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2,
771
+ trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last)
772
+
773
+ if architecture == 'resnet':
774
+ self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2,
775
+ trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last)
776
+
777
+ def forward(self, x, img, force_fp32=False):
778
+ if (x if x is not None else img).device.type != 'cuda':
779
+ force_fp32 = True
780
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
781
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
782
+
783
+ # Input.
784
+ if x is not None:
785
+ misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution])
786
+ x = x.to(dtype=dtype, memory_format=memory_format)
787
+
788
+ # FromRGB.
789
+ if self.in_channels == 0 or self.architecture == 'skip':
790
+ misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
791
+ img = img.to(dtype=dtype, memory_format=memory_format)
792
+ y = self.fromrgb(img)
793
+ x = x + y if x is not None else y
794
+ img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None
795
+
796
+ # Main layers.
797
+ if self.architecture == 'resnet':
798
+ y = self.skip(x, gain=np.sqrt(0.5))
799
+ x = self.conv0(x)
800
+ x = self.conv1(x, gain=np.sqrt(0.5))
801
+ x = y.add_(x)
802
+ else:
803
+ x = self.conv0(x)
804
+ x = self.conv1(x)
805
+
806
+ assert x.dtype == dtype
807
+ return x, img
808
+
809
+ def extra_repr(self):
810
+ return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
811
+
812
+ #----------------------------------------------------------------------------
813
+
814
+
815
+ class MinibatchStdLayer(torch.nn.Module):
816
+ def __init__(self, group_size, num_channels=1):
817
+ super().__init__()
818
+ self.group_size = group_size
819
+ self.num_channels = num_channels
820
+
821
+ def forward(self, x):
822
+ N, C, H, W = x.shape
823
+ with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants
824
+ G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N
825
+ F = self.num_channels
826
+ c = C // F
827
+
828
+ y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
829
+ y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
830
+ y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
831
+ y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
832
+ y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels.
833
+ y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
834
+ y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
835
+ x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
836
+ return x
837
+
838
+ def extra_repr(self):
839
+ return f'group_size={self.group_size}, num_channels={self.num_channels:d}'
840
+
841
+ #----------------------------------------------------------------------------
842
+
843
+
844
+ class DiscriminatorEpilogue(torch.nn.Module):
845
+ def __init__(self,
846
+ in_channels, # Number of input channels.
847
+ cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.
848
+ resolution, # Resolution of this block.
849
+ img_channels, # Number of input color channels.
850
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
851
+ mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
852
+ mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable.
853
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
854
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
855
+ ):
856
+ assert architecture in ['orig', 'skip', 'resnet']
857
+ super().__init__()
858
+ self.in_channels = in_channels
859
+ self.cmap_dim = cmap_dim
860
+ self.resolution = resolution
861
+ self.img_channels = img_channels
862
+ self.architecture = architecture
863
+
864
+ if architecture == 'skip':
865
+ self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation)
866
+ self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
867
+ self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp)
868
+ self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), in_channels, activation=activation)
869
+ self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim)
870
+
871
+ def forward(self, x, img, cmap, force_fp32=False):
872
+ misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW]
873
+ _ = force_fp32 # unused
874
+ dtype = torch.float32
875
+ memory_format = torch.contiguous_format
876
+
877
+ # FromRGB.
878
+ x = x.to(dtype=dtype, memory_format=memory_format)
879
+ if self.architecture == 'skip':
880
+ misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
881
+ img = img.to(dtype=dtype, memory_format=memory_format)
882
+ x = x + self.fromrgb(img)
883
+
884
+ # Main layers.
885
+ if self.mbstd is not None:
886
+ x = self.mbstd(x)
887
+ x = self.conv(x)
888
+ x = self.fc(x.flatten(1))
889
+ x = self.out(x)
890
+
891
+ # Conditioning.
892
+ if self.cmap_dim > 0:
893
+ misc.assert_shape(cmap, [None, self.cmap_dim])
894
+ x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
895
+
896
+ assert x.dtype == dtype
897
+ return x
898
+
899
+ def extra_repr(self):
900
+ return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
901
+
902
+ #----------------------------------------------------------------------------
903
+
904
+
905
+ class Discriminator(torch.nn.Module):
906
+ def __init__(self,
907
+ c_dim, # Conditioning label (C) dimensionality.
908
+ img_resolution, # Input resolution.
909
+ img_channels, # Number of input color channels.
910
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
911
+ channel_base = 32768, # Overall multiplier for the number of channels.
912
+ channel_max = 512, # Maximum number of channels in any layer.
913
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
914
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
915
+ cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
916
+ block_kwargs = {}, # Arguments for DiscriminatorBlock.
917
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
918
+ epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
919
+ ):
920
+ super().__init__()
921
+ self.c_dim = c_dim
922
+ self.img_resolution = img_resolution
923
+ self.img_resolution_log2 = int(np.log2(img_resolution))
924
+ self.img_channels = img_channels
925
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
926
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
927
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
928
+
929
+ if cmap_dim is None:
930
+ cmap_dim = channels_dict[4]
931
+ if c_dim == 0:
932
+ cmap_dim = 0
933
+
934
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
935
+ cur_layer_idx = 0
936
+ for res in self.block_resolutions:
937
+ in_channels = channels_dict[res] if res < img_resolution else 0
938
+ tmp_channels = channels_dict[res]
939
+ out_channels = channels_dict[res // 2]
940
+ use_fp16 = (res >= fp16_resolution)
941
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
942
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
943
+ setattr(self, f'b{res}', block)
944
+ cur_layer_idx += block.num_layers
945
+ if c_dim > 0:
946
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
947
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
948
+
949
+ def forward(self, img, c, update_emas=False, **block_kwargs):
950
+ _ = update_emas # unused
951
+ x = None
952
+ for res in self.block_resolutions:
953
+ block = getattr(self, f'b{res}')
954
+ x, img = block(x, img, **block_kwargs)
955
+
956
+ cmap = None
957
+ if self.c_dim > 0:
958
+ cmap = self.mapping(None, c)
959
+ x = self.b4(x, img, cmap)
960
+ return x
961
+
962
+ def extra_repr(self):
963
+ return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
964
+
965
+ #----------------------------------------------------------------------------
models/styleres.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from models.e4e import E4E_Inversion
6
+ from models.stylegan2 import Generator
7
+ from editings.editor import Editor
8
+ from options import Settings
9
+
10
+ class StyleRes(nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+ self.encoder = E4E_Inversion(resolution=256, num_layers = 50, mode='ir_se', out_res=64)
14
+ self.generator = Generator(z_dim=512, w_dim=512, c_dim=0, resolution=1024, img_channels=3,
15
+ fused_modconv_default='inference_only', embed_res=64)
16
+ # Set Generator arguments for eval mode
17
+ self.G_kwargs_val = {'noise_mode':'const', 'force_fp32':True}
18
+ self.device = Settings.device
19
+ self.editor = Editor()
20
+
21
+ def load_ckpt(self, ckpt_path):
22
+ ckpt = torch.load(ckpt_path, map_location='cpu')
23
+ self.encoder.basic_encoder.load_state_dict(ckpt['e4e'], strict=True)
24
+ self.encoder.latent_avg = ckpt['latent_avg']
25
+ self.generator.load_state_dict(ckpt['generator_smooth'], strict=True)
26
+ print("Model succesfully loaded")
27
+
28
+ def send_to_device(self):
29
+ self.encoder.to(self.device)
30
+ self.generator.to(self.device)
31
+ if self.device != 'cpu':
32
+ self.encoder.latent_avg = self.encoder.latent_avg.cuda()
33
+
34
+ """
35
+ Inputs: Input images and edit configs
36
+ Returns: Edited images together with the randomly generated image when the edit is interpolation.
37
+ """
38
+ def edit_images(self, image, cfg):
39
+ image = image.to(self.device)
40
+ with torch.no_grad():
41
+ latents, skips = self.encoder(image)
42
+
43
+ # GradCtrl requires gradients, others do not
44
+ latents_edited = self.editor.edit(latents, cfg)
45
+
46
+ with torch.no_grad():
47
+ # Get F space features F_orig, for the original image
48
+ skips['inversion'], _ = self.generator(latents, skips, return_f = True, **self.G_kwargs_val)
49
+ # Transform F_orig to incoming image
50
+ images, _ = self.generator(latents_edited, skips, **self.G_kwargs_val)
51
+
52
+ return images
53
+
54
+ # def edit_demo_image(self, image, edit, factor):
55
+ # from utils import AttrDict
56
+ # cfg = AttrDict()
57
+ # edit = edit.lower()
58
+ # if edit in ['pose', 'age', 'smile']:
59
+ # cfg.method = 'interfacegan'
60
+ # cfg.edit = edit
61
+ # cfg.strength = factor
62
+ # image = image.to(self.device)
63
+ # with torch.no_grad():
64
+ # latents, skips = self.encoder(image)
65
+ # latents_edited = self.editor.edit(latents, cfg)
66
+ # with torch.no_grad():
67
+ # # Get F space features F_orig, for the original image
68
+ # skips['inversion'], _ = self.generator(latents, skips, return_f = True, **self.G_kwargs_val)
69
+ # # Transform F_orig to incoming image
70
+ # images, _ = self.generator(latents_edited, skips, **self.G_kwargs_val)
71
+
72
+ # return images
73
+
74
+
75
+
models/torch_utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
models/torch_utils/custom_ops.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import glob
10
+ import hashlib
11
+ import importlib
12
+ import os
13
+ import re
14
+ import shutil
15
+ import uuid
16
+
17
+ import torch
18
+ import torch.utils.cpp_extension
19
+ from torch.utils.file_baton import FileBaton
20
+
21
+ #----------------------------------------------------------------------------
22
+ # Global options.
23
+
24
+ verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
25
+
26
+ #----------------------------------------------------------------------------
27
+ # Internal helper funcs.
28
+
29
+ def _find_compiler_bindir():
30
+ patterns = [
31
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
32
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
33
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
34
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
35
+ ]
36
+ for pattern in patterns:
37
+ matches = sorted(glob.glob(pattern))
38
+ if len(matches):
39
+ return matches[-1]
40
+ return None
41
+
42
+ #----------------------------------------------------------------------------
43
+
44
+ def _get_mangled_gpu_name():
45
+ name = torch.cuda.get_device_name().lower()
46
+ out = []
47
+ for c in name:
48
+ if re.match('[a-z0-9_-]+', c):
49
+ out.append(c)
50
+ else:
51
+ out.append('-')
52
+ return ''.join(out)
53
+
54
+ #----------------------------------------------------------------------------
55
+ # Main entry point for compiling and loading C++/CUDA plugins.
56
+
57
+ _cached_plugins = dict()
58
+
59
+ def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
60
+ assert verbosity in ['none', 'brief', 'full']
61
+ if headers is None:
62
+ headers = []
63
+ if source_dir is not None:
64
+ sources = [os.path.join(source_dir, fname) for fname in sources]
65
+ headers = [os.path.join(source_dir, fname) for fname in headers]
66
+
67
+ # Already cached?
68
+ if module_name in _cached_plugins:
69
+ return _cached_plugins[module_name]
70
+
71
+ # Print status.
72
+ if verbosity == 'full':
73
+ print(f'Setting up PyTorch plugin "{module_name}"...')
74
+ elif verbosity == 'brief':
75
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
76
+ verbose_build = (verbosity == 'full')
77
+
78
+ # Compile and load.
79
+ try: # pylint: disable=too-many-nested-blocks
80
+ # Make sure we can find the necessary compiler binaries.
81
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
82
+ compiler_bindir = _find_compiler_bindir()
83
+ if compiler_bindir is None:
84
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
85
+ os.environ['PATH'] += ';' + compiler_bindir
86
+
87
+ # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
88
+ # break the build or unnecessarily restrict what's available to nvcc.
89
+ # Unset it to let nvcc decide based on what's available on the
90
+ # machine.
91
+ os.environ['TORCH_CUDA_ARCH_LIST'] = ''
92
+
93
+ # Incremental build md5sum trickery. Copies all the input source files
94
+ # into a cached build directory under a combined md5 digest of the input
95
+ # source files. Copying is done only if the combined digest has changed.
96
+ # This keeps input file timestamps and filenames the same as in previous
97
+ # extension builds, allowing for fast incremental rebuilds.
98
+ #
99
+ # This optimization is done only in case all the source files reside in
100
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
101
+ # environment variable is set (we take this as a signal that the user
102
+ # actually cares about this.)
103
+ #
104
+ # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
105
+ # around the *.cu dependency bug in ninja config.
106
+ #
107
+ all_source_files = sorted(sources + headers)
108
+ all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
109
+ if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
110
+
111
+ # Compute combined hash digest for all source files.
112
+ hash_md5 = hashlib.md5()
113
+ for src in all_source_files:
114
+ with open(src, 'rb') as f:
115
+ hash_md5.update(f.read())
116
+
117
+ # Select cached build directory name.
118
+ source_digest = hash_md5.hexdigest()
119
+ build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
120
+ cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
121
+
122
+ if not os.path.isdir(cached_build_dir):
123
+ tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
124
+ os.makedirs(tmpdir)
125
+ for src in all_source_files:
126
+ shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
127
+ try:
128
+ os.replace(tmpdir, cached_build_dir) # atomic
129
+ except OSError:
130
+ # source directory already exists, delete tmpdir and its contents.
131
+ shutil.rmtree(tmpdir)
132
+ if not os.path.isdir(cached_build_dir): raise
133
+
134
+ # Compile.
135
+ cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
136
+ torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
137
+ verbose=verbose_build, sources=cached_sources, **build_kwargs)
138
+ else:
139
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
140
+
141
+ # Load.
142
+ module = importlib.import_module(module_name)
143
+
144
+ except:
145
+ if verbosity == 'brief':
146
+ print('Failed!')
147
+ raise
148
+
149
+ # Print status and add to cache dict.
150
+ if verbosity == 'full':
151
+ print(f'Done setting up PyTorch plugin "{module_name}".')
152
+ elif verbosity == 'brief':
153
+ print('Done.')
154
+ _cached_plugins[module_name] = module
155
+ return module
156
+
157
+ #----------------------------------------------------------------------------
models/torch_utils/misc.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import re
10
+ import contextlib
11
+ import numpy as np
12
+ import torch
13
+ import warnings
14
+
15
+ #----------------------------------------------------------------------------
16
+ # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
17
+ # same constant is used multiple times.
18
+
19
+ _constant_cache = dict()
20
+
21
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
22
+ value = np.asarray(value)
23
+ if shape is not None:
24
+ shape = tuple(shape)
25
+ if dtype is None:
26
+ dtype = torch.get_default_dtype()
27
+ if device is None:
28
+ device = torch.device('cpu')
29
+ if memory_format is None:
30
+ memory_format = torch.contiguous_format
31
+
32
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
33
+ tensor = _constant_cache.get(key, None)
34
+ if tensor is None:
35
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
36
+ if shape is not None:
37
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
38
+ tensor = tensor.contiguous(memory_format=memory_format)
39
+ _constant_cache[key] = tensor
40
+ return tensor
41
+
42
+ #----------------------------------------------------------------------------
43
+ # Replace NaN/Inf with specified numerical values.
44
+
45
+ try:
46
+ nan_to_num = torch.nan_to_num # 1.8.0a0
47
+ except AttributeError:
48
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
49
+ assert isinstance(input, torch.Tensor)
50
+ if posinf is None:
51
+ posinf = torch.finfo(input.dtype).max
52
+ if neginf is None:
53
+ neginf = torch.finfo(input.dtype).min
54
+ assert nan == 0
55
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
56
+
57
+ #----------------------------------------------------------------------------
58
+ # Symbolic assert.
59
+
60
+ try:
61
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
62
+ except AttributeError:
63
+ symbolic_assert = torch.Assert # 1.7.0
64
+
65
+ #----------------------------------------------------------------------------
66
+ # Context manager to temporarily suppress known warnings in torch.jit.trace().
67
+ # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
68
+
69
+ @contextlib.contextmanager
70
+ def suppress_tracer_warnings():
71
+ flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
72
+ warnings.filters.insert(0, flt)
73
+ yield
74
+ warnings.filters.remove(flt)
75
+
76
+ #----------------------------------------------------------------------------
77
+ # Assert that the shape of a tensor matches the given list of integers.
78
+ # None indicates that the size of a dimension is allowed to vary.
79
+ # Performs symbolic assertion when used in torch.jit.trace().
80
+
81
+ def assert_shape(tensor, ref_shape):
82
+ if tensor.ndim != len(ref_shape):
83
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
84
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
85
+ if ref_size is None:
86
+ pass
87
+ elif isinstance(ref_size, torch.Tensor):
88
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
89
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
90
+ elif isinstance(size, torch.Tensor):
91
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
92
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
93
+ elif size != ref_size:
94
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
95
+
96
+ #----------------------------------------------------------------------------
97
+ # Function decorator that calls torch.autograd.profiler.record_function().
98
+
99
+ def profiled_function(fn):
100
+ def decorator(*args, **kwargs):
101
+ with torch.autograd.profiler.record_function(fn.__name__):
102
+ return fn(*args, **kwargs)
103
+ decorator.__name__ = fn.__name__
104
+ return decorator
105
+
106
+ #----------------------------------------------------------------------------
107
+ # Sampler for torch.utils.data.DataLoader that loops over the dataset
108
+ # indefinitely, shuffling items as it goes.
109
+
110
+ class InfiniteSampler(torch.utils.data.Sampler):
111
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
112
+ assert len(dataset) > 0
113
+ assert num_replicas > 0
114
+ assert 0 <= rank < num_replicas
115
+ assert 0 <= window_size <= 1
116
+ super().__init__(dataset)
117
+ self.dataset = dataset
118
+ self.rank = rank
119
+ self.num_replicas = num_replicas
120
+ self.shuffle = shuffle
121
+ self.seed = seed
122
+ self.window_size = window_size
123
+
124
+ def __iter__(self):
125
+ order = np.arange(len(self.dataset))
126
+ rnd = None
127
+ window = 0
128
+ if self.shuffle:
129
+ rnd = np.random.RandomState(self.seed)
130
+ rnd.shuffle(order)
131
+ window = int(np.rint(order.size * self.window_size))
132
+
133
+ idx = 0
134
+ while True:
135
+ i = idx % order.size
136
+ if idx % self.num_replicas == self.rank:
137
+ yield order[i]
138
+ if window >= 2:
139
+ j = (i - rnd.randint(window)) % order.size
140
+ order[i], order[j] = order[j], order[i]
141
+ idx += 1
142
+
143
+ #----------------------------------------------------------------------------
144
+ # Utilities for operating with torch.nn.Module parameters and buffers.
145
+
146
+ def params_and_buffers(module):
147
+ assert isinstance(module, torch.nn.Module)
148
+ return list(module.parameters()) + list(module.buffers())
149
+
150
+ def named_params_and_buffers(module):
151
+ assert isinstance(module, torch.nn.Module)
152
+ return list(module.named_parameters()) + list(module.named_buffers())
153
+
154
+ def copy_params_and_buffers(src_module, dst_module, require_all=False):
155
+ assert isinstance(src_module, torch.nn.Module)
156
+ assert isinstance(dst_module, torch.nn.Module)
157
+ src_tensors = dict(named_params_and_buffers(src_module))
158
+ for name, tensor in named_params_and_buffers(dst_module):
159
+ assert (name in src_tensors) or (not require_all)
160
+ if name in src_tensors:
161
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
162
+
163
+ #----------------------------------------------------------------------------
164
+ # Context manager for easily enabling/disabling DistributedDataParallel
165
+ # synchronization.
166
+
167
+ @contextlib.contextmanager
168
+ def ddp_sync(module, sync):
169
+ assert isinstance(module, torch.nn.Module)
170
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
171
+ yield
172
+ else:
173
+ with module.no_sync():
174
+ yield
175
+
176
+ #----------------------------------------------------------------------------
177
+ # Check DistributedDataParallel consistency across processes.
178
+
179
+ def check_ddp_consistency(module, ignore_regex=None):
180
+ assert isinstance(module, torch.nn.Module)
181
+ for name, tensor in named_params_and_buffers(module):
182
+ fullname = type(module).__name__ + '.' + name
183
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
184
+ continue
185
+ tensor = tensor.detach()
186
+ if tensor.is_floating_point():
187
+ tensor = nan_to_num(tensor)
188
+ other = tensor.clone()
189
+ torch.distributed.broadcast(tensor=other, src=0)
190
+ assert (tensor == other).all(), fullname
191
+
192
+ #----------------------------------------------------------------------------
193
+ # Print summary table of module hierarchy.
194
+
195
+ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
196
+ assert isinstance(module, torch.nn.Module)
197
+ assert not isinstance(module, torch.jit.ScriptModule)
198
+ assert isinstance(inputs, (tuple, list))
199
+
200
+ # Register hooks.
201
+ entries = []
202
+ nesting = [0]
203
+ def pre_hook(_mod, _inputs):
204
+ nesting[0] += 1
205
+ def post_hook(mod, _inputs, outputs):
206
+ nesting[0] -= 1
207
+ if nesting[0] <= max_nesting:
208
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
209
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
210
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
211
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
212
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
213
+
214
+ # Run module.
215
+ outputs = module(*inputs)
216
+ for hook in hooks:
217
+ hook.remove()
218
+
219
+ # Identify unique outputs, parameters, and buffers.
220
+ tensors_seen = set()
221
+ for e in entries:
222
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
223
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
224
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
225
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
226
+
227
+ # Filter out redundant entries.
228
+ if skip_redundant:
229
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
230
+
231
+ # Construct table.
232
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
233
+ rows += [['---'] * len(rows[0])]
234
+ param_total = 0
235
+ buffer_total = 0
236
+ submodule_names = {mod: name for name, mod in module.named_modules()}
237
+ for e in entries:
238
+ name = '<top-level>' if e.mod is module else submodule_names[e.mod]
239
+ param_size = sum(t.numel() for t in e.unique_params)
240
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
241
+ output_shapes = [str(list(t.shape)) for t in e.outputs]
242
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
243
+ rows += [[
244
+ name + (':0' if len(e.outputs) >= 2 else ''),
245
+ str(param_size) if param_size else '-',
246
+ str(buffer_size) if buffer_size else '-',
247
+ (output_shapes + ['-'])[0],
248
+ (output_dtypes + ['-'])[0],
249
+ ]]
250
+ for idx in range(1, len(e.outputs)):
251
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
252
+ param_total += param_size
253
+ buffer_total += buffer_size
254
+ rows += [['---'] * len(rows[0])]
255
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
256
+
257
+ # Print table.
258
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
259
+ print()
260
+ for row in rows:
261
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
262
+ print()
263
+ return outputs
264
+
265
+ #----------------------------------------------------------------------------
models/torch_utils/ops/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
models/torch_utils/ops/bias_act.cpp ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "bias_act.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static bool has_same_layout(torch::Tensor x, torch::Tensor y)
17
+ {
18
+ if (x.dim() != y.dim())
19
+ return false;
20
+ for (int64_t i = 0; i < x.dim(); i++)
21
+ {
22
+ if (x.size(i) != y.size(i))
23
+ return false;
24
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
25
+ return false;
26
+ }
27
+ return true;
28
+ }
29
+
30
+ //------------------------------------------------------------------------
31
+
32
+ static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
33
+ {
34
+ // Validate arguments.
35
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
36
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
37
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
38
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
39
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
40
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
41
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
42
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
43
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
44
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
45
+
46
+ // Validate layout.
47
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
48
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
49
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
50
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
51
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
52
+
53
+ // Create output tensor.
54
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
55
+ torch::Tensor y = torch::empty_like(x);
56
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
57
+
58
+ // Initialize CUDA kernel parameters.
59
+ bias_act_kernel_params p;
60
+ p.x = x.data_ptr();
61
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
62
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
63
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
64
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
65
+ p.y = y.data_ptr();
66
+ p.grad = grad;
67
+ p.act = act;
68
+ p.alpha = alpha;
69
+ p.gain = gain;
70
+ p.clamp = clamp;
71
+ p.sizeX = (int)x.numel();
72
+ p.sizeB = (int)b.numel();
73
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
74
+
75
+ // Choose CUDA kernel.
76
+ void* kernel;
77
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
78
+ {
79
+ kernel = choose_bias_act_kernel<scalar_t>(p);
80
+ });
81
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
82
+
83
+ // Launch CUDA kernel.
84
+ p.loopX = 4;
85
+ int blockSize = 4 * 32;
86
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
87
+ void* args[] = {&p};
88
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
89
+ return y;
90
+ }
91
+
92
+ //------------------------------------------------------------------------
93
+
94
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
95
+ {
96
+ m.def("bias_act", &bias_act);
97
+ }
98
+
99
+ //------------------------------------------------------------------------
models/torch_utils/ops/bias_act.cu ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "bias_act.h"
11
+
12
+ //------------------------------------------------------------------------
13
+ // Helpers.
14
+
15
+ template <class T> struct InternalType;
16
+ template <> struct InternalType<double> { typedef double scalar_t; };
17
+ template <> struct InternalType<float> { typedef float scalar_t; };
18
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
+
20
+ //------------------------------------------------------------------------
21
+ // CUDA kernel.
22
+
23
+ template <class T, int A>
24
+ __global__ void bias_act_kernel(bias_act_kernel_params p)
25
+ {
26
+ typedef typename InternalType<T>::scalar_t scalar_t;
27
+ int G = p.grad;
28
+ scalar_t alpha = (scalar_t)p.alpha;
29
+ scalar_t gain = (scalar_t)p.gain;
30
+ scalar_t clamp = (scalar_t)p.clamp;
31
+ scalar_t one = (scalar_t)1;
32
+ scalar_t two = (scalar_t)2;
33
+ scalar_t expRange = (scalar_t)80;
34
+ scalar_t halfExpRange = (scalar_t)40;
35
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
36
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
37
+
38
+ // Loop over elements.
39
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
40
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
41
+ {
42
+ // Load.
43
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
44
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
45
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
46
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
47
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
48
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
49
+ scalar_t y = 0;
50
+
51
+ // Apply bias.
52
+ ((G == 0) ? x : xref) += b;
53
+
54
+ // linear
55
+ if (A == 1)
56
+ {
57
+ if (G == 0) y = x;
58
+ if (G == 1) y = x;
59
+ }
60
+
61
+ // relu
62
+ if (A == 2)
63
+ {
64
+ if (G == 0) y = (x > 0) ? x : 0;
65
+ if (G == 1) y = (yy > 0) ? x : 0;
66
+ }
67
+
68
+ // lrelu
69
+ if (A == 3)
70
+ {
71
+ if (G == 0) y = (x > 0) ? x : x * alpha;
72
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
73
+ }
74
+
75
+ // tanh
76
+ if (A == 4)
77
+ {
78
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
79
+ if (G == 1) y = x * (one - yy * yy);
80
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
81
+ }
82
+
83
+ // sigmoid
84
+ if (A == 5)
85
+ {
86
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
87
+ if (G == 1) y = x * yy * (one - yy);
88
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
89
+ }
90
+
91
+ // elu
92
+ if (A == 6)
93
+ {
94
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
95
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
96
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
97
+ }
98
+
99
+ // selu
100
+ if (A == 7)
101
+ {
102
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
103
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
104
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
105
+ }
106
+
107
+ // softplus
108
+ if (A == 8)
109
+ {
110
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
111
+ if (G == 1) y = x * (one - exp(-yy));
112
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
113
+ }
114
+
115
+ // swish
116
+ if (A == 9)
117
+ {
118
+ if (G == 0)
119
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
120
+ else
121
+ {
122
+ scalar_t c = exp(xref);
123
+ scalar_t d = c + one;
124
+ if (G == 1)
125
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
126
+ else
127
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
128
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
129
+ }
130
+ }
131
+
132
+ // Apply gain.
133
+ y *= gain * dy;
134
+
135
+ // Clamp.
136
+ if (clamp >= 0)
137
+ {
138
+ if (G == 0)
139
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
140
+ else
141
+ y = (yref > -clamp & yref < clamp) ? y : 0;
142
+ }
143
+
144
+ // Store.
145
+ ((T*)p.y)[xi] = (T)y;
146
+ }
147
+ }
148
+
149
+ //------------------------------------------------------------------------
150
+ // CUDA kernel selection.
151
+
152
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
153
+ {
154
+ if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
155
+ if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
156
+ if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
157
+ if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
158
+ if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
159
+ if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
160
+ if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
161
+ if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
162
+ if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
163
+ return NULL;
164
+ }
165
+
166
+ //------------------------------------------------------------------------
167
+ // Template specializations.
168
+
169
+ template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
170
+ template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
171
+ template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
172
+
173
+ //------------------------------------------------------------------------
models/torch_utils/ops/bias_act.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ //------------------------------------------------------------------------
10
+ // CUDA kernel parameters.
11
+
12
+ struct bias_act_kernel_params
13
+ {
14
+ const void* x; // [sizeX]
15
+ const void* b; // [sizeB] or NULL
16
+ const void* xref; // [sizeX] or NULL
17
+ const void* yref; // [sizeX] or NULL
18
+ const void* dy; // [sizeX] or NULL
19
+ void* y; // [sizeX]
20
+
21
+ int grad;
22
+ int act;
23
+ float alpha;
24
+ float gain;
25
+ float clamp;
26
+
27
+ int sizeX;
28
+ int sizeB;
29
+ int stepB;
30
+ int loopX;
31
+ };
32
+
33
+ //------------------------------------------------------------------------
34
+ // CUDA kernel selection.
35
+
36
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37
+
38
+ //------------------------------------------------------------------------
models/torch_utils/ops/bias_act.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom PyTorch ops for efficient bias and activation."""
10
+
11
+ import os
12
+ import numpy as np
13
+ import torch
14
+ import models.dnnlib as dnnlib
15
+
16
+ from .. import custom_ops
17
+ from .. import misc
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ activation_funcs = {
22
+ 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
23
+ 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
24
+ 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
25
+ 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
26
+ 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
27
+ 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
28
+ 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
29
+ 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
30
+ 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
31
+ }
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ _plugin = None
36
+ _null_tensor = torch.empty([0])
37
+
38
+ def _init():
39
+ global _plugin
40
+ if _plugin is None:
41
+ _plugin = custom_ops.get_plugin(
42
+ module_name='bias_act_plugin',
43
+ sources=['bias_act.cpp', 'bias_act.cu'],
44
+ headers=['bias_act.h'],
45
+ source_dir=os.path.dirname(__file__),
46
+ extra_cuda_cflags=['--use_fast_math'],
47
+ )
48
+ return True
49
+
50
+ #----------------------------------------------------------------------------
51
+
52
+ def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
53
+ r"""Fused bias and activation function.
54
+
55
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
56
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
57
+ the fused op is considerably more efficient than performing the same calculation
58
+ using standard PyTorch ops. It supports first and second order gradients,
59
+ but not third order gradients.
60
+
61
+ Args:
62
+ x: Input activation tensor. Can be of any shape.
63
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
64
+ as `x`. The shape must be known, and it must match the dimension of `x`
65
+ corresponding to `dim`.
66
+ dim: The dimension in `x` corresponding to the elements of `b`.
67
+ The value of `dim` is ignored if `b` is not specified.
68
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
69
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
70
+ See `activation_funcs` for a full list. `None` is not allowed.
71
+ alpha: Shape parameter for the activation function, or `None` to use the default.
72
+ gain: Scaling factor for the output tensor, or `None` to use default.
73
+ See `activation_funcs` for the default scaling of each activation function.
74
+ If unsure, consider specifying 1.
75
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
76
+ the clamping (default).
77
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
78
+
79
+ Returns:
80
+ Tensor of the same shape and datatype as `x`.
81
+ """
82
+ assert isinstance(x, torch.Tensor)
83
+ assert impl in ['ref', 'cuda']
84
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
85
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
86
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
87
+
88
+ #----------------------------------------------------------------------------
89
+
90
+ @misc.profiled_function
91
+ def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
92
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
93
+ """
94
+ assert isinstance(x, torch.Tensor)
95
+ assert clamp is None or clamp >= 0
96
+ spec = activation_funcs[act]
97
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
98
+ gain = float(gain if gain is not None else spec.def_gain)
99
+ clamp = float(clamp if clamp is not None else -1)
100
+
101
+ # Add bias.
102
+ if b is not None:
103
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
104
+ assert 0 <= dim < x.ndim
105
+ assert b.shape[0] == x.shape[dim]
106
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
107
+
108
+ # Evaluate activation function.
109
+ alpha = float(alpha)
110
+ x = spec.func(x, alpha=alpha)
111
+
112
+ # Scale by gain.
113
+ gain = float(gain)
114
+ if gain != 1:
115
+ x = x * gain
116
+
117
+ # Clamp.
118
+ if clamp >= 0:
119
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
120
+ return x
121
+
122
+ #----------------------------------------------------------------------------
123
+
124
+ _bias_act_cuda_cache = dict()
125
+
126
+ def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
127
+ """Fast CUDA implementation of `bias_act()` using custom ops.
128
+ """
129
+ # Parse arguments.
130
+ assert clamp is None or clamp >= 0
131
+ spec = activation_funcs[act]
132
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
133
+ gain = float(gain if gain is not None else spec.def_gain)
134
+ clamp = float(clamp if clamp is not None else -1)
135
+
136
+ # Lookup from cache.
137
+ key = (dim, act, alpha, gain, clamp)
138
+ if key in _bias_act_cuda_cache:
139
+ return _bias_act_cuda_cache[key]
140
+
141
+ # Forward op.
142
+ class BiasActCuda(torch.autograd.Function):
143
+ @staticmethod
144
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
145
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format
146
+ x = x.contiguous(memory_format=ctx.memory_format)
147
+ b = b.contiguous() if b is not None else _null_tensor
148
+ y = x
149
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
150
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
151
+ ctx.save_for_backward(
152
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
153
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
154
+ y if 'y' in spec.ref else _null_tensor)
155
+ return y
156
+
157
+ @staticmethod
158
+ def backward(ctx, dy): # pylint: disable=arguments-differ
159
+ dy = dy.contiguous(memory_format=ctx.memory_format)
160
+ x, b, y = ctx.saved_tensors
161
+ dx = None
162
+ db = None
163
+
164
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
165
+ dx = dy
166
+ if act != 'linear' or gain != 1 or clamp >= 0:
167
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
168
+
169
+ if ctx.needs_input_grad[1]:
170
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
171
+
172
+ return dx, db
173
+
174
+ # Backward op.
175
+ class BiasActCudaGrad(torch.autograd.Function):
176
+ @staticmethod
177
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
178
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format
179
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
180
+ ctx.save_for_backward(
181
+ dy if spec.has_2nd_grad else _null_tensor,
182
+ x, b, y)
183
+ return dx
184
+
185
+ @staticmethod
186
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
187
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
188
+ dy, x, b, y = ctx.saved_tensors
189
+ d_dy = None
190
+ d_x = None
191
+ d_b = None
192
+ d_y = None
193
+
194
+ if ctx.needs_input_grad[0]:
195
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
196
+
197
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
198
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
199
+
200
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
201
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
202
+
203
+ return d_dy, d_x, d_b, d_y
204
+
205
+ # Add to cache.
206
+ _bias_act_cuda_cache[key] = BiasActCuda
207
+ return BiasActCuda
208
+
209
+ #----------------------------------------------------------------------------
models/torch_utils/ops/conv2d_gradfix.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom replacement for `torch.nn.functional.conv2d` that supports
10
+ arbitrarily high order gradients with zero performance penalty."""
11
+
12
+ import contextlib
13
+ import torch
14
+
15
+ # pylint: disable=redefined-builtin
16
+ # pylint: disable=arguments-differ
17
+ # pylint: disable=protected-access
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ enabled = False # Enable the custom op by setting this to true.
22
+ weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
23
+
24
+ @contextlib.contextmanager
25
+ def no_weight_gradients(disable=True):
26
+ global weight_gradients_disabled
27
+ old = weight_gradients_disabled
28
+ if disable:
29
+ weight_gradients_disabled = True
30
+ yield
31
+ weight_gradients_disabled = old
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
36
+ if _should_use_custom_op(input):
37
+ return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
38
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
39
+
40
+ def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
41
+ if _should_use_custom_op(input):
42
+ return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
43
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
44
+
45
+ #----------------------------------------------------------------------------
46
+
47
+ def _should_use_custom_op(input):
48
+ assert isinstance(input, torch.Tensor)
49
+ if (not enabled) or (not torch.backends.cudnn.enabled):
50
+ return False
51
+ if input.device.type != 'cuda':
52
+ return False
53
+ return True
54
+
55
+ def _tuple_of_ints(xs, ndim):
56
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
57
+ assert len(xs) == ndim
58
+ assert all(isinstance(x, int) for x in xs)
59
+ return xs
60
+
61
+ #----------------------------------------------------------------------------
62
+
63
+ _conv2d_gradfix_cache = dict()
64
+ _null_tensor = torch.empty([0])
65
+
66
+ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
67
+ # Parse arguments.
68
+ ndim = 2
69
+ weight_shape = tuple(weight_shape)
70
+ stride = _tuple_of_ints(stride, ndim)
71
+ padding = _tuple_of_ints(padding, ndim)
72
+ output_padding = _tuple_of_ints(output_padding, ndim)
73
+ dilation = _tuple_of_ints(dilation, ndim)
74
+
75
+ # Lookup from cache.
76
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
77
+ if key in _conv2d_gradfix_cache:
78
+ return _conv2d_gradfix_cache[key]
79
+
80
+ # Validate arguments.
81
+ assert groups >= 1
82
+ assert len(weight_shape) == ndim + 2
83
+ assert all(stride[i] >= 1 for i in range(ndim))
84
+ assert all(padding[i] >= 0 for i in range(ndim))
85
+ assert all(dilation[i] >= 0 for i in range(ndim))
86
+ if not transpose:
87
+ assert all(output_padding[i] == 0 for i in range(ndim))
88
+ else: # transpose
89
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
90
+
91
+ # Helpers.
92
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
93
+ def calc_output_padding(input_shape, output_shape):
94
+ if transpose:
95
+ return [0, 0]
96
+ return [
97
+ input_shape[i + 2]
98
+ - (output_shape[i + 2] - 1) * stride[i]
99
+ - (1 - 2 * padding[i])
100
+ - dilation[i] * (weight_shape[i + 2] - 1)
101
+ for i in range(ndim)
102
+ ]
103
+
104
+ # Forward & backward.
105
+ class Conv2d(torch.autograd.Function):
106
+ @staticmethod
107
+ def forward(ctx, input, weight, bias):
108
+ assert weight.shape == weight_shape
109
+ ctx.save_for_backward(
110
+ input if weight.requires_grad else _null_tensor,
111
+ weight if input.requires_grad else _null_tensor,
112
+ )
113
+ ctx.input_shape = input.shape
114
+
115
+ # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
116
+ if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):
117
+ a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])
118
+ b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)
119
+ c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)
120
+ c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)
121
+ c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
122
+ return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
123
+
124
+ # General case => cuDNN.
125
+ if transpose:
126
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
127
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
128
+
129
+ @staticmethod
130
+ def backward(ctx, grad_output):
131
+ input, weight = ctx.saved_tensors
132
+ input_shape = ctx.input_shape
133
+ grad_input = None
134
+ grad_weight = None
135
+ grad_bias = None
136
+
137
+ if ctx.needs_input_grad[0]:
138
+ p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)
139
+ op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
140
+ grad_input = op.apply(grad_output, weight, None)
141
+ assert grad_input.shape == input_shape
142
+
143
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
144
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
145
+ assert grad_weight.shape == weight_shape
146
+
147
+ if ctx.needs_input_grad[2]:
148
+ grad_bias = grad_output.sum([0, 2, 3])
149
+
150
+ return grad_input, grad_weight, grad_bias
151
+
152
+ # Gradient with respect to the weights.
153
+ class Conv2dGradWeight(torch.autograd.Function):
154
+ @staticmethod
155
+ def forward(ctx, grad_output, input):
156
+ ctx.save_for_backward(
157
+ grad_output if input.requires_grad else _null_tensor,
158
+ input if grad_output.requires_grad else _null_tensor,
159
+ )
160
+ ctx.grad_output_shape = grad_output.shape
161
+ ctx.input_shape = input.shape
162
+
163
+ # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
164
+ if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):
165
+ a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
166
+ b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
167
+ c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)
168
+ return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
169
+
170
+ # General case => cuDNN.
171
+ name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight'
172
+ flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
173
+ return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
174
+
175
+ @staticmethod
176
+ def backward(ctx, grad2_grad_weight):
177
+ grad_output, input = ctx.saved_tensors
178
+ grad_output_shape = ctx.grad_output_shape
179
+ input_shape = ctx.input_shape
180
+ grad2_grad_output = None
181
+ grad2_input = None
182
+
183
+ if ctx.needs_input_grad[0]:
184
+ grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
185
+ assert grad2_grad_output.shape == grad_output_shape
186
+
187
+ if ctx.needs_input_grad[1]:
188
+ p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)
189
+ op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
190
+ grad2_input = op.apply(grad_output, grad2_grad_weight, None)
191
+ assert grad2_input.shape == input_shape
192
+
193
+ return grad2_grad_output, grad2_input
194
+
195
+ _conv2d_gradfix_cache[key] = Conv2d
196
+ return Conv2d
197
+
198
+ #----------------------------------------------------------------------------
models/torch_utils/ops/conv2d_resample.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """2D convolution with optional up/downsampling."""
10
+
11
+ import torch
12
+
13
+ from .. import misc
14
+ from . import conv2d_gradfix
15
+ from . import upfirdn2d
16
+ from .upfirdn2d import _parse_padding
17
+ from .upfirdn2d import _get_filter_size
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ def _get_weight_shape(w):
22
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
23
+ shape = [int(sz) for sz in w.shape]
24
+ misc.assert_shape(w, shape)
25
+ return shape
26
+
27
+ #----------------------------------------------------------------------------
28
+
29
+ def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
30
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
31
+ """
32
+ _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
33
+
34
+ # Flip weight if requested.
35
+ # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
36
+ if not flip_weight and (kw > 1 or kh > 1):
37
+ w = w.flip([2, 3])
38
+
39
+ # Execute using conv2d_gradfix.
40
+ op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
41
+ return op(x, w, stride=stride, padding=padding, groups=groups)
42
+
43
+ #----------------------------------------------------------------------------
44
+
45
+ @misc.profiled_function
46
+ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
47
+ r"""2D convolution with optional up/downsampling.
48
+
49
+ Padding is performed only once at the beginning, not between the operations.
50
+
51
+ Args:
52
+ x: Input tensor of shape
53
+ `[batch_size, in_channels, in_height, in_width]`.
54
+ w: Weight tensor of shape
55
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
56
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
57
+ calling upfirdn2d.setup_filter(). None = identity (default).
58
+ up: Integer upsampling factor (default: 1).
59
+ down: Integer downsampling factor (default: 1).
60
+ padding: Padding with respect to the upsampled image. Can be a single number
61
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
62
+ (default: 0).
63
+ groups: Split input channels into N groups (default: 1).
64
+ flip_weight: False = convolution, True = correlation (default: True).
65
+ flip_filter: False = convolution, True = correlation (default: False).
66
+
67
+ Returns:
68
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
69
+ """
70
+ # Validate arguments.
71
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
72
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
73
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
74
+ assert isinstance(up, int) and (up >= 1)
75
+ assert isinstance(down, int) and (down >= 1)
76
+ assert isinstance(groups, int) and (groups >= 1)
77
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
78
+ fw, fh = _get_filter_size(f)
79
+ px0, px1, py0, py1 = _parse_padding(padding)
80
+
81
+ # Adjust padding to account for up/downsampling.
82
+ if up > 1:
83
+ px0 += (fw + up - 1) // 2
84
+ px1 += (fw - up) // 2
85
+ py0 += (fh + up - 1) // 2
86
+ py1 += (fh - up) // 2
87
+ if down > 1:
88
+ px0 += (fw - down + 1) // 2
89
+ px1 += (fw - down) // 2
90
+ py0 += (fh - down + 1) // 2
91
+ py1 += (fh - down) // 2
92
+
93
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
94
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
95
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
96
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
97
+ return x
98
+
99
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
100
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
101
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
102
+ x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
103
+ return x
104
+
105
+ # Fast path: downsampling only => use strided convolution.
106
+ if down > 1 and up == 1:
107
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
108
+ x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
109
+ return x
110
+
111
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
112
+ if up > 1:
113
+ if groups == 1:
114
+ w = w.transpose(0, 1)
115
+ else:
116
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
117
+ w = w.transpose(1, 2)
118
+ w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
119
+ px0 -= kw - 1
120
+ px1 -= kw - up
121
+ py0 -= kh - 1
122
+ py1 -= kh - up
123
+ pxt = max(min(-px0, -px1), 0)
124
+ pyt = max(min(-py0, -py1), 0)
125
+ x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
126
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
127
+ if down > 1:
128
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
129
+ return x
130
+
131
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
132
+ if up == 1 and down == 1:
133
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
134
+ return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
135
+
136
+ # Fallback: Generic reference implementation.
137
+ x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
138
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
139
+ if down > 1:
140
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
141
+ return x
142
+
143
+ #----------------------------------------------------------------------------
models/torch_utils/ops/filtered_lrelu.cpp ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "filtered_lrelu.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu(
17
+ torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si,
18
+ int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns)
19
+ {
20
+ // Set CUDA device.
21
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
22
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
23
+
24
+ // Validate arguments.
25
+ TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device");
26
+ TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32");
27
+ TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype");
28
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32");
29
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
30
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
31
+ TORCH_CHECK(x.numel() > 0, "x is empty");
32
+ TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2");
33
+ TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large");
34
+ TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large");
35
+ TORCH_CHECK(fu.numel() > 0, "fu is empty");
36
+ TORCH_CHECK(fd.numel() > 0, "fd is empty");
37
+ TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x");
38
+ TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1");
39
+
40
+ // Figure out how much shared memory is available on the device.
41
+ int maxSharedBytes = 0;
42
+ AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index()));
43
+ int sharedKB = maxSharedBytes >> 10;
44
+
45
+ // Populate enough launch parameters to check if a CUDA kernel exists.
46
+ filtered_lrelu_kernel_params p;
47
+ p.up = up;
48
+ p.down = down;
49
+ p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter.
50
+ p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);
51
+ filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel<float, int32_t, false, false>(p, sharedKB);
52
+ if (!test_spec.exec)
53
+ {
54
+ // No kernel found - return empty tensors and indicate missing kernel with return code of -1.
55
+ return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);
56
+ }
57
+
58
+ // Input/output element size.
59
+ int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;
60
+
61
+ // Input sizes.
62
+ int64_t xw = (int)x.size(3);
63
+ int64_t xh = (int)x.size(2);
64
+ int64_t fut_w = (int)fu.size(-1) - 1;
65
+ int64_t fut_h = (int)fu.size(0) - 1;
66
+ int64_t fdt_w = (int)fd.size(-1) - 1;
67
+ int64_t fdt_h = (int)fd.size(0) - 1;
68
+
69
+ // Logical size of upsampled buffer.
70
+ int64_t cw = xw * up + (px0 + px1) - fut_w;
71
+ int64_t ch = xh * up + (py0 + py1) - fut_h;
72
+ TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter");
73
+ TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large");
74
+
75
+ // Compute output size and allocate.
76
+ int64_t yw = (cw - fdt_w + (down - 1)) / down;
77
+ int64_t yh = (ch - fdt_h + (down - 1)) / down;
78
+ TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1");
79
+ TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large");
80
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format());
81
+
82
+ // Allocate sign tensor.
83
+ torch::Tensor so;
84
+ torch::Tensor s = si;
85
+ bool readSigns = !!s.numel();
86
+ int64_t sw_active = 0; // Active width of sign tensor.
87
+ if (writeSigns)
88
+ {
89
+ sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements.
90
+ int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height.
91
+ int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16.
92
+ TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large");
93
+ s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
94
+ }
95
+ else if (readSigns)
96
+ sw_active = s.size(3) << 2;
97
+
98
+ // Validate sign tensor if in use.
99
+ if (readSigns || writeSigns)
100
+ {
101
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
102
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
103
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
104
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
105
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
106
+ TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large");
107
+ }
108
+
109
+ // Populate rest of CUDA kernel parameters.
110
+ p.x = x.data_ptr();
111
+ p.y = y.data_ptr();
112
+ p.b = b.data_ptr();
113
+ p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
114
+ p.fu = fu.data_ptr<float>();
115
+ p.fd = fd.data_ptr<float>();
116
+ p.pad0 = make_int2(px0, py0);
117
+ p.gain = gain;
118
+ p.slope = slope;
119
+ p.clamp = clamp;
120
+ p.flip = (flip_filters) ? 1 : 0;
121
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
122
+ p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
123
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous.
124
+ p.sOfs = make_int2(sx, sy);
125
+ p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes.
126
+
127
+ // x, y, b strides are in bytes.
128
+ p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0));
129
+ p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0));
130
+ p.bStride = sz * b.stride(0);
131
+
132
+ // fu, fd strides are in elements.
133
+ p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);
134
+ p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);
135
+
136
+ // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those.
137
+ bool index64b = false;
138
+ if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;
139
+ if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true;
140
+ if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true;
141
+ if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true;
142
+ if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true;
143
+ if (s.numel() > INT_MAX) index64b = true;
144
+
145
+ // Choose CUDA kernel.
146
+ filtered_lrelu_kernel_spec spec = { 0 };
147
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&]
148
+ {
149
+ if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation.
150
+ {
151
+ // Choose kernel based on index type, datatype and sign read/write modes.
152
+ if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, true, false>(p, sharedKB);
153
+ else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, true >(p, sharedKB);
154
+ else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, false>(p, sharedKB);
155
+ else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, true, false>(p, sharedKB);
156
+ else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, true >(p, sharedKB);
157
+ else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, false>(p, sharedKB);
158
+ }
159
+ });
160
+ TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists.
161
+
162
+ // Launch CUDA kernel.
163
+ void* args[] = {&p};
164
+ int bx = spec.numWarps * 32;
165
+ int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
166
+ int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
167
+ int gz = p.yShape.z * p.yShape.w;
168
+
169
+ // Repeat multiple horizontal tiles in a CTA?
170
+ if (spec.xrep)
171
+ {
172
+ p.tilesXrep = spec.xrep;
173
+ p.tilesXdim = gx;
174
+
175
+ gx = (gx + p.tilesXrep - 1) / p.tilesXrep;
176
+ std::swap(gx, gy);
177
+ }
178
+ else
179
+ {
180
+ p.tilesXrep = 0;
181
+ p.tilesXdim = 0;
182
+ }
183
+
184
+ // Launch filter setup kernel.
185
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream()));
186
+
187
+ // Copy kernels to constant memory.
188
+ if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<true, false>(at::cuda::getCurrentCUDAStream())));
189
+ else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters<false, true >(at::cuda::getCurrentCUDAStream())));
190
+ else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<false, false>(at::cuda::getCurrentCUDAStream())));
191
+
192
+ // Set cache and shared memory configurations for main kernel.
193
+ AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared));
194
+ if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
195
+ AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10));
196
+ AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte));
197
+
198
+ // Launch main kernel.
199
+ const int maxSubGz = 65535; // CUDA maximum for block z dimension.
200
+ for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big.
201
+ {
202
+ p.blockZofs = zofs;
203
+ int subGz = std::min(maxSubGz, gz - zofs);
204
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream()));
205
+ }
206
+
207
+ // Done.
208
+ return std::make_tuple(y, so, 0);
209
+ }
210
+
211
+ //------------------------------------------------------------------------
212
+
213
+ static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns)
214
+ {
215
+ // Set CUDA device.
216
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
217
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
218
+
219
+ // Validate arguments.
220
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
221
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
222
+ TORCH_CHECK(x.numel() > 0, "x is empty");
223
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64");
224
+
225
+ // Output signs if we don't have sign input.
226
+ torch::Tensor so;
227
+ torch::Tensor s = si;
228
+ bool readSigns = !!s.numel();
229
+ if (writeSigns)
230
+ {
231
+ int64_t sw = x.size(3);
232
+ sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.
233
+ s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
234
+ }
235
+
236
+ // Validate sign tensor if in use.
237
+ if (readSigns || writeSigns)
238
+ {
239
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
240
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
241
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
242
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
243
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
244
+ TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large");
245
+ }
246
+
247
+ // Initialize CUDA kernel parameters.
248
+ filtered_lrelu_act_kernel_params p;
249
+ p.x = x.data_ptr();
250
+ p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
251
+ p.gain = gain;
252
+ p.slope = slope;
253
+ p.clamp = clamp;
254
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
255
+ p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));
256
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous.
257
+ p.sOfs = make_int2(sx, sy);
258
+
259
+ // Choose CUDA kernel.
260
+ void* func = 0;
261
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&]
262
+ {
263
+ if (writeSigns)
264
+ func = choose_filtered_lrelu_act_kernel<scalar_t, true, false>();
265
+ else if (readSigns)
266
+ func = choose_filtered_lrelu_act_kernel<scalar_t, false, true>();
267
+ else
268
+ func = choose_filtered_lrelu_act_kernel<scalar_t, false, false>();
269
+ });
270
+ TORCH_CHECK(func, "internal error - CUDA kernel not found");
271
+
272
+ // Launch CUDA kernel.
273
+ void* args[] = {&p};
274
+ int bx = 128; // 4 warps per block.
275
+
276
+ // Logical size of launch = writeSigns ? p.s : p.x
277
+ uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;
278
+ uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;
279
+ uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.
280
+ gx = (gx - 1) / bx + 1;
281
+
282
+ // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest.
283
+ const uint32_t gmax = 65535;
284
+ gy = std::min(gy, gmax);
285
+ gz = std::min(gz, gmax);
286
+
287
+ // Launch.
288
+ AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream()));
289
+ return so;
290
+ }
291
+
292
+ //------------------------------------------------------------------------
293
+
294
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
295
+ {
296
+ m.def("filtered_lrelu", &filtered_lrelu); // The whole thing.
297
+ m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place.
298
+ }
299
+
300
+ //------------------------------------------------------------------------
models/torch_utils/ops/filtered_lrelu.cu ADDED
@@ -0,0 +1,1284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "filtered_lrelu.h"
11
+ #include <cstdint>
12
+
13
+ //------------------------------------------------------------------------
14
+ // Helpers.
15
+
16
+ enum // Filter modes.
17
+ {
18
+ MODE_SUSD = 0, // Separable upsampling, separable downsampling.
19
+ MODE_FUSD = 1, // Full upsampling, separable downsampling.
20
+ MODE_SUFD = 2, // Separable upsampling, full downsampling.
21
+ MODE_FUFD = 3, // Full upsampling, full downsampling.
22
+ };
23
+
24
+ template <class T> struct InternalType;
25
+ template <> struct InternalType<double>
26
+ {
27
+ typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t;
28
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); }
29
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); }
30
+ __device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); }
31
+ };
32
+ template <> struct InternalType<float>
33
+ {
34
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
35
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
36
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
37
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
38
+ };
39
+ template <> struct InternalType<c10::Half>
40
+ {
41
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
42
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
43
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
44
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
45
+ };
46
+
47
+ #define MIN(A, B) ((A) < (B) ? (A) : (B))
48
+ #define MAX(A, B) ((A) > (B) ? (A) : (B))
49
+ #define CEIL_DIV(A, B) (((B)==1) ? (A) : \
50
+ ((B)==2) ? ((int)((A)+1) >> 1) : \
51
+ ((B)==4) ? ((int)((A)+3) >> 2) : \
52
+ (((A) + ((A) > 0 ? (B) - 1 : 0)) / (B)))
53
+
54
+ // This works only up to blocks of size 256 x 256 and for all N that are powers of two.
55
+ template <int N> __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i)
56
+ {
57
+ if ((N & (N-1)) && N <= 256)
58
+ y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256.
59
+ else
60
+ y = i/N;
61
+
62
+ x = i - y*N;
63
+ }
64
+
65
+ // Type cast stride before reading it.
66
+ template <class T> __device__ __forceinline__ T get_stride(const int64_t& x)
67
+ {
68
+ return *reinterpret_cast<const T*>(&x);
69
+ }
70
+
71
+ //------------------------------------------------------------------------
72
+ // Filters, setup kernel, copying function.
73
+
74
+ #define MAX_FILTER_SIZE 32
75
+
76
+ // Combined up/down filter buffers so that transfer can be done with one copy.
77
+ __device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel.
78
+ __device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel.
79
+
80
+ // Accessors to combined buffers to index up/down filters individually.
81
+ #define c_fu (c_fbuf)
82
+ #define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
83
+ #define g_fu (g_fbuf)
84
+ #define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
85
+
86
+ // Set up filters into global memory buffer.
87
+ static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p)
88
+ {
89
+ for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x)
90
+ {
91
+ int x, y;
92
+ fast_div_mod<MAX_FILTER_SIZE>(x, y, idx);
93
+
94
+ int fu_x = p.flip ? x : (p.fuShape.x - 1 - x);
95
+ int fu_y = p.flip ? y : (p.fuShape.y - 1 - y);
96
+ if (p.fuShape.y > 0)
97
+ g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y];
98
+ else
99
+ g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x];
100
+
101
+ int fd_x = p.flip ? x : (p.fdShape.x - 1 - x);
102
+ int fd_y = p.flip ? y : (p.fdShape.y - 1 - y);
103
+ if (p.fdShape.y > 0)
104
+ g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y];
105
+ else
106
+ g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x];
107
+ }
108
+ }
109
+
110
+ // Host function to copy filters written by setup kernel into constant buffer for main kernel.
111
+ template <bool, bool> static cudaError_t copy_filters(cudaStream_t stream)
112
+ {
113
+ void* src = 0;
114
+ cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf);
115
+ if (err) return err;
116
+ return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream);
117
+ }
118
+
119
+ //------------------------------------------------------------------------
120
+ // Coordinate spaces:
121
+ // - Relative to input tensor: inX, inY, tileInX, tileInY
122
+ // - Relative to input tile: relInX, relInY, tileInW, tileInH
123
+ // - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH
124
+ // - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH
125
+ // - Relative to output tensor: outX, outY, tileOutX, tileOutY
126
+ //
127
+ // Relationships between coordinate spaces:
128
+ // - inX = tileInX + relInX
129
+ // - inY = tileInY + relInY
130
+ // - relUpX = relInX * up + phaseInX
131
+ // - relUpY = relInY * up + phaseInY
132
+ // - relUpX = relOutX * down
133
+ // - relUpY = relOutY * down
134
+ // - outX = tileOutX + relOutX
135
+ // - outY = tileOutY + relOutY
136
+
137
+ extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer.
138
+
139
+ template <class T, class index_t, int sharedKB, bool signWrite, bool signRead, int filterMode, int up, int fuSize, int down, int fdSize, int tileOutW, int tileOutH, int threadsPerBlock, bool enableXrep, bool enableWriteSkip>
140
+ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p)
141
+ {
142
+ // Check that we don't try to support non-existing filter modes.
143
+ static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported");
144
+ static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported");
145
+ static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor");
146
+ static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor");
147
+ static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor");
148
+ static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor");
149
+ static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE");
150
+ static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters");
151
+ static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters");
152
+ static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4");
153
+ static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4");
154
+
155
+ // Static definitions.
156
+ typedef typename InternalType<T>::scalar_t scalar_t;
157
+ typedef typename InternalType<T>::vec2_t vec2_t;
158
+ typedef typename InternalType<T>::vec4_t vec4_t;
159
+ const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4.
160
+ const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height.
161
+ const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width.
162
+ const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height.
163
+ const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up.
164
+ const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4.
165
+
166
+ // Merge 1x1 downsampling into last upsampling step for upf1 and ups2.
167
+ const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD));
168
+
169
+ // Sizes of logical buffers.
170
+ const int szIn = tileInH_up * tileInW;
171
+ const int szUpX = tileInH_up * tileUpW;
172
+ const int szUpXY = downInline ? 0 : (tileUpH * tileUpW);
173
+ const int szDownX = tileUpH * tileOutW;
174
+
175
+ // Sizes for shared memory arrays.
176
+ const int s_buf0_size_base =
177
+ (filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) :
178
+ (filterMode == MODE_FUSD) ? MAX(szIn, szDownX) :
179
+ (filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) :
180
+ (filterMode == MODE_FUFD) ? szIn :
181
+ -1;
182
+ const int s_buf1_size_base =
183
+ (filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) :
184
+ (filterMode == MODE_FUSD) ? szUpXY :
185
+ (filterMode == MODE_SUFD) ? szUpX :
186
+ (filterMode == MODE_FUFD) ? szUpXY :
187
+ -1;
188
+
189
+ // Ensure U128 alignment.
190
+ const int s_buf0_size = (s_buf0_size_base + 3) & ~3;
191
+ const int s_buf1_size = (s_buf1_size_base + 3) & ~3;
192
+
193
+ // Check at compile time that we don't use too much shared memory.
194
+ static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow");
195
+
196
+ // Declare shared memory arrays.
197
+ scalar_t* s_buf0;
198
+ scalar_t* s_buf1;
199
+ if (sharedKB <= 48)
200
+ {
201
+ // Allocate shared memory arrays here.
202
+ __shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused.
203
+ s_buf0 = s_buf0_st;
204
+ s_buf1 = s_buf0 + s_buf0_size;
205
+ }
206
+ else
207
+ {
208
+ // Use the dynamically allocated shared memory array.
209
+ s_buf0 = (scalar_t*)s_buf_raw;
210
+ s_buf1 = s_buf0 + s_buf0_size;
211
+ }
212
+
213
+ // Pointers to the buffers.
214
+ scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY]
215
+ scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX]
216
+ scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX]
217
+ scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX]
218
+ if (filterMode == MODE_SUSD)
219
+ {
220
+ s_tileIn = s_buf0;
221
+ s_tileUpX = s_buf1;
222
+ s_tileUpXY = s_buf0;
223
+ s_tileDownX = s_buf1;
224
+ }
225
+ else if (filterMode == MODE_FUSD)
226
+ {
227
+ s_tileIn = s_buf0;
228
+ s_tileUpXY = s_buf1;
229
+ s_tileDownX = s_buf0;
230
+ }
231
+ else if (filterMode == MODE_SUFD)
232
+ {
233
+ s_tileIn = s_buf0;
234
+ s_tileUpX = s_buf1;
235
+ s_tileUpXY = s_buf0;
236
+ }
237
+ else if (filterMode == MODE_FUFD)
238
+ {
239
+ s_tileIn = s_buf0;
240
+ s_tileUpXY = s_buf1;
241
+ }
242
+
243
+ // Allow large grids in z direction via per-launch offset.
244
+ int channelIdx = blockIdx.z + p.blockZofs;
245
+ int batchIdx = channelIdx / p.yShape.z;
246
+ channelIdx -= batchIdx * p.yShape.z;
247
+
248
+ // Offset to output feature map. In bytes.
249
+ index_t mapOfsOut = channelIdx * get_stride<index_t>(p.yStride.z) + batchIdx * get_stride<index_t>(p.yStride.w);
250
+
251
+ // Sign shift amount.
252
+ uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6;
253
+
254
+ // Inner tile loop.
255
+ #pragma unroll 1
256
+ for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++)
257
+ {
258
+ // Locate output tile.
259
+ int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x;
260
+ int tileOutX = tileX * tileOutW;
261
+ int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH;
262
+
263
+ // Locate input tile.
264
+ int tmpX = tileOutX * down - p.pad0.x;
265
+ int tmpY = tileOutY * down - p.pad0.y;
266
+ int tileInX = CEIL_DIV(tmpX, up);
267
+ int tileInY = CEIL_DIV(tmpY, up);
268
+ const int phaseInX = tileInX * up - tmpX;
269
+ const int phaseInY = tileInY * up - tmpY;
270
+
271
+ // Extra sync if input and output buffers are the same and we are not on first tile.
272
+ if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline)))
273
+ __syncthreads();
274
+
275
+ // Load input tile & apply bias. Unrolled.
276
+ scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride<index_t>(p.bStride)));
277
+ index_t mapOfsIn = channelIdx * get_stride<index_t>(p.xStride.z) + batchIdx * get_stride<index_t>(p.xStride.w);
278
+ int idx = threadIdx.x;
279
+ const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock);
280
+ #pragma unroll
281
+ for (int loop = 0; loop < loopCountIN; loop++)
282
+ {
283
+ int relInX, relInY;
284
+ fast_div_mod<tileInW>(relInX, relInY, idx);
285
+ int inX = tileInX + relInX;
286
+ int inY = tileInY + relInY;
287
+ scalar_t v = 0;
288
+
289
+ if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y)
290
+ v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride<index_t>(p.xStride.x) + inY * get_stride<index_t>(p.xStride.y) + mapOfsIn))) + b;
291
+
292
+ bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH);
293
+ if (!skip)
294
+ s_tileIn[idx] = v;
295
+
296
+ idx += threadsPerBlock;
297
+ }
298
+
299
+ if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter.
300
+ {
301
+ // Horizontal upsampling.
302
+ __syncthreads();
303
+ if (up == 4)
304
+ {
305
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
306
+ {
307
+ int relUpX0, relInY;
308
+ fast_div_mod<tileUpW>(relUpX0, relInY, idx);
309
+ int relInX0 = relUpX0 / up;
310
+ int src0 = relInX0 + tileInW * relInY;
311
+ int dst = relInY * tileUpW + relUpX0;
312
+ vec4_t v = InternalType<T>::zero_vec4();
313
+ scalar_t a = s_tileIn[src0];
314
+ if (phaseInX == 0)
315
+ {
316
+ #pragma unroll
317
+ for (int step = 0; step < fuSize / up; step++)
318
+ {
319
+ v.x += a * (scalar_t)c_fu[step * up + 0];
320
+ a = s_tileIn[src0 + step + 1];
321
+ v.y += a * (scalar_t)c_fu[step * up + 3];
322
+ v.z += a * (scalar_t)c_fu[step * up + 2];
323
+ v.w += a * (scalar_t)c_fu[step * up + 1];
324
+ }
325
+ }
326
+ else if (phaseInX == 1)
327
+ {
328
+ #pragma unroll
329
+ for (int step = 0; step < fuSize / up; step++)
330
+ {
331
+ v.x += a * (scalar_t)c_fu[step * up + 1];
332
+ v.y += a * (scalar_t)c_fu[step * up + 0];
333
+ a = s_tileIn[src0 + step + 1];
334
+ v.z += a * (scalar_t)c_fu[step * up + 3];
335
+ v.w += a * (scalar_t)c_fu[step * up + 2];
336
+ }
337
+ }
338
+ else if (phaseInX == 2)
339
+ {
340
+ #pragma unroll
341
+ for (int step = 0; step < fuSize / up; step++)
342
+ {
343
+ v.x += a * (scalar_t)c_fu[step * up + 2];
344
+ v.y += a * (scalar_t)c_fu[step * up + 1];
345
+ v.z += a * (scalar_t)c_fu[step * up + 0];
346
+ a = s_tileIn[src0 + step + 1];
347
+ v.w += a * (scalar_t)c_fu[step * up + 3];
348
+ }
349
+ }
350
+ else // (phaseInX == 3)
351
+ {
352
+ #pragma unroll
353
+ for (int step = 0; step < fuSize / up; step++)
354
+ {
355
+ v.x += a * (scalar_t)c_fu[step * up + 3];
356
+ v.y += a * (scalar_t)c_fu[step * up + 2];
357
+ v.z += a * (scalar_t)c_fu[step * up + 1];
358
+ v.w += a * (scalar_t)c_fu[step * up + 0];
359
+ a = s_tileIn[src0 + step + 1];
360
+ }
361
+ }
362
+ s_tileUpX[dst+0] = v.x;
363
+ s_tileUpX[dst+1] = v.y;
364
+ s_tileUpX[dst+2] = v.z;
365
+ s_tileUpX[dst+3] = v.w;
366
+ }
367
+ }
368
+ else if (up == 2)
369
+ {
370
+ bool p0 = (phaseInX == 0);
371
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
372
+ {
373
+ int relUpX0, relInY;
374
+ fast_div_mod<tileUpW>(relUpX0, relInY, idx);
375
+ int relInX0 = relUpX0 / up;
376
+ int src0 = relInX0 + tileInW * relInY;
377
+ int dst = relInY * tileUpW + relUpX0;
378
+ vec2_t v = InternalType<T>::zero_vec2();
379
+ scalar_t a = s_tileIn[src0];
380
+ if (p0) // (phaseInX == 0)
381
+ {
382
+ #pragma unroll
383
+ for (int step = 0; step < fuSize / up; step++)
384
+ {
385
+ v.x += a * (scalar_t)c_fu[step * up + 0];
386
+ a = s_tileIn[src0 + step + 1];
387
+ v.y += a * (scalar_t)c_fu[step * up + 1];
388
+ }
389
+ }
390
+ else // (phaseInX == 1)
391
+ {
392
+ #pragma unroll
393
+ for (int step = 0; step < fuSize / up; step++)
394
+ {
395
+ v.x += a * (scalar_t)c_fu[step * up + 1];
396
+ v.y += a * (scalar_t)c_fu[step * up + 0];
397
+ a = s_tileIn[src0 + step + 1];
398
+ }
399
+ }
400
+ s_tileUpX[dst+0] = v.x;
401
+ s_tileUpX[dst+1] = v.y;
402
+ }
403
+ }
404
+
405
+ // Vertical upsampling & nonlinearity.
406
+
407
+ __syncthreads();
408
+ int groupMask = 15 << ((threadIdx.x & 31) & ~3);
409
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
410
+ int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes.
411
+ if (up == 4)
412
+ {
413
+ minY -= 3; // Adjust according to block height.
414
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
415
+ {
416
+ int relUpX, relInY0;
417
+ fast_div_mod<tileUpW>(relUpX, relInY0, idx);
418
+ int relUpY0 = relInY0 * up;
419
+ int src0 = relInY0 * tileUpW + relUpX;
420
+ int dst = relUpY0 * tileUpW + relUpX;
421
+ vec4_t v = InternalType<T>::zero_vec4();
422
+
423
+ scalar_t a = s_tileUpX[src0];
424
+ if (phaseInY == 0)
425
+ {
426
+ #pragma unroll
427
+ for (int step = 0; step < fuSize / up; step++)
428
+ {
429
+ v.x += a * (scalar_t)c_fu[step * up + 0];
430
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
431
+ v.y += a * (scalar_t)c_fu[step * up + 3];
432
+ v.z += a * (scalar_t)c_fu[step * up + 2];
433
+ v.w += a * (scalar_t)c_fu[step * up + 1];
434
+ }
435
+ }
436
+ else if (phaseInY == 1)
437
+ {
438
+ #pragma unroll
439
+ for (int step = 0; step < fuSize / up; step++)
440
+ {
441
+ v.x += a * (scalar_t)c_fu[step * up + 1];
442
+ v.y += a * (scalar_t)c_fu[step * up + 0];
443
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
444
+ v.z += a * (scalar_t)c_fu[step * up + 3];
445
+ v.w += a * (scalar_t)c_fu[step * up + 2];
446
+ }
447
+ }
448
+ else if (phaseInY == 2)
449
+ {
450
+ #pragma unroll
451
+ for (int step = 0; step < fuSize / up; step++)
452
+ {
453
+ v.x += a * (scalar_t)c_fu[step * up + 2];
454
+ v.y += a * (scalar_t)c_fu[step * up + 1];
455
+ v.z += a * (scalar_t)c_fu[step * up + 0];
456
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
457
+ v.w += a * (scalar_t)c_fu[step * up + 3];
458
+ }
459
+ }
460
+ else // (phaseInY == 3)
461
+ {
462
+ #pragma unroll
463
+ for (int step = 0; step < fuSize / up; step++)
464
+ {
465
+ v.x += a * (scalar_t)c_fu[step * up + 3];
466
+ v.y += a * (scalar_t)c_fu[step * up + 2];
467
+ v.z += a * (scalar_t)c_fu[step * up + 1];
468
+ v.w += a * (scalar_t)c_fu[step * up + 0];
469
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
470
+ }
471
+ }
472
+
473
+ int x = tileOutX * down + relUpX;
474
+ int y = tileOutY * down + relUpY0;
475
+ int signX = x + p.sOfs.x;
476
+ int signY = y + p.sOfs.y;
477
+ int signZ = blockIdx.z + p.blockZofs;
478
+ int signXb = signX >> 2;
479
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
480
+ index_t si1 = si0 + p.sShape.x;
481
+ index_t si2 = si0 + p.sShape.x * 2;
482
+ index_t si3 = si0 + p.sShape.x * 3;
483
+
484
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
485
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
486
+ v.z *= (scalar_t)((float)up * (float)up * p.gain);
487
+ v.w *= (scalar_t)((float)up * (float)up * p.gain);
488
+
489
+ if (signWrite)
490
+ {
491
+ if (!enableWriteSkip)
492
+ {
493
+ // Determine and write signs.
494
+ int sx = __float_as_uint(v.x) >> 31 << 0;
495
+ int sy = __float_as_uint(v.y) >> 31 << 8;
496
+ int sz = __float_as_uint(v.z) >> 31 << 16;
497
+ int sw = __float_as_uint(v.w) >> 31 << 24;
498
+ if (sx) v.x *= p.slope;
499
+ if (sy) v.y *= p.slope;
500
+ if (sz) v.z *= p.slope;
501
+ if (sw) v.w *= p.slope;
502
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
503
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
504
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }
505
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }
506
+
507
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
508
+ {
509
+ // Combine signs.
510
+ uint32_t s = sx + sy + sw + sz;
511
+ s <<= (signX & 3) << 1;
512
+ s |= __shfl_xor_sync(groupMask, s, 1);
513
+ s |= __shfl_xor_sync(groupMask, s, 2);
514
+
515
+ // Write signs.
516
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
517
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
518
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
519
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
520
+ }
521
+ }
522
+ else
523
+ {
524
+ // Determine and write signs.
525
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
526
+ {
527
+ int sx = __float_as_uint(v.x) >> 31 << 0;
528
+ int sy = __float_as_uint(v.y) >> 31 << 8;
529
+ int sz = __float_as_uint(v.z) >> 31 << 16;
530
+ int sw = __float_as_uint(v.w) >> 31 << 24;
531
+ if (sx) v.x *= p.slope;
532
+ if (sy) v.y *= p.slope;
533
+ if (sz) v.z *= p.slope;
534
+ if (sw) v.w *= p.slope;
535
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
536
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
537
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }
538
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }
539
+
540
+ // Combine signs.
541
+ uint32_t s = sx + sy + sw + sz;
542
+ s <<= (signX & 3) << 1;
543
+ s |= __shfl_xor_sync(groupMask, s, 1);
544
+ s |= __shfl_xor_sync(groupMask, s, 2);
545
+
546
+ // Write signs.
547
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
548
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
549
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
550
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
551
+ }
552
+ else
553
+ {
554
+ // Just compute the values.
555
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
556
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
557
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
558
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
559
+ }
560
+ }
561
+ }
562
+ else if (signRead) // Read signs and apply.
563
+ {
564
+ if ((uint32_t)signXb < p.swLimit)
565
+ {
566
+ int ss = (signX & 3) << 1;
567
+ if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
568
+ if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
569
+ if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; }
570
+ if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; }
571
+ }
572
+ }
573
+ else // Forward pass with no sign write.
574
+ {
575
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
576
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
577
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
578
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
579
+ }
580
+
581
+ s_tileUpXY[dst + 0 * tileUpW] = v.x;
582
+ if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y;
583
+ if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z;
584
+ if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w;
585
+ }
586
+ }
587
+ else if (up == 2)
588
+ {
589
+ minY -= 1; // Adjust according to block height.
590
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
591
+ {
592
+ int relUpX, relInY0;
593
+ fast_div_mod<tileUpW>(relUpX, relInY0, idx);
594
+ int relUpY0 = relInY0 * up;
595
+ int src0 = relInY0 * tileUpW + relUpX;
596
+ int dst = relUpY0 * tileUpW + relUpX;
597
+ vec2_t v = InternalType<T>::zero_vec2();
598
+
599
+ scalar_t a = s_tileUpX[src0];
600
+ if (phaseInY == 0)
601
+ {
602
+ #pragma unroll
603
+ for (int step = 0; step < fuSize / up; step++)
604
+ {
605
+ v.x += a * (scalar_t)c_fu[step * up + 0];
606
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
607
+ v.y += a * (scalar_t)c_fu[step * up + 1];
608
+ }
609
+ }
610
+ else // (phaseInY == 1)
611
+ {
612
+ #pragma unroll
613
+ for (int step = 0; step < fuSize / up; step++)
614
+ {
615
+ v.x += a * (scalar_t)c_fu[step * up + 1];
616
+ v.y += a * (scalar_t)c_fu[step * up + 0];
617
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
618
+ }
619
+ }
620
+
621
+ int x = tileOutX * down + relUpX;
622
+ int y = tileOutY * down + relUpY0;
623
+ int signX = x + p.sOfs.x;
624
+ int signY = y + p.sOfs.y;
625
+ int signZ = blockIdx.z + p.blockZofs;
626
+ int signXb = signX >> 2;
627
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
628
+ index_t si1 = si0 + p.sShape.x;
629
+
630
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
631
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
632
+
633
+ if (signWrite)
634
+ {
635
+ if (!enableWriteSkip)
636
+ {
637
+ // Determine and write signs.
638
+ int sx = __float_as_uint(v.x) >> 31 << 0;
639
+ int sy = __float_as_uint(v.y) >> 31 << 8;
640
+ if (sx) v.x *= p.slope;
641
+ if (sy) v.y *= p.slope;
642
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
643
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
644
+
645
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
646
+ {
647
+ // Combine signs.
648
+ int s = sx + sy;
649
+ s <<= signXo;
650
+ s |= __shfl_xor_sync(groupMask, s, 1);
651
+ s |= __shfl_xor_sync(groupMask, s, 2);
652
+
653
+ // Write signs.
654
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
655
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
656
+ }
657
+ }
658
+ else
659
+ {
660
+ // Determine and write signs.
661
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
662
+ {
663
+ int sx = __float_as_uint(v.x) >> 31 << 0;
664
+ int sy = __float_as_uint(v.y) >> 31 << 8;
665
+ if (sx) v.x *= p.slope;
666
+ if (sy) v.y *= p.slope;
667
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
668
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
669
+
670
+ // Combine signs.
671
+ int s = sx + sy;
672
+ s <<= signXo;
673
+ s |= __shfl_xor_sync(groupMask, s, 1);
674
+ s |= __shfl_xor_sync(groupMask, s, 2);
675
+
676
+ // Write signs.
677
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
678
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
679
+ }
680
+ else
681
+ {
682
+ // Just compute the values.
683
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
684
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
685
+ }
686
+ }
687
+ }
688
+ else if (signRead) // Read signs and apply.
689
+ {
690
+ if ((uint32_t)signXb < p.swLimit)
691
+ {
692
+ if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
693
+ if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
694
+ }
695
+ }
696
+ else // Forward pass with no sign write.
697
+ {
698
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
699
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
700
+ }
701
+
702
+ if (!downInline)
703
+ {
704
+ // Write into temporary buffer.
705
+ s_tileUpXY[dst] = v.x;
706
+ if (relUpY0 < tileUpH - 1)
707
+ s_tileUpXY[dst + tileUpW] = v.y;
708
+ }
709
+ else
710
+ {
711
+ // Write directly into output buffer.
712
+ if ((uint32_t)x < p.yShape.x)
713
+ {
714
+ int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down);
715
+ index_t ofs = x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut;
716
+ if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]);
717
+ if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]);
718
+ }
719
+ }
720
+ }
721
+ }
722
+ }
723
+ else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD)
724
+ {
725
+ // Full upsampling filter.
726
+
727
+ if (up == 2)
728
+ {
729
+ // 2 x 2-wide.
730
+ __syncthreads();
731
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs.
732
+ for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4)
733
+ {
734
+ int relUpX0, relUpY0;
735
+ fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
736
+ int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up);
737
+ int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up);
738
+ int src0 = relInX0 + tileInW * relInY0;
739
+ int tap0y = (relInY0 * up + phaseInY - relUpY0);
740
+
741
+ #define X_LOOP(TAPY, PX) \
742
+ for (int sx = 0; sx < fuSize / up; sx++) \
743
+ { \
744
+ v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
745
+ v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
746
+ v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
747
+ v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
748
+ }
749
+
750
+ vec4_t v = InternalType<T>::zero_vec4();
751
+ if (tap0y == 0 && phaseInX == 0)
752
+ #pragma unroll
753
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
754
+ #pragma unroll
755
+ X_LOOP(0, 0) }
756
+ if (tap0y == 0 && phaseInX == 1)
757
+ #pragma unroll
758
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
759
+ #pragma unroll
760
+ X_LOOP(0, 1) }
761
+ if (tap0y == 1 && phaseInX == 0)
762
+ #pragma unroll
763
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
764
+ #pragma unroll
765
+ X_LOOP(1, 0) }
766
+ if (tap0y == 1 && phaseInX == 1)
767
+ #pragma unroll
768
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
769
+ #pragma unroll
770
+ X_LOOP(1, 1) }
771
+
772
+ #undef X_LOOP
773
+
774
+ int x = tileOutX * down + relUpX0;
775
+ int y = tileOutY * down + relUpY0;
776
+ int signX = x + p.sOfs.x;
777
+ int signY = y + p.sOfs.y;
778
+ int signZ = blockIdx.z + p.blockZofs;
779
+ int signXb = signX >> 2;
780
+ index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
781
+
782
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
783
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
784
+ v.z *= (scalar_t)((float)up * (float)up * p.gain);
785
+ v.w *= (scalar_t)((float)up * (float)up * p.gain);
786
+
787
+ if (signWrite)
788
+ {
789
+ if (!enableWriteSkip)
790
+ {
791
+ // Determine and write signs.
792
+ int sx = __float_as_uint(v.x) >> 31;
793
+ int sy = __float_as_uint(v.y) >> 31;
794
+ int sz = __float_as_uint(v.z) >> 31;
795
+ int sw = __float_as_uint(v.w) >> 31;
796
+ if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }
797
+ if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }
798
+ if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }
799
+ if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }
800
+
801
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
802
+ {
803
+ p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
804
+ }
805
+ }
806
+ else
807
+ {
808
+ // Determine and write signs.
809
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
810
+ {
811
+ int sx = __float_as_uint(v.x) >> 31;
812
+ int sy = __float_as_uint(v.y) >> 31;
813
+ int sz = __float_as_uint(v.z) >> 31;
814
+ int sw = __float_as_uint(v.w) >> 31;
815
+ if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }
816
+ if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }
817
+ if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }
818
+ if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }
819
+
820
+ p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
821
+ }
822
+ else
823
+ {
824
+ // Just compute the values.
825
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
826
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
827
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
828
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
829
+ }
830
+ }
831
+ }
832
+ else if (signRead) // Read sign and apply.
833
+ {
834
+ if ((uint32_t)signY < p.sShape.y)
835
+ {
836
+ int s = 0;
837
+ if ((uint32_t)signXb < p.swLimit) s = p.s[si];
838
+ if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8;
839
+ s >>= (signX & 3) << 1;
840
+ if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f;
841
+ if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f;
842
+ if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f;
843
+ if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f;
844
+ }
845
+ }
846
+ else // Forward pass with no sign write.
847
+ {
848
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
849
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
850
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
851
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
852
+ }
853
+
854
+ s_tileUpXY[idx + 0] = v.x;
855
+ s_tileUpXY[idx + 1] = v.y;
856
+ s_tileUpXY[idx + 2] = v.z;
857
+ s_tileUpXY[idx + 3] = v.w;
858
+ }
859
+ }
860
+ else if (up == 1)
861
+ {
862
+ __syncthreads();
863
+ uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3);
864
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
865
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x)
866
+ {
867
+ int relUpX0, relUpY0;
868
+ fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
869
+ scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter.
870
+
871
+ int x = tileOutX * down + relUpX0;
872
+ int y = tileOutY * down + relUpY0;
873
+ int signX = x + p.sOfs.x;
874
+ int signY = y + p.sOfs.y;
875
+ int signZ = blockIdx.z + p.blockZofs;
876
+ int signXb = signX >> 2;
877
+ index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
878
+ v *= (scalar_t)((float)up * (float)up * p.gain);
879
+
880
+ if (signWrite)
881
+ {
882
+ if (!enableWriteSkip)
883
+ {
884
+ // Determine and write sign.
885
+ uint32_t s = 0;
886
+ uint32_t signXbit = (1u << signXo);
887
+ if (v < 0.f)
888
+ {
889
+ s = signXbit;
890
+ v *= p.slope;
891
+ }
892
+ if (fabsf(v) > p.clamp)
893
+ {
894
+ s = signXbit * 2;
895
+ v = InternalType<T>::clamp(v, p.clamp);
896
+ }
897
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
898
+ {
899
+ s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
900
+ s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
901
+ p.s[si] = s; // Write.
902
+ }
903
+ }
904
+ else
905
+ {
906
+ // Determine and write sign.
907
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
908
+ {
909
+ uint32_t s = 0;
910
+ uint32_t signXbit = (1u << signXo);
911
+ if (v < 0.f)
912
+ {
913
+ s = signXbit;
914
+ v *= p.slope;
915
+ }
916
+ if (fabsf(v) > p.clamp)
917
+ {
918
+ s = signXbit * 2;
919
+ v = InternalType<T>::clamp(v, p.clamp);
920
+ }
921
+ s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
922
+ s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
923
+ p.s[si] = s; // Write.
924
+ }
925
+ else
926
+ {
927
+ // Just compute the value.
928
+ if (v < 0.f) v *= p.slope;
929
+ v = InternalType<T>::clamp(v, p.clamp);
930
+ }
931
+ }
932
+ }
933
+ else if (signRead)
934
+ {
935
+ // Read sign and apply if within sign tensor bounds.
936
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y)
937
+ {
938
+ int s = p.s[si];
939
+ s >>= signXo;
940
+ if (s & 1) v *= p.slope;
941
+ if (s & 2) v = 0.f;
942
+ }
943
+ }
944
+ else // Forward pass with no sign write.
945
+ {
946
+ if (v < 0.f) v *= p.slope;
947
+ v = InternalType<T>::clamp(v, p.clamp);
948
+ }
949
+
950
+ if (!downInline) // Write into temporary buffer.
951
+ s_tileUpXY[idx] = v;
952
+ else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer
953
+ *((T*)((char*)p.y + (x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]);
954
+ }
955
+ }
956
+ }
957
+
958
+ // Downsampling.
959
+ if (filterMode == MODE_SUSD || filterMode == MODE_FUSD)
960
+ {
961
+ // Horizontal downsampling.
962
+ __syncthreads();
963
+ if (down == 4 && tileOutW % 4 == 0)
964
+ {
965
+ // Calculate 4 pixels at a time.
966
+ for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4)
967
+ {
968
+ int relOutX0, relUpY;
969
+ fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
970
+ int relUpX0 = relOutX0 * down;
971
+ int src0 = relUpY * tileUpW + relUpX0;
972
+ vec4_t v = InternalType<T>::zero_vec4();
973
+ #pragma unroll
974
+ for (int step = 0; step < fdSize; step++)
975
+ {
976
+ v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
977
+ v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step];
978
+ v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step];
979
+ v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step];
980
+ }
981
+ s_tileDownX[idx+0] = v.x;
982
+ s_tileDownX[idx+1] = v.y;
983
+ s_tileDownX[idx+2] = v.z;
984
+ s_tileDownX[idx+3] = v.w;
985
+ }
986
+ }
987
+ else if ((down == 2 || down == 4) && (tileOutW % 2 == 0))
988
+ {
989
+ // Calculate 2 pixels at a time.
990
+ for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2)
991
+ {
992
+ int relOutX0, relUpY;
993
+ fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
994
+ int relUpX0 = relOutX0 * down;
995
+ int src0 = relUpY * tileUpW + relUpX0;
996
+ vec2_t v = InternalType<T>::zero_vec2();
997
+ #pragma unroll
998
+ for (int step = 0; step < fdSize; step++)
999
+ {
1000
+ v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
1001
+ v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step];
1002
+ }
1003
+ s_tileDownX[idx+0] = v.x;
1004
+ s_tileDownX[idx+1] = v.y;
1005
+ }
1006
+ }
1007
+ else
1008
+ {
1009
+ // Calculate 1 pixel at a time.
1010
+ for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x)
1011
+ {
1012
+ int relOutX0, relUpY;
1013
+ fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
1014
+ int relUpX0 = relOutX0 * down;
1015
+ int src = relUpY * tileUpW + relUpX0;
1016
+ scalar_t v = 0.f;
1017
+ #pragma unroll
1018
+ for (int step = 0; step < fdSize; step++)
1019
+ v += s_tileUpXY[src + step] * (scalar_t)c_fd[step];
1020
+ s_tileDownX[idx] = v;
1021
+ }
1022
+ }
1023
+
1024
+ // Vertical downsampling & store output tile.
1025
+ __syncthreads();
1026
+ for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
1027
+ {
1028
+ int relOutX, relOutY0;
1029
+ fast_div_mod<tileOutW>(relOutX, relOutY0, idx);
1030
+ int relUpY0 = relOutY0 * down;
1031
+ int src0 = relUpY0 * tileOutW + relOutX;
1032
+ scalar_t v = 0;
1033
+ #pragma unroll
1034
+ for (int step = 0; step < fdSize; step++)
1035
+ v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step];
1036
+
1037
+ int outX = tileOutX + relOutX;
1038
+ int outY = tileOutY + relOutY0;
1039
+
1040
+ if (outX < p.yShape.x & outY < p.yShape.y)
1041
+ *((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
1042
+ }
1043
+ }
1044
+ else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD)
1045
+ {
1046
+ // Full downsampling filter.
1047
+ if (down == 2)
1048
+ {
1049
+ // 2-wide.
1050
+ __syncthreads();
1051
+ for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2)
1052
+ {
1053
+ int relOutX0, relOutY0;
1054
+ fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
1055
+ int relUpX0 = relOutX0 * down;
1056
+ int relUpY0 = relOutY0 * down;
1057
+ int src0 = relUpY0 * tileUpW + relUpX0;
1058
+ vec2_t v = InternalType<T>::zero_vec2();
1059
+ #pragma unroll
1060
+ for (int sy = 0; sy < fdSize; sy++)
1061
+ #pragma unroll
1062
+ for (int sx = 0; sx < fdSize; sx++)
1063
+ {
1064
+ v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
1065
+ v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
1066
+ }
1067
+
1068
+ int outX = tileOutX + relOutX0;
1069
+ int outY = tileOutY + relOutY0;
1070
+ if ((uint32_t)outY < p.yShape.y)
1071
+ {
1072
+ index_t ofs = outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut;
1073
+ if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x;
1074
+ if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.x))) = (T)v.y;
1075
+ }
1076
+ }
1077
+ }
1078
+ else if (down == 1 && !downInline)
1079
+ {
1080
+ // Thread per pixel.
1081
+ __syncthreads();
1082
+ for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
1083
+ {
1084
+ int relOutX0, relOutY0;
1085
+ fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
1086
+ scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter.
1087
+
1088
+ int outX = tileOutX + relOutX0;
1089
+ int outY = tileOutY + relOutY0;
1090
+ if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y)
1091
+ *((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
1092
+ }
1093
+ }
1094
+ }
1095
+
1096
+ if (!enableXrep)
1097
+ break;
1098
+ }
1099
+ }
1100
+
1101
+ //------------------------------------------------------------------------
1102
+ // Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant.
1103
+ // Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used.
1104
+
1105
+ template <class T, bool signWrite, bool signRead>
1106
+ static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p)
1107
+ {
1108
+ typedef typename InternalType<T>::scalar_t scalar_t;
1109
+
1110
+ // Indexing.
1111
+ int32_t x = threadIdx.x + blockIdx.x * blockDim.x;
1112
+ int32_t ymax = signWrite ? p.sShape.y : p.xShape.y;
1113
+ int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index.
1114
+
1115
+ // Loop to accommodate oversized tensors.
1116
+ for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z)
1117
+ for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y)
1118
+ {
1119
+ // Extract z and w (channel, minibatch index).
1120
+ int32_t w = q / p.xShape.z;
1121
+ int32_t z = q - w * p.xShape.z;
1122
+
1123
+ // Choose behavior based on sign read/write mode.
1124
+ if (signWrite)
1125
+ {
1126
+ // Process value if in p.x.
1127
+ uint32_t s = 0;
1128
+ if (x < p.xShape.x && y < p.xShape.y)
1129
+ {
1130
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
1131
+ T* pv = ((T*)p.x) + ix;
1132
+ scalar_t v = (scalar_t)(*pv);
1133
+
1134
+ // Gain, LReLU, clamp.
1135
+ v *= p.gain;
1136
+ if (v < 0.f)
1137
+ {
1138
+ v *= p.slope;
1139
+ s = 1; // Sign.
1140
+ }
1141
+ if (fabsf(v) > p.clamp)
1142
+ {
1143
+ v = InternalType<T>::clamp(v, p.clamp);
1144
+ s = 2; // Clamp.
1145
+ }
1146
+
1147
+ *pv = (T)v; // Write value.
1148
+ }
1149
+
1150
+ // Coalesce into threads 0 and 16 of warp.
1151
+ uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu;
1152
+ s <<= ((threadIdx.x & 15) << 1); // Shift into place.
1153
+ s |= __shfl_xor_sync(m, s, 1); // Distribute.
1154
+ s |= __shfl_xor_sync(m, s, 2);
1155
+ s |= __shfl_xor_sync(m, s, 4);
1156
+ s |= __shfl_xor_sync(m, s, 8);
1157
+
1158
+ // Write signs if leader and in p.s.
1159
+ if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in.
1160
+ {
1161
+ uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous.
1162
+ ((uint32_t*)p.s)[is >> 4] = s;
1163
+ }
1164
+ }
1165
+ else if (signRead)
1166
+ {
1167
+ // Process value if in p.x.
1168
+ if (x < p.xShape.x) // y is always in.
1169
+ {
1170
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
1171
+ T* pv = ((T*)p.x) + ix;
1172
+ scalar_t v = (scalar_t)(*pv);
1173
+ v *= p.gain;
1174
+
1175
+ // Apply sign buffer offset.
1176
+ uint32_t sx = x + p.sOfs.x;
1177
+ uint32_t sy = y + p.sOfs.y;
1178
+
1179
+ // Read and apply signs if we land inside valid region of sign buffer.
1180
+ if (sx < p.sShape.x && sy < p.sShape.y)
1181
+ {
1182
+ uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous.
1183
+ unsigned char s = p.s[is];
1184
+ s >>= (sx & 3) << 1; // Shift into place.
1185
+ if (s & 1) // Sign?
1186
+ v *= p.slope;
1187
+ if (s & 2) // Clamp?
1188
+ v = 0.f;
1189
+ }
1190
+
1191
+ *pv = (T)v; // Write value.
1192
+ }
1193
+ }
1194
+ else
1195
+ {
1196
+ // Forward pass with no sign write. Process value if in p.x.
1197
+ if (x < p.xShape.x) // y is always in.
1198
+ {
1199
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
1200
+ T* pv = ((T*)p.x) + ix;
1201
+ scalar_t v = (scalar_t)(*pv);
1202
+ v *= p.gain;
1203
+ if (v < 0.f)
1204
+ v *= p.slope;
1205
+ if (fabsf(v) > p.clamp)
1206
+ v = InternalType<T>::clamp(v, p.clamp);
1207
+ *pv = (T)v; // Write value.
1208
+ }
1209
+ }
1210
+ }
1211
+ }
1212
+
1213
+ template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void)
1214
+ {
1215
+ return (void*)filtered_lrelu_act_kernel<T, signWrite, signRead>;
1216
+ }
1217
+
1218
+ //------------------------------------------------------------------------
1219
+ // CUDA kernel selection.
1220
+
1221
+ template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB)
1222
+ {
1223
+ filtered_lrelu_kernel_spec s = { 0 };
1224
+
1225
+ // Return the first matching kernel.
1226
+ #define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \
1227
+ if (sharedKB >= SH) \
1228
+ if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \
1229
+ if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \
1230
+ if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \
1231
+ { \
1232
+ static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \
1233
+ static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \
1234
+ static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \
1235
+ s.setup = (void*)setup_filters_kernel; \
1236
+ s.exec = (void*)filtered_lrelu_kernel<T, index_t, SH, signWrite, signRead, MODE, U, FU, D, FD, TW, TH, W*32, !!XR, !!WS>; \
1237
+ s.tileOut = make_int2(TW, TH); \
1238
+ s.numWarps = W; \
1239
+ s.xrep = XR; \
1240
+ s.dynamicSharedKB = (SH == 48) ? 0 : SH; \
1241
+ return s; \
1242
+ }
1243
+
1244
+ // Launch parameters for various kernel specializations.
1245
+ // Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first.
1246
+ // Kernels that use more shared memory must be listed before those that use less, for the same reason.
1247
+
1248
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1
1249
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1
1250
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2
1251
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2
1252
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2
1253
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2
1254
+ CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2
1255
+ CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2
1256
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4
1257
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4
1258
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1
1259
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2
1260
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2
1261
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2
1262
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2
1263
+ CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2
1264
+ CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2
1265
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4
1266
+ CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB
1267
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4
1268
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1
1269
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2
1270
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2
1271
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2
1272
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2
1273
+ CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2
1274
+ CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2
1275
+ CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB
1276
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4
1277
+ CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB
1278
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4
1279
+
1280
+ #undef CASE
1281
+ return s; // No kernel found.
1282
+ }
1283
+
1284
+ //------------------------------------------------------------------------
models/torch_utils/ops/filtered_lrelu.h ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <cuda_runtime.h>
10
+
11
+ //------------------------------------------------------------------------
12
+ // CUDA kernel parameters.
13
+
14
+ struct filtered_lrelu_kernel_params
15
+ {
16
+ // These parameters decide which kernel to use.
17
+ int up; // upsampling ratio (1, 2, 4)
18
+ int down; // downsampling ratio (1, 2, 4)
19
+ int2 fuShape; // [size, 1] | [size, size]
20
+ int2 fdShape; // [size, 1] | [size, size]
21
+
22
+ int _dummy; // Alignment.
23
+
24
+ // Rest of the parameters.
25
+ const void* x; // Input tensor.
26
+ void* y; // Output tensor.
27
+ const void* b; // Bias tensor.
28
+ unsigned char* s; // Sign tensor in/out. NULL if unused.
29
+ const float* fu; // Upsampling filter.
30
+ const float* fd; // Downsampling filter.
31
+
32
+ int2 pad0; // Left/top padding.
33
+ float gain; // Additional gain factor.
34
+ float slope; // Leaky ReLU slope on negative side.
35
+ float clamp; // Clamp after nonlinearity.
36
+ int flip; // Filter kernel flip for gradient computation.
37
+
38
+ int tilesXdim; // Original number of horizontal output tiles.
39
+ int tilesXrep; // Number of horizontal tiles per CTA.
40
+ int blockZofs; // Block z offset to support large minibatch, channel dimensions.
41
+
42
+ int4 xShape; // [width, height, channel, batch]
43
+ int4 yShape; // [width, height, channel, batch]
44
+ int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
45
+ int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
46
+ int swLimit; // Active width of sign tensor in bytes.
47
+
48
+ longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
49
+ longlong4 yStride; //
50
+ int64_t bStride; //
51
+ longlong3 fuStride; //
52
+ longlong3 fdStride; //
53
+ };
54
+
55
+ struct filtered_lrelu_act_kernel_params
56
+ {
57
+ void* x; // Input/output, modified in-place.
58
+ unsigned char* s; // Sign tensor in/out. NULL if unused.
59
+
60
+ float gain; // Additional gain factor.
61
+ float slope; // Leaky ReLU slope on negative side.
62
+ float clamp; // Clamp after nonlinearity.
63
+
64
+ int4 xShape; // [width, height, channel, batch]
65
+ longlong4 xStride; // Input/output tensor strides, same order as in shape.
66
+ int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
67
+ int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
68
+ };
69
+
70
+ //------------------------------------------------------------------------
71
+ // CUDA kernel specialization.
72
+
73
+ struct filtered_lrelu_kernel_spec
74
+ {
75
+ void* setup; // Function for filter kernel setup.
76
+ void* exec; // Function for main operation.
77
+ int2 tileOut; // Width/height of launch tile.
78
+ int numWarps; // Number of warps per thread block, determines launch block size.
79
+ int xrep; // For processing multiple horizontal tiles per thread block.
80
+ int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
81
+ };
82
+
83
+ //------------------------------------------------------------------------
84
+ // CUDA kernel selection.
85
+
86
+ template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
87
+ template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void);
88
+ template <bool signWrite, bool signRead> cudaError_t copy_filters(cudaStream_t stream);
89
+
90
+ //------------------------------------------------------------------------
models/torch_utils/ops/filtered_lrelu.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import numpy as np
11
+ import torch
12
+ import warnings
13
+
14
+ from .. import custom_ops
15
+ from .. import misc
16
+ from . import upfirdn2d
17
+ from . import bias_act
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ _plugin = None
22
+
23
+ def _init():
24
+ global _plugin
25
+ if _plugin is None:
26
+ _plugin = custom_ops.get_plugin(
27
+ module_name='filtered_lrelu_plugin',
28
+ sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'],
29
+ headers=['filtered_lrelu.h', 'filtered_lrelu.cu'],
30
+ source_dir=os.path.dirname(__file__),
31
+ extra_cuda_cflags=['--use_fast_math'],
32
+ )
33
+ return True
34
+
35
+ def _get_filter_size(f):
36
+ if f is None:
37
+ return 1, 1
38
+ assert isinstance(f, torch.Tensor)
39
+ assert 1 <= f.ndim <= 2
40
+ return f.shape[-1], f.shape[0] # width, height
41
+
42
+ def _parse_padding(padding):
43
+ if isinstance(padding, int):
44
+ padding = [padding, padding]
45
+ assert isinstance(padding, (list, tuple))
46
+ assert all(isinstance(x, (int, np.integer)) for x in padding)
47
+ padding = [int(x) for x in padding]
48
+ if len(padding) == 2:
49
+ px, py = padding
50
+ padding = [px, px, py, py]
51
+ px0, px1, py0, py1 = padding
52
+ return px0, px1, py0, py1
53
+
54
+ #----------------------------------------------------------------------------
55
+
56
+ def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'):
57
+ r"""Filtered leaky ReLU for a batch of 2D images.
58
+
59
+ Performs the following sequence of operations for each channel:
60
+
61
+ 1. Add channel-specific bias if provided (`b`).
62
+
63
+ 2. Upsample the image by inserting N-1 zeros after each pixel (`up`).
64
+
65
+ 3. Pad the image with the specified number of zeros on each side (`padding`).
66
+ Negative padding corresponds to cropping the image.
67
+
68
+ 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it
69
+ so that the footprint of all output pixels lies within the input image.
70
+
71
+ 5. Multiply each value by the provided gain factor (`gain`).
72
+
73
+ 6. Apply leaky ReLU activation function to each value.
74
+
75
+ 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided.
76
+
77
+ 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking
78
+ it so that the footprint of all output pixels lies within the input image.
79
+
80
+ 9. Downsample the image by keeping every Nth pixel (`down`).
81
+
82
+ The fused op is considerably more efficient than performing the same calculation
83
+ using standard PyTorch ops. It supports gradients of arbitrary order.
84
+
85
+ Args:
86
+ x: Float32/float16/float64 input tensor of the shape
87
+ `[batch_size, num_channels, in_height, in_width]`.
88
+ fu: Float32 upsampling FIR filter of the shape
89
+ `[filter_height, filter_width]` (non-separable),
90
+ `[filter_taps]` (separable), or
91
+ `None` (identity).
92
+ fd: Float32 downsampling FIR filter of the shape
93
+ `[filter_height, filter_width]` (non-separable),
94
+ `[filter_taps]` (separable), or
95
+ `None` (identity).
96
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
97
+ as `x`. The length of vector must must match the channel dimension of `x`.
98
+ up: Integer upsampling factor (default: 1).
99
+ down: Integer downsampling factor. (default: 1).
100
+ padding: Padding with respect to the upsampled image. Can be a single number
101
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
102
+ (default: 0).
103
+ gain: Overall scaling factor for signal magnitude (default: sqrt(2)).
104
+ slope: Slope on the negative side of leaky ReLU (default: 0.2).
105
+ clamp: Maximum magnitude for leaky ReLU output (default: None).
106
+ flip_filter: False = convolution, True = correlation (default: False).
107
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
108
+
109
+ Returns:
110
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
111
+ """
112
+ assert isinstance(x, torch.Tensor)
113
+ assert impl in ['ref', 'cuda']
114
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
115
+ return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0)
116
+ return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter)
117
+
118
+ #----------------------------------------------------------------------------
119
+
120
+ @misc.profiled_function
121
+ def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
122
+ """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using
123
+ existing `upfirdn2n()` and `bias_act()` ops.
124
+ """
125
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
126
+ fu_w, fu_h = _get_filter_size(fu)
127
+ fd_w, fd_h = _get_filter_size(fd)
128
+ if b is not None:
129
+ assert isinstance(b, torch.Tensor) and b.dtype == x.dtype
130
+ misc.assert_shape(b, [x.shape[1]])
131
+ assert isinstance(up, int) and up >= 1
132
+ assert isinstance(down, int) and down >= 1
133
+ px0, px1, py0, py1 = _parse_padding(padding)
134
+ assert gain == float(gain) and gain > 0
135
+ assert slope == float(slope) and slope >= 0
136
+ assert clamp is None or (clamp == float(clamp) and clamp >= 0)
137
+
138
+ # Calculate output size.
139
+ batch_size, channels, in_h, in_w = x.shape
140
+ in_dtype = x.dtype
141
+ out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down
142
+ out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down
143
+
144
+ # Compute using existing ops.
145
+ x = bias_act.bias_act(x=x, b=b) # Apply bias.
146
+ x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
147
+ x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp.
148
+ x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample.
149
+
150
+ # Check output shape & dtype.
151
+ misc.assert_shape(x, [batch_size, channels, out_h, out_w])
152
+ assert x.dtype == in_dtype
153
+ return x
154
+
155
+ #----------------------------------------------------------------------------
156
+
157
+ _filtered_lrelu_cuda_cache = dict()
158
+
159
+ def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
160
+ """Fast CUDA implementation of `filtered_lrelu()` using custom ops.
161
+ """
162
+ assert isinstance(up, int) and up >= 1
163
+ assert isinstance(down, int) and down >= 1
164
+ px0, px1, py0, py1 = _parse_padding(padding)
165
+ assert gain == float(gain) and gain > 0
166
+ gain = float(gain)
167
+ assert slope == float(slope) and slope >= 0
168
+ slope = float(slope)
169
+ assert clamp is None or (clamp == float(clamp) and clamp >= 0)
170
+ clamp = float(clamp if clamp is not None else 'inf')
171
+
172
+ # Lookup from cache.
173
+ key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter)
174
+ if key in _filtered_lrelu_cuda_cache:
175
+ return _filtered_lrelu_cuda_cache[key]
176
+
177
+ # Forward op.
178
+ class FilteredLReluCuda(torch.autograd.Function):
179
+ @staticmethod
180
+ def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ
181
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
182
+
183
+ # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable).
184
+ if fu is None:
185
+ fu = torch.ones([1, 1], dtype=torch.float32, device=x.device)
186
+ if fd is None:
187
+ fd = torch.ones([1, 1], dtype=torch.float32, device=x.device)
188
+ assert 1 <= fu.ndim <= 2
189
+ assert 1 <= fd.ndim <= 2
190
+
191
+ # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1.
192
+ if up == 1 and fu.ndim == 1 and fu.shape[0] == 1:
193
+ fu = fu.square()[None]
194
+ if down == 1 and fd.ndim == 1 and fd.shape[0] == 1:
195
+ fd = fd.square()[None]
196
+
197
+ # Missing sign input tensor.
198
+ if si is None:
199
+ si = torch.empty([0])
200
+
201
+ # Missing bias tensor.
202
+ if b is None:
203
+ b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device)
204
+
205
+ # Construct internal sign tensor only if gradients are needed.
206
+ write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad)
207
+
208
+ # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout.
209
+ strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1]
210
+ if any(a < b for a, b in zip(strides[:-1], strides[1:])):
211
+ warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning)
212
+
213
+ # Call C++/Cuda plugin if datatype is supported.
214
+ if x.dtype in [torch.float16, torch.float32]:
215
+ if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device):
216
+ warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning)
217
+ y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs)
218
+ else:
219
+ return_code = -1
220
+
221
+ # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because
222
+ # only the bit-packed sign tensor is retained for gradient computation.
223
+ if return_code < 0:
224
+ warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning)
225
+
226
+ y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias.
227
+ y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
228
+ so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place.
229
+ y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample.
230
+
231
+ # Prepare for gradient computation.
232
+ ctx.save_for_backward(fu, fd, (si if si.numel() else so))
233
+ ctx.x_shape = x.shape
234
+ ctx.y_shape = y.shape
235
+ ctx.s_ofs = sx, sy
236
+ return y
237
+
238
+ @staticmethod
239
+ def backward(ctx, dy): # pylint: disable=arguments-differ
240
+ fu, fd, si = ctx.saved_tensors
241
+ _, _, xh, xw = ctx.x_shape
242
+ _, _, yh, yw = ctx.y_shape
243
+ sx, sy = ctx.s_ofs
244
+ dx = None # 0
245
+ dfu = None; assert not ctx.needs_input_grad[1]
246
+ dfd = None; assert not ctx.needs_input_grad[2]
247
+ db = None # 3
248
+ dsi = None; assert not ctx.needs_input_grad[4]
249
+ dsx = None; assert not ctx.needs_input_grad[5]
250
+ dsy = None; assert not ctx.needs_input_grad[6]
251
+
252
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]:
253
+ pp = [
254
+ (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0,
255
+ xw * up - yw * down + px0 - (up - 1),
256
+ (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0,
257
+ xh * up - yh * down + py0 - (up - 1),
258
+ ]
259
+ gg = gain * (up ** 2) / (down ** 2)
260
+ ff = (not flip_filter)
261
+ sx = sx - (fu.shape[-1] - 1) + px0
262
+ sy = sy - (fu.shape[0] - 1) + py0
263
+ dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy)
264
+
265
+ if ctx.needs_input_grad[3]:
266
+ db = dx.sum([0, 2, 3])
267
+
268
+ return dx, dfu, dfd, db, dsi, dsx, dsy
269
+
270
+ # Add to cache.
271
+ _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda
272
+ return FilteredLReluCuda
273
+
274
+ #----------------------------------------------------------------------------
models/torch_utils/ops/filtered_lrelu_ns.cu ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include "filtered_lrelu.cu"
10
+
11
+ // Template/kernel specializations for no signs mode (no gradients required).
12
+
13
+ // Full op, 32-bit indexing.
14
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
15
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
16
+
17
+ // Full op, 64-bit indexing.
18
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
19
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
20
+
21
+ // Activation/signs only for generic variant. 64-bit indexing.
22
+ template void* choose_filtered_lrelu_act_kernel<c10::Half, false, false>(void);
23
+ template void* choose_filtered_lrelu_act_kernel<float, false, false>(void);
24
+ template void* choose_filtered_lrelu_act_kernel<double, false, false>(void);
25
+
26
+ // Copy filters to constant memory.
27
+ template cudaError_t copy_filters<false, false>(cudaStream_t stream);
models/torch_utils/ops/filtered_lrelu_rd.cu ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include "filtered_lrelu.cu"
10
+
11
+ // Template/kernel specializations for sign read mode.
12
+
13
+ // Full op, 32-bit indexing.
14
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
15
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
16
+
17
+ // Full op, 64-bit indexing.
18
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
19
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
20
+
21
+ // Activation/signs only for generic variant. 64-bit indexing.
22
+ template void* choose_filtered_lrelu_act_kernel<c10::Half, false, true>(void);
23
+ template void* choose_filtered_lrelu_act_kernel<float, false, true>(void);
24
+ template void* choose_filtered_lrelu_act_kernel<double, false, true>(void);
25
+
26
+ // Copy filters to constant memory.
27
+ template cudaError_t copy_filters<false, true>(cudaStream_t stream);
models/torch_utils/ops/filtered_lrelu_wr.cu ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include "filtered_lrelu.cu"
10
+
11
+ // Template/kernel specializations for sign write mode.
12
+
13
+ // Full op, 32-bit indexing.
14
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
15
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
16
+
17
+ // Full op, 64-bit indexing.
18
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
19
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
20
+
21
+ // Activation/signs only for generic variant. 64-bit indexing.
22
+ template void* choose_filtered_lrelu_act_kernel<c10::Half, true, false>(void);
23
+ template void* choose_filtered_lrelu_act_kernel<float, true, false>(void);
24
+ template void* choose_filtered_lrelu_act_kernel<double, true, false>(void);
25
+
26
+ // Copy filters to constant memory.
27
+ template cudaError_t copy_filters<true, false>(cudaStream_t stream);
models/torch_utils/ops/fma.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
10
+
11
+ import torch
12
+
13
+ #----------------------------------------------------------------------------
14
+
15
+ def fma(a, b, c): # => a * b + c
16
+ return _FusedMultiplyAdd.apply(a, b, c)
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
21
+ @staticmethod
22
+ def forward(ctx, a, b, c): # pylint: disable=arguments-differ
23
+ out = torch.addcmul(c, a, b)
24
+ ctx.save_for_backward(a, b)
25
+ ctx.c_shape = c.shape
26
+ return out
27
+
28
+ @staticmethod
29
+ def backward(ctx, dout): # pylint: disable=arguments-differ
30
+ a, b = ctx.saved_tensors
31
+ c_shape = ctx.c_shape
32
+ da = None
33
+ db = None
34
+ dc = None
35
+
36
+ if ctx.needs_input_grad[0]:
37
+ da = _unbroadcast(dout * b, a.shape)
38
+
39
+ if ctx.needs_input_grad[1]:
40
+ db = _unbroadcast(dout * a, b.shape)
41
+
42
+ if ctx.needs_input_grad[2]:
43
+ dc = _unbroadcast(dout, c_shape)
44
+
45
+ return da, db, dc
46
+
47
+ #----------------------------------------------------------------------------
48
+
49
+ def _unbroadcast(x, shape):
50
+ extra_dims = x.ndim - len(shape)
51
+ assert extra_dims >= 0
52
+ dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
53
+ if len(dim):
54
+ x = x.sum(dim=dim, keepdim=True)
55
+ if extra_dims:
56
+ x = x.reshape(-1, *x.shape[extra_dims+1:])
57
+ assert x.shape == shape
58
+ return x
59
+
60
+ #----------------------------------------------------------------------------
models/torch_utils/ops/grid_sample_gradfix.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom replacement for `torch.nn.functional.grid_sample` that
10
+ supports arbitrarily high order gradients between the input and output.
11
+ Only works on 2D images and assumes
12
+ `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
13
+
14
+ import torch
15
+
16
+ # pylint: disable=redefined-builtin
17
+ # pylint: disable=arguments-differ
18
+ # pylint: disable=protected-access
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ enabled = False # Enable the custom op by setting this to true.
23
+
24
+ #----------------------------------------------------------------------------
25
+
26
+ def grid_sample(input, grid):
27
+ if _should_use_custom_op():
28
+ return _GridSample2dForward.apply(input, grid)
29
+ return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
30
+
31
+ #----------------------------------------------------------------------------
32
+
33
+ def _should_use_custom_op():
34
+ return enabled
35
+
36
+ #----------------------------------------------------------------------------
37
+
38
+ class _GridSample2dForward(torch.autograd.Function):
39
+ @staticmethod
40
+ def forward(ctx, input, grid):
41
+ assert input.ndim == 4
42
+ assert grid.ndim == 4
43
+ output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
44
+ ctx.save_for_backward(input, grid)
45
+ return output
46
+
47
+ @staticmethod
48
+ def backward(ctx, grad_output):
49
+ input, grid = ctx.saved_tensors
50
+ grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
51
+ return grad_input, grad_grid
52
+
53
+ #----------------------------------------------------------------------------
54
+
55
+ class _GridSample2dBackward(torch.autograd.Function):
56
+ @staticmethod
57
+ def forward(ctx, grad_output, input, grid):
58
+ op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
59
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
60
+ ctx.save_for_backward(grid)
61
+ return grad_input, grad_grid
62
+
63
+ @staticmethod
64
+ def backward(ctx, grad2_grad_input, grad2_grad_grid):
65
+ _ = grad2_grad_grid # unused
66
+ grid, = ctx.saved_tensors
67
+ grad2_grad_output = None
68
+ grad2_input = None
69
+ grad2_grid = None
70
+
71
+ if ctx.needs_input_grad[0]:
72
+ grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
73
+
74
+ assert not ctx.needs_input_grad[2]
75
+ return grad2_grad_output, grad2_input, grad2_grid
76
+
77
+ #----------------------------------------------------------------------------
models/torch_utils/ops/upfirdn2d.cpp ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "upfirdn2d.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
17
+ {
18
+ // Validate arguments.
19
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
20
+ TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
21
+ TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
22
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
23
+ TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
24
+ TORCH_CHECK(x.numel() > 0, "x has zero size");
25
+ TORCH_CHECK(f.numel() > 0, "f has zero size");
26
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
27
+ TORCH_CHECK(f.dim() == 2, "f must be rank 2");
28
+ TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large");
29
+ TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
30
+ TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
31
+ TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
32
+
33
+ // Create output tensor.
34
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
35
+ int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
36
+ int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
37
+ TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
38
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
39
+ TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
40
+ TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large");
41
+
42
+ // Initialize CUDA kernel parameters.
43
+ upfirdn2d_kernel_params p;
44
+ p.x = x.data_ptr();
45
+ p.f = f.data_ptr<float>();
46
+ p.y = y.data_ptr();
47
+ p.up = make_int2(upx, upy);
48
+ p.down = make_int2(downx, downy);
49
+ p.pad0 = make_int2(padx0, pady0);
50
+ p.flip = (flip) ? 1 : 0;
51
+ p.gain = gain;
52
+ p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
53
+ p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
54
+ p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
55
+ p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
56
+ p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
57
+ p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
58
+ p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
59
+ p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
60
+
61
+ // Choose CUDA kernel.
62
+ upfirdn2d_kernel_spec spec;
63
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
64
+ {
65
+ spec = choose_upfirdn2d_kernel<scalar_t>(p);
66
+ });
67
+
68
+ // Set looping options.
69
+ p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
70
+ p.loopMinor = spec.loopMinor;
71
+ p.loopX = spec.loopX;
72
+ p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
73
+ p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
74
+
75
+ // Compute grid size.
76
+ dim3 blockSize, gridSize;
77
+ if (spec.tileOutW < 0) // large
78
+ {
79
+ blockSize = dim3(4, 32, 1);
80
+ gridSize = dim3(
81
+ ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
82
+ (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
83
+ p.launchMajor);
84
+ }
85
+ else // small
86
+ {
87
+ blockSize = dim3(256, 1, 1);
88
+ gridSize = dim3(
89
+ ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
90
+ (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
91
+ p.launchMajor);
92
+ }
93
+
94
+ // Launch CUDA kernel.
95
+ void* args[] = {&p};
96
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
97
+ return y;
98
+ }
99
+
100
+ //------------------------------------------------------------------------
101
+
102
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
103
+ {
104
+ m.def("upfirdn2d", &upfirdn2d);
105
+ }
106
+
107
+ //------------------------------------------------------------------------
models/torch_utils/ops/upfirdn2d.cu ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "upfirdn2d.h"
11
+
12
+ //------------------------------------------------------------------------
13
+ // Helpers.
14
+
15
+ template <class T> struct InternalType;
16
+ template <> struct InternalType<double> { typedef double scalar_t; };
17
+ template <> struct InternalType<float> { typedef float scalar_t; };
18
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
+
20
+ static __device__ __forceinline__ int floor_div(int a, int b)
21
+ {
22
+ int t = 1 - a / b;
23
+ return (a + t * b) / b - t;
24
+ }
25
+
26
+ //------------------------------------------------------------------------
27
+ // Generic CUDA implementation for large filters.
28
+
29
+ template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
30
+ {
31
+ typedef typename InternalType<T>::scalar_t scalar_t;
32
+
33
+ // Calculate thread index.
34
+ int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
35
+ int outY = minorBase / p.launchMinor;
36
+ minorBase -= outY * p.launchMinor;
37
+ int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
38
+ int majorBase = blockIdx.z * p.loopMajor;
39
+ if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
40
+ return;
41
+
42
+ // Setup Y receptive field.
43
+ int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
44
+ int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
45
+ int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
46
+ int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
47
+ if (p.flip)
48
+ filterY = p.filterSize.y - 1 - filterY;
49
+
50
+ // Loop over major, minor, and X.
51
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
52
+ for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
53
+ {
54
+ int nc = major * p.sizeMinor + minor;
55
+ int n = nc / p.inSize.z;
56
+ int c = nc - n * p.inSize.z;
57
+ for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
58
+ {
59
+ // Setup X receptive field.
60
+ int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
61
+ int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
62
+ int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
63
+ int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
64
+ if (p.flip)
65
+ filterX = p.filterSize.x - 1 - filterX;
66
+
67
+ // Initialize pointers.
68
+ const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
69
+ const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
70
+ int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
71
+ int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
72
+
73
+ // Inner loop.
74
+ scalar_t v = 0;
75
+ for (int y = 0; y < h; y++)
76
+ {
77
+ for (int x = 0; x < w; x++)
78
+ {
79
+ v += (scalar_t)(*xp) * (scalar_t)(*fp);
80
+ xp += p.inStride.x;
81
+ fp += filterStepX;
82
+ }
83
+ xp += p.inStride.y - w * p.inStride.x;
84
+ fp += filterStepY - w * filterStepX;
85
+ }
86
+
87
+ // Store result.
88
+ v *= p.gain;
89
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
90
+ }
91
+ }
92
+ }
93
+
94
+ //------------------------------------------------------------------------
95
+ // Specialized CUDA implementation for small filters.
96
+
97
+ template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
98
+ static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
99
+ {
100
+ typedef typename InternalType<T>::scalar_t scalar_t;
101
+ const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
102
+ const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
103
+ __shared__ volatile scalar_t sf[filterH][filterW];
104
+ __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
105
+
106
+ // Calculate tile index.
107
+ int minorBase = blockIdx.x;
108
+ int tileOutY = minorBase / p.launchMinor;
109
+ minorBase -= tileOutY * p.launchMinor;
110
+ minorBase *= loopMinor;
111
+ tileOutY *= tileOutH;
112
+ int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
113
+ int majorBase = blockIdx.z * p.loopMajor;
114
+ if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
115
+ return;
116
+
117
+ // Load filter (flipped).
118
+ for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
119
+ {
120
+ int fy = tapIdx / filterW;
121
+ int fx = tapIdx - fy * filterW;
122
+ scalar_t v = 0;
123
+ if (fx < p.filterSize.x & fy < p.filterSize.y)
124
+ {
125
+ int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
126
+ int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
127
+ v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
128
+ }
129
+ sf[fy][fx] = v;
130
+ }
131
+
132
+ // Loop over major and X.
133
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
134
+ {
135
+ int baseNC = major * p.sizeMinor + minorBase;
136
+ int n = baseNC / p.inSize.z;
137
+ int baseC = baseNC - n * p.inSize.z;
138
+ for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
139
+ {
140
+ // Load input pixels.
141
+ int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
142
+ int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
143
+ int tileInX = floor_div(tileMidX, upx);
144
+ int tileInY = floor_div(tileMidY, upy);
145
+ __syncthreads();
146
+ for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
147
+ {
148
+ int relC = inIdx;
149
+ int relInX = relC / loopMinor;
150
+ int relInY = relInX / tileInW;
151
+ relC -= relInX * loopMinor;
152
+ relInX -= relInY * tileInW;
153
+ int c = baseC + relC;
154
+ int inX = tileInX + relInX;
155
+ int inY = tileInY + relInY;
156
+ scalar_t v = 0;
157
+ if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
158
+ v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
159
+ sx[relInY][relInX][relC] = v;
160
+ }
161
+
162
+ // Loop over output pixels.
163
+ __syncthreads();
164
+ for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
165
+ {
166
+ int relC = outIdx;
167
+ int relOutX = relC / loopMinor;
168
+ int relOutY = relOutX / tileOutW;
169
+ relC -= relOutX * loopMinor;
170
+ relOutX -= relOutY * tileOutW;
171
+ int c = baseC + relC;
172
+ int outX = tileOutX + relOutX;
173
+ int outY = tileOutY + relOutY;
174
+
175
+ // Setup receptive field.
176
+ int midX = tileMidX + relOutX * downx;
177
+ int midY = tileMidY + relOutY * downy;
178
+ int inX = floor_div(midX, upx);
179
+ int inY = floor_div(midY, upy);
180
+ int relInX = inX - tileInX;
181
+ int relInY = inY - tileInY;
182
+ int filterX = (inX + 1) * upx - midX - 1; // flipped
183
+ int filterY = (inY + 1) * upy - midY - 1; // flipped
184
+
185
+ // Inner loop.
186
+ if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
187
+ {
188
+ scalar_t v = 0;
189
+ #pragma unroll
190
+ for (int y = 0; y < filterH / upy; y++)
191
+ #pragma unroll
192
+ for (int x = 0; x < filterW / upx; x++)
193
+ v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
194
+ v *= p.gain;
195
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
196
+ }
197
+ }
198
+ }
199
+ }
200
+ }
201
+
202
+ //------------------------------------------------------------------------
203
+ // CUDA kernel selection.
204
+
205
+ template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
206
+ {
207
+ int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
208
+ upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
209
+ if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
210
+
211
+ // No up/downsampling.
212
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
213
+ {
214
+ // contiguous
215
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
216
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
217
+ if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
218
+ if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
219
+ if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
220
+ if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
221
+ if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
222
+ if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
223
+ if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
224
+ if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
225
+ if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
226
+ if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
227
+ if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
228
+ // channels_last
229
+ if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
230
+ if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
231
+ if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
232
+ if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
233
+ if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 16,16,8>, 16,16,8, 1};
234
+ if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
235
+ if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 16,16,8>, 16,16,8, 1};
236
+ if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
237
+ if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
238
+ if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
239
+ if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
240
+ if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
241
+ if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
242
+ }
243
+
244
+ // 2x upsampling.
245
+ if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
246
+ {
247
+ // contiguous
248
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
249
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
250
+ if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
251
+ if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
252
+ if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
253
+ if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
254
+ // channels_last
255
+ if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
256
+ if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
257
+ if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
258
+ if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
259
+ if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
260
+ if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
261
+ }
262
+ if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
263
+ {
264
+ // contiguous
265
+ if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
266
+ if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
267
+ if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
268
+ // channels_last
269
+ if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
270
+ if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
271
+ if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
272
+ }
273
+ if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
274
+ {
275
+ // contiguous
276
+ if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
277
+ if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
278
+ if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
279
+ // channels_last
280
+ if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
281
+ if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
282
+ if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
283
+ }
284
+
285
+ // 2x downsampling.
286
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2)
287
+ {
288
+ // contiguous
289
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 32,16,1>, 32,16,1, 1};
290
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 32,16,1>, 32,16,1, 1};
291
+ if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
292
+ if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
293
+ if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
294
+ if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
295
+ // channels_last
296
+ if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 16,16,1>, 16,16,1, 1};
297
+ if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 16,16,1>, 16,16,1, 1};
298
+ if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
299
+ if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
300
+ if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
301
+ if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
302
+ }
303
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1)
304
+ {
305
+ // contiguous
306
+ if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
307
+ if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
308
+ if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
309
+ // channels_last
310
+ if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
311
+ if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
312
+ if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
313
+ }
314
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2)
315
+ {
316
+ // contiguous
317
+ if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
318
+ if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
319
+ if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
320
+ // channels_last
321
+ if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
322
+ if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
323
+ if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
324
+ }
325
+
326
+ // 4x upsampling.
327
+ if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
328
+ {
329
+ // contiguous
330
+ if (s != 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 64,32,1>, 64,32,1, 1};
331
+ if (s != 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 64,32,1>, 64,32,1, 1};
332
+ // channels_last
333
+ if (s == 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 32,32,1>, 32,32,1, 1};
334
+ if (s == 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 32,32,1>, 32,32,1, 1};
335
+ }
336
+ if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
337
+ {
338
+ // contiguous
339
+ if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,8,1>, 128,8,1, 1};
340
+ if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,8,1>, 128,8,1, 1};
341
+ // channels_last
342
+ if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,1,16>, 128,1,16, 1};
343
+ if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,1,16>, 128,1,16, 1};
344
+ }
345
+ if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
346
+ {
347
+ // contiguous
348
+ if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 32,32,1>, 32,32,1, 1};
349
+ if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 32,32,1>, 32,32,1, 1};
350
+ // channels_last
351
+ if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 1,128,16>, 1,128,16, 1};
352
+ if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 1,128,16>, 1,128,16, 1};
353
+ }
354
+
355
+ // 4x downsampling (inefficient).
356
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1)
357
+ {
358
+ // contiguous
359
+ if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,8,1>, 32,8,1, 1};
360
+ if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,8,1>, 32,8,1, 1};
361
+ // channels_last
362
+ if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,1,8>, 32,1,8, 1};
363
+ if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,1,8>, 32,1,8, 1};
364
+ }
365
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4)
366
+ {
367
+ // contiguous
368
+ if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 32,8,1>, 32,8,1, 1};
369
+ if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 32,8,1>, 32,8,1, 1};
370
+ // channels_last
371
+ if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 1,32,8>, 1,32,8, 1};
372
+ if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 1,32,8>, 1,32,8, 1};
373
+ }
374
+ return spec;
375
+ }
376
+
377
+ //------------------------------------------------------------------------
378
+ // Template specializations.
379
+
380
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
381
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
382
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
383
+
384
+ //------------------------------------------------------------------------
models/torch_utils/ops/upfirdn2d.h ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <cuda_runtime.h>
10
+
11
+ //------------------------------------------------------------------------
12
+ // CUDA kernel parameters.
13
+
14
+ struct upfirdn2d_kernel_params
15
+ {
16
+ const void* x;
17
+ const float* f;
18
+ void* y;
19
+
20
+ int2 up;
21
+ int2 down;
22
+ int2 pad0;
23
+ int flip;
24
+ float gain;
25
+
26
+ int4 inSize; // [width, height, channel, batch]
27
+ int4 inStride;
28
+ int2 filterSize; // [width, height]
29
+ int2 filterStride;
30
+ int4 outSize; // [width, height, channel, batch]
31
+ int4 outStride;
32
+ int sizeMinor;
33
+ int sizeMajor;
34
+
35
+ int loopMinor;
36
+ int loopMajor;
37
+ int loopX;
38
+ int launchMinor;
39
+ int launchMajor;
40
+ };
41
+
42
+ //------------------------------------------------------------------------
43
+ // CUDA kernel specialization.
44
+
45
+ struct upfirdn2d_kernel_spec
46
+ {
47
+ void* kernel;
48
+ int tileOutW;
49
+ int tileOutH;
50
+ int loopMinor;
51
+ int loopX;
52
+ };
53
+
54
+ //------------------------------------------------------------------------
55
+ // CUDA kernel selection.
56
+
57
+ template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58
+
59
+ //------------------------------------------------------------------------
models/torch_utils/ops/upfirdn2d.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom PyTorch ops for efficient resampling of 2D images."""
10
+
11
+ import os
12
+ import numpy as np
13
+ import torch
14
+
15
+ from .. import custom_ops
16
+ from .. import misc
17
+ from . import conv2d_gradfix
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ _plugin = None
22
+
23
+ def _init():
24
+ global _plugin
25
+ if _plugin is None:
26
+ _plugin = custom_ops.get_plugin(
27
+ module_name='upfirdn2d_plugin',
28
+ sources=['upfirdn2d.cpp', 'upfirdn2d.cu'],
29
+ headers=['upfirdn2d.h'],
30
+ source_dir=os.path.dirname(__file__),
31
+ extra_cuda_cflags=['--use_fast_math'],
32
+ )
33
+ return True
34
+
35
+ def _parse_scaling(scaling):
36
+ if isinstance(scaling, int):
37
+ scaling = [scaling, scaling]
38
+ assert isinstance(scaling, (list, tuple))
39
+ assert all(isinstance(x, int) for x in scaling)
40
+ sx, sy = scaling
41
+ assert sx >= 1 and sy >= 1
42
+ return sx, sy
43
+
44
+ def _parse_padding(padding):
45
+ if isinstance(padding, int):
46
+ padding = [padding, padding]
47
+ assert isinstance(padding, (list, tuple))
48
+ assert all(isinstance(x, int) for x in padding)
49
+ if len(padding) == 2:
50
+ padx, pady = padding
51
+ padding = [padx, padx, pady, pady]
52
+ padx0, padx1, pady0, pady1 = padding
53
+ return padx0, padx1, pady0, pady1
54
+
55
+ def _get_filter_size(f):
56
+ if f is None:
57
+ return 1, 1
58
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
59
+ fw = f.shape[-1]
60
+ fh = f.shape[0]
61
+ with misc.suppress_tracer_warnings():
62
+ fw = int(fw)
63
+ fh = int(fh)
64
+ misc.assert_shape(f, [fh, fw][:f.ndim])
65
+ assert fw >= 1 and fh >= 1
66
+ return fw, fh
67
+
68
+ #----------------------------------------------------------------------------
69
+
70
+ def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
71
+ r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
72
+
73
+ Args:
74
+ f: Torch tensor, numpy array, or python list of the shape
75
+ `[filter_height, filter_width]` (non-separable),
76
+ `[filter_taps]` (separable),
77
+ `[]` (impulse), or
78
+ `None` (identity).
79
+ device: Result device (default: cpu).
80
+ normalize: Normalize the filter so that it retains the magnitude
81
+ for constant input signal (DC)? (default: True).
82
+ flip_filter: Flip the filter? (default: False).
83
+ gain: Overall scaling factor for signal magnitude (default: 1).
84
+ separable: Return a separable filter? (default: select automatically).
85
+
86
+ Returns:
87
+ Float32 tensor of the shape
88
+ `[filter_height, filter_width]` (non-separable) or
89
+ `[filter_taps]` (separable).
90
+ """
91
+ # Validate.
92
+ if f is None:
93
+ f = 1
94
+ f = torch.as_tensor(f, dtype=torch.float32)
95
+ assert f.ndim in [0, 1, 2]
96
+ assert f.numel() > 0
97
+ if f.ndim == 0:
98
+ f = f[np.newaxis]
99
+
100
+ # Separable?
101
+ if separable is None:
102
+ separable = (f.ndim == 1 and f.numel() >= 8)
103
+ if f.ndim == 1 and not separable:
104
+ f = f.ger(f)
105
+ assert f.ndim == (1 if separable else 2)
106
+
107
+ # Apply normalize, flip, gain, and device.
108
+ if normalize:
109
+ f /= f.sum()
110
+ if flip_filter:
111
+ f = f.flip(list(range(f.ndim)))
112
+ f = f * (gain ** (f.ndim / 2))
113
+ f = f.to(device=device)
114
+ return f
115
+
116
+ #----------------------------------------------------------------------------
117
+
118
+ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
119
+ r"""Pad, upsample, filter, and downsample a batch of 2D images.
120
+
121
+ Performs the following sequence of operations for each channel:
122
+
123
+ 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
124
+
125
+ 2. Pad the image with the specified number of zeros on each side (`padding`).
126
+ Negative padding corresponds to cropping the image.
127
+
128
+ 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
129
+ so that the footprint of all output pixels lies within the input image.
130
+
131
+ 4. Downsample the image by keeping every Nth pixel (`down`).
132
+
133
+ This sequence of operations bears close resemblance to scipy.signal.upfirdn().
134
+ The fused op is considerably more efficient than performing the same calculation
135
+ using standard PyTorch ops. It supports gradients of arbitrary order.
136
+
137
+ Args:
138
+ x: Float32/float64/float16 input tensor of the shape
139
+ `[batch_size, num_channels, in_height, in_width]`.
140
+ f: Float32 FIR filter of the shape
141
+ `[filter_height, filter_width]` (non-separable),
142
+ `[filter_taps]` (separable), or
143
+ `None` (identity).
144
+ up: Integer upsampling factor. Can be a single int or a list/tuple
145
+ `[x, y]` (default: 1).
146
+ down: Integer downsampling factor. Can be a single int or a list/tuple
147
+ `[x, y]` (default: 1).
148
+ padding: Padding with respect to the upsampled image. Can be a single number
149
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
150
+ (default: 0).
151
+ flip_filter: False = convolution, True = correlation (default: False).
152
+ gain: Overall scaling factor for signal magnitude (default: 1).
153
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
154
+
155
+ Returns:
156
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
157
+ """
158
+ assert isinstance(x, torch.Tensor)
159
+ assert impl in ['ref', 'cuda']
160
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
161
+ return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
162
+ return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
163
+
164
+ #----------------------------------------------------------------------------
165
+
166
+ @misc.profiled_function
167
+ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
168
+ """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
169
+ """
170
+ # Validate arguments.
171
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
172
+ if f is None:
173
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
174
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
175
+ assert f.dtype == torch.float32 and not f.requires_grad
176
+ batch_size, num_channels, in_height, in_width = x.shape
177
+ upx, upy = _parse_scaling(up)
178
+ downx, downy = _parse_scaling(down)
179
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
180
+
181
+ # Check that upsampled buffer is not smaller than the filter.
182
+ upW = in_width * upx + padx0 + padx1
183
+ upH = in_height * upy + pady0 + pady1
184
+ assert upW >= f.shape[-1] and upH >= f.shape[0]
185
+
186
+ # Upsample by inserting zeros.
187
+ x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
188
+ x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
189
+ x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
190
+
191
+ # Pad or crop.
192
+ x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
193
+ x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
194
+
195
+ # Setup filter.
196
+ f = f * (gain ** (f.ndim / 2))
197
+ f = f.to(x.dtype)
198
+ if not flip_filter:
199
+ f = f.flip(list(range(f.ndim)))
200
+
201
+ # Convolve with the filter.
202
+ f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
203
+ if f.ndim == 4:
204
+ x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
205
+ else:
206
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
207
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
208
+
209
+ # Downsample by throwing away pixels.
210
+ x = x[:, :, ::downy, ::downx]
211
+ return x
212
+
213
+ #----------------------------------------------------------------------------
214
+
215
+ _upfirdn2d_cuda_cache = dict()
216
+
217
+ def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
218
+ """Fast CUDA implementation of `upfirdn2d()` using custom ops.
219
+ """
220
+ # Parse arguments.
221
+ upx, upy = _parse_scaling(up)
222
+ downx, downy = _parse_scaling(down)
223
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
224
+
225
+ # Lookup from cache.
226
+ key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
227
+ if key in _upfirdn2d_cuda_cache:
228
+ return _upfirdn2d_cuda_cache[key]
229
+
230
+ # Forward op.
231
+ class Upfirdn2dCuda(torch.autograd.Function):
232
+ @staticmethod
233
+ def forward(ctx, x, f): # pylint: disable=arguments-differ
234
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
235
+ if f is None:
236
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
237
+ if f.ndim == 1 and f.shape[0] == 1:
238
+ f = f.square().unsqueeze(0) # Convert separable-1 into full-1x1.
239
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
240
+ y = x
241
+ if f.ndim == 2:
242
+ y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
243
+ else:
244
+ y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, 1.0)
245
+ y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, gain)
246
+ ctx.save_for_backward(f)
247
+ ctx.x_shape = x.shape
248
+ return y
249
+
250
+ @staticmethod
251
+ def backward(ctx, dy): # pylint: disable=arguments-differ
252
+ f, = ctx.saved_tensors
253
+ _, _, ih, iw = ctx.x_shape
254
+ _, _, oh, ow = dy.shape
255
+ fw, fh = _get_filter_size(f)
256
+ p = [
257
+ fw - padx0 - 1,
258
+ iw * upx - ow * downx + padx0 - upx + 1,
259
+ fh - pady0 - 1,
260
+ ih * upy - oh * downy + pady0 - upy + 1,
261
+ ]
262
+ dx = None
263
+ df = None
264
+
265
+ if ctx.needs_input_grad[0]:
266
+ dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
267
+
268
+ assert not ctx.needs_input_grad[1]
269
+ return dx, df
270
+
271
+ # Add to cache.
272
+ _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
273
+ return Upfirdn2dCuda
274
+
275
+ #----------------------------------------------------------------------------
276
+
277
+ def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
278
+ r"""Filter a batch of 2D images using the given 2D FIR filter.
279
+
280
+ By default, the result is padded so that its shape matches the input.
281
+ User-specified padding is applied on top of that, with negative values
282
+ indicating cropping. Pixels outside the image are assumed to be zero.
283
+
284
+ Args:
285
+ x: Float32/float64/float16 input tensor of the shape
286
+ `[batch_size, num_channels, in_height, in_width]`.
287
+ f: Float32 FIR filter of the shape
288
+ `[filter_height, filter_width]` (non-separable),
289
+ `[filter_taps]` (separable), or
290
+ `None` (identity).
291
+ padding: Padding with respect to the output. Can be a single number or a
292
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
293
+ (default: 0).
294
+ flip_filter: False = convolution, True = correlation (default: False).
295
+ gain: Overall scaling factor for signal magnitude (default: 1).
296
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
297
+
298
+ Returns:
299
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
300
+ """
301
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
302
+ fw, fh = _get_filter_size(f)
303
+ p = [
304
+ padx0 + fw // 2,
305
+ padx1 + (fw - 1) // 2,
306
+ pady0 + fh // 2,
307
+ pady1 + (fh - 1) // 2,
308
+ ]
309
+ return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
310
+
311
+ #----------------------------------------------------------------------------
312
+
313
+ def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
314
+ r"""Upsample a batch of 2D images using the given 2D FIR filter.
315
+
316
+ By default, the result is padded so that its shape is a multiple of the input.
317
+ User-specified padding is applied on top of that, with negative values
318
+ indicating cropping. Pixels outside the image are assumed to be zero.
319
+
320
+ Args:
321
+ x: Float32/float64/float16 input tensor of the shape
322
+ `[batch_size, num_channels, in_height, in_width]`.
323
+ f: Float32 FIR filter of the shape
324
+ `[filter_height, filter_width]` (non-separable),
325
+ `[filter_taps]` (separable), or
326
+ `None` (identity).
327
+ up: Integer upsampling factor. Can be a single int or a list/tuple
328
+ `[x, y]` (default: 1).
329
+ padding: Padding with respect to the output. Can be a single number or a
330
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
331
+ (default: 0).
332
+ flip_filter: False = convolution, True = correlation (default: False).
333
+ gain: Overall scaling factor for signal magnitude (default: 1).
334
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
335
+
336
+ Returns:
337
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
338
+ """
339
+ upx, upy = _parse_scaling(up)
340
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
341
+ fw, fh = _get_filter_size(f)
342
+ p = [
343
+ padx0 + (fw + upx - 1) // 2,
344
+ padx1 + (fw - upx) // 2,
345
+ pady0 + (fh + upy - 1) // 2,
346
+ pady1 + (fh - upy) // 2,
347
+ ]
348
+ return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
349
+
350
+ #----------------------------------------------------------------------------
351
+
352
+ def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
353
+ r"""Downsample a batch of 2D images using the given 2D FIR filter.
354
+
355
+ By default, the result is padded so that its shape is a fraction of the input.
356
+ User-specified padding is applied on top of that, with negative values
357
+ indicating cropping. Pixels outside the image are assumed to be zero.
358
+
359
+ Args:
360
+ x: Float32/float64/float16 input tensor of the shape
361
+ `[batch_size, num_channels, in_height, in_width]`.
362
+ f: Float32 FIR filter of the shape
363
+ `[filter_height, filter_width]` (non-separable),
364
+ `[filter_taps]` (separable), or
365
+ `None` (identity).
366
+ down: Integer downsampling factor. Can be a single int or a list/tuple
367
+ `[x, y]` (default: 1).
368
+ padding: Padding with respect to the input. Can be a single number or a
369
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
370
+ (default: 0).
371
+ flip_filter: False = convolution, True = correlation (default: False).
372
+ gain: Overall scaling factor for signal magnitude (default: 1).
373
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
374
+
375
+ Returns:
376
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
377
+ """
378
+ downx, downy = _parse_scaling(down)
379
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
380
+ fw, fh = _get_filter_size(f)
381
+ p = [
382
+ padx0 + (fw - downx + 1) // 2,
383
+ padx1 + (fw - downx) // 2,
384
+ pady0 + (fh - downy + 1) // 2,
385
+ pady1 + (fh - downy) // 2,
386
+ ]
387
+ return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
388
+
389
+ #----------------------------------------------------------------------------
models/torch_utils/persistence.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Facilities for pickling Python code alongside other data.
10
+
11
+ The pickled code is automatically imported into a separate Python module
12
+ during unpickling. This way, any previously exported pickles will remain
13
+ usable even if the original code is no longer available, or if the current
14
+ version of the code is not consistent with what was originally pickled."""
15
+
16
+ import sys
17
+ import pickle
18
+ import io
19
+ import inspect
20
+ import copy
21
+ import uuid
22
+ import types
23
+ # import dnnlib
24
+
25
+ #----------------------------------------------------------------------------
26
+
27
+ _version = 6 # internal version number
28
+ _decorators = set() # {decorator_class, ...}
29
+ _import_hooks = [] # [hook_function, ...]
30
+ _module_to_src_dict = dict() # {module: src, ...}
31
+ _src_to_module_dict = dict() # {src: module, ...}
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ def persistent_class(orig_class):
36
+ r"""Class decorator that extends a given class to save its source code
37
+ when pickled.
38
+
39
+ Example:
40
+
41
+ from torch_utils import persistence
42
+
43
+ @persistence.persistent_class
44
+ class MyNetwork(torch.nn.Module):
45
+ def __init__(self, num_inputs, num_outputs):
46
+ super().__init__()
47
+ self.fc = MyLayer(num_inputs, num_outputs)
48
+ ...
49
+
50
+ @persistence.persistent_class
51
+ class MyLayer(torch.nn.Module):
52
+ ...
53
+
54
+ When pickled, any instance of `MyNetwork` and `MyLayer` will save its
55
+ source code alongside other internal state (e.g., parameters, buffers,
56
+ and submodules). This way, any previously exported pickle will remain
57
+ usable even if the class definitions have been modified or are no
58
+ longer available.
59
+
60
+ The decorator saves the source code of the entire Python module
61
+ containing the decorated class. It does *not* save the source code of
62
+ any imported modules. Thus, the imported modules must be available
63
+ during unpickling, also including `torch_utils.persistence` itself.
64
+
65
+ It is ok to call functions defined in the same module from the
66
+ decorated class. However, if the decorated class depends on other
67
+ classes defined in the same module, they must be decorated as well.
68
+ This is illustrated in the above example in the case of `MyLayer`.
69
+
70
+ It is also possible to employ the decorator just-in-time before
71
+ calling the constructor. For example:
72
+
73
+ cls = MyLayer
74
+ if want_to_make_it_persistent:
75
+ cls = persistence.persistent_class(cls)
76
+ layer = cls(num_inputs, num_outputs)
77
+
78
+ As an additional feature, the decorator also keeps track of the
79
+ arguments that were used to construct each instance of the decorated
80
+ class. The arguments can be queried via `obj.init_args` and
81
+ `obj.init_kwargs`, and they are automatically pickled alongside other
82
+ object state. A typical use case is to first unpickle a previous
83
+ instance of a persistent class, and then upgrade it to use the latest
84
+ version of the source code:
85
+
86
+ with open('old_pickle.pkl', 'rb') as f:
87
+ old_net = pickle.load(f)
88
+ new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
89
+ misc.copy_params_and_buffers(old_net, new_net, require_all=True)
90
+ """
91
+ assert isinstance(orig_class, type)
92
+ if is_persistent(orig_class):
93
+ return orig_class
94
+
95
+ assert orig_class.__module__ in sys.modules
96
+ orig_module = sys.modules[orig_class.__module__]
97
+ orig_module_src = _module_to_src(orig_module)
98
+
99
+ class Decorator(orig_class):
100
+ _orig_module_src = orig_module_src
101
+ _orig_class_name = orig_class.__name__
102
+
103
+ def __init__(self, *args, **kwargs):
104
+ super().__init__(*args, **kwargs)
105
+ self._init_args = copy.deepcopy(args)
106
+ self._init_kwargs = copy.deepcopy(kwargs)
107
+ assert orig_class.__name__ in orig_module.__dict__
108
+ _check_pickleable(self.__reduce__())
109
+
110
+ @property
111
+ def init_args(self):
112
+ return copy.deepcopy(self._init_args)
113
+
114
+ @property
115
+ def init_kwargs(self):
116
+ return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
117
+
118
+ def __reduce__(self):
119
+ fields = list(super().__reduce__())
120
+ fields += [None] * max(3 - len(fields), 0)
121
+ if fields[0] is not _reconstruct_persistent_obj:
122
+ meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
123
+ fields[0] = _reconstruct_persistent_obj # reconstruct func
124
+ fields[1] = (meta,) # reconstruct args
125
+ fields[2] = None # state dict
126
+ return tuple(fields)
127
+
128
+ Decorator.__name__ = orig_class.__name__
129
+ _decorators.add(Decorator)
130
+ return Decorator
131
+
132
+ #----------------------------------------------------------------------------
133
+
134
+ def is_persistent(obj):
135
+ r"""Test whether the given object or class is persistent, i.e.,
136
+ whether it will save its source code when pickled.
137
+ """
138
+ try:
139
+ if obj in _decorators:
140
+ return True
141
+ except TypeError:
142
+ pass
143
+ return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
144
+
145
+ #----------------------------------------------------------------------------
146
+
147
+ def import_hook(hook):
148
+ r"""Register an import hook that is called whenever a persistent object
149
+ is being unpickled. A typical use case is to patch the pickled source
150
+ code to avoid errors and inconsistencies when the API of some imported
151
+ module has changed.
152
+
153
+ The hook should have the following signature:
154
+
155
+ hook(meta) -> modified meta
156
+
157
+ `meta` is an instance of `dnnlib.EasyDict` with the following fields:
158
+
159
+ type: Type of the persistent object, e.g. `'class'`.
160
+ version: Internal version number of `torch_utils.persistence`.
161
+ module_src Original source code of the Python module.
162
+ class_name: Class name in the original Python module.
163
+ state: Internal state of the object.
164
+
165
+ Example:
166
+
167
+ @persistence.import_hook
168
+ def wreck_my_network(meta):
169
+ if meta.class_name == 'MyNetwork':
170
+ print('MyNetwork is being imported. I will wreck it!')
171
+ meta.module_src = meta.module_src.replace("True", "False")
172
+ return meta
173
+ """
174
+ assert callable(hook)
175
+ _import_hooks.append(hook)
176
+
177
+ #----------------------------------------------------------------------------
178
+
179
+ def _reconstruct_persistent_obj(meta):
180
+ r"""Hook that is called internally by the `pickle` module to unpickle
181
+ a persistent object.
182
+ """
183
+ meta = dnnlib.EasyDict(meta)
184
+ meta.state = dnnlib.EasyDict(meta.state)
185
+ for hook in _import_hooks:
186
+ meta = hook(meta)
187
+ assert meta is not None
188
+
189
+ assert meta.version == _version
190
+ module = _src_to_module(meta.module_src)
191
+
192
+ assert meta.type == 'class'
193
+ orig_class = module.__dict__[meta.class_name]
194
+ decorator_class = persistent_class(orig_class)
195
+ obj = decorator_class.__new__(decorator_class)
196
+
197
+ setstate = getattr(obj, '__setstate__', None)
198
+ if callable(setstate):
199
+ setstate(meta.state) # pylint: disable=not-callable
200
+ else:
201
+ obj.__dict__.update(meta.state)
202
+ return obj
203
+
204
+ #----------------------------------------------------------------------------
205
+
206
+ def _module_to_src(module):
207
+ r"""Query the source code of a given Python module.
208
+ """
209
+ src = _module_to_src_dict.get(module, None)
210
+ if src is None:
211
+ src = inspect.getsource(module)
212
+ _module_to_src_dict[module] = src
213
+ _src_to_module_dict[src] = module
214
+ return src
215
+
216
+ def _src_to_module(src):
217
+ r"""Get or create a Python module for the given source code.
218
+ """
219
+ module = _src_to_module_dict.get(src, None)
220
+ if module is None:
221
+ module_name = "_imported_module_" + uuid.uuid4().hex
222
+ module = types.ModuleType(module_name)
223
+ sys.modules[module_name] = module
224
+ _module_to_src_dict[module] = src
225
+ _src_to_module_dict[src] = module
226
+ exec(src, module.__dict__) # pylint: disable=exec-used
227
+ return module
228
+
229
+ #----------------------------------------------------------------------------
230
+
231
+ def _check_pickleable(obj):
232
+ r"""Check that the given object is pickleable, raising an exception if
233
+ it is not. This function is expected to be considerably more efficient
234
+ than actually pickling the object.
235
+ """
236
+ def recurse(obj):
237
+ if isinstance(obj, (list, tuple, set)):
238
+ return [recurse(x) for x in obj]
239
+ if isinstance(obj, dict):
240
+ return [[recurse(x), recurse(y)] for x, y in obj.items()]
241
+ if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
242
+ return None # Python primitive types are pickleable.
243
+ if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
244
+ return None # NumPy arrays and PyTorch tensors are pickleable.
245
+ if is_persistent(obj):
246
+ return None # Persistent objects are pickleable, by virtue of the constructor check.
247
+ return obj
248
+ with io.BytesIO() as f:
249
+ pickle.dump(recurse(obj), f)
250
+
251
+ #----------------------------------------------------------------------------
models/torch_utils/training_stats.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Facilities for reporting and collecting training statistics across
10
+ multiple processes and devices. The interface is designed to minimize
11
+ synchronization overhead as well as the amount of boilerplate in user
12
+ code."""
13
+
14
+ import re
15
+ import numpy as np
16
+ import torch
17
+ import dnnlib
18
+
19
+ from . import misc
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
24
+ _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
25
+ _counter_dtype = torch.float64 # Data type to use for the internal counters.
26
+ _rank = 0 # Rank of the current process.
27
+ _sync_device = None # Device to use for multiprocess communication. None = single-process.
28
+ _sync_called = False # Has _sync() been called yet?
29
+ _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
30
+ _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
31
+
32
+ #----------------------------------------------------------------------------
33
+
34
+ def init_multiprocessing(rank, sync_device):
35
+ r"""Initializes `torch_utils.training_stats` for collecting statistics
36
+ across multiple processes.
37
+
38
+ This function must be called after
39
+ `torch.distributed.init_process_group()` and before `Collector.update()`.
40
+ The call is not necessary if multi-process collection is not needed.
41
+
42
+ Args:
43
+ rank: Rank of the current process.
44
+ sync_device: PyTorch device to use for inter-process
45
+ communication, or None to disable multi-process
46
+ collection. Typically `torch.device('cuda', rank)`.
47
+ """
48
+ global _rank, _sync_device
49
+ assert not _sync_called
50
+ _rank = rank
51
+ _sync_device = sync_device
52
+
53
+ #----------------------------------------------------------------------------
54
+
55
+ @misc.profiled_function
56
+ def report(name, value):
57
+ r"""Broadcasts the given set of scalars to all interested instances of
58
+ `Collector`, across device and process boundaries.
59
+
60
+ This function is expected to be extremely cheap and can be safely
61
+ called from anywhere in the training loop, loss function, or inside a
62
+ `torch.nn.Module`.
63
+
64
+ Warning: The current implementation expects the set of unique names to
65
+ be consistent across processes. Please make sure that `report()` is
66
+ called at least once for each unique name by each process, and in the
67
+ same order. If a given process has no scalars to broadcast, it can do
68
+ `report(name, [])` (empty list).
69
+
70
+ Args:
71
+ name: Arbitrary string specifying the name of the statistic.
72
+ Averages are accumulated separately for each unique name.
73
+ value: Arbitrary set of scalars. Can be a list, tuple,
74
+ NumPy array, PyTorch tensor, or Python scalar.
75
+
76
+ Returns:
77
+ The same `value` that was passed in.
78
+ """
79
+ if name not in _counters:
80
+ _counters[name] = dict()
81
+
82
+ elems = torch.as_tensor(value)
83
+ if elems.numel() == 0:
84
+ return value
85
+
86
+ elems = elems.detach().flatten().to(_reduce_dtype)
87
+ moments = torch.stack([
88
+ torch.ones_like(elems).sum(),
89
+ elems.sum(),
90
+ elems.square().sum(),
91
+ ])
92
+ assert moments.ndim == 1 and moments.shape[0] == _num_moments
93
+ moments = moments.to(_counter_dtype)
94
+
95
+ device = moments.device
96
+ if device not in _counters[name]:
97
+ _counters[name][device] = torch.zeros_like(moments)
98
+ _counters[name][device].add_(moments)
99
+ return value
100
+
101
+ #----------------------------------------------------------------------------
102
+
103
+ def report0(name, value):
104
+ r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
105
+ but ignores any scalars provided by the other processes.
106
+ See `report()` for further details.
107
+ """
108
+ report(name, value if _rank == 0 else [])
109
+ return value
110
+
111
+ #----------------------------------------------------------------------------
112
+
113
+ class Collector:
114
+ r"""Collects the scalars broadcasted by `report()` and `report0()` and
115
+ computes their long-term averages (mean and standard deviation) over
116
+ user-defined periods of time.
117
+
118
+ The averages are first collected into internal counters that are not
119
+ directly visible to the user. They are then copied to the user-visible
120
+ state as a result of calling `update()` and can then be queried using
121
+ `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
122
+ internal counters for the next round, so that the user-visible state
123
+ effectively reflects averages collected between the last two calls to
124
+ `update()`.
125
+
126
+ Args:
127
+ regex: Regular expression defining which statistics to
128
+ collect. The default is to collect everything.
129
+ keep_previous: Whether to retain the previous averages if no
130
+ scalars were collected on a given round
131
+ (default: True).
132
+ """
133
+ def __init__(self, regex='.*', keep_previous=True):
134
+ self._regex = re.compile(regex)
135
+ self._keep_previous = keep_previous
136
+ self._cumulative = dict()
137
+ self._moments = dict()
138
+ self.update()
139
+ self._moments.clear()
140
+
141
+ def names(self):
142
+ r"""Returns the names of all statistics broadcasted so far that
143
+ match the regular expression specified at construction time.
144
+ """
145
+ return [name for name in _counters if self._regex.fullmatch(name)]
146
+
147
+ def update(self):
148
+ r"""Copies current values of the internal counters to the
149
+ user-visible state and resets them for the next round.
150
+
151
+ If `keep_previous=True` was specified at construction time, the
152
+ operation is skipped for statistics that have received no scalars
153
+ since the last update, retaining their previous averages.
154
+
155
+ This method performs a number of GPU-to-CPU transfers and one
156
+ `torch.distributed.all_reduce()`. It is intended to be called
157
+ periodically in the main training loop, typically once every
158
+ N training steps.
159
+ """
160
+ if not self._keep_previous:
161
+ self._moments.clear()
162
+ for name, cumulative in _sync(self.names()):
163
+ if name not in self._cumulative:
164
+ self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
165
+ delta = cumulative - self._cumulative[name]
166
+ self._cumulative[name].copy_(cumulative)
167
+ if float(delta[0]) != 0:
168
+ self._moments[name] = delta
169
+
170
+ def _get_delta(self, name):
171
+ r"""Returns the raw moments that were accumulated for the given
172
+ statistic between the last two calls to `update()`, or zero if
173
+ no scalars were collected.
174
+ """
175
+ assert self._regex.fullmatch(name)
176
+ if name not in self._moments:
177
+ self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
178
+ return self._moments[name]
179
+
180
+ def num(self, name):
181
+ r"""Returns the number of scalars that were accumulated for the given
182
+ statistic between the last two calls to `update()`, or zero if
183
+ no scalars were collected.
184
+ """
185
+ delta = self._get_delta(name)
186
+ return int(delta[0])
187
+
188
+ def mean(self, name):
189
+ r"""Returns the mean of the scalars that were accumulated for the
190
+ given statistic between the last two calls to `update()`, or NaN if
191
+ no scalars were collected.
192
+ """
193
+ delta = self._get_delta(name)
194
+ if int(delta[0]) == 0:
195
+ return float('nan')
196
+ return float(delta[1] / delta[0])
197
+
198
+ def std(self, name):
199
+ r"""Returns the standard deviation of the scalars that were
200
+ accumulated for the given statistic between the last two calls to
201
+ `update()`, or NaN if no scalars were collected.
202
+ """
203
+ delta = self._get_delta(name)
204
+ if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
205
+ return float('nan')
206
+ if int(delta[0]) == 1:
207
+ return float(0)
208
+ mean = float(delta[1] / delta[0])
209
+ raw_var = float(delta[2] / delta[0])
210
+ return np.sqrt(max(raw_var - np.square(mean), 0))
211
+
212
+ def as_dict(self):
213
+ r"""Returns the averages accumulated between the last two calls to
214
+ `update()` as an `dnnlib.EasyDict`. The contents are as follows:
215
+
216
+ dnnlib.EasyDict(
217
+ NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
218
+ ...
219
+ )
220
+ """
221
+ stats = dnnlib.EasyDict()
222
+ for name in self.names():
223
+ stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
224
+ return stats
225
+
226
+ def __getitem__(self, name):
227
+ r"""Convenience getter.
228
+ `collector[name]` is a synonym for `collector.mean(name)`.
229
+ """
230
+ return self.mean(name)
231
+
232
+ #----------------------------------------------------------------------------
233
+
234
+ def _sync(names):
235
+ r"""Synchronize the global cumulative counters across devices and
236
+ processes. Called internally by `Collector.update()`.
237
+ """
238
+ if len(names) == 0:
239
+ return []
240
+ global _sync_called
241
+ _sync_called = True
242
+
243
+ # Collect deltas within current rank.
244
+ deltas = []
245
+ device = _sync_device if _sync_device is not None else torch.device('cpu')
246
+ for name in names:
247
+ delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
248
+ for counter in _counters[name].values():
249
+ delta.add_(counter.to(device))
250
+ counter.copy_(torch.zeros_like(counter))
251
+ deltas.append(delta)
252
+ deltas = torch.stack(deltas)
253
+
254
+ # Sum deltas across ranks.
255
+ if _sync_device is not None:
256
+ torch.distributed.all_reduce(deltas)
257
+
258
+ # Update cumulative values.
259
+ deltas = deltas.cpu()
260
+ for idx, name in enumerate(names):
261
+ if name not in _cumulative:
262
+ _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
263
+ _cumulative[name].add_(deltas[idx])
264
+
265
+ # Return name-value pairs.
266
+ return [(name, _cumulative[name]) for name in names]
267
+
268
+ #----------------------------------------------------------------------------