Spaces:
Runtime error
Runtime error
echen01
commited on
Commit
β’
2fec875
1
Parent(s):
926824a
working demo
Browse filesThis view is limited to 50 files because it contains too many changes. Β
See raw diff
- PTI/.gitignore +0 -1
- PTI/LICENSE +0 -21
- PTI/README.md +0 -229
- PTI/torch_utils/custom_ops.py +0 -126
- PTI/torch_utils/misc.py +0 -262
- PTI/torch_utils/ops/bias_act.cpp +0 -99
- PTI/torch_utils/ops/bias_act.cu +0 -173
- PTI/torch_utils/ops/bias_act.h +0 -38
- PTI/torch_utils/ops/bias_act.py +0 -212
- PTI/torch_utils/ops/conv2d_gradfix.py +0 -170
- PTI/torch_utils/ops/conv2d_resample.py +0 -156
- PTI/torch_utils/ops/fma.py +0 -60
- PTI/torch_utils/ops/grid_sample_gradfix.py +0 -83
- PTI/torch_utils/ops/upfirdn2d.cpp +0 -103
- PTI/torch_utils/ops/upfirdn2d.cu +0 -350
- PTI/torch_utils/ops/upfirdn2d.h +0 -59
- PTI/torch_utils/ops/upfirdn2d.py +0 -384
- PTI/torch_utils/persistence.py +0 -251
- PTI/torch_utils/training_stats.py +0 -268
- app.py +62 -13
- checkpoints/model_gradio_demo_input.pt +3 -0
- PTI/criteria/color_transfer_loss.py β color_transfer_loss.py +0 -0
- {PTI/configs β configs}/__init__.py +0 -0
- {PTI/configs β configs}/evaluation_config.py +0 -0
- {PTI/configs β configs}/global_config.py +2 -2
- {PTI/configs β configs}/hyperparameters.py +1 -1
- {PTI/configs β configs}/paths_config.py +4 -4
- {PTI/criteria β criteria}/__init__.py +0 -0
- {PTI/criteria β criteria}/backbones/__init__.py +0 -0
- {PTI/criteria β criteria}/backbones/iresnet.py +0 -0
- {PTI/criteria β criteria}/backbones/iresnet2060.py +0 -0
- {PTI/criteria β criteria}/backbones/mobilefacenet.py +0 -0
- {PTI/criteria β criteria}/deeplab.py +0 -0
- {PTI/criteria β criteria}/helpers.py +0 -0
- {PTI/criteria β criteria}/id_loss.py +0 -0
- {PTI/criteria β criteria}/l2_loss.py +0 -0
- {PTI/criteria β criteria}/localitly_regulizer.py +0 -0
- {PTI/criteria β criteria}/mask.py +0 -0
- {PTI/criteria β criteria}/model_irse.py +0 -0
- {PTI/criteria β criteria}/validation.py +0 -0
- dnnlib/__pycache__/__init__.cpython-39.pyc +0 -0
- dnnlib/__pycache__/util.cpython-39.pyc +0 -0
- embeddings/2010/PTI/input/0.pt +3 -0
- imgs/Steven-Yeun.jpg +3 -0
- imgs/cropped/input.png +3 -0
- imgs/input.png +3 -0
- {PTI/training β models/StyleCLIP}/__init__.py +0 -0
- {PTI/training/coaches β models/StyleCLIP/criteria}/__init__.py +0 -0
- models/StyleCLIP/criteria/clip_loss.py +17 -0
- models/StyleCLIP/criteria/id_loss.py +39 -0
PTI/.gitignore
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
|
|
|
|
PTI/LICENSE
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
MIT License
|
2 |
-
|
3 |
-
Copyright (c) 2021 Daniel Roich
|
4 |
-
|
5 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
-
of this software and associated documentation files (the "Software"), to deal
|
7 |
-
in the Software without restriction, including without limitation the rights
|
8 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
-
copies of the Software, and to permit persons to whom the Software is
|
10 |
-
furnished to do so, subject to the following conditions:
|
11 |
-
|
12 |
-
The above copyright notice and this permission notice shall be included in all
|
13 |
-
copies or substantial portions of the Software.
|
14 |
-
|
15 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
-
SOFTWARE.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PTI/README.md
DELETED
@@ -1,229 +0,0 @@
|
|
1 |
-
# PTI: Pivotal Tuning for Latent-based editing of Real Images
|
2 |
-
|
3 |
-
<!-- > Recently, a surge of advanced facial editing techniques have been proposed
|
4 |
-
that leverage the generative power of a pre-trained StyleGAN. To successfully
|
5 |
-
edit an image this way, one must first project (or invert) the image into
|
6 |
-
the pre-trained generatorβs domain. As it turns out, however, StyleGANβs
|
7 |
-
latent space induces an inherent tradeoff between distortion and editability,
|
8 |
-
i.e. between maintaining the original appearance and convincingly altering
|
9 |
-
some of its attributes. Practically, this means it is still challenging to
|
10 |
-
apply ID-preserving facial latent-space editing to faces which are out of the
|
11 |
-
generatorβs domain. In this paper, we present an approach to bridge this
|
12 |
-
gap. Our technique slightly alters the generator, so that an out-of-domain
|
13 |
-
image is faithfully mapped into an in-domain latent code. The key idea is
|
14 |
-
pivotal tuning β a brief training process that preserves the editing quality
|
15 |
-
of an in-domain latent region, while changing its portrayed identity and
|
16 |
-
appearance. In Pivotal Tuning Inversion (PTI), an initial inverted latent code
|
17 |
-
serves as a pivot, around which the generator is fined-tuned. At the same
|
18 |
-
time, a regularization term keeps nearby identities intact, to locally contain
|
19 |
-
the effect. This surgical training process ends up altering appearance features
|
20 |
-
that represent mostly identity, without affecting editing capabilities.
|
21 |
-
To supplement this, we further show that pivotal tuning can also adjust the
|
22 |
-
generator to accommodate a multitude of faces, while introducing negligible
|
23 |
-
distortion on the rest of the domain. We validate our technique through
|
24 |
-
inversion and editing metrics, and show preferable scores to state-of-the-art
|
25 |
-
methods. We further qualitatively demonstrate our technique by applying
|
26 |
-
advanced edits (such as pose, age, or expression) to numerous images of
|
27 |
-
well-known and recognizable identities. Finally, we demonstrate resilience
|
28 |
-
to harder cases, including heavy make-up, elaborate hairstyles and/or headwear,
|
29 |
-
which otherwise could not have been successfully inverted and edited
|
30 |
-
by state-of-the-art methods. -->
|
31 |
-
|
32 |
-
<a href="https://arxiv.org/abs/2106.05744"><img src="https://img.shields.io/badge/arXiv-2008.00951-b31b1b.svg"></a>
|
33 |
-
<a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-yellow.svg"></a>
|
34 |
-
Inference Notebook: <a href="https://colab.research.google.com/github/danielroich/PTI/blob/main/notebooks/inference_playground.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" height=20></a>
|
35 |
-
|
36 |
-
<p align="center">
|
37 |
-
<img src="docs/teaser.jpg"/>
|
38 |
-
<br>
|
39 |
-
Pivotal Tuning Inversion (PTI) enables employing off-the-shelf latent based
|
40 |
-
semantic editing techniques on real images using StyleGAN.
|
41 |
-
PTI excels in identity preserving edits, portrayed through recognizable figures β
|
42 |
-
Serena Williams and Robert Downey Jr. (top), and in handling faces which
|
43 |
-
are clearly out-of-domain, e.g., due to heavy makeup (bottom).
|
44 |
-
</br>
|
45 |
-
</p>
|
46 |
-
|
47 |
-
## Description
|
48 |
-
Official Implementation of our PTI paper + code for evaluation metrics. PTI introduces an optimization mechanizem for solving the StyleGAN inversion task.
|
49 |
-
Providing near-perfect reconstruction results while maintaining the high editing abilitis of the native StyleGAN latent space W. For more details, see <a href="https://arxiv.org/abs/2106.05744"><img src="https://img.shields.io/badge/arXiv-2008.00951-b31b1b.svg"></a>
|
50 |
-
|
51 |
-
## Recent Updates
|
52 |
-
**2021.07.01**: Fixed files download phase in the inference notebook. Which might caused the notebook not to run smoothly.
|
53 |
-
|
54 |
-
**2021.06.29**: Added support for CPU. In order to run PTI on CPU please change `device` parameter under `configs/global_config.py` to "cpu" instead of "cuda".
|
55 |
-
|
56 |
-
**2021.06.25** : Adding mohawk edit using StyleCLIP+PTI in inference notebook.
|
57 |
-
Updating documentation in inference notebook due to Google Drive rate limit reached.
|
58 |
-
Currently, Google Drive does not allow to download the pretrined models using Colab automatically. Manual intervention might be needed.
|
59 |
-
|
60 |
-
## Getting Started
|
61 |
-
### Prerequisites
|
62 |
-
- Linux or macOS
|
63 |
-
- NVIDIA GPU + CUDA CuDNN (Not mandatory bur recommended)
|
64 |
-
- Python 3
|
65 |
-
|
66 |
-
### Installation
|
67 |
-
- Dependencies:
|
68 |
-
1. lpips
|
69 |
-
2. wandb
|
70 |
-
3. pytorch
|
71 |
-
4. torchvision
|
72 |
-
5. matplotlib
|
73 |
-
6. dlib
|
74 |
-
- All dependencies can be installed using *pip install* and the package name
|
75 |
-
|
76 |
-
## Pretrained Models
|
77 |
-
Please download the pretrained models from the following links.
|
78 |
-
|
79 |
-
### Auxiliary Models
|
80 |
-
We provide various auxiliary models needed for PTI inversion task.
|
81 |
-
This includes the StyleGAN generator and pre-trained models used for loss computation.
|
82 |
-
| Path | Description
|
83 |
-
| :--- | :----------
|
84 |
-
|[FFHQ StyleGAN](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl) | StyleGAN2-ada model trained on FFHQ with 1024x1024 output resolution.
|
85 |
-
|[Dlib alignment](https://drive.google.com/file/d/1HKmjg6iXsWr4aFPuU0gBXPGR83wqMzq7/view?usp=sharing) | Dlib alignment used for images preproccessing.
|
86 |
-
|[FFHQ e4e encoder](https://drive.google.com/file/d/1ALC5CLA89Ouw40TwvxcwebhzWXM5YSCm/view?usp=sharing) | Pretrained e4e encoder. Used for StyleCLIP editing.
|
87 |
-
|
88 |
-
Note: The StyleGAN model is used directly from the official [stylegan2-ada-pytorch implementation](https://github.com/NVlabs/stylegan2-ada-pytorch).
|
89 |
-
For StyleCLIP pretrained mappers, please see [StyleCLIP's official routes](https://github.com/orpatashnik/StyleCLIP/blob/main/utils.py)
|
90 |
-
|
91 |
-
|
92 |
-
By default, we assume that all auxiliary models are downloaded and saved to the directory `pretrained_models`.
|
93 |
-
However, you may use your own paths by changing the necessary values in `configs/path_configs.py`.
|
94 |
-
|
95 |
-
|
96 |
-
## Inversion
|
97 |
-
### Preparing your Data
|
98 |
-
In order to invert a real image and edit it you should first align and crop it to the correct size. To do so you should perform *One* of the following steps:
|
99 |
-
1. Run `notebooks/align_data.ipynb` and change the "images_path" variable to the raw images path
|
100 |
-
2. Run `utils/align_data.py` and change the "images_path" variable to the raw images path
|
101 |
-
|
102 |
-
|
103 |
-
### Weights And Biases
|
104 |
-
The project supports [Weights And Biases](https://wandb.ai/home) framework for experiment tracking. For the inversion task it enables visualization of the losses progression and the generator intermediate results during the initial inversion and the *Pivotal Tuning*(PT) procedure.
|
105 |
-
|
106 |
-
The log frequency can be adjusted using the parameters defined at `configs/global_config.py` under the "Logs" subsection.
|
107 |
-
|
108 |
-
There is no no need to have an account. However, in order to use the features provided by Weights and Biases you first have to register on their site.
|
109 |
-
|
110 |
-
|
111 |
-
### Running PTI
|
112 |
-
The main training script is `scripts/run_pti.py`. The script receives aligned and cropped images from paths configured in the "Input info" subscetion in
|
113 |
-
`configs/paths_config.py`.
|
114 |
-
Results are saved to directories found at "Dirs for output files" under `configs/paths_config.py`. This includes inversion latent codes and tuned generators.
|
115 |
-
The hyperparametrs for the inversion task can be found at `configs/hyperparameters.py`. They are intilized to the default values used in the paper.
|
116 |
-
|
117 |
-
## Editing
|
118 |
-
By default, we assume that all auxiliary edit directions are downloaded and saved to the directory `editings`.
|
119 |
-
However, you may use your own paths by changing the necessary values in `configs/path_configs.py` under "Edit directions" subsection.
|
120 |
-
|
121 |
-
Example of editing code can be found at `scripts/latent_editor_wrapper.py`
|
122 |
-
|
123 |
-
## Inference Notebooks
|
124 |
-
To help visualize the results of PTI we provide a Jupyter notebook found in `notebooks/inference_playground.ipynb`.
|
125 |
-
The notebook will download the pretrained models and run inference on a sample image found online or
|
126 |
-
on images of your choosing. It is recommended to run this in [Google Colab](https://colab.research.google.com/github/danielroich/PTI/blob/main/notebooks/inference_playground.ipynb).
|
127 |
-
|
128 |
-
The notebook demonstrates how to:
|
129 |
-
- Invert an image using PTI
|
130 |
-
- Visualise the inversion and use the PTI output
|
131 |
-
- Edit the image after PTI using InterfaceGAN and StyleCLIP
|
132 |
-
- Compare to other inversion methods
|
133 |
-
|
134 |
-
## Evaluation
|
135 |
-
Currently the repository supports qualitative evaluation for reconstruction of: PTI, SG2 (*W Space*), e4e, SG2Plus (*W+ Space*).
|
136 |
-
As well as editing using InterfaceGAN and GANSpace for the same inversion methods.
|
137 |
-
To run the evaluation please see `evaluation/qualitative_edit_comparison.py`. Examples of the evaluation scripts are:
|
138 |
-
|
139 |
-
<p align="center">
|
140 |
-
<img src="docs/model_rec.jpg"/>
|
141 |
-
<br>
|
142 |
-
Reconsturction comparison between different methods. The images order is: Original image, W+ inversion, e4e inversion, W inversion, PTI inversion
|
143 |
-
</br>
|
144 |
-
</p>
|
145 |
-
|
146 |
-
<p align="center">
|
147 |
-
<img src="docs/stern_rotation.jpg"/>
|
148 |
-
<br>
|
149 |
-
InterfaceGAN pose edit comparison between different methods. The images order is: Original, W+, e4e, W, PTI
|
150 |
-
</br>
|
151 |
-
</p>
|
152 |
-
|
153 |
-
<p align="center">
|
154 |
-
<img src="docs/tyron_original.jpg" width="220" height="220"/>
|
155 |
-
<img src="docs/tyron_edit.jpg" width="220" height="220"/>
|
156 |
-
<br>
|
157 |
-
Image per edit or several edits without comparison
|
158 |
-
</br>
|
159 |
-
</p>
|
160 |
-
|
161 |
-
### Coming Soon - Quantitative evaluation and StyleCLIP qualitative evaluation
|
162 |
-
|
163 |
-
## Repository structure
|
164 |
-
| Path | Description <img width=200>
|
165 |
-
| :--- | :---
|
166 |
-
| ├ configs | Folder containing configs defining Hyperparameters, paths and logging
|
167 |
-
| ├ criteria | Folder containing various loss and regularization criterias for the optimization
|
168 |
-
| ├ dnnlib | Folder containing internal utils for StyleGAN2-ada
|
169 |
-
| ├ docs | Folder containing the latent space edit directions
|
170 |
-
| ├ editings | Folder containing images displayed in the README
|
171 |
-
| ├ environment | Folder containing Anaconda environment used in our experiments
|
172 |
-
| ├ licenses | Folder containing licenses of the open source projects used in this repository
|
173 |
-
| ├ models | Folder containing models used in different editing techniques and first phase inversion
|
174 |
-
| ├ notebooks | Folder with jupyter notebooks to demonstrate the usage of PTI end-to-end
|
175 |
-
| ├ scripts | Folder with running scripts for inversion, editing and metric computations
|
176 |
-
| ├ torch_utils | Folder containing internal utils for StyleGAN2-ada
|
177 |
-
| ├ training | Folder containing the core training logic of PTI
|
178 |
-
| ├ utils | Folder with various utility functions
|
179 |
-
|
180 |
-
|
181 |
-
## Credits
|
182 |
-
**StyleGAN2-ada model and implementation:**
|
183 |
-
https://github.com/NVlabs/stylegan2-ada-pytorch
|
184 |
-
Copyright Β© 2021, NVIDIA Corporation.
|
185 |
-
Nvidia Source Code License https://nvlabs.github.io/stylegan2-ada-pytorch/license.html
|
186 |
-
|
187 |
-
**LPIPS model and implementation:**
|
188 |
-
https://github.com/richzhang/PerceptualSimilarity
|
189 |
-
Copyright (c) 2020, Sou Uchida
|
190 |
-
License (BSD 2-Clause) https://github.com/richzhang/PerceptualSimilarity/blob/master/LICENSE
|
191 |
-
|
192 |
-
**e4e model and implementation:**
|
193 |
-
https://github.com/omertov/encoder4editing
|
194 |
-
Copyright (c) 2021 omertov
|
195 |
-
License (MIT) https://github.com/omertov/encoder4editing/blob/main/LICENSE
|
196 |
-
|
197 |
-
**StyleCLIP model and implementation:**
|
198 |
-
https://github.com/orpatashnik/StyleCLIP
|
199 |
-
Copyright (c) 2021 orpatashnik
|
200 |
-
License (MIT) https://github.com/orpatashnik/StyleCLIP/blob/main/LICENSE
|
201 |
-
|
202 |
-
**InterfaceGAN implementation:**
|
203 |
-
https://github.com/genforce/interfacegan
|
204 |
-
Copyright (c) 2020 genforce
|
205 |
-
License (MIT) https://github.com/genforce/interfacegan/blob/master/LICENSE
|
206 |
-
|
207 |
-
**GANSpace implementation:**
|
208 |
-
https://github.com/harskish/ganspace
|
209 |
-
Copyright (c) 2020 harkish
|
210 |
-
License (Apache License 2.0) https://github.com/harskish/ganspace/blob/master/LICENSE
|
211 |
-
|
212 |
-
|
213 |
-
## Acknowledgments
|
214 |
-
This repository structure is based on [encoder4editing](https://github.com/omertov/encoder4editing) and [ReStyle](https://github.com/yuval-alaluf/restyle-encoder) repositories
|
215 |
-
|
216 |
-
## Contact
|
217 |
-
For any inquiry please contact us at our email addresses: [email protected] or [email protected]
|
218 |
-
|
219 |
-
|
220 |
-
## Citation
|
221 |
-
If you use this code for your research, please cite:
|
222 |
-
```
|
223 |
-
@article{roich2021pivotal,
|
224 |
-
title={Pivotal Tuning for Latent-based Editing of Real Images},
|
225 |
-
author={Roich, Daniel and Mokady, Ron and Bermano, Amit H and Cohen-Or, Daniel},
|
226 |
-
journal={arXiv preprint arXiv:2106.05744},
|
227 |
-
year={2021}
|
228 |
-
}
|
229 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PTI/torch_utils/custom_ops.py
DELETED
@@ -1,126 +0,0 @@
|
|
1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
import os
|
10 |
-
import glob
|
11 |
-
import torch
|
12 |
-
import torch.utils.cpp_extension
|
13 |
-
import importlib
|
14 |
-
import hashlib
|
15 |
-
import shutil
|
16 |
-
from pathlib import Path
|
17 |
-
|
18 |
-
from torch.utils.file_baton import FileBaton
|
19 |
-
|
20 |
-
#----------------------------------------------------------------------------
|
21 |
-
# Global options.
|
22 |
-
|
23 |
-
verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
|
24 |
-
|
25 |
-
#----------------------------------------------------------------------------
|
26 |
-
# Internal helper funcs.
|
27 |
-
|
28 |
-
def _find_compiler_bindir():
|
29 |
-
patterns = [
|
30 |
-
'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
31 |
-
'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
32 |
-
'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
33 |
-
'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
|
34 |
-
]
|
35 |
-
for pattern in patterns:
|
36 |
-
matches = sorted(glob.glob(pattern))
|
37 |
-
if len(matches):
|
38 |
-
return matches[-1]
|
39 |
-
return None
|
40 |
-
|
41 |
-
#----------------------------------------------------------------------------
|
42 |
-
# Main entry point for compiling and loading C++/CUDA plugins.
|
43 |
-
|
44 |
-
_cached_plugins = dict()
|
45 |
-
|
46 |
-
def get_plugin(module_name, sources, **build_kwargs):
|
47 |
-
assert verbosity in ['none', 'brief', 'full']
|
48 |
-
|
49 |
-
# Already cached?
|
50 |
-
if module_name in _cached_plugins:
|
51 |
-
return _cached_plugins[module_name]
|
52 |
-
|
53 |
-
# Print status.
|
54 |
-
if verbosity == 'full':
|
55 |
-
print(f'Setting up PyTorch plugin "{module_name}"...')
|
56 |
-
elif verbosity == 'brief':
|
57 |
-
print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
|
58 |
-
|
59 |
-
try: # pylint: disable=too-many-nested-blocks
|
60 |
-
# Make sure we can find the necessary compiler binaries.
|
61 |
-
if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
|
62 |
-
compiler_bindir = _find_compiler_bindir()
|
63 |
-
if compiler_bindir is None:
|
64 |
-
raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
|
65 |
-
os.environ['PATH'] += ';' + compiler_bindir
|
66 |
-
|
67 |
-
# Compile and load.
|
68 |
-
verbose_build = (verbosity == 'full')
|
69 |
-
|
70 |
-
# Incremental build md5sum trickery. Copies all the input source files
|
71 |
-
# into a cached build directory under a combined md5 digest of the input
|
72 |
-
# source files. Copying is done only if the combined digest has changed.
|
73 |
-
# This keeps input file timestamps and filenames the same as in previous
|
74 |
-
# extension builds, allowing for fast incremental rebuilds.
|
75 |
-
#
|
76 |
-
# This optimization is done only in case all the source files reside in
|
77 |
-
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
|
78 |
-
# environment variable is set (we take this as a signal that the user
|
79 |
-
# actually cares about this.)
|
80 |
-
source_dirs_set = set(os.path.dirname(source) for source in sources)
|
81 |
-
if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
|
82 |
-
all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
|
83 |
-
|
84 |
-
# Compute a combined hash digest for all source files in the same
|
85 |
-
# custom op directory (usually .cu, .cpp, .py and .h files).
|
86 |
-
hash_md5 = hashlib.md5()
|
87 |
-
for src in all_source_files:
|
88 |
-
with open(src, 'rb') as f:
|
89 |
-
hash_md5.update(f.read())
|
90 |
-
build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
|
91 |
-
digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
|
92 |
-
|
93 |
-
if not os.path.isdir(digest_build_dir):
|
94 |
-
os.makedirs(digest_build_dir, exist_ok=True)
|
95 |
-
baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
|
96 |
-
if baton.try_acquire():
|
97 |
-
try:
|
98 |
-
for src in all_source_files:
|
99 |
-
shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
|
100 |
-
finally:
|
101 |
-
baton.release()
|
102 |
-
else:
|
103 |
-
# Someone else is copying source files under the digest dir,
|
104 |
-
# wait until done and continue.
|
105 |
-
baton.wait()
|
106 |
-
digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
|
107 |
-
torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
|
108 |
-
verbose=verbose_build, sources=digest_sources, **build_kwargs)
|
109 |
-
else:
|
110 |
-
torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
|
111 |
-
module = importlib.import_module(module_name)
|
112 |
-
|
113 |
-
except:
|
114 |
-
if verbosity == 'brief':
|
115 |
-
print('Failed!')
|
116 |
-
raise
|
117 |
-
|
118 |
-
# Print status and add to cache.
|
119 |
-
if verbosity == 'full':
|
120 |
-
print(f'Done setting up PyTorch plugin "{module_name}".')
|
121 |
-
elif verbosity == 'brief':
|
122 |
-
print('Done.')
|
123 |
-
_cached_plugins[module_name] = module
|
124 |
-
return module
|
125 |
-
|
126 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PTI/torch_utils/misc.py
DELETED
@@ -1,262 +0,0 @@
|
|
1 |
-
ο»Ώ# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
import re
|
10 |
-
import contextlib
|
11 |
-
import numpy as np
|
12 |
-
import torch
|
13 |
-
import warnings
|
14 |
-
import dnnlib
|
15 |
-
|
16 |
-
#----------------------------------------------------------------------------
|
17 |
-
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
|
18 |
-
# same constant is used multiple times.
|
19 |
-
|
20 |
-
_constant_cache = dict()
|
21 |
-
|
22 |
-
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
23 |
-
value = np.asarray(value)
|
24 |
-
if shape is not None:
|
25 |
-
shape = tuple(shape)
|
26 |
-
if dtype is None:
|
27 |
-
dtype = torch.get_default_dtype()
|
28 |
-
if device is None:
|
29 |
-
device = torch.device('cpu')
|
30 |
-
if memory_format is None:
|
31 |
-
memory_format = torch.contiguous_format
|
32 |
-
|
33 |
-
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
|
34 |
-
tensor = _constant_cache.get(key, None)
|
35 |
-
if tensor is None:
|
36 |
-
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
37 |
-
if shape is not None:
|
38 |
-
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
39 |
-
tensor = tensor.contiguous(memory_format=memory_format)
|
40 |
-
_constant_cache[key] = tensor
|
41 |
-
return tensor
|
42 |
-
|
43 |
-
#----------------------------------------------------------------------------
|
44 |
-
# Replace NaN/Inf with specified numerical values.
|
45 |
-
|
46 |
-
try:
|
47 |
-
nan_to_num = torch.nan_to_num # 1.8.0a0
|
48 |
-
except AttributeError:
|
49 |
-
def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
|
50 |
-
assert isinstance(input, torch.Tensor)
|
51 |
-
if posinf is None:
|
52 |
-
posinf = torch.finfo(input.dtype).max
|
53 |
-
if neginf is None:
|
54 |
-
neginf = torch.finfo(input.dtype).min
|
55 |
-
assert nan == 0
|
56 |
-
return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
|
57 |
-
|
58 |
-
#----------------------------------------------------------------------------
|
59 |
-
# Symbolic assert.
|
60 |
-
|
61 |
-
try:
|
62 |
-
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
|
63 |
-
except AttributeError:
|
64 |
-
symbolic_assert = torch.Assert # 1.7.0
|
65 |
-
|
66 |
-
#----------------------------------------------------------------------------
|
67 |
-
# Context manager to suppress known warnings in torch.jit.trace().
|
68 |
-
|
69 |
-
class suppress_tracer_warnings(warnings.catch_warnings):
|
70 |
-
def __enter__(self):
|
71 |
-
super().__enter__()
|
72 |
-
warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
|
73 |
-
return self
|
74 |
-
|
75 |
-
#----------------------------------------------------------------------------
|
76 |
-
# Assert that the shape of a tensor matches the given list of integers.
|
77 |
-
# None indicates that the size of a dimension is allowed to vary.
|
78 |
-
# Performs symbolic assertion when used in torch.jit.trace().
|
79 |
-
|
80 |
-
def assert_shape(tensor, ref_shape):
|
81 |
-
if tensor.ndim != len(ref_shape):
|
82 |
-
raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
|
83 |
-
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
|
84 |
-
if ref_size is None:
|
85 |
-
pass
|
86 |
-
elif isinstance(ref_size, torch.Tensor):
|
87 |
-
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
88 |
-
symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
|
89 |
-
elif isinstance(size, torch.Tensor):
|
90 |
-
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
91 |
-
symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
|
92 |
-
elif size != ref_size:
|
93 |
-
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
|
94 |
-
|
95 |
-
#----------------------------------------------------------------------------
|
96 |
-
# Function decorator that calls torch.autograd.profiler.record_function().
|
97 |
-
|
98 |
-
def profiled_function(fn):
|
99 |
-
def decorator(*args, **kwargs):
|
100 |
-
with torch.autograd.profiler.record_function(fn.__name__):
|
101 |
-
return fn(*args, **kwargs)
|
102 |
-
decorator.__name__ = fn.__name__
|
103 |
-
return decorator
|
104 |
-
|
105 |
-
#----------------------------------------------------------------------------
|
106 |
-
# Sampler for torch.utils.data.DataLoader that loops over the dataset
|
107 |
-
# indefinitely, shuffling items as it goes.
|
108 |
-
|
109 |
-
class InfiniteSampler(torch.utils.data.Sampler):
|
110 |
-
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
|
111 |
-
assert len(dataset) > 0
|
112 |
-
assert num_replicas > 0
|
113 |
-
assert 0 <= rank < num_replicas
|
114 |
-
assert 0 <= window_size <= 1
|
115 |
-
super().__init__(dataset)
|
116 |
-
self.dataset = dataset
|
117 |
-
self.rank = rank
|
118 |
-
self.num_replicas = num_replicas
|
119 |
-
self.shuffle = shuffle
|
120 |
-
self.seed = seed
|
121 |
-
self.window_size = window_size
|
122 |
-
|
123 |
-
def __iter__(self):
|
124 |
-
order = np.arange(len(self.dataset))
|
125 |
-
rnd = None
|
126 |
-
window = 0
|
127 |
-
if self.shuffle:
|
128 |
-
rnd = np.random.RandomState(self.seed)
|
129 |
-
rnd.shuffle(order)
|
130 |
-
window = int(np.rint(order.size * self.window_size))
|
131 |
-
|
132 |
-
idx = 0
|
133 |
-
while True:
|
134 |
-
i = idx % order.size
|
135 |
-
if idx % self.num_replicas == self.rank:
|
136 |
-
yield order[i]
|
137 |
-
if window >= 2:
|
138 |
-
j = (i - rnd.randint(window)) % order.size
|
139 |
-
order[i], order[j] = order[j], order[i]
|
140 |
-
idx += 1
|
141 |
-
|
142 |
-
#----------------------------------------------------------------------------
|
143 |
-
# Utilities for operating with torch.nn.Module parameters and buffers.
|
144 |
-
|
145 |
-
def params_and_buffers(module):
|
146 |
-
assert isinstance(module, torch.nn.Module)
|
147 |
-
return list(module.parameters()) + list(module.buffers())
|
148 |
-
|
149 |
-
def named_params_and_buffers(module):
|
150 |
-
assert isinstance(module, torch.nn.Module)
|
151 |
-
return list(module.named_parameters()) + list(module.named_buffers())
|
152 |
-
|
153 |
-
def copy_params_and_buffers(src_module, dst_module, require_all=False):
|
154 |
-
assert isinstance(src_module, torch.nn.Module)
|
155 |
-
assert isinstance(dst_module, torch.nn.Module)
|
156 |
-
src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
|
157 |
-
for name, tensor in named_params_and_buffers(dst_module):
|
158 |
-
assert (name in src_tensors) or (not require_all)
|
159 |
-
if name in src_tensors:
|
160 |
-
tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
|
161 |
-
|
162 |
-
#----------------------------------------------------------------------------
|
163 |
-
# Context manager for easily enabling/disabling DistributedDataParallel
|
164 |
-
# synchronization.
|
165 |
-
|
166 |
-
@contextlib.contextmanager
|
167 |
-
def ddp_sync(module, sync):
|
168 |
-
assert isinstance(module, torch.nn.Module)
|
169 |
-
if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
|
170 |
-
yield
|
171 |
-
else:
|
172 |
-
with module.no_sync():
|
173 |
-
yield
|
174 |
-
|
175 |
-
#----------------------------------------------------------------------------
|
176 |
-
# Check DistributedDataParallel consistency across processes.
|
177 |
-
|
178 |
-
def check_ddp_consistency(module, ignore_regex=None):
|
179 |
-
assert isinstance(module, torch.nn.Module)
|
180 |
-
for name, tensor in named_params_and_buffers(module):
|
181 |
-
fullname = type(module).__name__ + '.' + name
|
182 |
-
if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
|
183 |
-
continue
|
184 |
-
tensor = tensor.detach()
|
185 |
-
other = tensor.clone()
|
186 |
-
torch.distributed.broadcast(tensor=other, src=0)
|
187 |
-
assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
|
188 |
-
|
189 |
-
#----------------------------------------------------------------------------
|
190 |
-
# Print summary table of module hierarchy.
|
191 |
-
|
192 |
-
def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
|
193 |
-
assert isinstance(module, torch.nn.Module)
|
194 |
-
assert not isinstance(module, torch.jit.ScriptModule)
|
195 |
-
assert isinstance(inputs, (tuple, list))
|
196 |
-
|
197 |
-
# Register hooks.
|
198 |
-
entries = []
|
199 |
-
nesting = [0]
|
200 |
-
def pre_hook(_mod, _inputs):
|
201 |
-
nesting[0] += 1
|
202 |
-
def post_hook(mod, _inputs, outputs):
|
203 |
-
nesting[0] -= 1
|
204 |
-
if nesting[0] <= max_nesting:
|
205 |
-
outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
|
206 |
-
outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
|
207 |
-
entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
|
208 |
-
hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
|
209 |
-
hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
|
210 |
-
|
211 |
-
# Run module.
|
212 |
-
outputs = module(*inputs)
|
213 |
-
for hook in hooks:
|
214 |
-
hook.remove()
|
215 |
-
|
216 |
-
# Identify unique outputs, parameters, and buffers.
|
217 |
-
tensors_seen = set()
|
218 |
-
for e in entries:
|
219 |
-
e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
|
220 |
-
e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
|
221 |
-
e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
|
222 |
-
tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
|
223 |
-
|
224 |
-
# Filter out redundant entries.
|
225 |
-
if skip_redundant:
|
226 |
-
entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
|
227 |
-
|
228 |
-
# Construct table.
|
229 |
-
rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
|
230 |
-
rows += [['---'] * len(rows[0])]
|
231 |
-
param_total = 0
|
232 |
-
buffer_total = 0
|
233 |
-
submodule_names = {mod: name for name, mod in module.named_modules()}
|
234 |
-
for e in entries:
|
235 |
-
name = '<top-level>' if e.mod is module else submodule_names[e.mod]
|
236 |
-
param_size = sum(t.numel() for t in e.unique_params)
|
237 |
-
buffer_size = sum(t.numel() for t in e.unique_buffers)
|
238 |
-
output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
|
239 |
-
output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
|
240 |
-
rows += [[
|
241 |
-
name + (':0' if len(e.outputs) >= 2 else ''),
|
242 |
-
str(param_size) if param_size else '-',
|
243 |
-
str(buffer_size) if buffer_size else '-',
|
244 |
-
(output_shapes + ['-'])[0],
|
245 |
-
(output_dtypes + ['-'])[0],
|
246 |
-
]]
|
247 |
-
for idx in range(1, len(e.outputs)):
|
248 |
-
rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
|
249 |
-
param_total += param_size
|
250 |
-
buffer_total += buffer_size
|
251 |
-
rows += [['---'] * len(rows[0])]
|
252 |
-
rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
|
253 |
-
|
254 |
-
# Print table.
|
255 |
-
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
|
256 |
-
print()
|
257 |
-
for row in rows:
|
258 |
-
print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
|
259 |
-
print()
|
260 |
-
return outputs
|
261 |
-
|
262 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PTI/torch_utils/ops/bias_act.cpp
DELETED
@@ -1,99 +0,0 @@
|
|
1 |
-
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
//
|
3 |
-
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
// and proprietary rights in and to this software, related documentation
|
5 |
-
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
// distribution of this software and related documentation without an express
|
7 |
-
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
#include <torch/extension.h>
|
10 |
-
#include <ATen/cuda/CUDAContext.h>
|
11 |
-
#include <c10/cuda/CUDAGuard.h>
|
12 |
-
#include "bias_act.h"
|
13 |
-
|
14 |
-
//------------------------------------------------------------------------
|
15 |
-
|
16 |
-
static bool has_same_layout(torch::Tensor x, torch::Tensor y)
|
17 |
-
{
|
18 |
-
if (x.dim() != y.dim())
|
19 |
-
return false;
|
20 |
-
for (int64_t i = 0; i < x.dim(); i++)
|
21 |
-
{
|
22 |
-
if (x.size(i) != y.size(i))
|
23 |
-
return false;
|
24 |
-
if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
|
25 |
-
return false;
|
26 |
-
}
|
27 |
-
return true;
|
28 |
-
}
|
29 |
-
|
30 |
-
//------------------------------------------------------------------------
|
31 |
-
|
32 |
-
static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
|
33 |
-
{
|
34 |
-
// Validate arguments.
|
35 |
-
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
36 |
-
TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
|
37 |
-
TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
|
38 |
-
TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
|
39 |
-
TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
|
40 |
-
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
41 |
-
TORCH_CHECK(b.dim() == 1, "b must have rank 1");
|
42 |
-
TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
|
43 |
-
TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
|
44 |
-
TORCH_CHECK(grad >= 0, "grad must be non-negative");
|
45 |
-
|
46 |
-
// Validate layout.
|
47 |
-
TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
|
48 |
-
TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
|
49 |
-
TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
|
50 |
-
TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
|
51 |
-
TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
|
52 |
-
|
53 |
-
// Create output tensor.
|
54 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
55 |
-
torch::Tensor y = torch::empty_like(x);
|
56 |
-
TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
|
57 |
-
|
58 |
-
// Initialize CUDA kernel parameters.
|
59 |
-
bias_act_kernel_params p;
|
60 |
-
p.x = x.data_ptr();
|
61 |
-
p.b = (b.numel()) ? b.data_ptr() : NULL;
|
62 |
-
p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
|
63 |
-
p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
|
64 |
-
p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
|
65 |
-
p.y = y.data_ptr();
|
66 |
-
p.grad = grad;
|
67 |
-
p.act = act;
|
68 |
-
p.alpha = alpha;
|
69 |
-
p.gain = gain;
|
70 |
-
p.clamp = clamp;
|
71 |
-
p.sizeX = (int)x.numel();
|
72 |
-
p.sizeB = (int)b.numel();
|
73 |
-
p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
|
74 |
-
|
75 |
-
// Choose CUDA kernel.
|
76 |
-
void* kernel;
|
77 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
78 |
-
{
|
79 |
-
kernel = choose_bias_act_kernel<scalar_t>(p);
|
80 |
-
});
|
81 |
-
TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
|
82 |
-
|
83 |
-
// Launch CUDA kernel.
|
84 |
-
p.loopX = 4;
|
85 |
-
int blockSize = 4 * 32;
|
86 |
-
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
|
87 |
-
void* args[] = {&p};
|
88 |
-
AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
89 |
-
return y;
|
90 |
-
}
|
91 |
-
|
92 |
-
//------------------------------------------------------------------------
|
93 |
-
|
94 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
95 |
-
{
|
96 |
-
m.def("bias_act", &bias_act);
|
97 |
-
}
|
98 |
-
|
99 |
-
//------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PTI/torch_utils/ops/bias_act.cu
DELETED
@@ -1,173 +0,0 @@
|
|
1 |
-
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
//
|
3 |
-
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
// and proprietary rights in and to this software, related documentation
|
5 |
-
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
// distribution of this software and related documentation without an express
|
7 |
-
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
#include <c10/util/Half.h>
|
10 |
-
#include "bias_act.h"
|
11 |
-
|
12 |
-
//------------------------------------------------------------------------
|
13 |
-
// Helpers.
|
14 |
-
|
15 |
-
template <class T> struct InternalType;
|
16 |
-
template <> struct InternalType<double> { typedef double scalar_t; };
|
17 |
-
template <> struct InternalType<float> { typedef float scalar_t; };
|
18 |
-
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
19 |
-
|
20 |
-
//------------------------------------------------------------------------
|
21 |
-
// CUDA kernel.
|
22 |
-
|
23 |
-
template <class T, int A>
|
24 |
-
__global__ void bias_act_kernel(bias_act_kernel_params p)
|
25 |
-
{
|
26 |
-
typedef typename InternalType<T>::scalar_t scalar_t;
|
27 |
-
int G = p.grad;
|
28 |
-
scalar_t alpha = (scalar_t)p.alpha;
|
29 |
-
scalar_t gain = (scalar_t)p.gain;
|
30 |
-
scalar_t clamp = (scalar_t)p.clamp;
|
31 |
-
scalar_t one = (scalar_t)1;
|
32 |
-
scalar_t two = (scalar_t)2;
|
33 |
-
scalar_t expRange = (scalar_t)80;
|
34 |
-
scalar_t halfExpRange = (scalar_t)40;
|
35 |
-
scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
|
36 |
-
scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
|
37 |
-
|
38 |
-
// Loop over elements.
|
39 |
-
int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
|
40 |
-
for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
|
41 |
-
{
|
42 |
-
// Load.
|
43 |
-
scalar_t x = (scalar_t)((const T*)p.x)[xi];
|
44 |
-
scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
|
45 |
-
scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
|
46 |
-
scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
|
47 |
-
scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
|
48 |
-
scalar_t yy = (gain != 0) ? yref / gain : 0;
|
49 |
-
scalar_t y = 0;
|
50 |
-
|
51 |
-
// Apply bias.
|
52 |
-
((G == 0) ? x : xref) += b;
|
53 |
-
|
54 |
-
// linear
|
55 |
-
if (A == 1)
|
56 |
-
{
|
57 |
-
if (G == 0) y = x;
|
58 |
-
if (G == 1) y = x;
|
59 |
-
}
|
60 |
-
|
61 |
-
// relu
|
62 |
-
if (A == 2)
|
63 |
-
{
|
64 |
-
if (G == 0) y = (x > 0) ? x : 0;
|
65 |
-
if (G == 1) y = (yy > 0) ? x : 0;
|
66 |
-
}
|
67 |
-
|
68 |
-
// lrelu
|
69 |
-
if (A == 3)
|
70 |
-
{
|
71 |
-
if (G == 0) y = (x > 0) ? x : x * alpha;
|
72 |
-
if (G == 1) y = (yy > 0) ? x : x * alpha;
|
73 |
-
}
|
74 |
-
|
75 |
-
// tanh
|
76 |
-
if (A == 4)
|
77 |
-
{
|
78 |
-
if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
|
79 |
-
if (G == 1) y = x * (one - yy * yy);
|
80 |
-
if (G == 2) y = x * (one - yy * yy) * (-two * yy);
|
81 |
-
}
|
82 |
-
|
83 |
-
// sigmoid
|
84 |
-
if (A == 5)
|
85 |
-
{
|
86 |
-
if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
|
87 |
-
if (G == 1) y = x * yy * (one - yy);
|
88 |
-
if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
|
89 |
-
}
|
90 |
-
|
91 |
-
// elu
|
92 |
-
if (A == 6)
|
93 |
-
{
|
94 |
-
if (G == 0) y = (x >= 0) ? x : exp(x) - one;
|
95 |
-
if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
|
96 |
-
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
|
97 |
-
}
|
98 |
-
|
99 |
-
// selu
|
100 |
-
if (A == 7)
|
101 |
-
{
|
102 |
-
if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
|
103 |
-
if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
|
104 |
-
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
|
105 |
-
}
|
106 |
-
|
107 |
-
// softplus
|
108 |
-
if (A == 8)
|
109 |
-
{
|
110 |
-
if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
|
111 |
-
if (G == 1) y = x * (one - exp(-yy));
|
112 |
-
if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
|
113 |
-
}
|
114 |
-
|
115 |
-
// swish
|
116 |
-
if (A == 9)
|
117 |
-
{
|
118 |
-
if (G == 0)
|
119 |
-
y = (x < -expRange) ? 0 : x / (exp(-x) + one);
|
120 |
-
else
|
121 |
-
{
|
122 |
-
scalar_t c = exp(xref);
|
123 |
-
scalar_t d = c + one;
|
124 |
-
if (G == 1)
|
125 |
-
y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
|
126 |
-
else
|
127 |
-
y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
|
128 |
-
yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
|
129 |
-
}
|
130 |
-
}
|
131 |
-
|
132 |
-
// Apply gain.
|
133 |
-
y *= gain * dy;
|
134 |
-
|
135 |
-
// Clamp.
|
136 |
-
if (clamp >= 0)
|
137 |
-
{
|
138 |
-
if (G == 0)
|
139 |
-
y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
|
140 |
-
else
|
141 |
-
y = (yref > -clamp & yref < clamp) ? y : 0;
|
142 |
-
}
|
143 |
-
|
144 |
-
// Store.
|
145 |
-
((T*)p.y)[xi] = (T)y;
|
146 |
-
}
|
147 |
-
}
|
148 |
-
|
149 |
-
//------------------------------------------------------------------------
|
150 |
-
// CUDA kernel selection.
|
151 |
-
|
152 |
-
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
|
153 |
-
{
|
154 |
-
if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
|
155 |
-
if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
|
156 |
-
if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
|
157 |
-
if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
|
158 |
-
if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
|
159 |
-
if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
|
160 |
-
if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
|
161 |
-
if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
|
162 |
-
if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
|
163 |
-
return NULL;
|
164 |
-
}
|
165 |
-
|
166 |
-
//------------------------------------------------------------------------
|
167 |
-
// Template specializations.
|
168 |
-
|
169 |
-
template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
|
170 |
-
template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
|
171 |
-
template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
|
172 |
-
|
173 |
-
//------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PTI/torch_utils/ops/bias_act.h
DELETED
@@ -1,38 +0,0 @@
|
|
1 |
-
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
//
|
3 |
-
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
// and proprietary rights in and to this software, related documentation
|
5 |
-
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
// distribution of this software and related documentation without an express
|
7 |
-
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
//------------------------------------------------------------------------
|
10 |
-
// CUDA kernel parameters.
|
11 |
-
|
12 |
-
struct bias_act_kernel_params
|
13 |
-
{
|
14 |
-
const void* x; // [sizeX]
|
15 |
-
const void* b; // [sizeB] or NULL
|
16 |
-
const void* xref; // [sizeX] or NULL
|
17 |
-
const void* yref; // [sizeX] or NULL
|
18 |
-
const void* dy; // [sizeX] or NULL
|
19 |
-
void* y; // [sizeX]
|
20 |
-
|
21 |
-
int grad;
|
22 |
-
int act;
|
23 |
-
float alpha;
|
24 |
-
float gain;
|
25 |
-
float clamp;
|
26 |
-
|
27 |
-
int sizeX;
|
28 |
-
int sizeB;
|
29 |
-
int stepB;
|
30 |
-
int loopX;
|
31 |
-
};
|
32 |
-
|
33 |
-
//------------------------------------------------------------------------
|
34 |
-
// CUDA kernel selection.
|
35 |
-
|
36 |
-
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
|
37 |
-
|
38 |
-
//------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PTI/torch_utils/ops/bias_act.py
DELETED
@@ -1,212 +0,0 @@
|
|
1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
"""Custom PyTorch ops for efficient bias and activation."""
|
10 |
-
|
11 |
-
import os
|
12 |
-
import warnings
|
13 |
-
import numpy as np
|
14 |
-
import torch
|
15 |
-
import dnnlib
|
16 |
-
import traceback
|
17 |
-
|
18 |
-
from .. import custom_ops
|
19 |
-
from .. import misc
|
20 |
-
|
21 |
-
#----------------------------------------------------------------------------
|
22 |
-
|
23 |
-
activation_funcs = {
|
24 |
-
'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
|
25 |
-
'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
|
26 |
-
'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
|
27 |
-
'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
|
28 |
-
'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
|
29 |
-
'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
|
30 |
-
'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
|
31 |
-
'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
|
32 |
-
'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
|
33 |
-
}
|
34 |
-
|
35 |
-
#----------------------------------------------------------------------------
|
36 |
-
|
37 |
-
_inited = False
|
38 |
-
_plugin = None
|
39 |
-
_null_tensor = torch.empty([0])
|
40 |
-
|
41 |
-
def _init():
|
42 |
-
global _inited, _plugin
|
43 |
-
if not _inited:
|
44 |
-
_inited = True
|
45 |
-
sources = ['bias_act.cpp', 'bias_act.cu']
|
46 |
-
sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
|
47 |
-
try:
|
48 |
-
_plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
|
49 |
-
except:
|
50 |
-
warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
|
51 |
-
return _plugin is not None
|
52 |
-
|
53 |
-
#----------------------------------------------------------------------------
|
54 |
-
|
55 |
-
def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
|
56 |
-
r"""Fused bias and activation function.
|
57 |
-
|
58 |
-
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
|
59 |
-
and scales the result by `gain`. Each of the steps is optional. In most cases,
|
60 |
-
the fused op is considerably more efficient than performing the same calculation
|
61 |
-
using standard PyTorch ops. It supports first and second order gradients,
|
62 |
-
but not third order gradients.
|
63 |
-
|
64 |
-
Args:
|
65 |
-
x: Input activation tensor. Can be of any shape.
|
66 |
-
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
|
67 |
-
as `x`. The shape must be known, and it must match the dimension of `x`
|
68 |
-
corresponding to `dim`.
|
69 |
-
dim: The dimension in `x` corresponding to the elements of `b`.
|
70 |
-
The value of `dim` is ignored if `b` is not specified.
|
71 |
-
act: Name of the activation function to evaluate, or `"linear"` to disable.
|
72 |
-
Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
|
73 |
-
See `activation_funcs` for a full list. `None` is not allowed.
|
74 |
-
alpha: Shape parameter for the activation function, or `None` to use the default.
|
75 |
-
gain: Scaling factor for the output tensor, or `None` to use default.
|
76 |
-
See `activation_funcs` for the default scaling of each activation function.
|
77 |
-
If unsure, consider specifying 1.
|
78 |
-
clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
|
79 |
-
the clamping (default).
|
80 |
-
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
81 |
-
|
82 |
-
Returns:
|
83 |
-
Tensor of the same shape and datatype as `x`.
|
84 |
-
"""
|
85 |
-
assert isinstance(x, torch.Tensor)
|
86 |
-
assert impl in ['ref', 'cuda']
|
87 |
-
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
88 |
-
return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
|
89 |
-
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
|
90 |
-
|
91 |
-
#----------------------------------------------------------------------------
|
92 |
-
|
93 |
-
@misc.profiled_function
|
94 |
-
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
95 |
-
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops.
|
96 |
-
"""
|
97 |
-
assert isinstance(x, torch.Tensor)
|
98 |
-
assert clamp is None or clamp >= 0
|
99 |
-
spec = activation_funcs[act]
|
100 |
-
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
101 |
-
gain = float(gain if gain is not None else spec.def_gain)
|
102 |
-
clamp = float(clamp if clamp is not None else -1)
|
103 |
-
|
104 |
-
# Add bias.
|
105 |
-
if b is not None:
|
106 |
-
assert isinstance(b, torch.Tensor) and b.ndim == 1
|
107 |
-
assert 0 <= dim < x.ndim
|
108 |
-
assert b.shape[0] == x.shape[dim]
|
109 |
-
x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
|
110 |
-
|
111 |
-
# Evaluate activation function.
|
112 |
-
alpha = float(alpha)
|
113 |
-
x = spec.func(x, alpha=alpha)
|
114 |
-
|
115 |
-
# Scale by gain.
|
116 |
-
gain = float(gain)
|
117 |
-
if gain != 1:
|
118 |
-
x = x * gain
|
119 |
-
|
120 |
-
# Clamp.
|
121 |
-
if clamp >= 0:
|
122 |
-
x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
|
123 |
-
return x
|
124 |
-
|
125 |
-
#----------------------------------------------------------------------------
|
126 |
-
|
127 |
-
_bias_act_cuda_cache = dict()
|
128 |
-
|
129 |
-
def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
130 |
-
"""Fast CUDA implementation of `bias_act()` using custom ops.
|
131 |
-
"""
|
132 |
-
# Parse arguments.
|
133 |
-
assert clamp is None or clamp >= 0
|
134 |
-
spec = activation_funcs[act]
|
135 |
-
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
136 |
-
gain = float(gain if gain is not None else spec.def_gain)
|
137 |
-
clamp = float(clamp if clamp is not None else -1)
|
138 |
-
|
139 |
-
# Lookup from cache.
|
140 |
-
key = (dim, act, alpha, gain, clamp)
|
141 |
-
if key in _bias_act_cuda_cache:
|
142 |
-
return _bias_act_cuda_cache[key]
|
143 |
-
|
144 |
-
# Forward op.
|
145 |
-
class BiasActCuda(torch.autograd.Function):
|
146 |
-
@staticmethod
|
147 |
-
def forward(ctx, x, b): # pylint: disable=arguments-differ
|
148 |
-
ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
|
149 |
-
x = x.contiguous(memory_format=ctx.memory_format)
|
150 |
-
b = b.contiguous() if b is not None else _null_tensor
|
151 |
-
y = x
|
152 |
-
if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
|
153 |
-
y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
|
154 |
-
ctx.save_for_backward(
|
155 |
-
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
156 |
-
b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
157 |
-
y if 'y' in spec.ref else _null_tensor)
|
158 |
-
return y
|
159 |
-
|
160 |
-
@staticmethod
|
161 |
-
def backward(ctx, dy): # pylint: disable=arguments-differ
|
162 |
-
dy = dy.contiguous(memory_format=ctx.memory_format)
|
163 |
-
x, b, y = ctx.saved_tensors
|
164 |
-
dx = None
|
165 |
-
db = None
|
166 |
-
|
167 |
-
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
|
168 |
-
dx = dy
|
169 |
-
if act != 'linear' or gain != 1 or clamp >= 0:
|
170 |
-
dx = BiasActCudaGrad.apply(dy, x, b, y)
|
171 |
-
|
172 |
-
if ctx.needs_input_grad[1]:
|
173 |
-
db = dx.sum([i for i in range(dx.ndim) if i != dim])
|
174 |
-
|
175 |
-
return dx, db
|
176 |
-
|
177 |
-
# Backward op.
|
178 |
-
class BiasActCudaGrad(torch.autograd.Function):
|
179 |
-
@staticmethod
|
180 |
-
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
|
181 |
-
ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
|
182 |
-
dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
|
183 |
-
ctx.save_for_backward(
|
184 |
-
dy if spec.has_2nd_grad else _null_tensor,
|
185 |
-
x, b, y)
|
186 |
-
return dx
|
187 |
-
|
188 |
-
@staticmethod
|
189 |
-
def backward(ctx, d_dx): # pylint: disable=arguments-differ
|
190 |
-
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
|
191 |
-
dy, x, b, y = ctx.saved_tensors
|
192 |
-
d_dy = None
|
193 |
-
d_x = None
|
194 |
-
d_b = None
|
195 |
-
d_y = None
|
196 |
-
|
197 |
-
if ctx.needs_input_grad[0]:
|
198 |
-
d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
|
199 |
-
|
200 |
-
if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
|
201 |
-
d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
|
202 |
-
|
203 |
-
if spec.has_2nd_grad and ctx.needs_input_grad[2]:
|
204 |
-
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
|
205 |
-
|
206 |
-
return d_dy, d_x, d_b, d_y
|
207 |
-
|
208 |
-
# Add to cache.
|
209 |
-
_bias_act_cuda_cache[key] = BiasActCuda
|
210 |
-
return BiasActCuda
|
211 |
-
|
212 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PTI/torch_utils/ops/conv2d_gradfix.py
DELETED
@@ -1,170 +0,0 @@
|
|
1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
"""Custom replacement for `torch.nn.functional.conv2d` that supports
|
10 |
-
arbitrarily high order gradients with zero performance penalty."""
|
11 |
-
|
12 |
-
import warnings
|
13 |
-
import contextlib
|
14 |
-
import torch
|
15 |
-
|
16 |
-
# pylint: disable=redefined-builtin
|
17 |
-
# pylint: disable=arguments-differ
|
18 |
-
# pylint: disable=protected-access
|
19 |
-
|
20 |
-
#----------------------------------------------------------------------------
|
21 |
-
|
22 |
-
enabled = False # Enable the custom op by setting this to true.
|
23 |
-
weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
|
24 |
-
|
25 |
-
@contextlib.contextmanager
|
26 |
-
def no_weight_gradients():
|
27 |
-
global weight_gradients_disabled
|
28 |
-
old = weight_gradients_disabled
|
29 |
-
weight_gradients_disabled = True
|
30 |
-
yield
|
31 |
-
weight_gradients_disabled = old
|
32 |
-
|
33 |
-
#----------------------------------------------------------------------------
|
34 |
-
|
35 |
-
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
36 |
-
if _should_use_custom_op(input):
|
37 |
-
return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
|
38 |
-
return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
39 |
-
|
40 |
-
def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
41 |
-
if _should_use_custom_op(input):
|
42 |
-
return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
|
43 |
-
return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
|
44 |
-
|
45 |
-
#----------------------------------------------------------------------------
|
46 |
-
|
47 |
-
def _should_use_custom_op(input):
|
48 |
-
assert isinstance(input, torch.Tensor)
|
49 |
-
if (not enabled) or (not torch.backends.cudnn.enabled):
|
50 |
-
return False
|
51 |
-
if input.device.type != 'cuda':
|
52 |
-
return False
|
53 |
-
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
|
54 |
-
return True
|
55 |
-
warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
|
56 |
-
return False
|
57 |
-
|
58 |
-
def _tuple_of_ints(xs, ndim):
|
59 |
-
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
|
60 |
-
assert len(xs) == ndim
|
61 |
-
assert all(isinstance(x, int) for x in xs)
|
62 |
-
return xs
|
63 |
-
|
64 |
-
#----------------------------------------------------------------------------
|
65 |
-
|
66 |
-
_conv2d_gradfix_cache = dict()
|
67 |
-
|
68 |
-
def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
|
69 |
-
# Parse arguments.
|
70 |
-
ndim = 2
|
71 |
-
weight_shape = tuple(weight_shape)
|
72 |
-
stride = _tuple_of_ints(stride, ndim)
|
73 |
-
padding = _tuple_of_ints(padding, ndim)
|
74 |
-
output_padding = _tuple_of_ints(output_padding, ndim)
|
75 |
-
dilation = _tuple_of_ints(dilation, ndim)
|
76 |
-
|
77 |
-
# Lookup from cache.
|
78 |
-
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
|
79 |
-
if key in _conv2d_gradfix_cache:
|
80 |
-
return _conv2d_gradfix_cache[key]
|
81 |
-
|
82 |
-
# Validate arguments.
|
83 |
-
assert groups >= 1
|
84 |
-
assert len(weight_shape) == ndim + 2
|
85 |
-
assert all(stride[i] >= 1 for i in range(ndim))
|
86 |
-
assert all(padding[i] >= 0 for i in range(ndim))
|
87 |
-
assert all(dilation[i] >= 0 for i in range(ndim))
|
88 |
-
if not transpose:
|
89 |
-
assert all(output_padding[i] == 0 for i in range(ndim))
|
90 |
-
else: # transpose
|
91 |
-
assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
|
92 |
-
|
93 |
-
# Helpers.
|
94 |
-
common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
|
95 |
-
def calc_output_padding(input_shape, output_shape):
|
96 |
-
if transpose:
|
97 |
-
return [0, 0]
|
98 |
-
return [
|
99 |
-
input_shape[i + 2]
|
100 |
-
- (output_shape[i + 2] - 1) * stride[i]
|
101 |
-
- (1 - 2 * padding[i])
|
102 |
-
- dilation[i] * (weight_shape[i + 2] - 1)
|
103 |
-
for i in range(ndim)
|
104 |
-
]
|
105 |
-
|
106 |
-
# Forward & backward.
|
107 |
-
class Conv2d(torch.autograd.Function):
|
108 |
-
@staticmethod
|
109 |
-
def forward(ctx, input, weight, bias):
|
110 |
-
assert weight.shape == weight_shape
|
111 |
-
if not transpose:
|
112 |
-
output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
|
113 |
-
else: # transpose
|
114 |
-
output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
|
115 |
-
ctx.save_for_backward(input, weight)
|
116 |
-
return output
|
117 |
-
|
118 |
-
@staticmethod
|
119 |
-
def backward(ctx, grad_output):
|
120 |
-
input, weight = ctx.saved_tensors
|
121 |
-
grad_input = None
|
122 |
-
grad_weight = None
|
123 |
-
grad_bias = None
|
124 |
-
|
125 |
-
if ctx.needs_input_grad[0]:
|
126 |
-
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
|
127 |
-
grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
|
128 |
-
assert grad_input.shape == input.shape
|
129 |
-
|
130 |
-
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
131 |
-
grad_weight = Conv2dGradWeight.apply(grad_output, input)
|
132 |
-
assert grad_weight.shape == weight_shape
|
133 |
-
|
134 |
-
if ctx.needs_input_grad[2]:
|
135 |
-
grad_bias = grad_output.sum([0, 2, 3])
|
136 |
-
|
137 |
-
return grad_input, grad_weight, grad_bias
|
138 |
-
|
139 |
-
# Gradient with respect to the weights.
|
140 |
-
class Conv2dGradWeight(torch.autograd.Function):
|
141 |
-
@staticmethod
|
142 |
-
def forward(ctx, grad_output, input):
|
143 |
-
op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
|
144 |
-
flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
|
145 |
-
grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
|
146 |
-
assert grad_weight.shape == weight_shape
|
147 |
-
ctx.save_for_backward(grad_output, input)
|
148 |
-
return grad_weight
|
149 |
-
|
150 |
-
@staticmethod
|
151 |
-
def backward(ctx, grad2_grad_weight):
|
152 |
-
grad_output, input = ctx.saved_tensors
|
153 |
-
grad2_grad_output = None
|
154 |
-
grad2_input = None
|
155 |
-
|
156 |
-
if ctx.needs_input_grad[0]:
|
157 |
-
grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
|
158 |
-
assert grad2_grad_output.shape == grad_output.shape
|
159 |
-
|
160 |
-
if ctx.needs_input_grad[1]:
|
161 |
-
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
|
162 |
-
grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
|
163 |
-
assert grad2_input.shape == input.shape
|
164 |
-
|
165 |
-
return grad2_grad_output, grad2_input
|
166 |
-
|
167 |
-
_conv2d_gradfix_cache[key] = Conv2d
|
168 |
-
return Conv2d
|
169 |
-
|
170 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PTI/torch_utils/ops/conv2d_resample.py
DELETED
@@ -1,156 +0,0 @@
|
|
1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
"""2D convolution with optional up/downsampling."""
|
10 |
-
|
11 |
-
import torch
|
12 |
-
|
13 |
-
from .. import misc
|
14 |
-
from . import conv2d_gradfix
|
15 |
-
from . import upfirdn2d
|
16 |
-
from .upfirdn2d import _parse_padding
|
17 |
-
from .upfirdn2d import _get_filter_size
|
18 |
-
|
19 |
-
#----------------------------------------------------------------------------
|
20 |
-
|
21 |
-
def _get_weight_shape(w):
|
22 |
-
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
|
23 |
-
shape = [int(sz) for sz in w.shape]
|
24 |
-
misc.assert_shape(w, shape)
|
25 |
-
return shape
|
26 |
-
|
27 |
-
#----------------------------------------------------------------------------
|
28 |
-
|
29 |
-
def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
|
30 |
-
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
|
31 |
-
"""
|
32 |
-
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
33 |
-
|
34 |
-
# Flip weight if requested.
|
35 |
-
if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
|
36 |
-
w = w.flip([2, 3])
|
37 |
-
|
38 |
-
# Workaround performance pitfall in cuDNN 8.0.5, triggered when using
|
39 |
-
# 1x1 kernel + memory_format=channels_last + less than 64 channels.
|
40 |
-
if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
|
41 |
-
if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
|
42 |
-
if out_channels <= 4 and groups == 1:
|
43 |
-
in_shape = x.shape
|
44 |
-
x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
|
45 |
-
x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
|
46 |
-
else:
|
47 |
-
x = x.to(memory_format=torch.contiguous_format)
|
48 |
-
w = w.to(memory_format=torch.contiguous_format)
|
49 |
-
x = conv2d_gradfix.conv2d(x, w, groups=groups)
|
50 |
-
return x.to(memory_format=torch.channels_last)
|
51 |
-
|
52 |
-
# Otherwise => execute using conv2d_gradfix.
|
53 |
-
op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
|
54 |
-
return op(x, w, stride=stride, padding=padding, groups=groups)
|
55 |
-
|
56 |
-
#----------------------------------------------------------------------------
|
57 |
-
|
58 |
-
@misc.profiled_function
|
59 |
-
def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
|
60 |
-
r"""2D convolution with optional up/downsampling.
|
61 |
-
|
62 |
-
Padding is performed only once at the beginning, not between the operations.
|
63 |
-
|
64 |
-
Args:
|
65 |
-
x: Input tensor of shape
|
66 |
-
`[batch_size, in_channels, in_height, in_width]`.
|
67 |
-
w: Weight tensor of shape
|
68 |
-
`[out_channels, in_channels//groups, kernel_height, kernel_width]`.
|
69 |
-
f: Low-pass filter for up/downsampling. Must be prepared beforehand by
|
70 |
-
calling upfirdn2d.setup_filter(). None = identity (default).
|
71 |
-
up: Integer upsampling factor (default: 1).
|
72 |
-
down: Integer downsampling factor (default: 1).
|
73 |
-
padding: Padding with respect to the upsampled image. Can be a single number
|
74 |
-
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
75 |
-
(default: 0).
|
76 |
-
groups: Split input channels into N groups (default: 1).
|
77 |
-
flip_weight: False = convolution, True = correlation (default: True).
|
78 |
-
flip_filter: False = convolution, True = correlation (default: False).
|
79 |
-
|
80 |
-
Returns:
|
81 |
-
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
82 |
-
"""
|
83 |
-
# Validate arguments.
|
84 |
-
assert isinstance(x, torch.Tensor) and (x.ndim == 4)
|
85 |
-
assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
|
86 |
-
assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
|
87 |
-
assert isinstance(up, int) and (up >= 1)
|
88 |
-
assert isinstance(down, int) and (down >= 1)
|
89 |
-
assert isinstance(groups, int) and (groups >= 1)
|
90 |
-
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
91 |
-
fw, fh = _get_filter_size(f)
|
92 |
-
px0, px1, py0, py1 = _parse_padding(padding)
|
93 |
-
|
94 |
-
# Adjust padding to account for up/downsampling.
|
95 |
-
if up > 1:
|
96 |
-
px0 += (fw + up - 1) // 2
|
97 |
-
px1 += (fw - up) // 2
|
98 |
-
py0 += (fh + up - 1) // 2
|
99 |
-
py1 += (fh - up) // 2
|
100 |
-
if down > 1:
|
101 |
-
px0 += (fw - down + 1) // 2
|
102 |
-
px1 += (fw - down) // 2
|
103 |
-
py0 += (fh - down + 1) // 2
|
104 |
-
py1 += (fh - down) // 2
|
105 |
-
|
106 |
-
# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
|
107 |
-
if kw == 1 and kh == 1 and (down > 1 and up == 1):
|
108 |
-
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
|
109 |
-
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
110 |
-
return x
|
111 |
-
|
112 |
-
# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
|
113 |
-
if kw == 1 and kh == 1 and (up > 1 and down == 1):
|
114 |
-
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
115 |
-
x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
|
116 |
-
return x
|
117 |
-
|
118 |
-
# Fast path: downsampling only => use strided convolution.
|
119 |
-
if down > 1 and up == 1:
|
120 |
-
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
|
121 |
-
x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
|
122 |
-
return x
|
123 |
-
|
124 |
-
# Fast path: upsampling with optional downsampling => use transpose strided convolution.
|
125 |
-
if up > 1:
|
126 |
-
if groups == 1:
|
127 |
-
w = w.transpose(0, 1)
|
128 |
-
else:
|
129 |
-
w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
|
130 |
-
w = w.transpose(1, 2)
|
131 |
-
w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
|
132 |
-
px0 -= kw - 1
|
133 |
-
px1 -= kw - up
|
134 |
-
py0 -= kh - 1
|
135 |
-
py1 -= kh - up
|
136 |
-
pxt = max(min(-px0, -px1), 0)
|
137 |
-
pyt = max(min(-py0, -py1), 0)
|
138 |
-
x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
|
139 |
-
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
|
140 |
-
if down > 1:
|
141 |
-
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
142 |
-
return x
|
143 |
-
|
144 |
-
# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
|
145 |
-
if up == 1 and down == 1:
|
146 |
-
if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
|
147 |
-
return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
|
148 |
-
|
149 |
-
# Fallback: Generic reference implementation.
|
150 |
-
x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
|
151 |
-
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
152 |
-
if down > 1:
|
153 |
-
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
154 |
-
return x
|
155 |
-
|
156 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PTI/torch_utils/ops/fma.py
DELETED
@@ -1,60 +0,0 @@
|
|
1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
|
10 |
-
|
11 |
-
import torch
|
12 |
-
|
13 |
-
#----------------------------------------------------------------------------
|
14 |
-
|
15 |
-
def fma(a, b, c): # => a * b + c
|
16 |
-
return _FusedMultiplyAdd.apply(a, b, c)
|
17 |
-
|
18 |
-
#----------------------------------------------------------------------------
|
19 |
-
|
20 |
-
class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
|
21 |
-
@staticmethod
|
22 |
-
def forward(ctx, a, b, c): # pylint: disable=arguments-differ
|
23 |
-
out = torch.addcmul(c, a, b)
|
24 |
-
ctx.save_for_backward(a, b)
|
25 |
-
ctx.c_shape = c.shape
|
26 |
-
return out
|
27 |
-
|
28 |
-
@staticmethod
|
29 |
-
def backward(ctx, dout): # pylint: disable=arguments-differ
|
30 |
-
a, b = ctx.saved_tensors
|
31 |
-
c_shape = ctx.c_shape
|
32 |
-
da = None
|
33 |
-
db = None
|
34 |
-
dc = None
|
35 |
-
|
36 |
-
if ctx.needs_input_grad[0]:
|
37 |
-
da = _unbroadcast(dout * b, a.shape)
|
38 |
-
|
39 |
-
if ctx.needs_input_grad[1]:
|
40 |
-
db = _unbroadcast(dout * a, b.shape)
|
41 |
-
|
42 |
-
if ctx.needs_input_grad[2]:
|
43 |
-
dc = _unbroadcast(dout, c_shape)
|
44 |
-
|
45 |
-
return da, db, dc
|
46 |
-
|
47 |
-
#----------------------------------------------------------------------------
|
48 |
-
|
49 |
-
def _unbroadcast(x, shape):
|
50 |
-
extra_dims = x.ndim - len(shape)
|
51 |
-
assert extra_dims >= 0
|
52 |
-
dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
|
53 |
-
if len(dim):
|
54 |
-
x = x.sum(dim=dim, keepdim=True)
|
55 |
-
if extra_dims:
|
56 |
-
x = x.reshape(-1, *x.shape[extra_dims+1:])
|
57 |
-
assert x.shape == shape
|
58 |
-
return x
|
59 |
-
|
60 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PTI/torch_utils/ops/grid_sample_gradfix.py
DELETED
@@ -1,83 +0,0 @@
|
|
1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
"""Custom replacement for `torch.nn.functional.grid_sample` that
|
10 |
-
supports arbitrarily high order gradients between the input and output.
|
11 |
-
Only works on 2D images and assumes
|
12 |
-
`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
|
13 |
-
|
14 |
-
import warnings
|
15 |
-
import torch
|
16 |
-
|
17 |
-
# pylint: disable=redefined-builtin
|
18 |
-
# pylint: disable=arguments-differ
|
19 |
-
# pylint: disable=protected-access
|
20 |
-
|
21 |
-
#----------------------------------------------------------------------------
|
22 |
-
|
23 |
-
enabled = False # Enable the custom op by setting this to true.
|
24 |
-
|
25 |
-
#----------------------------------------------------------------------------
|
26 |
-
|
27 |
-
def grid_sample(input, grid):
|
28 |
-
if _should_use_custom_op():
|
29 |
-
return _GridSample2dForward.apply(input, grid)
|
30 |
-
return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
31 |
-
|
32 |
-
#----------------------------------------------------------------------------
|
33 |
-
|
34 |
-
def _should_use_custom_op():
|
35 |
-
if not enabled:
|
36 |
-
return False
|
37 |
-
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
|
38 |
-
return True
|
39 |
-
warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
|
40 |
-
return False
|
41 |
-
|
42 |
-
#----------------------------------------------------------------------------
|
43 |
-
|
44 |
-
class _GridSample2dForward(torch.autograd.Function):
|
45 |
-
@staticmethod
|
46 |
-
def forward(ctx, input, grid):
|
47 |
-
assert input.ndim == 4
|
48 |
-
assert grid.ndim == 4
|
49 |
-
output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
50 |
-
ctx.save_for_backward(input, grid)
|
51 |
-
return output
|
52 |
-
|
53 |
-
@staticmethod
|
54 |
-
def backward(ctx, grad_output):
|
55 |
-
input, grid = ctx.saved_tensors
|
56 |
-
grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
|
57 |
-
return grad_input, grad_grid
|
58 |
-
|
59 |
-
#----------------------------------------------------------------------------
|
60 |
-
|
61 |
-
class _GridSample2dBackward(torch.autograd.Function):
|
62 |
-
@staticmethod
|
63 |
-
def forward(ctx, grad_output, input, grid):
|
64 |
-
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
|
65 |
-
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
|
66 |
-
ctx.save_for_backward(grid)
|
67 |
-
return grad_input, grad_grid
|
68 |
-
|
69 |
-
@staticmethod
|
70 |
-
def backward(ctx, grad2_grad_input, grad2_grad_grid):
|
71 |
-
_ = grad2_grad_grid # unused
|
72 |
-
grid, = ctx.saved_tensors
|
73 |
-
grad2_grad_output = None
|
74 |
-
grad2_input = None
|
75 |
-
grad2_grid = None
|
76 |
-
|
77 |
-
if ctx.needs_input_grad[0]:
|
78 |
-
grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
|
79 |
-
|
80 |
-
assert not ctx.needs_input_grad[2]
|
81 |
-
return grad2_grad_output, grad2_input, grad2_grid
|
82 |
-
|
83 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PTI/torch_utils/ops/upfirdn2d.cpp
DELETED
@@ -1,103 +0,0 @@
|
|
1 |
-
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
//
|
3 |
-
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
// and proprietary rights in and to this software, related documentation
|
5 |
-
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
// distribution of this software and related documentation without an express
|
7 |
-
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
#include <torch/extension.h>
|
10 |
-
#include <ATen/cuda/CUDAContext.h>
|
11 |
-
#include <c10/cuda/CUDAGuard.h>
|
12 |
-
#include "upfirdn2d.h"
|
13 |
-
|
14 |
-
//------------------------------------------------------------------------
|
15 |
-
|
16 |
-
static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
|
17 |
-
{
|
18 |
-
// Validate arguments.
|
19 |
-
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
20 |
-
TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
|
21 |
-
TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
|
22 |
-
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
23 |
-
TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
|
24 |
-
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
|
25 |
-
TORCH_CHECK(f.dim() == 2, "f must be rank 2");
|
26 |
-
TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
|
27 |
-
TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
|
28 |
-
TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
|
29 |
-
|
30 |
-
// Create output tensor.
|
31 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
32 |
-
int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
|
33 |
-
int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
|
34 |
-
TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
|
35 |
-
torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
|
36 |
-
TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
|
37 |
-
|
38 |
-
// Initialize CUDA kernel parameters.
|
39 |
-
upfirdn2d_kernel_params p;
|
40 |
-
p.x = x.data_ptr();
|
41 |
-
p.f = f.data_ptr<float>();
|
42 |
-
p.y = y.data_ptr();
|
43 |
-
p.up = make_int2(upx, upy);
|
44 |
-
p.down = make_int2(downx, downy);
|
45 |
-
p.pad0 = make_int2(padx0, pady0);
|
46 |
-
p.flip = (flip) ? 1 : 0;
|
47 |
-
p.gain = gain;
|
48 |
-
p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
|
49 |
-
p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
|
50 |
-
p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
|
51 |
-
p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
|
52 |
-
p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
|
53 |
-
p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
|
54 |
-
p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
|
55 |
-
p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
|
56 |
-
|
57 |
-
// Choose CUDA kernel.
|
58 |
-
upfirdn2d_kernel_spec spec;
|
59 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
60 |
-
{
|
61 |
-
spec = choose_upfirdn2d_kernel<scalar_t>(p);
|
62 |
-
});
|
63 |
-
|
64 |
-
// Set looping options.
|
65 |
-
p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
|
66 |
-
p.loopMinor = spec.loopMinor;
|
67 |
-
p.loopX = spec.loopX;
|
68 |
-
p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
|
69 |
-
p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
|
70 |
-
|
71 |
-
// Compute grid size.
|
72 |
-
dim3 blockSize, gridSize;
|
73 |
-
if (spec.tileOutW < 0) // large
|
74 |
-
{
|
75 |
-
blockSize = dim3(4, 32, 1);
|
76 |
-
gridSize = dim3(
|
77 |
-
((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
|
78 |
-
(p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
|
79 |
-
p.launchMajor);
|
80 |
-
}
|
81 |
-
else // small
|
82 |
-
{
|
83 |
-
blockSize = dim3(256, 1, 1);
|
84 |
-
gridSize = dim3(
|
85 |
-
((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
|
86 |
-
(p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
|
87 |
-
p.launchMajor);
|
88 |
-
}
|
89 |
-
|
90 |
-
// Launch CUDA kernel.
|
91 |
-
void* args[] = {&p};
|
92 |
-
AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
93 |
-
return y;
|
94 |
-
}
|
95 |
-
|
96 |
-
//------------------------------------------------------------------------
|
97 |
-
|
98 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
99 |
-
{
|
100 |
-
m.def("upfirdn2d", &upfirdn2d);
|
101 |
-
}
|
102 |
-
|
103 |
-
//------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PTI/torch_utils/ops/upfirdn2d.cu
DELETED
@@ -1,350 +0,0 @@
|
|
1 |
-
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
//
|
3 |
-
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
// and proprietary rights in and to this software, related documentation
|
5 |
-
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
// distribution of this software and related documentation without an express
|
7 |
-
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
#include <c10/util/Half.h>
|
10 |
-
#include "upfirdn2d.h"
|
11 |
-
|
12 |
-
//------------------------------------------------------------------------
|
13 |
-
// Helpers.
|
14 |
-
|
15 |
-
template <class T> struct InternalType;
|
16 |
-
template <> struct InternalType<double> { typedef double scalar_t; };
|
17 |
-
template <> struct InternalType<float> { typedef float scalar_t; };
|
18 |
-
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
19 |
-
|
20 |
-
static __device__ __forceinline__ int floor_div(int a, int b)
|
21 |
-
{
|
22 |
-
int t = 1 - a / b;
|
23 |
-
return (a + t * b) / b - t;
|
24 |
-
}
|
25 |
-
|
26 |
-
//------------------------------------------------------------------------
|
27 |
-
// Generic CUDA implementation for large filters.
|
28 |
-
|
29 |
-
template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
|
30 |
-
{
|
31 |
-
typedef typename InternalType<T>::scalar_t scalar_t;
|
32 |
-
|
33 |
-
// Calculate thread index.
|
34 |
-
int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
|
35 |
-
int outY = minorBase / p.launchMinor;
|
36 |
-
minorBase -= outY * p.launchMinor;
|
37 |
-
int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
|
38 |
-
int majorBase = blockIdx.z * p.loopMajor;
|
39 |
-
if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
|
40 |
-
return;
|
41 |
-
|
42 |
-
// Setup Y receptive field.
|
43 |
-
int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
|
44 |
-
int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
|
45 |
-
int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
|
46 |
-
int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
|
47 |
-
if (p.flip)
|
48 |
-
filterY = p.filterSize.y - 1 - filterY;
|
49 |
-
|
50 |
-
// Loop over major, minor, and X.
|
51 |
-
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
52 |
-
for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
|
53 |
-
{
|
54 |
-
int nc = major * p.sizeMinor + minor;
|
55 |
-
int n = nc / p.inSize.z;
|
56 |
-
int c = nc - n * p.inSize.z;
|
57 |
-
for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
|
58 |
-
{
|
59 |
-
// Setup X receptive field.
|
60 |
-
int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
|
61 |
-
int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
|
62 |
-
int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
|
63 |
-
int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
|
64 |
-
if (p.flip)
|
65 |
-
filterX = p.filterSize.x - 1 - filterX;
|
66 |
-
|
67 |
-
// Initialize pointers.
|
68 |
-
const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
69 |
-
const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
|
70 |
-
int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
|
71 |
-
int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
|
72 |
-
|
73 |
-
// Inner loop.
|
74 |
-
scalar_t v = 0;
|
75 |
-
for (int y = 0; y < h; y++)
|
76 |
-
{
|
77 |
-
for (int x = 0; x < w; x++)
|
78 |
-
{
|
79 |
-
v += (scalar_t)(*xp) * (scalar_t)(*fp);
|
80 |
-
xp += p.inStride.x;
|
81 |
-
fp += filterStepX;
|
82 |
-
}
|
83 |
-
xp += p.inStride.y - w * p.inStride.x;
|
84 |
-
fp += filterStepY - w * filterStepX;
|
85 |
-
}
|
86 |
-
|
87 |
-
// Store result.
|
88 |
-
v *= p.gain;
|
89 |
-
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
90 |
-
}
|
91 |
-
}
|
92 |
-
}
|
93 |
-
|
94 |
-
//------------------------------------------------------------------------
|
95 |
-
// Specialized CUDA implementation for small filters.
|
96 |
-
|
97 |
-
template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
|
98 |
-
static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
|
99 |
-
{
|
100 |
-
typedef typename InternalType<T>::scalar_t scalar_t;
|
101 |
-
const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
|
102 |
-
const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
|
103 |
-
__shared__ volatile scalar_t sf[filterH][filterW];
|
104 |
-
__shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
|
105 |
-
|
106 |
-
// Calculate tile index.
|
107 |
-
int minorBase = blockIdx.x;
|
108 |
-
int tileOutY = minorBase / p.launchMinor;
|
109 |
-
minorBase -= tileOutY * p.launchMinor;
|
110 |
-
minorBase *= loopMinor;
|
111 |
-
tileOutY *= tileOutH;
|
112 |
-
int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
|
113 |
-
int majorBase = blockIdx.z * p.loopMajor;
|
114 |
-
if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
|
115 |
-
return;
|
116 |
-
|
117 |
-
// Load filter (flipped).
|
118 |
-
for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
|
119 |
-
{
|
120 |
-
int fy = tapIdx / filterW;
|
121 |
-
int fx = tapIdx - fy * filterW;
|
122 |
-
scalar_t v = 0;
|
123 |
-
if (fx < p.filterSize.x & fy < p.filterSize.y)
|
124 |
-
{
|
125 |
-
int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
|
126 |
-
int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
|
127 |
-
v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
|
128 |
-
}
|
129 |
-
sf[fy][fx] = v;
|
130 |
-
}
|
131 |
-
|
132 |
-
// Loop over major and X.
|
133 |
-
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
134 |
-
{
|
135 |
-
int baseNC = major * p.sizeMinor + minorBase;
|
136 |
-
int n = baseNC / p.inSize.z;
|
137 |
-
int baseC = baseNC - n * p.inSize.z;
|
138 |
-
for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
|
139 |
-
{
|
140 |
-
// Load input pixels.
|
141 |
-
int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
|
142 |
-
int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
|
143 |
-
int tileInX = floor_div(tileMidX, upx);
|
144 |
-
int tileInY = floor_div(tileMidY, upy);
|
145 |
-
__syncthreads();
|
146 |
-
for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
|
147 |
-
{
|
148 |
-
int relC = inIdx;
|
149 |
-
int relInX = relC / loopMinor;
|
150 |
-
int relInY = relInX / tileInW;
|
151 |
-
relC -= relInX * loopMinor;
|
152 |
-
relInX -= relInY * tileInW;
|
153 |
-
int c = baseC + relC;
|
154 |
-
int inX = tileInX + relInX;
|
155 |
-
int inY = tileInY + relInY;
|
156 |
-
scalar_t v = 0;
|
157 |
-
if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
|
158 |
-
v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
159 |
-
sx[relInY][relInX][relC] = v;
|
160 |
-
}
|
161 |
-
|
162 |
-
// Loop over output pixels.
|
163 |
-
__syncthreads();
|
164 |
-
for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
|
165 |
-
{
|
166 |
-
int relC = outIdx;
|
167 |
-
int relOutX = relC / loopMinor;
|
168 |
-
int relOutY = relOutX / tileOutW;
|
169 |
-
relC -= relOutX * loopMinor;
|
170 |
-
relOutX -= relOutY * tileOutW;
|
171 |
-
int c = baseC + relC;
|
172 |
-
int outX = tileOutX + relOutX;
|
173 |
-
int outY = tileOutY + relOutY;
|
174 |
-
|
175 |
-
// Setup receptive field.
|
176 |
-
int midX = tileMidX + relOutX * downx;
|
177 |
-
int midY = tileMidY + relOutY * downy;
|
178 |
-
int inX = floor_div(midX, upx);
|
179 |
-
int inY = floor_div(midY, upy);
|
180 |
-
int relInX = inX - tileInX;
|
181 |
-
int relInY = inY - tileInY;
|
182 |
-
int filterX = (inX + 1) * upx - midX - 1; // flipped
|
183 |
-
int filterY = (inY + 1) * upy - midY - 1; // flipped
|
184 |
-
|
185 |
-
// Inner loop.
|
186 |
-
if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
|
187 |
-
{
|
188 |
-
scalar_t v = 0;
|
189 |
-
#pragma unroll
|
190 |
-
for (int y = 0; y < filterH / upy; y++)
|
191 |
-
#pragma unroll
|
192 |
-
for (int x = 0; x < filterW / upx; x++)
|
193 |
-
v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
|
194 |
-
v *= p.gain;
|
195 |
-
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
196 |
-
}
|
197 |
-
}
|
198 |
-
}
|
199 |
-
}
|
200 |
-
}
|
201 |
-
|
202 |
-
//------------------------------------------------------------------------
|
203 |
-
// CUDA kernel selection.
|
204 |
-
|
205 |
-
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
|
206 |
-
{
|
207 |
-
int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
|
208 |
-
|
209 |
-
upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
|
210 |
-
if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
|
211 |
-
|
212 |
-
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
|
213 |
-
{
|
214 |
-
if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
|
215 |
-
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
216 |
-
if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
|
217 |
-
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
218 |
-
if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
|
219 |
-
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
220 |
-
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
|
221 |
-
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
222 |
-
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
|
223 |
-
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
224 |
-
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
225 |
-
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
|
226 |
-
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
227 |
-
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
|
228 |
-
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
229 |
-
}
|
230 |
-
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
|
231 |
-
{
|
232 |
-
if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
|
233 |
-
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
234 |
-
if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
235 |
-
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
236 |
-
if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
237 |
-
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
238 |
-
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
|
239 |
-
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
240 |
-
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
|
241 |
-
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
242 |
-
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
243 |
-
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
|
244 |
-
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
245 |
-
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
|
246 |
-
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
247 |
-
}
|
248 |
-
if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
|
249 |
-
{
|
250 |
-
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
|
251 |
-
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
252 |
-
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
253 |
-
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
|
254 |
-
}
|
255 |
-
if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
|
256 |
-
{
|
257 |
-
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
|
258 |
-
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
|
259 |
-
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
260 |
-
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
|
261 |
-
}
|
262 |
-
if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
|
263 |
-
{
|
264 |
-
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
265 |
-
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
|
266 |
-
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
267 |
-
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
|
268 |
-
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
269 |
-
}
|
270 |
-
if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
|
271 |
-
{
|
272 |
-
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
273 |
-
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
|
274 |
-
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
275 |
-
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
|
276 |
-
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
277 |
-
}
|
278 |
-
if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
|
279 |
-
{
|
280 |
-
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
281 |
-
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
|
282 |
-
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
283 |
-
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
|
284 |
-
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
285 |
-
}
|
286 |
-
if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
|
287 |
-
{
|
288 |
-
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
289 |
-
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
|
290 |
-
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
291 |
-
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
|
292 |
-
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
293 |
-
}
|
294 |
-
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
|
295 |
-
{
|
296 |
-
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
|
297 |
-
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
|
298 |
-
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
|
299 |
-
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
|
300 |
-
}
|
301 |
-
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
|
302 |
-
{
|
303 |
-
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
|
304 |
-
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
|
305 |
-
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
|
306 |
-
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
|
307 |
-
}
|
308 |
-
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
|
309 |
-
{
|
310 |
-
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
|
311 |
-
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,8,1>, 64,8,1, 1};
|
312 |
-
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
|
313 |
-
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,8,1>, 64,8,1, 1};
|
314 |
-
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
|
315 |
-
}
|
316 |
-
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
|
317 |
-
{
|
318 |
-
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
|
319 |
-
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,1,8>, 64,1,8, 1};
|
320 |
-
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
|
321 |
-
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,1,8>, 64,1,8, 1};
|
322 |
-
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
|
323 |
-
}
|
324 |
-
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
|
325 |
-
{
|
326 |
-
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
|
327 |
-
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 32,16,1>, 32,16,1, 1};
|
328 |
-
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
|
329 |
-
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 32,16,1>, 32,16,1, 1};
|
330 |
-
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
|
331 |
-
}
|
332 |
-
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
|
333 |
-
{
|
334 |
-
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
|
335 |
-
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 1,64,8>, 1,64,8, 1};
|
336 |
-
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
|
337 |
-
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 1,64,8>, 1,64,8, 1};
|
338 |
-
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
|
339 |
-
}
|
340 |
-
return spec;
|
341 |
-
}
|
342 |
-
|
343 |
-
//------------------------------------------------------------------------
|
344 |
-
// Template specializations.
|
345 |
-
|
346 |
-
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
|
347 |
-
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
|
348 |
-
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
|
349 |
-
|
350 |
-
//------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PTI/torch_utils/ops/upfirdn2d.h
DELETED
@@ -1,59 +0,0 @@
|
|
1 |
-
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
//
|
3 |
-
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
// and proprietary rights in and to this software, related documentation
|
5 |
-
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
// distribution of this software and related documentation without an express
|
7 |
-
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
#include <cuda_runtime.h>
|
10 |
-
|
11 |
-
//------------------------------------------------------------------------
|
12 |
-
// CUDA kernel parameters.
|
13 |
-
|
14 |
-
struct upfirdn2d_kernel_params
|
15 |
-
{
|
16 |
-
const void* x;
|
17 |
-
const float* f;
|
18 |
-
void* y;
|
19 |
-
|
20 |
-
int2 up;
|
21 |
-
int2 down;
|
22 |
-
int2 pad0;
|
23 |
-
int flip;
|
24 |
-
float gain;
|
25 |
-
|
26 |
-
int4 inSize; // [width, height, channel, batch]
|
27 |
-
int4 inStride;
|
28 |
-
int2 filterSize; // [width, height]
|
29 |
-
int2 filterStride;
|
30 |
-
int4 outSize; // [width, height, channel, batch]
|
31 |
-
int4 outStride;
|
32 |
-
int sizeMinor;
|
33 |
-
int sizeMajor;
|
34 |
-
|
35 |
-
int loopMinor;
|
36 |
-
int loopMajor;
|
37 |
-
int loopX;
|
38 |
-
int launchMinor;
|
39 |
-
int launchMajor;
|
40 |
-
};
|
41 |
-
|
42 |
-
//------------------------------------------------------------------------
|
43 |
-
// CUDA kernel specialization.
|
44 |
-
|
45 |
-
struct upfirdn2d_kernel_spec
|
46 |
-
{
|
47 |
-
void* kernel;
|
48 |
-
int tileOutW;
|
49 |
-
int tileOutH;
|
50 |
-
int loopMinor;
|
51 |
-
int loopX;
|
52 |
-
};
|
53 |
-
|
54 |
-
//------------------------------------------------------------------------
|
55 |
-
// CUDA kernel selection.
|
56 |
-
|
57 |
-
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
|
58 |
-
|
59 |
-
//------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PTI/torch_utils/ops/upfirdn2d.py
DELETED
@@ -1,384 +0,0 @@
|
|
1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
"""Custom PyTorch ops for efficient resampling of 2D images."""
|
10 |
-
|
11 |
-
import os
|
12 |
-
import warnings
|
13 |
-
import numpy as np
|
14 |
-
import torch
|
15 |
-
import traceback
|
16 |
-
|
17 |
-
from .. import custom_ops
|
18 |
-
from .. import misc
|
19 |
-
from . import conv2d_gradfix
|
20 |
-
|
21 |
-
#----------------------------------------------------------------------------
|
22 |
-
|
23 |
-
_inited = False
|
24 |
-
_plugin = None
|
25 |
-
|
26 |
-
def _init():
|
27 |
-
global _inited, _plugin
|
28 |
-
if not _inited:
|
29 |
-
sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
|
30 |
-
sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
|
31 |
-
try:
|
32 |
-
_plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
|
33 |
-
except:
|
34 |
-
warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
|
35 |
-
return _plugin is not None
|
36 |
-
|
37 |
-
def _parse_scaling(scaling):
|
38 |
-
if isinstance(scaling, int):
|
39 |
-
scaling = [scaling, scaling]
|
40 |
-
assert isinstance(scaling, (list, tuple))
|
41 |
-
assert all(isinstance(x, int) for x in scaling)
|
42 |
-
sx, sy = scaling
|
43 |
-
assert sx >= 1 and sy >= 1
|
44 |
-
return sx, sy
|
45 |
-
|
46 |
-
def _parse_padding(padding):
|
47 |
-
if isinstance(padding, int):
|
48 |
-
padding = [padding, padding]
|
49 |
-
assert isinstance(padding, (list, tuple))
|
50 |
-
assert all(isinstance(x, int) for x in padding)
|
51 |
-
if len(padding) == 2:
|
52 |
-
padx, pady = padding
|
53 |
-
padding = [padx, padx, pady, pady]
|
54 |
-
padx0, padx1, pady0, pady1 = padding
|
55 |
-
return padx0, padx1, pady0, pady1
|
56 |
-
|
57 |
-
def _get_filter_size(f):
|
58 |
-
if f is None:
|
59 |
-
return 1, 1
|
60 |
-
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
61 |
-
fw = f.shape[-1]
|
62 |
-
fh = f.shape[0]
|
63 |
-
with misc.suppress_tracer_warnings():
|
64 |
-
fw = int(fw)
|
65 |
-
fh = int(fh)
|
66 |
-
misc.assert_shape(f, [fh, fw][:f.ndim])
|
67 |
-
assert fw >= 1 and fh >= 1
|
68 |
-
return fw, fh
|
69 |
-
|
70 |
-
#----------------------------------------------------------------------------
|
71 |
-
|
72 |
-
def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
|
73 |
-
r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
|
74 |
-
|
75 |
-
Args:
|
76 |
-
f: Torch tensor, numpy array, or python list of the shape
|
77 |
-
`[filter_height, filter_width]` (non-separable),
|
78 |
-
`[filter_taps]` (separable),
|
79 |
-
`[]` (impulse), or
|
80 |
-
`None` (identity).
|
81 |
-
device: Result device (default: cpu).
|
82 |
-
normalize: Normalize the filter so that it retains the magnitude
|
83 |
-
for constant input signal (DC)? (default: True).
|
84 |
-
flip_filter: Flip the filter? (default: False).
|
85 |
-
gain: Overall scaling factor for signal magnitude (default: 1).
|
86 |
-
separable: Return a separable filter? (default: select automatically).
|
87 |
-
|
88 |
-
Returns:
|
89 |
-
Float32 tensor of the shape
|
90 |
-
`[filter_height, filter_width]` (non-separable) or
|
91 |
-
`[filter_taps]` (separable).
|
92 |
-
"""
|
93 |
-
# Validate.
|
94 |
-
if f is None:
|
95 |
-
f = 1
|
96 |
-
f = torch.as_tensor(f, dtype=torch.float32)
|
97 |
-
assert f.ndim in [0, 1, 2]
|
98 |
-
assert f.numel() > 0
|
99 |
-
if f.ndim == 0:
|
100 |
-
f = f[np.newaxis]
|
101 |
-
|
102 |
-
# Separable?
|
103 |
-
if separable is None:
|
104 |
-
separable = (f.ndim == 1 and f.numel() >= 8)
|
105 |
-
if f.ndim == 1 and not separable:
|
106 |
-
f = f.ger(f)
|
107 |
-
assert f.ndim == (1 if separable else 2)
|
108 |
-
|
109 |
-
# Apply normalize, flip, gain, and device.
|
110 |
-
if normalize:
|
111 |
-
f /= f.sum()
|
112 |
-
if flip_filter:
|
113 |
-
f = f.flip(list(range(f.ndim)))
|
114 |
-
f = f * (gain ** (f.ndim / 2))
|
115 |
-
f = f.to(device=device)
|
116 |
-
return f
|
117 |
-
|
118 |
-
#----------------------------------------------------------------------------
|
119 |
-
|
120 |
-
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
121 |
-
r"""Pad, upsample, filter, and downsample a batch of 2D images.
|
122 |
-
|
123 |
-
Performs the following sequence of operations for each channel:
|
124 |
-
|
125 |
-
1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
|
126 |
-
|
127 |
-
2. Pad the image with the specified number of zeros on each side (`padding`).
|
128 |
-
Negative padding corresponds to cropping the image.
|
129 |
-
|
130 |
-
3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
|
131 |
-
so that the footprint of all output pixels lies within the input image.
|
132 |
-
|
133 |
-
4. Downsample the image by keeping every Nth pixel (`down`).
|
134 |
-
|
135 |
-
This sequence of operations bears close resemblance to scipy.signal.upfirdn().
|
136 |
-
The fused op is considerably more efficient than performing the same calculation
|
137 |
-
using standard PyTorch ops. It supports gradients of arbitrary order.
|
138 |
-
|
139 |
-
Args:
|
140 |
-
x: Float32/float64/float16 input tensor of the shape
|
141 |
-
`[batch_size, num_channels, in_height, in_width]`.
|
142 |
-
f: Float32 FIR filter of the shape
|
143 |
-
`[filter_height, filter_width]` (non-separable),
|
144 |
-
`[filter_taps]` (separable), or
|
145 |
-
`None` (identity).
|
146 |
-
up: Integer upsampling factor. Can be a single int or a list/tuple
|
147 |
-
`[x, y]` (default: 1).
|
148 |
-
down: Integer downsampling factor. Can be a single int or a list/tuple
|
149 |
-
`[x, y]` (default: 1).
|
150 |
-
padding: Padding with respect to the upsampled image. Can be a single number
|
151 |
-
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
152 |
-
(default: 0).
|
153 |
-
flip_filter: False = convolution, True = correlation (default: False).
|
154 |
-
gain: Overall scaling factor for signal magnitude (default: 1).
|
155 |
-
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
156 |
-
|
157 |
-
Returns:
|
158 |
-
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
159 |
-
"""
|
160 |
-
assert isinstance(x, torch.Tensor)
|
161 |
-
assert impl in ['ref', 'cuda']
|
162 |
-
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
163 |
-
return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
|
164 |
-
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
|
165 |
-
|
166 |
-
#----------------------------------------------------------------------------
|
167 |
-
|
168 |
-
@misc.profiled_function
|
169 |
-
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
170 |
-
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
|
171 |
-
"""
|
172 |
-
# Validate arguments.
|
173 |
-
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
174 |
-
if f is None:
|
175 |
-
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
176 |
-
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
177 |
-
assert f.dtype == torch.float32 and not f.requires_grad
|
178 |
-
batch_size, num_channels, in_height, in_width = x.shape
|
179 |
-
upx, upy = _parse_scaling(up)
|
180 |
-
downx, downy = _parse_scaling(down)
|
181 |
-
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
182 |
-
|
183 |
-
# Upsample by inserting zeros.
|
184 |
-
x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
|
185 |
-
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
|
186 |
-
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
|
187 |
-
|
188 |
-
# Pad or crop.
|
189 |
-
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
|
190 |
-
x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
|
191 |
-
|
192 |
-
# Setup filter.
|
193 |
-
f = f * (gain ** (f.ndim / 2))
|
194 |
-
f = f.to(x.dtype)
|
195 |
-
if not flip_filter:
|
196 |
-
f = f.flip(list(range(f.ndim)))
|
197 |
-
|
198 |
-
# Convolve with the filter.
|
199 |
-
f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
|
200 |
-
if f.ndim == 4:
|
201 |
-
x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
|
202 |
-
else:
|
203 |
-
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
|
204 |
-
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
|
205 |
-
|
206 |
-
# Downsample by throwing away pixels.
|
207 |
-
x = x[:, :, ::downy, ::downx]
|
208 |
-
return x
|
209 |
-
|
210 |
-
#----------------------------------------------------------------------------
|
211 |
-
|
212 |
-
_upfirdn2d_cuda_cache = dict()
|
213 |
-
|
214 |
-
def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
|
215 |
-
"""Fast CUDA implementation of `upfirdn2d()` using custom ops.
|
216 |
-
"""
|
217 |
-
# Parse arguments.
|
218 |
-
upx, upy = _parse_scaling(up)
|
219 |
-
downx, downy = _parse_scaling(down)
|
220 |
-
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
221 |
-
|
222 |
-
# Lookup from cache.
|
223 |
-
key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
224 |
-
if key in _upfirdn2d_cuda_cache:
|
225 |
-
return _upfirdn2d_cuda_cache[key]
|
226 |
-
|
227 |
-
# Forward op.
|
228 |
-
class Upfirdn2dCuda(torch.autograd.Function):
|
229 |
-
@staticmethod
|
230 |
-
def forward(ctx, x, f): # pylint: disable=arguments-differ
|
231 |
-
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
232 |
-
if f is None:
|
233 |
-
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
234 |
-
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
235 |
-
y = x
|
236 |
-
if f.ndim == 2:
|
237 |
-
y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
238 |
-
else:
|
239 |
-
y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
|
240 |
-
y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
|
241 |
-
ctx.save_for_backward(f)
|
242 |
-
ctx.x_shape = x.shape
|
243 |
-
return y
|
244 |
-
|
245 |
-
@staticmethod
|
246 |
-
def backward(ctx, dy): # pylint: disable=arguments-differ
|
247 |
-
f, = ctx.saved_tensors
|
248 |
-
_, _, ih, iw = ctx.x_shape
|
249 |
-
_, _, oh, ow = dy.shape
|
250 |
-
fw, fh = _get_filter_size(f)
|
251 |
-
p = [
|
252 |
-
fw - padx0 - 1,
|
253 |
-
iw * upx - ow * downx + padx0 - upx + 1,
|
254 |
-
fh - pady0 - 1,
|
255 |
-
ih * upy - oh * downy + pady0 - upy + 1,
|
256 |
-
]
|
257 |
-
dx = None
|
258 |
-
df = None
|
259 |
-
|
260 |
-
if ctx.needs_input_grad[0]:
|
261 |
-
dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
|
262 |
-
|
263 |
-
assert not ctx.needs_input_grad[1]
|
264 |
-
return dx, df
|
265 |
-
|
266 |
-
# Add to cache.
|
267 |
-
_upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
|
268 |
-
return Upfirdn2dCuda
|
269 |
-
|
270 |
-
#----------------------------------------------------------------------------
|
271 |
-
|
272 |
-
def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
273 |
-
r"""Filter a batch of 2D images using the given 2D FIR filter.
|
274 |
-
|
275 |
-
By default, the result is padded so that its shape matches the input.
|
276 |
-
User-specified padding is applied on top of that, with negative values
|
277 |
-
indicating cropping. Pixels outside the image are assumed to be zero.
|
278 |
-
|
279 |
-
Args:
|
280 |
-
x: Float32/float64/float16 input tensor of the shape
|
281 |
-
`[batch_size, num_channels, in_height, in_width]`.
|
282 |
-
f: Float32 FIR filter of the shape
|
283 |
-
`[filter_height, filter_width]` (non-separable),
|
284 |
-
`[filter_taps]` (separable), or
|
285 |
-
`None` (identity).
|
286 |
-
padding: Padding with respect to the output. Can be a single number or a
|
287 |
-
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
288 |
-
(default: 0).
|
289 |
-
flip_filter: False = convolution, True = correlation (default: False).
|
290 |
-
gain: Overall scaling factor for signal magnitude (default: 1).
|
291 |
-
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
292 |
-
|
293 |
-
Returns:
|
294 |
-
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
295 |
-
"""
|
296 |
-
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
297 |
-
fw, fh = _get_filter_size(f)
|
298 |
-
p = [
|
299 |
-
padx0 + fw // 2,
|
300 |
-
padx1 + (fw - 1) // 2,
|
301 |
-
pady0 + fh // 2,
|
302 |
-
pady1 + (fh - 1) // 2,
|
303 |
-
]
|
304 |
-
return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
305 |
-
|
306 |
-
#----------------------------------------------------------------------------
|
307 |
-
|
308 |
-
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
309 |
-
r"""Upsample a batch of 2D images using the given 2D FIR filter.
|
310 |
-
|
311 |
-
By default, the result is padded so that its shape is a multiple of the input.
|
312 |
-
User-specified padding is applied on top of that, with negative values
|
313 |
-
indicating cropping. Pixels outside the image are assumed to be zero.
|
314 |
-
|
315 |
-
Args:
|
316 |
-
x: Float32/float64/float16 input tensor of the shape
|
317 |
-
`[batch_size, num_channels, in_height, in_width]`.
|
318 |
-
f: Float32 FIR filter of the shape
|
319 |
-
`[filter_height, filter_width]` (non-separable),
|
320 |
-
`[filter_taps]` (separable), or
|
321 |
-
`None` (identity).
|
322 |
-
up: Integer upsampling factor. Can be a single int or a list/tuple
|
323 |
-
`[x, y]` (default: 1).
|
324 |
-
padding: Padding with respect to the output. Can be a single number or a
|
325 |
-
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
326 |
-
(default: 0).
|
327 |
-
flip_filter: False = convolution, True = correlation (default: False).
|
328 |
-
gain: Overall scaling factor for signal magnitude (default: 1).
|
329 |
-
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
330 |
-
|
331 |
-
Returns:
|
332 |
-
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
333 |
-
"""
|
334 |
-
upx, upy = _parse_scaling(up)
|
335 |
-
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
336 |
-
fw, fh = _get_filter_size(f)
|
337 |
-
p = [
|
338 |
-
padx0 + (fw + upx - 1) // 2,
|
339 |
-
padx1 + (fw - upx) // 2,
|
340 |
-
pady0 + (fh + upy - 1) // 2,
|
341 |
-
pady1 + (fh - upy) // 2,
|
342 |
-
]
|
343 |
-
return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
|
344 |
-
|
345 |
-
#----------------------------------------------------------------------------
|
346 |
-
|
347 |
-
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
348 |
-
r"""Downsample a batch of 2D images using the given 2D FIR filter.
|
349 |
-
|
350 |
-
By default, the result is padded so that its shape is a fraction of the input.
|
351 |
-
User-specified padding is applied on top of that, with negative values
|
352 |
-
indicating cropping. Pixels outside the image are assumed to be zero.
|
353 |
-
|
354 |
-
Args:
|
355 |
-
x: Float32/float64/float16 input tensor of the shape
|
356 |
-
`[batch_size, num_channels, in_height, in_width]`.
|
357 |
-
f: Float32 FIR filter of the shape
|
358 |
-
`[filter_height, filter_width]` (non-separable),
|
359 |
-
`[filter_taps]` (separable), or
|
360 |
-
`None` (identity).
|
361 |
-
down: Integer downsampling factor. Can be a single int or a list/tuple
|
362 |
-
`[x, y]` (default: 1).
|
363 |
-
padding: Padding with respect to the input. Can be a single number or a
|
364 |
-
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
365 |
-
(default: 0).
|
366 |
-
flip_filter: False = convolution, True = correlation (default: False).
|
367 |
-
gain: Overall scaling factor for signal magnitude (default: 1).
|
368 |
-
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
369 |
-
|
370 |
-
Returns:
|
371 |
-
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
372 |
-
"""
|
373 |
-
downx, downy = _parse_scaling(down)
|
374 |
-
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
375 |
-
fw, fh = _get_filter_size(f)
|
376 |
-
p = [
|
377 |
-
padx0 + (fw - downx + 1) // 2,
|
378 |
-
padx1 + (fw - downx) // 2,
|
379 |
-
pady0 + (fh - downy + 1) // 2,
|
380 |
-
pady1 + (fh - downy) // 2,
|
381 |
-
]
|
382 |
-
return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
383 |
-
|
384 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PTI/torch_utils/persistence.py
DELETED
@@ -1,251 +0,0 @@
|
|
1 |
-
ο»Ώ# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
"""Facilities for pickling Python code alongside other data.
|
10 |
-
|
11 |
-
The pickled code is automatically imported into a separate Python module
|
12 |
-
during unpickling. This way, any previously exported pickles will remain
|
13 |
-
usable even if the original code is no longer available, or if the current
|
14 |
-
version of the code is not consistent with what was originally pickled."""
|
15 |
-
|
16 |
-
import sys
|
17 |
-
import pickle
|
18 |
-
import io
|
19 |
-
import inspect
|
20 |
-
import copy
|
21 |
-
import uuid
|
22 |
-
import types
|
23 |
-
import dnnlib
|
24 |
-
|
25 |
-
#----------------------------------------------------------------------------
|
26 |
-
|
27 |
-
_version = 6 # internal version number
|
28 |
-
_decorators = set() # {decorator_class, ...}
|
29 |
-
_import_hooks = [] # [hook_function, ...]
|
30 |
-
_module_to_src_dict = dict() # {module: src, ...}
|
31 |
-
_src_to_module_dict = dict() # {src: module, ...}
|
32 |
-
|
33 |
-
#----------------------------------------------------------------------------
|
34 |
-
|
35 |
-
def persistent_class(orig_class):
|
36 |
-
r"""Class decorator that extends a given class to save its source code
|
37 |
-
when pickled.
|
38 |
-
|
39 |
-
Example:
|
40 |
-
|
41 |
-
from torch_utils import persistence
|
42 |
-
|
43 |
-
@persistence.persistent_class
|
44 |
-
class MyNetwork(torch.nn.Module):
|
45 |
-
def __init__(self, num_inputs, num_outputs):
|
46 |
-
super().__init__()
|
47 |
-
self.fc = MyLayer(num_inputs, num_outputs)
|
48 |
-
...
|
49 |
-
|
50 |
-
@persistence.persistent_class
|
51 |
-
class MyLayer(torch.nn.Module):
|
52 |
-
...
|
53 |
-
|
54 |
-
When pickled, any instance of `MyNetwork` and `MyLayer` will save its
|
55 |
-
source code alongside other internal state (e.g., parameters, buffers,
|
56 |
-
and submodules). This way, any previously exported pickle will remain
|
57 |
-
usable even if the class definitions have been modified or are no
|
58 |
-
longer available.
|
59 |
-
|
60 |
-
The decorator saves the source code of the entire Python module
|
61 |
-
containing the decorated class. It does *not* save the source code of
|
62 |
-
any imported modules. Thus, the imported modules must be available
|
63 |
-
during unpickling, also including `torch_utils.persistence` itself.
|
64 |
-
|
65 |
-
It is ok to call functions defined in the same module from the
|
66 |
-
decorated class. However, if the decorated class depends on other
|
67 |
-
classes defined in the same module, they must be decorated as well.
|
68 |
-
This is illustrated in the above example in the case of `MyLayer`.
|
69 |
-
|
70 |
-
It is also possible to employ the decorator just-in-time before
|
71 |
-
calling the constructor. For example:
|
72 |
-
|
73 |
-
cls = MyLayer
|
74 |
-
if want_to_make_it_persistent:
|
75 |
-
cls = persistence.persistent_class(cls)
|
76 |
-
layer = cls(num_inputs, num_outputs)
|
77 |
-
|
78 |
-
As an additional feature, the decorator also keeps track of the
|
79 |
-
arguments that were used to construct each instance of the decorated
|
80 |
-
class. The arguments can be queried via `obj.init_args` and
|
81 |
-
`obj.init_kwargs`, and they are automatically pickled alongside other
|
82 |
-
object state. A typical use case is to first unpickle a previous
|
83 |
-
instance of a persistent class, and then upgrade it to use the latest
|
84 |
-
version of the source code:
|
85 |
-
|
86 |
-
with open('old_pickle.pkl', 'rb') as f:
|
87 |
-
old_net = pickle.load(f)
|
88 |
-
new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
|
89 |
-
misc.copy_params_and_buffers(old_net, new_net, require_all=True)
|
90 |
-
"""
|
91 |
-
assert isinstance(orig_class, type)
|
92 |
-
if is_persistent(orig_class):
|
93 |
-
return orig_class
|
94 |
-
|
95 |
-
assert orig_class.__module__ in sys.modules
|
96 |
-
orig_module = sys.modules[orig_class.__module__]
|
97 |
-
orig_module_src = _module_to_src(orig_module)
|
98 |
-
|
99 |
-
class Decorator(orig_class):
|
100 |
-
_orig_module_src = orig_module_src
|
101 |
-
_orig_class_name = orig_class.__name__
|
102 |
-
|
103 |
-
def __init__(self, *args, **kwargs):
|
104 |
-
super().__init__(*args, **kwargs)
|
105 |
-
self._init_args = copy.deepcopy(args)
|
106 |
-
self._init_kwargs = copy.deepcopy(kwargs)
|
107 |
-
assert orig_class.__name__ in orig_module.__dict__
|
108 |
-
_check_pickleable(self.__reduce__())
|
109 |
-
|
110 |
-
@property
|
111 |
-
def init_args(self):
|
112 |
-
return copy.deepcopy(self._init_args)
|
113 |
-
|
114 |
-
@property
|
115 |
-
def init_kwargs(self):
|
116 |
-
return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
|
117 |
-
|
118 |
-
def __reduce__(self):
|
119 |
-
fields = list(super().__reduce__())
|
120 |
-
fields += [None] * max(3 - len(fields), 0)
|
121 |
-
if fields[0] is not _reconstruct_persistent_obj:
|
122 |
-
meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
|
123 |
-
fields[0] = _reconstruct_persistent_obj # reconstruct func
|
124 |
-
fields[1] = (meta,) # reconstruct args
|
125 |
-
fields[2] = None # state dict
|
126 |
-
return tuple(fields)
|
127 |
-
|
128 |
-
Decorator.__name__ = orig_class.__name__
|
129 |
-
_decorators.add(Decorator)
|
130 |
-
return Decorator
|
131 |
-
|
132 |
-
#----------------------------------------------------------------------------
|
133 |
-
|
134 |
-
def is_persistent(obj):
|
135 |
-
r"""Test whether the given object or class is persistent, i.e.,
|
136 |
-
whether it will save its source code when pickled.
|
137 |
-
"""
|
138 |
-
try:
|
139 |
-
if obj in _decorators:
|
140 |
-
return True
|
141 |
-
except TypeError:
|
142 |
-
pass
|
143 |
-
return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
|
144 |
-
|
145 |
-
#----------------------------------------------------------------------------
|
146 |
-
|
147 |
-
def import_hook(hook):
|
148 |
-
r"""Register an import hook that is called whenever a persistent object
|
149 |
-
is being unpickled. A typical use case is to patch the pickled source
|
150 |
-
code to avoid errors and inconsistencies when the API of some imported
|
151 |
-
module has changed.
|
152 |
-
|
153 |
-
The hook should have the following signature:
|
154 |
-
|
155 |
-
hook(meta) -> modified meta
|
156 |
-
|
157 |
-
`meta` is an instance of `dnnlib.EasyDict` with the following fields:
|
158 |
-
|
159 |
-
type: Type of the persistent object, e.g. `'class'`.
|
160 |
-
version: Internal version number of `torch_utils.persistence`.
|
161 |
-
module_src Original source code of the Python module.
|
162 |
-
class_name: Class name in the original Python module.
|
163 |
-
state: Internal state of the object.
|
164 |
-
|
165 |
-
Example:
|
166 |
-
|
167 |
-
@persistence.import_hook
|
168 |
-
def wreck_my_network(meta):
|
169 |
-
if meta.class_name == 'MyNetwork':
|
170 |
-
print('MyNetwork is being imported. I will wreck it!')
|
171 |
-
meta.module_src = meta.module_src.replace("True", "False")
|
172 |
-
return meta
|
173 |
-
"""
|
174 |
-
assert callable(hook)
|
175 |
-
_import_hooks.append(hook)
|
176 |
-
|
177 |
-
#----------------------------------------------------------------------------
|
178 |
-
|
179 |
-
def _reconstruct_persistent_obj(meta):
|
180 |
-
r"""Hook that is called internally by the `pickle` module to unpickle
|
181 |
-
a persistent object.
|
182 |
-
"""
|
183 |
-
meta = dnnlib.EasyDict(meta)
|
184 |
-
meta.state = dnnlib.EasyDict(meta.state)
|
185 |
-
for hook in _import_hooks:
|
186 |
-
meta = hook(meta)
|
187 |
-
assert meta is not None
|
188 |
-
|
189 |
-
assert meta.version == _version
|
190 |
-
module = _src_to_module(meta.module_src)
|
191 |
-
|
192 |
-
assert meta.type == 'class'
|
193 |
-
orig_class = module.__dict__[meta.class_name]
|
194 |
-
decorator_class = persistent_class(orig_class)
|
195 |
-
obj = decorator_class.__new__(decorator_class)
|
196 |
-
|
197 |
-
setstate = getattr(obj, '__setstate__', None)
|
198 |
-
if callable(setstate):
|
199 |
-
setstate(meta.state) # pylint: disable=not-callable
|
200 |
-
else:
|
201 |
-
obj.__dict__.update(meta.state)
|
202 |
-
return obj
|
203 |
-
|
204 |
-
#----------------------------------------------------------------------------
|
205 |
-
|
206 |
-
def _module_to_src(module):
|
207 |
-
r"""Query the source code of a given Python module.
|
208 |
-
"""
|
209 |
-
src = _module_to_src_dict.get(module, None)
|
210 |
-
if src is None:
|
211 |
-
src = inspect.getsource(module)
|
212 |
-
_module_to_src_dict[module] = src
|
213 |
-
_src_to_module_dict[src] = module
|
214 |
-
return src
|
215 |
-
|
216 |
-
def _src_to_module(src):
|
217 |
-
r"""Get or create a Python module for the given source code.
|
218 |
-
"""
|
219 |
-
module = _src_to_module_dict.get(src, None)
|
220 |
-
if module is None:
|
221 |
-
module_name = "_imported_module_" + uuid.uuid4().hex
|
222 |
-
module = types.ModuleType(module_name)
|
223 |
-
sys.modules[module_name] = module
|
224 |
-
_module_to_src_dict[module] = src
|
225 |
-
_src_to_module_dict[src] = module
|
226 |
-
exec(src, module.__dict__) # pylint: disable=exec-used
|
227 |
-
return module
|
228 |
-
|
229 |
-
#----------------------------------------------------------------------------
|
230 |
-
|
231 |
-
def _check_pickleable(obj):
|
232 |
-
r"""Check that the given object is pickleable, raising an exception if
|
233 |
-
it is not. This function is expected to be considerably more efficient
|
234 |
-
than actually pickling the object.
|
235 |
-
"""
|
236 |
-
def recurse(obj):
|
237 |
-
if isinstance(obj, (list, tuple, set)):
|
238 |
-
return [recurse(x) for x in obj]
|
239 |
-
if isinstance(obj, dict):
|
240 |
-
return [[recurse(x), recurse(y)] for x, y in obj.items()]
|
241 |
-
if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
|
242 |
-
return None # Python primitive types are pickleable.
|
243 |
-
if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
|
244 |
-
return None # NumPy arrays and PyTorch tensors are pickleable.
|
245 |
-
if is_persistent(obj):
|
246 |
-
return None # Persistent objects are pickleable, by virtue of the constructor check.
|
247 |
-
return obj
|
248 |
-
with io.BytesIO() as f:
|
249 |
-
pickle.dump(recurse(obj), f)
|
250 |
-
|
251 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PTI/torch_utils/training_stats.py
DELETED
@@ -1,268 +0,0 @@
|
|
1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
"""Facilities for reporting and collecting training statistics across
|
10 |
-
multiple processes and devices. The interface is designed to minimize
|
11 |
-
synchronization overhead as well as the amount of boilerplate in user
|
12 |
-
code."""
|
13 |
-
|
14 |
-
import re
|
15 |
-
import numpy as np
|
16 |
-
import torch
|
17 |
-
import dnnlib
|
18 |
-
|
19 |
-
from . import misc
|
20 |
-
|
21 |
-
#----------------------------------------------------------------------------
|
22 |
-
|
23 |
-
_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
|
24 |
-
_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
|
25 |
-
_counter_dtype = torch.float64 # Data type to use for the internal counters.
|
26 |
-
_rank = 0 # Rank of the current process.
|
27 |
-
_sync_device = None # Device to use for multiprocess communication. None = single-process.
|
28 |
-
_sync_called = False # Has _sync() been called yet?
|
29 |
-
_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
|
30 |
-
_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
|
31 |
-
|
32 |
-
#----------------------------------------------------------------------------
|
33 |
-
|
34 |
-
def init_multiprocessing(rank, sync_device):
|
35 |
-
r"""Initializes `torch_utils.training_stats` for collecting statistics
|
36 |
-
across multiple processes.
|
37 |
-
|
38 |
-
This function must be called after
|
39 |
-
`torch.distributed.init_process_group()` and before `Collector.update()`.
|
40 |
-
The call is not necessary if multi-process collection is not needed.
|
41 |
-
|
42 |
-
Args:
|
43 |
-
rank: Rank of the current process.
|
44 |
-
sync_device: PyTorch device to use for inter-process
|
45 |
-
communication, or None to disable multi-process
|
46 |
-
collection. Typically `torch.device('cuda', rank)`.
|
47 |
-
"""
|
48 |
-
global _rank, _sync_device
|
49 |
-
assert not _sync_called
|
50 |
-
_rank = rank
|
51 |
-
_sync_device = sync_device
|
52 |
-
|
53 |
-
#----------------------------------------------------------------------------
|
54 |
-
|
55 |
-
@misc.profiled_function
|
56 |
-
def report(name, value):
|
57 |
-
r"""Broadcasts the given set of scalars to all interested instances of
|
58 |
-
`Collector`, across device and process boundaries.
|
59 |
-
|
60 |
-
This function is expected to be extremely cheap and can be safely
|
61 |
-
called from anywhere in the training loop, loss function, or inside a
|
62 |
-
`torch.nn.Module`.
|
63 |
-
|
64 |
-
Warning: The current implementation expects the set of unique names to
|
65 |
-
be consistent across processes. Please make sure that `report()` is
|
66 |
-
called at least once for each unique name by each process, and in the
|
67 |
-
same order. If a given process has no scalars to broadcast, it can do
|
68 |
-
`report(name, [])` (empty list).
|
69 |
-
|
70 |
-
Args:
|
71 |
-
name: Arbitrary string specifying the name of the statistic.
|
72 |
-
Averages are accumulated separately for each unique name.
|
73 |
-
value: Arbitrary set of scalars. Can be a list, tuple,
|
74 |
-
NumPy array, PyTorch tensor, or Python scalar.
|
75 |
-
|
76 |
-
Returns:
|
77 |
-
The same `value` that was passed in.
|
78 |
-
"""
|
79 |
-
if name not in _counters:
|
80 |
-
_counters[name] = dict()
|
81 |
-
|
82 |
-
elems = torch.as_tensor(value)
|
83 |
-
if elems.numel() == 0:
|
84 |
-
return value
|
85 |
-
|
86 |
-
elems = elems.detach().flatten().to(_reduce_dtype)
|
87 |
-
moments = torch.stack([
|
88 |
-
torch.ones_like(elems).sum(),
|
89 |
-
elems.sum(),
|
90 |
-
elems.square().sum(),
|
91 |
-
])
|
92 |
-
assert moments.ndim == 1 and moments.shape[0] == _num_moments
|
93 |
-
moments = moments.to(_counter_dtype)
|
94 |
-
|
95 |
-
device = moments.device
|
96 |
-
if device not in _counters[name]:
|
97 |
-
_counters[name][device] = torch.zeros_like(moments)
|
98 |
-
_counters[name][device].add_(moments)
|
99 |
-
return value
|
100 |
-
|
101 |
-
#----------------------------------------------------------------------------
|
102 |
-
|
103 |
-
def report0(name, value):
|
104 |
-
r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
|
105 |
-
but ignores any scalars provided by the other processes.
|
106 |
-
See `report()` for further details.
|
107 |
-
"""
|
108 |
-
report(name, value if _rank == 0 else [])
|
109 |
-
return value
|
110 |
-
|
111 |
-
#----------------------------------------------------------------------------
|
112 |
-
|
113 |
-
class Collector:
|
114 |
-
r"""Collects the scalars broadcasted by `report()` and `report0()` and
|
115 |
-
computes their long-term averages (mean and standard deviation) over
|
116 |
-
user-defined periods of time.
|
117 |
-
|
118 |
-
The averages are first collected into internal counters that are not
|
119 |
-
directly visible to the user. They are then copied to the user-visible
|
120 |
-
state as a result of calling `update()` and can then be queried using
|
121 |
-
`mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
|
122 |
-
internal counters for the next round, so that the user-visible state
|
123 |
-
effectively reflects averages collected between the last two calls to
|
124 |
-
`update()`.
|
125 |
-
|
126 |
-
Args:
|
127 |
-
regex: Regular expression defining which statistics to
|
128 |
-
collect. The default is to collect everything.
|
129 |
-
keep_previous: Whether to retain the previous averages if no
|
130 |
-
scalars were collected on a given round
|
131 |
-
(default: True).
|
132 |
-
"""
|
133 |
-
def __init__(self, regex='.*', keep_previous=True):
|
134 |
-
self._regex = re.compile(regex)
|
135 |
-
self._keep_previous = keep_previous
|
136 |
-
self._cumulative = dict()
|
137 |
-
self._moments = dict()
|
138 |
-
self.update()
|
139 |
-
self._moments.clear()
|
140 |
-
|
141 |
-
def names(self):
|
142 |
-
r"""Returns the names of all statistics broadcasted so far that
|
143 |
-
match the regular expression specified at construction time.
|
144 |
-
"""
|
145 |
-
return [name for name in _counters if self._regex.fullmatch(name)]
|
146 |
-
|
147 |
-
def update(self):
|
148 |
-
r"""Copies current values of the internal counters to the
|
149 |
-
user-visible state and resets them for the next round.
|
150 |
-
|
151 |
-
If `keep_previous=True` was specified at construction time, the
|
152 |
-
operation is skipped for statistics that have received no scalars
|
153 |
-
since the last update, retaining their previous averages.
|
154 |
-
|
155 |
-
This method performs a number of GPU-to-CPU transfers and one
|
156 |
-
`torch.distributed.all_reduce()`. It is intended to be called
|
157 |
-
periodically in the main training loop, typically once every
|
158 |
-
N training steps.
|
159 |
-
"""
|
160 |
-
if not self._keep_previous:
|
161 |
-
self._moments.clear()
|
162 |
-
for name, cumulative in _sync(self.names()):
|
163 |
-
if name not in self._cumulative:
|
164 |
-
self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
165 |
-
delta = cumulative - self._cumulative[name]
|
166 |
-
self._cumulative[name].copy_(cumulative)
|
167 |
-
if float(delta[0]) != 0:
|
168 |
-
self._moments[name] = delta
|
169 |
-
|
170 |
-
def _get_delta(self, name):
|
171 |
-
r"""Returns the raw moments that were accumulated for the given
|
172 |
-
statistic between the last two calls to `update()`, or zero if
|
173 |
-
no scalars were collected.
|
174 |
-
"""
|
175 |
-
assert self._regex.fullmatch(name)
|
176 |
-
if name not in self._moments:
|
177 |
-
self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
178 |
-
return self._moments[name]
|
179 |
-
|
180 |
-
def num(self, name):
|
181 |
-
r"""Returns the number of scalars that were accumulated for the given
|
182 |
-
statistic between the last two calls to `update()`, or zero if
|
183 |
-
no scalars were collected.
|
184 |
-
"""
|
185 |
-
delta = self._get_delta(name)
|
186 |
-
return int(delta[0])
|
187 |
-
|
188 |
-
def mean(self, name):
|
189 |
-
r"""Returns the mean of the scalars that were accumulated for the
|
190 |
-
given statistic between the last two calls to `update()`, or NaN if
|
191 |
-
no scalars were collected.
|
192 |
-
"""
|
193 |
-
delta = self._get_delta(name)
|
194 |
-
if int(delta[0]) == 0:
|
195 |
-
return float('nan')
|
196 |
-
return float(delta[1] / delta[0])
|
197 |
-
|
198 |
-
def std(self, name):
|
199 |
-
r"""Returns the standard deviation of the scalars that were
|
200 |
-
accumulated for the given statistic between the last two calls to
|
201 |
-
`update()`, or NaN if no scalars were collected.
|
202 |
-
"""
|
203 |
-
delta = self._get_delta(name)
|
204 |
-
if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
|
205 |
-
return float('nan')
|
206 |
-
if int(delta[0]) == 1:
|
207 |
-
return float(0)
|
208 |
-
mean = float(delta[1] / delta[0])
|
209 |
-
raw_var = float(delta[2] / delta[0])
|
210 |
-
return np.sqrt(max(raw_var - np.square(mean), 0))
|
211 |
-
|
212 |
-
def as_dict(self):
|
213 |
-
r"""Returns the averages accumulated between the last two calls to
|
214 |
-
`update()` as an `dnnlib.EasyDict`. The contents are as follows:
|
215 |
-
|
216 |
-
dnnlib.EasyDict(
|
217 |
-
NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
|
218 |
-
...
|
219 |
-
)
|
220 |
-
"""
|
221 |
-
stats = dnnlib.EasyDict()
|
222 |
-
for name in self.names():
|
223 |
-
stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
|
224 |
-
return stats
|
225 |
-
|
226 |
-
def __getitem__(self, name):
|
227 |
-
r"""Convenience getter.
|
228 |
-
`collector[name]` is a synonym for `collector.mean(name)`.
|
229 |
-
"""
|
230 |
-
return self.mean(name)
|
231 |
-
|
232 |
-
#----------------------------------------------------------------------------
|
233 |
-
|
234 |
-
def _sync(names):
|
235 |
-
r"""Synchronize the global cumulative counters across devices and
|
236 |
-
processes. Called internally by `Collector.update()`.
|
237 |
-
"""
|
238 |
-
if len(names) == 0:
|
239 |
-
return []
|
240 |
-
global _sync_called
|
241 |
-
_sync_called = True
|
242 |
-
|
243 |
-
# Collect deltas within current rank.
|
244 |
-
deltas = []
|
245 |
-
device = _sync_device if _sync_device is not None else torch.device('cpu')
|
246 |
-
for name in names:
|
247 |
-
delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
|
248 |
-
for counter in _counters[name].values():
|
249 |
-
delta.add_(counter.to(device))
|
250 |
-
counter.copy_(torch.zeros_like(counter))
|
251 |
-
deltas.append(delta)
|
252 |
-
deltas = torch.stack(deltas)
|
253 |
-
|
254 |
-
# Sum deltas across ranks.
|
255 |
-
if _sync_device is not None:
|
256 |
-
torch.distributed.all_reduce(deltas)
|
257 |
-
|
258 |
-
# Update cumulative values.
|
259 |
-
deltas = deltas.cpu()
|
260 |
-
for idx, name in enumerate(names):
|
261 |
-
if name not in _cumulative:
|
262 |
-
_cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
263 |
-
_cumulative[name].add_(deltas[idx])
|
264 |
-
|
265 |
-
# Return name-value pairs.
|
266 |
-
return [(name, _cumulative[name]) for name in names]
|
267 |
-
|
268 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -1,40 +1,89 @@
|
|
1 |
import gradio as gr
|
2 |
-
import utils
|
3 |
from PIL import Image
|
4 |
import torch
|
5 |
import math
|
6 |
from torchvision import transforms
|
7 |
-
|
8 |
-
|
9 |
device = "cpu"
|
10 |
years = [str(y) for y in range(1880, 2020, 10)]
|
|
|
11 |
|
12 |
|
|
|
|
|
|
|
|
|
|
|
13 |
orig_models = {}
|
14 |
|
15 |
for year in years:
|
16 |
G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device)
|
17 |
-
orig_models[year] = { "G": G.eval()}
|
18 |
|
19 |
|
20 |
def run_alignment(image_path,idx=None):
|
21 |
import dlib
|
22 |
from align_all_parallel import align_face
|
23 |
predictor = dlib.shape_predictor("pretrained_models/shape_predictor_68_face_landmarks.dat")
|
24 |
-
aligned_image = align_face(filepath=image_path, predictor=predictor, idx=idx)
|
25 |
-
|
26 |
|
27 |
return aligned_image
|
28 |
|
29 |
-
def predict(inp):
|
30 |
#with torch.no_grad():
|
31 |
inp.save("imgs/input.png")
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
|
36 |
gr.Interface(fn=predict,
|
37 |
-
inputs=gr.Image(type="pil"),
|
38 |
-
outputs=gr.Image(type="pil"),
|
39 |
-
|
40 |
-
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import utils.utils as utils
|
3 |
from PIL import Image
|
4 |
import torch
|
5 |
import math
|
6 |
from torchvision import transforms
|
7 |
+
from run_pti import run_PTI
|
|
|
8 |
device = "cpu"
|
9 |
years = [str(y) for y in range(1880, 2020, 10)]
|
10 |
+
decades = [y + "s" for y in years]
|
11 |
|
12 |
|
13 |
+
transform = transforms.Compose([
|
14 |
+
transforms.Resize((256, 256)),
|
15 |
+
transforms.ToTensor(),
|
16 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
17 |
+
|
18 |
orig_models = {}
|
19 |
|
20 |
for year in years:
|
21 |
G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device)
|
22 |
+
orig_models[year] = { "G": G.eval().float()}
|
23 |
|
24 |
|
25 |
def run_alignment(image_path,idx=None):
|
26 |
import dlib
|
27 |
from align_all_parallel import align_face
|
28 |
predictor = dlib.shape_predictor("pretrained_models/shape_predictor_68_face_landmarks.dat")
|
29 |
+
aligned_image = align_face(filepath=image_path, predictor=predictor, idx=idx)
|
30 |
+
|
31 |
|
32 |
return aligned_image
|
33 |
|
34 |
+
def predict(inp, in_decade):
|
35 |
#with torch.no_grad():
|
36 |
inp.save("imgs/input.png")
|
37 |
+
inversion = run_alignment("imgs/input.png", idx=0)
|
38 |
+
inversion.save("imgs/cropped/input.png")
|
39 |
+
run_PTI(run_name="gradio_demo", use_wandb=False, use_multi_id_training=False)
|
40 |
+
#inversion = Image.open("imgs/cropped/input.png")
|
41 |
+
|
42 |
+
in_year = in_decade[:-1]
|
43 |
+
pti_models = {}
|
44 |
+
|
45 |
+
for year in years:
|
46 |
+
G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device)
|
47 |
+
pti_models[year] = { "G": G.eval().float()}
|
48 |
+
|
49 |
+
|
50 |
+
pti_models[in_year]['G'] = torch.load(f"checkpoints/model_gradio_demo_input.pt", device).eval().float()
|
51 |
+
|
52 |
+
for year in years:
|
53 |
+
if year != in_year:
|
54 |
+
for p_pti, p_orig, (names, p) in zip(pti_models[in_year]['G'].parameters(),orig_models[in_year]['G'].parameters(), pti_models[year]['G'].named_parameters()):
|
55 |
+
with torch.no_grad():
|
56 |
+
delta = p_pti - p_orig
|
57 |
+
p += delta
|
58 |
+
|
59 |
+
space = 0
|
60 |
+
dst = Image.new("RGB", (256 * (len(years) + 1) + (space * len(years)), 256), color='white')
|
61 |
+
|
62 |
+
|
63 |
+
w_pti = torch.load(f"embeddings/{in_year}/PTI/input/0.pt", map_location=device)
|
64 |
+
|
65 |
+
border_width = 10
|
66 |
+
#fill_color = 'red'
|
67 |
+
dst.paste(inversion, (0, 0))
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
for i in range(0, len(years)):
|
72 |
+
year = str(years[i])
|
73 |
+
with torch.no_grad():
|
74 |
+
child_tensor = pti_models[year]["G"].synthesis(w_pti.view(1, 14, 512), noise_mode="const", force_fp32=True)
|
75 |
+
img = utils.tensor2im(child_tensor.squeeze(0))
|
76 |
+
# if year == in_year:
|
77 |
+
# img = img.crop((border_width, border_width, 256 - border_width, 256-border_width))
|
78 |
+
# img = PIL.ImageOps.expand(img, border=border_width, fill=fill_color)
|
79 |
+
dst.paste(img, ((256 + space) * (i+1), 0))
|
80 |
+
dst
|
81 |
+
return dst
|
82 |
|
83 |
|
84 |
gr.Interface(fn=predict,
|
85 |
+
inputs=[gr.Image(label="Input Image", type="pil"), gr.Dropdown(label="Input Decade", choices=decades, value="2010s")],
|
86 |
+
outputs=gr.Image(label="Decade Transformations", type="pil"),
|
87 |
+
examples=[["imgs/Steven-Yeun.jpg", "2010s"]]
|
88 |
+
|
89 |
+
).launch() #.launch(server_name="0.0.0.0", server_port=8098)
|
checkpoints/model_gradio_demo_input.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:65ee5644ec8ab0966a4eb51971995c071f8178db765750d45e32e0ed18a09738
|
3 |
+
size 99867041
|
PTI/criteria/color_transfer_loss.py β color_transfer_loss.py
RENAMED
File without changes
|
{PTI/configs β configs}/__init__.py
RENAMED
File without changes
|
{PTI/configs β configs}/evaluation_config.py
RENAMED
File without changes
|
{PTI/configs β configs}/global_config.py
RENAMED
@@ -1,6 +1,6 @@
|
|
1 |
## Device
|
2 |
-
cuda_visible_devices = "
|
3 |
-
device = "
|
4 |
|
5 |
## Logs
|
6 |
training_step = 1
|
|
|
1 |
## Device
|
2 |
+
cuda_visible_devices = "0"
|
3 |
+
device = "cpu"
|
4 |
|
5 |
## Logs
|
6 |
training_step = 1
|
{PTI/configs β configs}/hyperparameters.py
RENAMED
@@ -28,4 +28,4 @@ max_images_to_invert = 10
|
|
28 |
pti_learning_rate = 3e-4
|
29 |
first_inv_lr = 5e-3
|
30 |
train_batch_size = 1
|
31 |
-
use_last_w_pivots =
|
|
|
28 |
pti_learning_rate = 3e-4
|
29 |
first_inv_lr = 5e-3
|
30 |
train_batch_size = 1
|
31 |
+
use_last_w_pivots = False
|
{PTI/configs β configs}/paths_config.py
RENAMED
@@ -4,12 +4,12 @@ year = "2010"
|
|
4 |
e4e = "./pretrained_models/e4e_ffhq_encode.pt"
|
5 |
|
6 |
|
7 |
-
stylegan2_ada_ffhq = f"
|
8 |
|
9 |
style_clip_pretrained_mappers = ""
|
10 |
-
ir_se50 = "/
|
11 |
dlib = "./pretrained_models/align.dat"
|
12 |
-
deeplab = "/
|
13 |
|
14 |
## Dirs for output files
|
15 |
checkpoints_dir = "./checkpoints"
|
@@ -20,7 +20,7 @@ experiments_output_dir = "./output"
|
|
20 |
## Input info
|
21 |
### Input dir, where the images reside
|
22 |
input_data_path = (
|
23 |
-
f"/
|
24 |
)
|
25 |
input_data_id = f"{year}"
|
26 |
|
|
|
4 |
e4e = "./pretrained_models/e4e_ffhq_encode.pt"
|
5 |
|
6 |
|
7 |
+
stylegan2_ada_ffhq = f"pretrained_models/{year}.pkl"
|
8 |
|
9 |
style_clip_pretrained_mappers = ""
|
10 |
+
ir_se50 = "pretrained_models/model_ir_se50.pth"
|
11 |
dlib = "./pretrained_models/align.dat"
|
12 |
+
deeplab = "pretrained_models/deeplab_model/deeplab_model.pth"
|
13 |
|
14 |
## Dirs for output files
|
15 |
checkpoints_dir = "./checkpoints"
|
|
|
20 |
## Input info
|
21 |
### Input dir, where the images reside
|
22 |
input_data_path = (
|
23 |
+
f"imgs/cropped"
|
24 |
)
|
25 |
input_data_id = f"{year}"
|
26 |
|
{PTI/criteria β criteria}/__init__.py
RENAMED
File without changes
|
{PTI/criteria β criteria}/backbones/__init__.py
RENAMED
File without changes
|
{PTI/criteria β criteria}/backbones/iresnet.py
RENAMED
File without changes
|
{PTI/criteria β criteria}/backbones/iresnet2060.py
RENAMED
File without changes
|
{PTI/criteria β criteria}/backbones/mobilefacenet.py
RENAMED
File without changes
|
{PTI/criteria β criteria}/deeplab.py
RENAMED
File without changes
|
{PTI/criteria β criteria}/helpers.py
RENAMED
File without changes
|
{PTI/criteria β criteria}/id_loss.py
RENAMED
File without changes
|
{PTI/criteria β criteria}/l2_loss.py
RENAMED
File without changes
|
{PTI/criteria β criteria}/localitly_regulizer.py
RENAMED
File without changes
|
{PTI/criteria β criteria}/mask.py
RENAMED
File without changes
|
{PTI/criteria β criteria}/model_irse.py
RENAMED
File without changes
|
{PTI/criteria β criteria}/validation.py
RENAMED
File without changes
|
dnnlib/__pycache__/__init__.cpython-39.pyc
CHANGED
Binary files a/dnnlib/__pycache__/__init__.cpython-39.pyc and b/dnnlib/__pycache__/__init__.cpython-39.pyc differ
|
|
dnnlib/__pycache__/util.cpython-39.pyc
CHANGED
Binary files a/dnnlib/__pycache__/util.cpython-39.pyc and b/dnnlib/__pycache__/util.cpython-39.pyc differ
|
|
embeddings/2010/PTI/input/0.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cb9b3fc3d08f6dd3ee5c87299fff8cd932e4a6afaa90fe06f3d5c5f9503ebf26
|
3 |
+
size 29419
|
imgs/Steven-Yeun.jpg
ADDED
Git LFS Details
|
imgs/cropped/input.png
ADDED
Git LFS Details
|
imgs/input.png
ADDED
Git LFS Details
|
{PTI/training β models/StyleCLIP}/__init__.py
RENAMED
File without changes
|
{PTI/training/coaches β models/StyleCLIP/criteria}/__init__.py
RENAMED
File without changes
|
models/StyleCLIP/criteria/clip_loss.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import clip
|
4 |
+
|
5 |
+
|
6 |
+
class CLIPLoss(torch.nn.Module):
|
7 |
+
|
8 |
+
def __init__(self, opts):
|
9 |
+
super(CLIPLoss, self).__init__()
|
10 |
+
self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
|
11 |
+
self.upsample = torch.nn.Upsample(scale_factor=7)
|
12 |
+
self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32)
|
13 |
+
|
14 |
+
def forward(self, image, text):
|
15 |
+
image = self.avg_pool(self.upsample(image))
|
16 |
+
similarity = 1 - self.model(image, text)[0] / 100
|
17 |
+
return similarity
|
models/StyleCLIP/criteria/id_loss.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from models.facial_recognition.model_irse import Backbone
|
5 |
+
|
6 |
+
|
7 |
+
class IDLoss(nn.Module):
|
8 |
+
def __init__(self, opts):
|
9 |
+
super(IDLoss, self).__init__()
|
10 |
+
print('Loading ResNet ArcFace')
|
11 |
+
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
|
12 |
+
self.facenet.load_state_dict(torch.load(opts.ir_se50_weights))
|
13 |
+
self.pool = torch.nn.AdaptiveAvgPool2d((256, 256))
|
14 |
+
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
15 |
+
self.facenet.eval()
|
16 |
+
self.opts = opts
|
17 |
+
|
18 |
+
def extract_feats(self, x):
|
19 |
+
if x.shape[2] != 256:
|
20 |
+
x = self.pool(x)
|
21 |
+
x = x[:, :, 35:223, 32:220] # Crop interesting region
|
22 |
+
x = self.face_pool(x)
|
23 |
+
x_feats = self.facenet(x)
|
24 |
+
return x_feats
|
25 |
+
|
26 |
+
def forward(self, y_hat, y):
|
27 |
+
n_samples = y.shape[0]
|
28 |
+
y_feats = self.extract_feats(y) # Otherwise use the feature from there
|
29 |
+
y_hat_feats = self.extract_feats(y_hat)
|
30 |
+
y_feats = y_feats.detach()
|
31 |
+
loss = 0
|
32 |
+
sim_improvement = 0
|
33 |
+
count = 0
|
34 |
+
for i in range(n_samples):
|
35 |
+
diff_target = y_hat_feats[i].dot(y_feats[i])
|
36 |
+
loss += 1 - diff_target
|
37 |
+
count += 1
|
38 |
+
|
39 |
+
return loss / count, sim_improvement / count
|