echen01 commited on
Commit
2fec875
β€’
1 Parent(s): 926824a

working demo

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. PTI/.gitignore +0 -1
  2. PTI/LICENSE +0 -21
  3. PTI/README.md +0 -229
  4. PTI/torch_utils/custom_ops.py +0 -126
  5. PTI/torch_utils/misc.py +0 -262
  6. PTI/torch_utils/ops/bias_act.cpp +0 -99
  7. PTI/torch_utils/ops/bias_act.cu +0 -173
  8. PTI/torch_utils/ops/bias_act.h +0 -38
  9. PTI/torch_utils/ops/bias_act.py +0 -212
  10. PTI/torch_utils/ops/conv2d_gradfix.py +0 -170
  11. PTI/torch_utils/ops/conv2d_resample.py +0 -156
  12. PTI/torch_utils/ops/fma.py +0 -60
  13. PTI/torch_utils/ops/grid_sample_gradfix.py +0 -83
  14. PTI/torch_utils/ops/upfirdn2d.cpp +0 -103
  15. PTI/torch_utils/ops/upfirdn2d.cu +0 -350
  16. PTI/torch_utils/ops/upfirdn2d.h +0 -59
  17. PTI/torch_utils/ops/upfirdn2d.py +0 -384
  18. PTI/torch_utils/persistence.py +0 -251
  19. PTI/torch_utils/training_stats.py +0 -268
  20. app.py +62 -13
  21. checkpoints/model_gradio_demo_input.pt +3 -0
  22. PTI/criteria/color_transfer_loss.py β†’ color_transfer_loss.py +0 -0
  23. {PTI/configs β†’ configs}/__init__.py +0 -0
  24. {PTI/configs β†’ configs}/evaluation_config.py +0 -0
  25. {PTI/configs β†’ configs}/global_config.py +2 -2
  26. {PTI/configs β†’ configs}/hyperparameters.py +1 -1
  27. {PTI/configs β†’ configs}/paths_config.py +4 -4
  28. {PTI/criteria β†’ criteria}/__init__.py +0 -0
  29. {PTI/criteria β†’ criteria}/backbones/__init__.py +0 -0
  30. {PTI/criteria β†’ criteria}/backbones/iresnet.py +0 -0
  31. {PTI/criteria β†’ criteria}/backbones/iresnet2060.py +0 -0
  32. {PTI/criteria β†’ criteria}/backbones/mobilefacenet.py +0 -0
  33. {PTI/criteria β†’ criteria}/deeplab.py +0 -0
  34. {PTI/criteria β†’ criteria}/helpers.py +0 -0
  35. {PTI/criteria β†’ criteria}/id_loss.py +0 -0
  36. {PTI/criteria β†’ criteria}/l2_loss.py +0 -0
  37. {PTI/criteria β†’ criteria}/localitly_regulizer.py +0 -0
  38. {PTI/criteria β†’ criteria}/mask.py +0 -0
  39. {PTI/criteria β†’ criteria}/model_irse.py +0 -0
  40. {PTI/criteria β†’ criteria}/validation.py +0 -0
  41. dnnlib/__pycache__/__init__.cpython-39.pyc +0 -0
  42. dnnlib/__pycache__/util.cpython-39.pyc +0 -0
  43. embeddings/2010/PTI/input/0.pt +3 -0
  44. imgs/Steven-Yeun.jpg +3 -0
  45. imgs/cropped/input.png +3 -0
  46. imgs/input.png +3 -0
  47. {PTI/training β†’ models/StyleCLIP}/__init__.py +0 -0
  48. {PTI/training/coaches β†’ models/StyleCLIP/criteria}/__init__.py +0 -0
  49. models/StyleCLIP/criteria/clip_loss.py +17 -0
  50. models/StyleCLIP/criteria/id_loss.py +39 -0
PTI/.gitignore DELETED
@@ -1 +0,0 @@
1
-
 
 
PTI/LICENSE DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2021 Daniel Roich
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/README.md DELETED
@@ -1,229 +0,0 @@
1
- # PTI: Pivotal Tuning for Latent-based editing of Real Images
2
-
3
- <!-- > Recently, a surge of advanced facial editing techniques have been proposed
4
- that leverage the generative power of a pre-trained StyleGAN. To successfully
5
- edit an image this way, one must first project (or invert) the image into
6
- the pre-trained generator’s domain. As it turns out, however, StyleGAN’s
7
- latent space induces an inherent tradeoff between distortion and editability,
8
- i.e. between maintaining the original appearance and convincingly altering
9
- some of its attributes. Practically, this means it is still challenging to
10
- apply ID-preserving facial latent-space editing to faces which are out of the
11
- generator’s domain. In this paper, we present an approach to bridge this
12
- gap. Our technique slightly alters the generator, so that an out-of-domain
13
- image is faithfully mapped into an in-domain latent code. The key idea is
14
- pivotal tuning β€” a brief training process that preserves the editing quality
15
- of an in-domain latent region, while changing its portrayed identity and
16
- appearance. In Pivotal Tuning Inversion (PTI), an initial inverted latent code
17
- serves as a pivot, around which the generator is fined-tuned. At the same
18
- time, a regularization term keeps nearby identities intact, to locally contain
19
- the effect. This surgical training process ends up altering appearance features
20
- that represent mostly identity, without affecting editing capabilities.
21
- To supplement this, we further show that pivotal tuning can also adjust the
22
- generator to accommodate a multitude of faces, while introducing negligible
23
- distortion on the rest of the domain. We validate our technique through
24
- inversion and editing metrics, and show preferable scores to state-of-the-art
25
- methods. We further qualitatively demonstrate our technique by applying
26
- advanced edits (such as pose, age, or expression) to numerous images of
27
- well-known and recognizable identities. Finally, we demonstrate resilience
28
- to harder cases, including heavy make-up, elaborate hairstyles and/or headwear,
29
- which otherwise could not have been successfully inverted and edited
30
- by state-of-the-art methods. -->
31
-
32
- <a href="https://arxiv.org/abs/2106.05744"><img src="https://img.shields.io/badge/arXiv-2008.00951-b31b1b.svg"></a>
33
- <a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-yellow.svg"></a>
34
- Inference Notebook: <a href="https://colab.research.google.com/github/danielroich/PTI/blob/main/notebooks/inference_playground.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" height=20></a>
35
-
36
- <p align="center">
37
- <img src="docs/teaser.jpg"/>
38
- <br>
39
- Pivotal Tuning Inversion (PTI) enables employing off-the-shelf latent based
40
- semantic editing techniques on real images using StyleGAN.
41
- PTI excels in identity preserving edits, portrayed through recognizable figures β€”
42
- Serena Williams and Robert Downey Jr. (top), and in handling faces which
43
- are clearly out-of-domain, e.g., due to heavy makeup (bottom).
44
- </br>
45
- </p>
46
-
47
- ## Description
48
- Official Implementation of our PTI paper + code for evaluation metrics. PTI introduces an optimization mechanizem for solving the StyleGAN inversion task.
49
- Providing near-perfect reconstruction results while maintaining the high editing abilitis of the native StyleGAN latent space W. For more details, see <a href="https://arxiv.org/abs/2106.05744"><img src="https://img.shields.io/badge/arXiv-2008.00951-b31b1b.svg"></a>
50
-
51
- ## Recent Updates
52
- **2021.07.01**: Fixed files download phase in the inference notebook. Which might caused the notebook not to run smoothly.
53
-
54
- **2021.06.29**: Added support for CPU. In order to run PTI on CPU please change `device` parameter under `configs/global_config.py` to "cpu" instead of "cuda".
55
-
56
- **2021.06.25** : Adding mohawk edit using StyleCLIP+PTI in inference notebook.
57
- Updating documentation in inference notebook due to Google Drive rate limit reached.
58
- Currently, Google Drive does not allow to download the pretrined models using Colab automatically. Manual intervention might be needed.
59
-
60
- ## Getting Started
61
- ### Prerequisites
62
- - Linux or macOS
63
- - NVIDIA GPU + CUDA CuDNN (Not mandatory bur recommended)
64
- - Python 3
65
-
66
- ### Installation
67
- - Dependencies:
68
- 1. lpips
69
- 2. wandb
70
- 3. pytorch
71
- 4. torchvision
72
- 5. matplotlib
73
- 6. dlib
74
- - All dependencies can be installed using *pip install* and the package name
75
-
76
- ## Pretrained Models
77
- Please download the pretrained models from the following links.
78
-
79
- ### Auxiliary Models
80
- We provide various auxiliary models needed for PTI inversion task.
81
- This includes the StyleGAN generator and pre-trained models used for loss computation.
82
- | Path | Description
83
- | :--- | :----------
84
- |[FFHQ StyleGAN](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl) | StyleGAN2-ada model trained on FFHQ with 1024x1024 output resolution.
85
- |[Dlib alignment](https://drive.google.com/file/d/1HKmjg6iXsWr4aFPuU0gBXPGR83wqMzq7/view?usp=sharing) | Dlib alignment used for images preproccessing.
86
- |[FFHQ e4e encoder](https://drive.google.com/file/d/1ALC5CLA89Ouw40TwvxcwebhzWXM5YSCm/view?usp=sharing) | Pretrained e4e encoder. Used for StyleCLIP editing.
87
-
88
- Note: The StyleGAN model is used directly from the official [stylegan2-ada-pytorch implementation](https://github.com/NVlabs/stylegan2-ada-pytorch).
89
- For StyleCLIP pretrained mappers, please see [StyleCLIP's official routes](https://github.com/orpatashnik/StyleCLIP/blob/main/utils.py)
90
-
91
-
92
- By default, we assume that all auxiliary models are downloaded and saved to the directory `pretrained_models`.
93
- However, you may use your own paths by changing the necessary values in `configs/path_configs.py`.
94
-
95
-
96
- ## Inversion
97
- ### Preparing your Data
98
- In order to invert a real image and edit it you should first align and crop it to the correct size. To do so you should perform *One* of the following steps:
99
- 1. Run `notebooks/align_data.ipynb` and change the "images_path" variable to the raw images path
100
- 2. Run `utils/align_data.py` and change the "images_path" variable to the raw images path
101
-
102
-
103
- ### Weights And Biases
104
- The project supports [Weights And Biases](https://wandb.ai/home) framework for experiment tracking. For the inversion task it enables visualization of the losses progression and the generator intermediate results during the initial inversion and the *Pivotal Tuning*(PT) procedure.
105
-
106
- The log frequency can be adjusted using the parameters defined at `configs/global_config.py` under the "Logs" subsection.
107
-
108
- There is no no need to have an account. However, in order to use the features provided by Weights and Biases you first have to register on their site.
109
-
110
-
111
- ### Running PTI
112
- The main training script is `scripts/run_pti.py`. The script receives aligned and cropped images from paths configured in the "Input info" subscetion in
113
- `configs/paths_config.py`.
114
- Results are saved to directories found at "Dirs for output files" under `configs/paths_config.py`. This includes inversion latent codes and tuned generators.
115
- The hyperparametrs for the inversion task can be found at `configs/hyperparameters.py`. They are intilized to the default values used in the paper.
116
-
117
- ## Editing
118
- By default, we assume that all auxiliary edit directions are downloaded and saved to the directory `editings`.
119
- However, you may use your own paths by changing the necessary values in `configs/path_configs.py` under "Edit directions" subsection.
120
-
121
- Example of editing code can be found at `scripts/latent_editor_wrapper.py`
122
-
123
- ## Inference Notebooks
124
- To help visualize the results of PTI we provide a Jupyter notebook found in `notebooks/inference_playground.ipynb`.
125
- The notebook will download the pretrained models and run inference on a sample image found online or
126
- on images of your choosing. It is recommended to run this in [Google Colab](https://colab.research.google.com/github/danielroich/PTI/blob/main/notebooks/inference_playground.ipynb).
127
-
128
- The notebook demonstrates how to:
129
- - Invert an image using PTI
130
- - Visualise the inversion and use the PTI output
131
- - Edit the image after PTI using InterfaceGAN and StyleCLIP
132
- - Compare to other inversion methods
133
-
134
- ## Evaluation
135
- Currently the repository supports qualitative evaluation for reconstruction of: PTI, SG2 (*W Space*), e4e, SG2Plus (*W+ Space*).
136
- As well as editing using InterfaceGAN and GANSpace for the same inversion methods.
137
- To run the evaluation please see `evaluation/qualitative_edit_comparison.py`. Examples of the evaluation scripts are:
138
-
139
- <p align="center">
140
- <img src="docs/model_rec.jpg"/>
141
- <br>
142
- Reconsturction comparison between different methods. The images order is: Original image, W+ inversion, e4e inversion, W inversion, PTI inversion
143
- </br>
144
- </p>
145
-
146
- <p align="center">
147
- <img src="docs/stern_rotation.jpg"/>
148
- <br>
149
- InterfaceGAN pose edit comparison between different methods. The images order is: Original, W+, e4e, W, PTI
150
- </br>
151
- </p>
152
-
153
- <p align="center">
154
- <img src="docs/tyron_original.jpg" width="220" height="220"/>
155
- <img src="docs/tyron_edit.jpg" width="220" height="220"/>
156
- <br>
157
- Image per edit or several edits without comparison
158
- </br>
159
- </p>
160
-
161
- ### Coming Soon - Quantitative evaluation and StyleCLIP qualitative evaluation
162
-
163
- ## Repository structure
164
- | Path | Description <img width=200>
165
- | :--- | :---
166
- | &boxvr;&nbsp; configs | Folder containing configs defining Hyperparameters, paths and logging
167
- | &boxvr;&nbsp; criteria | Folder containing various loss and regularization criterias for the optimization
168
- | &boxvr;&nbsp; dnnlib | Folder containing internal utils for StyleGAN2-ada
169
- | &boxvr;&nbsp; docs | Folder containing the latent space edit directions
170
- | &boxvr;&nbsp; editings | Folder containing images displayed in the README
171
- | &boxvr;&nbsp; environment | Folder containing Anaconda environment used in our experiments
172
- | &boxvr;&nbsp; licenses | Folder containing licenses of the open source projects used in this repository
173
- | &boxvr;&nbsp; models | Folder containing models used in different editing techniques and first phase inversion
174
- | &boxvr;&nbsp; notebooks | Folder with jupyter notebooks to demonstrate the usage of PTI end-to-end
175
- | &boxvr;&nbsp; scripts | Folder with running scripts for inversion, editing and metric computations
176
- | &boxvr;&nbsp; torch_utils | Folder containing internal utils for StyleGAN2-ada
177
- | &boxvr;&nbsp; training | Folder containing the core training logic of PTI
178
- | &boxvr;&nbsp; utils | Folder with various utility functions
179
-
180
-
181
- ## Credits
182
- **StyleGAN2-ada model and implementation:**
183
- https://github.com/NVlabs/stylegan2-ada-pytorch
184
- Copyright Β© 2021, NVIDIA Corporation.
185
- Nvidia Source Code License https://nvlabs.github.io/stylegan2-ada-pytorch/license.html
186
-
187
- **LPIPS model and implementation:**
188
- https://github.com/richzhang/PerceptualSimilarity
189
- Copyright (c) 2020, Sou Uchida
190
- License (BSD 2-Clause) https://github.com/richzhang/PerceptualSimilarity/blob/master/LICENSE
191
-
192
- **e4e model and implementation:**
193
- https://github.com/omertov/encoder4editing
194
- Copyright (c) 2021 omertov
195
- License (MIT) https://github.com/omertov/encoder4editing/blob/main/LICENSE
196
-
197
- **StyleCLIP model and implementation:**
198
- https://github.com/orpatashnik/StyleCLIP
199
- Copyright (c) 2021 orpatashnik
200
- License (MIT) https://github.com/orpatashnik/StyleCLIP/blob/main/LICENSE
201
-
202
- **InterfaceGAN implementation:**
203
- https://github.com/genforce/interfacegan
204
- Copyright (c) 2020 genforce
205
- License (MIT) https://github.com/genforce/interfacegan/blob/master/LICENSE
206
-
207
- **GANSpace implementation:**
208
- https://github.com/harskish/ganspace
209
- Copyright (c) 2020 harkish
210
- License (Apache License 2.0) https://github.com/harskish/ganspace/blob/master/LICENSE
211
-
212
-
213
- ## Acknowledgments
214
- This repository structure is based on [encoder4editing](https://github.com/omertov/encoder4editing) and [ReStyle](https://github.com/yuval-alaluf/restyle-encoder) repositories
215
-
216
- ## Contact
217
- For any inquiry please contact us at our email addresses: [email protected] or [email protected]
218
-
219
-
220
- ## Citation
221
- If you use this code for your research, please cite:
222
- ```
223
- @article{roich2021pivotal,
224
- title={Pivotal Tuning for Latent-based Editing of Real Images},
225
- author={Roich, Daniel and Mokady, Ron and Bermano, Amit H and Cohen-Or, Daniel},
226
- journal={arXiv preprint arXiv:2106.05744},
227
- year={2021}
228
- }
229
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/torch_utils/custom_ops.py DELETED
@@ -1,126 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- import os
10
- import glob
11
- import torch
12
- import torch.utils.cpp_extension
13
- import importlib
14
- import hashlib
15
- import shutil
16
- from pathlib import Path
17
-
18
- from torch.utils.file_baton import FileBaton
19
-
20
- #----------------------------------------------------------------------------
21
- # Global options.
22
-
23
- verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
24
-
25
- #----------------------------------------------------------------------------
26
- # Internal helper funcs.
27
-
28
- def _find_compiler_bindir():
29
- patterns = [
30
- 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
31
- 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
32
- 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
33
- 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
34
- ]
35
- for pattern in patterns:
36
- matches = sorted(glob.glob(pattern))
37
- if len(matches):
38
- return matches[-1]
39
- return None
40
-
41
- #----------------------------------------------------------------------------
42
- # Main entry point for compiling and loading C++/CUDA plugins.
43
-
44
- _cached_plugins = dict()
45
-
46
- def get_plugin(module_name, sources, **build_kwargs):
47
- assert verbosity in ['none', 'brief', 'full']
48
-
49
- # Already cached?
50
- if module_name in _cached_plugins:
51
- return _cached_plugins[module_name]
52
-
53
- # Print status.
54
- if verbosity == 'full':
55
- print(f'Setting up PyTorch plugin "{module_name}"...')
56
- elif verbosity == 'brief':
57
- print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
58
-
59
- try: # pylint: disable=too-many-nested-blocks
60
- # Make sure we can find the necessary compiler binaries.
61
- if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
62
- compiler_bindir = _find_compiler_bindir()
63
- if compiler_bindir is None:
64
- raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
65
- os.environ['PATH'] += ';' + compiler_bindir
66
-
67
- # Compile and load.
68
- verbose_build = (verbosity == 'full')
69
-
70
- # Incremental build md5sum trickery. Copies all the input source files
71
- # into a cached build directory under a combined md5 digest of the input
72
- # source files. Copying is done only if the combined digest has changed.
73
- # This keeps input file timestamps and filenames the same as in previous
74
- # extension builds, allowing for fast incremental rebuilds.
75
- #
76
- # This optimization is done only in case all the source files reside in
77
- # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
78
- # environment variable is set (we take this as a signal that the user
79
- # actually cares about this.)
80
- source_dirs_set = set(os.path.dirname(source) for source in sources)
81
- if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
82
- all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
83
-
84
- # Compute a combined hash digest for all source files in the same
85
- # custom op directory (usually .cu, .cpp, .py and .h files).
86
- hash_md5 = hashlib.md5()
87
- for src in all_source_files:
88
- with open(src, 'rb') as f:
89
- hash_md5.update(f.read())
90
- build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
91
- digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
92
-
93
- if not os.path.isdir(digest_build_dir):
94
- os.makedirs(digest_build_dir, exist_ok=True)
95
- baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
96
- if baton.try_acquire():
97
- try:
98
- for src in all_source_files:
99
- shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
100
- finally:
101
- baton.release()
102
- else:
103
- # Someone else is copying source files under the digest dir,
104
- # wait until done and continue.
105
- baton.wait()
106
- digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
107
- torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
108
- verbose=verbose_build, sources=digest_sources, **build_kwargs)
109
- else:
110
- torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
111
- module = importlib.import_module(module_name)
112
-
113
- except:
114
- if verbosity == 'brief':
115
- print('Failed!')
116
- raise
117
-
118
- # Print status and add to cache.
119
- if verbosity == 'full':
120
- print(f'Done setting up PyTorch plugin "{module_name}".')
121
- elif verbosity == 'brief':
122
- print('Done.')
123
- _cached_plugins[module_name] = module
124
- return module
125
-
126
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/torch_utils/misc.py DELETED
@@ -1,262 +0,0 @@
1
- ο»Ώ# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- import re
10
- import contextlib
11
- import numpy as np
12
- import torch
13
- import warnings
14
- import dnnlib
15
-
16
- #----------------------------------------------------------------------------
17
- # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
18
- # same constant is used multiple times.
19
-
20
- _constant_cache = dict()
21
-
22
- def constant(value, shape=None, dtype=None, device=None, memory_format=None):
23
- value = np.asarray(value)
24
- if shape is not None:
25
- shape = tuple(shape)
26
- if dtype is None:
27
- dtype = torch.get_default_dtype()
28
- if device is None:
29
- device = torch.device('cpu')
30
- if memory_format is None:
31
- memory_format = torch.contiguous_format
32
-
33
- key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
34
- tensor = _constant_cache.get(key, None)
35
- if tensor is None:
36
- tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
37
- if shape is not None:
38
- tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
39
- tensor = tensor.contiguous(memory_format=memory_format)
40
- _constant_cache[key] = tensor
41
- return tensor
42
-
43
- #----------------------------------------------------------------------------
44
- # Replace NaN/Inf with specified numerical values.
45
-
46
- try:
47
- nan_to_num = torch.nan_to_num # 1.8.0a0
48
- except AttributeError:
49
- def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
50
- assert isinstance(input, torch.Tensor)
51
- if posinf is None:
52
- posinf = torch.finfo(input.dtype).max
53
- if neginf is None:
54
- neginf = torch.finfo(input.dtype).min
55
- assert nan == 0
56
- return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
57
-
58
- #----------------------------------------------------------------------------
59
- # Symbolic assert.
60
-
61
- try:
62
- symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
63
- except AttributeError:
64
- symbolic_assert = torch.Assert # 1.7.0
65
-
66
- #----------------------------------------------------------------------------
67
- # Context manager to suppress known warnings in torch.jit.trace().
68
-
69
- class suppress_tracer_warnings(warnings.catch_warnings):
70
- def __enter__(self):
71
- super().__enter__()
72
- warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
73
- return self
74
-
75
- #----------------------------------------------------------------------------
76
- # Assert that the shape of a tensor matches the given list of integers.
77
- # None indicates that the size of a dimension is allowed to vary.
78
- # Performs symbolic assertion when used in torch.jit.trace().
79
-
80
- def assert_shape(tensor, ref_shape):
81
- if tensor.ndim != len(ref_shape):
82
- raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
83
- for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
84
- if ref_size is None:
85
- pass
86
- elif isinstance(ref_size, torch.Tensor):
87
- with suppress_tracer_warnings(): # as_tensor results are registered as constants
88
- symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
89
- elif isinstance(size, torch.Tensor):
90
- with suppress_tracer_warnings(): # as_tensor results are registered as constants
91
- symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
92
- elif size != ref_size:
93
- raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
94
-
95
- #----------------------------------------------------------------------------
96
- # Function decorator that calls torch.autograd.profiler.record_function().
97
-
98
- def profiled_function(fn):
99
- def decorator(*args, **kwargs):
100
- with torch.autograd.profiler.record_function(fn.__name__):
101
- return fn(*args, **kwargs)
102
- decorator.__name__ = fn.__name__
103
- return decorator
104
-
105
- #----------------------------------------------------------------------------
106
- # Sampler for torch.utils.data.DataLoader that loops over the dataset
107
- # indefinitely, shuffling items as it goes.
108
-
109
- class InfiniteSampler(torch.utils.data.Sampler):
110
- def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
111
- assert len(dataset) > 0
112
- assert num_replicas > 0
113
- assert 0 <= rank < num_replicas
114
- assert 0 <= window_size <= 1
115
- super().__init__(dataset)
116
- self.dataset = dataset
117
- self.rank = rank
118
- self.num_replicas = num_replicas
119
- self.shuffle = shuffle
120
- self.seed = seed
121
- self.window_size = window_size
122
-
123
- def __iter__(self):
124
- order = np.arange(len(self.dataset))
125
- rnd = None
126
- window = 0
127
- if self.shuffle:
128
- rnd = np.random.RandomState(self.seed)
129
- rnd.shuffle(order)
130
- window = int(np.rint(order.size * self.window_size))
131
-
132
- idx = 0
133
- while True:
134
- i = idx % order.size
135
- if idx % self.num_replicas == self.rank:
136
- yield order[i]
137
- if window >= 2:
138
- j = (i - rnd.randint(window)) % order.size
139
- order[i], order[j] = order[j], order[i]
140
- idx += 1
141
-
142
- #----------------------------------------------------------------------------
143
- # Utilities for operating with torch.nn.Module parameters and buffers.
144
-
145
- def params_and_buffers(module):
146
- assert isinstance(module, torch.nn.Module)
147
- return list(module.parameters()) + list(module.buffers())
148
-
149
- def named_params_and_buffers(module):
150
- assert isinstance(module, torch.nn.Module)
151
- return list(module.named_parameters()) + list(module.named_buffers())
152
-
153
- def copy_params_and_buffers(src_module, dst_module, require_all=False):
154
- assert isinstance(src_module, torch.nn.Module)
155
- assert isinstance(dst_module, torch.nn.Module)
156
- src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
157
- for name, tensor in named_params_and_buffers(dst_module):
158
- assert (name in src_tensors) or (not require_all)
159
- if name in src_tensors:
160
- tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
161
-
162
- #----------------------------------------------------------------------------
163
- # Context manager for easily enabling/disabling DistributedDataParallel
164
- # synchronization.
165
-
166
- @contextlib.contextmanager
167
- def ddp_sync(module, sync):
168
- assert isinstance(module, torch.nn.Module)
169
- if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
170
- yield
171
- else:
172
- with module.no_sync():
173
- yield
174
-
175
- #----------------------------------------------------------------------------
176
- # Check DistributedDataParallel consistency across processes.
177
-
178
- def check_ddp_consistency(module, ignore_regex=None):
179
- assert isinstance(module, torch.nn.Module)
180
- for name, tensor in named_params_and_buffers(module):
181
- fullname = type(module).__name__ + '.' + name
182
- if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
183
- continue
184
- tensor = tensor.detach()
185
- other = tensor.clone()
186
- torch.distributed.broadcast(tensor=other, src=0)
187
- assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
188
-
189
- #----------------------------------------------------------------------------
190
- # Print summary table of module hierarchy.
191
-
192
- def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
193
- assert isinstance(module, torch.nn.Module)
194
- assert not isinstance(module, torch.jit.ScriptModule)
195
- assert isinstance(inputs, (tuple, list))
196
-
197
- # Register hooks.
198
- entries = []
199
- nesting = [0]
200
- def pre_hook(_mod, _inputs):
201
- nesting[0] += 1
202
- def post_hook(mod, _inputs, outputs):
203
- nesting[0] -= 1
204
- if nesting[0] <= max_nesting:
205
- outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
206
- outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
207
- entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
208
- hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
209
- hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
210
-
211
- # Run module.
212
- outputs = module(*inputs)
213
- for hook in hooks:
214
- hook.remove()
215
-
216
- # Identify unique outputs, parameters, and buffers.
217
- tensors_seen = set()
218
- for e in entries:
219
- e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
220
- e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
221
- e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
222
- tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
223
-
224
- # Filter out redundant entries.
225
- if skip_redundant:
226
- entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
227
-
228
- # Construct table.
229
- rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
230
- rows += [['---'] * len(rows[0])]
231
- param_total = 0
232
- buffer_total = 0
233
- submodule_names = {mod: name for name, mod in module.named_modules()}
234
- for e in entries:
235
- name = '<top-level>' if e.mod is module else submodule_names[e.mod]
236
- param_size = sum(t.numel() for t in e.unique_params)
237
- buffer_size = sum(t.numel() for t in e.unique_buffers)
238
- output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
239
- output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
240
- rows += [[
241
- name + (':0' if len(e.outputs) >= 2 else ''),
242
- str(param_size) if param_size else '-',
243
- str(buffer_size) if buffer_size else '-',
244
- (output_shapes + ['-'])[0],
245
- (output_dtypes + ['-'])[0],
246
- ]]
247
- for idx in range(1, len(e.outputs)):
248
- rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
249
- param_total += param_size
250
- buffer_total += buffer_size
251
- rows += [['---'] * len(rows[0])]
252
- rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
253
-
254
- # Print table.
255
- widths = [max(len(cell) for cell in column) for column in zip(*rows)]
256
- print()
257
- for row in rows:
258
- print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
259
- print()
260
- return outputs
261
-
262
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/torch_utils/ops/bias_act.cpp DELETED
@@ -1,99 +0,0 @@
1
- // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- //
3
- // NVIDIA CORPORATION and its licensors retain all intellectual property
4
- // and proprietary rights in and to this software, related documentation
5
- // and any modifications thereto. Any use, reproduction, disclosure or
6
- // distribution of this software and related documentation without an express
7
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- #include <torch/extension.h>
10
- #include <ATen/cuda/CUDAContext.h>
11
- #include <c10/cuda/CUDAGuard.h>
12
- #include "bias_act.h"
13
-
14
- //------------------------------------------------------------------------
15
-
16
- static bool has_same_layout(torch::Tensor x, torch::Tensor y)
17
- {
18
- if (x.dim() != y.dim())
19
- return false;
20
- for (int64_t i = 0; i < x.dim(); i++)
21
- {
22
- if (x.size(i) != y.size(i))
23
- return false;
24
- if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
25
- return false;
26
- }
27
- return true;
28
- }
29
-
30
- //------------------------------------------------------------------------
31
-
32
- static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
33
- {
34
- // Validate arguments.
35
- TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
36
- TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
37
- TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
38
- TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
39
- TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
40
- TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
41
- TORCH_CHECK(b.dim() == 1, "b must have rank 1");
42
- TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
43
- TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
44
- TORCH_CHECK(grad >= 0, "grad must be non-negative");
45
-
46
- // Validate layout.
47
- TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
48
- TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
49
- TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
50
- TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
51
- TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
52
-
53
- // Create output tensor.
54
- const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
55
- torch::Tensor y = torch::empty_like(x);
56
- TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
57
-
58
- // Initialize CUDA kernel parameters.
59
- bias_act_kernel_params p;
60
- p.x = x.data_ptr();
61
- p.b = (b.numel()) ? b.data_ptr() : NULL;
62
- p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
63
- p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
64
- p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
65
- p.y = y.data_ptr();
66
- p.grad = grad;
67
- p.act = act;
68
- p.alpha = alpha;
69
- p.gain = gain;
70
- p.clamp = clamp;
71
- p.sizeX = (int)x.numel();
72
- p.sizeB = (int)b.numel();
73
- p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
74
-
75
- // Choose CUDA kernel.
76
- void* kernel;
77
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
78
- {
79
- kernel = choose_bias_act_kernel<scalar_t>(p);
80
- });
81
- TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
82
-
83
- // Launch CUDA kernel.
84
- p.loopX = 4;
85
- int blockSize = 4 * 32;
86
- int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
87
- void* args[] = {&p};
88
- AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
89
- return y;
90
- }
91
-
92
- //------------------------------------------------------------------------
93
-
94
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
95
- {
96
- m.def("bias_act", &bias_act);
97
- }
98
-
99
- //------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/torch_utils/ops/bias_act.cu DELETED
@@ -1,173 +0,0 @@
1
- // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- //
3
- // NVIDIA CORPORATION and its licensors retain all intellectual property
4
- // and proprietary rights in and to this software, related documentation
5
- // and any modifications thereto. Any use, reproduction, disclosure or
6
- // distribution of this software and related documentation without an express
7
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- #include <c10/util/Half.h>
10
- #include "bias_act.h"
11
-
12
- //------------------------------------------------------------------------
13
- // Helpers.
14
-
15
- template <class T> struct InternalType;
16
- template <> struct InternalType<double> { typedef double scalar_t; };
17
- template <> struct InternalType<float> { typedef float scalar_t; };
18
- template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
-
20
- //------------------------------------------------------------------------
21
- // CUDA kernel.
22
-
23
- template <class T, int A>
24
- __global__ void bias_act_kernel(bias_act_kernel_params p)
25
- {
26
- typedef typename InternalType<T>::scalar_t scalar_t;
27
- int G = p.grad;
28
- scalar_t alpha = (scalar_t)p.alpha;
29
- scalar_t gain = (scalar_t)p.gain;
30
- scalar_t clamp = (scalar_t)p.clamp;
31
- scalar_t one = (scalar_t)1;
32
- scalar_t two = (scalar_t)2;
33
- scalar_t expRange = (scalar_t)80;
34
- scalar_t halfExpRange = (scalar_t)40;
35
- scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
36
- scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
37
-
38
- // Loop over elements.
39
- int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
40
- for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
41
- {
42
- // Load.
43
- scalar_t x = (scalar_t)((const T*)p.x)[xi];
44
- scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
45
- scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
46
- scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
47
- scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
48
- scalar_t yy = (gain != 0) ? yref / gain : 0;
49
- scalar_t y = 0;
50
-
51
- // Apply bias.
52
- ((G == 0) ? x : xref) += b;
53
-
54
- // linear
55
- if (A == 1)
56
- {
57
- if (G == 0) y = x;
58
- if (G == 1) y = x;
59
- }
60
-
61
- // relu
62
- if (A == 2)
63
- {
64
- if (G == 0) y = (x > 0) ? x : 0;
65
- if (G == 1) y = (yy > 0) ? x : 0;
66
- }
67
-
68
- // lrelu
69
- if (A == 3)
70
- {
71
- if (G == 0) y = (x > 0) ? x : x * alpha;
72
- if (G == 1) y = (yy > 0) ? x : x * alpha;
73
- }
74
-
75
- // tanh
76
- if (A == 4)
77
- {
78
- if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
79
- if (G == 1) y = x * (one - yy * yy);
80
- if (G == 2) y = x * (one - yy * yy) * (-two * yy);
81
- }
82
-
83
- // sigmoid
84
- if (A == 5)
85
- {
86
- if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
87
- if (G == 1) y = x * yy * (one - yy);
88
- if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
89
- }
90
-
91
- // elu
92
- if (A == 6)
93
- {
94
- if (G == 0) y = (x >= 0) ? x : exp(x) - one;
95
- if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
96
- if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
97
- }
98
-
99
- // selu
100
- if (A == 7)
101
- {
102
- if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
103
- if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
104
- if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
105
- }
106
-
107
- // softplus
108
- if (A == 8)
109
- {
110
- if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
111
- if (G == 1) y = x * (one - exp(-yy));
112
- if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
113
- }
114
-
115
- // swish
116
- if (A == 9)
117
- {
118
- if (G == 0)
119
- y = (x < -expRange) ? 0 : x / (exp(-x) + one);
120
- else
121
- {
122
- scalar_t c = exp(xref);
123
- scalar_t d = c + one;
124
- if (G == 1)
125
- y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
126
- else
127
- y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
128
- yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
129
- }
130
- }
131
-
132
- // Apply gain.
133
- y *= gain * dy;
134
-
135
- // Clamp.
136
- if (clamp >= 0)
137
- {
138
- if (G == 0)
139
- y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
140
- else
141
- y = (yref > -clamp & yref < clamp) ? y : 0;
142
- }
143
-
144
- // Store.
145
- ((T*)p.y)[xi] = (T)y;
146
- }
147
- }
148
-
149
- //------------------------------------------------------------------------
150
- // CUDA kernel selection.
151
-
152
- template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
153
- {
154
- if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
155
- if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
156
- if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
157
- if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
158
- if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
159
- if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
160
- if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
161
- if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
162
- if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
163
- return NULL;
164
- }
165
-
166
- //------------------------------------------------------------------------
167
- // Template specializations.
168
-
169
- template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
170
- template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
171
- template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
172
-
173
- //------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/torch_utils/ops/bias_act.h DELETED
@@ -1,38 +0,0 @@
1
- // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- //
3
- // NVIDIA CORPORATION and its licensors retain all intellectual property
4
- // and proprietary rights in and to this software, related documentation
5
- // and any modifications thereto. Any use, reproduction, disclosure or
6
- // distribution of this software and related documentation without an express
7
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- //------------------------------------------------------------------------
10
- // CUDA kernel parameters.
11
-
12
- struct bias_act_kernel_params
13
- {
14
- const void* x; // [sizeX]
15
- const void* b; // [sizeB] or NULL
16
- const void* xref; // [sizeX] or NULL
17
- const void* yref; // [sizeX] or NULL
18
- const void* dy; // [sizeX] or NULL
19
- void* y; // [sizeX]
20
-
21
- int grad;
22
- int act;
23
- float alpha;
24
- float gain;
25
- float clamp;
26
-
27
- int sizeX;
28
- int sizeB;
29
- int stepB;
30
- int loopX;
31
- };
32
-
33
- //------------------------------------------------------------------------
34
- // CUDA kernel selection.
35
-
36
- template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37
-
38
- //------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/torch_utils/ops/bias_act.py DELETED
@@ -1,212 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Custom PyTorch ops for efficient bias and activation."""
10
-
11
- import os
12
- import warnings
13
- import numpy as np
14
- import torch
15
- import dnnlib
16
- import traceback
17
-
18
- from .. import custom_ops
19
- from .. import misc
20
-
21
- #----------------------------------------------------------------------------
22
-
23
- activation_funcs = {
24
- 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
25
- 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
26
- 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
27
- 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
28
- 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
29
- 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
30
- 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
31
- 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
32
- 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
33
- }
34
-
35
- #----------------------------------------------------------------------------
36
-
37
- _inited = False
38
- _plugin = None
39
- _null_tensor = torch.empty([0])
40
-
41
- def _init():
42
- global _inited, _plugin
43
- if not _inited:
44
- _inited = True
45
- sources = ['bias_act.cpp', 'bias_act.cu']
46
- sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
47
- try:
48
- _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
49
- except:
50
- warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
51
- return _plugin is not None
52
-
53
- #----------------------------------------------------------------------------
54
-
55
- def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
56
- r"""Fused bias and activation function.
57
-
58
- Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
59
- and scales the result by `gain`. Each of the steps is optional. In most cases,
60
- the fused op is considerably more efficient than performing the same calculation
61
- using standard PyTorch ops. It supports first and second order gradients,
62
- but not third order gradients.
63
-
64
- Args:
65
- x: Input activation tensor. Can be of any shape.
66
- b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
67
- as `x`. The shape must be known, and it must match the dimension of `x`
68
- corresponding to `dim`.
69
- dim: The dimension in `x` corresponding to the elements of `b`.
70
- The value of `dim` is ignored if `b` is not specified.
71
- act: Name of the activation function to evaluate, or `"linear"` to disable.
72
- Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
73
- See `activation_funcs` for a full list. `None` is not allowed.
74
- alpha: Shape parameter for the activation function, or `None` to use the default.
75
- gain: Scaling factor for the output tensor, or `None` to use default.
76
- See `activation_funcs` for the default scaling of each activation function.
77
- If unsure, consider specifying 1.
78
- clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
79
- the clamping (default).
80
- impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
81
-
82
- Returns:
83
- Tensor of the same shape and datatype as `x`.
84
- """
85
- assert isinstance(x, torch.Tensor)
86
- assert impl in ['ref', 'cuda']
87
- if impl == 'cuda' and x.device.type == 'cuda' and _init():
88
- return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
89
- return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
90
-
91
- #----------------------------------------------------------------------------
92
-
93
- @misc.profiled_function
94
- def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
95
- """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
96
- """
97
- assert isinstance(x, torch.Tensor)
98
- assert clamp is None or clamp >= 0
99
- spec = activation_funcs[act]
100
- alpha = float(alpha if alpha is not None else spec.def_alpha)
101
- gain = float(gain if gain is not None else spec.def_gain)
102
- clamp = float(clamp if clamp is not None else -1)
103
-
104
- # Add bias.
105
- if b is not None:
106
- assert isinstance(b, torch.Tensor) and b.ndim == 1
107
- assert 0 <= dim < x.ndim
108
- assert b.shape[0] == x.shape[dim]
109
- x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
110
-
111
- # Evaluate activation function.
112
- alpha = float(alpha)
113
- x = spec.func(x, alpha=alpha)
114
-
115
- # Scale by gain.
116
- gain = float(gain)
117
- if gain != 1:
118
- x = x * gain
119
-
120
- # Clamp.
121
- if clamp >= 0:
122
- x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
123
- return x
124
-
125
- #----------------------------------------------------------------------------
126
-
127
- _bias_act_cuda_cache = dict()
128
-
129
- def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
130
- """Fast CUDA implementation of `bias_act()` using custom ops.
131
- """
132
- # Parse arguments.
133
- assert clamp is None or clamp >= 0
134
- spec = activation_funcs[act]
135
- alpha = float(alpha if alpha is not None else spec.def_alpha)
136
- gain = float(gain if gain is not None else spec.def_gain)
137
- clamp = float(clamp if clamp is not None else -1)
138
-
139
- # Lookup from cache.
140
- key = (dim, act, alpha, gain, clamp)
141
- if key in _bias_act_cuda_cache:
142
- return _bias_act_cuda_cache[key]
143
-
144
- # Forward op.
145
- class BiasActCuda(torch.autograd.Function):
146
- @staticmethod
147
- def forward(ctx, x, b): # pylint: disable=arguments-differ
148
- ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
149
- x = x.contiguous(memory_format=ctx.memory_format)
150
- b = b.contiguous() if b is not None else _null_tensor
151
- y = x
152
- if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
153
- y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
154
- ctx.save_for_backward(
155
- x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
156
- b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
157
- y if 'y' in spec.ref else _null_tensor)
158
- return y
159
-
160
- @staticmethod
161
- def backward(ctx, dy): # pylint: disable=arguments-differ
162
- dy = dy.contiguous(memory_format=ctx.memory_format)
163
- x, b, y = ctx.saved_tensors
164
- dx = None
165
- db = None
166
-
167
- if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
168
- dx = dy
169
- if act != 'linear' or gain != 1 or clamp >= 0:
170
- dx = BiasActCudaGrad.apply(dy, x, b, y)
171
-
172
- if ctx.needs_input_grad[1]:
173
- db = dx.sum([i for i in range(dx.ndim) if i != dim])
174
-
175
- return dx, db
176
-
177
- # Backward op.
178
- class BiasActCudaGrad(torch.autograd.Function):
179
- @staticmethod
180
- def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
181
- ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
182
- dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
183
- ctx.save_for_backward(
184
- dy if spec.has_2nd_grad else _null_tensor,
185
- x, b, y)
186
- return dx
187
-
188
- @staticmethod
189
- def backward(ctx, d_dx): # pylint: disable=arguments-differ
190
- d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
191
- dy, x, b, y = ctx.saved_tensors
192
- d_dy = None
193
- d_x = None
194
- d_b = None
195
- d_y = None
196
-
197
- if ctx.needs_input_grad[0]:
198
- d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
199
-
200
- if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
201
- d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
202
-
203
- if spec.has_2nd_grad and ctx.needs_input_grad[2]:
204
- d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
205
-
206
- return d_dy, d_x, d_b, d_y
207
-
208
- # Add to cache.
209
- _bias_act_cuda_cache[key] = BiasActCuda
210
- return BiasActCuda
211
-
212
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/torch_utils/ops/conv2d_gradfix.py DELETED
@@ -1,170 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Custom replacement for `torch.nn.functional.conv2d` that supports
10
- arbitrarily high order gradients with zero performance penalty."""
11
-
12
- import warnings
13
- import contextlib
14
- import torch
15
-
16
- # pylint: disable=redefined-builtin
17
- # pylint: disable=arguments-differ
18
- # pylint: disable=protected-access
19
-
20
- #----------------------------------------------------------------------------
21
-
22
- enabled = False # Enable the custom op by setting this to true.
23
- weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
24
-
25
- @contextlib.contextmanager
26
- def no_weight_gradients():
27
- global weight_gradients_disabled
28
- old = weight_gradients_disabled
29
- weight_gradients_disabled = True
30
- yield
31
- weight_gradients_disabled = old
32
-
33
- #----------------------------------------------------------------------------
34
-
35
- def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
36
- if _should_use_custom_op(input):
37
- return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
38
- return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
39
-
40
- def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
41
- if _should_use_custom_op(input):
42
- return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
43
- return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
44
-
45
- #----------------------------------------------------------------------------
46
-
47
- def _should_use_custom_op(input):
48
- assert isinstance(input, torch.Tensor)
49
- if (not enabled) or (not torch.backends.cudnn.enabled):
50
- return False
51
- if input.device.type != 'cuda':
52
- return False
53
- if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
54
- return True
55
- warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
56
- return False
57
-
58
- def _tuple_of_ints(xs, ndim):
59
- xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
60
- assert len(xs) == ndim
61
- assert all(isinstance(x, int) for x in xs)
62
- return xs
63
-
64
- #----------------------------------------------------------------------------
65
-
66
- _conv2d_gradfix_cache = dict()
67
-
68
- def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
69
- # Parse arguments.
70
- ndim = 2
71
- weight_shape = tuple(weight_shape)
72
- stride = _tuple_of_ints(stride, ndim)
73
- padding = _tuple_of_ints(padding, ndim)
74
- output_padding = _tuple_of_ints(output_padding, ndim)
75
- dilation = _tuple_of_ints(dilation, ndim)
76
-
77
- # Lookup from cache.
78
- key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
79
- if key in _conv2d_gradfix_cache:
80
- return _conv2d_gradfix_cache[key]
81
-
82
- # Validate arguments.
83
- assert groups >= 1
84
- assert len(weight_shape) == ndim + 2
85
- assert all(stride[i] >= 1 for i in range(ndim))
86
- assert all(padding[i] >= 0 for i in range(ndim))
87
- assert all(dilation[i] >= 0 for i in range(ndim))
88
- if not transpose:
89
- assert all(output_padding[i] == 0 for i in range(ndim))
90
- else: # transpose
91
- assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
92
-
93
- # Helpers.
94
- common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
95
- def calc_output_padding(input_shape, output_shape):
96
- if transpose:
97
- return [0, 0]
98
- return [
99
- input_shape[i + 2]
100
- - (output_shape[i + 2] - 1) * stride[i]
101
- - (1 - 2 * padding[i])
102
- - dilation[i] * (weight_shape[i + 2] - 1)
103
- for i in range(ndim)
104
- ]
105
-
106
- # Forward & backward.
107
- class Conv2d(torch.autograd.Function):
108
- @staticmethod
109
- def forward(ctx, input, weight, bias):
110
- assert weight.shape == weight_shape
111
- if not transpose:
112
- output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
113
- else: # transpose
114
- output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
115
- ctx.save_for_backward(input, weight)
116
- return output
117
-
118
- @staticmethod
119
- def backward(ctx, grad_output):
120
- input, weight = ctx.saved_tensors
121
- grad_input = None
122
- grad_weight = None
123
- grad_bias = None
124
-
125
- if ctx.needs_input_grad[0]:
126
- p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
127
- grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
128
- assert grad_input.shape == input.shape
129
-
130
- if ctx.needs_input_grad[1] and not weight_gradients_disabled:
131
- grad_weight = Conv2dGradWeight.apply(grad_output, input)
132
- assert grad_weight.shape == weight_shape
133
-
134
- if ctx.needs_input_grad[2]:
135
- grad_bias = grad_output.sum([0, 2, 3])
136
-
137
- return grad_input, grad_weight, grad_bias
138
-
139
- # Gradient with respect to the weights.
140
- class Conv2dGradWeight(torch.autograd.Function):
141
- @staticmethod
142
- def forward(ctx, grad_output, input):
143
- op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
144
- flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
145
- grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
146
- assert grad_weight.shape == weight_shape
147
- ctx.save_for_backward(grad_output, input)
148
- return grad_weight
149
-
150
- @staticmethod
151
- def backward(ctx, grad2_grad_weight):
152
- grad_output, input = ctx.saved_tensors
153
- grad2_grad_output = None
154
- grad2_input = None
155
-
156
- if ctx.needs_input_grad[0]:
157
- grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
158
- assert grad2_grad_output.shape == grad_output.shape
159
-
160
- if ctx.needs_input_grad[1]:
161
- p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
162
- grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
163
- assert grad2_input.shape == input.shape
164
-
165
- return grad2_grad_output, grad2_input
166
-
167
- _conv2d_gradfix_cache[key] = Conv2d
168
- return Conv2d
169
-
170
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/torch_utils/ops/conv2d_resample.py DELETED
@@ -1,156 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """2D convolution with optional up/downsampling."""
10
-
11
- import torch
12
-
13
- from .. import misc
14
- from . import conv2d_gradfix
15
- from . import upfirdn2d
16
- from .upfirdn2d import _parse_padding
17
- from .upfirdn2d import _get_filter_size
18
-
19
- #----------------------------------------------------------------------------
20
-
21
- def _get_weight_shape(w):
22
- with misc.suppress_tracer_warnings(): # this value will be treated as a constant
23
- shape = [int(sz) for sz in w.shape]
24
- misc.assert_shape(w, shape)
25
- return shape
26
-
27
- #----------------------------------------------------------------------------
28
-
29
- def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
30
- """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
31
- """
32
- out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
33
-
34
- # Flip weight if requested.
35
- if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
36
- w = w.flip([2, 3])
37
-
38
- # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
39
- # 1x1 kernel + memory_format=channels_last + less than 64 channels.
40
- if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
41
- if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
42
- if out_channels <= 4 and groups == 1:
43
- in_shape = x.shape
44
- x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
45
- x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
46
- else:
47
- x = x.to(memory_format=torch.contiguous_format)
48
- w = w.to(memory_format=torch.contiguous_format)
49
- x = conv2d_gradfix.conv2d(x, w, groups=groups)
50
- return x.to(memory_format=torch.channels_last)
51
-
52
- # Otherwise => execute using conv2d_gradfix.
53
- op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
54
- return op(x, w, stride=stride, padding=padding, groups=groups)
55
-
56
- #----------------------------------------------------------------------------
57
-
58
- @misc.profiled_function
59
- def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
60
- r"""2D convolution with optional up/downsampling.
61
-
62
- Padding is performed only once at the beginning, not between the operations.
63
-
64
- Args:
65
- x: Input tensor of shape
66
- `[batch_size, in_channels, in_height, in_width]`.
67
- w: Weight tensor of shape
68
- `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
69
- f: Low-pass filter for up/downsampling. Must be prepared beforehand by
70
- calling upfirdn2d.setup_filter(). None = identity (default).
71
- up: Integer upsampling factor (default: 1).
72
- down: Integer downsampling factor (default: 1).
73
- padding: Padding with respect to the upsampled image. Can be a single number
74
- or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
75
- (default: 0).
76
- groups: Split input channels into N groups (default: 1).
77
- flip_weight: False = convolution, True = correlation (default: True).
78
- flip_filter: False = convolution, True = correlation (default: False).
79
-
80
- Returns:
81
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
82
- """
83
- # Validate arguments.
84
- assert isinstance(x, torch.Tensor) and (x.ndim == 4)
85
- assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
86
- assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
87
- assert isinstance(up, int) and (up >= 1)
88
- assert isinstance(down, int) and (down >= 1)
89
- assert isinstance(groups, int) and (groups >= 1)
90
- out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
91
- fw, fh = _get_filter_size(f)
92
- px0, px1, py0, py1 = _parse_padding(padding)
93
-
94
- # Adjust padding to account for up/downsampling.
95
- if up > 1:
96
- px0 += (fw + up - 1) // 2
97
- px1 += (fw - up) // 2
98
- py0 += (fh + up - 1) // 2
99
- py1 += (fh - up) // 2
100
- if down > 1:
101
- px0 += (fw - down + 1) // 2
102
- px1 += (fw - down) // 2
103
- py0 += (fh - down + 1) // 2
104
- py1 += (fh - down) // 2
105
-
106
- # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
107
- if kw == 1 and kh == 1 and (down > 1 and up == 1):
108
- x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
109
- x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
110
- return x
111
-
112
- # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
113
- if kw == 1 and kh == 1 and (up > 1 and down == 1):
114
- x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
115
- x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
116
- return x
117
-
118
- # Fast path: downsampling only => use strided convolution.
119
- if down > 1 and up == 1:
120
- x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
121
- x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
122
- return x
123
-
124
- # Fast path: upsampling with optional downsampling => use transpose strided convolution.
125
- if up > 1:
126
- if groups == 1:
127
- w = w.transpose(0, 1)
128
- else:
129
- w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
130
- w = w.transpose(1, 2)
131
- w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
132
- px0 -= kw - 1
133
- px1 -= kw - up
134
- py0 -= kh - 1
135
- py1 -= kh - up
136
- pxt = max(min(-px0, -px1), 0)
137
- pyt = max(min(-py0, -py1), 0)
138
- x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
139
- x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
140
- if down > 1:
141
- x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
142
- return x
143
-
144
- # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
145
- if up == 1 and down == 1:
146
- if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
147
- return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
148
-
149
- # Fallback: Generic reference implementation.
150
- x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
151
- x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
152
- if down > 1:
153
- x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
154
- return x
155
-
156
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/torch_utils/ops/fma.py DELETED
@@ -1,60 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
10
-
11
- import torch
12
-
13
- #----------------------------------------------------------------------------
14
-
15
- def fma(a, b, c): # => a * b + c
16
- return _FusedMultiplyAdd.apply(a, b, c)
17
-
18
- #----------------------------------------------------------------------------
19
-
20
- class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
21
- @staticmethod
22
- def forward(ctx, a, b, c): # pylint: disable=arguments-differ
23
- out = torch.addcmul(c, a, b)
24
- ctx.save_for_backward(a, b)
25
- ctx.c_shape = c.shape
26
- return out
27
-
28
- @staticmethod
29
- def backward(ctx, dout): # pylint: disable=arguments-differ
30
- a, b = ctx.saved_tensors
31
- c_shape = ctx.c_shape
32
- da = None
33
- db = None
34
- dc = None
35
-
36
- if ctx.needs_input_grad[0]:
37
- da = _unbroadcast(dout * b, a.shape)
38
-
39
- if ctx.needs_input_grad[1]:
40
- db = _unbroadcast(dout * a, b.shape)
41
-
42
- if ctx.needs_input_grad[2]:
43
- dc = _unbroadcast(dout, c_shape)
44
-
45
- return da, db, dc
46
-
47
- #----------------------------------------------------------------------------
48
-
49
- def _unbroadcast(x, shape):
50
- extra_dims = x.ndim - len(shape)
51
- assert extra_dims >= 0
52
- dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
53
- if len(dim):
54
- x = x.sum(dim=dim, keepdim=True)
55
- if extra_dims:
56
- x = x.reshape(-1, *x.shape[extra_dims+1:])
57
- assert x.shape == shape
58
- return x
59
-
60
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/torch_utils/ops/grid_sample_gradfix.py DELETED
@@ -1,83 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Custom replacement for `torch.nn.functional.grid_sample` that
10
- supports arbitrarily high order gradients between the input and output.
11
- Only works on 2D images and assumes
12
- `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
13
-
14
- import warnings
15
- import torch
16
-
17
- # pylint: disable=redefined-builtin
18
- # pylint: disable=arguments-differ
19
- # pylint: disable=protected-access
20
-
21
- #----------------------------------------------------------------------------
22
-
23
- enabled = False # Enable the custom op by setting this to true.
24
-
25
- #----------------------------------------------------------------------------
26
-
27
- def grid_sample(input, grid):
28
- if _should_use_custom_op():
29
- return _GridSample2dForward.apply(input, grid)
30
- return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
31
-
32
- #----------------------------------------------------------------------------
33
-
34
- def _should_use_custom_op():
35
- if not enabled:
36
- return False
37
- if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
38
- return True
39
- warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
40
- return False
41
-
42
- #----------------------------------------------------------------------------
43
-
44
- class _GridSample2dForward(torch.autograd.Function):
45
- @staticmethod
46
- def forward(ctx, input, grid):
47
- assert input.ndim == 4
48
- assert grid.ndim == 4
49
- output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
50
- ctx.save_for_backward(input, grid)
51
- return output
52
-
53
- @staticmethod
54
- def backward(ctx, grad_output):
55
- input, grid = ctx.saved_tensors
56
- grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
57
- return grad_input, grad_grid
58
-
59
- #----------------------------------------------------------------------------
60
-
61
- class _GridSample2dBackward(torch.autograd.Function):
62
- @staticmethod
63
- def forward(ctx, grad_output, input, grid):
64
- op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
65
- grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
66
- ctx.save_for_backward(grid)
67
- return grad_input, grad_grid
68
-
69
- @staticmethod
70
- def backward(ctx, grad2_grad_input, grad2_grad_grid):
71
- _ = grad2_grad_grid # unused
72
- grid, = ctx.saved_tensors
73
- grad2_grad_output = None
74
- grad2_input = None
75
- grad2_grid = None
76
-
77
- if ctx.needs_input_grad[0]:
78
- grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
79
-
80
- assert not ctx.needs_input_grad[2]
81
- return grad2_grad_output, grad2_input, grad2_grid
82
-
83
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/torch_utils/ops/upfirdn2d.cpp DELETED
@@ -1,103 +0,0 @@
1
- // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- //
3
- // NVIDIA CORPORATION and its licensors retain all intellectual property
4
- // and proprietary rights in and to this software, related documentation
5
- // and any modifications thereto. Any use, reproduction, disclosure or
6
- // distribution of this software and related documentation without an express
7
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- #include <torch/extension.h>
10
- #include <ATen/cuda/CUDAContext.h>
11
- #include <c10/cuda/CUDAGuard.h>
12
- #include "upfirdn2d.h"
13
-
14
- //------------------------------------------------------------------------
15
-
16
- static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
17
- {
18
- // Validate arguments.
19
- TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
20
- TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
21
- TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
22
- TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
23
- TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
24
- TORCH_CHECK(x.dim() == 4, "x must be rank 4");
25
- TORCH_CHECK(f.dim() == 2, "f must be rank 2");
26
- TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
27
- TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
28
- TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
29
-
30
- // Create output tensor.
31
- const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
32
- int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
33
- int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
34
- TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
35
- torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
36
- TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
37
-
38
- // Initialize CUDA kernel parameters.
39
- upfirdn2d_kernel_params p;
40
- p.x = x.data_ptr();
41
- p.f = f.data_ptr<float>();
42
- p.y = y.data_ptr();
43
- p.up = make_int2(upx, upy);
44
- p.down = make_int2(downx, downy);
45
- p.pad0 = make_int2(padx0, pady0);
46
- p.flip = (flip) ? 1 : 0;
47
- p.gain = gain;
48
- p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
49
- p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
50
- p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
51
- p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
52
- p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
53
- p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
54
- p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
55
- p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
56
-
57
- // Choose CUDA kernel.
58
- upfirdn2d_kernel_spec spec;
59
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
60
- {
61
- spec = choose_upfirdn2d_kernel<scalar_t>(p);
62
- });
63
-
64
- // Set looping options.
65
- p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
66
- p.loopMinor = spec.loopMinor;
67
- p.loopX = spec.loopX;
68
- p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
69
- p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
70
-
71
- // Compute grid size.
72
- dim3 blockSize, gridSize;
73
- if (spec.tileOutW < 0) // large
74
- {
75
- blockSize = dim3(4, 32, 1);
76
- gridSize = dim3(
77
- ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
78
- (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
79
- p.launchMajor);
80
- }
81
- else // small
82
- {
83
- blockSize = dim3(256, 1, 1);
84
- gridSize = dim3(
85
- ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
86
- (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
87
- p.launchMajor);
88
- }
89
-
90
- // Launch CUDA kernel.
91
- void* args[] = {&p};
92
- AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
93
- return y;
94
- }
95
-
96
- //------------------------------------------------------------------------
97
-
98
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
99
- {
100
- m.def("upfirdn2d", &upfirdn2d);
101
- }
102
-
103
- //------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/torch_utils/ops/upfirdn2d.cu DELETED
@@ -1,350 +0,0 @@
1
- // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- //
3
- // NVIDIA CORPORATION and its licensors retain all intellectual property
4
- // and proprietary rights in and to this software, related documentation
5
- // and any modifications thereto. Any use, reproduction, disclosure or
6
- // distribution of this software and related documentation without an express
7
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- #include <c10/util/Half.h>
10
- #include "upfirdn2d.h"
11
-
12
- //------------------------------------------------------------------------
13
- // Helpers.
14
-
15
- template <class T> struct InternalType;
16
- template <> struct InternalType<double> { typedef double scalar_t; };
17
- template <> struct InternalType<float> { typedef float scalar_t; };
18
- template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
-
20
- static __device__ __forceinline__ int floor_div(int a, int b)
21
- {
22
- int t = 1 - a / b;
23
- return (a + t * b) / b - t;
24
- }
25
-
26
- //------------------------------------------------------------------------
27
- // Generic CUDA implementation for large filters.
28
-
29
- template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
30
- {
31
- typedef typename InternalType<T>::scalar_t scalar_t;
32
-
33
- // Calculate thread index.
34
- int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
35
- int outY = minorBase / p.launchMinor;
36
- minorBase -= outY * p.launchMinor;
37
- int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
38
- int majorBase = blockIdx.z * p.loopMajor;
39
- if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
40
- return;
41
-
42
- // Setup Y receptive field.
43
- int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
44
- int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
45
- int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
46
- int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
47
- if (p.flip)
48
- filterY = p.filterSize.y - 1 - filterY;
49
-
50
- // Loop over major, minor, and X.
51
- for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
52
- for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
53
- {
54
- int nc = major * p.sizeMinor + minor;
55
- int n = nc / p.inSize.z;
56
- int c = nc - n * p.inSize.z;
57
- for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
58
- {
59
- // Setup X receptive field.
60
- int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
61
- int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
62
- int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
63
- int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
64
- if (p.flip)
65
- filterX = p.filterSize.x - 1 - filterX;
66
-
67
- // Initialize pointers.
68
- const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
69
- const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
70
- int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
71
- int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
72
-
73
- // Inner loop.
74
- scalar_t v = 0;
75
- for (int y = 0; y < h; y++)
76
- {
77
- for (int x = 0; x < w; x++)
78
- {
79
- v += (scalar_t)(*xp) * (scalar_t)(*fp);
80
- xp += p.inStride.x;
81
- fp += filterStepX;
82
- }
83
- xp += p.inStride.y - w * p.inStride.x;
84
- fp += filterStepY - w * filterStepX;
85
- }
86
-
87
- // Store result.
88
- v *= p.gain;
89
- ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
90
- }
91
- }
92
- }
93
-
94
- //------------------------------------------------------------------------
95
- // Specialized CUDA implementation for small filters.
96
-
97
- template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
98
- static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
99
- {
100
- typedef typename InternalType<T>::scalar_t scalar_t;
101
- const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
102
- const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
103
- __shared__ volatile scalar_t sf[filterH][filterW];
104
- __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
105
-
106
- // Calculate tile index.
107
- int minorBase = blockIdx.x;
108
- int tileOutY = minorBase / p.launchMinor;
109
- minorBase -= tileOutY * p.launchMinor;
110
- minorBase *= loopMinor;
111
- tileOutY *= tileOutH;
112
- int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
113
- int majorBase = blockIdx.z * p.loopMajor;
114
- if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
115
- return;
116
-
117
- // Load filter (flipped).
118
- for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
119
- {
120
- int fy = tapIdx / filterW;
121
- int fx = tapIdx - fy * filterW;
122
- scalar_t v = 0;
123
- if (fx < p.filterSize.x & fy < p.filterSize.y)
124
- {
125
- int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
126
- int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
127
- v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
128
- }
129
- sf[fy][fx] = v;
130
- }
131
-
132
- // Loop over major and X.
133
- for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
134
- {
135
- int baseNC = major * p.sizeMinor + minorBase;
136
- int n = baseNC / p.inSize.z;
137
- int baseC = baseNC - n * p.inSize.z;
138
- for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
139
- {
140
- // Load input pixels.
141
- int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
142
- int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
143
- int tileInX = floor_div(tileMidX, upx);
144
- int tileInY = floor_div(tileMidY, upy);
145
- __syncthreads();
146
- for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
147
- {
148
- int relC = inIdx;
149
- int relInX = relC / loopMinor;
150
- int relInY = relInX / tileInW;
151
- relC -= relInX * loopMinor;
152
- relInX -= relInY * tileInW;
153
- int c = baseC + relC;
154
- int inX = tileInX + relInX;
155
- int inY = tileInY + relInY;
156
- scalar_t v = 0;
157
- if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
158
- v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
159
- sx[relInY][relInX][relC] = v;
160
- }
161
-
162
- // Loop over output pixels.
163
- __syncthreads();
164
- for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
165
- {
166
- int relC = outIdx;
167
- int relOutX = relC / loopMinor;
168
- int relOutY = relOutX / tileOutW;
169
- relC -= relOutX * loopMinor;
170
- relOutX -= relOutY * tileOutW;
171
- int c = baseC + relC;
172
- int outX = tileOutX + relOutX;
173
- int outY = tileOutY + relOutY;
174
-
175
- // Setup receptive field.
176
- int midX = tileMidX + relOutX * downx;
177
- int midY = tileMidY + relOutY * downy;
178
- int inX = floor_div(midX, upx);
179
- int inY = floor_div(midY, upy);
180
- int relInX = inX - tileInX;
181
- int relInY = inY - tileInY;
182
- int filterX = (inX + 1) * upx - midX - 1; // flipped
183
- int filterY = (inY + 1) * upy - midY - 1; // flipped
184
-
185
- // Inner loop.
186
- if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
187
- {
188
- scalar_t v = 0;
189
- #pragma unroll
190
- for (int y = 0; y < filterH / upy; y++)
191
- #pragma unroll
192
- for (int x = 0; x < filterW / upx; x++)
193
- v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
194
- v *= p.gain;
195
- ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
196
- }
197
- }
198
- }
199
- }
200
- }
201
-
202
- //------------------------------------------------------------------------
203
- // CUDA kernel selection.
204
-
205
- template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
206
- {
207
- int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
208
-
209
- upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
210
- if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
211
-
212
- if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
213
- {
214
- if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
215
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
216
- if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
217
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
218
- if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
219
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
220
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
221
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
222
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
223
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
224
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
225
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
226
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
227
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
228
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
229
- }
230
- if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
231
- {
232
- if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
233
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
234
- if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
235
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
236
- if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
237
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
238
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
239
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
240
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
241
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
242
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
243
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
244
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
245
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
246
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
247
- }
248
- if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
249
- {
250
- if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
251
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
252
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
253
- if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
254
- }
255
- if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
256
- {
257
- if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
258
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
259
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
260
- if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
261
- }
262
- if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
263
- {
264
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
265
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
266
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
267
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
268
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
269
- }
270
- if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
271
- {
272
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
273
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
274
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
275
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
276
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
277
- }
278
- if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
279
- {
280
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
281
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
282
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
283
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
284
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
285
- }
286
- if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
287
- {
288
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
289
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
290
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
291
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
292
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
293
- }
294
- if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
295
- {
296
- if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
297
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
298
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
299
- if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
300
- }
301
- if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
302
- {
303
- if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
304
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
305
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
306
- if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
307
- }
308
- if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
309
- {
310
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
311
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,8,1>, 64,8,1, 1};
312
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
313
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,8,1>, 64,8,1, 1};
314
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
315
- }
316
- if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
317
- {
318
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
319
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,1,8>, 64,1,8, 1};
320
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
321
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,1,8>, 64,1,8, 1};
322
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
323
- }
324
- if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
325
- {
326
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
327
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 32,16,1>, 32,16,1, 1};
328
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
329
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 32,16,1>, 32,16,1, 1};
330
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
331
- }
332
- if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
333
- {
334
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
335
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 1,64,8>, 1,64,8, 1};
336
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
337
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 1,64,8>, 1,64,8, 1};
338
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
339
- }
340
- return spec;
341
- }
342
-
343
- //------------------------------------------------------------------------
344
- // Template specializations.
345
-
346
- template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
347
- template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
348
- template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
349
-
350
- //------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/torch_utils/ops/upfirdn2d.h DELETED
@@ -1,59 +0,0 @@
1
- // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- //
3
- // NVIDIA CORPORATION and its licensors retain all intellectual property
4
- // and proprietary rights in and to this software, related documentation
5
- // and any modifications thereto. Any use, reproduction, disclosure or
6
- // distribution of this software and related documentation without an express
7
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- #include <cuda_runtime.h>
10
-
11
- //------------------------------------------------------------------------
12
- // CUDA kernel parameters.
13
-
14
- struct upfirdn2d_kernel_params
15
- {
16
- const void* x;
17
- const float* f;
18
- void* y;
19
-
20
- int2 up;
21
- int2 down;
22
- int2 pad0;
23
- int flip;
24
- float gain;
25
-
26
- int4 inSize; // [width, height, channel, batch]
27
- int4 inStride;
28
- int2 filterSize; // [width, height]
29
- int2 filterStride;
30
- int4 outSize; // [width, height, channel, batch]
31
- int4 outStride;
32
- int sizeMinor;
33
- int sizeMajor;
34
-
35
- int loopMinor;
36
- int loopMajor;
37
- int loopX;
38
- int launchMinor;
39
- int launchMajor;
40
- };
41
-
42
- //------------------------------------------------------------------------
43
- // CUDA kernel specialization.
44
-
45
- struct upfirdn2d_kernel_spec
46
- {
47
- void* kernel;
48
- int tileOutW;
49
- int tileOutH;
50
- int loopMinor;
51
- int loopX;
52
- };
53
-
54
- //------------------------------------------------------------------------
55
- // CUDA kernel selection.
56
-
57
- template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58
-
59
- //------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/torch_utils/ops/upfirdn2d.py DELETED
@@ -1,384 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Custom PyTorch ops for efficient resampling of 2D images."""
10
-
11
- import os
12
- import warnings
13
- import numpy as np
14
- import torch
15
- import traceback
16
-
17
- from .. import custom_ops
18
- from .. import misc
19
- from . import conv2d_gradfix
20
-
21
- #----------------------------------------------------------------------------
22
-
23
- _inited = False
24
- _plugin = None
25
-
26
- def _init():
27
- global _inited, _plugin
28
- if not _inited:
29
- sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
30
- sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
31
- try:
32
- _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
33
- except:
34
- warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
35
- return _plugin is not None
36
-
37
- def _parse_scaling(scaling):
38
- if isinstance(scaling, int):
39
- scaling = [scaling, scaling]
40
- assert isinstance(scaling, (list, tuple))
41
- assert all(isinstance(x, int) for x in scaling)
42
- sx, sy = scaling
43
- assert sx >= 1 and sy >= 1
44
- return sx, sy
45
-
46
- def _parse_padding(padding):
47
- if isinstance(padding, int):
48
- padding = [padding, padding]
49
- assert isinstance(padding, (list, tuple))
50
- assert all(isinstance(x, int) for x in padding)
51
- if len(padding) == 2:
52
- padx, pady = padding
53
- padding = [padx, padx, pady, pady]
54
- padx0, padx1, pady0, pady1 = padding
55
- return padx0, padx1, pady0, pady1
56
-
57
- def _get_filter_size(f):
58
- if f is None:
59
- return 1, 1
60
- assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
61
- fw = f.shape[-1]
62
- fh = f.shape[0]
63
- with misc.suppress_tracer_warnings():
64
- fw = int(fw)
65
- fh = int(fh)
66
- misc.assert_shape(f, [fh, fw][:f.ndim])
67
- assert fw >= 1 and fh >= 1
68
- return fw, fh
69
-
70
- #----------------------------------------------------------------------------
71
-
72
- def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
73
- r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
74
-
75
- Args:
76
- f: Torch tensor, numpy array, or python list of the shape
77
- `[filter_height, filter_width]` (non-separable),
78
- `[filter_taps]` (separable),
79
- `[]` (impulse), or
80
- `None` (identity).
81
- device: Result device (default: cpu).
82
- normalize: Normalize the filter so that it retains the magnitude
83
- for constant input signal (DC)? (default: True).
84
- flip_filter: Flip the filter? (default: False).
85
- gain: Overall scaling factor for signal magnitude (default: 1).
86
- separable: Return a separable filter? (default: select automatically).
87
-
88
- Returns:
89
- Float32 tensor of the shape
90
- `[filter_height, filter_width]` (non-separable) or
91
- `[filter_taps]` (separable).
92
- """
93
- # Validate.
94
- if f is None:
95
- f = 1
96
- f = torch.as_tensor(f, dtype=torch.float32)
97
- assert f.ndim in [0, 1, 2]
98
- assert f.numel() > 0
99
- if f.ndim == 0:
100
- f = f[np.newaxis]
101
-
102
- # Separable?
103
- if separable is None:
104
- separable = (f.ndim == 1 and f.numel() >= 8)
105
- if f.ndim == 1 and not separable:
106
- f = f.ger(f)
107
- assert f.ndim == (1 if separable else 2)
108
-
109
- # Apply normalize, flip, gain, and device.
110
- if normalize:
111
- f /= f.sum()
112
- if flip_filter:
113
- f = f.flip(list(range(f.ndim)))
114
- f = f * (gain ** (f.ndim / 2))
115
- f = f.to(device=device)
116
- return f
117
-
118
- #----------------------------------------------------------------------------
119
-
120
- def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
121
- r"""Pad, upsample, filter, and downsample a batch of 2D images.
122
-
123
- Performs the following sequence of operations for each channel:
124
-
125
- 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
126
-
127
- 2. Pad the image with the specified number of zeros on each side (`padding`).
128
- Negative padding corresponds to cropping the image.
129
-
130
- 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
131
- so that the footprint of all output pixels lies within the input image.
132
-
133
- 4. Downsample the image by keeping every Nth pixel (`down`).
134
-
135
- This sequence of operations bears close resemblance to scipy.signal.upfirdn().
136
- The fused op is considerably more efficient than performing the same calculation
137
- using standard PyTorch ops. It supports gradients of arbitrary order.
138
-
139
- Args:
140
- x: Float32/float64/float16 input tensor of the shape
141
- `[batch_size, num_channels, in_height, in_width]`.
142
- f: Float32 FIR filter of the shape
143
- `[filter_height, filter_width]` (non-separable),
144
- `[filter_taps]` (separable), or
145
- `None` (identity).
146
- up: Integer upsampling factor. Can be a single int or a list/tuple
147
- `[x, y]` (default: 1).
148
- down: Integer downsampling factor. Can be a single int or a list/tuple
149
- `[x, y]` (default: 1).
150
- padding: Padding with respect to the upsampled image. Can be a single number
151
- or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
152
- (default: 0).
153
- flip_filter: False = convolution, True = correlation (default: False).
154
- gain: Overall scaling factor for signal magnitude (default: 1).
155
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
156
-
157
- Returns:
158
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
159
- """
160
- assert isinstance(x, torch.Tensor)
161
- assert impl in ['ref', 'cuda']
162
- if impl == 'cuda' and x.device.type == 'cuda' and _init():
163
- return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
164
- return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
165
-
166
- #----------------------------------------------------------------------------
167
-
168
- @misc.profiled_function
169
- def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
170
- """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
171
- """
172
- # Validate arguments.
173
- assert isinstance(x, torch.Tensor) and x.ndim == 4
174
- if f is None:
175
- f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
176
- assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
177
- assert f.dtype == torch.float32 and not f.requires_grad
178
- batch_size, num_channels, in_height, in_width = x.shape
179
- upx, upy = _parse_scaling(up)
180
- downx, downy = _parse_scaling(down)
181
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
182
-
183
- # Upsample by inserting zeros.
184
- x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
185
- x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
186
- x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
187
-
188
- # Pad or crop.
189
- x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
190
- x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
191
-
192
- # Setup filter.
193
- f = f * (gain ** (f.ndim / 2))
194
- f = f.to(x.dtype)
195
- if not flip_filter:
196
- f = f.flip(list(range(f.ndim)))
197
-
198
- # Convolve with the filter.
199
- f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
200
- if f.ndim == 4:
201
- x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
202
- else:
203
- x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
204
- x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
205
-
206
- # Downsample by throwing away pixels.
207
- x = x[:, :, ::downy, ::downx]
208
- return x
209
-
210
- #----------------------------------------------------------------------------
211
-
212
- _upfirdn2d_cuda_cache = dict()
213
-
214
- def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
215
- """Fast CUDA implementation of `upfirdn2d()` using custom ops.
216
- """
217
- # Parse arguments.
218
- upx, upy = _parse_scaling(up)
219
- downx, downy = _parse_scaling(down)
220
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
221
-
222
- # Lookup from cache.
223
- key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
224
- if key in _upfirdn2d_cuda_cache:
225
- return _upfirdn2d_cuda_cache[key]
226
-
227
- # Forward op.
228
- class Upfirdn2dCuda(torch.autograd.Function):
229
- @staticmethod
230
- def forward(ctx, x, f): # pylint: disable=arguments-differ
231
- assert isinstance(x, torch.Tensor) and x.ndim == 4
232
- if f is None:
233
- f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
234
- assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
235
- y = x
236
- if f.ndim == 2:
237
- y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
238
- else:
239
- y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
240
- y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
241
- ctx.save_for_backward(f)
242
- ctx.x_shape = x.shape
243
- return y
244
-
245
- @staticmethod
246
- def backward(ctx, dy): # pylint: disable=arguments-differ
247
- f, = ctx.saved_tensors
248
- _, _, ih, iw = ctx.x_shape
249
- _, _, oh, ow = dy.shape
250
- fw, fh = _get_filter_size(f)
251
- p = [
252
- fw - padx0 - 1,
253
- iw * upx - ow * downx + padx0 - upx + 1,
254
- fh - pady0 - 1,
255
- ih * upy - oh * downy + pady0 - upy + 1,
256
- ]
257
- dx = None
258
- df = None
259
-
260
- if ctx.needs_input_grad[0]:
261
- dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
262
-
263
- assert not ctx.needs_input_grad[1]
264
- return dx, df
265
-
266
- # Add to cache.
267
- _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
268
- return Upfirdn2dCuda
269
-
270
- #----------------------------------------------------------------------------
271
-
272
- def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
273
- r"""Filter a batch of 2D images using the given 2D FIR filter.
274
-
275
- By default, the result is padded so that its shape matches the input.
276
- User-specified padding is applied on top of that, with negative values
277
- indicating cropping. Pixels outside the image are assumed to be zero.
278
-
279
- Args:
280
- x: Float32/float64/float16 input tensor of the shape
281
- `[batch_size, num_channels, in_height, in_width]`.
282
- f: Float32 FIR filter of the shape
283
- `[filter_height, filter_width]` (non-separable),
284
- `[filter_taps]` (separable), or
285
- `None` (identity).
286
- padding: Padding with respect to the output. Can be a single number or a
287
- list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
288
- (default: 0).
289
- flip_filter: False = convolution, True = correlation (default: False).
290
- gain: Overall scaling factor for signal magnitude (default: 1).
291
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
292
-
293
- Returns:
294
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
295
- """
296
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
297
- fw, fh = _get_filter_size(f)
298
- p = [
299
- padx0 + fw // 2,
300
- padx1 + (fw - 1) // 2,
301
- pady0 + fh // 2,
302
- pady1 + (fh - 1) // 2,
303
- ]
304
- return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
305
-
306
- #----------------------------------------------------------------------------
307
-
308
- def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
309
- r"""Upsample a batch of 2D images using the given 2D FIR filter.
310
-
311
- By default, the result is padded so that its shape is a multiple of the input.
312
- User-specified padding is applied on top of that, with negative values
313
- indicating cropping. Pixels outside the image are assumed to be zero.
314
-
315
- Args:
316
- x: Float32/float64/float16 input tensor of the shape
317
- `[batch_size, num_channels, in_height, in_width]`.
318
- f: Float32 FIR filter of the shape
319
- `[filter_height, filter_width]` (non-separable),
320
- `[filter_taps]` (separable), or
321
- `None` (identity).
322
- up: Integer upsampling factor. Can be a single int or a list/tuple
323
- `[x, y]` (default: 1).
324
- padding: Padding with respect to the output. Can be a single number or a
325
- list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
326
- (default: 0).
327
- flip_filter: False = convolution, True = correlation (default: False).
328
- gain: Overall scaling factor for signal magnitude (default: 1).
329
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
330
-
331
- Returns:
332
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
333
- """
334
- upx, upy = _parse_scaling(up)
335
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
336
- fw, fh = _get_filter_size(f)
337
- p = [
338
- padx0 + (fw + upx - 1) // 2,
339
- padx1 + (fw - upx) // 2,
340
- pady0 + (fh + upy - 1) // 2,
341
- pady1 + (fh - upy) // 2,
342
- ]
343
- return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
344
-
345
- #----------------------------------------------------------------------------
346
-
347
- def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
348
- r"""Downsample a batch of 2D images using the given 2D FIR filter.
349
-
350
- By default, the result is padded so that its shape is a fraction of the input.
351
- User-specified padding is applied on top of that, with negative values
352
- indicating cropping. Pixels outside the image are assumed to be zero.
353
-
354
- Args:
355
- x: Float32/float64/float16 input tensor of the shape
356
- `[batch_size, num_channels, in_height, in_width]`.
357
- f: Float32 FIR filter of the shape
358
- `[filter_height, filter_width]` (non-separable),
359
- `[filter_taps]` (separable), or
360
- `None` (identity).
361
- down: Integer downsampling factor. Can be a single int or a list/tuple
362
- `[x, y]` (default: 1).
363
- padding: Padding with respect to the input. Can be a single number or a
364
- list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
365
- (default: 0).
366
- flip_filter: False = convolution, True = correlation (default: False).
367
- gain: Overall scaling factor for signal magnitude (default: 1).
368
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
369
-
370
- Returns:
371
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
372
- """
373
- downx, downy = _parse_scaling(down)
374
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
375
- fw, fh = _get_filter_size(f)
376
- p = [
377
- padx0 + (fw - downx + 1) // 2,
378
- padx1 + (fw - downx) // 2,
379
- pady0 + (fh - downy + 1) // 2,
380
- pady1 + (fh - downy) // 2,
381
- ]
382
- return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
383
-
384
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/torch_utils/persistence.py DELETED
@@ -1,251 +0,0 @@
1
- ο»Ώ# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Facilities for pickling Python code alongside other data.
10
-
11
- The pickled code is automatically imported into a separate Python module
12
- during unpickling. This way, any previously exported pickles will remain
13
- usable even if the original code is no longer available, or if the current
14
- version of the code is not consistent with what was originally pickled."""
15
-
16
- import sys
17
- import pickle
18
- import io
19
- import inspect
20
- import copy
21
- import uuid
22
- import types
23
- import dnnlib
24
-
25
- #----------------------------------------------------------------------------
26
-
27
- _version = 6 # internal version number
28
- _decorators = set() # {decorator_class, ...}
29
- _import_hooks = [] # [hook_function, ...]
30
- _module_to_src_dict = dict() # {module: src, ...}
31
- _src_to_module_dict = dict() # {src: module, ...}
32
-
33
- #----------------------------------------------------------------------------
34
-
35
- def persistent_class(orig_class):
36
- r"""Class decorator that extends a given class to save its source code
37
- when pickled.
38
-
39
- Example:
40
-
41
- from torch_utils import persistence
42
-
43
- @persistence.persistent_class
44
- class MyNetwork(torch.nn.Module):
45
- def __init__(self, num_inputs, num_outputs):
46
- super().__init__()
47
- self.fc = MyLayer(num_inputs, num_outputs)
48
- ...
49
-
50
- @persistence.persistent_class
51
- class MyLayer(torch.nn.Module):
52
- ...
53
-
54
- When pickled, any instance of `MyNetwork` and `MyLayer` will save its
55
- source code alongside other internal state (e.g., parameters, buffers,
56
- and submodules). This way, any previously exported pickle will remain
57
- usable even if the class definitions have been modified or are no
58
- longer available.
59
-
60
- The decorator saves the source code of the entire Python module
61
- containing the decorated class. It does *not* save the source code of
62
- any imported modules. Thus, the imported modules must be available
63
- during unpickling, also including `torch_utils.persistence` itself.
64
-
65
- It is ok to call functions defined in the same module from the
66
- decorated class. However, if the decorated class depends on other
67
- classes defined in the same module, they must be decorated as well.
68
- This is illustrated in the above example in the case of `MyLayer`.
69
-
70
- It is also possible to employ the decorator just-in-time before
71
- calling the constructor. For example:
72
-
73
- cls = MyLayer
74
- if want_to_make_it_persistent:
75
- cls = persistence.persistent_class(cls)
76
- layer = cls(num_inputs, num_outputs)
77
-
78
- As an additional feature, the decorator also keeps track of the
79
- arguments that were used to construct each instance of the decorated
80
- class. The arguments can be queried via `obj.init_args` and
81
- `obj.init_kwargs`, and they are automatically pickled alongside other
82
- object state. A typical use case is to first unpickle a previous
83
- instance of a persistent class, and then upgrade it to use the latest
84
- version of the source code:
85
-
86
- with open('old_pickle.pkl', 'rb') as f:
87
- old_net = pickle.load(f)
88
- new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
89
- misc.copy_params_and_buffers(old_net, new_net, require_all=True)
90
- """
91
- assert isinstance(orig_class, type)
92
- if is_persistent(orig_class):
93
- return orig_class
94
-
95
- assert orig_class.__module__ in sys.modules
96
- orig_module = sys.modules[orig_class.__module__]
97
- orig_module_src = _module_to_src(orig_module)
98
-
99
- class Decorator(orig_class):
100
- _orig_module_src = orig_module_src
101
- _orig_class_name = orig_class.__name__
102
-
103
- def __init__(self, *args, **kwargs):
104
- super().__init__(*args, **kwargs)
105
- self._init_args = copy.deepcopy(args)
106
- self._init_kwargs = copy.deepcopy(kwargs)
107
- assert orig_class.__name__ in orig_module.__dict__
108
- _check_pickleable(self.__reduce__())
109
-
110
- @property
111
- def init_args(self):
112
- return copy.deepcopy(self._init_args)
113
-
114
- @property
115
- def init_kwargs(self):
116
- return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
117
-
118
- def __reduce__(self):
119
- fields = list(super().__reduce__())
120
- fields += [None] * max(3 - len(fields), 0)
121
- if fields[0] is not _reconstruct_persistent_obj:
122
- meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
123
- fields[0] = _reconstruct_persistent_obj # reconstruct func
124
- fields[1] = (meta,) # reconstruct args
125
- fields[2] = None # state dict
126
- return tuple(fields)
127
-
128
- Decorator.__name__ = orig_class.__name__
129
- _decorators.add(Decorator)
130
- return Decorator
131
-
132
- #----------------------------------------------------------------------------
133
-
134
- def is_persistent(obj):
135
- r"""Test whether the given object or class is persistent, i.e.,
136
- whether it will save its source code when pickled.
137
- """
138
- try:
139
- if obj in _decorators:
140
- return True
141
- except TypeError:
142
- pass
143
- return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
144
-
145
- #----------------------------------------------------------------------------
146
-
147
- def import_hook(hook):
148
- r"""Register an import hook that is called whenever a persistent object
149
- is being unpickled. A typical use case is to patch the pickled source
150
- code to avoid errors and inconsistencies when the API of some imported
151
- module has changed.
152
-
153
- The hook should have the following signature:
154
-
155
- hook(meta) -> modified meta
156
-
157
- `meta` is an instance of `dnnlib.EasyDict` with the following fields:
158
-
159
- type: Type of the persistent object, e.g. `'class'`.
160
- version: Internal version number of `torch_utils.persistence`.
161
- module_src Original source code of the Python module.
162
- class_name: Class name in the original Python module.
163
- state: Internal state of the object.
164
-
165
- Example:
166
-
167
- @persistence.import_hook
168
- def wreck_my_network(meta):
169
- if meta.class_name == 'MyNetwork':
170
- print('MyNetwork is being imported. I will wreck it!')
171
- meta.module_src = meta.module_src.replace("True", "False")
172
- return meta
173
- """
174
- assert callable(hook)
175
- _import_hooks.append(hook)
176
-
177
- #----------------------------------------------------------------------------
178
-
179
- def _reconstruct_persistent_obj(meta):
180
- r"""Hook that is called internally by the `pickle` module to unpickle
181
- a persistent object.
182
- """
183
- meta = dnnlib.EasyDict(meta)
184
- meta.state = dnnlib.EasyDict(meta.state)
185
- for hook in _import_hooks:
186
- meta = hook(meta)
187
- assert meta is not None
188
-
189
- assert meta.version == _version
190
- module = _src_to_module(meta.module_src)
191
-
192
- assert meta.type == 'class'
193
- orig_class = module.__dict__[meta.class_name]
194
- decorator_class = persistent_class(orig_class)
195
- obj = decorator_class.__new__(decorator_class)
196
-
197
- setstate = getattr(obj, '__setstate__', None)
198
- if callable(setstate):
199
- setstate(meta.state) # pylint: disable=not-callable
200
- else:
201
- obj.__dict__.update(meta.state)
202
- return obj
203
-
204
- #----------------------------------------------------------------------------
205
-
206
- def _module_to_src(module):
207
- r"""Query the source code of a given Python module.
208
- """
209
- src = _module_to_src_dict.get(module, None)
210
- if src is None:
211
- src = inspect.getsource(module)
212
- _module_to_src_dict[module] = src
213
- _src_to_module_dict[src] = module
214
- return src
215
-
216
- def _src_to_module(src):
217
- r"""Get or create a Python module for the given source code.
218
- """
219
- module = _src_to_module_dict.get(src, None)
220
- if module is None:
221
- module_name = "_imported_module_" + uuid.uuid4().hex
222
- module = types.ModuleType(module_name)
223
- sys.modules[module_name] = module
224
- _module_to_src_dict[module] = src
225
- _src_to_module_dict[src] = module
226
- exec(src, module.__dict__) # pylint: disable=exec-used
227
- return module
228
-
229
- #----------------------------------------------------------------------------
230
-
231
- def _check_pickleable(obj):
232
- r"""Check that the given object is pickleable, raising an exception if
233
- it is not. This function is expected to be considerably more efficient
234
- than actually pickling the object.
235
- """
236
- def recurse(obj):
237
- if isinstance(obj, (list, tuple, set)):
238
- return [recurse(x) for x in obj]
239
- if isinstance(obj, dict):
240
- return [[recurse(x), recurse(y)] for x, y in obj.items()]
241
- if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
242
- return None # Python primitive types are pickleable.
243
- if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
244
- return None # NumPy arrays and PyTorch tensors are pickleable.
245
- if is_persistent(obj):
246
- return None # Persistent objects are pickleable, by virtue of the constructor check.
247
- return obj
248
- with io.BytesIO() as f:
249
- pickle.dump(recurse(obj), f)
250
-
251
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
PTI/torch_utils/training_stats.py DELETED
@@ -1,268 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Facilities for reporting and collecting training statistics across
10
- multiple processes and devices. The interface is designed to minimize
11
- synchronization overhead as well as the amount of boilerplate in user
12
- code."""
13
-
14
- import re
15
- import numpy as np
16
- import torch
17
- import dnnlib
18
-
19
- from . import misc
20
-
21
- #----------------------------------------------------------------------------
22
-
23
- _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
24
- _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
25
- _counter_dtype = torch.float64 # Data type to use for the internal counters.
26
- _rank = 0 # Rank of the current process.
27
- _sync_device = None # Device to use for multiprocess communication. None = single-process.
28
- _sync_called = False # Has _sync() been called yet?
29
- _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
30
- _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
31
-
32
- #----------------------------------------------------------------------------
33
-
34
- def init_multiprocessing(rank, sync_device):
35
- r"""Initializes `torch_utils.training_stats` for collecting statistics
36
- across multiple processes.
37
-
38
- This function must be called after
39
- `torch.distributed.init_process_group()` and before `Collector.update()`.
40
- The call is not necessary if multi-process collection is not needed.
41
-
42
- Args:
43
- rank: Rank of the current process.
44
- sync_device: PyTorch device to use for inter-process
45
- communication, or None to disable multi-process
46
- collection. Typically `torch.device('cuda', rank)`.
47
- """
48
- global _rank, _sync_device
49
- assert not _sync_called
50
- _rank = rank
51
- _sync_device = sync_device
52
-
53
- #----------------------------------------------------------------------------
54
-
55
- @misc.profiled_function
56
- def report(name, value):
57
- r"""Broadcasts the given set of scalars to all interested instances of
58
- `Collector`, across device and process boundaries.
59
-
60
- This function is expected to be extremely cheap and can be safely
61
- called from anywhere in the training loop, loss function, or inside a
62
- `torch.nn.Module`.
63
-
64
- Warning: The current implementation expects the set of unique names to
65
- be consistent across processes. Please make sure that `report()` is
66
- called at least once for each unique name by each process, and in the
67
- same order. If a given process has no scalars to broadcast, it can do
68
- `report(name, [])` (empty list).
69
-
70
- Args:
71
- name: Arbitrary string specifying the name of the statistic.
72
- Averages are accumulated separately for each unique name.
73
- value: Arbitrary set of scalars. Can be a list, tuple,
74
- NumPy array, PyTorch tensor, or Python scalar.
75
-
76
- Returns:
77
- The same `value` that was passed in.
78
- """
79
- if name not in _counters:
80
- _counters[name] = dict()
81
-
82
- elems = torch.as_tensor(value)
83
- if elems.numel() == 0:
84
- return value
85
-
86
- elems = elems.detach().flatten().to(_reduce_dtype)
87
- moments = torch.stack([
88
- torch.ones_like(elems).sum(),
89
- elems.sum(),
90
- elems.square().sum(),
91
- ])
92
- assert moments.ndim == 1 and moments.shape[0] == _num_moments
93
- moments = moments.to(_counter_dtype)
94
-
95
- device = moments.device
96
- if device not in _counters[name]:
97
- _counters[name][device] = torch.zeros_like(moments)
98
- _counters[name][device].add_(moments)
99
- return value
100
-
101
- #----------------------------------------------------------------------------
102
-
103
- def report0(name, value):
104
- r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
105
- but ignores any scalars provided by the other processes.
106
- See `report()` for further details.
107
- """
108
- report(name, value if _rank == 0 else [])
109
- return value
110
-
111
- #----------------------------------------------------------------------------
112
-
113
- class Collector:
114
- r"""Collects the scalars broadcasted by `report()` and `report0()` and
115
- computes their long-term averages (mean and standard deviation) over
116
- user-defined periods of time.
117
-
118
- The averages are first collected into internal counters that are not
119
- directly visible to the user. They are then copied to the user-visible
120
- state as a result of calling `update()` and can then be queried using
121
- `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
122
- internal counters for the next round, so that the user-visible state
123
- effectively reflects averages collected between the last two calls to
124
- `update()`.
125
-
126
- Args:
127
- regex: Regular expression defining which statistics to
128
- collect. The default is to collect everything.
129
- keep_previous: Whether to retain the previous averages if no
130
- scalars were collected on a given round
131
- (default: True).
132
- """
133
- def __init__(self, regex='.*', keep_previous=True):
134
- self._regex = re.compile(regex)
135
- self._keep_previous = keep_previous
136
- self._cumulative = dict()
137
- self._moments = dict()
138
- self.update()
139
- self._moments.clear()
140
-
141
- def names(self):
142
- r"""Returns the names of all statistics broadcasted so far that
143
- match the regular expression specified at construction time.
144
- """
145
- return [name for name in _counters if self._regex.fullmatch(name)]
146
-
147
- def update(self):
148
- r"""Copies current values of the internal counters to the
149
- user-visible state and resets them for the next round.
150
-
151
- If `keep_previous=True` was specified at construction time, the
152
- operation is skipped for statistics that have received no scalars
153
- since the last update, retaining their previous averages.
154
-
155
- This method performs a number of GPU-to-CPU transfers and one
156
- `torch.distributed.all_reduce()`. It is intended to be called
157
- periodically in the main training loop, typically once every
158
- N training steps.
159
- """
160
- if not self._keep_previous:
161
- self._moments.clear()
162
- for name, cumulative in _sync(self.names()):
163
- if name not in self._cumulative:
164
- self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
165
- delta = cumulative - self._cumulative[name]
166
- self._cumulative[name].copy_(cumulative)
167
- if float(delta[0]) != 0:
168
- self._moments[name] = delta
169
-
170
- def _get_delta(self, name):
171
- r"""Returns the raw moments that were accumulated for the given
172
- statistic between the last two calls to `update()`, or zero if
173
- no scalars were collected.
174
- """
175
- assert self._regex.fullmatch(name)
176
- if name not in self._moments:
177
- self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
178
- return self._moments[name]
179
-
180
- def num(self, name):
181
- r"""Returns the number of scalars that were accumulated for the given
182
- statistic between the last two calls to `update()`, or zero if
183
- no scalars were collected.
184
- """
185
- delta = self._get_delta(name)
186
- return int(delta[0])
187
-
188
- def mean(self, name):
189
- r"""Returns the mean of the scalars that were accumulated for the
190
- given statistic between the last two calls to `update()`, or NaN if
191
- no scalars were collected.
192
- """
193
- delta = self._get_delta(name)
194
- if int(delta[0]) == 0:
195
- return float('nan')
196
- return float(delta[1] / delta[0])
197
-
198
- def std(self, name):
199
- r"""Returns the standard deviation of the scalars that were
200
- accumulated for the given statistic between the last two calls to
201
- `update()`, or NaN if no scalars were collected.
202
- """
203
- delta = self._get_delta(name)
204
- if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
205
- return float('nan')
206
- if int(delta[0]) == 1:
207
- return float(0)
208
- mean = float(delta[1] / delta[0])
209
- raw_var = float(delta[2] / delta[0])
210
- return np.sqrt(max(raw_var - np.square(mean), 0))
211
-
212
- def as_dict(self):
213
- r"""Returns the averages accumulated between the last two calls to
214
- `update()` as an `dnnlib.EasyDict`. The contents are as follows:
215
-
216
- dnnlib.EasyDict(
217
- NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
218
- ...
219
- )
220
- """
221
- stats = dnnlib.EasyDict()
222
- for name in self.names():
223
- stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
224
- return stats
225
-
226
- def __getitem__(self, name):
227
- r"""Convenience getter.
228
- `collector[name]` is a synonym for `collector.mean(name)`.
229
- """
230
- return self.mean(name)
231
-
232
- #----------------------------------------------------------------------------
233
-
234
- def _sync(names):
235
- r"""Synchronize the global cumulative counters across devices and
236
- processes. Called internally by `Collector.update()`.
237
- """
238
- if len(names) == 0:
239
- return []
240
- global _sync_called
241
- _sync_called = True
242
-
243
- # Collect deltas within current rank.
244
- deltas = []
245
- device = _sync_device if _sync_device is not None else torch.device('cpu')
246
- for name in names:
247
- delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
248
- for counter in _counters[name].values():
249
- delta.add_(counter.to(device))
250
- counter.copy_(torch.zeros_like(counter))
251
- deltas.append(delta)
252
- deltas = torch.stack(deltas)
253
-
254
- # Sum deltas across ranks.
255
- if _sync_device is not None:
256
- torch.distributed.all_reduce(deltas)
257
-
258
- # Update cumulative values.
259
- deltas = deltas.cpu()
260
- for idx, name in enumerate(names):
261
- if name not in _cumulative:
262
- _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
263
- _cumulative[name].add_(deltas[idx])
264
-
265
- # Return name-value pairs.
266
- return [(name, _cumulative[name]) for name in names]
267
-
268
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,40 +1,89 @@
1
  import gradio as gr
2
- import utils
3
  from PIL import Image
4
  import torch
5
  import math
6
  from torchvision import transforms
7
-
8
-
9
  device = "cpu"
10
  years = [str(y) for y in range(1880, 2020, 10)]
 
11
 
12
 
 
 
 
 
 
13
  orig_models = {}
14
 
15
  for year in years:
16
  G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device)
17
- orig_models[year] = { "G": G.eval()}
18
 
19
 
20
  def run_alignment(image_path,idx=None):
21
  import dlib
22
  from align_all_parallel import align_face
23
  predictor = dlib.shape_predictor("pretrained_models/shape_predictor_68_face_landmarks.dat")
24
- aligned_image = align_face(filepath=image_path, predictor=predictor, idx=idx)
25
- print("Aligned image has shape: {}".format(aligned_image.size))
26
 
27
  return aligned_image
28
 
29
- def predict(inp):
30
  #with torch.no_grad():
31
  inp.save("imgs/input.png")
32
- out = run_alignment("imgs/input.png", idx=0)
33
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
 
36
  gr.Interface(fn=predict,
37
- inputs=gr.Image(type="pil"),
38
- outputs=gr.Image(type="pil"),
39
- #examples=["lion.jpg", "cheetah.jpg"]
40
- ).launch()
 
 
1
  import gradio as gr
2
+ import utils.utils as utils
3
  from PIL import Image
4
  import torch
5
  import math
6
  from torchvision import transforms
7
+ from run_pti import run_PTI
 
8
  device = "cpu"
9
  years = [str(y) for y in range(1880, 2020, 10)]
10
+ decades = [y + "s" for y in years]
11
 
12
 
13
+ transform = transforms.Compose([
14
+ transforms.Resize((256, 256)),
15
+ transforms.ToTensor(),
16
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
17
+
18
  orig_models = {}
19
 
20
  for year in years:
21
  G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device)
22
+ orig_models[year] = { "G": G.eval().float()}
23
 
24
 
25
  def run_alignment(image_path,idx=None):
26
  import dlib
27
  from align_all_parallel import align_face
28
  predictor = dlib.shape_predictor("pretrained_models/shape_predictor_68_face_landmarks.dat")
29
+ aligned_image = align_face(filepath=image_path, predictor=predictor, idx=idx)
30
+
31
 
32
  return aligned_image
33
 
34
+ def predict(inp, in_decade):
35
  #with torch.no_grad():
36
  inp.save("imgs/input.png")
37
+ inversion = run_alignment("imgs/input.png", idx=0)
38
+ inversion.save("imgs/cropped/input.png")
39
+ run_PTI(run_name="gradio_demo", use_wandb=False, use_multi_id_training=False)
40
+ #inversion = Image.open("imgs/cropped/input.png")
41
+
42
+ in_year = in_decade[:-1]
43
+ pti_models = {}
44
+
45
+ for year in years:
46
+ G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device)
47
+ pti_models[year] = { "G": G.eval().float()}
48
+
49
+
50
+ pti_models[in_year]['G'] = torch.load(f"checkpoints/model_gradio_demo_input.pt", device).eval().float()
51
+
52
+ for year in years:
53
+ if year != in_year:
54
+ for p_pti, p_orig, (names, p) in zip(pti_models[in_year]['G'].parameters(),orig_models[in_year]['G'].parameters(), pti_models[year]['G'].named_parameters()):
55
+ with torch.no_grad():
56
+ delta = p_pti - p_orig
57
+ p += delta
58
+
59
+ space = 0
60
+ dst = Image.new("RGB", (256 * (len(years) + 1) + (space * len(years)), 256), color='white')
61
+
62
+
63
+ w_pti = torch.load(f"embeddings/{in_year}/PTI/input/0.pt", map_location=device)
64
+
65
+ border_width = 10
66
+ #fill_color = 'red'
67
+ dst.paste(inversion, (0, 0))
68
+
69
+
70
+
71
+ for i in range(0, len(years)):
72
+ year = str(years[i])
73
+ with torch.no_grad():
74
+ child_tensor = pti_models[year]["G"].synthesis(w_pti.view(1, 14, 512), noise_mode="const", force_fp32=True)
75
+ img = utils.tensor2im(child_tensor.squeeze(0))
76
+ # if year == in_year:
77
+ # img = img.crop((border_width, border_width, 256 - border_width, 256-border_width))
78
+ # img = PIL.ImageOps.expand(img, border=border_width, fill=fill_color)
79
+ dst.paste(img, ((256 + space) * (i+1), 0))
80
+ dst
81
+ return dst
82
 
83
 
84
  gr.Interface(fn=predict,
85
+ inputs=[gr.Image(label="Input Image", type="pil"), gr.Dropdown(label="Input Decade", choices=decades, value="2010s")],
86
+ outputs=gr.Image(label="Decade Transformations", type="pil"),
87
+ examples=[["imgs/Steven-Yeun.jpg", "2010s"]]
88
+
89
+ ).launch() #.launch(server_name="0.0.0.0", server_port=8098)
checkpoints/model_gradio_demo_input.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65ee5644ec8ab0966a4eb51971995c071f8178db765750d45e32e0ed18a09738
3
+ size 99867041
PTI/criteria/color_transfer_loss.py β†’ color_transfer_loss.py RENAMED
File without changes
{PTI/configs β†’ configs}/__init__.py RENAMED
File without changes
{PTI/configs β†’ configs}/evaluation_config.py RENAMED
File without changes
{PTI/configs β†’ configs}/global_config.py RENAMED
@@ -1,6 +1,6 @@
1
  ## Device
2
- cuda_visible_devices = "1"
3
- device = "cuda:0"
4
 
5
  ## Logs
6
  training_step = 1
 
1
  ## Device
2
+ cuda_visible_devices = "0"
3
+ device = "cpu"
4
 
5
  ## Logs
6
  training_step = 1
{PTI/configs β†’ configs}/hyperparameters.py RENAMED
@@ -28,4 +28,4 @@ max_images_to_invert = 10
28
  pti_learning_rate = 3e-4
29
  first_inv_lr = 5e-3
30
  train_batch_size = 1
31
- use_last_w_pivots = True
 
28
  pti_learning_rate = 3e-4
29
  first_inv_lr = 5e-3
30
  train_batch_size = 1
31
+ use_last_w_pivots = False
{PTI/configs β†’ configs}/paths_config.py RENAMED
@@ -4,12 +4,12 @@ year = "2010"
4
  e4e = "./pretrained_models/e4e_ffhq_encode.pt"
5
 
6
 
7
- stylegan2_ada_ffhq = f"../pretrained_models/{year}.pkl"
8
 
9
  style_clip_pretrained_mappers = ""
10
- ir_se50 = "/share/phoenix/nfs04/S7/wikitime_models/model_ir_se50.pth"
11
  dlib = "./pretrained_models/align.dat"
12
- deeplab = "/share/phoenix/nfs04/S7/wikitime_models/deeplab_model/deeplab_model.pth"
13
 
14
  ## Dirs for output files
15
  checkpoints_dir = "./checkpoints"
@@ -20,7 +20,7 @@ experiments_output_dir = "./output"
20
  ## Input info
21
  ### Input dir, where the images reside
22
  input_data_path = (
23
- f"/share/phoenix/nfs04/S7/emc348/WikiFaces/datasets/new_crops/test/{year}"
24
  )
25
  input_data_id = f"{year}"
26
 
 
4
  e4e = "./pretrained_models/e4e_ffhq_encode.pt"
5
 
6
 
7
+ stylegan2_ada_ffhq = f"pretrained_models/{year}.pkl"
8
 
9
  style_clip_pretrained_mappers = ""
10
+ ir_se50 = "pretrained_models/model_ir_se50.pth"
11
  dlib = "./pretrained_models/align.dat"
12
+ deeplab = "pretrained_models/deeplab_model/deeplab_model.pth"
13
 
14
  ## Dirs for output files
15
  checkpoints_dir = "./checkpoints"
 
20
  ## Input info
21
  ### Input dir, where the images reside
22
  input_data_path = (
23
+ f"imgs/cropped"
24
  )
25
  input_data_id = f"{year}"
26
 
{PTI/criteria β†’ criteria}/__init__.py RENAMED
File without changes
{PTI/criteria β†’ criteria}/backbones/__init__.py RENAMED
File without changes
{PTI/criteria β†’ criteria}/backbones/iresnet.py RENAMED
File without changes
{PTI/criteria β†’ criteria}/backbones/iresnet2060.py RENAMED
File without changes
{PTI/criteria β†’ criteria}/backbones/mobilefacenet.py RENAMED
File without changes
{PTI/criteria β†’ criteria}/deeplab.py RENAMED
File without changes
{PTI/criteria β†’ criteria}/helpers.py RENAMED
File without changes
{PTI/criteria β†’ criteria}/id_loss.py RENAMED
File without changes
{PTI/criteria β†’ criteria}/l2_loss.py RENAMED
File without changes
{PTI/criteria β†’ criteria}/localitly_regulizer.py RENAMED
File without changes
{PTI/criteria β†’ criteria}/mask.py RENAMED
File without changes
{PTI/criteria β†’ criteria}/model_irse.py RENAMED
File without changes
{PTI/criteria β†’ criteria}/validation.py RENAMED
File without changes
dnnlib/__pycache__/__init__.cpython-39.pyc CHANGED
Binary files a/dnnlib/__pycache__/__init__.cpython-39.pyc and b/dnnlib/__pycache__/__init__.cpython-39.pyc differ
 
dnnlib/__pycache__/util.cpython-39.pyc CHANGED
Binary files a/dnnlib/__pycache__/util.cpython-39.pyc and b/dnnlib/__pycache__/util.cpython-39.pyc differ
 
embeddings/2010/PTI/input/0.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb9b3fc3d08f6dd3ee5c87299fff8cd932e4a6afaa90fe06f3d5c5f9503ebf26
3
+ size 29419
imgs/Steven-Yeun.jpg ADDED

Git LFS Details

  • SHA256: f7d9da1331c75fc2b8ac8caa024c804ac500c9c29b5ed4edf60bf30247eae8a5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.63 MB
imgs/cropped/input.png ADDED

Git LFS Details

  • SHA256: ba7b8df0bffe226c723eb22c537e66ff9de844e6aae7845a6c88e696f03b6a40
  • Pointer size: 131 Bytes
  • Size of remote file: 105 kB
imgs/input.png ADDED

Git LFS Details

  • SHA256: 3f8c1b42d80f44efcf0cb03a301072284a0b7ad6ae6f11871be44b7fae79613e
  • Pointer size: 133 Bytes
  • Size of remote file: 13.7 MB
{PTI/training β†’ models/StyleCLIP}/__init__.py RENAMED
File without changes
{PTI/training/coaches β†’ models/StyleCLIP/criteria}/__init__.py RENAMED
File without changes
models/StyleCLIP/criteria/clip_loss.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import clip
4
+
5
+
6
+ class CLIPLoss(torch.nn.Module):
7
+
8
+ def __init__(self, opts):
9
+ super(CLIPLoss, self).__init__()
10
+ self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
11
+ self.upsample = torch.nn.Upsample(scale_factor=7)
12
+ self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32)
13
+
14
+ def forward(self, image, text):
15
+ image = self.avg_pool(self.upsample(image))
16
+ similarity = 1 - self.model(image, text)[0] / 100
17
+ return similarity
models/StyleCLIP/criteria/id_loss.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from models.facial_recognition.model_irse import Backbone
5
+
6
+
7
+ class IDLoss(nn.Module):
8
+ def __init__(self, opts):
9
+ super(IDLoss, self).__init__()
10
+ print('Loading ResNet ArcFace')
11
+ self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
12
+ self.facenet.load_state_dict(torch.load(opts.ir_se50_weights))
13
+ self.pool = torch.nn.AdaptiveAvgPool2d((256, 256))
14
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
15
+ self.facenet.eval()
16
+ self.opts = opts
17
+
18
+ def extract_feats(self, x):
19
+ if x.shape[2] != 256:
20
+ x = self.pool(x)
21
+ x = x[:, :, 35:223, 32:220] # Crop interesting region
22
+ x = self.face_pool(x)
23
+ x_feats = self.facenet(x)
24
+ return x_feats
25
+
26
+ def forward(self, y_hat, y):
27
+ n_samples = y.shape[0]
28
+ y_feats = self.extract_feats(y) # Otherwise use the feature from there
29
+ y_hat_feats = self.extract_feats(y_hat)
30
+ y_feats = y_feats.detach()
31
+ loss = 0
32
+ sim_improvement = 0
33
+ count = 0
34
+ for i in range(n_samples):
35
+ diff_target = y_hat_feats[i].dot(y_feats[i])
36
+ loss += 1 - diff_target
37
+ count += 1
38
+
39
+ return loss / count, sim_improvement / count