diff --git a/PTI/.gitignore b/PTI/.gitignore
deleted file mode 100644
index 8b137891791fe96927ad78e64b0aad7bded08bdc..0000000000000000000000000000000000000000
--- a/PTI/.gitignore
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/PTI/LICENSE b/PTI/LICENSE
deleted file mode 100644
index 490821a63031f8f8d115b2edee69ae26ac5ea0b1..0000000000000000000000000000000000000000
--- a/PTI/LICENSE
+++ /dev/null
@@ -1,21 +0,0 @@
-MIT License
-
-Copyright (c) 2021 Daniel Roich
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
diff --git a/PTI/README.md b/PTI/README.md
deleted file mode 100644
index e57569893ab1695dde78a5c2265553a3bb1fb50a..0000000000000000000000000000000000000000
--- a/PTI/README.md
+++ /dev/null
@@ -1,229 +0,0 @@
-# PTI: Pivotal Tuning for Latent-based editing of Real Images
-
-
-
-
-
-Inference Notebook:
-
-
-
-
-Pivotal Tuning Inversion (PTI) enables employing off-the-shelf latent based
-semantic editing techniques on real images using StyleGAN.
-PTI excels in identity preserving edits, portrayed through recognizable figures —
-Serena Williams and Robert Downey Jr. (top), and in handling faces which
-are clearly out-of-domain, e.g., due to heavy makeup (bottom).
-
-
-
-## Description
-Official Implementation of our PTI paper + code for evaluation metrics. PTI introduces an optimization mechanizem for solving the StyleGAN inversion task.
-Providing near-perfect reconstruction results while maintaining the high editing abilitis of the native StyleGAN latent space W. For more details, see
-
-## Recent Updates
-**2021.07.01**: Fixed files download phase in the inference notebook. Which might caused the notebook not to run smoothly.
-
-**2021.06.29**: Added support for CPU. In order to run PTI on CPU please change `device` parameter under `configs/global_config.py` to "cpu" instead of "cuda".
-
-**2021.06.25** : Adding mohawk edit using StyleCLIP+PTI in inference notebook.
- Updating documentation in inference notebook due to Google Drive rate limit reached.
- Currently, Google Drive does not allow to download the pretrined models using Colab automatically. Manual intervention might be needed.
-
-## Getting Started
-### Prerequisites
-- Linux or macOS
-- NVIDIA GPU + CUDA CuDNN (Not mandatory bur recommended)
-- Python 3
-
-### Installation
-- Dependencies:
- 1. lpips
- 2. wandb
- 3. pytorch
- 4. torchvision
- 5. matplotlib
- 6. dlib
-- All dependencies can be installed using *pip install* and the package name
-
-## Pretrained Models
-Please download the pretrained models from the following links.
-
-### Auxiliary Models
-We provide various auxiliary models needed for PTI inversion task.
-This includes the StyleGAN generator and pre-trained models used for loss computation.
-| Path | Description
-| :--- | :----------
-|[FFHQ StyleGAN](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl) | StyleGAN2-ada model trained on FFHQ with 1024x1024 output resolution.
-|[Dlib alignment](https://drive.google.com/file/d/1HKmjg6iXsWr4aFPuU0gBXPGR83wqMzq7/view?usp=sharing) | Dlib alignment used for images preproccessing.
-|[FFHQ e4e encoder](https://drive.google.com/file/d/1ALC5CLA89Ouw40TwvxcwebhzWXM5YSCm/view?usp=sharing) | Pretrained e4e encoder. Used for StyleCLIP editing.
-
-Note: The StyleGAN model is used directly from the official [stylegan2-ada-pytorch implementation](https://github.com/NVlabs/stylegan2-ada-pytorch).
-For StyleCLIP pretrained mappers, please see [StyleCLIP's official routes](https://github.com/orpatashnik/StyleCLIP/blob/main/utils.py)
-
-
-By default, we assume that all auxiliary models are downloaded and saved to the directory `pretrained_models`.
-However, you may use your own paths by changing the necessary values in `configs/path_configs.py`.
-
-
-## Inversion
-### Preparing your Data
-In order to invert a real image and edit it you should first align and crop it to the correct size. To do so you should perform *One* of the following steps:
-1. Run `notebooks/align_data.ipynb` and change the "images_path" variable to the raw images path
-2. Run `utils/align_data.py` and change the "images_path" variable to the raw images path
-
-
-### Weights And Biases
-The project supports [Weights And Biases](https://wandb.ai/home) framework for experiment tracking. For the inversion task it enables visualization of the losses progression and the generator intermediate results during the initial inversion and the *Pivotal Tuning*(PT) procedure.
-
-The log frequency can be adjusted using the parameters defined at `configs/global_config.py` under the "Logs" subsection.
-
-There is no no need to have an account. However, in order to use the features provided by Weights and Biases you first have to register on their site.
-
-
-### Running PTI
-The main training script is `scripts/run_pti.py`. The script receives aligned and cropped images from paths configured in the "Input info" subscetion in
- `configs/paths_config.py`.
-Results are saved to directories found at "Dirs for output files" under `configs/paths_config.py`. This includes inversion latent codes and tuned generators.
-The hyperparametrs for the inversion task can be found at `configs/hyperparameters.py`. They are intilized to the default values used in the paper.
-
-## Editing
-By default, we assume that all auxiliary edit directions are downloaded and saved to the directory `editings`.
-However, you may use your own paths by changing the necessary values in `configs/path_configs.py` under "Edit directions" subsection.
-
-Example of editing code can be found at `scripts/latent_editor_wrapper.py`
-
-## Inference Notebooks
-To help visualize the results of PTI we provide a Jupyter notebook found in `notebooks/inference_playground.ipynb`.
-The notebook will download the pretrained models and run inference on a sample image found online or
-on images of your choosing. It is recommended to run this in [Google Colab](https://colab.research.google.com/github/danielroich/PTI/blob/main/notebooks/inference_playground.ipynb).
-
-The notebook demonstrates how to:
-- Invert an image using PTI
-- Visualise the inversion and use the PTI output
-- Edit the image after PTI using InterfaceGAN and StyleCLIP
-- Compare to other inversion methods
-
-## Evaluation
-Currently the repository supports qualitative evaluation for reconstruction of: PTI, SG2 (*W Space*), e4e, SG2Plus (*W+ Space*).
-As well as editing using InterfaceGAN and GANSpace for the same inversion methods.
-To run the evaluation please see `evaluation/qualitative_edit_comparison.py`. Examples of the evaluation scripts are:
-
-
-
-
-Reconsturction comparison between different methods. The images order is: Original image, W+ inversion, e4e inversion, W inversion, PTI inversion
-
-
-
-
-
-
-InterfaceGAN pose edit comparison between different methods. The images order is: Original, W+, e4e, W, PTI
-
-
-
-
-
-
-
-Image per edit or several edits without comparison
-
-
-
-### Coming Soon - Quantitative evaluation and StyleCLIP qualitative evaluation
-
-## Repository structure
-| Path | Description
-| :--- | :---
-| ├ configs | Folder containing configs defining Hyperparameters, paths and logging
-| ├ criteria | Folder containing various loss and regularization criterias for the optimization
-| ├ dnnlib | Folder containing internal utils for StyleGAN2-ada
-| ├ docs | Folder containing the latent space edit directions
-| ├ editings | Folder containing images displayed in the README
-| ├ environment | Folder containing Anaconda environment used in our experiments
-| ├ licenses | Folder containing licenses of the open source projects used in this repository
-| ├ models | Folder containing models used in different editing techniques and first phase inversion
-| ├ notebooks | Folder with jupyter notebooks to demonstrate the usage of PTI end-to-end
-| ├ scripts | Folder with running scripts for inversion, editing and metric computations
-| ├ torch_utils | Folder containing internal utils for StyleGAN2-ada
-| ├ training | Folder containing the core training logic of PTI
-| ├ utils | Folder with various utility functions
-
-
-## Credits
-**StyleGAN2-ada model and implementation:**
-https://github.com/NVlabs/stylegan2-ada-pytorch
-Copyright © 2021, NVIDIA Corporation.
-Nvidia Source Code License https://nvlabs.github.io/stylegan2-ada-pytorch/license.html
-
-**LPIPS model and implementation:**
-https://github.com/richzhang/PerceptualSimilarity
-Copyright (c) 2020, Sou Uchida
-License (BSD 2-Clause) https://github.com/richzhang/PerceptualSimilarity/blob/master/LICENSE
-
-**e4e model and implementation:**
-https://github.com/omertov/encoder4editing
-Copyright (c) 2021 omertov
-License (MIT) https://github.com/omertov/encoder4editing/blob/main/LICENSE
-
-**StyleCLIP model and implementation:**
-https://github.com/orpatashnik/StyleCLIP
-Copyright (c) 2021 orpatashnik
-License (MIT) https://github.com/orpatashnik/StyleCLIP/blob/main/LICENSE
-
-**InterfaceGAN implementation:**
-https://github.com/genforce/interfacegan
-Copyright (c) 2020 genforce
-License (MIT) https://github.com/genforce/interfacegan/blob/master/LICENSE
-
-**GANSpace implementation:**
-https://github.com/harskish/ganspace
-Copyright (c) 2020 harkish
-License (Apache License 2.0) https://github.com/harskish/ganspace/blob/master/LICENSE
-
-
-## Acknowledgments
-This repository structure is based on [encoder4editing](https://github.com/omertov/encoder4editing) and [ReStyle](https://github.com/yuval-alaluf/restyle-encoder) repositories
-
-## Contact
-For any inquiry please contact us at our email addresses: danielroich@gmail.com or ron.mokady@gmail.com
-
-
-## Citation
-If you use this code for your research, please cite:
-```
-@article{roich2021pivotal,
- title={Pivotal Tuning for Latent-based Editing of Real Images},
- author={Roich, Daniel and Mokady, Ron and Bermano, Amit H and Cohen-Or, Daniel},
- journal={arXiv preprint arXiv:2106.05744},
- year={2021}
-}
-```
diff --git a/PTI/torch_utils/custom_ops.py b/PTI/torch_utils/custom_ops.py
deleted file mode 100644
index 4cc4e43fc6f6ce79f2bd68a44ba87990b9b8564e..0000000000000000000000000000000000000000
--- a/PTI/torch_utils/custom_ops.py
+++ /dev/null
@@ -1,126 +0,0 @@
-# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-#
-# NVIDIA CORPORATION and its licensors retain all intellectual property
-# and proprietary rights in and to this software, related documentation
-# and any modifications thereto. Any use, reproduction, disclosure or
-# distribution of this software and related documentation without an express
-# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-import os
-import glob
-import torch
-import torch.utils.cpp_extension
-import importlib
-import hashlib
-import shutil
-from pathlib import Path
-
-from torch.utils.file_baton import FileBaton
-
-#----------------------------------------------------------------------------
-# Global options.
-
-verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
-
-#----------------------------------------------------------------------------
-# Internal helper funcs.
-
-def _find_compiler_bindir():
- patterns = [
- 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
- 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
- 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
- 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
- ]
- for pattern in patterns:
- matches = sorted(glob.glob(pattern))
- if len(matches):
- return matches[-1]
- return None
-
-#----------------------------------------------------------------------------
-# Main entry point for compiling and loading C++/CUDA plugins.
-
-_cached_plugins = dict()
-
-def get_plugin(module_name, sources, **build_kwargs):
- assert verbosity in ['none', 'brief', 'full']
-
- # Already cached?
- if module_name in _cached_plugins:
- return _cached_plugins[module_name]
-
- # Print status.
- if verbosity == 'full':
- print(f'Setting up PyTorch plugin "{module_name}"...')
- elif verbosity == 'brief':
- print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
-
- try: # pylint: disable=too-many-nested-blocks
- # Make sure we can find the necessary compiler binaries.
- if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
- compiler_bindir = _find_compiler_bindir()
- if compiler_bindir is None:
- raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
- os.environ['PATH'] += ';' + compiler_bindir
-
- # Compile and load.
- verbose_build = (verbosity == 'full')
-
- # Incremental build md5sum trickery. Copies all the input source files
- # into a cached build directory under a combined md5 digest of the input
- # source files. Copying is done only if the combined digest has changed.
- # This keeps input file timestamps and filenames the same as in previous
- # extension builds, allowing for fast incremental rebuilds.
- #
- # This optimization is done only in case all the source files reside in
- # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
- # environment variable is set (we take this as a signal that the user
- # actually cares about this.)
- source_dirs_set = set(os.path.dirname(source) for source in sources)
- if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
- all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
-
- # Compute a combined hash digest for all source files in the same
- # custom op directory (usually .cu, .cpp, .py and .h files).
- hash_md5 = hashlib.md5()
- for src in all_source_files:
- with open(src, 'rb') as f:
- hash_md5.update(f.read())
- build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
- digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
-
- if not os.path.isdir(digest_build_dir):
- os.makedirs(digest_build_dir, exist_ok=True)
- baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
- if baton.try_acquire():
- try:
- for src in all_source_files:
- shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
- finally:
- baton.release()
- else:
- # Someone else is copying source files under the digest dir,
- # wait until done and continue.
- baton.wait()
- digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
- torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
- verbose=verbose_build, sources=digest_sources, **build_kwargs)
- else:
- torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
- module = importlib.import_module(module_name)
-
- except:
- if verbosity == 'brief':
- print('Failed!')
- raise
-
- # Print status and add to cache.
- if verbosity == 'full':
- print(f'Done setting up PyTorch plugin "{module_name}".')
- elif verbosity == 'brief':
- print('Done.')
- _cached_plugins[module_name] = module
- return module
-
-#----------------------------------------------------------------------------
diff --git a/PTI/torch_utils/misc.py b/PTI/torch_utils/misc.py
deleted file mode 100644
index 7829f4d9f168557ce8a9a6dec289aa964234cb8c..0000000000000000000000000000000000000000
--- a/PTI/torch_utils/misc.py
+++ /dev/null
@@ -1,262 +0,0 @@
-# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-#
-# NVIDIA CORPORATION and its licensors retain all intellectual property
-# and proprietary rights in and to this software, related documentation
-# and any modifications thereto. Any use, reproduction, disclosure or
-# distribution of this software and related documentation without an express
-# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-import re
-import contextlib
-import numpy as np
-import torch
-import warnings
-import dnnlib
-
-#----------------------------------------------------------------------------
-# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
-# same constant is used multiple times.
-
-_constant_cache = dict()
-
-def constant(value, shape=None, dtype=None, device=None, memory_format=None):
- value = np.asarray(value)
- if shape is not None:
- shape = tuple(shape)
- if dtype is None:
- dtype = torch.get_default_dtype()
- if device is None:
- device = torch.device('cpu')
- if memory_format is None:
- memory_format = torch.contiguous_format
-
- key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
- tensor = _constant_cache.get(key, None)
- if tensor is None:
- tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
- if shape is not None:
- tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
- tensor = tensor.contiguous(memory_format=memory_format)
- _constant_cache[key] = tensor
- return tensor
-
-#----------------------------------------------------------------------------
-# Replace NaN/Inf with specified numerical values.
-
-try:
- nan_to_num = torch.nan_to_num # 1.8.0a0
-except AttributeError:
- def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
- assert isinstance(input, torch.Tensor)
- if posinf is None:
- posinf = torch.finfo(input.dtype).max
- if neginf is None:
- neginf = torch.finfo(input.dtype).min
- assert nan == 0
- return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
-
-#----------------------------------------------------------------------------
-# Symbolic assert.
-
-try:
- symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
-except AttributeError:
- symbolic_assert = torch.Assert # 1.7.0
-
-#----------------------------------------------------------------------------
-# Context manager to suppress known warnings in torch.jit.trace().
-
-class suppress_tracer_warnings(warnings.catch_warnings):
- def __enter__(self):
- super().__enter__()
- warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
- return self
-
-#----------------------------------------------------------------------------
-# Assert that the shape of a tensor matches the given list of integers.
-# None indicates that the size of a dimension is allowed to vary.
-# Performs symbolic assertion when used in torch.jit.trace().
-
-def assert_shape(tensor, ref_shape):
- if tensor.ndim != len(ref_shape):
- raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
- for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
- if ref_size is None:
- pass
- elif isinstance(ref_size, torch.Tensor):
- with suppress_tracer_warnings(): # as_tensor results are registered as constants
- symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
- elif isinstance(size, torch.Tensor):
- with suppress_tracer_warnings(): # as_tensor results are registered as constants
- symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
- elif size != ref_size:
- raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
-
-#----------------------------------------------------------------------------
-# Function decorator that calls torch.autograd.profiler.record_function().
-
-def profiled_function(fn):
- def decorator(*args, **kwargs):
- with torch.autograd.profiler.record_function(fn.__name__):
- return fn(*args, **kwargs)
- decorator.__name__ = fn.__name__
- return decorator
-
-#----------------------------------------------------------------------------
-# Sampler for torch.utils.data.DataLoader that loops over the dataset
-# indefinitely, shuffling items as it goes.
-
-class InfiniteSampler(torch.utils.data.Sampler):
- def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
- assert len(dataset) > 0
- assert num_replicas > 0
- assert 0 <= rank < num_replicas
- assert 0 <= window_size <= 1
- super().__init__(dataset)
- self.dataset = dataset
- self.rank = rank
- self.num_replicas = num_replicas
- self.shuffle = shuffle
- self.seed = seed
- self.window_size = window_size
-
- def __iter__(self):
- order = np.arange(len(self.dataset))
- rnd = None
- window = 0
- if self.shuffle:
- rnd = np.random.RandomState(self.seed)
- rnd.shuffle(order)
- window = int(np.rint(order.size * self.window_size))
-
- idx = 0
- while True:
- i = idx % order.size
- if idx % self.num_replicas == self.rank:
- yield order[i]
- if window >= 2:
- j = (i - rnd.randint(window)) % order.size
- order[i], order[j] = order[j], order[i]
- idx += 1
-
-#----------------------------------------------------------------------------
-# Utilities for operating with torch.nn.Module parameters and buffers.
-
-def params_and_buffers(module):
- assert isinstance(module, torch.nn.Module)
- return list(module.parameters()) + list(module.buffers())
-
-def named_params_and_buffers(module):
- assert isinstance(module, torch.nn.Module)
- return list(module.named_parameters()) + list(module.named_buffers())
-
-def copy_params_and_buffers(src_module, dst_module, require_all=False):
- assert isinstance(src_module, torch.nn.Module)
- assert isinstance(dst_module, torch.nn.Module)
- src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
- for name, tensor in named_params_and_buffers(dst_module):
- assert (name in src_tensors) or (not require_all)
- if name in src_tensors:
- tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
-
-#----------------------------------------------------------------------------
-# Context manager for easily enabling/disabling DistributedDataParallel
-# synchronization.
-
-@contextlib.contextmanager
-def ddp_sync(module, sync):
- assert isinstance(module, torch.nn.Module)
- if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
- yield
- else:
- with module.no_sync():
- yield
-
-#----------------------------------------------------------------------------
-# Check DistributedDataParallel consistency across processes.
-
-def check_ddp_consistency(module, ignore_regex=None):
- assert isinstance(module, torch.nn.Module)
- for name, tensor in named_params_and_buffers(module):
- fullname = type(module).__name__ + '.' + name
- if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
- continue
- tensor = tensor.detach()
- other = tensor.clone()
- torch.distributed.broadcast(tensor=other, src=0)
- assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
-
-#----------------------------------------------------------------------------
-# Print summary table of module hierarchy.
-
-def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
- assert isinstance(module, torch.nn.Module)
- assert not isinstance(module, torch.jit.ScriptModule)
- assert isinstance(inputs, (tuple, list))
-
- # Register hooks.
- entries = []
- nesting = [0]
- def pre_hook(_mod, _inputs):
- nesting[0] += 1
- def post_hook(mod, _inputs, outputs):
- nesting[0] -= 1
- if nesting[0] <= max_nesting:
- outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
- outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
- entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
- hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
- hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
-
- # Run module.
- outputs = module(*inputs)
- for hook in hooks:
- hook.remove()
-
- # Identify unique outputs, parameters, and buffers.
- tensors_seen = set()
- for e in entries:
- e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
- e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
- e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
- tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
-
- # Filter out redundant entries.
- if skip_redundant:
- entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
-
- # Construct table.
- rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
- rows += [['---'] * len(rows[0])]
- param_total = 0
- buffer_total = 0
- submodule_names = {mod: name for name, mod in module.named_modules()}
- for e in entries:
- name = '' if e.mod is module else submodule_names[e.mod]
- param_size = sum(t.numel() for t in e.unique_params)
- buffer_size = sum(t.numel() for t in e.unique_buffers)
- output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
- output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
- rows += [[
- name + (':0' if len(e.outputs) >= 2 else ''),
- str(param_size) if param_size else '-',
- str(buffer_size) if buffer_size else '-',
- (output_shapes + ['-'])[0],
- (output_dtypes + ['-'])[0],
- ]]
- for idx in range(1, len(e.outputs)):
- rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
- param_total += param_size
- buffer_total += buffer_size
- rows += [['---'] * len(rows[0])]
- rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
-
- # Print table.
- widths = [max(len(cell) for cell in column) for column in zip(*rows)]
- print()
- for row in rows:
- print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
- print()
- return outputs
-
-#----------------------------------------------------------------------------
diff --git a/PTI/torch_utils/ops/bias_act.cpp b/PTI/torch_utils/ops/bias_act.cpp
deleted file mode 100644
index 5d2425d8054991a8e8b6f7a940fd0ff7fa0bb330..0000000000000000000000000000000000000000
--- a/PTI/torch_utils/ops/bias_act.cpp
+++ /dev/null
@@ -1,99 +0,0 @@
-// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-//
-// NVIDIA CORPORATION and its licensors retain all intellectual property
-// and proprietary rights in and to this software, related documentation
-// and any modifications thereto. Any use, reproduction, disclosure or
-// distribution of this software and related documentation without an express
-// license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-#include
-#include
-#include
-#include "bias_act.h"
-
-//------------------------------------------------------------------------
-
-static bool has_same_layout(torch::Tensor x, torch::Tensor y)
-{
- if (x.dim() != y.dim())
- return false;
- for (int64_t i = 0; i < x.dim(); i++)
- {
- if (x.size(i) != y.size(i))
- return false;
- if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
- return false;
- }
- return true;
-}
-
-//------------------------------------------------------------------------
-
-static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
-{
- // Validate arguments.
- TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
- TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
- TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
- TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
- TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
- TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
- TORCH_CHECK(b.dim() == 1, "b must have rank 1");
- TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
- TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
- TORCH_CHECK(grad >= 0, "grad must be non-negative");
-
- // Validate layout.
- TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
- TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
- TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
- TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
- TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
-
- // Create output tensor.
- const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
- torch::Tensor y = torch::empty_like(x);
- TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
-
- // Initialize CUDA kernel parameters.
- bias_act_kernel_params p;
- p.x = x.data_ptr();
- p.b = (b.numel()) ? b.data_ptr() : NULL;
- p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
- p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
- p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
- p.y = y.data_ptr();
- p.grad = grad;
- p.act = act;
- p.alpha = alpha;
- p.gain = gain;
- p.clamp = clamp;
- p.sizeX = (int)x.numel();
- p.sizeB = (int)b.numel();
- p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
-
- // Choose CUDA kernel.
- void* kernel;
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
- {
- kernel = choose_bias_act_kernel(p);
- });
- TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
-
- // Launch CUDA kernel.
- p.loopX = 4;
- int blockSize = 4 * 32;
- int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
- void* args[] = {&p};
- AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
- return y;
-}
-
-//------------------------------------------------------------------------
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
-{
- m.def("bias_act", &bias_act);
-}
-
-//------------------------------------------------------------------------
diff --git a/PTI/torch_utils/ops/bias_act.cu b/PTI/torch_utils/ops/bias_act.cu
deleted file mode 100644
index dd8fc4756d7d94727f94af738665b68d9c518880..0000000000000000000000000000000000000000
--- a/PTI/torch_utils/ops/bias_act.cu
+++ /dev/null
@@ -1,173 +0,0 @@
-// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-//
-// NVIDIA CORPORATION and its licensors retain all intellectual property
-// and proprietary rights in and to this software, related documentation
-// and any modifications thereto. Any use, reproduction, disclosure or
-// distribution of this software and related documentation without an express
-// license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-#include
-#include "bias_act.h"
-
-//------------------------------------------------------------------------
-// Helpers.
-
-template struct InternalType;
-template <> struct InternalType { typedef double scalar_t; };
-template <> struct InternalType { typedef float scalar_t; };
-template <> struct InternalType { typedef float scalar_t; };
-
-//------------------------------------------------------------------------
-// CUDA kernel.
-
-template
-__global__ void bias_act_kernel(bias_act_kernel_params p)
-{
- typedef typename InternalType::scalar_t scalar_t;
- int G = p.grad;
- scalar_t alpha = (scalar_t)p.alpha;
- scalar_t gain = (scalar_t)p.gain;
- scalar_t clamp = (scalar_t)p.clamp;
- scalar_t one = (scalar_t)1;
- scalar_t two = (scalar_t)2;
- scalar_t expRange = (scalar_t)80;
- scalar_t halfExpRange = (scalar_t)40;
- scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
- scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
-
- // Loop over elements.
- int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
- for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
- {
- // Load.
- scalar_t x = (scalar_t)((const T*)p.x)[xi];
- scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
- scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
- scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
- scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
- scalar_t yy = (gain != 0) ? yref / gain : 0;
- scalar_t y = 0;
-
- // Apply bias.
- ((G == 0) ? x : xref) += b;
-
- // linear
- if (A == 1)
- {
- if (G == 0) y = x;
- if (G == 1) y = x;
- }
-
- // relu
- if (A == 2)
- {
- if (G == 0) y = (x > 0) ? x : 0;
- if (G == 1) y = (yy > 0) ? x : 0;
- }
-
- // lrelu
- if (A == 3)
- {
- if (G == 0) y = (x > 0) ? x : x * alpha;
- if (G == 1) y = (yy > 0) ? x : x * alpha;
- }
-
- // tanh
- if (A == 4)
- {
- if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
- if (G == 1) y = x * (one - yy * yy);
- if (G == 2) y = x * (one - yy * yy) * (-two * yy);
- }
-
- // sigmoid
- if (A == 5)
- {
- if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
- if (G == 1) y = x * yy * (one - yy);
- if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
- }
-
- // elu
- if (A == 6)
- {
- if (G == 0) y = (x >= 0) ? x : exp(x) - one;
- if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
- if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
- }
-
- // selu
- if (A == 7)
- {
- if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
- if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
- if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
- }
-
- // softplus
- if (A == 8)
- {
- if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
- if (G == 1) y = x * (one - exp(-yy));
- if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
- }
-
- // swish
- if (A == 9)
- {
- if (G == 0)
- y = (x < -expRange) ? 0 : x / (exp(-x) + one);
- else
- {
- scalar_t c = exp(xref);
- scalar_t d = c + one;
- if (G == 1)
- y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
- else
- y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
- yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
- }
- }
-
- // Apply gain.
- y *= gain * dy;
-
- // Clamp.
- if (clamp >= 0)
- {
- if (G == 0)
- y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
- else
- y = (yref > -clamp & yref < clamp) ? y : 0;
- }
-
- // Store.
- ((T*)p.y)[xi] = (T)y;
- }
-}
-
-//------------------------------------------------------------------------
-// CUDA kernel selection.
-
-template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
-{
- if (p.act == 1) return (void*)bias_act_kernel;
- if (p.act == 2) return (void*)bias_act_kernel;
- if (p.act == 3) return (void*)bias_act_kernel;
- if (p.act == 4) return (void*)bias_act_kernel;
- if (p.act == 5) return (void*)bias_act_kernel;
- if (p.act == 6) return (void*)bias_act_kernel;
- if (p.act == 7) return (void*)bias_act_kernel;
- if (p.act == 8) return (void*)bias_act_kernel;
- if (p.act == 9) return (void*)bias_act_kernel;
- return NULL;
-}
-
-//------------------------------------------------------------------------
-// Template specializations.
-
-template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
-template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
-template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
-
-//------------------------------------------------------------------------
diff --git a/PTI/torch_utils/ops/bias_act.h b/PTI/torch_utils/ops/bias_act.h
deleted file mode 100644
index a32187e1fb7e3bae509d4eceaf900866866875a4..0000000000000000000000000000000000000000
--- a/PTI/torch_utils/ops/bias_act.h
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-//
-// NVIDIA CORPORATION and its licensors retain all intellectual property
-// and proprietary rights in and to this software, related documentation
-// and any modifications thereto. Any use, reproduction, disclosure or
-// distribution of this software and related documentation without an express
-// license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-//------------------------------------------------------------------------
-// CUDA kernel parameters.
-
-struct bias_act_kernel_params
-{
- const void* x; // [sizeX]
- const void* b; // [sizeB] or NULL
- const void* xref; // [sizeX] or NULL
- const void* yref; // [sizeX] or NULL
- const void* dy; // [sizeX] or NULL
- void* y; // [sizeX]
-
- int grad;
- int act;
- float alpha;
- float gain;
- float clamp;
-
- int sizeX;
- int sizeB;
- int stepB;
- int loopX;
-};
-
-//------------------------------------------------------------------------
-// CUDA kernel selection.
-
-template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
-
-//------------------------------------------------------------------------
diff --git a/PTI/torch_utils/ops/bias_act.py b/PTI/torch_utils/ops/bias_act.py
deleted file mode 100644
index 4bcb409a89ccf6c6f6ecfca5962683df2d280b1f..0000000000000000000000000000000000000000
--- a/PTI/torch_utils/ops/bias_act.py
+++ /dev/null
@@ -1,212 +0,0 @@
-# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-#
-# NVIDIA CORPORATION and its licensors retain all intellectual property
-# and proprietary rights in and to this software, related documentation
-# and any modifications thereto. Any use, reproduction, disclosure or
-# distribution of this software and related documentation without an express
-# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-"""Custom PyTorch ops for efficient bias and activation."""
-
-import os
-import warnings
-import numpy as np
-import torch
-import dnnlib
-import traceback
-
-from .. import custom_ops
-from .. import misc
-
-#----------------------------------------------------------------------------
-
-activation_funcs = {
- 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
- 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
- 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
- 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
- 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
- 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
- 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
- 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
- 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
-}
-
-#----------------------------------------------------------------------------
-
-_inited = False
-_plugin = None
-_null_tensor = torch.empty([0])
-
-def _init():
- global _inited, _plugin
- if not _inited:
- _inited = True
- sources = ['bias_act.cpp', 'bias_act.cu']
- sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
- try:
- _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
- except:
- warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
- return _plugin is not None
-
-#----------------------------------------------------------------------------
-
-def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
- r"""Fused bias and activation function.
-
- Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
- and scales the result by `gain`. Each of the steps is optional. In most cases,
- the fused op is considerably more efficient than performing the same calculation
- using standard PyTorch ops. It supports first and second order gradients,
- but not third order gradients.
-
- Args:
- x: Input activation tensor. Can be of any shape.
- b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
- as `x`. The shape must be known, and it must match the dimension of `x`
- corresponding to `dim`.
- dim: The dimension in `x` corresponding to the elements of `b`.
- The value of `dim` is ignored if `b` is not specified.
- act: Name of the activation function to evaluate, or `"linear"` to disable.
- Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
- See `activation_funcs` for a full list. `None` is not allowed.
- alpha: Shape parameter for the activation function, or `None` to use the default.
- gain: Scaling factor for the output tensor, or `None` to use default.
- See `activation_funcs` for the default scaling of each activation function.
- If unsure, consider specifying 1.
- clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
- the clamping (default).
- impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
-
- Returns:
- Tensor of the same shape and datatype as `x`.
- """
- assert isinstance(x, torch.Tensor)
- assert impl in ['ref', 'cuda']
- if impl == 'cuda' and x.device.type == 'cuda' and _init():
- return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
- return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
-
-#----------------------------------------------------------------------------
-
-@misc.profiled_function
-def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
- """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
- """
- assert isinstance(x, torch.Tensor)
- assert clamp is None or clamp >= 0
- spec = activation_funcs[act]
- alpha = float(alpha if alpha is not None else spec.def_alpha)
- gain = float(gain if gain is not None else spec.def_gain)
- clamp = float(clamp if clamp is not None else -1)
-
- # Add bias.
- if b is not None:
- assert isinstance(b, torch.Tensor) and b.ndim == 1
- assert 0 <= dim < x.ndim
- assert b.shape[0] == x.shape[dim]
- x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
-
- # Evaluate activation function.
- alpha = float(alpha)
- x = spec.func(x, alpha=alpha)
-
- # Scale by gain.
- gain = float(gain)
- if gain != 1:
- x = x * gain
-
- # Clamp.
- if clamp >= 0:
- x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
- return x
-
-#----------------------------------------------------------------------------
-
-_bias_act_cuda_cache = dict()
-
-def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
- """Fast CUDA implementation of `bias_act()` using custom ops.
- """
- # Parse arguments.
- assert clamp is None or clamp >= 0
- spec = activation_funcs[act]
- alpha = float(alpha if alpha is not None else spec.def_alpha)
- gain = float(gain if gain is not None else spec.def_gain)
- clamp = float(clamp if clamp is not None else -1)
-
- # Lookup from cache.
- key = (dim, act, alpha, gain, clamp)
- if key in _bias_act_cuda_cache:
- return _bias_act_cuda_cache[key]
-
- # Forward op.
- class BiasActCuda(torch.autograd.Function):
- @staticmethod
- def forward(ctx, x, b): # pylint: disable=arguments-differ
- ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
- x = x.contiguous(memory_format=ctx.memory_format)
- b = b.contiguous() if b is not None else _null_tensor
- y = x
- if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
- y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
- ctx.save_for_backward(
- x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
- b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
- y if 'y' in spec.ref else _null_tensor)
- return y
-
- @staticmethod
- def backward(ctx, dy): # pylint: disable=arguments-differ
- dy = dy.contiguous(memory_format=ctx.memory_format)
- x, b, y = ctx.saved_tensors
- dx = None
- db = None
-
- if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
- dx = dy
- if act != 'linear' or gain != 1 or clamp >= 0:
- dx = BiasActCudaGrad.apply(dy, x, b, y)
-
- if ctx.needs_input_grad[1]:
- db = dx.sum([i for i in range(dx.ndim) if i != dim])
-
- return dx, db
-
- # Backward op.
- class BiasActCudaGrad(torch.autograd.Function):
- @staticmethod
- def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
- ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
- dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
- ctx.save_for_backward(
- dy if spec.has_2nd_grad else _null_tensor,
- x, b, y)
- return dx
-
- @staticmethod
- def backward(ctx, d_dx): # pylint: disable=arguments-differ
- d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
- dy, x, b, y = ctx.saved_tensors
- d_dy = None
- d_x = None
- d_b = None
- d_y = None
-
- if ctx.needs_input_grad[0]:
- d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
-
- if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
- d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
-
- if spec.has_2nd_grad and ctx.needs_input_grad[2]:
- d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
-
- return d_dy, d_x, d_b, d_y
-
- # Add to cache.
- _bias_act_cuda_cache[key] = BiasActCuda
- return BiasActCuda
-
-#----------------------------------------------------------------------------
diff --git a/PTI/torch_utils/ops/conv2d_gradfix.py b/PTI/torch_utils/ops/conv2d_gradfix.py
deleted file mode 100644
index e95e10d0b1d0315a63a76446fd4c5c293c8bbc6d..0000000000000000000000000000000000000000
--- a/PTI/torch_utils/ops/conv2d_gradfix.py
+++ /dev/null
@@ -1,170 +0,0 @@
-# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-#
-# NVIDIA CORPORATION and its licensors retain all intellectual property
-# and proprietary rights in and to this software, related documentation
-# and any modifications thereto. Any use, reproduction, disclosure or
-# distribution of this software and related documentation without an express
-# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-"""Custom replacement for `torch.nn.functional.conv2d` that supports
-arbitrarily high order gradients with zero performance penalty."""
-
-import warnings
-import contextlib
-import torch
-
-# pylint: disable=redefined-builtin
-# pylint: disable=arguments-differ
-# pylint: disable=protected-access
-
-#----------------------------------------------------------------------------
-
-enabled = False # Enable the custom op by setting this to true.
-weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
-
-@contextlib.contextmanager
-def no_weight_gradients():
- global weight_gradients_disabled
- old = weight_gradients_disabled
- weight_gradients_disabled = True
- yield
- weight_gradients_disabled = old
-
-#----------------------------------------------------------------------------
-
-def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
- if _should_use_custom_op(input):
- return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
- return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
-
-def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
- if _should_use_custom_op(input):
- return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
- return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
-
-#----------------------------------------------------------------------------
-
-def _should_use_custom_op(input):
- assert isinstance(input, torch.Tensor)
- if (not enabled) or (not torch.backends.cudnn.enabled):
- return False
- if input.device.type != 'cuda':
- return False
- if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
- return True
- warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
- return False
-
-def _tuple_of_ints(xs, ndim):
- xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
- assert len(xs) == ndim
- assert all(isinstance(x, int) for x in xs)
- return xs
-
-#----------------------------------------------------------------------------
-
-_conv2d_gradfix_cache = dict()
-
-def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
- # Parse arguments.
- ndim = 2
- weight_shape = tuple(weight_shape)
- stride = _tuple_of_ints(stride, ndim)
- padding = _tuple_of_ints(padding, ndim)
- output_padding = _tuple_of_ints(output_padding, ndim)
- dilation = _tuple_of_ints(dilation, ndim)
-
- # Lookup from cache.
- key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
- if key in _conv2d_gradfix_cache:
- return _conv2d_gradfix_cache[key]
-
- # Validate arguments.
- assert groups >= 1
- assert len(weight_shape) == ndim + 2
- assert all(stride[i] >= 1 for i in range(ndim))
- assert all(padding[i] >= 0 for i in range(ndim))
- assert all(dilation[i] >= 0 for i in range(ndim))
- if not transpose:
- assert all(output_padding[i] == 0 for i in range(ndim))
- else: # transpose
- assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
-
- # Helpers.
- common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
- def calc_output_padding(input_shape, output_shape):
- if transpose:
- return [0, 0]
- return [
- input_shape[i + 2]
- - (output_shape[i + 2] - 1) * stride[i]
- - (1 - 2 * padding[i])
- - dilation[i] * (weight_shape[i + 2] - 1)
- for i in range(ndim)
- ]
-
- # Forward & backward.
- class Conv2d(torch.autograd.Function):
- @staticmethod
- def forward(ctx, input, weight, bias):
- assert weight.shape == weight_shape
- if not transpose:
- output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
- else: # transpose
- output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
- ctx.save_for_backward(input, weight)
- return output
-
- @staticmethod
- def backward(ctx, grad_output):
- input, weight = ctx.saved_tensors
- grad_input = None
- grad_weight = None
- grad_bias = None
-
- if ctx.needs_input_grad[0]:
- p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
- grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
- assert grad_input.shape == input.shape
-
- if ctx.needs_input_grad[1] and not weight_gradients_disabled:
- grad_weight = Conv2dGradWeight.apply(grad_output, input)
- assert grad_weight.shape == weight_shape
-
- if ctx.needs_input_grad[2]:
- grad_bias = grad_output.sum([0, 2, 3])
-
- return grad_input, grad_weight, grad_bias
-
- # Gradient with respect to the weights.
- class Conv2dGradWeight(torch.autograd.Function):
- @staticmethod
- def forward(ctx, grad_output, input):
- op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
- flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
- grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
- assert grad_weight.shape == weight_shape
- ctx.save_for_backward(grad_output, input)
- return grad_weight
-
- @staticmethod
- def backward(ctx, grad2_grad_weight):
- grad_output, input = ctx.saved_tensors
- grad2_grad_output = None
- grad2_input = None
-
- if ctx.needs_input_grad[0]:
- grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
- assert grad2_grad_output.shape == grad_output.shape
-
- if ctx.needs_input_grad[1]:
- p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
- grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
- assert grad2_input.shape == input.shape
-
- return grad2_grad_output, grad2_input
-
- _conv2d_gradfix_cache[key] = Conv2d
- return Conv2d
-
-#----------------------------------------------------------------------------
diff --git a/PTI/torch_utils/ops/conv2d_resample.py b/PTI/torch_utils/ops/conv2d_resample.py
deleted file mode 100644
index cd4750744c83354bab78704d4ef51ad1070fcc4a..0000000000000000000000000000000000000000
--- a/PTI/torch_utils/ops/conv2d_resample.py
+++ /dev/null
@@ -1,156 +0,0 @@
-# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-#
-# NVIDIA CORPORATION and its licensors retain all intellectual property
-# and proprietary rights in and to this software, related documentation
-# and any modifications thereto. Any use, reproduction, disclosure or
-# distribution of this software and related documentation without an express
-# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-"""2D convolution with optional up/downsampling."""
-
-import torch
-
-from .. import misc
-from . import conv2d_gradfix
-from . import upfirdn2d
-from .upfirdn2d import _parse_padding
-from .upfirdn2d import _get_filter_size
-
-#----------------------------------------------------------------------------
-
-def _get_weight_shape(w):
- with misc.suppress_tracer_warnings(): # this value will be treated as a constant
- shape = [int(sz) for sz in w.shape]
- misc.assert_shape(w, shape)
- return shape
-
-#----------------------------------------------------------------------------
-
-def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
- """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
- """
- out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
-
- # Flip weight if requested.
- if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
- w = w.flip([2, 3])
-
- # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
- # 1x1 kernel + memory_format=channels_last + less than 64 channels.
- if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
- if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
- if out_channels <= 4 and groups == 1:
- in_shape = x.shape
- x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
- x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
- else:
- x = x.to(memory_format=torch.contiguous_format)
- w = w.to(memory_format=torch.contiguous_format)
- x = conv2d_gradfix.conv2d(x, w, groups=groups)
- return x.to(memory_format=torch.channels_last)
-
- # Otherwise => execute using conv2d_gradfix.
- op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
- return op(x, w, stride=stride, padding=padding, groups=groups)
-
-#----------------------------------------------------------------------------
-
-@misc.profiled_function
-def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
- r"""2D convolution with optional up/downsampling.
-
- Padding is performed only once at the beginning, not between the operations.
-
- Args:
- x: Input tensor of shape
- `[batch_size, in_channels, in_height, in_width]`.
- w: Weight tensor of shape
- `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
- f: Low-pass filter for up/downsampling. Must be prepared beforehand by
- calling upfirdn2d.setup_filter(). None = identity (default).
- up: Integer upsampling factor (default: 1).
- down: Integer downsampling factor (default: 1).
- padding: Padding with respect to the upsampled image. Can be a single number
- or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
- (default: 0).
- groups: Split input channels into N groups (default: 1).
- flip_weight: False = convolution, True = correlation (default: True).
- flip_filter: False = convolution, True = correlation (default: False).
-
- Returns:
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
- """
- # Validate arguments.
- assert isinstance(x, torch.Tensor) and (x.ndim == 4)
- assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
- assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
- assert isinstance(up, int) and (up >= 1)
- assert isinstance(down, int) and (down >= 1)
- assert isinstance(groups, int) and (groups >= 1)
- out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
- fw, fh = _get_filter_size(f)
- px0, px1, py0, py1 = _parse_padding(padding)
-
- # Adjust padding to account for up/downsampling.
- if up > 1:
- px0 += (fw + up - 1) // 2
- px1 += (fw - up) // 2
- py0 += (fh + up - 1) // 2
- py1 += (fh - up) // 2
- if down > 1:
- px0 += (fw - down + 1) // 2
- px1 += (fw - down) // 2
- py0 += (fh - down + 1) // 2
- py1 += (fh - down) // 2
-
- # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
- if kw == 1 and kh == 1 and (down > 1 and up == 1):
- x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
- x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
- return x
-
- # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
- if kw == 1 and kh == 1 and (up > 1 and down == 1):
- x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
- x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
- return x
-
- # Fast path: downsampling only => use strided convolution.
- if down > 1 and up == 1:
- x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
- x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
- return x
-
- # Fast path: upsampling with optional downsampling => use transpose strided convolution.
- if up > 1:
- if groups == 1:
- w = w.transpose(0, 1)
- else:
- w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
- w = w.transpose(1, 2)
- w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
- px0 -= kw - 1
- px1 -= kw - up
- py0 -= kh - 1
- py1 -= kh - up
- pxt = max(min(-px0, -px1), 0)
- pyt = max(min(-py0, -py1), 0)
- x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
- x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
- if down > 1:
- x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
- return x
-
- # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
- if up == 1 and down == 1:
- if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
- return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
-
- # Fallback: Generic reference implementation.
- x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
- x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
- if down > 1:
- x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
- return x
-
-#----------------------------------------------------------------------------
diff --git a/PTI/torch_utils/ops/fma.py b/PTI/torch_utils/ops/fma.py
deleted file mode 100644
index 2eeac58a626c49231e04122b93e321ada954c5d3..0000000000000000000000000000000000000000
--- a/PTI/torch_utils/ops/fma.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-#
-# NVIDIA CORPORATION and its licensors retain all intellectual property
-# and proprietary rights in and to this software, related documentation
-# and any modifications thereto. Any use, reproduction, disclosure or
-# distribution of this software and related documentation without an express
-# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
-
-import torch
-
-#----------------------------------------------------------------------------
-
-def fma(a, b, c): # => a * b + c
- return _FusedMultiplyAdd.apply(a, b, c)
-
-#----------------------------------------------------------------------------
-
-class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
- @staticmethod
- def forward(ctx, a, b, c): # pylint: disable=arguments-differ
- out = torch.addcmul(c, a, b)
- ctx.save_for_backward(a, b)
- ctx.c_shape = c.shape
- return out
-
- @staticmethod
- def backward(ctx, dout): # pylint: disable=arguments-differ
- a, b = ctx.saved_tensors
- c_shape = ctx.c_shape
- da = None
- db = None
- dc = None
-
- if ctx.needs_input_grad[0]:
- da = _unbroadcast(dout * b, a.shape)
-
- if ctx.needs_input_grad[1]:
- db = _unbroadcast(dout * a, b.shape)
-
- if ctx.needs_input_grad[2]:
- dc = _unbroadcast(dout, c_shape)
-
- return da, db, dc
-
-#----------------------------------------------------------------------------
-
-def _unbroadcast(x, shape):
- extra_dims = x.ndim - len(shape)
- assert extra_dims >= 0
- dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
- if len(dim):
- x = x.sum(dim=dim, keepdim=True)
- if extra_dims:
- x = x.reshape(-1, *x.shape[extra_dims+1:])
- assert x.shape == shape
- return x
-
-#----------------------------------------------------------------------------
diff --git a/PTI/torch_utils/ops/grid_sample_gradfix.py b/PTI/torch_utils/ops/grid_sample_gradfix.py
deleted file mode 100644
index ca6b3413ea72a734703c34382c023b84523601fd..0000000000000000000000000000000000000000
--- a/PTI/torch_utils/ops/grid_sample_gradfix.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-#
-# NVIDIA CORPORATION and its licensors retain all intellectual property
-# and proprietary rights in and to this software, related documentation
-# and any modifications thereto. Any use, reproduction, disclosure or
-# distribution of this software and related documentation without an express
-# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-"""Custom replacement for `torch.nn.functional.grid_sample` that
-supports arbitrarily high order gradients between the input and output.
-Only works on 2D images and assumes
-`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
-
-import warnings
-import torch
-
-# pylint: disable=redefined-builtin
-# pylint: disable=arguments-differ
-# pylint: disable=protected-access
-
-#----------------------------------------------------------------------------
-
-enabled = False # Enable the custom op by setting this to true.
-
-#----------------------------------------------------------------------------
-
-def grid_sample(input, grid):
- if _should_use_custom_op():
- return _GridSample2dForward.apply(input, grid)
- return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
-
-#----------------------------------------------------------------------------
-
-def _should_use_custom_op():
- if not enabled:
- return False
- if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
- return True
- warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
- return False
-
-#----------------------------------------------------------------------------
-
-class _GridSample2dForward(torch.autograd.Function):
- @staticmethod
- def forward(ctx, input, grid):
- assert input.ndim == 4
- assert grid.ndim == 4
- output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
- ctx.save_for_backward(input, grid)
- return output
-
- @staticmethod
- def backward(ctx, grad_output):
- input, grid = ctx.saved_tensors
- grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
- return grad_input, grad_grid
-
-#----------------------------------------------------------------------------
-
-class _GridSample2dBackward(torch.autograd.Function):
- @staticmethod
- def forward(ctx, grad_output, input, grid):
- op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
- grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
- ctx.save_for_backward(grid)
- return grad_input, grad_grid
-
- @staticmethod
- def backward(ctx, grad2_grad_input, grad2_grad_grid):
- _ = grad2_grad_grid # unused
- grid, = ctx.saved_tensors
- grad2_grad_output = None
- grad2_input = None
- grad2_grid = None
-
- if ctx.needs_input_grad[0]:
- grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
-
- assert not ctx.needs_input_grad[2]
- return grad2_grad_output, grad2_input, grad2_grid
-
-#----------------------------------------------------------------------------
diff --git a/PTI/torch_utils/ops/upfirdn2d.cpp b/PTI/torch_utils/ops/upfirdn2d.cpp
deleted file mode 100644
index 2d7177fc60040751d20e9a8da0301fa3ab64968a..0000000000000000000000000000000000000000
--- a/PTI/torch_utils/ops/upfirdn2d.cpp
+++ /dev/null
@@ -1,103 +0,0 @@
-// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-//
-// NVIDIA CORPORATION and its licensors retain all intellectual property
-// and proprietary rights in and to this software, related documentation
-// and any modifications thereto. Any use, reproduction, disclosure or
-// distribution of this software and related documentation without an express
-// license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-#include
-#include
-#include
-#include "upfirdn2d.h"
-
-//------------------------------------------------------------------------
-
-static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
-{
- // Validate arguments.
- TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
- TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
- TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
- TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
- TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
- TORCH_CHECK(x.dim() == 4, "x must be rank 4");
- TORCH_CHECK(f.dim() == 2, "f must be rank 2");
- TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
- TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
- TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
-
- // Create output tensor.
- const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
- int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
- int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
- TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
- torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
- TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
-
- // Initialize CUDA kernel parameters.
- upfirdn2d_kernel_params p;
- p.x = x.data_ptr();
- p.f = f.data_ptr();
- p.y = y.data_ptr();
- p.up = make_int2(upx, upy);
- p.down = make_int2(downx, downy);
- p.pad0 = make_int2(padx0, pady0);
- p.flip = (flip) ? 1 : 0;
- p.gain = gain;
- p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
- p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
- p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
- p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
- p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
- p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
- p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
- p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
-
- // Choose CUDA kernel.
- upfirdn2d_kernel_spec spec;
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
- {
- spec = choose_upfirdn2d_kernel(p);
- });
-
- // Set looping options.
- p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
- p.loopMinor = spec.loopMinor;
- p.loopX = spec.loopX;
- p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
- p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
-
- // Compute grid size.
- dim3 blockSize, gridSize;
- if (spec.tileOutW < 0) // large
- {
- blockSize = dim3(4, 32, 1);
- gridSize = dim3(
- ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
- (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
- p.launchMajor);
- }
- else // small
- {
- blockSize = dim3(256, 1, 1);
- gridSize = dim3(
- ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
- (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
- p.launchMajor);
- }
-
- // Launch CUDA kernel.
- void* args[] = {&p};
- AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
- return y;
-}
-
-//------------------------------------------------------------------------
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
-{
- m.def("upfirdn2d", &upfirdn2d);
-}
-
-//------------------------------------------------------------------------
diff --git a/PTI/torch_utils/ops/upfirdn2d.cu b/PTI/torch_utils/ops/upfirdn2d.cu
deleted file mode 100644
index ebdd9879f4bb16fc57a23cbc81f9de8ef54e4916..0000000000000000000000000000000000000000
--- a/PTI/torch_utils/ops/upfirdn2d.cu
+++ /dev/null
@@ -1,350 +0,0 @@
-// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-//
-// NVIDIA CORPORATION and its licensors retain all intellectual property
-// and proprietary rights in and to this software, related documentation
-// and any modifications thereto. Any use, reproduction, disclosure or
-// distribution of this software and related documentation without an express
-// license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-#include
-#include "upfirdn2d.h"
-
-//------------------------------------------------------------------------
-// Helpers.
-
-template struct InternalType;
-template <> struct InternalType { typedef double scalar_t; };
-template <> struct InternalType { typedef float scalar_t; };
-template <> struct InternalType { typedef float scalar_t; };
-
-static __device__ __forceinline__ int floor_div(int a, int b)
-{
- int t = 1 - a / b;
- return (a + t * b) / b - t;
-}
-
-//------------------------------------------------------------------------
-// Generic CUDA implementation for large filters.
-
-template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
-{
- typedef typename InternalType::scalar_t scalar_t;
-
- // Calculate thread index.
- int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
- int outY = minorBase / p.launchMinor;
- minorBase -= outY * p.launchMinor;
- int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
- int majorBase = blockIdx.z * p.loopMajor;
- if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
- return;
-
- // Setup Y receptive field.
- int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
- int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
- int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
- int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
- if (p.flip)
- filterY = p.filterSize.y - 1 - filterY;
-
- // Loop over major, minor, and X.
- for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
- for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
- {
- int nc = major * p.sizeMinor + minor;
- int n = nc / p.inSize.z;
- int c = nc - n * p.inSize.z;
- for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
- {
- // Setup X receptive field.
- int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
- int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
- int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
- int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
- if (p.flip)
- filterX = p.filterSize.x - 1 - filterX;
-
- // Initialize pointers.
- const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
- const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
- int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
- int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
-
- // Inner loop.
- scalar_t v = 0;
- for (int y = 0; y < h; y++)
- {
- for (int x = 0; x < w; x++)
- {
- v += (scalar_t)(*xp) * (scalar_t)(*fp);
- xp += p.inStride.x;
- fp += filterStepX;
- }
- xp += p.inStride.y - w * p.inStride.x;
- fp += filterStepY - w * filterStepX;
- }
-
- // Store result.
- v *= p.gain;
- ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
- }
- }
-}
-
-//------------------------------------------------------------------------
-// Specialized CUDA implementation for small filters.
-
-template
-static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
-{
- typedef typename InternalType::scalar_t scalar_t;
- const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
- const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
- __shared__ volatile scalar_t sf[filterH][filterW];
- __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
-
- // Calculate tile index.
- int minorBase = blockIdx.x;
- int tileOutY = minorBase / p.launchMinor;
- minorBase -= tileOutY * p.launchMinor;
- minorBase *= loopMinor;
- tileOutY *= tileOutH;
- int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
- int majorBase = blockIdx.z * p.loopMajor;
- if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
- return;
-
- // Load filter (flipped).
- for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
- {
- int fy = tapIdx / filterW;
- int fx = tapIdx - fy * filterW;
- scalar_t v = 0;
- if (fx < p.filterSize.x & fy < p.filterSize.y)
- {
- int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
- int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
- v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
- }
- sf[fy][fx] = v;
- }
-
- // Loop over major and X.
- for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
- {
- int baseNC = major * p.sizeMinor + minorBase;
- int n = baseNC / p.inSize.z;
- int baseC = baseNC - n * p.inSize.z;
- for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
- {
- // Load input pixels.
- int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
- int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
- int tileInX = floor_div(tileMidX, upx);
- int tileInY = floor_div(tileMidY, upy);
- __syncthreads();
- for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
- {
- int relC = inIdx;
- int relInX = relC / loopMinor;
- int relInY = relInX / tileInW;
- relC -= relInX * loopMinor;
- relInX -= relInY * tileInW;
- int c = baseC + relC;
- int inX = tileInX + relInX;
- int inY = tileInY + relInY;
- scalar_t v = 0;
- if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
- v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
- sx[relInY][relInX][relC] = v;
- }
-
- // Loop over output pixels.
- __syncthreads();
- for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
- {
- int relC = outIdx;
- int relOutX = relC / loopMinor;
- int relOutY = relOutX / tileOutW;
- relC -= relOutX * loopMinor;
- relOutX -= relOutY * tileOutW;
- int c = baseC + relC;
- int outX = tileOutX + relOutX;
- int outY = tileOutY + relOutY;
-
- // Setup receptive field.
- int midX = tileMidX + relOutX * downx;
- int midY = tileMidY + relOutY * downy;
- int inX = floor_div(midX, upx);
- int inY = floor_div(midY, upy);
- int relInX = inX - tileInX;
- int relInY = inY - tileInY;
- int filterX = (inX + 1) * upx - midX - 1; // flipped
- int filterY = (inY + 1) * upy - midY - 1; // flipped
-
- // Inner loop.
- if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
- {
- scalar_t v = 0;
- #pragma unroll
- for (int y = 0; y < filterH / upy; y++)
- #pragma unroll
- for (int x = 0; x < filterW / upx; x++)
- v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
- v *= p.gain;
- ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
- }
- }
- }
- }
-}
-
-//------------------------------------------------------------------------
-// CUDA kernel selection.
-
-template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
-{
- int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
-
- upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous
- if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last
-
- if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
- {
- if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
- if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
- if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
- }
- if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
- {
- if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
- if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
- if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
- }
- if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
- {
- if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
- if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
- }
- if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
- {
- if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
- if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
- }
- if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
- {
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
- }
- if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
- {
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
- }
- if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
- {
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
- }
- if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
- {
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
- }
- if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
- {
- if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
- if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
- }
- if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
- {
- if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
- if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
- }
- if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
- {
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
- }
- if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
- {
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
- }
- if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
- {
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
- }
- if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
- {
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
- }
- return spec;
-}
-
-//------------------------------------------------------------------------
-// Template specializations.
-
-template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p);
-template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p);
-template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
-
-//------------------------------------------------------------------------
diff --git a/PTI/torch_utils/ops/upfirdn2d.h b/PTI/torch_utils/ops/upfirdn2d.h
deleted file mode 100644
index c9e2032bcac9d2abde7a75eea4d812da348afadd..0000000000000000000000000000000000000000
--- a/PTI/torch_utils/ops/upfirdn2d.h
+++ /dev/null
@@ -1,59 +0,0 @@
-// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-//
-// NVIDIA CORPORATION and its licensors retain all intellectual property
-// and proprietary rights in and to this software, related documentation
-// and any modifications thereto. Any use, reproduction, disclosure or
-// distribution of this software and related documentation without an express
-// license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-#include
-
-//------------------------------------------------------------------------
-// CUDA kernel parameters.
-
-struct upfirdn2d_kernel_params
-{
- const void* x;
- const float* f;
- void* y;
-
- int2 up;
- int2 down;
- int2 pad0;
- int flip;
- float gain;
-
- int4 inSize; // [width, height, channel, batch]
- int4 inStride;
- int2 filterSize; // [width, height]
- int2 filterStride;
- int4 outSize; // [width, height, channel, batch]
- int4 outStride;
- int sizeMinor;
- int sizeMajor;
-
- int loopMinor;
- int loopMajor;
- int loopX;
- int launchMinor;
- int launchMajor;
-};
-
-//------------------------------------------------------------------------
-// CUDA kernel specialization.
-
-struct upfirdn2d_kernel_spec
-{
- void* kernel;
- int tileOutW;
- int tileOutH;
- int loopMinor;
- int loopX;
-};
-
-//------------------------------------------------------------------------
-// CUDA kernel selection.
-
-template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
-
-//------------------------------------------------------------------------
diff --git a/PTI/torch_utils/ops/upfirdn2d.py b/PTI/torch_utils/ops/upfirdn2d.py
deleted file mode 100644
index ceeac2b9834e33b7c601c28bf27f32aa91c69256..0000000000000000000000000000000000000000
--- a/PTI/torch_utils/ops/upfirdn2d.py
+++ /dev/null
@@ -1,384 +0,0 @@
-# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-#
-# NVIDIA CORPORATION and its licensors retain all intellectual property
-# and proprietary rights in and to this software, related documentation
-# and any modifications thereto. Any use, reproduction, disclosure or
-# distribution of this software and related documentation without an express
-# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-"""Custom PyTorch ops for efficient resampling of 2D images."""
-
-import os
-import warnings
-import numpy as np
-import torch
-import traceback
-
-from .. import custom_ops
-from .. import misc
-from . import conv2d_gradfix
-
-#----------------------------------------------------------------------------
-
-_inited = False
-_plugin = None
-
-def _init():
- global _inited, _plugin
- if not _inited:
- sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
- sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
- try:
- _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
- except:
- warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
- return _plugin is not None
-
-def _parse_scaling(scaling):
- if isinstance(scaling, int):
- scaling = [scaling, scaling]
- assert isinstance(scaling, (list, tuple))
- assert all(isinstance(x, int) for x in scaling)
- sx, sy = scaling
- assert sx >= 1 and sy >= 1
- return sx, sy
-
-def _parse_padding(padding):
- if isinstance(padding, int):
- padding = [padding, padding]
- assert isinstance(padding, (list, tuple))
- assert all(isinstance(x, int) for x in padding)
- if len(padding) == 2:
- padx, pady = padding
- padding = [padx, padx, pady, pady]
- padx0, padx1, pady0, pady1 = padding
- return padx0, padx1, pady0, pady1
-
-def _get_filter_size(f):
- if f is None:
- return 1, 1
- assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
- fw = f.shape[-1]
- fh = f.shape[0]
- with misc.suppress_tracer_warnings():
- fw = int(fw)
- fh = int(fh)
- misc.assert_shape(f, [fh, fw][:f.ndim])
- assert fw >= 1 and fh >= 1
- return fw, fh
-
-#----------------------------------------------------------------------------
-
-def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
- r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
-
- Args:
- f: Torch tensor, numpy array, or python list of the shape
- `[filter_height, filter_width]` (non-separable),
- `[filter_taps]` (separable),
- `[]` (impulse), or
- `None` (identity).
- device: Result device (default: cpu).
- normalize: Normalize the filter so that it retains the magnitude
- for constant input signal (DC)? (default: True).
- flip_filter: Flip the filter? (default: False).
- gain: Overall scaling factor for signal magnitude (default: 1).
- separable: Return a separable filter? (default: select automatically).
-
- Returns:
- Float32 tensor of the shape
- `[filter_height, filter_width]` (non-separable) or
- `[filter_taps]` (separable).
- """
- # Validate.
- if f is None:
- f = 1
- f = torch.as_tensor(f, dtype=torch.float32)
- assert f.ndim in [0, 1, 2]
- assert f.numel() > 0
- if f.ndim == 0:
- f = f[np.newaxis]
-
- # Separable?
- if separable is None:
- separable = (f.ndim == 1 and f.numel() >= 8)
- if f.ndim == 1 and not separable:
- f = f.ger(f)
- assert f.ndim == (1 if separable else 2)
-
- # Apply normalize, flip, gain, and device.
- if normalize:
- f /= f.sum()
- if flip_filter:
- f = f.flip(list(range(f.ndim)))
- f = f * (gain ** (f.ndim / 2))
- f = f.to(device=device)
- return f
-
-#----------------------------------------------------------------------------
-
-def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
- r"""Pad, upsample, filter, and downsample a batch of 2D images.
-
- Performs the following sequence of operations for each channel:
-
- 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
-
- 2. Pad the image with the specified number of zeros on each side (`padding`).
- Negative padding corresponds to cropping the image.
-
- 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
- so that the footprint of all output pixels lies within the input image.
-
- 4. Downsample the image by keeping every Nth pixel (`down`).
-
- This sequence of operations bears close resemblance to scipy.signal.upfirdn().
- The fused op is considerably more efficient than performing the same calculation
- using standard PyTorch ops. It supports gradients of arbitrary order.
-
- Args:
- x: Float32/float64/float16 input tensor of the shape
- `[batch_size, num_channels, in_height, in_width]`.
- f: Float32 FIR filter of the shape
- `[filter_height, filter_width]` (non-separable),
- `[filter_taps]` (separable), or
- `None` (identity).
- up: Integer upsampling factor. Can be a single int or a list/tuple
- `[x, y]` (default: 1).
- down: Integer downsampling factor. Can be a single int or a list/tuple
- `[x, y]` (default: 1).
- padding: Padding with respect to the upsampled image. Can be a single number
- or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
- (default: 0).
- flip_filter: False = convolution, True = correlation (default: False).
- gain: Overall scaling factor for signal magnitude (default: 1).
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
-
- Returns:
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
- """
- assert isinstance(x, torch.Tensor)
- assert impl in ['ref', 'cuda']
- if impl == 'cuda' and x.device.type == 'cuda' and _init():
- return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
- return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
-
-#----------------------------------------------------------------------------
-
-@misc.profiled_function
-def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
- """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
- """
- # Validate arguments.
- assert isinstance(x, torch.Tensor) and x.ndim == 4
- if f is None:
- f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
- assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
- assert f.dtype == torch.float32 and not f.requires_grad
- batch_size, num_channels, in_height, in_width = x.shape
- upx, upy = _parse_scaling(up)
- downx, downy = _parse_scaling(down)
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
-
- # Upsample by inserting zeros.
- x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
- x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
- x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
-
- # Pad or crop.
- x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
- x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
-
- # Setup filter.
- f = f * (gain ** (f.ndim / 2))
- f = f.to(x.dtype)
- if not flip_filter:
- f = f.flip(list(range(f.ndim)))
-
- # Convolve with the filter.
- f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
- if f.ndim == 4:
- x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
- else:
- x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
- x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
-
- # Downsample by throwing away pixels.
- x = x[:, :, ::downy, ::downx]
- return x
-
-#----------------------------------------------------------------------------
-
-_upfirdn2d_cuda_cache = dict()
-
-def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
- """Fast CUDA implementation of `upfirdn2d()` using custom ops.
- """
- # Parse arguments.
- upx, upy = _parse_scaling(up)
- downx, downy = _parse_scaling(down)
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
-
- # Lookup from cache.
- key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
- if key in _upfirdn2d_cuda_cache:
- return _upfirdn2d_cuda_cache[key]
-
- # Forward op.
- class Upfirdn2dCuda(torch.autograd.Function):
- @staticmethod
- def forward(ctx, x, f): # pylint: disable=arguments-differ
- assert isinstance(x, torch.Tensor) and x.ndim == 4
- if f is None:
- f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
- assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
- y = x
- if f.ndim == 2:
- y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
- else:
- y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
- y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
- ctx.save_for_backward(f)
- ctx.x_shape = x.shape
- return y
-
- @staticmethod
- def backward(ctx, dy): # pylint: disable=arguments-differ
- f, = ctx.saved_tensors
- _, _, ih, iw = ctx.x_shape
- _, _, oh, ow = dy.shape
- fw, fh = _get_filter_size(f)
- p = [
- fw - padx0 - 1,
- iw * upx - ow * downx + padx0 - upx + 1,
- fh - pady0 - 1,
- ih * upy - oh * downy + pady0 - upy + 1,
- ]
- dx = None
- df = None
-
- if ctx.needs_input_grad[0]:
- dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
-
- assert not ctx.needs_input_grad[1]
- return dx, df
-
- # Add to cache.
- _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
- return Upfirdn2dCuda
-
-#----------------------------------------------------------------------------
-
-def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
- r"""Filter a batch of 2D images using the given 2D FIR filter.
-
- By default, the result is padded so that its shape matches the input.
- User-specified padding is applied on top of that, with negative values
- indicating cropping. Pixels outside the image are assumed to be zero.
-
- Args:
- x: Float32/float64/float16 input tensor of the shape
- `[batch_size, num_channels, in_height, in_width]`.
- f: Float32 FIR filter of the shape
- `[filter_height, filter_width]` (non-separable),
- `[filter_taps]` (separable), or
- `None` (identity).
- padding: Padding with respect to the output. Can be a single number or a
- list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
- (default: 0).
- flip_filter: False = convolution, True = correlation (default: False).
- gain: Overall scaling factor for signal magnitude (default: 1).
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
-
- Returns:
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
- """
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
- fw, fh = _get_filter_size(f)
- p = [
- padx0 + fw // 2,
- padx1 + (fw - 1) // 2,
- pady0 + fh // 2,
- pady1 + (fh - 1) // 2,
- ]
- return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
-
-#----------------------------------------------------------------------------
-
-def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
- r"""Upsample a batch of 2D images using the given 2D FIR filter.
-
- By default, the result is padded so that its shape is a multiple of the input.
- User-specified padding is applied on top of that, with negative values
- indicating cropping. Pixels outside the image are assumed to be zero.
-
- Args:
- x: Float32/float64/float16 input tensor of the shape
- `[batch_size, num_channels, in_height, in_width]`.
- f: Float32 FIR filter of the shape
- `[filter_height, filter_width]` (non-separable),
- `[filter_taps]` (separable), or
- `None` (identity).
- up: Integer upsampling factor. Can be a single int or a list/tuple
- `[x, y]` (default: 1).
- padding: Padding with respect to the output. Can be a single number or a
- list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
- (default: 0).
- flip_filter: False = convolution, True = correlation (default: False).
- gain: Overall scaling factor for signal magnitude (default: 1).
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
-
- Returns:
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
- """
- upx, upy = _parse_scaling(up)
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
- fw, fh = _get_filter_size(f)
- p = [
- padx0 + (fw + upx - 1) // 2,
- padx1 + (fw - upx) // 2,
- pady0 + (fh + upy - 1) // 2,
- pady1 + (fh - upy) // 2,
- ]
- return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
-
-#----------------------------------------------------------------------------
-
-def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
- r"""Downsample a batch of 2D images using the given 2D FIR filter.
-
- By default, the result is padded so that its shape is a fraction of the input.
- User-specified padding is applied on top of that, with negative values
- indicating cropping. Pixels outside the image are assumed to be zero.
-
- Args:
- x: Float32/float64/float16 input tensor of the shape
- `[batch_size, num_channels, in_height, in_width]`.
- f: Float32 FIR filter of the shape
- `[filter_height, filter_width]` (non-separable),
- `[filter_taps]` (separable), or
- `None` (identity).
- down: Integer downsampling factor. Can be a single int or a list/tuple
- `[x, y]` (default: 1).
- padding: Padding with respect to the input. Can be a single number or a
- list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
- (default: 0).
- flip_filter: False = convolution, True = correlation (default: False).
- gain: Overall scaling factor for signal magnitude (default: 1).
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
-
- Returns:
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
- """
- downx, downy = _parse_scaling(down)
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
- fw, fh = _get_filter_size(f)
- p = [
- padx0 + (fw - downx + 1) // 2,
- padx1 + (fw - downx) // 2,
- pady0 + (fh - downy + 1) // 2,
- pady1 + (fh - downy) // 2,
- ]
- return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
-
-#----------------------------------------------------------------------------
diff --git a/PTI/torch_utils/persistence.py b/PTI/torch_utils/persistence.py
deleted file mode 100644
index 0186cfd97bca0fcb397a7b73643520c1d1105a02..0000000000000000000000000000000000000000
--- a/PTI/torch_utils/persistence.py
+++ /dev/null
@@ -1,251 +0,0 @@
-# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-#
-# NVIDIA CORPORATION and its licensors retain all intellectual property
-# and proprietary rights in and to this software, related documentation
-# and any modifications thereto. Any use, reproduction, disclosure or
-# distribution of this software and related documentation without an express
-# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-"""Facilities for pickling Python code alongside other data.
-
-The pickled code is automatically imported into a separate Python module
-during unpickling. This way, any previously exported pickles will remain
-usable even if the original code is no longer available, or if the current
-version of the code is not consistent with what was originally pickled."""
-
-import sys
-import pickle
-import io
-import inspect
-import copy
-import uuid
-import types
-import dnnlib
-
-#----------------------------------------------------------------------------
-
-_version = 6 # internal version number
-_decorators = set() # {decorator_class, ...}
-_import_hooks = [] # [hook_function, ...]
-_module_to_src_dict = dict() # {module: src, ...}
-_src_to_module_dict = dict() # {src: module, ...}
-
-#----------------------------------------------------------------------------
-
-def persistent_class(orig_class):
- r"""Class decorator that extends a given class to save its source code
- when pickled.
-
- Example:
-
- from torch_utils import persistence
-
- @persistence.persistent_class
- class MyNetwork(torch.nn.Module):
- def __init__(self, num_inputs, num_outputs):
- super().__init__()
- self.fc = MyLayer(num_inputs, num_outputs)
- ...
-
- @persistence.persistent_class
- class MyLayer(torch.nn.Module):
- ...
-
- When pickled, any instance of `MyNetwork` and `MyLayer` will save its
- source code alongside other internal state (e.g., parameters, buffers,
- and submodules). This way, any previously exported pickle will remain
- usable even if the class definitions have been modified or are no
- longer available.
-
- The decorator saves the source code of the entire Python module
- containing the decorated class. It does *not* save the source code of
- any imported modules. Thus, the imported modules must be available
- during unpickling, also including `torch_utils.persistence` itself.
-
- It is ok to call functions defined in the same module from the
- decorated class. However, if the decorated class depends on other
- classes defined in the same module, they must be decorated as well.
- This is illustrated in the above example in the case of `MyLayer`.
-
- It is also possible to employ the decorator just-in-time before
- calling the constructor. For example:
-
- cls = MyLayer
- if want_to_make_it_persistent:
- cls = persistence.persistent_class(cls)
- layer = cls(num_inputs, num_outputs)
-
- As an additional feature, the decorator also keeps track of the
- arguments that were used to construct each instance of the decorated
- class. The arguments can be queried via `obj.init_args` and
- `obj.init_kwargs`, and they are automatically pickled alongside other
- object state. A typical use case is to first unpickle a previous
- instance of a persistent class, and then upgrade it to use the latest
- version of the source code:
-
- with open('old_pickle.pkl', 'rb') as f:
- old_net = pickle.load(f)
- new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
- misc.copy_params_and_buffers(old_net, new_net, require_all=True)
- """
- assert isinstance(orig_class, type)
- if is_persistent(orig_class):
- return orig_class
-
- assert orig_class.__module__ in sys.modules
- orig_module = sys.modules[orig_class.__module__]
- orig_module_src = _module_to_src(orig_module)
-
- class Decorator(orig_class):
- _orig_module_src = orig_module_src
- _orig_class_name = orig_class.__name__
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._init_args = copy.deepcopy(args)
- self._init_kwargs = copy.deepcopy(kwargs)
- assert orig_class.__name__ in orig_module.__dict__
- _check_pickleable(self.__reduce__())
-
- @property
- def init_args(self):
- return copy.deepcopy(self._init_args)
-
- @property
- def init_kwargs(self):
- return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
-
- def __reduce__(self):
- fields = list(super().__reduce__())
- fields += [None] * max(3 - len(fields), 0)
- if fields[0] is not _reconstruct_persistent_obj:
- meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
- fields[0] = _reconstruct_persistent_obj # reconstruct func
- fields[1] = (meta,) # reconstruct args
- fields[2] = None # state dict
- return tuple(fields)
-
- Decorator.__name__ = orig_class.__name__
- _decorators.add(Decorator)
- return Decorator
-
-#----------------------------------------------------------------------------
-
-def is_persistent(obj):
- r"""Test whether the given object or class is persistent, i.e.,
- whether it will save its source code when pickled.
- """
- try:
- if obj in _decorators:
- return True
- except TypeError:
- pass
- return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
-
-#----------------------------------------------------------------------------
-
-def import_hook(hook):
- r"""Register an import hook that is called whenever a persistent object
- is being unpickled. A typical use case is to patch the pickled source
- code to avoid errors and inconsistencies when the API of some imported
- module has changed.
-
- The hook should have the following signature:
-
- hook(meta) -> modified meta
-
- `meta` is an instance of `dnnlib.EasyDict` with the following fields:
-
- type: Type of the persistent object, e.g. `'class'`.
- version: Internal version number of `torch_utils.persistence`.
- module_src Original source code of the Python module.
- class_name: Class name in the original Python module.
- state: Internal state of the object.
-
- Example:
-
- @persistence.import_hook
- def wreck_my_network(meta):
- if meta.class_name == 'MyNetwork':
- print('MyNetwork is being imported. I will wreck it!')
- meta.module_src = meta.module_src.replace("True", "False")
- return meta
- """
- assert callable(hook)
- _import_hooks.append(hook)
-
-#----------------------------------------------------------------------------
-
-def _reconstruct_persistent_obj(meta):
- r"""Hook that is called internally by the `pickle` module to unpickle
- a persistent object.
- """
- meta = dnnlib.EasyDict(meta)
- meta.state = dnnlib.EasyDict(meta.state)
- for hook in _import_hooks:
- meta = hook(meta)
- assert meta is not None
-
- assert meta.version == _version
- module = _src_to_module(meta.module_src)
-
- assert meta.type == 'class'
- orig_class = module.__dict__[meta.class_name]
- decorator_class = persistent_class(orig_class)
- obj = decorator_class.__new__(decorator_class)
-
- setstate = getattr(obj, '__setstate__', None)
- if callable(setstate):
- setstate(meta.state) # pylint: disable=not-callable
- else:
- obj.__dict__.update(meta.state)
- return obj
-
-#----------------------------------------------------------------------------
-
-def _module_to_src(module):
- r"""Query the source code of a given Python module.
- """
- src = _module_to_src_dict.get(module, None)
- if src is None:
- src = inspect.getsource(module)
- _module_to_src_dict[module] = src
- _src_to_module_dict[src] = module
- return src
-
-def _src_to_module(src):
- r"""Get or create a Python module for the given source code.
- """
- module = _src_to_module_dict.get(src, None)
- if module is None:
- module_name = "_imported_module_" + uuid.uuid4().hex
- module = types.ModuleType(module_name)
- sys.modules[module_name] = module
- _module_to_src_dict[module] = src
- _src_to_module_dict[src] = module
- exec(src, module.__dict__) # pylint: disable=exec-used
- return module
-
-#----------------------------------------------------------------------------
-
-def _check_pickleable(obj):
- r"""Check that the given object is pickleable, raising an exception if
- it is not. This function is expected to be considerably more efficient
- than actually pickling the object.
- """
- def recurse(obj):
- if isinstance(obj, (list, tuple, set)):
- return [recurse(x) for x in obj]
- if isinstance(obj, dict):
- return [[recurse(x), recurse(y)] for x, y in obj.items()]
- if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
- return None # Python primitive types are pickleable.
- if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
- return None # NumPy arrays and PyTorch tensors are pickleable.
- if is_persistent(obj):
- return None # Persistent objects are pickleable, by virtue of the constructor check.
- return obj
- with io.BytesIO() as f:
- pickle.dump(recurse(obj), f)
-
-#----------------------------------------------------------------------------
diff --git a/PTI/torch_utils/training_stats.py b/PTI/torch_utils/training_stats.py
deleted file mode 100644
index 26f467f9eaa074ee13de1cf2625cd7da44880847..0000000000000000000000000000000000000000
--- a/PTI/torch_utils/training_stats.py
+++ /dev/null
@@ -1,268 +0,0 @@
-# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-#
-# NVIDIA CORPORATION and its licensors retain all intellectual property
-# and proprietary rights in and to this software, related documentation
-# and any modifications thereto. Any use, reproduction, disclosure or
-# distribution of this software and related documentation without an express
-# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-"""Facilities for reporting and collecting training statistics across
-multiple processes and devices. The interface is designed to minimize
-synchronization overhead as well as the amount of boilerplate in user
-code."""
-
-import re
-import numpy as np
-import torch
-import dnnlib
-
-from . import misc
-
-#----------------------------------------------------------------------------
-
-_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
-_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
-_counter_dtype = torch.float64 # Data type to use for the internal counters.
-_rank = 0 # Rank of the current process.
-_sync_device = None # Device to use for multiprocess communication. None = single-process.
-_sync_called = False # Has _sync() been called yet?
-_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
-_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
-
-#----------------------------------------------------------------------------
-
-def init_multiprocessing(rank, sync_device):
- r"""Initializes `torch_utils.training_stats` for collecting statistics
- across multiple processes.
-
- This function must be called after
- `torch.distributed.init_process_group()` and before `Collector.update()`.
- The call is not necessary if multi-process collection is not needed.
-
- Args:
- rank: Rank of the current process.
- sync_device: PyTorch device to use for inter-process
- communication, or None to disable multi-process
- collection. Typically `torch.device('cuda', rank)`.
- """
- global _rank, _sync_device
- assert not _sync_called
- _rank = rank
- _sync_device = sync_device
-
-#----------------------------------------------------------------------------
-
-@misc.profiled_function
-def report(name, value):
- r"""Broadcasts the given set of scalars to all interested instances of
- `Collector`, across device and process boundaries.
-
- This function is expected to be extremely cheap and can be safely
- called from anywhere in the training loop, loss function, or inside a
- `torch.nn.Module`.
-
- Warning: The current implementation expects the set of unique names to
- be consistent across processes. Please make sure that `report()` is
- called at least once for each unique name by each process, and in the
- same order. If a given process has no scalars to broadcast, it can do
- `report(name, [])` (empty list).
-
- Args:
- name: Arbitrary string specifying the name of the statistic.
- Averages are accumulated separately for each unique name.
- value: Arbitrary set of scalars. Can be a list, tuple,
- NumPy array, PyTorch tensor, or Python scalar.
-
- Returns:
- The same `value` that was passed in.
- """
- if name not in _counters:
- _counters[name] = dict()
-
- elems = torch.as_tensor(value)
- if elems.numel() == 0:
- return value
-
- elems = elems.detach().flatten().to(_reduce_dtype)
- moments = torch.stack([
- torch.ones_like(elems).sum(),
- elems.sum(),
- elems.square().sum(),
- ])
- assert moments.ndim == 1 and moments.shape[0] == _num_moments
- moments = moments.to(_counter_dtype)
-
- device = moments.device
- if device not in _counters[name]:
- _counters[name][device] = torch.zeros_like(moments)
- _counters[name][device].add_(moments)
- return value
-
-#----------------------------------------------------------------------------
-
-def report0(name, value):
- r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
- but ignores any scalars provided by the other processes.
- See `report()` for further details.
- """
- report(name, value if _rank == 0 else [])
- return value
-
-#----------------------------------------------------------------------------
-
-class Collector:
- r"""Collects the scalars broadcasted by `report()` and `report0()` and
- computes their long-term averages (mean and standard deviation) over
- user-defined periods of time.
-
- The averages are first collected into internal counters that are not
- directly visible to the user. They are then copied to the user-visible
- state as a result of calling `update()` and can then be queried using
- `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
- internal counters for the next round, so that the user-visible state
- effectively reflects averages collected between the last two calls to
- `update()`.
-
- Args:
- regex: Regular expression defining which statistics to
- collect. The default is to collect everything.
- keep_previous: Whether to retain the previous averages if no
- scalars were collected on a given round
- (default: True).
- """
- def __init__(self, regex='.*', keep_previous=True):
- self._regex = re.compile(regex)
- self._keep_previous = keep_previous
- self._cumulative = dict()
- self._moments = dict()
- self.update()
- self._moments.clear()
-
- def names(self):
- r"""Returns the names of all statistics broadcasted so far that
- match the regular expression specified at construction time.
- """
- return [name for name in _counters if self._regex.fullmatch(name)]
-
- def update(self):
- r"""Copies current values of the internal counters to the
- user-visible state and resets them for the next round.
-
- If `keep_previous=True` was specified at construction time, the
- operation is skipped for statistics that have received no scalars
- since the last update, retaining their previous averages.
-
- This method performs a number of GPU-to-CPU transfers and one
- `torch.distributed.all_reduce()`. It is intended to be called
- periodically in the main training loop, typically once every
- N training steps.
- """
- if not self._keep_previous:
- self._moments.clear()
- for name, cumulative in _sync(self.names()):
- if name not in self._cumulative:
- self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
- delta = cumulative - self._cumulative[name]
- self._cumulative[name].copy_(cumulative)
- if float(delta[0]) != 0:
- self._moments[name] = delta
-
- def _get_delta(self, name):
- r"""Returns the raw moments that were accumulated for the given
- statistic between the last two calls to `update()`, or zero if
- no scalars were collected.
- """
- assert self._regex.fullmatch(name)
- if name not in self._moments:
- self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
- return self._moments[name]
-
- def num(self, name):
- r"""Returns the number of scalars that were accumulated for the given
- statistic between the last two calls to `update()`, or zero if
- no scalars were collected.
- """
- delta = self._get_delta(name)
- return int(delta[0])
-
- def mean(self, name):
- r"""Returns the mean of the scalars that were accumulated for the
- given statistic between the last two calls to `update()`, or NaN if
- no scalars were collected.
- """
- delta = self._get_delta(name)
- if int(delta[0]) == 0:
- return float('nan')
- return float(delta[1] / delta[0])
-
- def std(self, name):
- r"""Returns the standard deviation of the scalars that were
- accumulated for the given statistic between the last two calls to
- `update()`, or NaN if no scalars were collected.
- """
- delta = self._get_delta(name)
- if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
- return float('nan')
- if int(delta[0]) == 1:
- return float(0)
- mean = float(delta[1] / delta[0])
- raw_var = float(delta[2] / delta[0])
- return np.sqrt(max(raw_var - np.square(mean), 0))
-
- def as_dict(self):
- r"""Returns the averages accumulated between the last two calls to
- `update()` as an `dnnlib.EasyDict`. The contents are as follows:
-
- dnnlib.EasyDict(
- NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
- ...
- )
- """
- stats = dnnlib.EasyDict()
- for name in self.names():
- stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
- return stats
-
- def __getitem__(self, name):
- r"""Convenience getter.
- `collector[name]` is a synonym for `collector.mean(name)`.
- """
- return self.mean(name)
-
-#----------------------------------------------------------------------------
-
-def _sync(names):
- r"""Synchronize the global cumulative counters across devices and
- processes. Called internally by `Collector.update()`.
- """
- if len(names) == 0:
- return []
- global _sync_called
- _sync_called = True
-
- # Collect deltas within current rank.
- deltas = []
- device = _sync_device if _sync_device is not None else torch.device('cpu')
- for name in names:
- delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
- for counter in _counters[name].values():
- delta.add_(counter.to(device))
- counter.copy_(torch.zeros_like(counter))
- deltas.append(delta)
- deltas = torch.stack(deltas)
-
- # Sum deltas across ranks.
- if _sync_device is not None:
- torch.distributed.all_reduce(deltas)
-
- # Update cumulative values.
- deltas = deltas.cpu()
- for idx, name in enumerate(names):
- if name not in _cumulative:
- _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
- _cumulative[name].add_(deltas[idx])
-
- # Return name-value pairs.
- return [(name, _cumulative[name]) for name in names]
-
-#----------------------------------------------------------------------------
diff --git a/app.py b/app.py
index 1d5f6bcdc75c1d098989629b34e0b8ca79145394..201c1e76c2d90821cf1c892421dd0d378b396b4f 100644
--- a/app.py
+++ b/app.py
@@ -1,40 +1,89 @@
import gradio as gr
-import utils
+import utils.utils as utils
from PIL import Image
import torch
import math
from torchvision import transforms
-
-
+from run_pti import run_PTI
device = "cpu"
years = [str(y) for y in range(1880, 2020, 10)]
+decades = [y + "s" for y in years]
+transform = transforms.Compose([
+ transforms.Resize((256, 256)),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
+
orig_models = {}
for year in years:
G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device)
- orig_models[year] = { "G": G.eval()}
+ orig_models[year] = { "G": G.eval().float()}
def run_alignment(image_path,idx=None):
import dlib
from align_all_parallel import align_face
predictor = dlib.shape_predictor("pretrained_models/shape_predictor_68_face_landmarks.dat")
- aligned_image = align_face(filepath=image_path, predictor=predictor, idx=idx)
- print("Aligned image has shape: {}".format(aligned_image.size))
+ aligned_image = align_face(filepath=image_path, predictor=predictor, idx=idx)
+
return aligned_image
-def predict(inp):
+def predict(inp, in_decade):
#with torch.no_grad():
inp.save("imgs/input.png")
- out = run_alignment("imgs/input.png", idx=0)
- return out
+ inversion = run_alignment("imgs/input.png", idx=0)
+ inversion.save("imgs/cropped/input.png")
+ run_PTI(run_name="gradio_demo", use_wandb=False, use_multi_id_training=False)
+ #inversion = Image.open("imgs/cropped/input.png")
+
+ in_year = in_decade[:-1]
+ pti_models = {}
+
+ for year in years:
+ G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device)
+ pti_models[year] = { "G": G.eval().float()}
+
+
+ pti_models[in_year]['G'] = torch.load(f"checkpoints/model_gradio_demo_input.pt", device).eval().float()
+
+ for year in years:
+ if year != in_year:
+ for p_pti, p_orig, (names, p) in zip(pti_models[in_year]['G'].parameters(),orig_models[in_year]['G'].parameters(), pti_models[year]['G'].named_parameters()):
+ with torch.no_grad():
+ delta = p_pti - p_orig
+ p += delta
+
+ space = 0
+ dst = Image.new("RGB", (256 * (len(years) + 1) + (space * len(years)), 256), color='white')
+
+
+ w_pti = torch.load(f"embeddings/{in_year}/PTI/input/0.pt", map_location=device)
+
+ border_width = 10
+ #fill_color = 'red'
+ dst.paste(inversion, (0, 0))
+
+
+
+ for i in range(0, len(years)):
+ year = str(years[i])
+ with torch.no_grad():
+ child_tensor = pti_models[year]["G"].synthesis(w_pti.view(1, 14, 512), noise_mode="const", force_fp32=True)
+ img = utils.tensor2im(child_tensor.squeeze(0))
+ # if year == in_year:
+ # img = img.crop((border_width, border_width, 256 - border_width, 256-border_width))
+ # img = PIL.ImageOps.expand(img, border=border_width, fill=fill_color)
+ dst.paste(img, ((256 + space) * (i+1), 0))
+ dst
+ return dst
gr.Interface(fn=predict,
- inputs=gr.Image(type="pil"),
- outputs=gr.Image(type="pil"),
- #examples=["lion.jpg", "cheetah.jpg"]
- ).launch()
+ inputs=[gr.Image(label="Input Image", type="pil"), gr.Dropdown(label="Input Decade", choices=decades, value="2010s")],
+ outputs=gr.Image(label="Decade Transformations", type="pil"),
+ examples=[["imgs/Steven-Yeun.jpg", "2010s"]]
+
+ ).launch() #.launch(server_name="0.0.0.0", server_port=8098)
diff --git a/checkpoints/model_gradio_demo_input.pt b/checkpoints/model_gradio_demo_input.pt
new file mode 100644
index 0000000000000000000000000000000000000000..faddf9dcb32f85217f12c31e9b9cd5a8a3ccdead
--- /dev/null
+++ b/checkpoints/model_gradio_demo_input.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:65ee5644ec8ab0966a4eb51971995c071f8178db765750d45e32e0ed18a09738
+size 99867041
diff --git a/PTI/criteria/color_transfer_loss.py b/color_transfer_loss.py
similarity index 100%
rename from PTI/criteria/color_transfer_loss.py
rename to color_transfer_loss.py
diff --git a/PTI/configs/__init__.py b/configs/__init__.py
similarity index 100%
rename from PTI/configs/__init__.py
rename to configs/__init__.py
diff --git a/PTI/configs/evaluation_config.py b/configs/evaluation_config.py
similarity index 100%
rename from PTI/configs/evaluation_config.py
rename to configs/evaluation_config.py
diff --git a/PTI/configs/global_config.py b/configs/global_config.py
similarity index 80%
rename from PTI/configs/global_config.py
rename to configs/global_config.py
index 52dc2253c4d157fe7b7e0e630dc4a6eac9e1e142..82c1033caa6ea35404886cd17be3a510ceea4fbf 100644
--- a/PTI/configs/global_config.py
+++ b/configs/global_config.py
@@ -1,6 +1,6 @@
## Device
-cuda_visible_devices = "1"
-device = "cuda:0"
+cuda_visible_devices = "0"
+device = "cpu"
## Logs
training_step = 1
diff --git a/PTI/configs/hyperparameters.py b/configs/hyperparameters.py
similarity index 95%
rename from PTI/configs/hyperparameters.py
rename to configs/hyperparameters.py
index 67901452855be546851289d8fb67920933922075..62871505436118e52583a5b8e918d9d191dda73b 100644
--- a/PTI/configs/hyperparameters.py
+++ b/configs/hyperparameters.py
@@ -28,4 +28,4 @@ max_images_to_invert = 10
pti_learning_rate = 3e-4
first_inv_lr = 5e-3
train_batch_size = 1
-use_last_w_pivots = True
+use_last_w_pivots = False
diff --git a/PTI/configs/paths_config.py b/configs/paths_config.py
similarity index 75%
rename from PTI/configs/paths_config.py
rename to configs/paths_config.py
index 931d699f37b17ade6560f11520169e664ebfc96e..1513a9d7b347fc7bb56ba0a4affdfe3e893cf220 100644
--- a/PTI/configs/paths_config.py
+++ b/configs/paths_config.py
@@ -4,12 +4,12 @@ year = "2010"
e4e = "./pretrained_models/e4e_ffhq_encode.pt"
-stylegan2_ada_ffhq = f"../pretrained_models/{year}.pkl"
+stylegan2_ada_ffhq = f"pretrained_models/{year}.pkl"
style_clip_pretrained_mappers = ""
-ir_se50 = "/share/phoenix/nfs04/S7/wikitime_models/model_ir_se50.pth"
+ir_se50 = "pretrained_models/model_ir_se50.pth"
dlib = "./pretrained_models/align.dat"
-deeplab = "/share/phoenix/nfs04/S7/wikitime_models/deeplab_model/deeplab_model.pth"
+deeplab = "pretrained_models/deeplab_model/deeplab_model.pth"
## Dirs for output files
checkpoints_dir = "./checkpoints"
@@ -20,7 +20,7 @@ experiments_output_dir = "./output"
## Input info
### Input dir, where the images reside
input_data_path = (
- f"/share/phoenix/nfs04/S7/emc348/WikiFaces/datasets/new_crops/test/{year}"
+ f"imgs/cropped"
)
input_data_id = f"{year}"
diff --git a/PTI/criteria/__init__.py b/criteria/__init__.py
similarity index 100%
rename from PTI/criteria/__init__.py
rename to criteria/__init__.py
diff --git a/PTI/criteria/backbones/__init__.py b/criteria/backbones/__init__.py
similarity index 100%
rename from PTI/criteria/backbones/__init__.py
rename to criteria/backbones/__init__.py
diff --git a/PTI/criteria/backbones/iresnet.py b/criteria/backbones/iresnet.py
similarity index 100%
rename from PTI/criteria/backbones/iresnet.py
rename to criteria/backbones/iresnet.py
diff --git a/PTI/criteria/backbones/iresnet2060.py b/criteria/backbones/iresnet2060.py
similarity index 100%
rename from PTI/criteria/backbones/iresnet2060.py
rename to criteria/backbones/iresnet2060.py
diff --git a/PTI/criteria/backbones/mobilefacenet.py b/criteria/backbones/mobilefacenet.py
similarity index 100%
rename from PTI/criteria/backbones/mobilefacenet.py
rename to criteria/backbones/mobilefacenet.py
diff --git a/PTI/criteria/deeplab.py b/criteria/deeplab.py
similarity index 100%
rename from PTI/criteria/deeplab.py
rename to criteria/deeplab.py
diff --git a/PTI/criteria/helpers.py b/criteria/helpers.py
similarity index 100%
rename from PTI/criteria/helpers.py
rename to criteria/helpers.py
diff --git a/PTI/criteria/id_loss.py b/criteria/id_loss.py
similarity index 100%
rename from PTI/criteria/id_loss.py
rename to criteria/id_loss.py
diff --git a/PTI/criteria/l2_loss.py b/criteria/l2_loss.py
similarity index 100%
rename from PTI/criteria/l2_loss.py
rename to criteria/l2_loss.py
diff --git a/PTI/criteria/localitly_regulizer.py b/criteria/localitly_regulizer.py
similarity index 100%
rename from PTI/criteria/localitly_regulizer.py
rename to criteria/localitly_regulizer.py
diff --git a/PTI/criteria/mask.py b/criteria/mask.py
similarity index 100%
rename from PTI/criteria/mask.py
rename to criteria/mask.py
diff --git a/PTI/criteria/model_irse.py b/criteria/model_irse.py
similarity index 100%
rename from PTI/criteria/model_irse.py
rename to criteria/model_irse.py
diff --git a/PTI/criteria/validation.py b/criteria/validation.py
similarity index 100%
rename from PTI/criteria/validation.py
rename to criteria/validation.py
diff --git a/dnnlib/__pycache__/__init__.cpython-39.pyc b/dnnlib/__pycache__/__init__.cpython-39.pyc
index 0d65d958504c94699a494d504272f86185474ecc..56545ef7c6547f8c4b2046e5ea1d76c302995f9e 100644
Binary files a/dnnlib/__pycache__/__init__.cpython-39.pyc and b/dnnlib/__pycache__/__init__.cpython-39.pyc differ
diff --git a/dnnlib/__pycache__/util.cpython-39.pyc b/dnnlib/__pycache__/util.cpython-39.pyc
index f89a9685b65a3158471685f56daab3197c58be66..cdb0592847d19c19a5def909c3ff01c37300677d 100644
Binary files a/dnnlib/__pycache__/util.cpython-39.pyc and b/dnnlib/__pycache__/util.cpython-39.pyc differ
diff --git a/embeddings/2010/PTI/input/0.pt b/embeddings/2010/PTI/input/0.pt
new file mode 100644
index 0000000000000000000000000000000000000000..c9bed619221b4030ec4df47345c4d80f4c6267d9
--- /dev/null
+++ b/embeddings/2010/PTI/input/0.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cb9b3fc3d08f6dd3ee5c87299fff8cd932e4a6afaa90fe06f3d5c5f9503ebf26
+size 29419
diff --git a/imgs/Steven-Yeun.jpg b/imgs/Steven-Yeun.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..28583afeb8306b28c29ca83419c7ff9d75747403
--- /dev/null
+++ b/imgs/Steven-Yeun.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f7d9da1331c75fc2b8ac8caa024c804ac500c9c29b5ed4edf60bf30247eae8a5
+size 1627062
diff --git a/imgs/cropped/input.png b/imgs/cropped/input.png
new file mode 100644
index 0000000000000000000000000000000000000000..b046f6481e637e5f91dfa89b37927ce441811877
--- /dev/null
+++ b/imgs/cropped/input.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ba7b8df0bffe226c723eb22c537e66ff9de844e6aae7845a6c88e696f03b6a40
+size 104592
diff --git a/imgs/input.png b/imgs/input.png
new file mode 100644
index 0000000000000000000000000000000000000000..0ab9444f18b692a7cacd82159370a64b02a29bdd
--- /dev/null
+++ b/imgs/input.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3f8c1b42d80f44efcf0cb03a301072284a0b7ad6ae6f11871be44b7fae79613e
+size 13719370
diff --git a/PTI/training/__init__.py b/models/StyleCLIP/__init__.py
similarity index 100%
rename from PTI/training/__init__.py
rename to models/StyleCLIP/__init__.py
diff --git a/PTI/training/coaches/__init__.py b/models/StyleCLIP/criteria/__init__.py
similarity index 100%
rename from PTI/training/coaches/__init__.py
rename to models/StyleCLIP/criteria/__init__.py
diff --git a/models/StyleCLIP/criteria/clip_loss.py b/models/StyleCLIP/criteria/clip_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..18176ee8eb0d992d69d5b951d7f36e2efa92a37b
--- /dev/null
+++ b/models/StyleCLIP/criteria/clip_loss.py
@@ -0,0 +1,17 @@
+
+import torch
+import clip
+
+
+class CLIPLoss(torch.nn.Module):
+
+ def __init__(self, opts):
+ super(CLIPLoss, self).__init__()
+ self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
+ self.upsample = torch.nn.Upsample(scale_factor=7)
+ self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32)
+
+ def forward(self, image, text):
+ image = self.avg_pool(self.upsample(image))
+ similarity = 1 - self.model(image, text)[0] / 100
+ return similarity
\ No newline at end of file
diff --git a/models/StyleCLIP/criteria/id_loss.py b/models/StyleCLIP/criteria/id_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..a828023e115243e48918538d31b91d662cd12d0f
--- /dev/null
+++ b/models/StyleCLIP/criteria/id_loss.py
@@ -0,0 +1,39 @@
+import torch
+from torch import nn
+
+from models.facial_recognition.model_irse import Backbone
+
+
+class IDLoss(nn.Module):
+ def __init__(self, opts):
+ super(IDLoss, self).__init__()
+ print('Loading ResNet ArcFace')
+ self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
+ self.facenet.load_state_dict(torch.load(opts.ir_se50_weights))
+ self.pool = torch.nn.AdaptiveAvgPool2d((256, 256))
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
+ self.facenet.eval()
+ self.opts = opts
+
+ def extract_feats(self, x):
+ if x.shape[2] != 256:
+ x = self.pool(x)
+ x = x[:, :, 35:223, 32:220] # Crop interesting region
+ x = self.face_pool(x)
+ x_feats = self.facenet(x)
+ return x_feats
+
+ def forward(self, y_hat, y):
+ n_samples = y.shape[0]
+ y_feats = self.extract_feats(y) # Otherwise use the feature from there
+ y_hat_feats = self.extract_feats(y_hat)
+ y_feats = y_feats.detach()
+ loss = 0
+ sim_improvement = 0
+ count = 0
+ for i in range(n_samples):
+ diff_target = y_hat_feats[i].dot(y_feats[i])
+ loss += 1 - diff_target
+ count += 1
+
+ return loss / count, sim_improvement / count
diff --git a/models/StyleCLIP/global_directions/GUI.py b/models/StyleCLIP/global_directions/GUI.py
new file mode 100644
index 0000000000000000000000000000000000000000..19f7f8cce9305819b22664642799200d9e1cfff0
--- /dev/null
+++ b/models/StyleCLIP/global_directions/GUI.py
@@ -0,0 +1,103 @@
+
+
+from tkinter import Tk,Frame ,Label,Button,messagebox,Canvas,Text,Scale
+from tkinter import HORIZONTAL
+
+class View():
+ def __init__(self,master):
+
+ self.width=600
+ self.height=600
+
+
+ self.root=master
+ self.root.geometry("600x600")
+
+ self.left_frame=Frame(self.root,width=600)
+ self.left_frame.pack_propagate(0)
+ self.left_frame.pack(fill='both', side='left', expand='True')
+
+ self.retrieval_frame=Frame(self.root,bg='snow3')
+ self.retrieval_frame.pack_propagate(0)
+ self.retrieval_frame.pack(fill='both', side='right', expand='True')
+
+ self.bg_frame=Frame(self.left_frame,bg='snow3',height=600,width=600)
+ self.bg_frame.pack_propagate(0)
+ self.bg_frame.pack(fill='both', side='top', expand='True')
+
+ self.command_frame=Frame(self.left_frame,bg='snow3')
+ self.command_frame.pack_propagate(0)
+ self.command_frame.pack(fill='both', side='bottom', expand='True')
+# self.command_frame.grid(row=1, column=0,padx=0, pady=0)
+
+ self.bg=Canvas(self.bg_frame,width=self.width,height=self.height, bg='gray')
+ self.bg.place(relx=0.5, rely=0.5, anchor='center')
+
+ self.mani=Canvas(self.retrieval_frame,width=1024,height=1024, bg='gray')
+ self.mani.grid(row=0, column=0,padx=0, pady=42)
+
+ self.SetCommand()
+
+
+
+
+ def run(self):
+ self.root.mainloop()
+
+ def helloCallBack(self):
+ category=self.set_category.get()
+ messagebox.showinfo( "Hello Python",category)
+
+ def SetCommand(self):
+
+ tmp = Label(self.command_frame, text="neutral", width=10 ,bg='snow3')
+ tmp.grid(row=1, column=0,padx=10, pady=10)
+
+ tmp = Label(self.command_frame, text="a photo of a", width=10 ,bg='snow3')
+ tmp.grid(row=1, column=1,padx=10, pady=10)
+
+ self.neutral = Text ( self.command_frame, height=2, width=30)
+ self.neutral.grid(row=1, column=2,padx=10, pady=10)
+
+
+ tmp = Label(self.command_frame, text="target", width=10 ,bg='snow3')
+ tmp.grid(row=2, column=0,padx=10, pady=10)
+
+ tmp = Label(self.command_frame, text="a photo of a", width=10 ,bg='snow3')
+ tmp.grid(row=2, column=1,padx=10, pady=10)
+
+ self.target = Text ( self.command_frame, height=2, width=30)
+ self.target.grid(row=2, column=2,padx=10, pady=10)
+
+ tmp = Label(self.command_frame, text="strength", width=10 ,bg='snow3')
+ tmp.grid(row=3, column=0,padx=10, pady=10)
+
+ self.alpha = Scale(self.command_frame, from_=-15, to=25, orient=HORIZONTAL,bg='snow3', length=250,resolution=0.01)
+ self.alpha.grid(row=3, column=2,padx=10, pady=10)
+
+
+ tmp = Label(self.command_frame, text="disentangle", width=10 ,bg='snow3')
+ tmp.grid(row=4, column=0,padx=10, pady=10)
+
+ self.beta = Scale(self.command_frame, from_=0.08, to=0.4, orient=HORIZONTAL,bg='snow3', length=250,resolution=0.001)
+ self.beta.grid(row=4, column=2,padx=10, pady=10)
+
+ self.reset = Button(self.command_frame, text='Reset')
+ self.reset.grid(row=5, column=1,padx=10, pady=10)
+
+
+ self.set_init = Button(self.command_frame, text='Accept')
+ self.set_init.grid(row=5, column=2,padx=10, pady=10)
+
+#%%
+if __name__ == "__main__":
+ master=Tk()
+ self=View(master)
+ self.run()
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/models/StyleCLIP/global_directions/GenerateImg.py b/models/StyleCLIP/global_directions/GenerateImg.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c6dee48f2d6d9ac37c00ee77c7a46c2cc6b25e1
--- /dev/null
+++ b/models/StyleCLIP/global_directions/GenerateImg.py
@@ -0,0 +1,50 @@
+
+import os
+import numpy as np
+import argparse
+from manipulate import Manipulator
+
+from PIL import Image
+#%%
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='Process some integers.')
+
+ parser.add_argument('--dataset_name',type=str,default='ffhq',
+ help='name of dataset, for example, ffhq')
+
+ args = parser.parse_args()
+ dataset_name=args.dataset_name
+
+ if not os.path.isdir('./data/'+dataset_name):
+ os.system('mkdir ./data/'+dataset_name)
+ #%%
+ M=Manipulator(dataset_name=dataset_name)
+ np.set_printoptions(suppress=True)
+ print(M.dataset_name)
+ #%%
+
+ M.img_index=0
+ M.num_images=50
+ M.alpha=[0]
+ M.step=1
+ lindex,bname=0,0
+
+ M.manipulate_layers=[lindex]
+ codes,out=M.EditOneC(bname)
+ #%%
+
+ for i in range(len(out)):
+ img=out[i,0]
+ img=Image.fromarray(img)
+ img.save('./data/'+dataset_name+'/'+str(i)+'.jpg')
+ #%%
+ w=np.load('./npy/'+dataset_name+'/W.npy')
+
+ tmp=w[:M.num_images]
+ tmp=tmp[:,None,:]
+ tmp=np.tile(tmp,(1,M.Gs.components.synthesis.input_shape[1],1))
+
+ np.save('./data/'+dataset_name+'/w_plus.npy',tmp)
+
+
\ No newline at end of file
diff --git a/models/StyleCLIP/global_directions/GetCode.py b/models/StyleCLIP/global_directions/GetCode.py
new file mode 100644
index 0000000000000000000000000000000000000000..62e64dc8cbc5ad2bb16aef5da8f6d41c26b24170
--- /dev/null
+++ b/models/StyleCLIP/global_directions/GetCode.py
@@ -0,0 +1,232 @@
+
+
+
+import os
+import pickle
+import numpy as np
+from dnnlib import tflib
+import tensorflow as tf
+
+import argparse
+
+def LoadModel(dataset_name):
+ # Initialize TensorFlow.
+ tflib.init_tf()
+ model_path='./model/'
+ model_name=dataset_name+'.pkl'
+
+ tmp=os.path.join(model_path,model_name)
+ with open(tmp, 'rb') as f:
+ _, _, Gs = pickle.load(f)
+ return Gs
+
+def lerp(a,b,t):
+ return a + (b - a) * t
+
+#stylegan-ada
+def SelectName(layer_name,suffix):
+ if suffix==None:
+ tmp1='add:0' in layer_name
+ tmp2='shape=(?,' in layer_name
+ tmp4='G_synthesis_1' in layer_name
+ tmp= tmp1 and tmp2 and tmp4
+ else:
+ tmp1=('/Conv0_up'+suffix) in layer_name
+ tmp2=('/Conv1'+suffix) in layer_name
+ tmp3=('4x4/Conv'+suffix) in layer_name
+ tmp4='G_synthesis_1' in layer_name
+ tmp5=('/ToRGB'+suffix) in layer_name
+ tmp= (tmp1 or tmp2 or tmp3 or tmp5) and tmp4
+ return tmp
+
+
+def GetSNames(suffix):
+ #get style tensor name
+ with tf.Session() as sess:
+ op = sess.graph.get_operations()
+ layers=[m.values() for m in op]
+
+
+ select_layers=[]
+ for layer in layers:
+ layer_name=str(layer)
+ if SelectName(layer_name,suffix):
+ select_layers.append(layer[0])
+ return select_layers
+
+def SelectName2(layer_name):
+ tmp1='mod_bias' in layer_name
+ tmp2='mod_weight' in layer_name
+ tmp3='ToRGB' in layer_name
+
+ tmp= (tmp1 or tmp2) and (not tmp3)
+ return tmp
+
+def GetKName(Gs):
+
+ layers=[var for name, var in Gs.components.synthesis.vars.items()]
+
+ select_layers=[]
+ for layer in layers:
+ layer_name=str(layer)
+ if SelectName2(layer_name):
+ select_layers.append(layer)
+ return select_layers
+
+def GetCode(Gs,random_state,num_img,num_once,dataset_name):
+ rnd = np.random.RandomState(random_state) #5
+
+ truncation_psi=0.7
+ truncation_cutoff=8
+
+ dlatent_avg=Gs.get_var('dlatent_avg')
+
+ dlatents=np.zeros((num_img,512),dtype='float32')
+ for i in range(int(num_img/num_once)):
+ src_latents = rnd.randn(num_once, Gs.input_shape[1])
+ src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component]
+
+ # Apply truncation trick.
+ if truncation_psi is not None and truncation_cutoff is not None:
+ layer_idx = np.arange(src_dlatents.shape[1])[np.newaxis, :, np.newaxis]
+ ones = np.ones(layer_idx.shape, dtype=np.float32)
+ coefs = np.where(layer_idx < truncation_cutoff, truncation_psi * ones, ones)
+ src_dlatents_np=lerp(dlatent_avg, src_dlatents, coefs)
+ src_dlatents=src_dlatents_np[:,0,:].astype('float32')
+ dlatents[(i*num_once):((i+1)*num_once),:]=src_dlatents
+ print('get all z and w')
+
+ tmp='./npy/'+dataset_name+'/W'
+ np.save(tmp,dlatents)
+
+
+def GetImg(Gs,num_img,num_once,dataset_name,save_name='images'):
+ print('Generate Image')
+ tmp='./npy/'+dataset_name+'/W.npy'
+ dlatents=np.load(tmp)
+ fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
+
+ all_images=[]
+ for i in range(int(num_img/num_once)):
+ print(i)
+ images=[]
+ for k in range(num_once):
+ tmp=dlatents[i*num_once+k]
+ tmp=tmp[None,None,:]
+ tmp=np.tile(tmp,(1,Gs.components.synthesis.input_shape[1],1))
+ image2= Gs.components.synthesis.run(tmp, randomize_noise=False, output_transform=fmt)
+ images.append(image2)
+
+ images=np.concatenate(images)
+
+ all_images.append(images)
+
+ all_images=np.concatenate(all_images)
+
+ tmp='./npy/'+dataset_name+'/'+save_name
+ np.save(tmp,all_images)
+
+def GetS(dataset_name,num_img):
+ print('Generate S')
+ tmp='./npy/'+dataset_name+'/W.npy'
+ dlatents=np.load(tmp)[:num_img]
+
+ with tf.Session() as sess:
+ init = tf.global_variables_initializer()
+ sess.run(init)
+
+ Gs=LoadModel(dataset_name)
+ Gs.print_layers() #for ada
+ select_layers1=GetSNames(suffix=None) #None,'/mul_1:0','/mod_weight/read:0','/MatMul:0'
+ dlatents=dlatents[:,None,:]
+ dlatents=np.tile(dlatents,(1,Gs.components.synthesis.input_shape[1],1))
+
+ all_s = sess.run(
+ select_layers1,
+ feed_dict={'G_synthesis_1/dlatents_in:0': dlatents})
+
+ layer_names=[layer.name for layer in select_layers1]
+ save_tmp=[layer_names,all_s]
+ return save_tmp
+
+
+
+
+def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False):
+ """Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
+ Can be used as an output transformation for Network.run().
+ """
+ if nchw_to_nhwc:
+ images = np.transpose(images, [0, 2, 3, 1])
+
+ scale = 255 / (drange[1] - drange[0])
+ images = images * scale + (0.5 - drange[0] * scale)
+
+ np.clip(images, 0, 255, out=images)
+ images=images.astype('uint8')
+ return images
+
+
+def GetCodeMS(dlatents):
+ m=[]
+ std=[]
+ for i in range(len(dlatents)):
+ tmp= dlatents[i]
+ tmp_mean=tmp.mean(axis=0)
+ tmp_std=tmp.std(axis=0)
+ m.append(tmp_mean)
+ std.append(tmp_std)
+ return m,std
+
+
+
+#%%
+if __name__ == "__main__":
+
+
+ parser = argparse.ArgumentParser(description='Process some integers.')
+
+ parser.add_argument('--dataset_name',type=str,default='ffhq',
+ help='name of dataset, for example, ffhq')
+ parser.add_argument('--code_type',choices=['w','s','s_mean_std'],default='w')
+
+ args = parser.parse_args()
+ random_state=5
+ num_img=100_000
+ num_once=1_000
+ dataset_name=args.dataset_name
+
+ if not os.path.isfile('./model/'+dataset_name+'.pkl'):
+ url='https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/'
+ name='stylegan2-'+dataset_name+'-config-f.pkl'
+ os.system('wget ' +url+name + ' -P ./model/')
+ os.system('mv ./model/'+name+' ./model/'+dataset_name+'.pkl')
+
+ if not os.path.isdir('./npy/'+dataset_name):
+ os.system('mkdir ./npy/'+dataset_name)
+
+ if args.code_type=='w':
+ Gs=LoadModel(dataset_name=dataset_name)
+ GetCode(Gs,random_state,num_img,num_once,dataset_name)
+# GetImg(Gs,num_img=num_img,num_once=num_once,dataset_name=dataset_name,save_name='images_100K') #no need
+ elif args.code_type=='s':
+ save_name='S'
+ save_tmp=GetS(dataset_name,num_img=2_000)
+ tmp='./npy/'+dataset_name+'/'+save_name
+ with open(tmp, "wb") as fp:
+ pickle.dump(save_tmp, fp)
+
+ elif args.code_type=='s_mean_std':
+ save_tmp=GetS(dataset_name,num_img=num_img)
+ dlatents=save_tmp[1]
+ m,std=GetCodeMS(dlatents)
+ save_tmp=[m,std]
+ save_name='S_mean_std'
+ tmp='./npy/'+dataset_name+'/'+save_name
+ with open(tmp, "wb") as fp:
+ pickle.dump(save_tmp, fp)
+
+
+
+
+
diff --git a/models/StyleCLIP/global_directions/GetGUIData.py b/models/StyleCLIP/global_directions/GetGUIData.py
new file mode 100644
index 0000000000000000000000000000000000000000..52f77213ab88edf8b33eff166b89b9e56ac4ff01
--- /dev/null
+++ b/models/StyleCLIP/global_directions/GetGUIData.py
@@ -0,0 +1,67 @@
+
+import os
+import numpy as np
+import argparse
+from manipulate import Manipulator
+import torch
+from PIL import Image
+#%%
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='Process some integers.')
+
+ parser.add_argument('--dataset_name',type=str,default='ffhq',
+ help='name of dataset, for example, ffhq')
+
+ parser.add_argument('--real', action='store_true')
+
+ args = parser.parse_args()
+ dataset_name=args.dataset_name
+
+ if not os.path.isdir('./data/'+dataset_name):
+ os.system('mkdir ./data/'+dataset_name)
+ #%%
+ M=Manipulator(dataset_name=dataset_name)
+ np.set_printoptions(suppress=True)
+ print(M.dataset_name)
+ #%%
+ #remove all .jpg
+ names=os.listdir('./data/'+dataset_name+'/')
+ for name in names:
+ if '.jpg' in name:
+ os.system('rm ./data/'+dataset_name+'/'+name)
+
+
+ #%%
+ if args.real:
+ latents=torch.load('./data/'+dataset_name+'/latents.pt')
+ w_plus=latents.cpu().detach().numpy()
+ else:
+ w=np.load('./npy/'+dataset_name+'/W.npy')
+ tmp=w[:50] #only use 50 images
+ tmp=tmp[:,None,:]
+ w_plus=np.tile(tmp,(1,M.Gs.components.synthesis.input_shape[1],1))
+ np.save('./data/'+dataset_name+'/w_plus.npy',w_plus)
+
+ #%%
+ tmp=M.W2S(w_plus)
+ M.dlatents=tmp
+
+ M.img_index=0
+ M.num_images=len(w_plus)
+ M.alpha=[0]
+ M.step=1
+ lindex,bname=0,0
+
+ M.manipulate_layers=[lindex]
+ codes,out=M.EditOneC(bname)
+ #%%
+
+ for i in range(len(out)):
+ img=out[i,0]
+ img=Image.fromarray(img)
+ img.save('./data/'+dataset_name+'/'+str(i)+'.jpg')
+ #%%
+
+
+
\ No newline at end of file
diff --git a/models/StyleCLIP/global_directions/Inference.py b/models/StyleCLIP/global_directions/Inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..a292787c88a370b15b4f0d633ac27bb5bed2b510
--- /dev/null
+++ b/models/StyleCLIP/global_directions/Inference.py
@@ -0,0 +1,106 @@
+
+
+from manipulate import Manipulator
+import tensorflow as tf
+import numpy as np
+import torch
+import clip
+from MapTS import GetBoundary,GetDt
+
+class StyleCLIP():
+
+ def __init__(self,dataset_name='ffhq'):
+ print('load clip')
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.model, preprocess = clip.load("ViT-B/32", device=device)
+ self.LoadData(dataset_name)
+
+ def LoadData(self, dataset_name):
+ tf.keras.backend.clear_session()
+ M=Manipulator(dataset_name=dataset_name)
+ np.set_printoptions(suppress=True)
+ fs3=np.load('./npy/'+dataset_name+'/fs3.npy')
+
+ self.M=M
+ self.fs3=fs3
+
+ w_plus=np.load('./data/'+dataset_name+'/w_plus.npy')
+ self.M.dlatents=M.W2S(w_plus)
+
+ if dataset_name=='ffhq':
+ self.c_threshold=20
+ else:
+ self.c_threshold=100
+ self.SetInitP()
+
+ def SetInitP(self):
+ self.M.alpha=[3]
+ self.M.num_images=1
+
+ self.target=''
+ self.neutral=''
+ self.GetDt2()
+ img_index=0
+ self.M.dlatent_tmp=[tmp[img_index:(img_index+1)] for tmp in self.M.dlatents]
+
+
+ def GetDt2(self):
+ classnames=[self.target,self.neutral]
+ dt=GetDt(classnames,self.model)
+
+ self.dt=dt
+ num_cs=[]
+ betas=np.arange(0.1,0.3,0.01)
+ for i in range(len(betas)):
+ boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=betas[i])
+ print(betas[i])
+ num_cs.append(num_c)
+
+ num_cs=np.array(num_cs)
+ select=num_cs>self.c_threshold
+
+ if sum(select)==0:
+ self.beta=0.1
+ else:
+ self.beta=betas[select][-1]
+
+
+ def GetCode(self):
+ boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=self.beta)
+ codes=self.M.MSCode(self.M.dlatent_tmp,boundary_tmp2)
+ return codes
+
+ def GetImg(self):
+
+ codes=self.GetCode()
+ out=self.M.GenerateImg(codes)
+ img=out[0,0]
+ return img
+
+
+
+
+#%%
+if __name__ == "__main__":
+ style_clip=StyleCLIP()
+ self=style_clip
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/models/StyleCLIP/global_directions/MapTS.py b/models/StyleCLIP/global_directions/MapTS.py
new file mode 100644
index 0000000000000000000000000000000000000000..2160a62cdbb0278d213076637f79b1e6f66db906
--- /dev/null
+++ b/models/StyleCLIP/global_directions/MapTS.py
@@ -0,0 +1,394 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Created on Thu Feb 4 17:36:31 2021
+
+@author: wuzongze
+"""
+
+import os
+#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+#os.environ["CUDA_VISIBLE_DEVICES"] = "1" #(or "1" or "2")
+
+import sys
+
+#sys.path=['', '/usr/local/tensorflow/avx-avx2-gpu/1.14.0/python3.7/site-packages', '/usr/local/matlab/2018b/lib/python3.7/site-packages', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python37.zip', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/lib-dynload', '/usr/lib/python3.7', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/site-packages', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/site-packages/copkmeans-1.5-py3.7.egg', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/site-packages/spherecluster-0.1.7-py3.7.egg', '/usr/lib/python3/dist-packages', '/usr/local/lib/python3.7/dist-packages', '/usr/lib/python3/dist-packages/IPython/extensions']
+
+import tensorflow as tf
+
+import numpy as np
+import torch
+import clip
+from PIL import Image
+import pickle
+import copy
+import matplotlib.pyplot as plt
+
+def GetAlign(out,dt,model,preprocess):
+ imgs=out
+ imgs1=imgs.reshape([-1]+list(imgs.shape[2:]))
+
+ tmp=[]
+ for i in range(len(imgs1)):
+
+ img=Image.fromarray(imgs1[i])
+ image = preprocess(img).unsqueeze(0).to(device)
+ tmp.append(image)
+
+ image=torch.cat(tmp)
+
+ with torch.no_grad():
+ image_features = model.encode_image(image)
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
+
+ image_features1=image_features.cpu().numpy()
+
+ image_features1=image_features1.reshape(list(imgs.shape[:2])+[512])
+
+ fd=image_features1[:,1:,:]-image_features1[:,:-1,:]
+
+ fd1=fd.reshape([-1,512])
+ fd2=fd1/np.linalg.norm(fd1,axis=1)[:,None]
+
+ tmp=np.dot(fd2,dt)
+ m=tmp.mean()
+ acc=np.sum(tmp>0)/len(tmp)
+ print(m,acc)
+ return m,acc
+
+
+def SplitS(ds_p,M,if_std):
+ all_ds=[]
+ start=0
+ for i in M.mindexs:
+ tmp=M.dlatents[i].shape[1]
+ end=start+tmp
+ tmp=ds_p[start:end]
+# tmp=tmp*M.code_std[i]
+
+ all_ds.append(tmp)
+ start=end
+
+ all_ds2=[]
+ tmp_index=0
+ for i in range(len(M.s_names)):
+ if (not 'RGB' in M.s_names[i]) and (not len(all_ds[tmp_index])==0):
+
+# tmp=np.abs(all_ds[tmp_index]/M.code_std[i])
+# print(i,tmp.mean())
+# tmp=np.dot(M.latent_codes[i],all_ds[tmp_index])
+# print(tmp)
+ if if_std:
+ tmp=all_ds[tmp_index]*M.code_std[i]
+ else:
+ tmp=all_ds[tmp_index]
+
+ all_ds2.append(tmp)
+ tmp_index+=1
+ else:
+ tmp=np.zeros(len(M.dlatents[i][0]))
+ all_ds2.append(tmp)
+ return all_ds2
+
+
+imagenet_templates = [
+ 'a bad photo of a {}.',
+# 'a photo of many {}.',
+ 'a sculpture of a {}.',
+ 'a photo of the hard to see {}.',
+ 'a low resolution photo of the {}.',
+ 'a rendering of a {}.',
+ 'graffiti of a {}.',
+ 'a bad photo of the {}.',
+ 'a cropped photo of the {}.',
+ 'a tattoo of a {}.',
+ 'the embroidered {}.',
+ 'a photo of a hard to see {}.',
+ 'a bright photo of a {}.',
+ 'a photo of a clean {}.',
+ 'a photo of a dirty {}.',
+ 'a dark photo of the {}.',
+ 'a drawing of a {}.',
+ 'a photo of my {}.',
+ 'the plastic {}.',
+ 'a photo of the cool {}.',
+ 'a close-up photo of a {}.',
+ 'a black and white photo of the {}.',
+ 'a painting of the {}.',
+ 'a painting of a {}.',
+ 'a pixelated photo of the {}.',
+ 'a sculpture of the {}.',
+ 'a bright photo of the {}.',
+ 'a cropped photo of a {}.',
+ 'a plastic {}.',
+ 'a photo of the dirty {}.',
+ 'a jpeg corrupted photo of a {}.',
+ 'a blurry photo of the {}.',
+ 'a photo of the {}.',
+ 'a good photo of the {}.',
+ 'a rendering of the {}.',
+ 'a {} in a video game.',
+ 'a photo of one {}.',
+ 'a doodle of a {}.',
+ 'a close-up photo of the {}.',
+ 'a photo of a {}.',
+ 'the origami {}.',
+ 'the {} in a video game.',
+ 'a sketch of a {}.',
+ 'a doodle of the {}.',
+ 'a origami {}.',
+ 'a low resolution photo of a {}.',
+ 'the toy {}.',
+ 'a rendition of the {}.',
+ 'a photo of the clean {}.',
+ 'a photo of a large {}.',
+ 'a rendition of a {}.',
+ 'a photo of a nice {}.',
+ 'a photo of a weird {}.',
+ 'a blurry photo of a {}.',
+ 'a cartoon {}.',
+ 'art of a {}.',
+ 'a sketch of the {}.',
+ 'a embroidered {}.',
+ 'a pixelated photo of a {}.',
+ 'itap of the {}.',
+ 'a jpeg corrupted photo of the {}.',
+ 'a good photo of a {}.',
+ 'a plushie {}.',
+ 'a photo of the nice {}.',
+ 'a photo of the small {}.',
+ 'a photo of the weird {}.',
+ 'the cartoon {}.',
+ 'art of the {}.',
+ 'a drawing of the {}.',
+ 'a photo of the large {}.',
+ 'a black and white photo of a {}.',
+ 'the plushie {}.',
+ 'a dark photo of a {}.',
+ 'itap of a {}.',
+ 'graffiti of the {}.',
+ 'a toy {}.',
+ 'itap of my {}.',
+ 'a photo of a cool {}.',
+ 'a photo of a small {}.',
+ 'a tattoo of the {}.',
+]
+
+
+def zeroshot_classifier(classnames, templates,model):
+ with torch.no_grad():
+ zeroshot_weights = []
+ for classname in classnames:
+ texts = [template.format(classname) for template in templates] #format with class
+ texts = clip.tokenize(texts).cuda() #tokenize
+ class_embeddings = model.encode_text(texts) #embed with text encoder
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
+ class_embedding = class_embeddings.mean(dim=0)
+ class_embedding /= class_embedding.norm()
+ zeroshot_weights.append(class_embedding)
+ zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
+ return zeroshot_weights
+
+
+def GetDt(classnames,model):
+ text_features=zeroshot_classifier(classnames, imagenet_templates,model).t()
+
+ dt=text_features[0]-text_features[1]
+ dt=dt.cpu().numpy()
+
+# t_m1=t_m/np.linalg.norm(t_m)
+# dt=text_features.cpu().numpy()[0]-t_m1
+ print(np.linalg.norm(dt))
+ dt=dt/np.linalg.norm(dt)
+ return dt
+
+
+def GetBoundary(fs3,dt,M,threshold):
+ tmp=np.dot(fs3,dt)
+
+ ds_imp=copy.copy(tmp)
+ select=np.abs(tmp)", self.text_n)
+ self.view.target.bind("", self.text_t)
+ self.view.alpha.bind('', self.ChangeAlpha)
+ self.view.beta.bind('', self.ChangeBeta)
+ self.view.set_init.bind('', self.SetInit)
+ self.view.reset.bind('', self.Reset)
+ self.view.bg.bind('', self.open_img)
+
+
+ self.drawn = None
+
+ self.view.target.delete(1.0, "end")
+ self.view.target.insert("end", self.style_clip.target)
+#
+ self.view.neutral.delete(1.0, "end")
+ self.view.neutral.insert("end", self.style_clip.neutral)
+
+
+ def Reset(self,event):
+ self.style_clip.GetDt2()
+ self.style_clip.M.alpha=[0]
+
+ self.view.beta.set(self.style_clip.beta)
+ self.view.alpha.set(0)
+
+ img=self.style_clip.GetImg()
+ img=Image.fromarray(img)
+ img = ImageTk.PhotoImage(img)
+ self.addImage_m(img)
+
+
+ def SetInit(self,event):
+ codes=self.style_clip.GetCode()
+ self.style_clip.M.dlatent_tmp=[tmp[:,0] for tmp in codes]
+ print('set init')
+
+ def ChangeAlpha(self,event):
+ tmp=self.view.alpha.get()
+ self.style_clip.M.alpha=[float(tmp)]
+
+ img=self.style_clip.GetImg()
+ print('manipulate one')
+ img=Image.fromarray(img)
+ img = ImageTk.PhotoImage(img)
+ self.addImage_m(img)
+
+ def ChangeBeta(self,event):
+ tmp=self.view.beta.get()
+ self.style_clip.beta=float(tmp)
+
+ img=self.style_clip.GetImg()
+ print('manipulate one')
+ img=Image.fromarray(img)
+ img = ImageTk.PhotoImage(img)
+ self.addImage_m(img)
+
+ def ChangeDataset(self,event):
+
+ dataset_name=self.view.set_category.get()
+
+ self.style_clip.LoadData(dataset_name)
+
+ self.view.target.delete(1.0, "end")
+ self.view.target.insert("end", self.style_clip.target)
+
+ self.view.neutral.delete(1.0, "end")
+ self.view.neutral.insert("end", self.style_clip.neutral)
+
+ def text_t(self,event):
+ tmp=self.view.target.get("1.0",'end')
+ tmp=tmp.replace('\n','')
+
+ self.view.target.delete(1.0, "end")
+ self.view.target.insert("end", tmp)
+
+ print('target',tmp,'###')
+ self.style_clip.target=tmp
+ self.style_clip.GetDt2()
+ self.view.beta.set(self.style_clip.beta)
+ self.view.alpha.set(3)
+ self.style_clip.M.alpha=[3]
+
+ img=self.style_clip.GetImg()
+ print('manipulate one')
+ img=Image.fromarray(img)
+ img = ImageTk.PhotoImage(img)
+ self.addImage_m(img)
+
+
+ def text_n(self,event):
+ tmp=self.view.neutral.get("1.0",'end')
+ tmp=tmp.replace('\n','')
+
+ self.view.neutral.delete(1.0, "end")
+ self.view.neutral.insert("end", tmp)
+
+ print('neutral',tmp,'###')
+ self.style_clip.neutral=tmp
+ self.view.target.delete(1.0, "end")
+ self.view.target.insert("end", tmp)
+
+
+ def run(self):
+ self.root.mainloop()
+
+ def addImage(self,img):
+ self.view.bg.create_image(self.view.width/2, self.view.height/2, image=img, anchor='center')
+ self.image=img #save a copy of image. if not the image will disappear
+
+ def addImage_m(self,img):
+ self.view.mani.create_image(512, 512, image=img, anchor='center')
+ self.image2=img
+
+
+ def openfn(self):
+ filename = askopenfilename(title='open',initialdir='./data/'+self.style_clip.M.dataset_name+'/',filetypes=[("all image format", ".jpg"),("all image format", ".png")])
+ return filename
+
+ def open_img(self,event):
+ x = self.openfn()
+ print(x)
+
+
+ img = Image.open(x)
+ img2 = img.resize(( 512,512), Image.ANTIALIAS)
+ img2 = ImageTk.PhotoImage(img2)
+ self.addImage(img2)
+
+ img = ImageTk.PhotoImage(img)
+ self.addImage_m(img)
+
+ img_index=x.split('/')[-1].split('.')[0]
+ img_index=int(img_index)
+ print(img_index)
+ self.style_clip.M.img_index=img_index
+ self.style_clip.M.dlatent_tmp=[tmp[img_index:(img_index+1)] for tmp in self.style_clip.M.dlatents]
+
+
+ self.style_clip.GetDt2()
+ self.view.beta.set(self.style_clip.beta)
+ self.view.alpha.set(3)
+
+ #%%
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='Process some integers.')
+
+ parser.add_argument('--dataset_name',type=str,default='ffhq',
+ help='name of dataset, for example, ffhq')
+
+ args = parser.parse_args()
+ dataset_name=args.dataset_name
+
+ self=PlayInteractively(dataset_name)
+ self.run()
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/models/StyleCLIP/global_directions/SingleChannel.py b/models/StyleCLIP/global_directions/SingleChannel.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecaa7ec7898d37f8f5db171f9141a5253af3fa73
--- /dev/null
+++ b/models/StyleCLIP/global_directions/SingleChannel.py
@@ -0,0 +1,109 @@
+
+
+
+import numpy as np
+import torch
+import clip
+from PIL import Image
+import copy
+from manipulate import Manipulator
+import argparse
+
+def GetImgF(out,model,preprocess):
+ imgs=out
+ imgs1=imgs.reshape([-1]+list(imgs.shape[2:]))
+
+ tmp=[]
+ for i in range(len(imgs1)):
+
+ img=Image.fromarray(imgs1[i])
+ image = preprocess(img).unsqueeze(0).to(device)
+ tmp.append(image)
+
+ image=torch.cat(tmp)
+ with torch.no_grad():
+ image_features = model.encode_image(image)
+
+ image_features1=image_features.cpu().numpy()
+ image_features1=image_features1.reshape(list(imgs.shape[:2])+[512])
+
+ return image_features1
+
+def GetFs(fs):
+ tmp=np.linalg.norm(fs,axis=-1)
+ fs1=fs/tmp[:,:,:,None]
+ fs2=fs1[:,:,1,:]-fs1[:,:,0,:] # 5*sigma - (-5)* sigma
+ fs3=fs2/np.linalg.norm(fs2,axis=-1)[:,:,None]
+ fs3=fs3.mean(axis=1)
+ fs3=fs3/np.linalg.norm(fs3,axis=-1)[:,None]
+ return fs3
+
+#%%
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='Process some integers.')
+
+ parser.add_argument('--dataset_name',type=str,default='cat',
+ help='name of dataset, for example, ffhq')
+ args = parser.parse_args()
+ dataset_name=args.dataset_name
+
+ #%%
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ model, preprocess = clip.load("ViT-B/32", device=device)
+ #%%
+ M=Manipulator(dataset_name=dataset_name)
+ np.set_printoptions(suppress=True)
+ print(M.dataset_name)
+ #%%
+ img_sindex=0
+ num_images=100
+ dlatents_o=[]
+ tmp=img_sindex*num_images
+ for i in range(len(M.dlatents)):
+ tmp1=M.dlatents[i][tmp:(tmp+num_images)]
+ dlatents_o.append(tmp1)
+ #%%
+
+ all_f=[]
+ M.alpha=[-5,5] #ffhq 5
+ M.step=2
+ M.num_images=num_images
+ select=np.array(M.mindexs)<=16 #below or equal to 128 resolution
+ mindexs2=np.array(M.mindexs)[select]
+ for lindex in mindexs2: #ignore ToRGB layers
+ print(lindex)
+ num_c=M.dlatents[lindex].shape[1]
+ for cindex in range(num_c):
+
+ M.dlatents=copy.copy(dlatents_o)
+ M.dlatents[lindex][:,cindex]=M.code_mean[lindex][cindex]
+
+ M.manipulate_layers=[lindex]
+ codes,out=M.EditOneC(cindex)
+ image_features1=GetImgF(out,model,preprocess)
+ all_f.append(image_features1)
+
+ all_f=np.array(all_f)
+
+ fs3=GetFs(all_f)
+
+ #%%
+ file_path='./npy/'+M.dataset_name+'/'
+ np.save(file_path+'fs3',fs3)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/PTI/training/projectors/__init__.py b/models/StyleCLIP/global_directions/__init__.py
similarity index 100%
rename from PTI/training/projectors/__init__.py
rename to models/StyleCLIP/global_directions/__init__.py
diff --git a/models/StyleCLIP/global_directions/data/ffhq/w_plus.npy b/models/StyleCLIP/global_directions/data/ffhq/w_plus.npy
new file mode 100644
index 0000000000000000000000000000000000000000..db524aae88e16239679a8f72ccb3403fd16c95a9
--- /dev/null
+++ b/models/StyleCLIP/global_directions/data/ffhq/w_plus.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:394f0f166305654f49cd1b0cd3d4f2b7a51e740a449a1ebfa1c69f79d01399fa
+size 2506880
diff --git a/PTI/dnnlib/__init__.py b/models/StyleCLIP/global_directions/dnnlib/__init__.py
similarity index 86%
rename from PTI/dnnlib/__init__.py
rename to models/StyleCLIP/global_directions/dnnlib/__init__.py
index 2f08cf36f11f9b0fd94c1b7caeadf69b98375b04..c73940d81233142ae3dcd9a37b7ec2185c5d5fc5 100644
--- a/PTI/dnnlib/__init__.py
+++ b/models/StyleCLIP/global_directions/dnnlib/__init__.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
diff --git a/PTI/torch_utils/ops/__init__.py b/models/StyleCLIP/global_directions/dnnlib/tflib/__init__.py
similarity index 54%
rename from PTI/torch_utils/ops/__init__.py
rename to models/StyleCLIP/global_directions/dnnlib/tflib/__init__.py
index ece0ea08fe2e939cc260a1dafc0ab5b391b773d9..ca852844ec488c0134bffa647e25a40646ff4718 100644
--- a/PTI/torch_utils/ops/__init__.py
+++ b/models/StyleCLIP/global_directions/dnnlib/tflib/__init__.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
@@ -6,4 +6,15 @@
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
-# empty
+from . import autosummary
+from . import network
+from . import optimizer
+from . import tfutil
+from . import custom_ops
+
+from .tfutil import *
+from .network import Network
+
+from .optimizer import Optimizer
+
+from .custom_ops import get_plugin
diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/autosummary.py b/models/StyleCLIP/global_directions/dnnlib/tflib/autosummary.py
new file mode 100644
index 0000000000000000000000000000000000000000..56dfb96093bb5b1129a99585b4ce655b98d80009
--- /dev/null
+++ b/models/StyleCLIP/global_directions/dnnlib/tflib/autosummary.py
@@ -0,0 +1,193 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Helper for adding automatically tracked values to Tensorboard.
+
+Autosummary creates an identity op that internally keeps track of the input
+values and automatically shows up in TensorBoard. The reported value
+represents an average over input components. The average is accumulated
+constantly over time and flushed when save_summaries() is called.
+
+Notes:
+- The output tensor must be used as an input for something else in the
+ graph. Otherwise, the autosummary op will not get executed, and the average
+ value will not get accumulated.
+- It is perfectly fine to include autosummaries with the same name in
+ several places throughout the graph, even if they are executed concurrently.
+- It is ok to also pass in a python scalar or numpy array. In this case, it
+ is added to the average immediately.
+"""
+
+from collections import OrderedDict
+import numpy as np
+import tensorflow as tf
+from tensorboard import summary as summary_lib
+from tensorboard.plugins.custom_scalar import layout_pb2
+
+from . import tfutil
+from .tfutil import TfExpression
+from .tfutil import TfExpressionEx
+
+# Enable "Custom scalars" tab in TensorBoard for advanced formatting.
+# Disabled by default to reduce tfevents file size.
+enable_custom_scalars = False
+
+_dtype = tf.float64
+_vars = OrderedDict() # name => [var, ...]
+_immediate = OrderedDict() # name => update_op, update_value
+_finalized = False
+_merge_op = None
+
+
+def _create_var(name: str, value_expr: TfExpression) -> TfExpression:
+ """Internal helper for creating autosummary accumulators."""
+ assert not _finalized
+ name_id = name.replace("/", "_")
+ v = tf.cast(value_expr, _dtype)
+
+ if v.shape.is_fully_defined():
+ size = np.prod(v.shape.as_list())
+ size_expr = tf.constant(size, dtype=_dtype)
+ else:
+ size = None
+ size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype))
+
+ if size == 1:
+ if v.shape.ndims != 0:
+ v = tf.reshape(v, [])
+ v = [size_expr, v, tf.square(v)]
+ else:
+ v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))]
+ v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype))
+
+ with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None):
+ var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)]
+ update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v))
+
+ if name in _vars:
+ _vars[name].append(var)
+ else:
+ _vars[name] = [var]
+ return update_op
+
+
+def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None, condition: TfExpressionEx = True) -> TfExpressionEx:
+ """Create a new autosummary.
+
+ Args:
+ name: Name to use in TensorBoard
+ value: TensorFlow expression or python value to track
+ passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.
+
+ Example use of the passthru mechanism:
+
+ n = autosummary('l2loss', loss, passthru=n)
+
+ This is a shorthand for the following code:
+
+ with tf.control_dependencies([autosummary('l2loss', loss)]):
+ n = tf.identity(n)
+ """
+ tfutil.assert_tf_initialized()
+ name_id = name.replace("/", "_")
+
+ if tfutil.is_tf_expression(value):
+ with tf.name_scope("summary_" + name_id), tf.device(value.device):
+ condition = tf.convert_to_tensor(condition, name='condition')
+ update_op = tf.cond(condition, lambda: tf.group(_create_var(name, value)), tf.no_op)
+ with tf.control_dependencies([update_op]):
+ return tf.identity(value if passthru is None else passthru)
+
+ else: # python scalar or numpy array
+ assert not tfutil.is_tf_expression(passthru)
+ assert not tfutil.is_tf_expression(condition)
+ if condition:
+ if name not in _immediate:
+ with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None):
+ update_value = tf.placeholder(_dtype)
+ update_op = _create_var(name, update_value)
+ _immediate[name] = update_op, update_value
+ update_op, update_value = _immediate[name]
+ tfutil.run(update_op, {update_value: value})
+ return value if passthru is None else passthru
+
+
+def finalize_autosummaries() -> None:
+ """Create the necessary ops to include autosummaries in TensorBoard report.
+ Note: This should be done only once per graph.
+ """
+ global _finalized
+ tfutil.assert_tf_initialized()
+
+ if _finalized:
+ return None
+
+ _finalized = True
+ tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list])
+
+ # Create summary ops.
+ with tf.device(None), tf.control_dependencies(None):
+ for name, vars_list in _vars.items():
+ name_id = name.replace("/", "_")
+ with tfutil.absolute_name_scope("Autosummary/" + name_id):
+ moments = tf.add_n(vars_list)
+ moments /= moments[0]
+ with tf.control_dependencies([moments]): # read before resetting
+ reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list]
+ with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting
+ mean = moments[1]
+ std = tf.sqrt(moments[2] - tf.square(moments[1]))
+ tf.summary.scalar(name, mean)
+ if enable_custom_scalars:
+ tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std)
+ tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std)
+
+ # Setup layout for custom scalars.
+ layout = None
+ if enable_custom_scalars:
+ cat_dict = OrderedDict()
+ for series_name in sorted(_vars.keys()):
+ p = series_name.split("/")
+ cat = p[0] if len(p) >= 2 else ""
+ chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1]
+ if cat not in cat_dict:
+ cat_dict[cat] = OrderedDict()
+ if chart not in cat_dict[cat]:
+ cat_dict[cat][chart] = []
+ cat_dict[cat][chart].append(series_name)
+ categories = []
+ for cat_name, chart_dict in cat_dict.items():
+ charts = []
+ for chart_name, series_names in chart_dict.items():
+ series = []
+ for series_name in series_names:
+ series.append(layout_pb2.MarginChartContent.Series(
+ value=series_name,
+ lower="xCustomScalars/" + series_name + "/margin_lo",
+ upper="xCustomScalars/" + series_name + "/margin_hi"))
+ margin = layout_pb2.MarginChartContent(series=series)
+ charts.append(layout_pb2.Chart(title=chart_name, margin=margin))
+ categories.append(layout_pb2.Category(title=cat_name, chart=charts))
+ layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories))
+ return layout
+
+def save_summaries(file_writer, global_step=None):
+ """Call FileWriter.add_summary() with all summaries in the default graph,
+ automatically finalizing and merging them on the first call.
+ """
+ global _merge_op
+ tfutil.assert_tf_initialized()
+
+ if _merge_op is None:
+ layout = finalize_autosummaries()
+ if layout is not None:
+ file_writer.add_summary(layout)
+ with tf.device(None), tf.control_dependencies(None):
+ _merge_op = tf.summary.merge_all()
+
+ file_writer.add_summary(_merge_op.eval(), global_step)
diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/custom_ops.py b/models/StyleCLIP/global_directions/dnnlib/tflib/custom_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..702471e2006af6858345c1225c1e55b0acd17d32
--- /dev/null
+++ b/models/StyleCLIP/global_directions/dnnlib/tflib/custom_ops.py
@@ -0,0 +1,181 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""TensorFlow custom ops builder.
+"""
+
+import glob
+import os
+import re
+import uuid
+import hashlib
+import tempfile
+import shutil
+import tensorflow as tf
+from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module
+
+from .. import util
+
+#----------------------------------------------------------------------------
+# Global configs.
+
+cuda_cache_path = None
+cuda_cache_version_tag = 'v1'
+do_not_hash_included_headers = True # Speed up compilation by assuming that headers included by the CUDA code never change.
+verbose = True # Print status messages to stdout.
+
+#----------------------------------------------------------------------------
+# Internal helper funcs.
+
+def _find_compiler_bindir():
+ hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True)
+ if hostx64_paths != []:
+ return hostx64_paths[0]
+ hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True)
+ if hostx64_paths != []:
+ return hostx64_paths[0]
+ hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True)
+ if hostx64_paths != []:
+ return hostx64_paths[0]
+ vc_bin_dir = 'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin'
+ if os.path.isdir(vc_bin_dir):
+ return vc_bin_dir
+ return None
+
+def _get_compute_cap(device):
+ caps_str = device.physical_device_desc
+ m = re.search('compute capability: (\\d+).(\\d+)', caps_str)
+ major = m.group(1)
+ minor = m.group(2)
+ return (major, minor)
+
+def _get_cuda_gpu_arch_string():
+ gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU']
+ if len(gpus) == 0:
+ raise RuntimeError('No GPU devices found')
+ (major, minor) = _get_compute_cap(gpus[0])
+ return 'sm_%s%s' % (major, minor)
+
+def _run_cmd(cmd):
+ with os.popen(cmd) as pipe:
+ output = pipe.read()
+ status = pipe.close()
+ if status is not None:
+ raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output))
+
+def _prepare_nvcc_cli(opts):
+ cmd = 'nvcc ' + opts.strip()
+ cmd += ' --disable-warnings'
+ cmd += ' --include-path "%s"' % tf.sysconfig.get_include()
+ cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src')
+ cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'com_google_absl')
+ cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'eigen_archive')
+
+ compiler_bindir = _find_compiler_bindir()
+ if compiler_bindir is None:
+ # Require that _find_compiler_bindir succeeds on Windows. Allow
+ # nvcc to use whatever is the default on Linux.
+ if os.name == 'nt':
+ raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__)
+ else:
+ cmd += ' --compiler-bindir "%s"' % compiler_bindir
+ cmd += ' 2>&1'
+ return cmd
+
+#----------------------------------------------------------------------------
+# Main entry point.
+
+_plugin_cache = dict()
+
+def get_plugin(cuda_file, extra_nvcc_options=[]):
+ cuda_file_base = os.path.basename(cuda_file)
+ cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base)
+
+ # Already in cache?
+ if cuda_file in _plugin_cache:
+ return _plugin_cache[cuda_file]
+
+ # Setup plugin.
+ if verbose:
+ print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True)
+ try:
+ # Hash CUDA source.
+ md5 = hashlib.md5()
+ with open(cuda_file, 'rb') as f:
+ md5.update(f.read())
+ md5.update(b'\n')
+
+ # Hash headers included by the CUDA code by running it through the preprocessor.
+ if not do_not_hash_included_headers:
+ if verbose:
+ print('Preprocessing... ', end='', flush=True)
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext)
+ _run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)))
+ with open(tmp_file, 'rb') as f:
+ bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros
+ good_file_str = ('"' + cuda_file_base + '"').encode('utf-8')
+ for ln in f:
+ if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas
+ ln = ln.replace(bad_file_str, good_file_str)
+ md5.update(ln)
+ md5.update(b'\n')
+
+ # Select compiler configs.
+ compile_opts = ''
+ if os.name == 'nt':
+ compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.lib')
+ elif os.name == 'posix':
+ compile_opts += f' --compiler-options \'-fPIC\''
+ compile_opts += f' --compiler-options \'{" ".join(tf.sysconfig.get_compile_flags())}\''
+ compile_opts += f' --linker-options \'{" ".join(tf.sysconfig.get_link_flags())}\''
+ else:
+ assert False # not Windows or Linux, w00t?
+ compile_opts += f' --gpu-architecture={_get_cuda_gpu_arch_string()}'
+ compile_opts += ' --use_fast_math'
+ for opt in extra_nvcc_options:
+ compile_opts += ' ' + opt
+ nvcc_cmd = _prepare_nvcc_cli(compile_opts)
+
+ # Hash build configuration.
+ md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n')
+ md5.update(('tf.VERSION: ' + tf.VERSION).encode('utf-8') + b'\n')
+ md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n')
+
+ # Compile if not already compiled.
+ cache_dir = util.make_cache_dir_path('tflib-cudacache') if cuda_cache_path is None else cuda_cache_path
+ bin_file_ext = '.dll' if os.name == 'nt' else '.so'
+ bin_file = os.path.join(cache_dir, cuda_file_name + '_' + md5.hexdigest() + bin_file_ext)
+ if not os.path.isfile(bin_file):
+ if verbose:
+ print('Compiling... ', end='', flush=True)
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext)
+ _run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))
+ os.makedirs(cache_dir, exist_ok=True)
+ intermediate_file = os.path.join(cache_dir, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext)
+ shutil.copyfile(tmp_file, intermediate_file)
+ os.rename(intermediate_file, bin_file) # atomic
+
+ # Load.
+ if verbose:
+ print('Loading... ', end='', flush=True)
+ plugin = tf.load_op_library(bin_file)
+
+ # Add to cache.
+ _plugin_cache[cuda_file] = plugin
+ if verbose:
+ print('Done.', flush=True)
+ return plugin
+
+ except:
+ if verbose:
+ print('Failed!', flush=True)
+ raise
+
+#----------------------------------------------------------------------------
diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/network.py b/models/StyleCLIP/global_directions/dnnlib/tflib/network.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff0c169eabdc579041dac0650fbc6da956646594
--- /dev/null
+++ b/models/StyleCLIP/global_directions/dnnlib/tflib/network.py
@@ -0,0 +1,781 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Helper for managing networks."""
+
+import types
+import inspect
+import re
+import uuid
+import sys
+import copy
+import numpy as np
+import tensorflow as tf
+
+from collections import OrderedDict
+from typing import Any, List, Tuple, Union, Callable
+
+from . import tfutil
+from .. import util
+
+from .tfutil import TfExpression, TfExpressionEx
+
+# pylint: disable=protected-access
+# pylint: disable=attribute-defined-outside-init
+# pylint: disable=too-many-public-methods
+
+_import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import.
+_import_module_src = dict() # Source code for temporary modules created during pickle import.
+
+
+def import_handler(handler_func):
+ """Function decorator for declaring custom import handlers."""
+ _import_handlers.append(handler_func)
+ return handler_func
+
+
+class Network:
+ """Generic network abstraction.
+
+ Acts as a convenience wrapper for a parameterized network construction
+ function, providing several utility methods and convenient access to
+ the inputs/outputs/weights.
+
+ Network objects can be safely pickled and unpickled for long-term
+ archival purposes. The pickling works reliably as long as the underlying
+ network construction function is defined in a standalone Python module
+ that has no side effects or application-specific imports.
+
+ Args:
+ name: Network name. Used to select TensorFlow name and variable scopes. Defaults to build func name if None.
+ func_name: Fully qualified name of the underlying network construction function, or a top-level function object.
+ static_kwargs: Keyword arguments to be passed in to the network construction function.
+ """
+
+ def __init__(self, name: str = None, func_name: Any = None, **static_kwargs):
+ # Locate the user-specified build function.
+ assert isinstance(func_name, str) or util.is_top_level_function(func_name)
+ if util.is_top_level_function(func_name):
+ func_name = util.get_top_level_function_name(func_name)
+ module, func_name = util.get_module_from_obj_name(func_name)
+ func = util.get_obj_from_module(module, func_name)
+
+ # Dig up source code for the module containing the build function.
+ module_src = _import_module_src.get(module, None)
+ if module_src is None:
+ module_src = inspect.getsource(module)
+
+ # Initialize fields.
+ self._init_fields(name=(name or func_name), static_kwargs=static_kwargs, build_func=func, build_func_name=func_name, build_module_src=module_src)
+
+ def _init_fields(self, name: str, static_kwargs: dict, build_func: Callable, build_func_name: str, build_module_src: str) -> None:
+ tfutil.assert_tf_initialized()
+ assert isinstance(name, str)
+ assert len(name) >= 1
+ assert re.fullmatch(r"[A-Za-z0-9_.\\-]*", name)
+ assert isinstance(static_kwargs, dict)
+ assert util.is_pickleable(static_kwargs)
+ assert callable(build_func)
+ assert isinstance(build_func_name, str)
+ assert isinstance(build_module_src, str)
+
+ # Choose TensorFlow name scope.
+ with tf.name_scope(None):
+ scope = tf.get_default_graph().unique_name(name, mark_as_used=True)
+
+ # Query current TensorFlow device.
+ with tfutil.absolute_name_scope(scope), tf.control_dependencies(None):
+ device = tf.no_op(name="_QueryDevice").device
+
+ # Immutable state.
+ self._name = name
+ self._scope = scope
+ self._device = device
+ self._static_kwargs = util.EasyDict(copy.deepcopy(static_kwargs))
+ self._build_func = build_func
+ self._build_func_name = build_func_name
+ self._build_module_src = build_module_src
+
+ # State before _init_graph().
+ self._var_inits = dict() # var_name => initial_value, set to None by _init_graph()
+ self._all_inits_known = False # Do we know for sure that _var_inits covers all the variables?
+ self._components = None # subnet_name => Network, None if the components are not known yet
+
+ # Initialized by _init_graph().
+ self._input_templates = None
+ self._output_templates = None
+ self._own_vars = None
+
+ # Cached values initialized the respective methods.
+ self._input_shapes = None
+ self._output_shapes = None
+ self._input_names = None
+ self._output_names = None
+ self._vars = None
+ self._trainables = None
+ self._var_global_to_local = None
+ self._run_cache = dict()
+
+ def _init_graph(self) -> None:
+ assert self._var_inits is not None
+ assert self._input_templates is None
+ assert self._output_templates is None
+ assert self._own_vars is None
+
+ # Initialize components.
+ if self._components is None:
+ self._components = util.EasyDict()
+
+ # Choose build func kwargs.
+ build_kwargs = dict(self.static_kwargs)
+ build_kwargs["is_template_graph"] = True
+ build_kwargs["components"] = self._components
+
+ # Override scope and device, and ignore surrounding control dependencies.
+ with tfutil.absolute_variable_scope(self.scope, reuse=False), tfutil.absolute_name_scope(self.scope), tf.device(self.device), tf.control_dependencies(None):
+ assert tf.get_variable_scope().name == self.scope
+ assert tf.get_default_graph().get_name_scope() == self.scope
+
+ # Create input templates.
+ self._input_templates = []
+ for param in inspect.signature(self._build_func).parameters.values():
+ if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
+ self._input_templates.append(tf.placeholder(tf.float32, name=param.name))
+
+ # Call build func.
+ out_expr = self._build_func(*self._input_templates, **build_kwargs)
+
+ # Collect output templates and variables.
+ assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
+ self._output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
+ self._own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/"))
+
+ # Check for errors.
+ if len(self._input_templates) == 0:
+ raise ValueError("Network build func did not list any inputs.")
+ if len(self._output_templates) == 0:
+ raise ValueError("Network build func did not return any outputs.")
+ if any(not tfutil.is_tf_expression(t) for t in self._output_templates):
+ raise ValueError("Network outputs must be TensorFlow expressions.")
+ if any(t.shape.ndims is None for t in self._input_templates):
+ raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.")
+ if any(t.shape.ndims is None for t in self._output_templates):
+ raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.")
+ if any(not isinstance(comp, Network) for comp in self._components.values()):
+ raise ValueError("Components of a Network must be Networks themselves.")
+ if len(self._components) != len(set(comp.name for comp in self._components.values())):
+ raise ValueError("Components of a Network must have unique names.")
+
+ # Initialize variables.
+ if len(self._var_inits):
+ tfutil.set_vars({self._get_vars()[name]: value for name, value in self._var_inits.items() if name in self._get_vars()})
+ remaining_inits = [var.initializer for name, var in self._own_vars.items() if name not in self._var_inits]
+ if self._all_inits_known:
+ assert len(remaining_inits) == 0
+ else:
+ tfutil.run(remaining_inits)
+ self._var_inits = None
+
+ @property
+ def name(self):
+ """User-specified name string."""
+ return self._name
+
+ @property
+ def scope(self):
+ """Unique TensorFlow scope containing template graph and variables, derived from the user-specified name."""
+ return self._scope
+
+ @property
+ def device(self):
+ """Name of the TensorFlow device that the weights of this network reside on. Determined by the current device at construction time."""
+ return self._device
+
+ @property
+ def static_kwargs(self):
+ """EasyDict of arguments passed to the user-supplied build func."""
+ return copy.deepcopy(self._static_kwargs)
+
+ @property
+ def components(self):
+ """EasyDict of sub-networks created by the build func."""
+ return copy.copy(self._get_components())
+
+ def _get_components(self):
+ if self._components is None:
+ self._init_graph()
+ assert self._components is not None
+ return self._components
+
+ @property
+ def input_shapes(self):
+ """List of input tensor shapes, including minibatch dimension."""
+ if self._input_shapes is None:
+ self._input_shapes = [t.shape.as_list() for t in self.input_templates]
+ return copy.deepcopy(self._input_shapes)
+
+ @property
+ def output_shapes(self):
+ """List of output tensor shapes, including minibatch dimension."""
+ if self._output_shapes is None:
+ self._output_shapes = [t.shape.as_list() for t in self.output_templates]
+ return copy.deepcopy(self._output_shapes)
+
+ @property
+ def input_shape(self):
+ """Short-hand for input_shapes[0]."""
+ return self.input_shapes[0]
+
+ @property
+ def output_shape(self):
+ """Short-hand for output_shapes[0]."""
+ return self.output_shapes[0]
+
+ @property
+ def num_inputs(self):
+ """Number of input tensors."""
+ return len(self.input_shapes)
+
+ @property
+ def num_outputs(self):
+ """Number of output tensors."""
+ return len(self.output_shapes)
+
+ @property
+ def input_names(self):
+ """Name string for each input."""
+ if self._input_names is None:
+ self._input_names = [t.name.split("/")[-1].split(":")[0] for t in self.input_templates]
+ return copy.copy(self._input_names)
+
+ @property
+ def output_names(self):
+ """Name string for each output."""
+ if self._output_names is None:
+ self._output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates]
+ return copy.copy(self._output_names)
+
+ @property
+ def input_templates(self):
+ """Input placeholders in the template graph."""
+ if self._input_templates is None:
+ self._init_graph()
+ assert self._input_templates is not None
+ return copy.copy(self._input_templates)
+
+ @property
+ def output_templates(self):
+ """Output tensors in the template graph."""
+ if self._output_templates is None:
+ self._init_graph()
+ assert self._output_templates is not None
+ return copy.copy(self._output_templates)
+
+ @property
+ def own_vars(self):
+ """Variables defined by this network (local_name => var), excluding sub-networks."""
+ return copy.copy(self._get_own_vars())
+
+ def _get_own_vars(self):
+ if self._own_vars is None:
+ self._init_graph()
+ assert self._own_vars is not None
+ return self._own_vars
+
+ @property
+ def vars(self):
+ """All variables (local_name => var)."""
+ return copy.copy(self._get_vars())
+
+ def _get_vars(self):
+ if self._vars is None:
+ self._vars = OrderedDict(self._get_own_vars())
+ for comp in self._get_components().values():
+ self._vars.update((comp.name + "/" + name, var) for name, var in comp._get_vars().items())
+ return self._vars
+
+ @property
+ def trainables(self):
+ """All trainable variables (local_name => var)."""
+ return copy.copy(self._get_trainables())
+
+ def _get_trainables(self):
+ if self._trainables is None:
+ self._trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable)
+ return self._trainables
+
+ @property
+ def var_global_to_local(self):
+ """Mapping from variable global names to local names."""
+ return copy.copy(self._get_var_global_to_local())
+
+ def _get_var_global_to_local(self):
+ if self._var_global_to_local is None:
+ self._var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items())
+ return self._var_global_to_local
+
+ def reset_own_vars(self) -> None:
+ """Re-initialize all variables of this network, excluding sub-networks."""
+ if self._var_inits is None or self._components is None:
+ tfutil.run([var.initializer for var in self._get_own_vars().values()])
+ else:
+ self._var_inits.clear()
+ self._all_inits_known = False
+
+ def reset_vars(self) -> None:
+ """Re-initialize all variables of this network, including sub-networks."""
+ if self._var_inits is None:
+ tfutil.run([var.initializer for var in self._get_vars().values()])
+ else:
+ self._var_inits.clear()
+ self._all_inits_known = False
+ if self._components is not None:
+ for comp in self._components.values():
+ comp.reset_vars()
+
+ def reset_trainables(self) -> None:
+ """Re-initialize all trainable variables of this network, including sub-networks."""
+ tfutil.run([var.initializer for var in self._get_trainables().values()])
+
+ def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]:
+ """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s).
+ The graph is placed on the current TensorFlow device."""
+ assert len(in_expr) == self.num_inputs
+ assert not all(expr is None for expr in in_expr)
+ self._get_vars() # ensure that all variables have been created
+
+ # Choose build func kwargs.
+ build_kwargs = dict(self.static_kwargs)
+ build_kwargs.update(dynamic_kwargs)
+ build_kwargs["is_template_graph"] = False
+ build_kwargs["components"] = self._components
+
+ # Build TensorFlow graph to evaluate the network.
+ with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name):
+ assert tf.get_variable_scope().name == self.scope
+ valid_inputs = [expr for expr in in_expr if expr is not None]
+ final_inputs = []
+ for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes):
+ if expr is not None:
+ expr = tf.identity(expr, name=name)
+ else:
+ expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name)
+ final_inputs.append(expr)
+ out_expr = self._build_func(*final_inputs, **build_kwargs)
+
+ # Propagate input shapes back to the user-specified expressions.
+ for expr, final in zip(in_expr, final_inputs):
+ if isinstance(expr, tf.Tensor):
+ expr.set_shape(final.shape)
+
+ # Express outputs in the desired format.
+ assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
+ if return_as_list:
+ out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
+ return out_expr
+
+ def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str:
+ """Get the local name of a given variable, without any surrounding name scopes."""
+ assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str)
+ global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name
+ return self._get_var_global_to_local()[global_name]
+
+ def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression:
+ """Find variable by local or global name."""
+ assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str)
+ return self._get_vars()[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name
+
+ def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray:
+ """Get the value of a given variable as NumPy array.
+ Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible."""
+ return self.find_var(var_or_local_name).eval()
+
+ def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None:
+ """Set the value of a given variable based on the given NumPy array.
+ Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible."""
+ tfutil.set_vars({self.find_var(var_or_local_name): new_value})
+
+ def __getstate__(self) -> dict:
+ """Pickle export."""
+ state = dict()
+ state["version"] = 5
+ state["name"] = self.name
+ state["static_kwargs"] = dict(self.static_kwargs)
+ state["components"] = dict(self.components)
+ state["build_module_src"] = self._build_module_src
+ state["build_func_name"] = self._build_func_name
+ state["variables"] = list(zip(self._get_own_vars().keys(), tfutil.run(list(self._get_own_vars().values()))))
+ state["input_shapes"] = self.input_shapes
+ state["output_shapes"] = self.output_shapes
+ state["input_names"] = self.input_names
+ state["output_names"] = self.output_names
+ return state
+
+ def __setstate__(self, state: dict) -> None:
+ """Pickle import."""
+
+ # Execute custom import handlers.
+ for handler in _import_handlers:
+ state = handler(state)
+
+ # Get basic fields.
+ assert state["version"] in [2, 3, 4, 5]
+ name = state["name"]
+ static_kwargs = state["static_kwargs"]
+ build_module_src = state["build_module_src"]
+ build_func_name = state["build_func_name"]
+
+ # Create temporary module from the imported source code.
+ module_name = "_tflib_network_import_" + uuid.uuid4().hex
+ module = types.ModuleType(module_name)
+ sys.modules[module_name] = module
+ _import_module_src[module] = build_module_src
+ exec(build_module_src, module.__dict__) # pylint: disable=exec-used
+ build_func = util.get_obj_from_module(module, build_func_name)
+
+ # Initialize fields.
+ self._init_fields(name=name, static_kwargs=static_kwargs, build_func=build_func, build_func_name=build_func_name, build_module_src=build_module_src)
+ self._var_inits.update(copy.deepcopy(state["variables"]))
+ self._all_inits_known = True
+ self._components = util.EasyDict(state.get("components", {}))
+ self._input_shapes = copy.deepcopy(state.get("input_shapes", None))
+ self._output_shapes = copy.deepcopy(state.get("output_shapes", None))
+ self._input_names = copy.deepcopy(state.get("input_names", None))
+ self._output_names = copy.deepcopy(state.get("output_names", None))
+
+ def clone(self, name: str = None, **new_static_kwargs) -> "Network":
+ """Create a clone of this network with its own copy of the variables."""
+ static_kwargs = dict(self.static_kwargs)
+ static_kwargs.update(new_static_kwargs)
+ net = object.__new__(Network)
+ net._init_fields(name=(name or self.name), static_kwargs=static_kwargs, build_func=self._build_func, build_func_name=self._build_func_name, build_module_src=self._build_module_src)
+ net.copy_vars_from(self)
+ return net
+
+ def copy_own_vars_from(self, src_net: "Network") -> None:
+ """Copy the values of all variables from the given network, excluding sub-networks."""
+
+ # Source has unknown variables or unknown components => init now.
+ if (src_net._var_inits is not None and not src_net._all_inits_known) or src_net._components is None:
+ src_net._get_vars()
+
+ # Both networks are inited => copy directly.
+ if src_net._var_inits is None and self._var_inits is None:
+ names = [name for name in self._get_own_vars().keys() if name in src_net._get_own_vars()]
+ tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names}))
+ return
+
+ # Read from source.
+ if src_net._var_inits is None:
+ value_dict = tfutil.run(src_net._get_own_vars())
+ else:
+ value_dict = src_net._var_inits
+
+ # Write to destination.
+ if self._var_inits is None:
+ tfutil.set_vars({self._get_vars()[name]: value for name, value in value_dict.items() if name in self._get_vars()})
+ else:
+ self._var_inits.update(value_dict)
+
+ def copy_vars_from(self, src_net: "Network") -> None:
+ """Copy the values of all variables from the given network, including sub-networks."""
+
+ # Source has unknown variables or unknown components => init now.
+ if (src_net._var_inits is not None and not src_net._all_inits_known) or src_net._components is None:
+ src_net._get_vars()
+
+ # Source is inited, but destination components have not been created yet => set as initial values.
+ if src_net._var_inits is None and self._components is None:
+ self._var_inits.update(tfutil.run(src_net._get_vars()))
+ return
+
+ # Destination has unknown components => init now.
+ if self._components is None:
+ self._get_vars()
+
+ # Both networks are inited => copy directly.
+ if src_net._var_inits is None and self._var_inits is None:
+ names = [name for name in self._get_vars().keys() if name in src_net._get_vars()]
+ tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names}))
+ return
+
+ # Copy recursively, component by component.
+ self.copy_own_vars_from(src_net)
+ for name, src_comp in src_net._components.items():
+ if name in self._components:
+ self._components[name].copy_vars_from(src_comp)
+
+ def copy_trainables_from(self, src_net: "Network") -> None:
+ """Copy the values of all trainable variables from the given network, including sub-networks."""
+ names = [name for name in self._get_trainables().keys() if name in src_net._get_trainables()]
+ tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names}))
+
+ def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network":
+ """Create new network with the given parameters, and copy all variables from this network."""
+ if new_name is None:
+ new_name = self.name
+ static_kwargs = dict(self.static_kwargs)
+ static_kwargs.update(new_static_kwargs)
+ net = Network(name=new_name, func_name=new_func_name, **static_kwargs)
+ net.copy_vars_from(self)
+ return net
+
+ def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation:
+ """Construct a TensorFlow op that updates the variables of this network
+ to be slightly closer to those of the given network."""
+ with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"):
+ ops = []
+ for name, var in self._get_vars().items():
+ if name in src_net._get_vars():
+ cur_beta = beta if var.trainable else beta_nontrainable
+ new_value = tfutil.lerp(src_net._get_vars()[name], var, cur_beta)
+ ops.append(var.assign(new_value))
+ return tf.group(*ops)
+
+ def run(self,
+ *in_arrays: Tuple[Union[np.ndarray, None], ...],
+ input_transform: dict = None,
+ output_transform: dict = None,
+ return_as_list: bool = False,
+ print_progress: bool = False,
+ minibatch_size: int = None,
+ num_gpus: int = 1,
+ assume_frozen: bool = False,
+ **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:
+ """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).
+
+ Args:
+ input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network.
+ The dict must contain a 'func' field that points to a top-level function. The function is called with the input
+ TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
+ output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network.
+ The dict must contain a 'func' field that points to a top-level function. The function is called with the output
+ TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
+ return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
+ print_progress: Print progress to the console? Useful for very large input arrays.
+ minibatch_size: Maximum minibatch size to use, None = disable batching.
+ num_gpus: Number of GPUs to use.
+ assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
+ dynamic_kwargs: Additional keyword arguments to be passed into the network build function.
+ """
+ assert len(in_arrays) == self.num_inputs
+ assert not all(arr is None for arr in in_arrays)
+ assert input_transform is None or util.is_top_level_function(input_transform["func"])
+ assert output_transform is None or util.is_top_level_function(output_transform["func"])
+ output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs)
+ num_items = in_arrays[0].shape[0]
+ if minibatch_size is None:
+ minibatch_size = num_items
+
+ # Construct unique hash key from all arguments that affect the TensorFlow graph.
+ key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs)
+ def unwind_key(obj):
+ if isinstance(obj, dict):
+ return [(key, unwind_key(value)) for key, value in sorted(obj.items())]
+ if callable(obj):
+ return util.get_top_level_function_name(obj)
+ return obj
+ key = repr(unwind_key(key))
+
+ # Build graph.
+ if key not in self._run_cache:
+ with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
+ with tf.device("/cpu:0"):
+ in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
+ in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
+
+ out_split = []
+ for gpu in range(num_gpus):
+ with tf.device(self.device if num_gpus == 1 else "/gpu:%d" % gpu):
+ net_gpu = self.clone() if assume_frozen else self
+ in_gpu = in_split[gpu]
+
+ if input_transform is not None:
+ in_kwargs = dict(input_transform)
+ in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs)
+ in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu)
+
+ assert len(in_gpu) == self.num_inputs
+ out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs)
+
+ if output_transform is not None:
+ out_kwargs = dict(output_transform)
+ out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs)
+ out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu)
+
+ assert len(out_gpu) == self.num_outputs
+ out_split.append(out_gpu)
+
+ with tf.device("/cpu:0"):
+ out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)]
+ self._run_cache[key] = in_expr, out_expr
+
+ # Run minibatches.
+ in_expr, out_expr = self._run_cache[key]
+ out_arrays = [np.empty([num_items] + expr.shape.as_list()[1:], expr.dtype.name) for expr in out_expr]
+
+ for mb_begin in range(0, num_items, minibatch_size):
+ if print_progress:
+ print("\r%d / %d" % (mb_begin, num_items), end="")
+
+ mb_end = min(mb_begin + minibatch_size, num_items)
+ mb_num = mb_end - mb_begin
+ mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]
+ mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))
+
+ for dst, src in zip(out_arrays, mb_out):
+ dst[mb_begin: mb_end] = src
+
+ # Done.
+ if print_progress:
+ print("\r%d / %d" % (num_items, num_items))
+
+ if not return_as_list:
+ out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays)
+ return out_arrays
+
+ def list_ops(self) -> List[TfExpression]:
+ _ = self.output_templates # ensure that the template graph has been created
+ include_prefix = self.scope + "/"
+ exclude_prefix = include_prefix + "_"
+ ops = tf.get_default_graph().get_operations()
+ ops = [op for op in ops if op.name.startswith(include_prefix)]
+ ops = [op for op in ops if not op.name.startswith(exclude_prefix)]
+ return ops
+
+ def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]:
+ """Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to
+ individual layers of the network. Mainly intended to be used for reporting."""
+ layers = []
+
+ def recurse(scope, parent_ops, parent_vars, level):
+ if len(parent_ops) == 0 and len(parent_vars) == 0:
+ return
+
+ # Ignore specific patterns.
+ if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]):
+ return
+
+ # Filter ops and vars by scope.
+ global_prefix = scope + "/"
+ local_prefix = global_prefix[len(self.scope) + 1:]
+ cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]]
+ cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]]
+ if not cur_ops and not cur_vars:
+ return
+
+ # Filter out all ops related to variables.
+ for var in [op for op in cur_ops if op.type.startswith("Variable")]:
+ var_prefix = var.name + "/"
+ cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)]
+
+ # Scope does not contain ops as immediate children => recurse deeper.
+ contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type not in ["Identity", "Cast", "Transpose"] for op in cur_ops)
+ if (level == 0 or not contains_direct_ops) and (len(cur_ops) != 0 or len(cur_vars) != 0):
+ visited = set()
+ for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]:
+ token = rel_name.split("/")[0]
+ if token not in visited:
+ recurse(global_prefix + token, cur_ops, cur_vars, level + 1)
+ visited.add(token)
+ return
+
+ # Report layer.
+ layer_name = scope[len(self.scope) + 1:]
+ layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1]
+ layer_trainables = [var for _name, var in cur_vars if var.trainable]
+ layers.append((layer_name, layer_output, layer_trainables))
+
+ recurse(self.scope, self.list_ops(), list(self._get_vars().items()), 0)
+ return layers
+
+ def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None:
+ """Print a summary table of the network structure."""
+ rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]]
+ rows += [["---"] * 4]
+ total_params = 0
+
+ for layer_name, layer_output, layer_trainables in self.list_layers():
+ num_params = sum(int(np.prod(var.shape.as_list())) for var in layer_trainables)
+ weights = [var for var in layer_trainables if var.name.endswith("/weight:0")]
+ weights.sort(key=lambda x: len(x.name))
+ if len(weights) == 0 and len(layer_trainables) == 1:
+ weights = layer_trainables
+ total_params += num_params
+
+ if not hide_layers_with_no_params or num_params != 0:
+ num_params_str = str(num_params) if num_params > 0 else "-"
+ output_shape_str = str(layer_output.shape)
+ weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-"
+ rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]]
+
+ rows += [["---"] * 4]
+ rows += [["Total", str(total_params), "", ""]]
+
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
+ print()
+ for row in rows:
+ print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths)))
+ print()
+
+ def setup_weight_histograms(self, title: str = None) -> None:
+ """Construct summary ops to include histograms of all trainable parameters in TensorBoard."""
+ if title is None:
+ title = self.name
+
+ with tf.name_scope(None), tf.device(None), tf.control_dependencies(None):
+ for local_name, var in self._get_trainables().items():
+ if "/" in local_name:
+ p = local_name.split("/")
+ name = title + "_" + p[-1] + "/" + "_".join(p[:-1])
+ else:
+ name = title + "_toplevel/" + local_name
+
+ tf.summary.histogram(name, var)
+
+#----------------------------------------------------------------------------
+# Backwards-compatible emulation of legacy output transformation in Network.run().
+
+_print_legacy_warning = True
+
+def _handle_legacy_output_transforms(output_transform, dynamic_kwargs):
+ global _print_legacy_warning
+ legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"]
+ if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs):
+ return output_transform, dynamic_kwargs
+
+ if _print_legacy_warning:
+ _print_legacy_warning = False
+ print()
+ print("WARNING: Old-style output transformations in Network.run() are deprecated.")
+ print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'")
+ print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.")
+ print()
+ assert output_transform is None
+
+ new_kwargs = dict(dynamic_kwargs)
+ new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs}
+ new_transform["func"] = _legacy_output_transform_func
+ return new_transform, new_kwargs
+
+def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None):
+ if out_mul != 1.0:
+ expr = [x * out_mul for x in expr]
+
+ if out_add != 0.0:
+ expr = [x + out_add for x in expr]
+
+ if out_shrink > 1:
+ ksize = [1, 1, out_shrink, out_shrink]
+ expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr]
+
+ if out_dtype is not None:
+ if tf.as_dtype(out_dtype).is_integer:
+ expr = [tf.round(x) for x in expr]
+ expr = [tf.saturate_cast(x, out_dtype) for x in expr]
+ return expr
diff --git a/PTI/torch_utils/__init__.py b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/__init__.py
similarity index 85%
rename from PTI/torch_utils/__init__.py
rename to models/StyleCLIP/global_directions/dnnlib/tflib/ops/__init__.py
index ece0ea08fe2e939cc260a1dafc0ab5b391b773d9..43cce37364064146fd30e18612b1d9e3a84f513a 100644
--- a/PTI/torch_utils/__init__.py
+++ b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/__init__.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/ops/fused_bias_act.cu b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/fused_bias_act.cu
new file mode 100644
index 0000000000000000000000000000000000000000..0268f14395319003240b4a5a59141d703e9a4257
--- /dev/null
+++ b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/fused_bias_act.cu
@@ -0,0 +1,220 @@
+// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#define EIGEN_USE_GPU
+#define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include
+
+using namespace tensorflow;
+using namespace tensorflow::shape_inference;
+
+#define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false)
+
+//------------------------------------------------------------------------
+// CUDA kernel.
+
+template
+struct FusedBiasActKernelParams
+{
+ const T* x; // [sizeX]
+ const T* b; // [sizeB] or NULL
+ const T* xref; // [sizeX] or NULL
+ const T* yref; // [sizeX] or NULL
+ T* y; // [sizeX]
+
+ int grad;
+ int axis;
+ int act;
+ float alpha;
+ float gain;
+ float clamp;
+
+ int sizeX;
+ int sizeB;
+ int stepB;
+ int loopX;
+};
+
+template
+static __global__ void FusedBiasActKernel(const FusedBiasActKernelParams p)
+{
+ const float expRange = 80.0f;
+ const float halfExpRange = 40.0f;
+ const float seluScale = 1.0507009873554804934193349852946f;
+ const float seluAlpha = 1.6732632423543772848170429916717f;
+
+ // Loop over elements.
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
+ {
+ // Load and apply bias.
+ float x = (float)p.x[xi];
+ if (p.b)
+ x += (float)p.b[(xi / p.stepB) % p.sizeB];
+ float xref = (p.xref) ? (float)p.xref[xi] : 0.0f;
+ float yref = (p.yref) ? (float)p.yref[xi] : 0.0f;
+ float yy = (p.gain != 0.0f) ? yref / p.gain : 0.0f;
+
+ // Evaluate activation func.
+ float y;
+ switch (p.act * 10 + p.grad)
+ {
+ // linear
+ default:
+ case 10: y = x; break;
+ case 11: y = x; break;
+ case 12: y = 0.0f; break;
+
+ // relu
+ case 20: y = (x > 0.0f) ? x : 0.0f; break;
+ case 21: y = (yy > 0.0f) ? x : 0.0f; break;
+ case 22: y = 0.0f; break;
+
+ // lrelu
+ case 30: y = (x > 0.0f) ? x : x * p.alpha; break;
+ case 31: y = (yy > 0.0f) ? x : x * p.alpha; break;
+ case 32: y = 0.0f; break;
+
+ // tanh
+ case 40: { float c = expf(x); float d = 1.0f / c; y = (x < -expRange) ? -1.0f : (x > expRange) ? 1.0f : (c - d) / (c + d); } break;
+ case 41: y = x * (1.0f - yy * yy); break;
+ case 42: y = x * (1.0f - yy * yy) * (-2.0f * yy); break;
+
+ // sigmoid
+ case 50: y = (x < -expRange) ? 0.0f : 1.0f / (expf(-x) + 1.0f); break;
+ case 51: y = x * yy * (1.0f - yy); break;
+ case 52: y = x * yy * (1.0f - yy) * (1.0f - 2.0f * yy); break;
+
+ // elu
+ case 60: y = (x >= 0.0f) ? x : expf(x) - 1.0f; break;
+ case 61: y = (yy >= 0.0f) ? x : x * (yy + 1.0f); break;
+ case 62: y = (yy >= 0.0f) ? 0.0f : x * (yy + 1.0f); break;
+
+ // selu
+ case 70: y = (x >= 0.0f) ? seluScale * x : (seluScale * seluAlpha) * (expf(x) - 1.0f); break;
+ case 71: y = (yy >= 0.0f) ? x * seluScale : x * (yy + seluScale * seluAlpha); break;
+ case 72: y = (yy >= 0.0f) ? 0.0f : x * (yy + seluScale * seluAlpha); break;
+
+ // softplus
+ case 80: y = (x > expRange) ? x : logf(expf(x) + 1.0f); break;
+ case 81: y = x * (1.0f - expf(-yy)); break;
+ case 82: { float c = expf(-yy); y = x * c * (1.0f - c); } break;
+
+ // swish
+ case 90: y = (x < -expRange) ? 0.0f : x / (expf(-x) + 1.0f); break;
+ case 91:
+ case 92:
+ {
+ float c = expf(xref);
+ float d = c + 1.0f;
+ if (p.grad == 1)
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
+ else
+ y = (xref > halfExpRange) ? 0.0f : x * c * (xref * (2.0f - d) + 2.0f * d) / (d * d * d);
+ yref = (xref < -expRange) ? 0.0f : xref / (expf(-xref) + 1.0f) * p.gain;
+ }
+ break;
+ }
+
+ // Apply gain.
+ y *= p.gain;
+
+ // Clamp.
+ if (p.clamp >= 0.0f)
+ {
+ if (p.grad == 0)
+ y = (fabsf(y) < p.clamp) ? y : (y >= 0.0f) ? p.clamp : -p.clamp;
+ else
+ y = (fabsf(yref) < p.clamp) ? y : 0.0f;
+ }
+
+ // Store.
+ p.y[xi] = (T)y;
+ }
+}
+
+//------------------------------------------------------------------------
+// TensorFlow op.
+
+template
+struct FusedBiasActOp : public OpKernel
+{
+ FusedBiasActKernelParams m_attribs;
+
+ FusedBiasActOp(OpKernelConstruction* ctx) : OpKernel(ctx)
+ {
+ memset(&m_attribs, 0, sizeof(m_attribs));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("grad", &m_attribs.grad));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &m_attribs.axis));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("act", &m_attribs.act));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &m_attribs.alpha));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("gain", &m_attribs.gain));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("clamp", &m_attribs.clamp));
+ OP_REQUIRES(ctx, m_attribs.grad >= 0, errors::InvalidArgument("grad must be non-negative"));
+ OP_REQUIRES(ctx, m_attribs.axis >= 0, errors::InvalidArgument("axis must be non-negative"));
+ OP_REQUIRES(ctx, m_attribs.act >= 0, errors::InvalidArgument("act must be non-negative"));
+ }
+
+ void Compute(OpKernelContext* ctx)
+ {
+ FusedBiasActKernelParams p = m_attribs;
+ cudaStream_t stream = ctx->eigen_device().stream();
+
+ const Tensor& x = ctx->input(0); // [...]
+ const Tensor& b = ctx->input(1); // [sizeB] or [0]
+ const Tensor& xref = ctx->input(2); // x.shape or [0]
+ const Tensor& yref = ctx->input(3); // x.shape or [0]
+ p.x = x.flat().data();
+ p.b = (b.NumElements()) ? b.flat().data() : NULL;
+ p.xref = (xref.NumElements()) ? xref.flat().data() : NULL;
+ p.yref = (yref.NumElements()) ? yref.flat().data() : NULL;
+ OP_REQUIRES(ctx, b.NumElements() == 0 || m_attribs.axis < x.dims(), errors::InvalidArgument("axis out of bounds"));
+ OP_REQUIRES(ctx, b.dims() == 1, errors::InvalidArgument("b must have rank 1"));
+ OP_REQUIRES(ctx, b.NumElements() == 0 || b.NumElements() == x.dim_size(m_attribs.axis), errors::InvalidArgument("b has wrong number of elements"));
+ OP_REQUIRES(ctx, xref.NumElements() == 0 || xref.NumElements() == x.NumElements(), errors::InvalidArgument("xref has wrong number of elements"));
+ OP_REQUIRES(ctx, yref.NumElements() == 0 || yref.NumElements() == x.NumElements(), errors::InvalidArgument("yref has wrong number of elements"));
+ OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("x is too large"));
+
+ p.sizeX = (int)x.NumElements();
+ p.sizeB = (int)b.NumElements();
+ p.stepB = 1;
+ for (int i = m_attribs.axis + 1; i < x.dims(); i++)
+ p.stepB *= (int)x.dim_size(i);
+
+ Tensor* y = NULL; // x.shape
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y));
+ p.y = y->flat().data();
+
+ p.loopX = 4;
+ int blockSize = 4 * 32;
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
+ void* args[] = {&p};
+ OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)FusedBiasActKernel, gridSize, blockSize, args, 0, stream));
+ }
+};
+
+REGISTER_OP("FusedBiasAct")
+ .Input ("x: T")
+ .Input ("b: T")
+ .Input ("xref: T")
+ .Input ("yref: T")
+ .Output ("y: T")
+ .Attr ("T: {float, half}")
+ .Attr ("grad: int = 0")
+ .Attr ("axis: int = 1")
+ .Attr ("act: int = 0")
+ .Attr ("alpha: float = 0.0")
+ .Attr ("gain: float = 1.0")
+ .Attr ("clamp: float = -1.0");
+REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp);
+REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp);
+
+//------------------------------------------------------------------------
diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/ops/fused_bias_act.py b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/fused_bias_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..79991b0497d3d92f25194a31668b9568048163f8
--- /dev/null
+++ b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/fused_bias_act.py
@@ -0,0 +1,211 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Custom TensorFlow ops for efficient bias and activation."""
+
+import os
+import numpy as np
+import tensorflow as tf
+from .. import custom_ops
+from ...util import EasyDict
+
+def _get_plugin():
+ return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu')
+
+#----------------------------------------------------------------------------
+
+activation_funcs = {
+ 'linear': EasyDict(func=lambda x, **_: x, def_alpha=None, def_gain=1.0, cuda_idx=1, ref='y', zero_2nd_grad=True),
+ 'relu': EasyDict(func=lambda x, **_: tf.nn.relu(x), def_alpha=None, def_gain=np.sqrt(2), cuda_idx=2, ref='y', zero_2nd_grad=True),
+ 'lrelu': EasyDict(func=lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', zero_2nd_grad=True),
+ 'tanh': EasyDict(func=lambda x, **_: tf.nn.tanh(x), def_alpha=None, def_gain=1.0, cuda_idx=4, ref='y', zero_2nd_grad=False),
+ 'sigmoid': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x), def_alpha=None, def_gain=1.0, cuda_idx=5, ref='y', zero_2nd_grad=False),
+ 'elu': EasyDict(func=lambda x, **_: tf.nn.elu(x), def_alpha=None, def_gain=1.0, cuda_idx=6, ref='y', zero_2nd_grad=False),
+ 'selu': EasyDict(func=lambda x, **_: tf.nn.selu(x), def_alpha=None, def_gain=1.0, cuda_idx=7, ref='y', zero_2nd_grad=False),
+ 'softplus': EasyDict(func=lambda x, **_: tf.nn.softplus(x), def_alpha=None, def_gain=1.0, cuda_idx=8, ref='y', zero_2nd_grad=False),
+ 'swish': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x) * x, def_alpha=None, def_gain=np.sqrt(2), cuda_idx=9, ref='x', zero_2nd_grad=False),
+}
+
+#----------------------------------------------------------------------------
+
+def fused_bias_act(x, b=None, axis=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
+ r"""Fused bias and activation function.
+
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
+ the fused op is considerably more efficient than performing the same calculation
+ using standard TensorFlow ops. It supports first and second order gradients,
+ but not third order gradients.
+
+ Args:
+ x: Input activation tensor. Can have any shape, but if `b` is defined, the
+ dimension corresponding to `axis`, as well as the rank, must be known.
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
+ as `x`. The shape must be known, and it must match the dimension of `x`
+ corresponding to `axis`.
+ axis: The dimension in `x` corresponding to the elements of `b`.
+ The value of `axis` is ignored if `b` is not specified.
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
+ See `activation_funcs` for a full list. `None` is not allowed.
+ alpha: Shape parameter for the activation function, or `None` to use the default.
+ gain: Scaling factor for the output tensor, or `None` to use default.
+ See `activation_funcs` for the default scaling of each activation function.
+ If unsure, consider specifying `1.0`.
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
+ the clamping (default).
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
+
+ Returns:
+ Tensor of the same shape and datatype as `x`.
+ """
+
+ impl_dict = {
+ 'ref': _fused_bias_act_ref,
+ 'cuda': _fused_bias_act_cuda,
+ }
+ return impl_dict[impl](x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain, clamp=clamp)
+
+#----------------------------------------------------------------------------
+
+def _fused_bias_act_ref(x, b, axis, act, alpha, gain, clamp):
+ """Slow reference implementation of `fused_bias_act()` using standard TensorFlow ops."""
+
+ # Validate arguments.
+ x = tf.convert_to_tensor(x)
+ b = tf.convert_to_tensor(b) if b is not None else tf.constant([], dtype=x.dtype)
+ act_spec = activation_funcs[act]
+ assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis])
+ assert b.shape[0] == 0 or 0 <= axis < x.shape.rank
+ if alpha is None:
+ alpha = act_spec.def_alpha
+ if gain is None:
+ gain = act_spec.def_gain
+
+ # Add bias.
+ if b.shape[0] != 0:
+ x += tf.reshape(b, [-1 if i == axis else 1 for i in range(x.shape.rank)])
+
+ # Evaluate activation function.
+ x = act_spec.func(x, alpha=alpha)
+
+ # Scale by gain.
+ if gain != 1:
+ x *= gain
+
+ # Clamp.
+ if clamp is not None:
+ clamp = np.asarray(clamp, dtype=x.dtype.name)
+ assert clamp.shape == () and clamp >= 0
+ x = tf.clip_by_value(x, -clamp, clamp)
+ return x
+
+#----------------------------------------------------------------------------
+
+def _fused_bias_act_cuda(x, b, axis, act, alpha, gain, clamp):
+ """Fast CUDA implementation of `fused_bias_act()` using custom ops."""
+
+ # Validate arguments.
+ x = tf.convert_to_tensor(x)
+ empty_tensor = tf.constant([], dtype=x.dtype)
+ b = tf.convert_to_tensor(b) if b is not None else empty_tensor
+ act_spec = activation_funcs[act]
+ assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis])
+ assert b.shape[0] == 0 or 0 <= axis < x.shape.rank
+ if alpha is None:
+ alpha = act_spec.def_alpha
+ if gain is None:
+ gain = act_spec.def_gain
+
+ # Special cases.
+ if act == 'linear' and b is None and gain == 1.0:
+ return x
+ if act_spec.cuda_idx is None:
+ return _fused_bias_act_ref(x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain, clamp=clamp)
+
+ # CUDA op.
+ cuda_op = _get_plugin().fused_bias_act
+ cuda_kwargs = dict(axis=int(axis), act=int(act_spec.cuda_idx), gain=float(gain))
+ if alpha is not None:
+ cuda_kwargs['alpha'] = float(alpha)
+ if clamp is not None:
+ clamp = np.asarray(clamp, dtype=x.dtype.name)
+ assert clamp.shape == () and clamp >= 0
+ cuda_kwargs['clamp'] = float(clamp.astype(np.float32))
+ def ref(tensor, name):
+ return tensor if act_spec.ref == name else empty_tensor
+
+ # Forward pass: y = func(x, b).
+ def func_y(x, b):
+ y = cuda_op(x=x, b=b, xref=empty_tensor, yref=empty_tensor, grad=0, **cuda_kwargs)
+ y.set_shape(x.shape)
+ return y
+
+ # Backward pass: dx, db = grad(dy, x, y)
+ def grad_dx(dy, x, y):
+ dx = cuda_op(x=dy, b=empty_tensor, xref=ref(x,'x'), yref=ref(y,'y'), grad=1, **cuda_kwargs)
+ dx.set_shape(x.shape)
+ return dx
+ def grad_db(dx):
+ if b.shape[0] == 0:
+ return empty_tensor
+ db = dx
+ if axis < x.shape.rank - 1:
+ db = tf.reduce_sum(db, list(range(axis + 1, x.shape.rank)))
+ if axis > 0:
+ db = tf.reduce_sum(db, list(range(axis)))
+ db.set_shape(b.shape)
+ return db
+
+ # Second order gradients: d_dy, d_x = grad2(d_dx, d_db, x, y)
+ def grad2_d_dy(d_dx, d_db, x, y):
+ d_dy = cuda_op(x=d_dx, b=d_db, xref=ref(x,'x'), yref=ref(y,'y'), grad=1, **cuda_kwargs)
+ d_dy.set_shape(x.shape)
+ return d_dy
+ def grad2_d_x(d_dx, d_db, x, y):
+ d_x = cuda_op(x=d_dx, b=d_db, xref=ref(x,'x'), yref=ref(y,'y'), grad=2, **cuda_kwargs)
+ d_x.set_shape(x.shape)
+ return d_x
+
+ # Fast version for piecewise-linear activation funcs.
+ @tf.custom_gradient
+ def func_zero_2nd_grad(x, b):
+ y = func_y(x, b)
+ @tf.custom_gradient
+ def grad(dy):
+ dx = grad_dx(dy, x, y)
+ db = grad_db(dx)
+ def grad2(d_dx, d_db):
+ d_dy = grad2_d_dy(d_dx, d_db, x, y)
+ return d_dy
+ return (dx, db), grad2
+ return y, grad
+
+ # Slow version for general activation funcs.
+ @tf.custom_gradient
+ def func_nonzero_2nd_grad(x, b):
+ y = func_y(x, b)
+ def grad_wrap(dy):
+ @tf.custom_gradient
+ def grad_impl(dy, x):
+ dx = grad_dx(dy, x, y)
+ db = grad_db(dx)
+ def grad2(d_dx, d_db):
+ d_dy = grad2_d_dy(d_dx, d_db, x, y)
+ d_x = grad2_d_x(d_dx, d_db, x, y)
+ return d_dy, d_x
+ return (dx, db), grad2
+ return grad_impl(dy, x)
+ return y, grad_wrap
+
+ # Which version to use?
+ if act_spec.zero_2nd_grad:
+ return func_zero_2nd_grad(x, b)
+ return func_nonzero_2nd_grad(x, b)
+
+#----------------------------------------------------------------------------
diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/ops/upfirdn_2d.cu b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/upfirdn_2d.cu
new file mode 100644
index 0000000000000000000000000000000000000000..7aad60d53e57d4f3e60f36a24df80a6278f1bb63
--- /dev/null
+++ b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/upfirdn_2d.cu
@@ -0,0 +1,359 @@
+// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#define EIGEN_USE_GPU
+#define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include
+
+using namespace tensorflow;
+using namespace tensorflow::shape_inference;
+
+//------------------------------------------------------------------------
+// Helpers.
+
+#define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false)
+
+static __host__ __device__ __forceinline__ int floorDiv(int a, int b)
+{
+ int t = 1 - a / b;
+ return (a + t * b) / b - t;
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel params.
+
+template
+struct UpFirDn2DKernelParams
+{
+ const T* x; // [majorDim, inH, inW, minorDim]
+ const T* k; // [kernelH, kernelW]
+ T* y; // [majorDim, outH, outW, minorDim]
+
+ int upx;
+ int upy;
+ int downx;
+ int downy;
+ int padx0;
+ int padx1;
+ int pady0;
+ int pady1;
+
+ int majorDim;
+ int inH;
+ int inW;
+ int minorDim;
+ int kernelH;
+ int kernelW;
+ int outH;
+ int outW;
+ int loopMajor;
+ int loopX;
+};
+
+//------------------------------------------------------------------------
+// General CUDA implementation for large filter kernels.
+
+template
+static __global__ void UpFirDn2DKernel_large(const UpFirDn2DKernelParams p)
+{
+ // Calculate thread index.
+ int minorIdx = blockIdx.x * blockDim.x + threadIdx.x;
+ int outY = minorIdx / p.minorDim;
+ minorIdx -= outY * p.minorDim;
+ int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
+ int majorIdxBase = blockIdx.z * p.loopMajor;
+ if (outXBase >= p.outW || outY >= p.outH || majorIdxBase >= p.majorDim)
+ return;
+
+ // Setup Y receptive field.
+ int midY = outY * p.downy + p.upy - 1 - p.pady0;
+ int inY = min(max(floorDiv(midY, p.upy), 0), p.inH);
+ int h = min(max(floorDiv(midY + p.kernelH, p.upy), 0), p.inH) - inY;
+ int kernelY = midY + p.kernelH - (inY + 1) * p.upy;
+
+ // Loop over majorDim and outX.
+ for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor && majorIdx < p.majorDim; loopMajor++, majorIdx++)
+ for (int loopX = 0, outX = outXBase; loopX < p.loopX && outX < p.outW; loopX++, outX += blockDim.y)
+ {
+ // Setup X receptive field.
+ int midX = outX * p.downx + p.upx - 1 - p.padx0;
+ int inX = min(max(floorDiv(midX, p.upx), 0), p.inW);
+ int w = min(max(floorDiv(midX + p.kernelW, p.upx), 0), p.inW) - inX;
+ int kernelX = midX + p.kernelW - (inX + 1) * p.upx;
+
+ // Initialize pointers.
+ const T* xp = &p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx];
+ const T* kp = &p.k[kernelY * p.kernelW + kernelX];
+ int xpx = p.minorDim;
+ int kpx = -p.upx;
+ int xpy = p.inW * p.minorDim;
+ int kpy = -p.upy * p.kernelW;
+
+ // Inner loop.
+ float v = 0.0f;
+ for (int y = 0; y < h; y++)
+ {
+ for (int x = 0; x < w; x++)
+ {
+ v += (float)(*xp) * (float)(*kp);
+ xp += xpx;
+ kp += kpx;
+ }
+ xp += xpy - w * xpx;
+ kp += kpy - w * kpx;
+ }
+
+ // Store result.
+ p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v;
+ }
+}
+
+//------------------------------------------------------------------------
+// Specialized CUDA implementation for small filter kernels.
+
+template
+static __global__ void UpFirDn2DKernel_small(const UpFirDn2DKernelParams p)
+{
+ //assert(kernelW % upx == 0);
+ //assert(kernelH % upy == 0);
+ const int tileInW = ((tileOutW - 1) * downx + kernelW - 1) / upx + 1;
+ const int tileInH = ((tileOutH - 1) * downy + kernelH - 1) / upy + 1;
+ __shared__ volatile float sk[kernelH][kernelW];
+ __shared__ volatile float sx[tileInH][tileInW];
+
+ // Calculate tile index.
+ int minorIdx = blockIdx.x;
+ int tileOutY = minorIdx / p.minorDim;
+ minorIdx -= tileOutY * p.minorDim;
+ tileOutY *= tileOutH;
+ int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
+ int majorIdxBase = blockIdx.z * p.loopMajor;
+ if (tileOutXBase >= p.outW | tileOutY >= p.outH | majorIdxBase >= p.majorDim)
+ return;
+
+ // Load filter kernel (flipped).
+ for (int tapIdx = threadIdx.x; tapIdx < kernelH * kernelW; tapIdx += blockDim.x)
+ {
+ int ky = tapIdx / kernelW;
+ int kx = tapIdx - ky * kernelW;
+ float v = 0.0f;
+ if (kx < p.kernelW & ky < p.kernelH)
+ v = (float)p.k[(p.kernelH - 1 - ky) * p.kernelW + (p.kernelW - 1 - kx)];
+ sk[ky][kx] = v;
+ }
+
+ // Loop over majorDim and outX.
+ for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor & majorIdx < p.majorDim; loopMajor++, majorIdx++)
+ for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outW; loopX++, tileOutX += tileOutW)
+ {
+ // Load input pixels.
+ int tileMidX = tileOutX * downx + upx - 1 - p.padx0;
+ int tileMidY = tileOutY * downy + upy - 1 - p.pady0;
+ int tileInX = floorDiv(tileMidX, upx);
+ int tileInY = floorDiv(tileMidY, upy);
+ __syncthreads();
+ for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW; inIdx += blockDim.x)
+ {
+ int relInY = inIdx / tileInW;
+ int relInX = inIdx - relInY * tileInW;
+ int inX = relInX + tileInX;
+ int inY = relInY + tileInY;
+ float v = 0.0f;
+ if (inX >= 0 & inY >= 0 & inX < p.inW & inY < p.inH)
+ v = (float)p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx];
+ sx[relInY][relInX] = v;
+ }
+
+ // Loop over output pixels.
+ __syncthreads();
+ for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW; outIdx += blockDim.x)
+ {
+ int relOutY = outIdx / tileOutW;
+ int relOutX = outIdx - relOutY * tileOutW;
+ int outX = relOutX + tileOutX;
+ int outY = relOutY + tileOutY;
+
+ // Setup receptive field.
+ int midX = tileMidX + relOutX * downx;
+ int midY = tileMidY + relOutY * downy;
+ int inX = floorDiv(midX, upx);
+ int inY = floorDiv(midY, upy);
+ int relInX = inX - tileInX;
+ int relInY = inY - tileInY;
+ int kernelX = (inX + 1) * upx - midX - 1; // flipped
+ int kernelY = (inY + 1) * upy - midY - 1; // flipped
+
+ // Inner loop.
+ float v = 0.0f;
+ #pragma unroll
+ for (int y = 0; y < kernelH / upy; y++)
+ #pragma unroll
+ for (int x = 0; x < kernelW / upx; x++)
+ v += sx[relInY + y][relInX + x] * sk[kernelY + y * upy][kernelX + x * upx];
+
+ // Store result.
+ if (outX < p.outW & outY < p.outH)
+ p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v;
+ }
+ }
+}
+
+//------------------------------------------------------------------------
+// TensorFlow op.
+
+template
+struct UpFirDn2DOp : public OpKernel
+{
+ UpFirDn2DKernelParams m_attribs;
+
+ UpFirDn2DOp(OpKernelConstruction* ctx) : OpKernel(ctx)
+ {
+ memset(&m_attribs, 0, sizeof(m_attribs));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("upx", &m_attribs.upx));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("upy", &m_attribs.upy));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("downx", &m_attribs.downx));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("downy", &m_attribs.downy));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("padx0", &m_attribs.padx0));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("padx1", &m_attribs.padx1));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("pady0", &m_attribs.pady0));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("pady1", &m_attribs.pady1));
+ OP_REQUIRES(ctx, m_attribs.upx >= 1 && m_attribs.upy >= 1, errors::InvalidArgument("upx and upy must be at least 1x1"));
+ OP_REQUIRES(ctx, m_attribs.downx >= 1 && m_attribs.downy >= 1, errors::InvalidArgument("downx and downy must be at least 1x1"));
+ }
+
+ void Compute(OpKernelContext* ctx)
+ {
+ UpFirDn2DKernelParams p = m_attribs;
+ cudaStream_t stream = ctx->eigen_device().stream();
+
+ const Tensor& x = ctx->input(0); // [majorDim, inH, inW, minorDim]
+ const Tensor& k = ctx->input(1); // [kernelH, kernelW]
+ p.x = x.flat().data();
+ p.k = k.flat().data();
+ OP_REQUIRES(ctx, x.dims() == 4, errors::InvalidArgument("input must have rank 4"));
+ OP_REQUIRES(ctx, k.dims() == 2, errors::InvalidArgument("kernel must have rank 2"));
+ OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("input too large"));
+ OP_REQUIRES(ctx, k.NumElements() <= kint32max, errors::InvalidArgument("kernel too large"));
+
+ p.majorDim = (int)x.dim_size(0);
+ p.inH = (int)x.dim_size(1);
+ p.inW = (int)x.dim_size(2);
+ p.minorDim = (int)x.dim_size(3);
+ p.kernelH = (int)k.dim_size(0);
+ p.kernelW = (int)k.dim_size(1);
+ OP_REQUIRES(ctx, p.kernelW >= 1 && p.kernelH >= 1, errors::InvalidArgument("kernel must be at least 1x1"));
+
+ p.outW = (p.inW * p.upx + p.padx0 + p.padx1 - p.kernelW + p.downx) / p.downx;
+ p.outH = (p.inH * p.upy + p.pady0 + p.pady1 - p.kernelH + p.downy) / p.downy;
+ OP_REQUIRES(ctx, p.outW >= 1 && p.outH >= 1, errors::InvalidArgument("output must be at least 1x1"));
+
+ Tensor* y = NULL; // [majorDim, outH, outW, minorDim]
+ TensorShape ys;
+ ys.AddDim(p.majorDim);
+ ys.AddDim(p.outH);
+ ys.AddDim(p.outW);
+ ys.AddDim(p.minorDim);
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, ys, &y));
+ p.y = y->flat().data();
+ OP_REQUIRES(ctx, y->NumElements() <= kint32max, errors::InvalidArgument("output too large"));
+
+ // Choose CUDA kernel to use.
+ void* cudaKernel = (void*)UpFirDn2DKernel_large;
+ int tileOutW = -1;
+ int tileOutH = -1;
+
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 7 && p.kernelH <= 7 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 5 && p.kernelH <= 5 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 3 && p.kernelH <= 3 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; }
+
+ if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
+ if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
+ if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
+ if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 2 && p.kernelH <= 2 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
+ if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; }
+ if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; }
+ if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; }
+ if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; }
+ if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; }
+ if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; }
+ if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; }
+ if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; }
+ if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; }
+ if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; }
+
+ if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 8 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 6 && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 4 && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 2 && p.kernelH <= 2 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 8; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 8; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 8; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 8; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 8; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 16; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 16; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 16; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 16; }
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 16; }
+
+ // Choose launch params.
+ dim3 blockSize;
+ dim3 gridSize;
+ if (tileOutW > 0 && tileOutH > 0) // small
+ {
+ p.loopMajor = (p.majorDim - 1) / 16384 + 1;
+ p.loopX = 1;
+ blockSize = dim3(32 * 8, 1, 1);
+ gridSize = dim3(((p.outH - 1) / tileOutH + 1) * p.minorDim, (p.outW - 1) / (p.loopX * tileOutW) + 1, (p.majorDim - 1) / p.loopMajor + 1);
+ }
+ else // large
+ {
+ p.loopMajor = (p.majorDim - 1) / 16384 + 1;
+ p.loopX = 4;
+ blockSize = dim3(4, 32, 1);
+ gridSize = dim3((p.outH * p.minorDim - 1) / blockSize.x + 1, (p.outW - 1) / (p.loopX * blockSize.y) + 1, (p.majorDim - 1) / p.loopMajor + 1);
+ }
+
+ // Launch CUDA kernel.
+ void* args[] = {&p};
+ OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(cudaKernel, gridSize, blockSize, args, 0, stream));
+ }
+};
+
+REGISTER_OP("UpFirDn2D")
+ .Input ("x: T")
+ .Input ("k: T")
+ .Output ("y: T")
+ .Attr ("T: {float, half}")
+ .Attr ("upx: int = 1")
+ .Attr ("upy: int = 1")
+ .Attr ("downx: int = 1")
+ .Attr ("downy: int = 1")
+ .Attr ("padx0: int = 0")
+ .Attr ("padx1: int = 0")
+ .Attr ("pady0: int = 0")
+ .Attr ("pady1: int = 0");
+REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint("T"), UpFirDn2DOp);
+REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint("T"), UpFirDn2DOp);
+
+//------------------------------------------------------------------------
diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/ops/upfirdn_2d.py b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/upfirdn_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..55a31af7e146da7afeb964db018f14aca3134920
--- /dev/null
+++ b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/upfirdn_2d.py
@@ -0,0 +1,418 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Custom TensorFlow ops for efficient resampling of 2D images."""
+
+import os
+import numpy as np
+import tensorflow as tf
+from .. import custom_ops
+
+def _get_plugin():
+ return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu')
+
+#----------------------------------------------------------------------------
+
+def upfirdn_2d(x, k, upx=1, upy=1, downx=1, downy=1, padx0=0, padx1=0, pady0=0, pady1=0, impl='cuda'):
+ r"""Pad, upsample, FIR filter, and downsample a batch of 2D images.
+
+ Accepts a batch of 2D images of the shape `[majorDim, inH, inW, minorDim]`
+ and performs the following operations for each image, batched across
+ `majorDim` and `minorDim`:
+
+ 1. Upsample the image by inserting the zeros after each pixel (`upx`, `upy`).
+
+ 2. Pad the image with zeros by the specified number of pixels on each side
+ (`padx0`, `padx1`, `pady0`, `pady1`). Specifying a negative value
+ corresponds to cropping the image.
+
+ 3. Convolve the image with the specified 2D FIR filter (`k`), shrinking the
+ image so that the footprint of all output pixels lies within the input image.
+
+ 4. Downsample the image by throwing away pixels (`downx`, `downy`).
+
+ This sequence of operations bears close resemblance to scipy.signal.upfirdn().
+ The fused op is considerably more efficient than performing the same calculation
+ using standard TensorFlow ops. It supports gradients of arbitrary order.
+
+ Args:
+ x: Input tensor of the shape `[majorDim, inH, inW, minorDim]`.
+ k: 2D FIR filter of the shape `[firH, firW]`.
+ upx: Integer upsampling factor along the X-axis (default: 1).
+ upy: Integer upsampling factor along the Y-axis (default: 1).
+ downx: Integer downsampling factor along the X-axis (default: 1).
+ downy: Integer downsampling factor along the Y-axis (default: 1).
+ padx0: Number of pixels to pad on the left side (default: 0).
+ padx1: Number of pixels to pad on the right side (default: 0).
+ pady0: Number of pixels to pad on the top side (default: 0).
+ pady1: Number of pixels to pad on the bottom side (default: 0).
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
+
+ Returns:
+ Tensor of the shape `[majorDim, outH, outW, minorDim]`, and same datatype as `x`.
+ """
+
+ impl_dict = {
+ 'ref': _upfirdn_2d_ref,
+ 'cuda': _upfirdn_2d_cuda,
+ }
+ return impl_dict[impl](x=x, k=k, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1)
+
+#----------------------------------------------------------------------------
+
+def _upfirdn_2d_ref(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
+ """Slow reference implementation of `upfirdn_2d()` using standard TensorFlow ops."""
+
+ x = tf.convert_to_tensor(x)
+ k = np.asarray(k, dtype=np.float32)
+ assert x.shape.rank == 4
+ inH = x.shape[1].value
+ inW = x.shape[2].value
+ minorDim = _shape(x, 3)
+ kernelH, kernelW = k.shape
+ assert inW >= 1 and inH >= 1
+ assert kernelW >= 1 and kernelH >= 1
+ assert isinstance(upx, int) and isinstance(upy, int)
+ assert isinstance(downx, int) and isinstance(downy, int)
+ assert isinstance(padx0, int) and isinstance(padx1, int)
+ assert isinstance(pady0, int) and isinstance(pady1, int)
+
+ # Upsample (insert zeros).
+ x = tf.reshape(x, [-1, inH, 1, inW, 1, minorDim])
+ x = tf.pad(x, [[0, 0], [0, 0], [0, upy - 1], [0, 0], [0, upx - 1], [0, 0]])
+ x = tf.reshape(x, [-1, inH * upy, inW * upx, minorDim])
+
+ # Pad (crop if negative).
+ x = tf.pad(x, [[0, 0], [max(pady0, 0), max(pady1, 0)], [max(padx0, 0), max(padx1, 0)], [0, 0]])
+ x = x[:, max(-pady0, 0) : x.shape[1].value - max(-pady1, 0), max(-padx0, 0) : x.shape[2].value - max(-padx1, 0), :]
+
+ # Convolve with filter.
+ x = tf.transpose(x, [0, 3, 1, 2])
+ x = tf.reshape(x, [-1, 1, inH * upy + pady0 + pady1, inW * upx + padx0 + padx1])
+ w = tf.constant(k[::-1, ::-1, np.newaxis, np.newaxis], dtype=x.dtype)
+ x = tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='VALID', data_format='NCHW')
+ x = tf.reshape(x, [-1, minorDim, inH * upy + pady0 + pady1 - kernelH + 1, inW * upx + padx0 + padx1 - kernelW + 1])
+ x = tf.transpose(x, [0, 2, 3, 1])
+
+ # Downsample (throw away pixels).
+ return x[:, ::downy, ::downx, :]
+
+#----------------------------------------------------------------------------
+
+def _upfirdn_2d_cuda(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
+ """Fast CUDA implementation of `upfirdn_2d()` using custom ops."""
+
+ x = tf.convert_to_tensor(x)
+ k = np.asarray(k, dtype=np.float32)
+ majorDim, inH, inW, minorDim = x.shape.as_list()
+ kernelH, kernelW = k.shape
+ assert inW >= 1 and inH >= 1
+ assert kernelW >= 1 and kernelH >= 1
+ assert isinstance(upx, int) and isinstance(upy, int)
+ assert isinstance(downx, int) and isinstance(downy, int)
+ assert isinstance(padx0, int) and isinstance(padx1, int)
+ assert isinstance(pady0, int) and isinstance(pady1, int)
+
+ outW = (inW * upx + padx0 + padx1 - kernelW) // downx + 1
+ outH = (inH * upy + pady0 + pady1 - kernelH) // downy + 1
+ assert outW >= 1 and outH >= 1
+
+ cuda_op = _get_plugin().up_fir_dn2d
+ kc = tf.constant(k, dtype=x.dtype)
+ gkc = tf.constant(k[::-1, ::-1], dtype=x.dtype)
+ gpadx0 = kernelW - padx0 - 1
+ gpady0 = kernelH - pady0 - 1
+ gpadx1 = inW * upx - outW * downx + padx0 - upx + 1
+ gpady1 = inH * upy - outH * downy + pady0 - upy + 1
+
+ @tf.custom_gradient
+ def func(x):
+ y = cuda_op(x=x, k=kc, upx=int(upx), upy=int(upy), downx=int(downx), downy=int(downy), padx0=int(padx0), padx1=int(padx1), pady0=int(pady0), pady1=int(pady1))
+ y.set_shape([majorDim, outH, outW, minorDim])
+ @tf.custom_gradient
+ def grad(dy):
+ dx = cuda_op(x=dy, k=gkc, upx=int(downx), upy=int(downy), downx=int(upx), downy=int(upy), padx0=int(gpadx0), padx1=int(gpadx1), pady0=int(gpady0), pady1=int(gpady1))
+ dx.set_shape([majorDim, inH, inW, minorDim])
+ return dx, func
+ return y, grad
+ return func(x)
+
+#----------------------------------------------------------------------------
+
+def filter_2d(x, k, gain=1, padding=0, data_format='NCHW', impl='cuda'):
+ r"""Filter a batch of 2D images with the given FIR filter.
+
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
+ and filters each image with the given filter. The filter is normalized so that
+ if the input pixels are constant, they will be scaled by the specified `gain`.
+ Pixels outside the image are assumed to be zero.
+
+ Args:
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
+ gain: Scaling factor for signal magnitude (default: 1.0).
+ padding: Number of pixels to pad or crop the output on each side (default: 0).
+ data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
+
+ Returns:
+ Tensor of the same shape and datatype as `x`.
+ """
+
+ assert isinstance(padding, int)
+ k = _FilterKernel(k=k, gain=gain)
+ assert k.w == k.h
+ pad0 = k.w // 2 + padding
+ pad1 = (k.w - 1) // 2 + padding
+ return _simple_upfirdn_2d(x, k, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl)
+
+#----------------------------------------------------------------------------
+
+def upsample_2d(x, k=None, factor=2, gain=1, padding=0, data_format='NCHW', impl='cuda'):
+ r"""Upsample a batch of 2D images with the given filter.
+
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
+ and upsamples each image with the given filter. The filter is normalized so that
+ if the input pixels are constant, they will be scaled by the specified `gain`.
+ Pixels outside the image are assumed to be zero, and the filter is padded with
+ zeros so that its shape is a multiple of the upsampling factor.
+
+ Args:
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
+ The default is `[1] * factor`, which corresponds to nearest-neighbor
+ upsampling.
+ factor: Integer upsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
+ padding: Number of pixels to pad or crop the output on each side (default: 0).
+ data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
+
+ Returns:
+ Tensor of the shape `[N, C, H * factor, W * factor]` or
+ `[N, H * factor, W * factor, C]`, and same datatype as `x`.
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+ assert isinstance(padding, int)
+ k = _FilterKernel(k if k is not None else [1] * factor, gain * (factor ** 2))
+ assert k.w == k.h
+ pad0 = (k.w + factor - 1) // 2 + padding
+ pad1 = (k.w - factor) // 2 + padding
+ return _simple_upfirdn_2d(x, k, up=factor, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl)
+
+#----------------------------------------------------------------------------
+
+def downsample_2d(x, k=None, factor=2, gain=1, padding=0, data_format='NCHW', impl='cuda'):
+ r"""Downsample a batch of 2D images with the given filter.
+
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
+ and downsamples each image with the given filter. The filter is normalized so that
+ if the input pixels are constant, they will be scaled by the specified `gain`.
+ Pixels outside the image are assumed to be zero, and the filter is padded with
+ zeros so that its shape is a multiple of the downsampling factor.
+
+ Args:
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
+ The default is `[1] * factor`, which corresponds to average pooling.
+ factor: Integer downsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
+ padding: Number of pixels to pad or crop the output on each side (default: 0).
+ data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
+
+ Returns:
+ Tensor of the shape `[N, C, H // factor, W // factor]` or
+ `[N, H // factor, W // factor, C]`, and same datatype as `x`.
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+ assert isinstance(padding, int)
+ k = _FilterKernel(k if k is not None else [1] * factor, gain)
+ assert k.w == k.h
+ pad0 = (k.w - factor + 1) // 2 + padding * factor
+ pad1 = (k.w - factor) // 2 + padding * factor
+ return _simple_upfirdn_2d(x, k, down=factor, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl)
+
+#----------------------------------------------------------------------------
+
+def upsample_conv_2d(x, w, k=None, factor=2, gain=1, padding=0, data_format='NCHW', impl='cuda'):
+ r"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
+
+ Padding is performed only once at the beginning, not between the operations.
+ The fused op is considerably more efficient than performing the same calculation
+ using standard TensorFlow ops. It supports gradients of arbitrary order.
+
+ Args:
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`.
+ Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
+ k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
+ The default is `[1] * factor`, which corresponds to nearest-neighbor
+ upsampling.
+ factor: Integer upsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
+ padding: Number of pixels to pad or crop the output on each side (default: 0).
+ data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
+
+ Returns:
+ Tensor of the shape `[N, C, H * factor, W * factor]` or
+ `[N, H * factor, W * factor, C]`, and same datatype as `x`.
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+ assert isinstance(padding, int)
+
+ # Check weight shape.
+ w = tf.convert_to_tensor(w)
+ ch, cw, _inC, _outC = w.shape.as_list()
+ inC = _shape(w, 2)
+ outC = _shape(w, 3)
+ assert cw == ch
+
+ # Fast path for 1x1 convolution.
+ if cw == 1 and ch == 1:
+ x = tf.nn.conv2d(x, w, data_format=data_format, strides=[1,1,1,1], padding='VALID')
+ x = upsample_2d(x, k, factor=factor, gain=gain, padding=padding, data_format=data_format, impl=impl)
+ return x
+
+ # Setup filter kernel.
+ k = _FilterKernel(k if k is not None else [1] * factor, gain * (factor ** 2))
+ assert k.w == k.h
+
+ # Determine data dimensions.
+ if data_format == 'NCHW':
+ stride = [1, 1, factor, factor]
+ output_shape = [_shape(x, 0), outC, (_shape(x, 2) - 1) * factor + ch, (_shape(x, 3) - 1) * factor + cw]
+ num_groups = _shape(x, 1) // inC
+ else:
+ stride = [1, factor, factor, 1]
+ output_shape = [_shape(x, 0), (_shape(x, 1) - 1) * factor + ch, (_shape(x, 2) - 1) * factor + cw, outC]
+ num_groups = _shape(x, 3) // inC
+
+ # Transpose weights.
+ w = tf.reshape(w, [ch, cw, inC, num_groups, -1])
+ w = tf.transpose(w[::-1, ::-1], [0, 1, 4, 3, 2])
+ w = tf.reshape(w, [ch, cw, -1, num_groups * inC])
+
+ # Execute.
+ x = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=stride, padding='VALID', data_format=data_format)
+ pad0 = (k.w + factor - cw) // 2 + padding
+ pad1 = (k.w - factor - cw + 3) // 2 + padding
+ return _simple_upfirdn_2d(x, k, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl)
+
+#----------------------------------------------------------------------------
+
+def conv_downsample_2d(x, w, k=None, factor=2, gain=1, padding=0, data_format='NCHW', impl='cuda'):
+ r"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
+
+ Padding is performed only once at the beginning, not between the operations.
+ The fused op is considerably more efficient than performing the same calculation
+ using standard TensorFlow ops. It supports gradients of arbitrary order.
+
+ Args:
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`.
+ Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
+ k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
+ The default is `[1] * factor`, which corresponds to average pooling.
+ factor: Integer downsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
+ padding: Number of pixels to pad or crop the output on each side (default: 0).
+ data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
+
+ Returns:
+ Tensor of the shape `[N, C, H // factor, W // factor]` or
+ `[N, H // factor, W // factor, C]`, and same datatype as `x`.
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+ assert isinstance(padding, int)
+
+ # Check weight shape.
+ w = tf.convert_to_tensor(w)
+ ch, cw, _inC, _outC = w.shape.as_list()
+ assert cw == ch
+
+ # Fast path for 1x1 convolution.
+ if cw == 1 and ch == 1:
+ x = downsample_2d(x, k, factor=factor, gain=gain, padding=padding, data_format=data_format, impl=impl)
+ x = tf.nn.conv2d(x, w, data_format=data_format, strides=[1,1,1,1], padding='VALID')
+ return x
+
+ # Setup filter kernel.
+ k = _FilterKernel(k if k is not None else [1] * factor, gain)
+ assert k.w == k.h
+
+ # Determine stride.
+ if data_format == 'NCHW':
+ s = [1, 1, factor, factor]
+ else:
+ s = [1, factor, factor, 1]
+
+ # Execute.
+ pad0 = (k.w - factor + cw) // 2 + padding * factor
+ pad1 = (k.w - factor + cw - 1) // 2 + padding * factor
+ x = _simple_upfirdn_2d(x, k, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl)
+ return tf.nn.conv2d(x, w, strides=s, padding='VALID', data_format=data_format)
+
+#----------------------------------------------------------------------------
+# Internal helpers.
+
+class _FilterKernel:
+ def __init__(self, k, gain=1):
+ k = np.asarray(k, dtype=np.float32)
+ k /= np.sum(k)
+
+ # Separable.
+ if k.ndim == 1 and k.size >= 8:
+ self.w = k.size
+ self.h = k.size
+ self.kx = k[np.newaxis, :]
+ self.ky = k[:, np.newaxis] * gain
+ self.kxy = None
+
+ # Non-separable.
+ else:
+ if k.ndim == 1:
+ k = np.outer(k, k)
+ assert k.ndim == 2
+ self.w = k.shape[1]
+ self.h = k.shape[0]
+ self.kx = None
+ self.ky = None
+ self.kxy = k * gain
+
+def _simple_upfirdn_2d(x, k, up=1, down=1, pad0=0, pad1=0, data_format='NCHW', impl='cuda'):
+ assert isinstance(k, _FilterKernel)
+ assert data_format in ['NCHW', 'NHWC']
+ assert x.shape.rank == 4
+ y = x
+ if data_format == 'NCHW':
+ y = tf.reshape(y, [-1, _shape(y, 2), _shape(y, 3), 1])
+ if k.kx is not None:
+ y = upfirdn_2d(y, k.kx, upx=up, downx=down, padx0=pad0, padx1=pad1, impl=impl)
+ if k.ky is not None:
+ y = upfirdn_2d(y, k.ky, upy=up, downy=down, pady0=pad0, pady1=pad1, impl=impl)
+ if k.kxy is not None:
+ y = upfirdn_2d(y, k.kxy, upx=up, upy=up, downx=down, downy=down, padx0=pad0, padx1=pad1, pady0=pad0, pady1=pad1, impl=impl)
+ if data_format == 'NCHW':
+ y = tf.reshape(y, [-1, _shape(x, 1), _shape(y, 1), _shape(y, 2)])
+ return y
+
+def _shape(tf_expr, dim_idx):
+ if tf_expr.shape.rank is not None:
+ dim = tf_expr.shape[dim_idx].value
+ if dim is not None:
+ return dim
+ return tf.shape(tf_expr)[dim_idx]
+
+#----------------------------------------------------------------------------
diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/optimizer.py b/models/StyleCLIP/global_directions/dnnlib/tflib/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..cae5ffff3d11aaccd705d6936e080175ab97dd0e
--- /dev/null
+++ b/models/StyleCLIP/global_directions/dnnlib/tflib/optimizer.py
@@ -0,0 +1,372 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Helper wrapper for a Tensorflow optimizer."""
+
+import platform
+import numpy as np
+import tensorflow as tf
+
+from collections import OrderedDict
+from typing import List, Union
+
+from . import autosummary
+from . import tfutil
+from .. import util
+
+from .tfutil import TfExpression, TfExpressionEx
+
+_collective_ops_warning_printed = False
+_collective_ops_group_key = 831766147
+_collective_ops_instance_key = 436340067
+
+class Optimizer:
+ """A Wrapper for tf.train.Optimizer.
+
+ Automatically takes care of:
+ - Gradient averaging for multi-GPU training.
+ - Gradient accumulation for arbitrarily large minibatches.
+ - Dynamic loss scaling and typecasts for FP16 training.
+ - Ignoring corrupted gradients that contain NaNs/Infs.
+ - Reporting statistics.
+ - Well-chosen default settings.
+ """
+
+ def __init__(self,
+ name: str = "Train", # Name string that will appear in TensorFlow graph.
+ tf_optimizer: str = "tf.train.AdamOptimizer", # Underlying optimizer class.
+ learning_rate: TfExpressionEx = 0.001, # Learning rate. Can vary over time.
+ minibatch_multiplier: TfExpressionEx = None, # Treat N consecutive minibatches as one by accumulating gradients.
+ share: "Optimizer" = None, # Share internal state with a previously created optimizer?
+ use_loss_scaling: bool = False, # Enable dynamic loss scaling for robust mixed-precision training?
+ loss_scaling_init: float = 64.0, # Log2 of initial loss scaling factor.
+ loss_scaling_inc: float = 0.0005, # Log2 of per-minibatch loss scaling increment when there is no overflow.
+ loss_scaling_dec: float = 1.0, # Log2 of per-minibatch loss scaling decrement when there is an overflow.
+ report_mem_usage: bool = False, # Report fine-grained memory usage statistics in TensorBoard?
+ **kwargs):
+
+ # Public fields.
+ self.name = name
+ self.learning_rate = learning_rate
+ self.minibatch_multiplier = minibatch_multiplier
+ self.id = self.name.replace("/", ".")
+ self.scope = tf.get_default_graph().unique_name(self.id)
+ self.optimizer_class = util.get_obj_by_name(tf_optimizer)
+ self.optimizer_kwargs = dict(kwargs)
+ self.use_loss_scaling = use_loss_scaling
+ self.loss_scaling_init = loss_scaling_init
+ self.loss_scaling_inc = loss_scaling_inc
+ self.loss_scaling_dec = loss_scaling_dec
+
+ # Private fields.
+ self._updates_applied = False
+ self._devices = OrderedDict() # device_name => EasyDict()
+ self._shared_optimizers = OrderedDict() # device_name => optimizer_class
+ self._gradient_shapes = None # [shape, ...]
+ self._report_mem_usage = report_mem_usage
+
+ # Validate arguments.
+ assert callable(self.optimizer_class)
+
+ # Share internal state if requested.
+ if share is not None:
+ assert isinstance(share, Optimizer)
+ assert self.optimizer_class is share.optimizer_class
+ assert self.learning_rate is share.learning_rate
+ assert self.optimizer_kwargs == share.optimizer_kwargs
+ self._shared_optimizers = share._shared_optimizers # pylint: disable=protected-access
+
+ def _get_device(self, device_name: str):
+ """Get internal state for the given TensorFlow device."""
+ tfutil.assert_tf_initialized()
+ if device_name in self._devices:
+ return self._devices[device_name]
+
+ # Initialize fields.
+ device = util.EasyDict()
+ device.name = device_name
+ device.optimizer = None # Underlying optimizer: optimizer_class
+ device.loss_scaling_var = None # Log2 of loss scaling: tf.Variable
+ device.grad_raw = OrderedDict() # Raw gradients: var => [grad, ...]
+ device.grad_clean = OrderedDict() # Clean gradients: var => grad
+ device.grad_acc_vars = OrderedDict() # Accumulation sums: var => tf.Variable
+ device.grad_acc_count = None # Accumulation counter: tf.Variable
+ device.grad_acc = OrderedDict() # Accumulated gradients: var => grad
+
+ # Setup TensorFlow objects.
+ with tfutil.absolute_name_scope(self.scope + "/Devices"), tf.device(device_name), tf.control_dependencies(None):
+ if device_name not in self._shared_optimizers:
+ optimizer_name = self.scope.replace("/", "_") + "_opt%d" % len(self._shared_optimizers)
+ self._shared_optimizers[device_name] = self.optimizer_class(name=optimizer_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
+ device.optimizer = self._shared_optimizers[device_name]
+ if self.use_loss_scaling:
+ device.loss_scaling_var = tf.Variable(np.float32(self.loss_scaling_init), trainable=False, name="loss_scaling_var")
+
+ # Register device.
+ self._devices[device_name] = device
+ return device
+
+ def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
+ """Register the gradients of the given loss function with respect to the given variables.
+ Intended to be called once per GPU."""
+ tfutil.assert_tf_initialized()
+ assert not self._updates_applied
+ device = self._get_device(loss.device)
+
+ # Validate trainables.
+ if isinstance(trainable_vars, dict):
+ trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars
+ assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
+ assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
+ assert all(var.device == device.name for var in trainable_vars)
+
+ # Validate shapes.
+ if self._gradient_shapes is None:
+ self._gradient_shapes = [var.shape.as_list() for var in trainable_vars]
+ assert len(trainable_vars) == len(self._gradient_shapes)
+ assert all(var.shape.as_list() == var_shape for var, var_shape in zip(trainable_vars, self._gradient_shapes))
+
+ # Report memory usage if requested.
+ deps = [loss]
+ if self._report_mem_usage:
+ self._report_mem_usage = False
+ try:
+ with tf.name_scope(self.id + '_mem'), tf.device(device.name), tf.control_dependencies([loss]):
+ deps.append(autosummary.autosummary(self.id + "/mem_usage_gb", tf.contrib.memory_stats.BytesInUse() / 2**30))
+ except tf.errors.NotFoundError:
+ pass
+
+ # Compute gradients.
+ with tf.name_scope(self.id + "_grad"), tf.device(device.name), tf.control_dependencies(deps):
+ loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
+ gate = tf.train.Optimizer.GATE_NONE # disable gating to reduce memory usage
+ grad_list = device.optimizer.compute_gradients(loss=loss, var_list=trainable_vars, gate_gradients=gate)
+
+ # Register gradients.
+ for grad, var in grad_list:
+ if var not in device.grad_raw:
+ device.grad_raw[var] = []
+ device.grad_raw[var].append(grad)
+
+ def apply_updates(self, allow_no_op: bool = False) -> tf.Operation:
+ """Construct training op to update the registered variables based on their gradients."""
+ tfutil.assert_tf_initialized()
+ assert not self._updates_applied
+ self._updates_applied = True
+ all_ops = []
+
+ # Check for no-op.
+ if allow_no_op and len(self._devices) == 0:
+ with tfutil.absolute_name_scope(self.scope):
+ return tf.no_op(name='TrainingOp')
+
+ # Clean up gradients.
+ for device_idx, device in enumerate(self._devices.values()):
+ with tfutil.absolute_name_scope(self.scope + "/Clean%d" % device_idx), tf.device(device.name):
+ for var, grad in device.grad_raw.items():
+
+ # Filter out disconnected gradients and convert to float32.
+ grad = [g for g in grad if g is not None]
+ grad = [tf.cast(g, tf.float32) for g in grad]
+
+ # Sum within the device.
+ if len(grad) == 0:
+ grad = tf.zeros(var.shape) # No gradients => zero.
+ elif len(grad) == 1:
+ grad = grad[0] # Single gradient => use as is.
+ else:
+ grad = tf.add_n(grad) # Multiple gradients => sum.
+
+ # Scale as needed.
+ scale = 1.0 / len(device.grad_raw[var]) / len(self._devices)
+ scale = tf.constant(scale, dtype=tf.float32, name="scale")
+ if self.minibatch_multiplier is not None:
+ scale /= tf.cast(self.minibatch_multiplier, tf.float32)
+ scale = self.undo_loss_scaling(scale)
+ device.grad_clean[var] = grad * scale
+
+ # Sum gradients across devices.
+ if len(self._devices) > 1:
+ with tfutil.absolute_name_scope(self.scope + "/Broadcast"), tf.device(None):
+ if platform.system() == "Windows": # Windows => NCCL ops are not available.
+ self._broadcast_fallback()
+ elif tf.VERSION.startswith("1.15."): # TF 1.15 => NCCL ops are broken: https://github.com/tensorflow/tensorflow/issues/41539
+ self._broadcast_fallback()
+ else: # Otherwise => NCCL ops are safe to use.
+ self._broadcast_nccl()
+
+ # Apply updates separately on each device.
+ for device_idx, device in enumerate(self._devices.values()):
+ with tfutil.absolute_name_scope(self.scope + "/Apply%d" % device_idx), tf.device(device.name):
+ # pylint: disable=cell-var-from-loop
+
+ # Accumulate gradients over time.
+ if self.minibatch_multiplier is None:
+ acc_ok = tf.constant(True, name='acc_ok')
+ device.grad_acc = OrderedDict(device.grad_clean)
+ else:
+ # Create variables.
+ with tf.control_dependencies(None):
+ for var in device.grad_clean.keys():
+ device.grad_acc_vars[var] = tf.Variable(tf.zeros(var.shape), trainable=False, name="grad_acc_var")
+ device.grad_acc_count = tf.Variable(tf.zeros([]), trainable=False, name="grad_acc_count")
+
+ # Track counter.
+ count_cur = device.grad_acc_count + 1.0
+ count_inc_op = lambda: tf.assign(device.grad_acc_count, count_cur)
+ count_reset_op = lambda: tf.assign(device.grad_acc_count, tf.zeros([]))
+ acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier, tf.float32))
+ all_ops.append(tf.cond(acc_ok, count_reset_op, count_inc_op))
+
+ # Track gradients.
+ for var, grad in device.grad_clean.items():
+ acc_var = device.grad_acc_vars[var]
+ acc_cur = acc_var + grad
+ device.grad_acc[var] = acc_cur
+ with tf.control_dependencies([acc_cur]):
+ acc_inc_op = lambda: tf.assign(acc_var, acc_cur)
+ acc_reset_op = lambda: tf.assign(acc_var, tf.zeros(var.shape))
+ all_ops.append(tf.cond(acc_ok, acc_reset_op, acc_inc_op))
+
+ # No overflow => apply gradients.
+ all_ok = tf.reduce_all(tf.stack([acc_ok] + [tf.reduce_all(tf.is_finite(g)) for g in device.grad_acc.values()]))
+ apply_op = lambda: device.optimizer.apply_gradients([(tf.cast(grad, var.dtype), var) for var, grad in device.grad_acc.items()])
+ all_ops.append(tf.cond(all_ok, apply_op, tf.no_op))
+
+ # Adjust loss scaling.
+ if self.use_loss_scaling:
+ ls_inc_op = lambda: tf.assign_add(device.loss_scaling_var, self.loss_scaling_inc)
+ ls_dec_op = lambda: tf.assign_sub(device.loss_scaling_var, self.loss_scaling_dec)
+ ls_update_op = lambda: tf.group(tf.cond(all_ok, ls_inc_op, ls_dec_op))
+ all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op))
+
+ # Last device => report statistics.
+ if device_idx == len(self._devices) - 1:
+ all_ops.append(autosummary.autosummary(self.id + "/learning_rate", tf.convert_to_tensor(self.learning_rate)))
+ all_ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(all_ok, 0, 1), condition=acc_ok))
+ if self.use_loss_scaling:
+ all_ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", device.loss_scaling_var))
+
+ # Initialize variables.
+ self.reset_optimizer_state()
+ if self.use_loss_scaling:
+ tfutil.init_uninitialized_vars([device.loss_scaling_var for device in self._devices.values()])
+ if self.minibatch_multiplier is not None:
+ tfutil.run([var.initializer for device in self._devices.values() for var in list(device.grad_acc_vars.values()) + [device.grad_acc_count]])
+
+ # Group everything into a single op.
+ with tfutil.absolute_name_scope(self.scope):
+ return tf.group(*all_ops, name="TrainingOp")
+
+ def reset_optimizer_state(self) -> None:
+ """Reset internal state of the underlying optimizer."""
+ tfutil.assert_tf_initialized()
+ tfutil.run([var.initializer for device in self._devices.values() for var in device.optimizer.variables()])
+
+ def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
+ """Get or create variable representing log2 of the current dynamic loss scaling factor."""
+ return self._get_device(device).loss_scaling_var
+
+ def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
+ """Apply dynamic loss scaling for the given expression."""
+ assert tfutil.is_tf_expression(value)
+ if not self.use_loss_scaling:
+ return value
+ return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
+
+ def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
+ """Undo the effect of dynamic loss scaling for the given expression."""
+ assert tfutil.is_tf_expression(value)
+ if not self.use_loss_scaling:
+ return value
+ return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
+
+ def _broadcast_nccl(self):
+ """Sum gradients across devices using NCCL ops (fast path)."""
+ from tensorflow.python.ops import nccl_ops # pylint: disable=no-name-in-module
+ for all_vars in zip(*[device.grad_clean.keys() for device in self._devices.values()]):
+ if any(x.shape.num_elements() > 0 for x in all_vars):
+ all_grads = [device.grad_clean[var] for device, var in zip(self._devices.values(), all_vars)]
+ all_grads = nccl_ops.all_sum(all_grads)
+ for device, var, grad in zip(self._devices.values(), all_vars, all_grads):
+ device.grad_clean[var] = grad
+
+ def _broadcast_fallback(self):
+ """Sum gradients across devices using TensorFlow collective ops (slow fallback path)."""
+ from tensorflow.python.ops import collective_ops # pylint: disable=no-name-in-module
+ global _collective_ops_warning_printed, _collective_ops_group_key, _collective_ops_instance_key
+ if all(x.shape.num_elements() == 0 for device in self._devices.values() for x in device.grad_clean.values()):
+ return
+ if not _collective_ops_warning_printed:
+ print("------------------------------------------------------------------------")
+ print("WARNING: Using slow fallback implementation for inter-GPU communication.")
+ print("Please use TensorFlow 1.14 on Linux for optimal training performance.")
+ print("------------------------------------------------------------------------")
+ _collective_ops_warning_printed = True
+ for device in self._devices.values():
+ with tf.device(device.name):
+ combo = [tf.reshape(x, [x.shape.num_elements()]) for x in device.grad_clean.values()]
+ combo = tf.concat(combo, axis=0)
+ combo = collective_ops.all_reduce(combo, merge_op='Add', final_op='Id',
+ group_size=len(self._devices), group_key=_collective_ops_group_key,
+ instance_key=_collective_ops_instance_key)
+ cur_ofs = 0
+ for var, grad_old in device.grad_clean.items():
+ grad_new = tf.reshape(combo[cur_ofs : cur_ofs + grad_old.shape.num_elements()], grad_old.shape)
+ cur_ofs += grad_old.shape.num_elements()
+ device.grad_clean[var] = grad_new
+ _collective_ops_instance_key += 1
+
+
+class SimpleAdam:
+ """Simplified version of tf.train.AdamOptimizer that behaves identically when used with dnnlib.tflib.Optimizer."""
+
+ def __init__(self, name="Adam", learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
+ self.name = name
+ self.learning_rate = learning_rate
+ self.beta1 = beta1
+ self.beta2 = beta2
+ self.epsilon = epsilon
+ self.all_state_vars = []
+
+ def variables(self):
+ return self.all_state_vars
+
+ def compute_gradients(self, loss, var_list, gate_gradients=tf.train.Optimizer.GATE_NONE):
+ assert gate_gradients == tf.train.Optimizer.GATE_NONE
+ return list(zip(tf.gradients(loss, var_list), var_list))
+
+ def apply_gradients(self, grads_and_vars):
+ with tf.name_scope(self.name):
+ state_vars = []
+ update_ops = []
+
+ # Adjust learning rate to deal with startup bias.
+ with tf.control_dependencies(None):
+ b1pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
+ b2pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
+ state_vars += [b1pow_var, b2pow_var]
+ b1pow_new = b1pow_var * self.beta1
+ b2pow_new = b2pow_var * self.beta2
+ update_ops += [tf.assign(b1pow_var, b1pow_new), tf.assign(b2pow_var, b2pow_new)]
+ lr_new = self.learning_rate * tf.sqrt(1 - b2pow_new) / (1 - b1pow_new)
+
+ # Construct ops to update each variable.
+ for grad, var in grads_and_vars:
+ with tf.control_dependencies(None):
+ m_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
+ v_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
+ state_vars += [m_var, v_var]
+ m_new = self.beta1 * m_var + (1 - self.beta1) * grad
+ v_new = self.beta2 * v_var + (1 - self.beta2) * tf.square(grad)
+ var_delta = lr_new * m_new / (tf.sqrt(v_new) + self.epsilon)
+ update_ops += [tf.assign(m_var, m_new), tf.assign(v_var, v_new), tf.assign_sub(var, var_delta)]
+
+ # Group everything together.
+ self.all_state_vars += state_vars
+ return tf.group(*update_ops)
diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/tfutil.py b/models/StyleCLIP/global_directions/dnnlib/tflib/tfutil.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe21100299251492ee6d49a7fab566ffb8283702
--- /dev/null
+++ b/models/StyleCLIP/global_directions/dnnlib/tflib/tfutil.py
@@ -0,0 +1,262 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Miscellaneous helper utils for Tensorflow."""
+
+import os
+import numpy as np
+import tensorflow as tf
+
+# Silence deprecation warnings from TensorFlow 1.13 onwards
+import logging
+logging.getLogger('tensorflow').setLevel(logging.ERROR)
+import tensorflow.contrib # requires TensorFlow 1.x!
+tf.contrib = tensorflow.contrib
+
+from typing import Any, Iterable, List, Union
+
+TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation]
+"""A type that represents a valid Tensorflow expression."""
+
+TfExpressionEx = Union[TfExpression, int, float, np.ndarray]
+"""A type that can be converted to a valid Tensorflow expression."""
+
+
+def run(*args, **kwargs) -> Any:
+ """Run the specified ops in the default session."""
+ assert_tf_initialized()
+ return tf.get_default_session().run(*args, **kwargs)
+
+
+def is_tf_expression(x: Any) -> bool:
+ """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation."""
+ return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation))
+
+
+def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]:
+ """Convert a Tensorflow shape to a list of ints. Retained for backwards compatibility -- use TensorShape.as_list() in new code."""
+ return [dim.value for dim in shape]
+
+
+def flatten(x: TfExpressionEx) -> TfExpression:
+ """Shortcut function for flattening a tensor."""
+ with tf.name_scope("Flatten"):
+ return tf.reshape(x, [-1])
+
+
+def log2(x: TfExpressionEx) -> TfExpression:
+ """Logarithm in base 2."""
+ with tf.name_scope("Log2"):
+ return tf.log(x) * np.float32(1.0 / np.log(2.0))
+
+
+def exp2(x: TfExpressionEx) -> TfExpression:
+ """Exponent in base 2."""
+ with tf.name_scope("Exp2"):
+ return tf.exp(x * np.float32(np.log(2.0)))
+
+
+def erfinv(y: TfExpressionEx) -> TfExpression:
+ """Inverse of the error function."""
+ # pylint: disable=no-name-in-module
+ from tensorflow.python.ops.distributions import special_math
+ return special_math.erfinv(y)
+
+
+def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx:
+ """Linear interpolation."""
+ with tf.name_scope("Lerp"):
+ return a + (b - a) * t
+
+
+def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression:
+ """Linear interpolation with clip."""
+ with tf.name_scope("LerpClip"):
+ return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
+
+
+def absolute_name_scope(scope: str) -> tf.name_scope:
+ """Forcefully enter the specified name scope, ignoring any surrounding scopes."""
+ return tf.name_scope(scope + "/")
+
+
+def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope:
+ """Forcefully enter the specified variable scope, ignoring any surrounding scopes."""
+ return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False)
+
+
+def _sanitize_tf_config(config_dict: dict = None) -> dict:
+ # Defaults.
+ cfg = dict()
+ cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is.
+ cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
+ cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
+ cfg["env.HDF5_USE_FILE_LOCKING"] = "FALSE" # Disable HDF5 file locking to avoid concurrency issues with network shares.
+ cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
+ cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.
+
+ # Remove defaults for environment variables that are already set.
+ for key in list(cfg):
+ fields = key.split(".")
+ if fields[0] == "env":
+ assert len(fields) == 2
+ if fields[1] in os.environ:
+ del cfg[key]
+
+ # User overrides.
+ if config_dict is not None:
+ cfg.update(config_dict)
+ return cfg
+
+
+def init_tf(config_dict: dict = None) -> None:
+ """Initialize TensorFlow session using good default settings."""
+ # Skip if already initialized.
+ if tf.get_default_session() is not None:
+ return
+
+ # Setup config dict and random seeds.
+ cfg = _sanitize_tf_config(config_dict)
+ np_random_seed = cfg["rnd.np_random_seed"]
+ if np_random_seed is not None:
+ np.random.seed(np_random_seed)
+ tf_random_seed = cfg["rnd.tf_random_seed"]
+ if tf_random_seed == "auto":
+ tf_random_seed = np.random.randint(1 << 31)
+ if tf_random_seed is not None:
+ tf.set_random_seed(tf_random_seed)
+
+ # Setup environment variables.
+ for key, value in cfg.items():
+ fields = key.split(".")
+ if fields[0] == "env":
+ assert len(fields) == 2
+ os.environ[fields[1]] = str(value)
+
+ # Create default TensorFlow session.
+ create_session(cfg, force_as_default=True)
+
+
+def assert_tf_initialized():
+ """Check that TensorFlow session has been initialized."""
+ if tf.get_default_session() is None:
+ raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().")
+
+
+def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session:
+ """Create tf.Session based on config dict."""
+ # Setup TensorFlow config proto.
+ cfg = _sanitize_tf_config(config_dict)
+ config_proto = tf.ConfigProto()
+ for key, value in cfg.items():
+ fields = key.split(".")
+ if fields[0] not in ["rnd", "env"]:
+ obj = config_proto
+ for field in fields[:-1]:
+ obj = getattr(obj, field)
+ setattr(obj, fields[-1], value)
+
+ # Create session.
+ session = tf.Session(config=config_proto)
+ if force_as_default:
+ # pylint: disable=protected-access
+ session._default_session = session.as_default()
+ session._default_session.enforce_nesting = False
+ session._default_session.__enter__()
+ return session
+
+
+def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:
+ """Initialize all tf.Variables that have not already been initialized.
+
+ Equivalent to the following, but more efficient and does not bloat the tf graph:
+ tf.variables_initializer(tf.report_uninitialized_variables()).run()
+ """
+ assert_tf_initialized()
+ if target_vars is None:
+ target_vars = tf.global_variables()
+
+ test_vars = []
+ test_ops = []
+
+ with tf.control_dependencies(None): # ignore surrounding control_dependencies
+ for var in target_vars:
+ assert is_tf_expression(var)
+
+ try:
+ tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0"))
+ except KeyError:
+ # Op does not exist => variable may be uninitialized.
+ test_vars.append(var)
+
+ with absolute_name_scope(var.name.split(":")[0]):
+ test_ops.append(tf.is_variable_initialized(var))
+
+ init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
+ run([var.initializer for var in init_vars])
+
+
+def set_vars(var_to_value_dict: dict) -> None:
+ """Set the values of given tf.Variables.
+
+ Equivalent to the following, but more efficient and does not bloat the tf graph:
+ tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
+ """
+ assert_tf_initialized()
+ ops = []
+ feed_dict = {}
+
+ for var, value in var_to_value_dict.items():
+ assert is_tf_expression(var)
+
+ try:
+ setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op
+ except KeyError:
+ with absolute_name_scope(var.name.split(":")[0]):
+ with tf.control_dependencies(None): # ignore surrounding control_dependencies
+ setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter
+
+ ops.append(setter)
+ feed_dict[setter.op.inputs[1]] = value
+
+ run(ops, feed_dict)
+
+
+def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs):
+ """Create tf.Variable with large initial value without bloating the tf graph."""
+ assert_tf_initialized()
+ assert isinstance(initial_value, np.ndarray)
+ zeros = tf.zeros(initial_value.shape, initial_value.dtype)
+ var = tf.Variable(zeros, *args, **kwargs)
+ set_vars({var: initial_value})
+ return var
+
+
+def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
+ """Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
+ Can be used as an input transformation for Network.run().
+ """
+ images = tf.cast(images, tf.float32)
+ if nhwc_to_nchw:
+ images = tf.transpose(images, [0, 3, 1, 2])
+ return images * ((drange[1] - drange[0]) / 255) + drange[0]
+
+
+def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1):
+ """Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
+ Can be used as an output transformation for Network.run().
+ """
+ images = tf.cast(images, tf.float32)
+ if shrink > 1:
+ ksize = [1, 1, shrink, shrink]
+ images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW")
+ if nchw_to_nhwc:
+ images = tf.transpose(images, [0, 2, 3, 1])
+ scale = 255 / (drange[1] - drange[0])
+ images = images * scale + (0.5 - drange[0] * scale)
+ return tf.saturate_cast(images, tf.uint8)
diff --git a/PTI/dnnlib/util.py b/models/StyleCLIP/global_directions/dnnlib/util.py
similarity index 98%
rename from PTI/dnnlib/util.py
rename to models/StyleCLIP/global_directions/dnnlib/util.py
index 76725336d01e75e1c68daa88be47f4fde0bbc63b..0c35b8923bb27bcd91fd0c14234480067138a3fc 100644
--- a/PTI/dnnlib/util.py
+++ b/models/StyleCLIP/global_directions/dnnlib/util.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
@@ -75,10 +75,8 @@ class Logger(object):
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.close()
- def write(self, text: Union[str, bytes]) -> None:
+ def write(self, text: str) -> None:
"""Write text to stdout (and a file) and optionally flush."""
- if isinstance(text, bytes):
- text = text.decode()
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
return
@@ -109,7 +107,6 @@ class Logger(object):
if self.file is not None:
self.file.close()
- self.file = None
# Cache directories
@@ -450,8 +447,6 @@ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: b
if verbose:
print(" done")
break
- except KeyboardInterrupt:
- raise
except:
if not attempts_left:
if verbose:
diff --git a/models/StyleCLIP/global_directions/manipulate.py b/models/StyleCLIP/global_directions/manipulate.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1a2480caad8016fea0c06f0bfe521b25f084436
--- /dev/null
+++ b/models/StyleCLIP/global_directions/manipulate.py
@@ -0,0 +1,278 @@
+
+
+import os
+import os.path
+import pickle
+import numpy as np
+import tensorflow as tf
+from dnnlib import tflib
+from global_directions.utils.visualizer import HtmlPageVisualizer
+
+
+def Vis(bname,suffix,out,rownames=None,colnames=None):
+ num_images=out.shape[0]
+ step=out.shape[1]
+
+ if colnames is None:
+ colnames=[f'Step {i:02d}' for i in range(1, step + 1)]
+ if rownames is None:
+ rownames=[str(i) for i in range(num_images)]
+
+
+ visualizer = HtmlPageVisualizer(
+ num_rows=num_images, num_cols=step + 1, viz_size=256)
+ visualizer.set_headers(
+ ['Name'] +colnames)
+
+ for i in range(num_images):
+ visualizer.set_cell(i, 0, text=rownames[i])
+
+ for i in range(num_images):
+ for k in range(step):
+ image=out[i,k,:,:,:]
+ visualizer.set_cell(i, 1+k, image=image)
+
+ # Save results.
+ visualizer.save(f'./html/'+bname+'_'+suffix+'.html')
+
+
+
+
+def LoadData(img_path):
+ tmp=img_path+'S'
+ with open(tmp, "rb") as fp: #Pickling
+ s_names,all_s=pickle.load( fp)
+ dlatents=all_s
+
+ pindexs=[]
+ mindexs=[]
+ for i in range(len(s_names)):
+ name=s_names[i]
+ if not('ToRGB' in name):
+ mindexs.append(i)
+ else:
+ pindexs.append(i)
+
+ tmp=img_path+'S_mean_std'
+ with open(tmp, "rb") as fp: #Pickling
+ m,std=pickle.load( fp)
+
+ return dlatents,s_names,mindexs,pindexs,m,std
+
+
+def LoadModel(model_path,model_name):
+ # Initialize TensorFlow.
+ tflib.init_tf()
+ tmp=os.path.join(model_path,model_name)
+ with open(tmp, 'rb') as f:
+ _, _, Gs = pickle.load(f)
+ Gs.print_layers()
+ return Gs
+
+def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False):
+ """Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
+ Can be used as an output transformation for Network.run().
+ """
+ if nchw_to_nhwc:
+ images = np.transpose(images, [0, 2, 3, 1])
+
+ scale = 255 / (drange[1] - drange[0])
+ images = images * scale + (0.5 - drange[0] * scale)
+
+ np.clip(images, 0, 255, out=images)
+ images=images.astype('uint8')
+ return images
+
+
+def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
+ """Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
+ Can be used as an input transformation for Network.run().
+ """
+ if nhwc_to_nchw:
+ images=np.rollaxis(images, 3, 1)
+ return images/ 255 *(drange[1] - drange[0])+ drange[0]
+
+
+class Manipulator():
+ def __init__(self,dataset_name='ffhq'):
+ self.file_path='./'
+ self.img_path=self.file_path+'npy/'+dataset_name+'/'
+ self.model_path=self.file_path+'model/'
+ self.dataset_name=dataset_name
+ self.model_name=dataset_name+'.pkl'
+
+ self.alpha=[0] #manipulation strength
+ self.num_images=10
+ self.img_index=0 #which image to start
+ self.viz_size=256
+ self.manipulate_layers=None #which layer to manipulate, list
+
+ self.dlatents,self.s_names,self.mindexs,self.pindexs,self.code_mean,self.code_std=LoadData(self.img_path)
+
+ self.sess=tf.InteractiveSession()
+ init = tf.global_variables_initializer()
+ self.sess.run(init)
+ self.Gs=LoadModel(self.model_path,self.model_name)
+ self.num_layers=len(self.dlatents)
+
+ self.Vis=Vis
+ self.noise_constant={}
+
+ for i in range(len(self.s_names)):
+ tmp1=self.s_names[i].split('/')
+ if not 'ToRGB' in tmp1:
+ tmp1[-1]='random_normal:0'
+ size=int(tmp1[1].split('x')[0])
+ tmp1='/'.join(tmp1)
+ tmp=(1,1,size,size)
+ self.noise_constant[tmp1]=np.random.random(tmp)
+
+ tmp=self.Gs.components.synthesis.input_shape[1]
+ d={}
+ d['G_synthesis_1/dlatents_in:0']=np.zeros([1,tmp,512])
+ names=list(self.noise_constant.keys())
+ tmp=tflib.run(names,d)
+ for i in range(len(names)):
+ self.noise_constant[names[i]]=tmp[i]
+
+ self.fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
+ self.img_size=self.Gs.output_shape[-1]
+
+ def GenerateImg(self,codes):
+
+
+ num_images,step=codes[0].shape[:2]
+
+
+ out=np.zeros((num_images,step,self.img_size,self.img_size,3),dtype='uint8')
+ for i in range(num_images):
+ for k in range(step):
+ d={}
+ for m in range(len(self.s_names)):
+ d[self.s_names[m]]=codes[m][i,k][None,:] #need to change
+ d['G_synthesis_1/4x4/Const/Shape:0']=np.array([1,18, 512], dtype=np.int32)
+ d.update(self.noise_constant)
+ img=tflib.run('G_synthesis_1/images_out:0', d)
+ image=convert_images_to_uint8(img, nchw_to_nhwc=True)
+ out[i,k,:,:,:]=image[0]
+ return out
+
+
+
+ def MSCode(self,dlatent_tmp,boundary_tmp):
+
+ step=len(self.alpha)
+ dlatent_tmp1=[tmp.reshape((self.num_images,-1)) for tmp in dlatent_tmp]
+ dlatent_tmp2=[np.tile(tmp[:,None],(1,step,1)) for tmp in dlatent_tmp1] # (10, 7, 512)
+
+ l=np.array(self.alpha)
+ l=l.reshape(
+ [step if axis == 1 else 1 for axis in range(dlatent_tmp2[0].ndim)])
+
+ if type(self.manipulate_layers)==int:
+ tmp=[self.manipulate_layers]
+ elif type(self.manipulate_layers)==list:
+ tmp=self.manipulate_layers
+ elif self.manipulate_layers is None:
+ tmp=np.arange(len(boundary_tmp))
+ else:
+ raise ValueError('manipulate_layers is wrong')
+
+ for i in tmp:
+ dlatent_tmp2[i]+=l*boundary_tmp[i]
+
+ codes=[]
+ for i in range(len(dlatent_tmp2)):
+ tmp=list(dlatent_tmp[i].shape)
+ tmp.insert(1,step)
+ codes.append(dlatent_tmp2[i].reshape(tmp))
+ return codes
+
+
+ def EditOne(self,bname,dlatent_tmp=None):
+ if dlatent_tmp==None:
+ dlatent_tmp=[tmp[self.img_index:(self.img_index+self.num_images)] for tmp in self.dlatents]
+
+ boundary_tmp=[]
+ for i in range(len(self.boundary)):
+ tmp=self.boundary[i]
+ if len(tmp)<=bname:
+ boundary_tmp.append([])
+ else:
+ boundary_tmp.append(tmp[bname])
+
+ codes=self.MSCode(dlatent_tmp,boundary_tmp)
+
+ out=self.GenerateImg(codes)
+ return codes,out
+
+ def EditOneC(self,cindex,dlatent_tmp=None):
+ if dlatent_tmp==None:
+ dlatent_tmp=[tmp[self.img_index:(self.img_index+self.num_images)] for tmp in self.dlatents]
+
+ boundary_tmp=[[] for i in range(len(self.dlatents))]
+
+ #'only manipulate 1 layer and one channel'
+ assert len(self.manipulate_layers)==1
+
+ ml=self.manipulate_layers[0]
+ tmp=dlatent_tmp[ml].shape[1] #ada
+ tmp1=np.zeros(tmp)
+ tmp1[cindex]=self.code_std[ml][cindex] #1
+ boundary_tmp[ml]=tmp1
+
+ codes=self.MSCode(dlatent_tmp,boundary_tmp)
+ out=self.GenerateImg(codes)
+ return codes,out
+
+
+ def W2S(self,dlatent_tmp):
+
+ all_s = self.sess.run(
+ self.s_names,
+ feed_dict={'G_synthesis_1/dlatents_in:0': dlatent_tmp})
+ return all_s
+
+
+
+
+
+
+
+
+#%%
+if __name__ == "__main__":
+
+
+ M=Manipulator(dataset_name='ffhq')
+
+
+ #%%
+ M.alpha=[-5,0,5]
+ M.num_images=20
+ lindex,cindex=6,501
+
+ M.manipulate_layers=[lindex]
+ codes,out=M.EditOneC(cindex) #dlatent_tmp
+ tmp=str(M.manipulate_layers)+'_'+str(cindex)
+ M.Vis(tmp,'c',out)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/PTI/utils/__init__.py b/models/StyleCLIP/global_directions/utils/__init__.py
similarity index 100%
rename from PTI/utils/__init__.py
rename to models/StyleCLIP/global_directions/utils/__init__.py
diff --git a/models/StyleCLIP/global_directions/utils/editor.py b/models/StyleCLIP/global_directions/utils/editor.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1c2ac56fd7b4b127f948c6b8cf15874a8fe9d93
--- /dev/null
+++ b/models/StyleCLIP/global_directions/utils/editor.py
@@ -0,0 +1,507 @@
+# python 3.7
+"""Utility functions for image editing from latent space."""
+
+import os.path
+import numpy as np
+
+__all__ = [
+ 'parse_indices', 'interpolate', 'mix_style',
+ 'get_layerwise_manipulation_strength', 'manipulate', 'parse_boundary_list'
+]
+
+
+def parse_indices(obj, min_val=None, max_val=None):
+ """Parses indices.
+
+ If the input is a list or tuple, this function has no effect.
+
+ The input can also be a string, which is either a comma separated list of
+ numbers 'a, b, c', or a dash separated range 'a - c'. Space in the string will
+ be ignored.
+
+ Args:
+ obj: The input object to parse indices from.
+ min_val: If not `None`, this function will check that all indices are equal
+ to or larger than this value. (default: None)
+ max_val: If not `None`, this function will check that all indices are equal
+ to or smaller than this field. (default: None)
+
+ Returns:
+ A list of integers.
+
+ Raises:
+ If the input is invalid, i.e., neither a list or tuple, nor a string.
+ """
+ if obj is None or obj == '':
+ indices = []
+ elif isinstance(obj, int):
+ indices = [obj]
+ elif isinstance(obj, (list, tuple, np.ndarray)):
+ indices = list(obj)
+ elif isinstance(obj, str):
+ indices = []
+ splits = obj.replace(' ', '').split(',')
+ for split in splits:
+ numbers = list(map(int, split.split('-')))
+ if len(numbers) == 1:
+ indices.append(numbers[0])
+ elif len(numbers) == 2:
+ indices.extend(list(range(numbers[0], numbers[1] + 1)))
+ else:
+ raise ValueError(f'Invalid type of input: {type(obj)}!')
+
+ assert isinstance(indices, list)
+ indices = sorted(list(set(indices)))
+ for idx in indices:
+ assert isinstance(idx, int)
+ if min_val is not None:
+ assert idx >= min_val, f'{idx} is smaller than min val `{min_val}`!'
+ if max_val is not None:
+ assert idx <= max_val, f'{idx} is larger than max val `{max_val}`!'
+
+ return indices
+
+
+def interpolate(src_codes, dst_codes, step=5):
+ """Interpolates two sets of latent codes linearly.
+
+ Args:
+ src_codes: Source codes, with shape [num, *code_shape].
+ dst_codes: Target codes, with shape [num, *code_shape].
+ step: Number of interplolation steps, with source and target included. For
+ example, if `step = 5`, three more samples will be inserted. (default: 5)
+
+ Returns:
+ Interpolated codes, with shape [num, step, *code_shape].
+
+ Raises:
+ ValueError: If the input two sets of latent codes are with different shapes.
+ """
+ if not (src_codes.ndim >= 2 and src_codes.shape == dst_codes.shape):
+ raise ValueError(f'Shapes of source codes and target codes should both be '
+ f'[num, *code_shape], but {src_codes.shape} and '
+ f'{dst_codes.shape} are received!')
+ num = src_codes.shape[0]
+ code_shape = src_codes.shape[1:]
+
+ a = src_codes[:, np.newaxis]
+ b = dst_codes[:, np.newaxis]
+ l = np.linspace(0.0, 1.0, step).reshape(
+ [step if axis == 1 else 1 for axis in range(a.ndim)])
+ results = a + l * (b - a)
+ assert results.shape == (num, step, *code_shape)
+
+ return results
+
+
+def mix_style(style_codes,
+ content_codes,
+ num_layers=1,
+ mix_layers=None,
+ is_style_layerwise=True,
+ is_content_layerwise=True):
+ """Mixes styles from style codes to those of content codes.
+
+ Each style code or content code consists of `num_layers` codes, each of which
+ is typically fed into a particular layer of the generator. This function mixes
+ styles by partially replacing the codes of `content_codes` from some certain
+ layers with those of `style_codes`.
+
+ For example, if both style code and content code are with shape [10, 512],
+ meaning to have 10 layers and each employs a 512-dimensional latent code. And
+ the 1st, 2nd, and 3rd layers are the target layers to perform style mixing.
+ Then the top half of the content code (with shape [3, 512]) will be replaced
+ by the top half of the style code (also with shape [3, 512]).
+
+ NOTE: This function also supports taking single-layer latent codes as inputs,
+ i.e., setting `is_style_layerwise` or `is_content_layerwise` as False. In this
+ case, the corresponding code will be first repeated for `num_layers` before
+ performing style mixing.
+
+ Args:
+ style_codes: Style codes, with shape [num_styles, *code_shape] or
+ [num_styles, num_layers, *code_shape].
+ content_codes: Content codes, with shape [num_contents, *code_shape] or
+ [num_contents, num_layers, *code_shape].
+ num_layers: Total number of layers in the generative model. (default: 1)
+ mix_layers: Indices of the layers to perform style mixing. `None` means to
+ replace all layers, in which case the content code will be completely
+ replaced by style code. (default: None)
+ is_style_layerwise: Indicating whether the input `style_codes` are
+ layer-wise codes. (default: True)
+ is_content_layerwise: Indicating whether the input `content_codes` are
+ layer-wise codes. (default: True)
+ num_layers
+
+ Returns:
+ Codes after style mixing, with shape [num_styles, num_contents, num_layers,
+ *code_shape].
+
+ Raises:
+ ValueError: If input `content_codes` or `style_codes` is with invalid shape.
+ """
+ if not is_style_layerwise:
+ style_codes = style_codes[:, np.newaxis]
+ style_codes = np.tile(
+ style_codes,
+ [num_layers if axis == 1 else 1 for axis in range(style_codes.ndim)])
+ if not is_content_layerwise:
+ content_codes = content_codes[:, np.newaxis]
+ content_codes = np.tile(
+ content_codes,
+ [num_layers if axis == 1 else 1 for axis in range(content_codes.ndim)])
+
+ if not (style_codes.ndim >= 3 and style_codes.shape[1] == num_layers and
+ style_codes.shape[1:] == content_codes.shape[1:]):
+ raise ValueError(f'Shapes of style codes and content codes should be '
+ f'[num_styles, num_layers, *code_shape] and '
+ f'[num_contents, num_layers, *code_shape] respectively, '
+ f'but {style_codes.shape} and {content_codes.shape} are '
+ f'received!')
+
+ layer_indices = parse_indices(mix_layers, min_val=0, max_val=num_layers - 1)
+ if not layer_indices:
+ layer_indices = list(range(num_layers))
+
+ num_styles = style_codes.shape[0]
+ num_contents = content_codes.shape[0]
+ code_shape = content_codes.shape[2:]
+
+ s = style_codes[:, np.newaxis]
+ s = np.tile(s, [num_contents if axis == 1 else 1 for axis in range(s.ndim)])
+ c = content_codes[np.newaxis]
+ c = np.tile(c, [num_styles if axis == 0 else 1 for axis in range(c.ndim)])
+
+ from_style = np.zeros(s.shape, dtype=bool)
+ from_style[:, :, layer_indices] = True
+ results = np.where(from_style, s, c)
+ assert results.shape == (num_styles, num_contents, num_layers, *code_shape)
+
+ return results
+
+
+def get_layerwise_manipulation_strength(num_layers,
+ truncation_psi,
+ truncation_layers):
+ """Gets layer-wise strength for manipulation.
+
+ Recall the truncation trick played on layer [0, truncation_layers):
+
+ w = truncation_psi * w + (1 - truncation_psi) * w_avg
+
+ So, when using the same boundary to manipulate different layers, layer
+ [0, truncation_layers) and layer [truncation_layers, num_layers) should use
+ different strength to eliminate the effect from the truncation trick. More
+ concretely, the strength for layer [0, truncation_layers) is set as
+ `truncation_psi`, while that for other layers are set as 1.
+ """
+ strength = [1.0 for _ in range(num_layers)]
+ if truncation_layers > 0:
+ for layer_idx in range(0, truncation_layers):
+ strength[layer_idx] = truncation_psi
+ return strength
+
+
+def manipulate(latent_codes,
+ boundary,
+ start_distance=-5.0,
+ end_distance=5.0,
+ step=21,
+ layerwise_manipulation=False,
+ num_layers=1,
+ manipulate_layers=None,
+ is_code_layerwise=False,
+ is_boundary_layerwise=False,
+ layerwise_manipulation_strength=1.0):
+ """Manipulates the given latent codes with respect to a particular boundary.
+
+ Basically, this function takes a set of latent codes and a boundary as inputs,
+ and outputs a collection of manipulated latent codes.
+
+ For example, let `step` to be 10, `latent_codes` to be with shape [num,
+ *code_shape], and `boundary` to be with shape [1, *code_shape] and unit norm.
+ Then the output will be with shape [num, 10, *code_shape]. For each 10-element
+ manipulated codes, the first code is `start_distance` away from the original
+ code (i.e., the input) along the `boundary` direction, while the last code is
+ `end_distance` away. Remaining codes are linearly interpolated. Here,
+ `distance` is sign sensitive.
+
+ NOTE: This function also supports layer-wise manipulation, in which case the
+ generator should be able to take layer-wise latent codes as inputs. For
+ example, if the generator has 18 convolutional layers in total, and each of
+ which takes an independent latent code as input. It is possible, sometimes
+ with even better performance, to only partially manipulate these latent codes
+ corresponding to some certain layers yet keeping others untouched.
+
+ NOTE: Boundary is assumed to be normalized to unit norm already.
+
+ Args:
+ latent_codes: The input latent codes for manipulation, with shape
+ [num, *code_shape] or [num, num_layers, *code_shape].
+ boundary: The semantic boundary as reference, with shape [1, *code_shape] or
+ [1, num_layers, *code_shape].
+ start_distance: Start point for manipulation. (default: -5.0)
+ end_distance: End point for manipulation. (default: 5.0)
+ step: Number of manipulation steps. (default: 21)
+ layerwise_manipulation: Whether to perform layer-wise manipulation.
+ (default: False)
+ num_layers: Number of layers. Only active when `layerwise_manipulation` is
+ set as `True`. Should be a positive integer. (default: 1)
+ manipulate_layers: Indices of the layers to perform manipulation. `None`
+ means to manipulate latent codes from all layers. (default: None)
+ is_code_layerwise: Whether the input latent codes are layer-wise. If set as
+ `False`, the function will first repeat the input codes for `num_layers`
+ times before perform manipulation. (default: False)
+ is_boundary_layerwise: Whether the input boundary is layer-wise. If set as
+ `False`, the function will first repeat boundary for `num_layers` times
+ before perform manipulation. (default: False)
+ layerwise_manipulation_strength: Manipulation strength for each layer. Only
+ active when `layerwise_manipulation` is set as `True`. This field can be
+ used to resolve the strength discrepancy across layers when truncation
+ trick is on. See function `get_layerwise_manipulation_strength()` for
+ details. A tuple, list, or `numpy.ndarray` is expected. If set as a single
+ number, this strength will be used for all layers. (default: 1.0)
+
+ Returns:
+ Manipulated codes, with shape [num, step, *code_shape] if
+ `layerwise_manipulation` is set as `False`, or shape [num, step,
+ num_layers, *code_shape] if `layerwise_manipulation` is set as `True`.
+
+ Raises:
+ ValueError: If the input latent codes, boundary, or strength are with
+ invalid shape.
+ """
+ if not (boundary.ndim >= 2 and boundary.shape[0] == 1):
+ raise ValueError(f'Boundary should be with shape [1, *code_shape] or '
+ f'[1, num_layers, *code_shape], but '
+ f'{boundary.shape} is received!')
+
+ if not layerwise_manipulation:
+ assert not is_code_layerwise
+ assert not is_boundary_layerwise
+ num_layers = 1
+ manipulate_layers = None
+ layerwise_manipulation_strength = 1.0
+
+ # Preprocessing for layer-wise manipulation.
+ # Parse indices of manipulation layers.
+ layer_indices = parse_indices(
+ manipulate_layers, min_val=0, max_val=num_layers - 1)
+ if not layer_indices:
+ layer_indices = list(range(num_layers))
+ # Make latent codes layer-wise if needed.
+ assert num_layers > 0
+ if not is_code_layerwise:
+ x = latent_codes[:, np.newaxis]
+ x = np.tile(x, [num_layers if axis == 1 else 1 for axis in range(x.ndim)])
+ else:
+ x = latent_codes
+ if x.shape[1] != num_layers:
+ raise ValueError(f'Latent codes should be with shape [num, num_layers, '
+ f'*code_shape], where `num_layers` equals to '
+ f'{num_layers}, but {x.shape} is received!')
+ # Make boundary layer-wise if needed.
+ if not is_boundary_layerwise:
+ b = boundary
+ b = np.tile(b, [num_layers if axis == 0 else 1 for axis in range(b.ndim)])
+ else:
+ b = boundary[0]
+ if b.shape[0] != num_layers:
+ raise ValueError(f'Boundary should be with shape [num_layers, '
+ f'*code_shape], where `num_layers` equals to '
+ f'{num_layers}, but {b.shape} is received!')
+ # Get layer-wise manipulation strength.
+ if isinstance(layerwise_manipulation_strength, (int, float)):
+ s = [float(layerwise_manipulation_strength) for _ in range(num_layers)]
+ elif isinstance(layerwise_manipulation_strength, (list, tuple)):
+ s = layerwise_manipulation_strength
+ if len(s) != num_layers:
+ raise ValueError(f'Shape of layer-wise manipulation strength `{len(s)}` '
+ f'mismatches number of layers `{num_layers}`!')
+ elif isinstance(layerwise_manipulation_strength, np.ndarray):
+ s = layerwise_manipulation_strength
+ if s.size != num_layers:
+ raise ValueError(f'Shape of layer-wise manipulation strength `{s.size}` '
+ f'mismatches number of layers `{num_layers}`!')
+ else:
+ raise ValueError(f'Unsupported type of `layerwise_manipulation_strength`!')
+ s = np.array(s).reshape(
+ [num_layers if axis == 0 else 1 for axis in range(b.ndim)])
+ b = b * s
+
+ if x.shape[1:] != b.shape:
+ raise ValueError(f'Latent code shape {x.shape} and boundary shape '
+ f'{b.shape} mismatch!')
+ num = x.shape[0]
+ code_shape = x.shape[2:]
+
+ x = x[:, np.newaxis]
+ b = b[np.newaxis, np.newaxis, :]
+ l = np.linspace(start_distance, end_distance, step).reshape(
+ [step if axis == 1 else 1 for axis in range(x.ndim)])
+ results = np.tile(x, [step if axis == 1 else 1 for axis in range(x.ndim)])
+ is_manipulatable = np.zeros(results.shape, dtype=bool)
+ is_manipulatable[:, :, layer_indices] = True
+ results = np.where(is_manipulatable, x + l * b, results)
+ assert results.shape == (num, step, num_layers, *code_shape)
+
+ return results if layerwise_manipulation else results[:, :, 0]
+
+
+def manipulate2(latent_codes,
+ proj,
+ mindex,
+ start_distance=-5.0,
+ end_distance=5.0,
+ step=21,
+ layerwise_manipulation=False,
+ num_layers=1,
+ manipulate_layers=None,
+ is_code_layerwise=False,
+ layerwise_manipulation_strength=1.0):
+
+
+ if not layerwise_manipulation:
+ assert not is_code_layerwise
+# assert not is_boundary_layerwise
+ num_layers = 1
+ manipulate_layers = None
+ layerwise_manipulation_strength = 1.0
+
+ # Preprocessing for layer-wise manipulation.
+ # Parse indices of manipulation layers.
+ layer_indices = parse_indices(
+ manipulate_layers, min_val=0, max_val=num_layers - 1)
+ if not layer_indices:
+ layer_indices = list(range(num_layers))
+ # Make latent codes layer-wise if needed.
+ assert num_layers > 0
+ if not is_code_layerwise:
+ x = latent_codes[:, np.newaxis]
+ x = np.tile(x, [num_layers if axis == 1 else 1 for axis in range(x.ndim)])
+ else:
+ x = latent_codes
+ if x.shape[1] != num_layers:
+ raise ValueError(f'Latent codes should be with shape [num, num_layers, '
+ f'*code_shape], where `num_layers` equals to '
+ f'{num_layers}, but {x.shape} is received!')
+ # Make boundary layer-wise if needed.
+# if not is_boundary_layerwise:
+# b = boundary
+# b = np.tile(b, [num_layers if axis == 0 else 1 for axis in range(b.ndim)])
+# else:
+# b = boundary[0]
+# if b.shape[0] != num_layers:
+# raise ValueError(f'Boundary should be with shape [num_layers, '
+# f'*code_shape], where `num_layers` equals to '
+# f'{num_layers}, but {b.shape} is received!')
+ # Get layer-wise manipulation strength.
+ if isinstance(layerwise_manipulation_strength, (int, float)):
+ s = [float(layerwise_manipulation_strength) for _ in range(num_layers)]
+ elif isinstance(layerwise_manipulation_strength, (list, tuple)):
+ s = layerwise_manipulation_strength
+ if len(s) != num_layers:
+ raise ValueError(f'Shape of layer-wise manipulation strength `{len(s)}` '
+ f'mismatches number of layers `{num_layers}`!')
+ elif isinstance(layerwise_manipulation_strength, np.ndarray):
+ s = layerwise_manipulation_strength
+ if s.size != num_layers:
+ raise ValueError(f'Shape of layer-wise manipulation strength `{s.size}` '
+ f'mismatches number of layers `{num_layers}`!')
+ else:
+ raise ValueError(f'Unsupported type of `layerwise_manipulation_strength`!')
+# s = np.array(s).reshape(
+# [num_layers if axis == 0 else 1 for axis in range(b.ndim)])
+# b = b * s
+
+# if x.shape[1:] != b.shape:
+# raise ValueError(f'Latent code shape {x.shape} and boundary shape '
+# f'{b.shape} mismatch!')
+ num = x.shape[0]
+ code_shape = x.shape[2:]
+
+ x = x[:, np.newaxis]
+# b = b[np.newaxis, np.newaxis, :]
+# l = np.linspace(start_distance, end_distance, step).reshape(
+# [step if axis == 1 else 1 for axis in range(x.ndim)])
+ results = np.tile(x, [step if axis == 1 else 1 for axis in range(x.ndim)])
+ is_manipulatable = np.zeros(results.shape, dtype=bool)
+ is_manipulatable[:, :, layer_indices] = True
+
+ tmp=MPC(proj,x,mindex,start_distance,end_distance,step)
+ tmp = tmp[:, :,np.newaxis]
+ tmp1 = np.tile(tmp, [num_layers if axis == 2 else 1 for axis in range(tmp.ndim)])
+
+
+ results = np.where(is_manipulatable, tmp1, results)
+# print(results.shape)
+ assert results.shape == (num, step, num_layers, *code_shape)
+ return results if layerwise_manipulation else results[:, :, 0]
+
+def MPC(proj,x,mindex,start_distance,end_distance,step):
+ # x shape (batch_size,1,num_layers,feature)
+# print(x.shape)
+ x1=proj.transform(x[:,0,0,:]) #/np.sqrt(proj.explained_variance_) # (batch_size,num_pc)
+
+ x1 = x1[:, np.newaxis]
+ x1 = np.tile(x1, [step if axis == 1 else 1 for axis in range(x1.ndim)])
+
+
+ l = np.linspace(start_distance, end_distance, step)[None,:]
+ x1[:,:,mindex]+=l
+
+ tmp=x1.reshape((-1,x1.shape[-1])) #*np.sqrt(proj.explained_variance_)
+# print('xxx')
+ x2=proj.inverse_transform(tmp)
+ x2=x2.reshape((x1.shape[0],x1.shape[1],-1))
+
+# x1 = x1[:, np.newaxis]
+# x1 = np.tile(x1, [step if axis == 1 else 1 for axis in range(x1.ndim)])
+
+ return x2
+
+
+
+
+def parse_boundary_list(boundary_list_path):
+ """Parses boundary list.
+
+ Sometimes, a text file containing a list of boundaries will significantly
+ simplify image manipulation with a large amount of boundaries. This function
+ is used to parse boundary information from such list file.
+
+ Basically, each item in the list should be with format
+ `($NAME, $SPACE_TYPE): $PATH`. `DISABLE` at the beginning of the line can
+ disable a particular boundary.
+
+ Sample:
+
+ (age, z): $AGE_BOUNDARY_PATH
+ (gender, w): $GENDER_BOUNDARY_PATH
+ DISABLE(pose, wp): $POSE_BOUNDARY_PATH
+
+ Args:
+ boundary_list_path: Path to the boundary list.
+
+ Returns:
+ A dictionary, whose key is a two-element tuple (boundary_name, space_type)
+ and value is the corresponding boundary path.
+
+ Raise:
+ ValueError: If the given boundary list does not exist.
+ """
+ if not os.path.isfile(boundary_list_path):
+ raise ValueError(f'Boundary list `boundary_list_path` does not exist!')
+
+ boundaries = {}
+ with open(boundary_list_path, 'r') as f:
+ for line in f:
+ if line[:len('DISABLE')] == 'DISABLE':
+ continue
+ boundary_info, boundary_path = line.strip().split(':')
+ boundary_name, space_type = boundary_info.strip()[1:-1].split(',')
+ boundary_name = boundary_name.strip()
+ space_type = space_type.strip().lower()
+ boundary_path = boundary_path.strip()
+ boundaries[(boundary_name, space_type)] = boundary_path
+ return boundaries
diff --git a/models/StyleCLIP/global_directions/utils/train_boundary.py b/models/StyleCLIP/global_directions/utils/train_boundary.py
new file mode 100644
index 0000000000000000000000000000000000000000..710d062bc4b42913fcc5b12bd545e47af00c7123
--- /dev/null
+++ b/models/StyleCLIP/global_directions/utils/train_boundary.py
@@ -0,0 +1,158 @@
+
+import numpy as np
+from sklearn import svm
+
+
+
+
+
+def train_boundary(latent_codes,
+ scores,
+ chosen_num_or_ratio=0.02,
+ split_ratio=0.7,
+ invalid_value=None,
+ logger=None,
+ logger_name='train_boundary'):
+ """Trains boundary in latent space with offline predicted attribute scores.
+
+ Given a collection of latent codes and the attribute scores predicted from the
+ corresponding images, this function will train a linear SVM by treating it as
+ a bi-classification problem. Basically, the samples with highest attribute
+ scores are treated as positive samples, while those with lowest scores as
+ negative. For now, the latent code can ONLY be with 1 dimension.
+
+ NOTE: The returned boundary is with shape (1, latent_space_dim), and also
+ normalized with unit norm.
+
+ Args:
+ latent_codes: Input latent codes as training data.
+ scores: Input attribute scores used to generate training labels.
+ chosen_num_or_ratio: How many samples will be chosen as positive (negative)
+ samples. If this field lies in range (0, 0.5], `chosen_num_or_ratio *
+ latent_codes_num` will be used. Otherwise, `min(chosen_num_or_ratio,
+ 0.5 * latent_codes_num)` will be used. (default: 0.02)
+ split_ratio: Ratio to split training and validation sets. (default: 0.7)
+ invalid_value: This field is used to filter out data. (default: None)
+ logger: Logger for recording log messages. If set as `None`, a default
+ logger, which prints messages from all levels to screen, will be created.
+ (default: None)
+
+ Returns:
+ A decision boundary with type `numpy.ndarray`.
+
+ Raises:
+ ValueError: If the input `latent_codes` or `scores` are with invalid format.
+ """
+# if not logger:
+# logger = setup_logger(work_dir='', logger_name=logger_name)
+
+ if (not isinstance(latent_codes, np.ndarray) or
+ not len(latent_codes.shape) == 2):
+ raise ValueError(f'Input `latent_codes` should be with type'
+ f'`numpy.ndarray`, and shape [num_samples, '
+ f'latent_space_dim]!')
+ num_samples = latent_codes.shape[0]
+ latent_space_dim = latent_codes.shape[1]
+ if (not isinstance(scores, np.ndarray) or not len(scores.shape) == 2 or
+ not scores.shape[0] == num_samples or not scores.shape[1] == 1):
+ raise ValueError(f'Input `scores` should be with type `numpy.ndarray`, and '
+ f'shape [num_samples, 1], where `num_samples` should be '
+ f'exactly same as that of input `latent_codes`!')
+ if chosen_num_or_ratio <= 0:
+ raise ValueError(f'Input `chosen_num_or_ratio` should be positive, '
+ f'but {chosen_num_or_ratio} received!')
+
+# logger.info(f'Filtering training data.')
+ print('Filtering training data.')
+ if invalid_value is not None:
+ latent_codes = latent_codes[scores[:, 0] != invalid_value]
+ scores = scores[scores[:, 0] != invalid_value]
+
+# logger.info(f'Sorting scores to get positive and negative samples.')
+ print('Sorting scores to get positive and negative samples.')
+
+ sorted_idx = np.argsort(scores, axis=0)[::-1, 0]
+ latent_codes = latent_codes[sorted_idx]
+ scores = scores[sorted_idx]
+ num_samples = latent_codes.shape[0]
+ if 0 < chosen_num_or_ratio <= 1:
+ chosen_num = int(num_samples * chosen_num_or_ratio)
+ else:
+ chosen_num = int(chosen_num_or_ratio)
+ chosen_num = min(chosen_num, num_samples // 2)
+
+# logger.info(f'Spliting training and validation sets:')
+ print('Filtering training data.')
+
+ train_num = int(chosen_num * split_ratio)
+ val_num = chosen_num - train_num
+ # Positive samples.
+ positive_idx = np.arange(chosen_num)
+ np.random.shuffle(positive_idx)
+ positive_train = latent_codes[:chosen_num][positive_idx[:train_num]]
+ positive_val = latent_codes[:chosen_num][positive_idx[train_num:]]
+ # Negative samples.
+ negative_idx = np.arange(chosen_num)
+ np.random.shuffle(negative_idx)
+ negative_train = latent_codes[-chosen_num:][negative_idx[:train_num]]
+ negative_val = latent_codes[-chosen_num:][negative_idx[train_num:]]
+ # Training set.
+ train_data = np.concatenate([positive_train, negative_train], axis=0)
+ train_label = np.concatenate([np.ones(train_num, dtype=np.int),
+ np.zeros(train_num, dtype=np.int)], axis=0)
+# logger.info(f' Training: {train_num} positive, {train_num} negative.')
+ print(f' Training: {train_num} positive, {train_num} negative.')
+ # Validation set.
+ val_data = np.concatenate([positive_val, negative_val], axis=0)
+ val_label = np.concatenate([np.ones(val_num, dtype=np.int),
+ np.zeros(val_num, dtype=np.int)], axis=0)
+# logger.info(f' Validation: {val_num} positive, {val_num} negative.')
+ print(f' Validation: {val_num} positive, {val_num} negative.')
+
+ # Remaining set.
+ remaining_num = num_samples - chosen_num * 2
+ remaining_data = latent_codes[chosen_num:-chosen_num]
+ remaining_scores = scores[chosen_num:-chosen_num]
+ decision_value = (scores[0] + scores[-1]) / 2
+ remaining_label = np.ones(remaining_num, dtype=np.int)
+ remaining_label[remaining_scores.ravel() < decision_value] = 0
+ remaining_positive_num = np.sum(remaining_label == 1)
+ remaining_negative_num = np.sum(remaining_label == 0)
+# logger.info(f' Remaining: {remaining_positive_num} positive, '
+# f'{remaining_negative_num} negative.')
+ print(f' Remaining: {remaining_positive_num} positive, '
+ f'{remaining_negative_num} negative.')
+# logger.info(f'Training boundary.')
+ print(f'Training boundary.')
+
+ clf = svm.SVC(kernel='linear')
+ classifier = clf.fit(train_data, train_label)
+# logger.info(f'Finish training.')
+ print(f'Finish training.')
+
+
+ if val_num:
+ val_prediction = classifier.predict(val_data)
+ correct_num = np.sum(val_label == val_prediction)
+# logger.info(f'Accuracy for validation set: '
+# f'{correct_num} / {val_num * 2} = '
+# f'{correct_num / (val_num * 2):.6f}')
+ print(f'Accuracy for validation set: '
+ f'{correct_num} / {val_num * 2} = '
+ f'{correct_num / (val_num * 2):.6f}')
+ vacc=correct_num/len(val_label)
+ '''
+ if remaining_num:
+ remaining_prediction = classifier.predict(remaining_data)
+ correct_num = np.sum(remaining_label == remaining_prediction)
+ logger.info(f'Accuracy for remaining set: '
+ f'{correct_num} / {remaining_num} = '
+ f'{correct_num / remaining_num:.6f}')
+ '''
+ a = classifier.coef_.reshape(1, latent_space_dim).astype(np.float32)
+ return a / np.linalg.norm(a),vacc
+
+
+
+
+
diff --git a/models/StyleCLIP/global_directions/utils/visualizer.py b/models/StyleCLIP/global_directions/utils/visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c4a1fba06bf6bc680aa59bf645f796283f6f1c6
--- /dev/null
+++ b/models/StyleCLIP/global_directions/utils/visualizer.py
@@ -0,0 +1,605 @@
+# python 3.7
+"""Utility functions for visualizing results on html page."""
+
+import base64
+import os.path
+import cv2
+import numpy as np
+
+__all__ = [
+ 'get_grid_shape', 'get_blank_image', 'load_image', 'save_image',
+ 'resize_image', 'add_text_to_image', 'fuse_images', 'HtmlPageVisualizer',
+ 'VideoReader', 'VideoWriter', 'adjust_pixel_range'
+]
+
+
+def adjust_pixel_range(images, min_val=-1.0, max_val=1.0, channel_order='NCHW'):
+ """Adjusts the pixel range of the input images.
+
+ This function assumes the input array (image batch) is with shape [batch_size,
+ channel, height, width] if `channel_order = NCHW`, or with shape [batch_size,
+ height, width] if `channel_order = NHWC`. The returned images are with shape
+ [batch_size, height, width, channel] and pixel range [0, 255].
+
+ NOTE: The channel order of output images will remain the same as the input.
+
+ Args:
+ images: Input images to adjust pixel range.
+ min_val: Min value of the input images. (default: -1.0)
+ max_val: Max value of the input images. (default: 1.0)
+ channel_order: Channel order of the input array. (default: NCHW)
+
+ Returns:
+ The postprocessed images with dtype `numpy.uint8` and range [0, 255].
+
+ Raises:
+ ValueError: If the input `images` are not with type `numpy.ndarray` or the
+ shape is invalid according to `channel_order`.
+ """
+ if not isinstance(images, np.ndarray):
+ raise ValueError(f'Images should be with type `numpy.ndarray`!')
+
+ channel_order = channel_order.upper()
+ if channel_order not in ['NCHW', 'NHWC']:
+ raise ValueError(f'Invalid channel order `{channel_order}`!')
+
+ if images.ndim != 4:
+ raise ValueError(f'Input images are expected to be with shape `NCHW` or '
+ f'`NHWC`, but `{images.shape}` is received!')
+ if channel_order == 'NCHW' and images.shape[1] not in [1, 3]:
+ raise ValueError(f'Input images should have 1 or 3 channels under `NCHW` '
+ f'channel order!')
+ if channel_order == 'NHWC' and images.shape[3] not in [1, 3]:
+ raise ValueError(f'Input images should have 1 or 3 channels under `NHWC` '
+ f'channel order!')
+
+ images = images.astype(np.float32)
+ images = (images - min_val) * 255 / (max_val - min_val)
+ images = np.clip(images + 0.5, 0, 255).astype(np.uint8)
+ if channel_order == 'NCHW':
+ images = images.transpose(0, 2, 3, 1)
+
+ return images
+
+
+def get_grid_shape(size, row=0, col=0, is_portrait=False):
+ """Gets the shape of a grid based on the size.
+
+ This function makes greatest effort on making the output grid square if
+ neither `row` nor `col` is set. If `is_portrait` is set as `False`, the height
+ will always be equal to or smaller than the width. For example, if input
+ `size = 16`, output shape will be `(4, 4)`; if input `size = 15`, output shape
+ will be (3, 5). Otherwise, the height will always be equal to or larger than
+ the width.
+
+ Args:
+ size: Size (height * width) of the target grid.
+ is_portrait: Whether to return a portrait size of a landscape size.
+ (default: False)
+
+ Returns:
+ A two-element tuple, representing height and width respectively.
+ """
+ assert isinstance(size, int)
+ assert isinstance(row, int)
+ assert isinstance(col, int)
+ if size == 0:
+ return (0, 0)
+
+ if row > 0 and col > 0 and row * col != size:
+ row = 0
+ col = 0
+
+ if row > 0 and size % row == 0:
+ return (row, size // row)
+ if col > 0 and size % col == 0:
+ return (size // col, col)
+
+ row = int(np.sqrt(size))
+ while row > 0:
+ if size % row == 0:
+ col = size // row
+ break
+ row = row - 1
+
+ return (col, row) if is_portrait else (row, col)
+
+
+def get_blank_image(height, width, channels=3, is_black=True):
+ """Gets a blank image, either white of black.
+
+ NOTE: This function will always return an image with `RGB` channel order for
+ color image and pixel range [0, 255].
+
+ Args:
+ height: Height of the returned image.
+ width: Width of the returned image.
+ channels: Number of channels. (default: 3)
+ is_black: Whether to return a black image or white image. (default: True)
+ """
+ shape = (height, width, channels)
+ if is_black:
+ return np.zeros(shape, dtype=np.uint8)
+ return np.ones(shape, dtype=np.uint8) * 255
+
+
+def load_image(path):
+ """Loads an image from disk.
+
+ NOTE: This function will always return an image with `RGB` channel order for
+ color image and pixel range [0, 255].
+
+ Args:
+ path: Path to load the image from.
+
+ Returns:
+ An image with dtype `np.ndarray` or `None` if input `path` does not exist.
+ """
+ if not os.path.isfile(path):
+ return None
+
+ image = cv2.imread(path)
+ return image[:, :, ::-1]
+
+
+def save_image(path, image):
+ """Saves an image to disk.
+
+ NOTE: The input image (if colorful) is assumed to be with `RGB` channel order
+ and pixel range [0, 255].
+
+ Args:
+ path: Path to save the image to.
+ image: Image to save.
+ """
+ if image is None:
+ return
+
+ assert len(image.shape) == 3 and image.shape[2] in [1, 3]
+ cv2.imwrite(path, image[:, :, ::-1])
+
+
+def resize_image(image, *args, **kwargs):
+ """Resizes image.
+
+ This is a wrap of `cv2.resize()`.
+
+ NOTE: THe channel order of the input image will not be changed.
+
+ Args:
+ image: Image to resize.
+ """
+ if image is None:
+ return None
+
+ assert image.ndim == 3 and image.shape[2] in [1, 3]
+ image = cv2.resize(image, *args, **kwargs)
+ if image.ndim == 2:
+ return image[:, :, np.newaxis]
+ return image
+
+
+def add_text_to_image(image,
+ text='',
+ position=None,
+ font=cv2.FONT_HERSHEY_TRIPLEX,
+ font_size=1.0,
+ line_type=cv2.LINE_8,
+ line_width=1,
+ color=(255, 255, 255)):
+ """Overlays text on given image.
+
+ NOTE: The input image is assumed to be with `RGB` channel order.
+
+ Args:
+ image: The image to overlay text on.
+ text: Text content to overlay on the image. (default: '')
+ position: Target position (bottom-left corner) to add text. If not set,
+ center of the image will be used by default. (default: None)
+ font: Font of the text added. (default: cv2.FONT_HERSHEY_TRIPLEX)
+ font_size: Font size of the text added. (default: 1.0)
+ line_type: Line type used to depict the text. (default: cv2.LINE_8)
+ line_width: Line width used to depict the text. (default: 1)
+ color: Color of the text added in `RGB` channel order. (default:
+ (255, 255, 255))
+
+ Returns:
+ An image with target text overlayed on.
+ """
+ if image is None or not text:
+ return image
+
+ cv2.putText(img=image,
+ text=text,
+ org=position,
+ fontFace=font,
+ fontScale=font_size,
+ color=color,
+ thickness=line_width,
+ lineType=line_type,
+ bottomLeftOrigin=False)
+
+ return image
+
+
+def fuse_images(images,
+ image_size=None,
+ row=0,
+ col=0,
+ is_row_major=True,
+ is_portrait=False,
+ row_spacing=0,
+ col_spacing=0,
+ border_left=0,
+ border_right=0,
+ border_top=0,
+ border_bottom=0,
+ black_background=True):
+ """Fuses a collection of images into an entire image.
+
+ Args:
+ images: A collection of images to fuse. Should be with shape [num, height,
+ width, channels].
+ image_size: Int or two-element tuple. This field is used to resize the image
+ before fusing. `None` disables resizing. (default: None)
+ row: Number of rows used for image fusion. If not set, this field will be
+ automatically assigned based on `col` and total number of images.
+ (default: None)
+ col: Number of columns used for image fusion. If not set, this field will be
+ automatically assigned based on `row` and total number of images.
+ (default: None)
+ is_row_major: Whether the input images should be arranged row-major or
+ column-major. (default: True)
+ is_portrait: Only active when both `row` and `col` should be assigned
+ automatically. (default: False)
+ row_spacing: Space between rows. (default: 0)
+ col_spacing: Space between columns. (default: 0)
+ border_left: Width of left border. (default: 0)
+ border_right: Width of right border. (default: 0)
+ border_top: Width of top border. (default: 0)
+ border_bottom: Width of bottom border. (default: 0)
+
+ Returns:
+ The fused image.
+
+ Raises:
+ ValueError: If the input `images` is not with shape [num, height, width,
+ width].
+ """
+ if images is None:
+ return images
+
+ if not images.ndim == 4:
+ raise ValueError(f'Input `images` should be with shape [num, height, '
+ f'width, channels], but {images.shape} is received!')
+
+ num, image_height, image_width, channels = images.shape
+ if image_size is not None:
+ if isinstance(image_size, int):
+ image_size = (image_size, image_size)
+ assert isinstance(image_size, (list, tuple)) and len(image_size) == 2
+ width, height = image_size
+ else:
+ height, width = image_height, image_width
+ row, col = get_grid_shape(num, row=row, col=col, is_portrait=is_portrait)
+ fused_height = (
+ height * row + row_spacing * (row - 1) + border_top + border_bottom)
+ fused_width = (
+ width * col + col_spacing * (col - 1) + border_left + border_right)
+ fused_image = get_blank_image(
+ fused_height, fused_width, channels=channels, is_black=black_background)
+ images = images.reshape(row, col, image_height, image_width, channels)
+ if not is_row_major:
+ images = images.transpose(1, 0, 2, 3, 4)
+
+ for i in range(row):
+ y = border_top + i * (height + row_spacing)
+ for j in range(col):
+ x = border_left + j * (width + col_spacing)
+ if image_size is not None:
+ image = cv2.resize(images[i, j], image_size)
+ else:
+ image = images[i, j]
+ fused_image[y:y + height, x:x + width] = image
+
+ return fused_image
+
+
+def get_sortable_html_header(column_name_list, sort_by_ascending=False):
+ """Gets header for sortable html page.
+
+ Basically, the html page contains a sortable table, where user can sort the
+ rows by a particular column by clicking the column head.
+
+ Example:
+
+ column_name_list = [name_1, name_2, name_3]
+ header = get_sortable_html_header(column_name_list)
+ footer = get_sortable_html_footer()
+ sortable_table = ...
+ html_page = header + sortable_table + footer
+
+ Args:
+ column_name_list: List of column header names.
+ sort_by_ascending: Default sorting order. If set as `True`, the html page
+ will be sorted by ascending order when the header is clicked for the first
+ time.
+
+ Returns:
+ A string, which represents for the header for a sortable html page.
+ """
+ header = '\n'.join([
+ '',
+ '',
+ '',
+ '',
+ '',
+ '',
+ '',
+ '',
+ '',
+ '',
+ '',
+ '',
+ '',
+ ''])
+ for idx, column_name in enumerate(column_name_list):
+ header += f' {column_name} | \n'
+ header += '
\n'
+ header += '\n'
+ header += '\n'
+
+ return header
+
+
+def get_sortable_html_footer():
+ """Gets footer for sortable html page.
+
+ Check function `get_sortable_html_header()` for more details.
+ """
+ return '\n
\n\n\n\n'
+
+
+def encode_image_to_html_str(image, image_size=None):
+ """Encodes an image to html language.
+
+ Args:
+ image: The input image to encode. Should be with `RGB` channel order.
+ image_size: Int or two-element tuple. This field is used to resize the image
+ before encoding. `None` disables resizing. (default: None)
+
+ Returns:
+ A string which represents the encoded image.
+ """
+ if image is None:
+ return ''
+
+ assert len(image.shape) == 3 and image.shape[2] in [1, 3]
+
+ # Change channel order to `BGR`, which is opencv-friendly.
+ image = image[:, :, ::-1]
+
+ # Resize the image if needed.
+ if image_size is not None:
+ if isinstance(image_size, int):
+ image_size = (image_size, image_size)
+ assert isinstance(image_size, (list, tuple)) and len(image_size) == 2
+ image = cv2.resize(image, image_size)
+
+ # Encode the image to html-format string.
+ encoded_image = cv2.imencode(".jpg", image)[1].tostring()
+ encoded_image_base64 = base64.b64encode(encoded_image).decode('utf-8')
+ html_str = f''
+
+ return html_str
+
+
+class HtmlPageVisualizer(object):
+ """Defines the html page visualizer.
+
+ This class can be used to visualize image results as html page. Basically, it
+ is based on an html-format sorted table with helper functions
+ `get_sortable_html_header()`, `get_sortable_html_footer()`, and
+ `encode_image_to_html_str()`. To simplify the usage, specifying the following
+ fields is enough to create a visualization page:
+
+ (1) num_rows: Number of rows of the table (header-row exclusive).
+ (2) num_cols: Number of columns of the table.
+ (3) header contents (optional): Title of each column.
+
+ NOTE: `grid_size` can be used to assign `num_rows` and `num_cols`
+ automatically.
+
+ Example:
+
+ html = HtmlPageVisualizer(num_rows, num_cols)
+ html.set_headers([...])
+ for i in range(num_rows):
+ for j in range(num_cols):
+ html.set_cell(i, j, text=..., image=...)
+ html.save('visualize.html')
+ """
+
+ def __init__(self,
+ num_rows=0,
+ num_cols=0,
+ grid_size=0,
+ is_portrait=False,
+ viz_size=None):
+ if grid_size > 0:
+ num_rows, num_cols = get_grid_shape(
+ grid_size, row=num_rows, col=num_cols, is_portrait=is_portrait)
+ assert num_rows > 0 and num_cols > 0
+
+ self.num_rows = num_rows
+ self.num_cols = num_cols
+ self.viz_size = viz_size
+ self.headers = ['' for _ in range(self.num_cols)]
+ self.cells = [[{
+ 'text': '',
+ 'image': '',
+ } for _ in range(self.num_cols)] for _ in range(self.num_rows)]
+
+ def set_header(self, column_idx, content):
+ """Sets the content of a particular header by column index."""
+ self.headers[column_idx] = content
+
+ def set_headers(self, contents):
+ """Sets the contents of all headers."""
+ if isinstance(contents, str):
+ contents = [contents]
+ assert isinstance(contents, (list, tuple))
+ assert len(contents) == self.num_cols
+ for column_idx, content in enumerate(contents):
+ self.set_header(column_idx, content)
+
+ def set_cell(self, row_idx, column_idx, text='', image=None):
+ """Sets the content of a particular cell.
+
+ Basically, a cell contains some text as well as an image. Both text and
+ image can be empty.
+
+ Args:
+ row_idx: Row index of the cell to edit.
+ column_idx: Column index of the cell to edit.
+ text: Text to add into the target cell.
+ image: Image to show in the target cell. Should be with `RGB` channel
+ order.
+ """
+ self.cells[row_idx][column_idx]['text'] = text
+ self.cells[row_idx][column_idx]['image'] = encode_image_to_html_str(
+ image, self.viz_size)
+
+ def save(self, save_path):
+ """Saves the html page."""
+ html = ''
+ for i in range(self.num_rows):
+ html += f'\n'
+ for j in range(self.num_cols):
+ text = self.cells[i][j]['text']
+ image = self.cells[i][j]['image']
+ if text:
+ html += f' {text}
{image} | \n'
+ else:
+ html += f' {image} | \n'
+ html += f'
\n'
+
+ header = get_sortable_html_header(self.headers)
+ footer = get_sortable_html_footer()
+
+ with open(save_path, 'w') as f:
+ f.write(header + html + footer)
+
+
+class VideoReader(object):
+ """Defines the video reader.
+
+ This class can be used to read frames from a given video.
+ """
+
+ def __init__(self, path):
+ """Initializes the video reader by loading the video from disk."""
+ if not os.path.isfile(path):
+ raise ValueError(f'Video `{path}` does not exist!')
+
+ self.path = path
+ self.video = cv2.VideoCapture(path)
+ assert self.video.isOpened()
+ self.position = 0
+
+ self.length = int(self.video.get(cv2.CAP_PROP_FRAME_COUNT))
+ self.frame_height = int(self.video.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ self.frame_width = int(self.video.get(cv2.CAP_PROP_FRAME_WIDTH))
+ self.fps = self.video.get(cv2.CAP_PROP_FPS)
+
+ def __del__(self):
+ """Releases the opened video."""
+ self.video.release()
+
+ def read(self, position=None):
+ """Reads a certain frame.
+
+ NOTE: The returned frame is assumed to be with `RGB` channel order.
+
+ Args:
+ position: Optional. If set, the reader will read frames from the exact
+ position. Otherwise, the reader will read next frames. (default: None)
+ """
+ if position is not None and position < self.length:
+ self.video.set(cv2.CAP_PROP_POS_FRAMES, position)
+ self.position = position
+
+ success, frame = self.video.read()
+ self.position = self.position + 1
+
+ return frame[:, :, ::-1] if success else None
+
+
+class VideoWriter(object):
+ """Defines the video writer.
+
+ This class can be used to create a video.
+
+ NOTE: `.avi` and `DIVX` is the most recommended codec format since it does not
+ rely on other dependencies.
+ """
+
+ def __init__(self, path, frame_height, frame_width, fps=24, codec='DIVX'):
+ """Creates the video writer."""
+ self.path = path
+ self.frame_height = frame_height
+ self.frame_width = frame_width
+ self.fps = fps
+ self.codec = codec
+
+ self.video = cv2.VideoWriter(filename=path,
+ fourcc=cv2.VideoWriter_fourcc(*codec),
+ fps=fps,
+ frameSize=(frame_width, frame_height))
+
+ def __del__(self):
+ """Releases the opened video."""
+ self.video.release()
+
+ def write(self, frame):
+ """Writes a target frame.
+
+ NOTE: The input frame is assumed to be with `RGB` channel order.
+ """
+ self.video.write(frame[:, :, ::-1])
diff --git a/models/StyleCLIP/mapper/__init__.py b/models/StyleCLIP/mapper/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/StyleCLIP/mapper/datasets/__init__.py b/models/StyleCLIP/mapper/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/StyleCLIP/mapper/datasets/latents_dataset.py b/models/StyleCLIP/mapper/datasets/latents_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..dde6ef52b7488e864ccd2fa2930b5100c1025c17
--- /dev/null
+++ b/models/StyleCLIP/mapper/datasets/latents_dataset.py
@@ -0,0 +1,15 @@
+from torch.utils.data import Dataset
+
+
+class LatentsDataset(Dataset):
+
+ def __init__(self, latents, opts):
+ self.latents = latents
+ self.opts = opts
+
+ def __len__(self):
+ return self.latents.shape[0]
+
+ def __getitem__(self, index):
+
+ return self.latents[index]
diff --git a/models/StyleCLIP/mapper/latent_mappers.py b/models/StyleCLIP/mapper/latent_mappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..63637adc9646986a3546edd19f4555a2f75a379f
--- /dev/null
+++ b/models/StyleCLIP/mapper/latent_mappers.py
@@ -0,0 +1,81 @@
+import torch
+from torch import nn
+from torch.nn import Module
+
+from models.StyleCLIP.models.stylegan2.model import EqualLinear, PixelNorm
+
+
+class Mapper(Module):
+
+ def __init__(self, opts):
+ super(Mapper, self).__init__()
+
+ self.opts = opts
+ layers = [PixelNorm()]
+
+ for i in range(4):
+ layers.append(
+ EqualLinear(
+ 512, 512, lr_mul=0.01, activation='fused_lrelu'
+ )
+ )
+
+ self.mapping = nn.Sequential(*layers)
+
+
+ def forward(self, x):
+ x = self.mapping(x)
+ return x
+
+
+class SingleMapper(Module):
+
+ def __init__(self, opts):
+ super(SingleMapper, self).__init__()
+
+ self.opts = opts
+
+ self.mapping = Mapper(opts)
+
+ def forward(self, x):
+ out = self.mapping(x)
+ return out
+
+
+class LevelsMapper(Module):
+
+ def __init__(self, opts):
+ super(LevelsMapper, self).__init__()
+
+ self.opts = opts
+
+ if not opts.no_coarse_mapper:
+ self.course_mapping = Mapper(opts)
+ if not opts.no_medium_mapper:
+ self.medium_mapping = Mapper(opts)
+ if not opts.no_fine_mapper:
+ self.fine_mapping = Mapper(opts)
+
+ def forward(self, x):
+ x_coarse = x[:, :4, :]
+ x_medium = x[:, 4:8, :]
+ x_fine = x[:, 8:, :]
+
+ if not self.opts.no_coarse_mapper:
+ x_coarse = self.course_mapping(x_coarse)
+ else:
+ x_coarse = torch.zeros_like(x_coarse)
+ if not self.opts.no_medium_mapper:
+ x_medium = self.medium_mapping(x_medium)
+ else:
+ x_medium = torch.zeros_like(x_medium)
+ if not self.opts.no_fine_mapper:
+ x_fine = self.fine_mapping(x_fine)
+ else:
+ x_fine = torch.zeros_like(x_fine)
+
+
+ out = torch.cat([x_coarse, x_medium, x_fine], dim=1)
+
+ return out
+
diff --git a/models/StyleCLIP/mapper/options/__init__.py b/models/StyleCLIP/mapper/options/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/StyleCLIP/mapper/options/test_options.py b/models/StyleCLIP/mapper/options/test_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..aab2e5a5bba1038b97110fa6c8e8bce14de7390c
--- /dev/null
+++ b/models/StyleCLIP/mapper/options/test_options.py
@@ -0,0 +1,42 @@
+from argparse import ArgumentParser
+
+
+class TestOptions:
+
+ def __init__(self):
+ self.parser = ArgumentParser()
+ self.initialize()
+
+ def initialize(self):
+ # arguments for inference script
+ self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory')
+ self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to model checkpoint')
+ self.parser.add_argument('--couple_outputs', action='store_true',
+ help='Whether to also save inputs + outputs side-by-side')
+
+ self.parser.add_argument('--mapper_type', default='LevelsMapper', type=str, help='Which mapper to use')
+ self.parser.add_argument('--no_coarse_mapper', default=False, action="store_true")
+ self.parser.add_argument('--no_medium_mapper', default=False, action="store_true")
+ self.parser.add_argument('--no_fine_mapper', default=False, action="store_true")
+ self.parser.add_argument('--stylegan_size', default=1024, type=int)
+
+ self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference')
+ self.parser.add_argument('--latents_test_path', default=None, type=str, help="The latents for the validation")
+ self.parser.add_argument('--test_workers', default=2, type=int,
+ help='Number of test/inference dataloader workers')
+
+ self.parser.add_argument('--n_images', type=int, default=None,
+ help='Number of images to output. If None, run on all data')
+
+ self.parser.add_argument('--run_id', type=str, default='PKNWUQRQRKXQ',
+ help='The generator id to use')
+
+ self.parser.add_argument('--image_name', type=str, default='',
+ help='image to run on')
+
+ self.parser.add_argument('--edit_name', type=str, default='',
+ help='edit type')
+
+ def parse(self):
+ opts = self.parser.parse_args()
+ return opts
diff --git a/models/StyleCLIP/mapper/options/train_options.py b/models/StyleCLIP/mapper/options/train_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..a365217f8b76d38aaef4a42b90152ec7a8e7bf1f
--- /dev/null
+++ b/models/StyleCLIP/mapper/options/train_options.py
@@ -0,0 +1,49 @@
+from argparse import ArgumentParser
+
+
+class TrainOptions:
+
+ def __init__(self):
+ self.parser = ArgumentParser()
+ self.initialize()
+
+ def initialize(self):
+ self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory')
+ self.parser.add_argument('--mapper_type', default='LevelsMapper', type=str, help='Which mapper to use')
+ self.parser.add_argument('--no_coarse_mapper', default=False, action="store_true")
+ self.parser.add_argument('--no_medium_mapper', default=False, action="store_true")
+ self.parser.add_argument('--no_fine_mapper', default=False, action="store_true")
+ self.parser.add_argument('--latents_train_path', default="train_faces.pt", type=str, help="The latents for the training")
+ self.parser.add_argument('--latents_test_path', default="test_faces.pt", type=str, help="The latents for the validation")
+ self.parser.add_argument('--train_dataset_size', default=5000, type=int, help="Will be used only if no latents are given")
+ self.parser.add_argument('--test_dataset_size', default=1000, type=int, help="Will be used only if no latents are given")
+
+ self.parser.add_argument('--batch_size', default=2, type=int, help='Batch size for training')
+ self.parser.add_argument('--test_batch_size', default=1, type=int, help='Batch size for testing and inference')
+ self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers')
+ self.parser.add_argument('--test_workers', default=2, type=int, help='Number of test/inference dataloader workers')
+
+ self.parser.add_argument('--learning_rate', default=0.5, type=float, help='Optimizer learning rate')
+ self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use')
+
+ self.parser.add_argument('--id_lambda', default=0.1, type=float, help='ID loss multiplier factor')
+ self.parser.add_argument('--clip_lambda', default=1.0, type=float, help='CLIP loss multiplier factor')
+ self.parser.add_argument('--latent_l2_lambda', default=0.8, type=float, help='Latent L2 loss multiplier factor')
+
+ self.parser.add_argument('--stylegan_weights', default='../pretrained_models/stylegan2-ffhq-config-f.pt', type=str, help='Path to StyleGAN model weights')
+ self.parser.add_argument('--stylegan_size', default=1024, type=int)
+ self.parser.add_argument('--ir_se50_weights', default='../pretrained_models/model_ir_se50.pth', type=str, help="Path to facial recognition network used in ID loss")
+ self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to StyleCLIPModel model checkpoint')
+
+ self.parser.add_argument('--max_steps', default=50000, type=int, help='Maximum number of training steps')
+ self.parser.add_argument('--image_interval', default=100, type=int, help='Interval for logging train images during training')
+ self.parser.add_argument('--board_interval', default=50, type=int, help='Interval for logging metrics to tensorboard')
+ self.parser.add_argument('--val_interval', default=2000, type=int, help='Validation interval')
+ self.parser.add_argument('--save_interval', default=2000, type=int, help='Model checkpoint interval')
+
+ self.parser.add_argument('--description', required=True, type=str, help='Driving text prompt')
+
+
+ def parse(self):
+ opts = self.parser.parse_args()
+ return opts
\ No newline at end of file
diff --git a/models/StyleCLIP/mapper/scripts/inference.py b/models/StyleCLIP/mapper/scripts/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..98d765b3607bc6ecf4d137ac3a876b400269c82a
--- /dev/null
+++ b/models/StyleCLIP/mapper/scripts/inference.py
@@ -0,0 +1,80 @@
+import os
+import pickle
+from argparse import Namespace
+import torchvision
+import torch
+import sys
+import time
+
+from configs import paths_config, global_config
+from models.StyleCLIP.mapper.styleclip_mapper import StyleCLIPMapper
+from utils.models_utils import load_tuned_G, load_old_G
+
+sys.path.append(".")
+sys.path.append("..")
+
+
+def run(test_opts, model_id, image_name, use_multi_id_G):
+ out_path_results = os.path.join(test_opts.exp_dir, test_opts.data_dir_name)
+ os.makedirs(out_path_results, exist_ok=True)
+ out_path_results = os.path.join(out_path_results, test_opts.image_name)
+ os.makedirs(out_path_results, exist_ok=True)
+
+ # update test configs with configs used during training
+ ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
+ opts = ckpt['opts']
+ opts.update(vars(test_opts))
+ opts = Namespace(**opts)
+
+ net = StyleCLIPMapper(opts, test_opts.run_id)
+ net.eval()
+ net.to(global_config.device)
+
+ generator_type = paths_config.multi_id_model_type if use_multi_id_G else image_name
+
+ new_G = load_tuned_G(model_id, generator_type)
+ old_G = load_old_G()
+
+ run_styleclip(net, new_G, opts, paths_config.pti_results_keyword, out_path_results, test_opts)
+ run_styleclip(net, old_G, opts, paths_config.e4e_results_keyword, out_path_results, test_opts)
+
+
+def run_styleclip(net, G, opts, method, out_path_results, test_opts):
+ net.set_G(G)
+
+ out_path_results = os.path.join(out_path_results, method)
+ os.makedirs(out_path_results, exist_ok=True)
+
+ latent = torch.load(opts.latents_test_path)
+
+ global_i = 0
+ global_time = []
+ with torch.no_grad():
+ input_cuda = latent.cuda().float()
+ tic = time.time()
+ result_batch = run_on_batch(input_cuda, net, test_opts.couple_outputs)
+ toc = time.time()
+ global_time.append(toc - tic)
+
+ for i in range(opts.test_batch_size):
+ im_path = f'{test_opts.image_name}_{test_opts.edit_name}'
+ if test_opts.couple_outputs:
+ couple_output = torch.cat([result_batch[2][i].unsqueeze(0), result_batch[0][i].unsqueeze(0)])
+ torchvision.utils.save_image(couple_output, os.path.join(out_path_results, f"{im_path}.jpg"),
+ normalize=True, range=(-1, 1))
+ else:
+ torchvision.utils.save_image(result_batch[0][i], os.path.join(out_path_results, f"{im_path}.jpg"),
+ normalize=True, range=(-1, 1))
+ torch.save(result_batch[1][i].detach().cpu(), os.path.join(out_path_results, f"latent_{im_path}.pt"))
+
+
+def run_on_batch(inputs, net, couple_outputs=False):
+ w = inputs
+ with torch.no_grad():
+ w_hat = w + 0.06 * net.mapper(w)
+ x_hat = net.decoder.synthesis(w_hat, noise_mode='const', force_fp32=True)
+ result_batch = (x_hat, w_hat)
+ if couple_outputs:
+ x = net.decoder.synthesis(w, noise_mode='const', force_fp32=True)
+ result_batch = (x_hat, w_hat, x)
+ return result_batch
diff --git a/models/StyleCLIP/mapper/scripts/train.py b/models/StyleCLIP/mapper/scripts/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..4141436fb3edee8ab5f7576fde0c0e53b529ef66
--- /dev/null
+++ b/models/StyleCLIP/mapper/scripts/train.py
@@ -0,0 +1,32 @@
+"""
+This file runs the main training/val loop
+"""
+import os
+import json
+import sys
+import pprint
+
+sys.path.append(".")
+sys.path.append("..")
+
+from mapper.options.train_options import TrainOptions
+from mapper.training.coach import Coach
+
+
+def main(opts):
+ if os.path.exists(opts.exp_dir):
+ raise Exception('Oops... {} already exists'.format(opts.exp_dir))
+ os.makedirs(opts.exp_dir, exist_ok=True)
+
+ opts_dict = vars(opts)
+ pprint.pprint(opts_dict)
+ with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f:
+ json.dump(opts_dict, f, indent=4, sort_keys=True)
+
+ coach = Coach(opts)
+ coach.train()
+
+
+if __name__ == '__main__':
+ opts = TrainOptions().parse()
+ main(opts)
diff --git a/models/StyleCLIP/mapper/styleclip_mapper.py b/models/StyleCLIP/mapper/styleclip_mapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..86c04bee5744a551f4c0d31359e0de1f5492ff7e
--- /dev/null
+++ b/models/StyleCLIP/mapper/styleclip_mapper.py
@@ -0,0 +1,76 @@
+import torch
+from torch import nn
+from models.StyleCLIP.mapper import latent_mappers
+from models.StyleCLIP.models.stylegan2.model import Generator
+
+
+def get_keys(d, name):
+ if 'state_dict' in d:
+ d = d['state_dict']
+ d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
+ return d_filt
+
+
+class StyleCLIPMapper(nn.Module):
+
+ def __init__(self, opts, run_id):
+ super(StyleCLIPMapper, self).__init__()
+ self.opts = opts
+ # Define architecture
+ self.mapper = self.set_mapper()
+ self.run_id = run_id
+
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
+ # Load weights if needed
+ self.load_weights()
+
+ def set_mapper(self):
+ if self.opts.mapper_type == 'SingleMapper':
+ mapper = latent_mappers.SingleMapper(self.opts)
+ elif self.opts.mapper_type == 'LevelsMapper':
+ mapper = latent_mappers.LevelsMapper(self.opts)
+ else:
+ raise Exception('{} is not a valid mapper'.format(self.opts.mapper_type))
+ return mapper
+
+ def load_weights(self):
+ if self.opts.checkpoint_path is not None:
+ print('Loading from checkpoint: {}'.format(self.opts.checkpoint_path))
+ ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
+ self.mapper.load_state_dict(get_keys(ckpt, 'mapper'), strict=True)
+
+ def set_G(self, new_G):
+ self.decoder = new_G
+
+ def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
+ inject_latent=None, return_latents=False, alpha=None):
+ if input_code:
+ codes = x
+ else:
+ codes = self.mapper(x)
+
+ if latent_mask is not None:
+ for i in latent_mask:
+ if inject_latent is not None:
+ if alpha is not None:
+ codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
+ else:
+ codes[:, i] = inject_latent[:, i]
+ else:
+ codes[:, i] = 0
+
+ input_is_latent = not input_code
+ images = self.decoder.synthesis(codes, noise_mode='const')
+ result_latent = None
+ # images, result_latent = self.decoder([codes],
+ # input_is_latent=input_is_latent,
+ # randomize_noise=randomize_noise,
+ # return_latents=return_latents)
+
+ if resize:
+ images = self.face_pool(images)
+
+ if return_latents:
+ return images, result_latent
+ else:
+ return images
diff --git a/models/StyleCLIP/mapper/training/__init__.py b/models/StyleCLIP/mapper/training/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/StyleCLIP/mapper/training/coach.py b/models/StyleCLIP/mapper/training/coach.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd38eb226106a21e19beb306cd9b0de6a1e7db04
--- /dev/null
+++ b/models/StyleCLIP/mapper/training/coach.py
@@ -0,0 +1,242 @@
+import os
+
+import clip
+import torch
+import torchvision
+from torch import nn
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+
+import criteria.clip_loss as clip_loss
+from criteria import id_loss
+from mapper.datasets.latents_dataset import LatentsDataset
+from mapper.styleclip_mapper import StyleCLIPMapper
+from mapper.training.ranger import Ranger
+from mapper.training import train_utils
+
+
+class Coach:
+ def __init__(self, opts):
+ self.opts = opts
+
+ self.global_step = 0
+
+ self.device = 'cuda:0'
+ self.opts.device = self.device
+
+ # Initialize network
+ self.net = StyleCLIPMapper(self.opts).to(self.device)
+
+ # Initialize loss
+ if self.opts.id_lambda > 0:
+ self.id_loss = id_loss.IDLoss(self.opts).to(self.device).eval()
+ if self.opts.clip_lambda > 0:
+ self.clip_loss = clip_loss.CLIPLoss(opts)
+ if self.opts.latent_l2_lambda > 0:
+ self.latent_l2_loss = nn.MSELoss().to(self.device).eval()
+
+ # Initialize optimizer
+ self.optimizer = self.configure_optimizers()
+
+ # Initialize dataset
+ self.train_dataset, self.test_dataset = self.configure_datasets()
+ self.train_dataloader = DataLoader(self.train_dataset,
+ batch_size=self.opts.batch_size,
+ shuffle=True,
+ num_workers=int(self.opts.workers),
+ drop_last=True)
+ self.test_dataloader = DataLoader(self.test_dataset,
+ batch_size=self.opts.test_batch_size,
+ shuffle=False,
+ num_workers=int(self.opts.test_workers),
+ drop_last=True)
+
+ self.text_inputs = torch.cat([clip.tokenize(self.opts.description)]).cuda()
+
+ # Initialize logger
+ log_dir = os.path.join(opts.exp_dir, 'logs')
+ os.makedirs(log_dir, exist_ok=True)
+ self.log_dir = log_dir
+ self.logger = SummaryWriter(log_dir=log_dir)
+
+ # Initialize checkpoint dir
+ self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints')
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
+ self.best_val_loss = None
+ if self.opts.save_interval is None:
+ self.opts.save_interval = self.opts.max_steps
+
+ def train(self):
+ self.net.train()
+ while self.global_step < self.opts.max_steps:
+ for batch_idx, batch in enumerate(self.train_dataloader):
+ self.optimizer.zero_grad()
+ w = batch
+ w = w.to(self.device)
+ with torch.no_grad():
+ x, _ = self.net.decoder([w], input_is_latent=True, randomize_noise=False, truncation=1)
+ w_hat = w + 0.1 * self.net.mapper(w)
+ x_hat, w_hat = self.net.decoder([w_hat], input_is_latent=True, return_latents=True, randomize_noise=False, truncation=1)
+ loss, loss_dict = self.calc_loss(w, x, w_hat, x_hat)
+ loss.backward()
+ self.optimizer.step()
+
+ # Logging related
+ if self.global_step % self.opts.image_interval == 0 or (
+ self.global_step < 1000 and self.global_step % 1000 == 0):
+ self.parse_and_log_images(x, x_hat, title='images_train')
+ if self.global_step % self.opts.board_interval == 0:
+ self.print_metrics(loss_dict, prefix='train')
+ self.log_metrics(loss_dict, prefix='train')
+
+ # Validation related
+ val_loss_dict = None
+ if self.global_step % self.opts.val_interval == 0 or self.global_step == self.opts.max_steps:
+ val_loss_dict = self.validate()
+ if val_loss_dict and (self.best_val_loss is None or val_loss_dict['loss'] < self.best_val_loss):
+ self.best_val_loss = val_loss_dict['loss']
+ self.checkpoint_me(val_loss_dict, is_best=True)
+
+ if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps:
+ if val_loss_dict is not None:
+ self.checkpoint_me(val_loss_dict, is_best=False)
+ else:
+ self.checkpoint_me(loss_dict, is_best=False)
+
+ if self.global_step == self.opts.max_steps:
+ print('OMG, finished training!')
+ break
+
+ self.global_step += 1
+
+ def validate(self):
+ self.net.eval()
+ agg_loss_dict = []
+ for batch_idx, batch in enumerate(self.test_dataloader):
+ if batch_idx > 200:
+ break
+
+ w = batch
+
+ with torch.no_grad():
+ w = w.to(self.device).float()
+ x, _ = self.net.decoder([w], input_is_latent=True, randomize_noise=True, truncation=1)
+ w_hat = w + 0.1 * self.net.mapper(w)
+ x_hat, _ = self.net.decoder([w_hat], input_is_latent=True, randomize_noise=True, truncation=1)
+ loss, cur_loss_dict = self.calc_loss(w, x, w_hat, x_hat)
+ agg_loss_dict.append(cur_loss_dict)
+
+ # Logging related
+ self.parse_and_log_images(x, x_hat, title='images_val', index=batch_idx)
+
+ # For first step just do sanity test on small amount of data
+ if self.global_step == 0 and batch_idx >= 4:
+ self.net.train()
+ return None # Do not log, inaccurate in first batch
+
+ loss_dict = train_utils.aggregate_loss_dict(agg_loss_dict)
+ self.log_metrics(loss_dict, prefix='test')
+ self.print_metrics(loss_dict, prefix='test')
+
+ self.net.train()
+ return loss_dict
+
+ def checkpoint_me(self, loss_dict, is_best):
+ save_name = 'best_model.pt' if is_best else 'iteration_{}.pt'.format(self.global_step)
+ save_dict = self.__get_save_dict()
+ checkpoint_path = os.path.join(self.checkpoint_dir, save_name)
+ torch.save(save_dict, checkpoint_path)
+ with open(os.path.join(self.checkpoint_dir, 'timestamp.txt'), 'a') as f:
+ if is_best:
+ f.write('**Best**: Step - {}, Loss - {:.3f} \n{}\n'.format(self.global_step, self.best_val_loss, loss_dict))
+ else:
+ f.write('Step - {}, \n{}\n'.format(self.global_step, loss_dict))
+
+ def configure_optimizers(self):
+ params = list(self.net.mapper.parameters())
+ if self.opts.optim_name == 'adam':
+ optimizer = torch.optim.Adam(params, lr=self.opts.learning_rate)
+ else:
+ optimizer = Ranger(params, lr=self.opts.learning_rate)
+ return optimizer
+
+ def configure_datasets(self):
+ if self.opts.latents_train_path:
+ train_latents = torch.load(self.opts.latents_train_path)
+ else:
+ train_latents_z = torch.randn(self.opts.train_dataset_size, 512).cuda()
+ train_latents = []
+ for b in range(self.opts.train_dataset_size // self.opts.batch_size):
+ with torch.no_grad():
+ _, train_latents_b = self.net.decoder([train_latents_z[b: b + self.opts.batch_size]],
+ truncation=0.7, truncation_latent=self.net.latent_avg, return_latents=True)
+ train_latents.append(train_latents_b)
+ train_latents = torch.cat(train_latents)
+
+ if self.opts.latents_test_path:
+ test_latents = torch.load(self.opts.latents_test_path)
+ else:
+ test_latents_z = torch.randn(self.opts.train_dataset_size, 512).cuda()
+ test_latents = []
+ for b in range(self.opts.test_dataset_size // self.opts.test_batch_size):
+ with torch.no_grad():
+ _, test_latents_b = self.net.decoder([test_latents_z[b: b + self.opts.test_batch_size]],
+ truncation=0.7, truncation_latent=self.net.latent_avg, return_latents=True)
+ test_latents.append(test_latents_b)
+ test_latents = torch.cat(test_latents)
+
+ train_dataset_celeba = LatentsDataset(latents=train_latents.cpu(),
+ opts=self.opts)
+ test_dataset_celeba = LatentsDataset(latents=test_latents.cpu(),
+ opts=self.opts)
+ train_dataset = train_dataset_celeba
+ test_dataset = test_dataset_celeba
+ print("Number of training samples: {}".format(len(train_dataset)))
+ print("Number of test samples: {}".format(len(test_dataset)))
+ return train_dataset, test_dataset
+
+ def calc_loss(self, w, x, w_hat, x_hat):
+ loss_dict = {}
+ loss = 0.0
+ if self.opts.id_lambda > 0:
+ loss_id, sim_improvement = self.id_loss(x_hat, x)
+ loss_dict['loss_id'] = float(loss_id)
+ loss_dict['id_improve'] = float(sim_improvement)
+ loss = loss_id * self.opts.id_lambda
+ if self.opts.clip_lambda > 0:
+ loss_clip = self.clip_loss(x_hat, self.text_inputs).mean()
+ loss_dict['loss_clip'] = float(loss_clip)
+ loss += loss_clip * self.opts.clip_lambda
+ if self.opts.latent_l2_lambda > 0:
+ loss_l2_latent = self.latent_l2_loss(w_hat, w)
+ loss_dict['loss_l2_latent'] = float(loss_l2_latent)
+ loss += loss_l2_latent * self.opts.latent_l2_lambda
+ loss_dict['loss'] = float(loss)
+ return loss, loss_dict
+
+ def log_metrics(self, metrics_dict, prefix):
+ for key, value in metrics_dict.items():
+ #pass
+ print(f"step: {self.global_step} \t metric: {prefix}/{key} \t value: {value}")
+ self.logger.add_scalar('{}/{}'.format(prefix, key), value, self.global_step)
+
+ def print_metrics(self, metrics_dict, prefix):
+ print('Metrics for {}, step {}'.format(prefix, self.global_step))
+ for key, value in metrics_dict.items():
+ print('\t{} = '.format(key), value)
+
+ def parse_and_log_images(self, x, x_hat, title, index=None):
+ if index is None:
+ path = os.path.join(self.log_dir, title, f'{str(self.global_step).zfill(5)}.jpg')
+ else:
+ path = os.path.join(self.log_dir, title, f'{str(self.global_step).zfill(5)}_{str(index).zfill(5)}.jpg')
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ torchvision.utils.save_image(torch.cat([x.detach().cpu(), x_hat.detach().cpu()]), path,
+ normalize=True, scale_each=True, range=(-1, 1), nrow=self.opts.batch_size)
+
+ def __get_save_dict(self):
+ save_dict = {
+ 'state_dict': self.net.state_dict(),
+ 'opts': vars(self.opts)
+ }
+ return save_dict
\ No newline at end of file
diff --git a/models/StyleCLIP/mapper/training/ranger.py b/models/StyleCLIP/mapper/training/ranger.py
new file mode 100644
index 0000000000000000000000000000000000000000..9442fd10d42fcc19f4e0dd798d1573b31ed2c0a0
--- /dev/null
+++ b/models/StyleCLIP/mapper/training/ranger.py
@@ -0,0 +1,164 @@
+# Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer.
+
+# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
+# and/or
+# https://github.com/lessw2020/Best-Deep-Learning-Optimizers
+
+# Ranger has now been used to capture 12 records on the FastAI leaderboard.
+
+# This version = 20.4.11
+
+# Credits:
+# Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization
+# RAdam --> https://github.com/LiyuanLucasLiu/RAdam
+# Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code.
+# Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610
+
+# summary of changes:
+# 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init.
+# full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights),
+# supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues.
+# changes 8/31/19 - fix references to *self*.N_sma_threshold;
+# changed eps to 1e-5 as better default than 1e-8.
+
+import math
+import torch
+from torch.optim.optimizer import Optimizer
+
+
+class Ranger(Optimizer):
+
+ def __init__(self, params, lr=1e-3, # lr
+ alpha=0.5, k=6, N_sma_threshhold=5, # Ranger configs
+ betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam configs
+ use_gc=True, gc_conv_only=False
+ # Gradient centralization on or off, applied to conv layers only or conv + fc layers
+ ):
+
+ # parameter checks
+ if not 0.0 <= alpha <= 1.0:
+ raise ValueError(f'Invalid slow update rate: {alpha}')
+ if not 1 <= k:
+ raise ValueError(f'Invalid lookahead steps: {k}')
+ if not lr > 0:
+ raise ValueError(f'Invalid Learning Rate: {lr}')
+ if not eps > 0:
+ raise ValueError(f'Invalid eps: {eps}')
+
+ # parameter comments:
+ # beta1 (momentum) of .95 seems to work better than .90...
+ # N_sma_threshold of 5 seems better in testing than 4.
+ # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
+
+ # prep defaults and init torch.optim base
+ defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold,
+ eps=eps, weight_decay=weight_decay)
+ super().__init__(params, defaults)
+
+ # adjustable threshold
+ self.N_sma_threshhold = N_sma_threshhold
+
+ # look ahead params
+
+ self.alpha = alpha
+ self.k = k
+
+ # radam buffer for state
+ self.radam_buffer = [[None, None, None] for ind in range(10)]
+
+ # gc on or off
+ self.use_gc = use_gc
+
+ # level of gradient centralization
+ self.gc_gradient_threshold = 3 if gc_conv_only else 1
+
+ def __setstate__(self, state):
+ super(Ranger, self).__setstate__(state)
+
+ def step(self, closure=None):
+ loss = None
+
+ # Evaluate averages and grad, update param tensors
+ for group in self.param_groups:
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad.data.float()
+
+ if grad.is_sparse:
+ raise RuntimeError('Ranger optimizer does not support sparse gradients')
+
+ p_data_fp32 = p.data.float()
+
+ state = self.state[p] # get state dict for this param
+
+ if len(state) == 0: # if first time to run...init dictionary with our desired entries
+ # if self.first_run_check==0:
+ # self.first_run_check=1
+ # print("Initializing slow buffer...should not see this at load from saved model!")
+ state['step'] = 0
+ state['exp_avg'] = torch.zeros_like(p_data_fp32)
+ state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
+
+ # look ahead weight storage now in state dict
+ state['slow_buffer'] = torch.empty_like(p.data)
+ state['slow_buffer'].copy_(p.data)
+
+ else:
+ state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
+ state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
+
+ # begin computations
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+ beta1, beta2 = group['betas']
+
+ # GC operation for Conv layers and FC layers
+ if grad.dim() > self.gc_gradient_threshold:
+ grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True))
+
+ state['step'] += 1
+
+ # compute variance mov avg
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
+ # compute mean moving avg
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
+
+ buffered = self.radam_buffer[int(state['step'] % 10)]
+
+ if state['step'] == buffered[0]:
+ N_sma, step_size = buffered[1], buffered[2]
+ else:
+ buffered[0] = state['step']
+ beta2_t = beta2 ** state['step']
+ N_sma_max = 2 / (1 - beta2) - 1
+ N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
+ buffered[1] = N_sma
+ if N_sma > self.N_sma_threshhold:
+ step_size = math.sqrt(
+ (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
+ N_sma_max - 2)) / (1 - beta1 ** state['step'])
+ else:
+ step_size = 1.0 / (1 - beta1 ** state['step'])
+ buffered[2] = step_size
+
+ if group['weight_decay'] != 0:
+ p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
+
+ # apply lr
+ if N_sma > self.N_sma_threshhold:
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
+ p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
+ else:
+ p_data_fp32.add_(-step_size * group['lr'], exp_avg)
+
+ p.data.copy_(p_data_fp32)
+
+ # integrated look ahead...
+ # we do it at the param level instead of group level
+ if state['step'] % group['k'] == 0:
+ slow_p = state['slow_buffer'] # get access to slow param tensor
+ slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha
+ p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor
+
+ return loss
\ No newline at end of file
diff --git a/models/StyleCLIP/mapper/training/train_utils.py b/models/StyleCLIP/mapper/training/train_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c55177f7442010bc1fcc64de3d142585c22adc0
--- /dev/null
+++ b/models/StyleCLIP/mapper/training/train_utils.py
@@ -0,0 +1,13 @@
+
+def aggregate_loss_dict(agg_loss_dict):
+ mean_vals = {}
+ for output in agg_loss_dict:
+ for key in output:
+ mean_vals[key] = mean_vals.setdefault(key, []) + [output[key]]
+ for key in mean_vals:
+ if len(mean_vals[key]) > 0:
+ mean_vals[key] = sum(mean_vals[key]) / len(mean_vals[key])
+ else:
+ print('{} has no value'.format(key))
+ mean_vals[key] = 0
+ return mean_vals
diff --git a/models/StyleCLIP/models/__init__.py b/models/StyleCLIP/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/StyleCLIP/models/facial_recognition/__init__.py b/models/StyleCLIP/models/facial_recognition/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/StyleCLIP/models/facial_recognition/helpers.py b/models/StyleCLIP/models/facial_recognition/helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..b51fdf97141407fcc1c9d249a086ddbfd042469f
--- /dev/null
+++ b/models/StyleCLIP/models/facial_recognition/helpers.py
@@ -0,0 +1,119 @@
+from collections import namedtuple
+import torch
+from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
+
+"""
+ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
+"""
+
+
+class Flatten(Module):
+ def forward(self, input):
+ return input.view(input.size(0), -1)
+
+
+def l2_norm(input, axis=1):
+ norm = torch.norm(input, 2, axis, True)
+ output = torch.div(input, norm)
+ return output
+
+
+class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
+ """ A named tuple describing a ResNet block. """
+
+
+def get_block(in_channel, depth, num_units, stride=2):
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
+
+
+def get_blocks(num_layers):
+ if num_layers == 50:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=4),
+ get_block(in_channel=128, depth=256, num_units=14),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ elif num_layers == 100:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=13),
+ get_block(in_channel=128, depth=256, num_units=30),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ elif num_layers == 152:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=8),
+ get_block(in_channel=128, depth=256, num_units=36),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ else:
+ raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
+ return blocks
+
+
+class SEModule(Module):
+ def __init__(self, channels, reduction):
+ super(SEModule, self).__init__()
+ self.avg_pool = AdaptiveAvgPool2d(1)
+ self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
+ self.relu = ReLU(inplace=True)
+ self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
+ self.sigmoid = Sigmoid()
+
+ def forward(self, x):
+ module_input = x
+ x = self.avg_pool(x)
+ x = self.fc1(x)
+ x = self.relu(x)
+ x = self.fc2(x)
+ x = self.sigmoid(x)
+ return module_input * x
+
+
+class bottleneck_IR(Module):
+ def __init__(self, in_channel, depth, stride):
+ super(bottleneck_IR, self).__init__()
+ if in_channel == depth:
+ self.shortcut_layer = MaxPool2d(1, stride)
+ else:
+ self.shortcut_layer = Sequential(
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
+ BatchNorm2d(depth)
+ )
+ self.res_layer = Sequential(
+ BatchNorm2d(in_channel),
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
+ )
+
+ def forward(self, x):
+ shortcut = self.shortcut_layer(x)
+ res = self.res_layer(x)
+ return res + shortcut
+
+
+class bottleneck_IR_SE(Module):
+ def __init__(self, in_channel, depth, stride):
+ super(bottleneck_IR_SE, self).__init__()
+ if in_channel == depth:
+ self.shortcut_layer = MaxPool2d(1, stride)
+ else:
+ self.shortcut_layer = Sequential(
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
+ BatchNorm2d(depth)
+ )
+ self.res_layer = Sequential(
+ BatchNorm2d(in_channel),
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
+ PReLU(depth),
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
+ BatchNorm2d(depth),
+ SEModule(depth, 16)
+ )
+
+ def forward(self, x):
+ shortcut = self.shortcut_layer(x)
+ res = self.res_layer(x)
+ return res + shortcut
diff --git a/models/StyleCLIP/models/facial_recognition/model_irse.py b/models/StyleCLIP/models/facial_recognition/model_irse.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1c79e0366e4a6fd92011e86df80f8b31ec671ae
--- /dev/null
+++ b/models/StyleCLIP/models/facial_recognition/model_irse.py
@@ -0,0 +1,84 @@
+from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
+from models.facial_recognition.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
+
+"""
+Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
+"""
+
+
+class Backbone(Module):
+ def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
+ super(Backbone, self).__init__()
+ assert input_size in [112, 224], "input_size should be 112 or 224"
+ assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
+ assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
+ blocks = get_blocks(num_layers)
+ if mode == 'ir':
+ unit_module = bottleneck_IR
+ elif mode == 'ir_se':
+ unit_module = bottleneck_IR_SE
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
+ BatchNorm2d(64),
+ PReLU(64))
+ if input_size == 112:
+ self.output_layer = Sequential(BatchNorm2d(512),
+ Dropout(drop_ratio),
+ Flatten(),
+ Linear(512 * 7 * 7, 512),
+ BatchNorm1d(512, affine=affine))
+ else:
+ self.output_layer = Sequential(BatchNorm2d(512),
+ Dropout(drop_ratio),
+ Flatten(),
+ Linear(512 * 14 * 14, 512),
+ BatchNorm1d(512, affine=affine))
+
+ modules = []
+ for block in blocks:
+ for bottleneck in block:
+ modules.append(unit_module(bottleneck.in_channel,
+ bottleneck.depth,
+ bottleneck.stride))
+ self.body = Sequential(*modules)
+
+ def forward(self, x):
+ x = self.input_layer(x)
+ x = self.body(x)
+ x = self.output_layer(x)
+ return l2_norm(x)
+
+
+def IR_50(input_size):
+ """Constructs a ir-50 model."""
+ model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
+ return model
+
+
+def IR_101(input_size):
+ """Constructs a ir-101 model."""
+ model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
+ return model
+
+
+def IR_152(input_size):
+ """Constructs a ir-152 model."""
+ model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
+ return model
+
+
+def IR_SE_50(input_size):
+ """Constructs a ir_se-50 model."""
+ model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
+ return model
+
+
+def IR_SE_101(input_size):
+ """Constructs a ir_se-101 model."""
+ model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
+ return model
+
+
+def IR_SE_152(input_size):
+ """Constructs a ir_se-152 model."""
+ model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
+ return model
diff --git a/models/StyleCLIP/models/stylegan2/__init__.py b/models/StyleCLIP/models/stylegan2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/StyleCLIP/models/stylegan2/model.py b/models/StyleCLIP/models/stylegan2/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d5559203f4f3843fc814b090780ffa129a6fdf0
--- /dev/null
+++ b/models/StyleCLIP/models/stylegan2/model.py
@@ -0,0 +1,674 @@
+import math
+import random
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from models.StyleCLIP.models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
+
+
+class PixelNorm(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, input):
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
+
+
+def make_kernel(k):
+ k = torch.tensor(k, dtype=torch.float32)
+
+ if k.ndim == 1:
+ k = k[None, :] * k[:, None]
+
+ k /= k.sum()
+
+ return k
+
+
+class Upsample(nn.Module):
+ def __init__(self, kernel, factor=2):
+ super().__init__()
+
+ self.factor = factor
+ kernel = make_kernel(kernel) * (factor ** 2)
+ self.register_buffer('kernel', kernel)
+
+ p = kernel.shape[0] - factor
+
+ pad0 = (p + 1) // 2 + factor - 1
+ pad1 = p // 2
+
+ self.pad = (pad0, pad1)
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
+
+ return out
+
+
+class Downsample(nn.Module):
+ def __init__(self, kernel, factor=2):
+ super().__init__()
+
+ self.factor = factor
+ kernel = make_kernel(kernel)
+ self.register_buffer('kernel', kernel)
+
+ p = kernel.shape[0] - factor
+
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ self.pad = (pad0, pad1)
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
+
+ return out
+
+
+class Blur(nn.Module):
+ def __init__(self, kernel, pad, upsample_factor=1):
+ super().__init__()
+
+ kernel = make_kernel(kernel)
+
+ if upsample_factor > 1:
+ kernel = kernel * (upsample_factor ** 2)
+
+ self.register_buffer('kernel', kernel)
+
+ self.pad = pad
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
+
+ return out
+
+
+class EqualConv2d(nn.Module):
+ def __init__(
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
+ ):
+ super().__init__()
+
+ self.weight = nn.Parameter(
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
+ )
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
+
+ self.stride = stride
+ self.padding = padding
+
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_channel))
+
+ else:
+ self.bias = None
+
+ def forward(self, input):
+ out = F.conv2d(
+ input,
+ self.weight * self.scale,
+ bias=self.bias,
+ stride=self.stride,
+ padding=self.padding,
+ )
+
+ return out
+
+ def __repr__(self):
+ return (
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
+ )
+
+
+class EqualLinear(nn.Module):
+ def __init__(
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
+ ):
+ super().__init__()
+
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
+
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
+
+ else:
+ self.bias = None
+
+ self.activation = activation
+
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
+ self.lr_mul = lr_mul
+
+ def forward(self, input):
+ if self.activation:
+ out = F.linear(input, self.weight * self.scale)
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
+
+ else:
+ out = F.linear(
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
+ )
+
+ return out
+
+ def __repr__(self):
+ return (
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
+ )
+
+
+class ScaledLeakyReLU(nn.Module):
+ def __init__(self, negative_slope=0.2):
+ super().__init__()
+
+ self.negative_slope = negative_slope
+
+ def forward(self, input):
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
+
+ return out * math.sqrt(2)
+
+
+class ModulatedConv2d(nn.Module):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim,
+ demodulate=True,
+ upsample=False,
+ downsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ ):
+ super().__init__()
+
+ self.eps = 1e-8
+ self.kernel_size = kernel_size
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+ self.upsample = upsample
+ self.downsample = downsample
+
+ if upsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
+ pad0 = (p + 1) // 2 + factor - 1
+ pad1 = p // 2 + 1
+
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
+
+ if downsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
+
+ fan_in = in_channel * kernel_size ** 2
+ self.scale = 1 / math.sqrt(fan_in)
+ self.padding = kernel_size // 2
+
+ self.weight = nn.Parameter(
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
+ )
+
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
+
+ self.demodulate = demodulate
+
+ def __repr__(self):
+ return (
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
+ f'upsample={self.upsample}, downsample={self.downsample})'
+ )
+
+ def forward(self, input, style):
+ batch, in_channel, height, width = input.shape
+
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
+ weight = self.scale * self.weight * style
+
+ if self.demodulate:
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
+
+ weight = weight.view(
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
+ )
+
+ if self.upsample:
+ input = input.view(1, batch * in_channel, height, width)
+ weight = weight.view(
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
+ )
+ weight = weight.transpose(1, 2).reshape(
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
+ )
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
+ _, _, height, width = out.shape
+ out = out.view(batch, self.out_channel, height, width)
+ out = self.blur(out)
+
+ elif self.downsample:
+ input = self.blur(input)
+ _, _, height, width = input.shape
+ input = input.view(1, batch * in_channel, height, width)
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
+ _, _, height, width = out.shape
+ out = out.view(batch, self.out_channel, height, width)
+
+ else:
+ input = input.view(1, batch * in_channel, height, width)
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
+ _, _, height, width = out.shape
+ out = out.view(batch, self.out_channel, height, width)
+
+ return out
+
+
+class NoiseInjection(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ self.weight = nn.Parameter(torch.zeros(1))
+
+ def forward(self, image, noise=None):
+ if noise is None:
+ batch, _, height, width = image.shape
+ noise = image.new_empty(batch, 1, height, width).normal_()
+
+ return image + self.weight * noise
+
+
+class ConstantInput(nn.Module):
+ def __init__(self, channel, size=4):
+ super().__init__()
+
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
+
+ def forward(self, input):
+ batch = input.shape[0]
+ out = self.input.repeat(batch, 1, 1, 1)
+
+ return out
+
+
+class StyledConv(nn.Module):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim,
+ upsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ demodulate=True,
+ ):
+ super().__init__()
+
+ self.conv = ModulatedConv2d(
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim,
+ upsample=upsample,
+ blur_kernel=blur_kernel,
+ demodulate=demodulate,
+ )
+
+ self.noise = NoiseInjection()
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
+ # self.activate = ScaledLeakyReLU(0.2)
+ self.activate = FusedLeakyReLU(out_channel)
+
+ def forward(self, input, style, noise=None):
+ out = self.conv(input, style)
+ out = self.noise(out, noise=noise)
+ # out = out + self.bias
+ out = self.activate(out)
+
+ return out
+
+
+class ToRGB(nn.Module):
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
+ super().__init__()
+
+ if upsample:
+ self.upsample = Upsample(blur_kernel)
+
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
+
+ def forward(self, input, style, skip=None):
+ out = self.conv(input, style)
+ out = out + self.bias
+
+ if skip is not None:
+ skip = self.upsample(skip)
+
+ out = out + skip
+
+ return out
+
+
+class Generator(nn.Module):
+ def __init__(
+ self,
+ size,
+ style_dim,
+ n_mlp,
+ channel_multiplier=2,
+ blur_kernel=[1, 3, 3, 1],
+ lr_mlp=0.01,
+ ):
+ super().__init__()
+
+ self.size = size
+
+ self.style_dim = style_dim
+
+ layers = [PixelNorm()]
+
+ for i in range(n_mlp):
+ layers.append(
+ EqualLinear(
+ style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
+ )
+ )
+
+ self.style = nn.Sequential(*layers)
+
+ self.channels = {
+ 4: 512,
+ 8: 512,
+ 16: 512,
+ 32: 512,
+ 64: 256 * channel_multiplier,
+ 128: 128 * channel_multiplier,
+ 256: 64 * channel_multiplier,
+ 512: 32 * channel_multiplier,
+ 1024: 16 * channel_multiplier,
+ }
+
+ self.input = ConstantInput(self.channels[4])
+ self.conv1 = StyledConv(
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
+ )
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
+
+ self.log_size = int(math.log(size, 2))
+ self.num_layers = (self.log_size - 2) * 2 + 1
+
+ self.convs = nn.ModuleList()
+ self.upsamples = nn.ModuleList()
+ self.to_rgbs = nn.ModuleList()
+ self.noises = nn.Module()
+
+ in_channel = self.channels[4]
+
+ for layer_idx in range(self.num_layers):
+ res = (layer_idx + 5) // 2
+ shape = [1, 1, 2 ** res, 2 ** res]
+ self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
+
+ for i in range(3, self.log_size + 1):
+ out_channel = self.channels[2 ** i]
+
+ self.convs.append(
+ StyledConv(
+ in_channel,
+ out_channel,
+ 3,
+ style_dim,
+ upsample=True,
+ blur_kernel=blur_kernel,
+ )
+ )
+
+ self.convs.append(
+ StyledConv(
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
+ )
+ )
+
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
+
+ in_channel = out_channel
+
+ self.n_latent = self.log_size * 2 - 2
+
+ def make_noise(self):
+ device = self.input.input.device
+
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
+
+ for i in range(3, self.log_size + 1):
+ for _ in range(2):
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
+
+ return noises
+
+ def mean_latent(self, n_latent):
+ latent_in = torch.randn(
+ n_latent, self.style_dim, device=self.input.input.device
+ )
+ latent = self.style(latent_in).mean(0, keepdim=True)
+
+ return latent
+
+ def get_latent(self, input):
+ return self.style(input)
+
+ def forward(
+ self,
+ styles,
+ return_latents=False,
+ inject_index=None,
+ truncation=1,
+ truncation_latent=None,
+ input_is_latent=False,
+ noise=None,
+ randomize_noise=True,
+ ):
+ if not input_is_latent:
+ styles = [self.style(s) for s in styles]
+
+ if noise is None:
+ if randomize_noise:
+ noise = [None] * self.num_layers
+ else:
+ noise = [
+ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
+ ]
+
+ if truncation < 1:
+ style_t = []
+
+ for style in styles:
+ style_t.append(
+ truncation_latent + truncation * (style - truncation_latent)
+ )
+
+ styles = style_t
+
+ if len(styles) < 2:
+ inject_index = self.n_latent
+
+ if styles[0].ndim < 3:
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+
+ else:
+ latent = styles[0]
+
+ else:
+ if inject_index is None:
+ inject_index = random.randint(1, self.n_latent - 1)
+
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
+
+ latent = torch.cat([latent, latent2], 1)
+
+ out = self.input(latent)
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
+
+ skip = self.to_rgb1(out, latent[:, 1])
+
+ i = 1
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
+ ):
+ out = conv1(out, latent[:, i], noise=noise1)
+ out = conv2(out, latent[:, i + 1], noise=noise2)
+ skip = to_rgb(out, latent[:, i + 2], skip)
+
+ i += 2
+
+ image = skip
+
+ if return_latents:
+ return image, latent
+
+ else:
+ return image, None
+
+
+class ConvLayer(nn.Sequential):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ downsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ bias=True,
+ activate=True,
+ ):
+ layers = []
+
+ if downsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
+
+ stride = 2
+ self.padding = 0
+
+ else:
+ stride = 1
+ self.padding = kernel_size // 2
+
+ layers.append(
+ EqualConv2d(
+ in_channel,
+ out_channel,
+ kernel_size,
+ padding=self.padding,
+ stride=stride,
+ bias=bias and not activate,
+ )
+ )
+
+ if activate:
+ if bias:
+ layers.append(FusedLeakyReLU(out_channel))
+
+ else:
+ layers.append(ScaledLeakyReLU(0.2))
+
+ super().__init__(*layers)
+
+
+class ResBlock(nn.Module):
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
+ super().__init__()
+
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
+
+ self.skip = ConvLayer(
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
+ )
+
+ def forward(self, input):
+ out = self.conv1(input)
+ out = self.conv2(out)
+
+ skip = self.skip(input)
+ out = (out + skip) / math.sqrt(2)
+
+ return out
+
+
+class Discriminator(nn.Module):
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
+ super().__init__()
+
+ channels = {
+ 4: 512,
+ 8: 512,
+ 16: 512,
+ 32: 512,
+ 64: 256 * channel_multiplier,
+ 128: 128 * channel_multiplier,
+ 256: 64 * channel_multiplier,
+ 512: 32 * channel_multiplier,
+ 1024: 16 * channel_multiplier,
+ }
+
+ convs = [ConvLayer(3, channels[size], 1)]
+
+ log_size = int(math.log(size, 2))
+
+ in_channel = channels[size]
+
+ for i in range(log_size, 2, -1):
+ out_channel = channels[2 ** (i - 1)]
+
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
+
+ in_channel = out_channel
+
+ self.convs = nn.Sequential(*convs)
+
+ self.stddev_group = 4
+ self.stddev_feat = 1
+
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
+ self.final_linear = nn.Sequential(
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
+ EqualLinear(channels[4], 1),
+ )
+
+ def forward(self, input):
+ out = self.convs(input)
+
+ batch, channel, height, width = out.shape
+ group = min(batch, self.stddev_group)
+ stddev = out.view(
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
+ )
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
+ stddev = stddev.repeat(group, 1, height, width)
+ out = torch.cat([out, stddev], 1)
+
+ out = self.final_conv(out)
+
+ out = out.view(batch, -1)
+ out = self.final_linear(out)
+
+ return out
+
diff --git a/models/StyleCLIP/models/stylegan2/op/__init__.py b/models/StyleCLIP/models/stylegan2/op/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0918d92285955855be89f00096b888ee5597ce3
--- /dev/null
+++ b/models/StyleCLIP/models/stylegan2/op/__init__.py
@@ -0,0 +1,2 @@
+from .fused_act import FusedLeakyReLU, fused_leaky_relu
+from .upfirdn2d import upfirdn2d
diff --git a/models/StyleCLIP/models/stylegan2/op/fused_act.py b/models/StyleCLIP/models/stylegan2/op/fused_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d575bc9198e6d46eee040eb374c6d8f64c3363c
--- /dev/null
+++ b/models/StyleCLIP/models/stylegan2/op/fused_act.py
@@ -0,0 +1,40 @@
+import os
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+module_path = os.path.dirname(__file__)
+
+
+
+class FusedLeakyReLU(nn.Module):
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
+ super().__init__()
+
+ self.bias = nn.Parameter(torch.zeros(channel))
+ self.negative_slope = negative_slope
+ self.scale = scale
+
+ def forward(self, input):
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
+
+
+def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
+ input = input.cuda()
+ if input.ndim == 3:
+ return (
+ F.leaky_relu(
+ input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope
+ )
+ * scale
+ )
+ else:
+ return (
+ F.leaky_relu(
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
+ )
+ * scale
+ )
+
diff --git a/models/StyleCLIP/models/stylegan2/op/upfirdn2d.py b/models/StyleCLIP/models/stylegan2/op/upfirdn2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..02fc25af780868d9b883631eb6b03a25c225d745
--- /dev/null
+++ b/models/StyleCLIP/models/stylegan2/op/upfirdn2d.py
@@ -0,0 +1,60 @@
+import os
+
+import torch
+from torch.nn import functional as F
+
+
+module_path = os.path.dirname(__file__)
+
+
+
+def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
+ out = upfirdn2d_native(
+ input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
+ )
+
+ return out
+
+
+def upfirdn2d_native(
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
+):
+ _, channel, in_h, in_w = input.shape
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ _, in_h, in_w, minor = input.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+
+ out = F.pad(
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
+ )
+ out = out[
+ :,
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
+ :,
+ ]
+
+ out = out.permute(0, 3, 1, 2)
+ out = out.reshape(
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
+ )
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ )
+ out = out.permute(0, 2, 3, 1)
+ out = out[:, ::down_y, ::down_x, :]
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+
+ return out.view(-1, channel, out_h, out_w)
\ No newline at end of file
diff --git a/models/StyleCLIP/optimization/run_optimization.py b/models/StyleCLIP/optimization/run_optimization.py
new file mode 100644
index 0000000000000000000000000000000000000000..766d0c81400951202bed51e3f1812e1260ccf071
--- /dev/null
+++ b/models/StyleCLIP/optimization/run_optimization.py
@@ -0,0 +1,128 @@
+import argparse
+import math
+import os
+import pickle
+
+import torch
+import torchvision
+from torch import optim
+from tqdm import tqdm
+
+from StyleCLIP.criteria.clip_loss import CLIPLoss
+from StyleCLIP.models.stylegan2.model import Generator
+import clip
+from StyleCLIP.utils import ensure_checkpoint_exists
+
+
+def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
+ lr_ramp = min(1, (1 - t) / rampdown)
+ lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
+ lr_ramp = lr_ramp * min(1, t / rampup)
+
+ return initial_lr * lr_ramp
+
+
+def main(args, use_old_G):
+ ensure_checkpoint_exists(args.ckpt)
+ text_inputs = torch.cat([clip.tokenize(args.description)]).cuda()
+ os.makedirs(args.results_dir, exist_ok=True)
+ new_generator_path = f'/disk2/danielroich/Sandbox/stylegan2_ada_pytorch/checkpoints/model_{args.run_id}_{args.image_name}.pt'
+ old_generator_path = '/disk2/danielroich/Sandbox/pretrained_models/ffhq.pkl'
+
+ if not use_old_G:
+ with open(new_generator_path, 'rb') as f:
+ G = torch.load(f).cuda().eval()
+ else:
+ with open(old_generator_path, 'rb') as f:
+ G = pickle.load(f)['G_ema'].cuda().eval()
+
+ if args.latent_path:
+ latent_code_init = torch.load(args.latent_path).cuda()
+ elif args.mode == "edit":
+ latent_code_init_not_trunc = torch.randn(1, 512).cuda()
+ with torch.no_grad():
+ latent_code_init = G.mapping(latent_code_init_not_trunc, None)
+
+ latent = latent_code_init.detach().clone()
+ latent.requires_grad = True
+
+ clip_loss = CLIPLoss(args)
+
+ optimizer = optim.Adam([latent], lr=args.lr)
+
+ pbar = tqdm(range(args.step))
+
+ for i in pbar:
+ t = i / args.step
+ lr = get_lr(t, args.lr)
+ optimizer.param_groups[0]["lr"] = lr
+
+ img_gen = G.synthesis(latent, noise_mode='const')
+
+ c_loss = clip_loss(img_gen, text_inputs)
+
+ if args.mode == "edit":
+ l2_loss = ((latent_code_init - latent) ** 2).sum()
+ loss = c_loss + args.l2_lambda * l2_loss
+ else:
+ loss = c_loss
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ pbar.set_description(
+ (
+ f"loss: {loss.item():.4f};"
+ )
+ )
+ if args.save_intermediate_image_every > 0 and i % args.save_intermediate_image_every == 0:
+ with torch.no_grad():
+ img_gen = G.synthesis(latent, noise_mode='const')
+
+ torchvision.utils.save_image(img_gen,
+ f"/disk2/danielroich/Sandbox/StyleCLIP/results/inference_results/{str(i).zfill(5)}.png",
+ normalize=True, range=(-1, 1))
+
+ if args.mode == "edit":
+ with torch.no_grad():
+ img_orig = G.synthesis(latent_code_init, noise_mode='const')
+
+ final_result = torch.cat([img_orig, img_gen])
+ else:
+ final_result = img_gen
+
+ return final_result
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--description", type=str, default="a person with purple hair",
+ help="the text that guides the editing/generation")
+ parser.add_argument("--ckpt", type=str, default="../pretrained_models/stylegan2-ffhq-config-f.pt",
+ help="pretrained StyleGAN2 weights")
+ parser.add_argument("--stylegan_size", type=int, default=1024, help="StyleGAN resolution")
+ parser.add_argument("--lr_rampup", type=float, default=0.05)
+ parser.add_argument("--lr", type=float, default=0.1)
+ parser.add_argument("--step", type=int, default=300, help="number of optimization steps")
+ parser.add_argument("--mode", type=str, default="edit", choices=["edit", "free_generation"],
+ help="choose between edit an image an generate a free one")
+ parser.add_argument("--l2_lambda", type=float, default=0.008,
+ help="weight of the latent distance (used for editing only)")
+ parser.add_argument("--latent_path", type=str, default=None,
+ help="starts the optimization from the given latent code if provided. Otherwose, starts from"
+ "the mean latent in a free generation, and from a random one in editing. "
+ "Expects a .pt format")
+ parser.add_argument("--truncation", type=float, default=0.7,
+ help="used only for the initial latent vector, and only when a latent code path is"
+ "not provided")
+ parser.add_argument("--save_intermediate_image_every", type=int, default=20,
+ help="if > 0 then saves intermidate results during the optimization")
+ parser.add_argument("--results_dir", type=str, default="results")
+
+ args = parser.parse_args()
+
+ result_image = main(args)
+
+ torchvision.utils.save_image(result_image.detach().cpu(), os.path.join(args.results_dir, "final_result.jpg"),
+ normalize=True, scale_each=True, range=(-1, 1))
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/e4e/__init__.py b/models/e4e/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/e4e/discriminator.py b/models/e4e/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..16bf3722c7f2e35cdc9bd177a33ed0975e67200d
--- /dev/null
+++ b/models/e4e/discriminator.py
@@ -0,0 +1,20 @@
+from torch import nn
+
+
+class LatentCodesDiscriminator(nn.Module):
+ def __init__(self, style_dim, n_mlp):
+ super().__init__()
+
+ self.style_dim = style_dim
+
+ layers = []
+ for i in range(n_mlp-1):
+ layers.append(
+ nn.Linear(style_dim, style_dim)
+ )
+ layers.append(nn.LeakyReLU(0.2))
+ layers.append(nn.Linear(512, 1))
+ self.mlp = nn.Sequential(*layers)
+
+ def forward(self, w):
+ return self.mlp(w)
diff --git a/models/e4e/encoders/__init__.py b/models/e4e/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/e4e/encoders/helpers.py b/models/e4e/encoders/helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4a58b34ea5ca6912fe53c63dede0a8696f5c024
--- /dev/null
+++ b/models/e4e/encoders/helpers.py
@@ -0,0 +1,140 @@
+from collections import namedtuple
+import torch
+import torch.nn.functional as F
+from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
+
+"""
+ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
+"""
+
+
+class Flatten(Module):
+ def forward(self, input):
+ return input.view(input.size(0), -1)
+
+
+def l2_norm(input, axis=1):
+ norm = torch.norm(input, 2, axis, True)
+ output = torch.div(input, norm)
+ return output
+
+
+class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
+ """ A named tuple describing a ResNet block. """
+
+
+def get_block(in_channel, depth, num_units, stride=2):
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
+
+
+def get_blocks(num_layers):
+ if num_layers == 50:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=4),
+ get_block(in_channel=128, depth=256, num_units=14),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ elif num_layers == 100:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=13),
+ get_block(in_channel=128, depth=256, num_units=30),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ elif num_layers == 152:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=8),
+ get_block(in_channel=128, depth=256, num_units=36),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ else:
+ raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
+ return blocks
+
+
+class SEModule(Module):
+ def __init__(self, channels, reduction):
+ super(SEModule, self).__init__()
+ self.avg_pool = AdaptiveAvgPool2d(1)
+ self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
+ self.relu = ReLU(inplace=True)
+ self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
+ self.sigmoid = Sigmoid()
+
+ def forward(self, x):
+ module_input = x
+ x = self.avg_pool(x)
+ x = self.fc1(x)
+ x = self.relu(x)
+ x = self.fc2(x)
+ x = self.sigmoid(x)
+ return module_input * x
+
+
+class bottleneck_IR(Module):
+ def __init__(self, in_channel, depth, stride):
+ super(bottleneck_IR, self).__init__()
+ if in_channel == depth:
+ self.shortcut_layer = MaxPool2d(1, stride)
+ else:
+ self.shortcut_layer = Sequential(
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
+ BatchNorm2d(depth)
+ )
+ self.res_layer = Sequential(
+ BatchNorm2d(in_channel),
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
+ )
+
+ def forward(self, x):
+ shortcut = self.shortcut_layer(x)
+ res = self.res_layer(x)
+ return res + shortcut
+
+
+class bottleneck_IR_SE(Module):
+ def __init__(self, in_channel, depth, stride):
+ super(bottleneck_IR_SE, self).__init__()
+ if in_channel == depth:
+ self.shortcut_layer = MaxPool2d(1, stride)
+ else:
+ self.shortcut_layer = Sequential(
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
+ BatchNorm2d(depth)
+ )
+ self.res_layer = Sequential(
+ BatchNorm2d(in_channel),
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
+ PReLU(depth),
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
+ BatchNorm2d(depth),
+ SEModule(depth, 16)
+ )
+
+ def forward(self, x):
+ shortcut = self.shortcut_layer(x)
+ res = self.res_layer(x)
+ return res + shortcut
+
+
+def _upsample_add(x, y):
+ """Upsample and add two feature maps.
+ Args:
+ x: (Variable) top feature map to be upsampled.
+ y: (Variable) lateral feature map.
+ Returns:
+ (Variable) added feature map.
+ Note in PyTorch, when input size is odd, the upsampled feature map
+ with `F.upsample(..., scale_factor=2, mode='nearest')`
+ maybe not equal to the lateral feature map size.
+ e.g.
+ original input size: [N,_,15,15] ->
+ conv2d feature map size: [N,_,8,8] ->
+ upsampled feature map size: [N,_,16,16]
+ So we choose bilinear upsample which supports arbitrary output sizes.
+ """
+ _, _, H, W = y.size()
+ return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
diff --git a/models/e4e/encoders/model_irse.py b/models/e4e/encoders/model_irse.py
new file mode 100644
index 0000000000000000000000000000000000000000..976ce2c61104efdc6b0015d895830346dd01bc10
--- /dev/null
+++ b/models/e4e/encoders/model_irse.py
@@ -0,0 +1,84 @@
+from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
+from encoder4editing.models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
+
+"""
+Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
+"""
+
+
+class Backbone(Module):
+ def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
+ super(Backbone, self).__init__()
+ assert input_size in [112, 224], "input_size should be 112 or 224"
+ assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
+ assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
+ blocks = get_blocks(num_layers)
+ if mode == 'ir':
+ unit_module = bottleneck_IR
+ elif mode == 'ir_se':
+ unit_module = bottleneck_IR_SE
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
+ BatchNorm2d(64),
+ PReLU(64))
+ if input_size == 112:
+ self.output_layer = Sequential(BatchNorm2d(512),
+ Dropout(drop_ratio),
+ Flatten(),
+ Linear(512 * 7 * 7, 512),
+ BatchNorm1d(512, affine=affine))
+ else:
+ self.output_layer = Sequential(BatchNorm2d(512),
+ Dropout(drop_ratio),
+ Flatten(),
+ Linear(512 * 14 * 14, 512),
+ BatchNorm1d(512, affine=affine))
+
+ modules = []
+ for block in blocks:
+ for bottleneck in block:
+ modules.append(unit_module(bottleneck.in_channel,
+ bottleneck.depth,
+ bottleneck.stride))
+ self.body = Sequential(*modules)
+
+ def forward(self, x):
+ x = self.input_layer(x)
+ x = self.body(x)
+ x = self.output_layer(x)
+ return l2_norm(x)
+
+
+def IR_50(input_size):
+ """Constructs a ir-50 model."""
+ model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
+ return model
+
+
+def IR_101(input_size):
+ """Constructs a ir-101 model."""
+ model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
+ return model
+
+
+def IR_152(input_size):
+ """Constructs a ir-152 model."""
+ model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
+ return model
+
+
+def IR_SE_50(input_size):
+ """Constructs a ir_se-50 model."""
+ model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
+ return model
+
+
+def IR_SE_101(input_size):
+ """Constructs a ir_se-101 model."""
+ model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
+ return model
+
+
+def IR_SE_152(input_size):
+ """Constructs a ir_se-152 model."""
+ model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
+ return model
diff --git a/models/e4e/encoders/psp_encoders.py b/models/e4e/encoders/psp_encoders.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c7c70e5e2586bd6a0de825e45a80e9116156166
--- /dev/null
+++ b/models/e4e/encoders/psp_encoders.py
@@ -0,0 +1,200 @@
+from enum import Enum
+import math
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module
+
+from models.e4e.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add
+from models.e4e.stylegan2.model import EqualLinear
+
+
+class ProgressiveStage(Enum):
+ WTraining = 0
+ Delta1Training = 1
+ Delta2Training = 2
+ Delta3Training = 3
+ Delta4Training = 4
+ Delta5Training = 5
+ Delta6Training = 6
+ Delta7Training = 7
+ Delta8Training = 8
+ Delta9Training = 9
+ Delta10Training = 10
+ Delta11Training = 11
+ Delta12Training = 12
+ Delta13Training = 13
+ Delta14Training = 14
+ Delta15Training = 15
+ Delta16Training = 16
+ Delta17Training = 17
+ Inference = 18
+
+
+class GradualStyleBlock(Module):
+ def __init__(self, in_c, out_c, spatial):
+ super(GradualStyleBlock, self).__init__()
+ self.out_c = out_c
+ self.spatial = spatial
+ num_pools = int(np.log2(spatial))
+ modules = []
+ modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
+ nn.LeakyReLU()]
+ for i in range(num_pools - 1):
+ modules += [
+ Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
+ nn.LeakyReLU()
+ ]
+ self.convs = nn.Sequential(*modules)
+ self.linear = EqualLinear(out_c, out_c, lr_mul=1)
+
+ def forward(self, x):
+ x = self.convs(x)
+ x = x.view(-1, self.out_c)
+ x = self.linear(x)
+ return x
+
+
+class GradualStyleEncoder(Module):
+ def __init__(self, num_layers, mode='ir', opts=None):
+ super(GradualStyleEncoder, self).__init__()
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
+ blocks = get_blocks(num_layers)
+ if mode == 'ir':
+ unit_module = bottleneck_IR
+ elif mode == 'ir_se':
+ unit_module = bottleneck_IR_SE
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
+ BatchNorm2d(64),
+ PReLU(64))
+ modules = []
+ for block in blocks:
+ for bottleneck in block:
+ modules.append(unit_module(bottleneck.in_channel,
+ bottleneck.depth,
+ bottleneck.stride))
+ self.body = Sequential(*modules)
+
+ self.styles = nn.ModuleList()
+ log_size = int(math.log(opts.stylegan_size, 2))
+ self.style_count = 2 * log_size - 2
+ self.coarse_ind = 3
+ self.middle_ind = 7
+ for i in range(self.style_count):
+ if i < self.coarse_ind:
+ style = GradualStyleBlock(512, 512, 16)
+ elif i < self.middle_ind:
+ style = GradualStyleBlock(512, 512, 32)
+ else:
+ style = GradualStyleBlock(512, 512, 64)
+ self.styles.append(style)
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x):
+ x = self.input_layer(x)
+
+ latents = []
+ modulelist = list(self.body._modules.values())
+ for i, l in enumerate(modulelist):
+ x = l(x)
+ if i == 6:
+ c1 = x
+ elif i == 20:
+ c2 = x
+ elif i == 23:
+ c3 = x
+
+ for j in range(self.coarse_ind):
+ latents.append(self.styles[j](c3))
+
+ p2 = _upsample_add(c3, self.latlayer1(c2))
+ for j in range(self.coarse_ind, self.middle_ind):
+ latents.append(self.styles[j](p2))
+
+ p1 = _upsample_add(p2, self.latlayer2(c1))
+ for j in range(self.middle_ind, self.style_count):
+ latents.append(self.styles[j](p1))
+
+ out = torch.stack(latents, dim=1)
+ return out
+
+
+class Encoder4Editing(Module):
+ def __init__(self, num_layers, mode='ir', opts=None):
+ super(Encoder4Editing, self).__init__()
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
+ blocks = get_blocks(num_layers)
+ if mode == 'ir':
+ unit_module = bottleneck_IR
+ elif mode == 'ir_se':
+ unit_module = bottleneck_IR_SE
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
+ BatchNorm2d(64),
+ PReLU(64))
+ modules = []
+ for block in blocks:
+ for bottleneck in block:
+ modules.append(unit_module(bottleneck.in_channel,
+ bottleneck.depth,
+ bottleneck.stride))
+ self.body = Sequential(*modules)
+
+ self.styles = nn.ModuleList()
+ log_size = int(math.log(opts.stylegan_size, 2))
+ self.style_count = 2 * log_size - 2
+ self.coarse_ind = 3
+ self.middle_ind = 7
+
+ for i in range(self.style_count):
+ if i < self.coarse_ind:
+ style = GradualStyleBlock(512, 512, 16)
+ elif i < self.middle_ind:
+ style = GradualStyleBlock(512, 512, 32)
+ else:
+ style = GradualStyleBlock(512, 512, 64)
+ self.styles.append(style)
+
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
+
+ self.progressive_stage = ProgressiveStage.Inference
+
+ def get_deltas_starting_dimensions(self):
+ ''' Get a list of the initial dimension of every delta from which it is applied '''
+ return list(range(self.style_count)) # Each dimension has a delta applied to it
+
+ def set_progressive_stage(self, new_stage: ProgressiveStage):
+ self.progressive_stage = new_stage
+ print('Changed progressive stage to: ', new_stage)
+
+ def forward(self, x):
+ x = self.input_layer(x)
+
+ modulelist = list(self.body._modules.values())
+ for i, l in enumerate(modulelist):
+ x = l(x)
+ if i == 6:
+ c1 = x
+ elif i == 20:
+ c2 = x
+ elif i == 23:
+ c3 = x
+
+ # Infer main W and duplicate it
+ w0 = self.styles[0](c3)
+ w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2)
+ stage = self.progressive_stage.value
+ features = c3
+ for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas
+ if i == self.coarse_ind:
+ p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features
+ features = p2
+ elif i == self.middle_ind:
+ p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features
+ features = p1
+ delta_i = self.styles[i](features)
+ w[:, i] += delta_i
+ return w
diff --git a/models/e4e/latent_codes_pool.py b/models/e4e/latent_codes_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..0281d4b5e80f8eb26e824fa35b4f908dcb6634e6
--- /dev/null
+++ b/models/e4e/latent_codes_pool.py
@@ -0,0 +1,55 @@
+import random
+import torch
+
+
+class LatentCodesPool:
+ """This class implements latent codes buffer that stores previously generated w latent codes.
+ This buffer enables us to update discriminators using a history of generated w's
+ rather than the ones produced by the latest encoder.
+ """
+
+ def __init__(self, pool_size):
+ """Initialize the ImagePool class
+ Parameters:
+ pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
+ """
+ self.pool_size = pool_size
+ if self.pool_size > 0: # create an empty pool
+ self.num_ws = 0
+ self.ws = []
+
+ def query(self, ws):
+ """Return w's from the pool.
+ Parameters:
+ ws: the latest generated w's from the generator
+ Returns w's from the buffer.
+ By 50/100, the buffer will return input w's.
+ By 50/100, the buffer will return w's previously stored in the buffer,
+ and insert the current w's to the buffer.
+ """
+ if self.pool_size == 0: # if the buffer size is 0, do nothing
+ return ws
+ return_ws = []
+ for w in ws: # ws.shape: (batch, 512) or (batch, n_latent, 512)
+ # w = torch.unsqueeze(image.data, 0)
+ if w.ndim == 2:
+ i = random.randint(0, len(w) - 1) # apply a random latent index as a candidate
+ w = w[i]
+ self.handle_w(w, return_ws)
+ return_ws = torch.stack(return_ws, 0) # collect all the images and return
+ return return_ws
+
+ def handle_w(self, w, return_ws):
+ if self.num_ws < self.pool_size: # if the buffer is not full; keep inserting current codes to the buffer
+ self.num_ws = self.num_ws + 1
+ self.ws.append(w)
+ return_ws.append(w)
+ else:
+ p = random.uniform(0, 1)
+ if p > 0.5: # by 50% chance, the buffer will return a previously stored latent code, and insert the current code into the buffer
+ random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
+ tmp = self.ws[random_id].clone()
+ self.ws[random_id] = w
+ return_ws.append(tmp)
+ else: # by another 50% chance, the buffer will return the current image
+ return_ws.append(w)
diff --git a/models/e4e/psp.py b/models/e4e/psp.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf9f75dbaa66997abfc1e3e0e4f19ddfec7fedac
--- /dev/null
+++ b/models/e4e/psp.py
@@ -0,0 +1,97 @@
+import matplotlib
+from configs import paths_config
+matplotlib.use('Agg')
+import torch
+from torch import nn
+from models.e4e.encoders import psp_encoders
+from models.e4e.stylegan2.model import Generator
+
+
+def get_keys(d, name):
+ if 'state_dict' in d:
+ d = d['state_dict']
+ d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
+ return d_filt
+
+
+class pSp(nn.Module):
+
+ def __init__(self, opts):
+ super(pSp, self).__init__()
+ self.opts = opts
+ # Define architecture
+ self.encoder = self.set_encoder()
+ self.decoder = Generator(opts.stylegan_size, 512, 8, channel_multiplier=2)
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
+ # Load weights if needed
+ self.load_weights()
+
+ def set_encoder(self):
+ if self.opts.encoder_type == 'GradualStyleEncoder':
+ encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts)
+ elif self.opts.encoder_type == 'Encoder4Editing':
+ encoder = psp_encoders.Encoder4Editing(50, 'ir_se', self.opts)
+ else:
+ raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
+ return encoder
+
+ def load_weights(self):
+ if self.opts.checkpoint_path is not None:
+ print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.checkpoint_path))
+ ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
+ self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
+ self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
+ self.__load_latent_avg(ckpt)
+ else:
+ print('Loading encoders weights from irse50!')
+ encoder_ckpt = torch.load(paths_config.ir_se50)
+ self.encoder.load_state_dict(encoder_ckpt, strict=False)
+ print('Loading decoder weights from pretrained!')
+ ckpt = torch.load(self.opts.stylegan_weights)
+ self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
+ self.__load_latent_avg(ckpt, repeat=self.encoder.style_count)
+
+ def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
+ inject_latent=None, return_latents=False, alpha=None):
+ if input_code:
+ codes = x
+ else:
+ codes = self.encoder(x)
+ # normalize with respect to the center of an average face
+ if self.opts.start_from_latent_avg:
+ if codes.ndim == 2:
+ codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
+ else:
+ codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
+
+ if latent_mask is not None:
+ for i in latent_mask:
+ if inject_latent is not None:
+ if alpha is not None:
+ codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
+ else:
+ codes[:, i] = inject_latent[:, i]
+ else:
+ codes[:, i] = 0
+
+ input_is_latent = not input_code
+ images, result_latent = self.decoder([codes],
+ input_is_latent=input_is_latent,
+ randomize_noise=randomize_noise,
+ return_latents=return_latents)
+
+ if resize:
+ images = self.face_pool(images)
+
+ if return_latents:
+ return images, result_latent
+ else:
+ return images
+
+ def __load_latent_avg(self, ckpt, repeat=None):
+ if 'latent_avg' in ckpt:
+ self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
+ if repeat is not None:
+ self.latent_avg = self.latent_avg.repeat(repeat, 1)
+ else:
+ self.latent_avg = None
diff --git a/models/e4e/stylegan2/__init__.py b/models/e4e/stylegan2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/e4e/stylegan2/model.py b/models/e4e/stylegan2/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ede4360148e260363887662bae7fe68c987ee60e
--- /dev/null
+++ b/models/e4e/stylegan2/model.py
@@ -0,0 +1,674 @@
+import math
+import random
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .op.fused_act import FusedLeakyReLU, fused_leaky_relu
+from .op.upfirdn2d import upfirdn2d
+
+
+class PixelNorm(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, input):
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
+
+
+def make_kernel(k):
+ k = torch.tensor(k, dtype=torch.float32)
+
+ if k.ndim == 1:
+ k = k[None, :] * k[:, None]
+
+ k /= k.sum()
+
+ return k
+
+
+class Upsample(nn.Module):
+ def __init__(self, kernel, factor=2):
+ super().__init__()
+
+ self.factor = factor
+ kernel = make_kernel(kernel) * (factor ** 2)
+ self.register_buffer('kernel', kernel)
+
+ p = kernel.shape[0] - factor
+
+ pad0 = (p + 1) // 2 + factor - 1
+ pad1 = p // 2
+
+ self.pad = (pad0, pad1)
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
+
+ return out
+
+
+class Downsample(nn.Module):
+ def __init__(self, kernel, factor=2):
+ super().__init__()
+
+ self.factor = factor
+ kernel = make_kernel(kernel)
+ self.register_buffer('kernel', kernel)
+
+ p = kernel.shape[0] - factor
+
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ self.pad = (pad0, pad1)
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
+
+ return out
+
+
+class Blur(nn.Module):
+ def __init__(self, kernel, pad, upsample_factor=1):
+ super().__init__()
+
+ kernel = make_kernel(kernel)
+
+ if upsample_factor > 1:
+ kernel = kernel * (upsample_factor ** 2)
+
+ self.register_buffer('kernel', kernel)
+
+ self.pad = pad
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
+
+ return out
+
+
+class EqualConv2d(nn.Module):
+ def __init__(
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
+ ):
+ super().__init__()
+
+ self.weight = nn.Parameter(
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
+ )
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
+
+ self.stride = stride
+ self.padding = padding
+
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_channel))
+
+ else:
+ self.bias = None
+
+ def forward(self, input):
+ out = F.conv2d(
+ input,
+ self.weight * self.scale,
+ bias=self.bias,
+ stride=self.stride,
+ padding=self.padding,
+ )
+
+ return out
+
+ def __repr__(self):
+ return (
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
+ )
+
+
+class EqualLinear(nn.Module):
+ def __init__(
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
+ ):
+ super().__init__()
+
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
+
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
+
+ else:
+ self.bias = None
+
+ self.activation = activation
+
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
+ self.lr_mul = lr_mul
+
+ def forward(self, input):
+ if self.activation:
+ out = F.linear(input, self.weight * self.scale)
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
+
+ else:
+ out = F.linear(
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
+ )
+
+ return out
+
+ def __repr__(self):
+ return (
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
+ )
+
+
+class ScaledLeakyReLU(nn.Module):
+ def __init__(self, negative_slope=0.2):
+ super().__init__()
+
+ self.negative_slope = negative_slope
+
+ def forward(self, input):
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
+
+ return out * math.sqrt(2)
+
+
+class ModulatedConv2d(nn.Module):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim,
+ demodulate=True,
+ upsample=False,
+ downsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ ):
+ super().__init__()
+
+ self.eps = 1e-8
+ self.kernel_size = kernel_size
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+ self.upsample = upsample
+ self.downsample = downsample
+
+ if upsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
+ pad0 = (p + 1) // 2 + factor - 1
+ pad1 = p // 2 + 1
+
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
+
+ if downsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
+
+ fan_in = in_channel * kernel_size ** 2
+ self.scale = 1 / math.sqrt(fan_in)
+ self.padding = kernel_size // 2
+
+ self.weight = nn.Parameter(
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
+ )
+
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
+
+ self.demodulate = demodulate
+
+ def __repr__(self):
+ return (
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
+ f'upsample={self.upsample}, downsample={self.downsample})'
+ )
+
+ def forward(self, input, style):
+ batch, in_channel, height, width = input.shape
+
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
+ weight = self.scale * self.weight * style
+
+ if self.demodulate:
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
+
+ weight = weight.view(
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
+ )
+
+ if self.upsample:
+ input = input.view(1, batch * in_channel, height, width)
+ weight = weight.view(
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
+ )
+ weight = weight.transpose(1, 2).reshape(
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
+ )
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
+ _, _, height, width = out.shape
+ out = out.view(batch, self.out_channel, height, width)
+ out = self.blur(out)
+
+ elif self.downsample:
+ input = self.blur(input)
+ _, _, height, width = input.shape
+ input = input.view(1, batch * in_channel, height, width)
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
+ _, _, height, width = out.shape
+ out = out.view(batch, self.out_channel, height, width)
+
+ else:
+ input = input.view(1, batch * in_channel, height, width)
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
+ _, _, height, width = out.shape
+ out = out.view(batch, self.out_channel, height, width)
+
+ return out
+
+
+class NoiseInjection(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ self.weight = nn.Parameter(torch.zeros(1))
+
+ def forward(self, image, noise=None):
+ if noise is None:
+ batch, _, height, width = image.shape
+ noise = image.new_empty(batch, 1, height, width).normal_()
+
+ return image + self.weight * noise
+
+
+class ConstantInput(nn.Module):
+ def __init__(self, channel, size=4):
+ super().__init__()
+
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
+
+ def forward(self, input):
+ batch = input.shape[0]
+ out = self.input.repeat(batch, 1, 1, 1)
+
+ return out
+
+
+class StyledConv(nn.Module):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim,
+ upsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ demodulate=True,
+ ):
+ super().__init__()
+
+ self.conv = ModulatedConv2d(
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim,
+ upsample=upsample,
+ blur_kernel=blur_kernel,
+ demodulate=demodulate,
+ )
+
+ self.noise = NoiseInjection()
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
+ # self.activate = ScaledLeakyReLU(0.2)
+ self.activate = FusedLeakyReLU(out_channel)
+
+ def forward(self, input, style, noise=None):
+ out = self.conv(input, style)
+ out = self.noise(out, noise=noise)
+ # out = out + self.bias
+ out = self.activate(out)
+
+ return out
+
+
+class ToRGB(nn.Module):
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
+ super().__init__()
+
+ if upsample:
+ self.upsample = Upsample(blur_kernel)
+
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
+
+ def forward(self, input, style, skip=None):
+ out = self.conv(input, style)
+ out = out + self.bias
+
+ if skip is not None:
+ skip = self.upsample(skip)
+
+ out = out + skip
+
+ return out
+
+
+class Generator(nn.Module):
+ def __init__(
+ self,
+ size,
+ style_dim,
+ n_mlp,
+ channel_multiplier=2,
+ blur_kernel=[1, 3, 3, 1],
+ lr_mlp=0.01,
+ ):
+ super().__init__()
+
+ self.size = size
+
+ self.style_dim = style_dim
+
+ layers = [PixelNorm()]
+
+ for i in range(n_mlp):
+ layers.append(
+ EqualLinear(
+ style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
+ )
+ )
+
+ self.style = nn.Sequential(*layers)
+
+ self.channels = {
+ 4: 512,
+ 8: 512,
+ 16: 512,
+ 32: 512,
+ 64: 256 * channel_multiplier,
+ 128: 128 * channel_multiplier,
+ 256: 64 * channel_multiplier,
+ 512: 32 * channel_multiplier,
+ 1024: 16 * channel_multiplier,
+ }
+
+ self.input = ConstantInput(self.channels[4])
+ self.conv1 = StyledConv(
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
+ )
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
+
+ self.log_size = int(math.log(size, 2))
+ self.num_layers = (self.log_size - 2) * 2 + 1
+
+ self.convs = nn.ModuleList()
+ self.upsamples = nn.ModuleList()
+ self.to_rgbs = nn.ModuleList()
+ self.noises = nn.Module()
+
+ in_channel = self.channels[4]
+
+ for layer_idx in range(self.num_layers):
+ res = (layer_idx + 5) // 2
+ shape = [1, 1, 2 ** res, 2 ** res]
+ self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
+
+ for i in range(3, self.log_size + 1):
+ out_channel = self.channels[2 ** i]
+
+ self.convs.append(
+ StyledConv(
+ in_channel,
+ out_channel,
+ 3,
+ style_dim,
+ upsample=True,
+ blur_kernel=blur_kernel,
+ )
+ )
+
+ self.convs.append(
+ StyledConv(
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
+ )
+ )
+
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
+
+ in_channel = out_channel
+
+ self.n_latent = self.log_size * 2 - 2
+
+ def make_noise(self):
+ device = self.input.input.device
+
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
+
+ for i in range(3, self.log_size + 1):
+ for _ in range(2):
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
+
+ return noises
+
+ def mean_latent(self, n_latent):
+ latent_in = torch.randn(
+ n_latent, self.style_dim, device=self.input.input.device
+ )
+ latent = self.style(latent_in).mean(0, keepdim=True)
+
+ return latent
+
+ def get_latent(self, input):
+ return self.style(input)
+
+ def forward(
+ self,
+ styles,
+ return_latents=False,
+ return_features=False,
+ inject_index=None,
+ truncation=1,
+ truncation_latent=None,
+ input_is_latent=False,
+ noise=None,
+ randomize_noise=True,
+ ):
+ if not input_is_latent:
+ styles = [self.style(s) for s in styles]
+
+ if noise is None:
+ if randomize_noise:
+ noise = [None] * self.num_layers
+ else:
+ noise = [
+ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
+ ]
+
+ if truncation < 1:
+ style_t = []
+
+ for style in styles:
+ style_t.append(
+ truncation_latent + truncation * (style - truncation_latent)
+ )
+
+ styles = style_t
+
+ if len(styles) < 2:
+ inject_index = self.n_latent
+
+ if styles[0].ndim < 3:
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+ else:
+ latent = styles[0]
+
+ else:
+ if inject_index is None:
+ inject_index = random.randint(1, self.n_latent - 1)
+
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
+
+ latent = torch.cat([latent, latent2], 1)
+
+ out = self.input(latent)
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
+
+ skip = self.to_rgb1(out, latent[:, 1])
+
+ i = 1
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
+ ):
+ out = conv1(out, latent[:, i], noise=noise1)
+ out = conv2(out, latent[:, i + 1], noise=noise2)
+ skip = to_rgb(out, latent[:, i + 2], skip)
+
+ i += 2
+
+ image = skip
+
+ if return_latents:
+ return image, latent
+ elif return_features:
+ return image, out
+ else:
+ return image, None
+
+
+class ConvLayer(nn.Sequential):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ downsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ bias=True,
+ activate=True,
+ ):
+ layers = []
+
+ if downsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
+
+ stride = 2
+ self.padding = 0
+
+ else:
+ stride = 1
+ self.padding = kernel_size // 2
+
+ layers.append(
+ EqualConv2d(
+ in_channel,
+ out_channel,
+ kernel_size,
+ padding=self.padding,
+ stride=stride,
+ bias=bias and not activate,
+ )
+ )
+
+ if activate:
+ if bias:
+ layers.append(FusedLeakyReLU(out_channel))
+
+ else:
+ layers.append(ScaledLeakyReLU(0.2))
+
+ super().__init__(*layers)
+
+
+class ResBlock(nn.Module):
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
+ super().__init__()
+
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
+
+ self.skip = ConvLayer(
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
+ )
+
+ def forward(self, input):
+ out = self.conv1(input)
+ out = self.conv2(out)
+
+ skip = self.skip(input)
+ out = (out + skip) / math.sqrt(2)
+
+ return out
+
+
+class Discriminator(nn.Module):
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
+ super().__init__()
+
+ channels = {
+ 4: 512,
+ 8: 512,
+ 16: 512,
+ 32: 512,
+ 64: 256 * channel_multiplier,
+ 128: 128 * channel_multiplier,
+ 256: 64 * channel_multiplier,
+ 512: 32 * channel_multiplier,
+ 1024: 16 * channel_multiplier,
+ }
+
+ convs = [ConvLayer(3, channels[size], 1)]
+
+ log_size = int(math.log(size, 2))
+
+ in_channel = channels[size]
+
+ for i in range(log_size, 2, -1):
+ out_channel = channels[2 ** (i - 1)]
+
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
+
+ in_channel = out_channel
+
+ self.convs = nn.Sequential(*convs)
+
+ self.stddev_group = 4
+ self.stddev_feat = 1
+
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
+ self.final_linear = nn.Sequential(
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
+ EqualLinear(channels[4], 1),
+ )
+
+ def forward(self, input):
+ out = self.convs(input)
+
+ batch, channel, height, width = out.shape
+ group = min(batch, self.stddev_group)
+ stddev = out.view(
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
+ )
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
+ stddev = stddev.repeat(group, 1, height, width)
+ out = torch.cat([out, stddev], 1)
+
+ out = self.final_conv(out)
+
+ out = out.view(batch, -1)
+ out = self.final_linear(out)
+
+ return out
diff --git a/models/e4e/stylegan2/op/__init__.py b/models/e4e/stylegan2/op/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0918d92285955855be89f00096b888ee5597ce3
--- /dev/null
+++ b/models/e4e/stylegan2/op/__init__.py
@@ -0,0 +1,2 @@
+from .fused_act import FusedLeakyReLU, fused_leaky_relu
+from .upfirdn2d import upfirdn2d
diff --git a/models/e4e/stylegan2/op/fused_act.py b/models/e4e/stylegan2/op/fused_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..90949545ba955dabf2e17d8cf5e524d5cb190a63
--- /dev/null
+++ b/models/e4e/stylegan2/op/fused_act.py
@@ -0,0 +1,34 @@
+import os
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.autograd import Function
+
+
+module_path = os.path.dirname(__file__)
+
+
+
+class FusedLeakyReLU(nn.Module):
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
+ super().__init__()
+
+ self.bias = nn.Parameter(torch.zeros(channel))
+ self.negative_slope = negative_slope
+ self.scale = scale
+
+ def forward(self, input):
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
+
+
+def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
+ input = input.cuda()
+ return (
+ F.leaky_relu(
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
+ )
+ * scale
+ )
+
diff --git a/models/e4e/stylegan2/op/fused_bias_act.cpp b/models/e4e/stylegan2/op/fused_bias_act.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..02be898f970bcc8ea297867fcaa4e71b24b3d949
--- /dev/null
+++ b/models/e4e/stylegan2/op/fused_bias_act.cpp
@@ -0,0 +1,21 @@
+#include
+
+
+torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale);
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale) {
+ CHECK_CUDA(input);
+ CHECK_CUDA(bias);
+
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
+}
\ No newline at end of file
diff --git a/models/e4e/stylegan2/op/fused_bias_act_kernel.cu b/models/e4e/stylegan2/op/fused_bias_act_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..c9fa56fea7ede7072dc8925cfb0148f136eb85b8
--- /dev/null
+++ b/models/e4e/stylegan2/op/fused_bias_act_kernel.cu
@@ -0,0 +1,99 @@
+// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, visit
+// https://nvlabs.github.io/stylegan2/license.html
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+
+template
+static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
+
+ scalar_t zero = 0.0;
+
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
+ scalar_t x = p_x[xi];
+
+ if (use_bias) {
+ x += p_b[(xi / step_b) % size_b];
+ }
+
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
+
+ scalar_t y;
+
+ switch (act * 10 + grad) {
+ default:
+ case 10: y = x; break;
+ case 11: y = x; break;
+ case 12: y = 0.0; break;
+
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
+ case 32: y = 0.0; break;
+ }
+
+ out[xi] = y * scale;
+ }
+}
+
+
+torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale) {
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+
+ auto x = input.contiguous();
+ auto b = bias.contiguous();
+ auto ref = refer.contiguous();
+
+ int use_bias = b.numel() ? 1 : 0;
+ int use_ref = ref.numel() ? 1 : 0;
+
+ int size_x = x.numel();
+ int size_b = b.numel();
+ int step_b = 1;
+
+ for (int i = 1 + 1; i < x.dim(); i++) {
+ step_b *= x.size(i);
+ }
+
+ int loop_x = 4;
+ int block_size = 4 * 32;
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
+
+ auto y = torch::empty_like(x);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
+ fused_bias_act_kernel<<>>(
+ y.data_ptr(),
+ x.data_ptr(),
+ b.data_ptr(),
+ ref.data_ptr(),
+ act,
+ grad,
+ alpha,
+ scale,
+ loop_x,
+ size_x,
+ step_b,
+ size_b,
+ use_bias,
+ use_ref
+ );
+ });
+
+ return y;
+}
\ No newline at end of file
diff --git a/models/e4e/stylegan2/op/upfirdn2d.cpp b/models/e4e/stylegan2/op/upfirdn2d.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..d2e633dc896433c205e18bc3e455539192ff968e
--- /dev/null
+++ b/models/e4e/stylegan2/op/upfirdn2d.cpp
@@ -0,0 +1,23 @@
+#include
+
+
+torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
+ int up_x, int up_y, int down_x, int down_y,
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1);
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
+ int up_x, int up_y, int down_x, int down_y,
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
+ CHECK_CUDA(input);
+ CHECK_CUDA(kernel);
+
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
+}
\ No newline at end of file
diff --git a/models/e4e/stylegan2/op/upfirdn2d.py b/models/e4e/stylegan2/op/upfirdn2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..02fc25af780868d9b883631eb6b03a25c225d745
--- /dev/null
+++ b/models/e4e/stylegan2/op/upfirdn2d.py
@@ -0,0 +1,60 @@
+import os
+
+import torch
+from torch.nn import functional as F
+
+
+module_path = os.path.dirname(__file__)
+
+
+
+def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
+ out = upfirdn2d_native(
+ input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
+ )
+
+ return out
+
+
+def upfirdn2d_native(
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
+):
+ _, channel, in_h, in_w = input.shape
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ _, in_h, in_w, minor = input.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+
+ out = F.pad(
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
+ )
+ out = out[
+ :,
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
+ :,
+ ]
+
+ out = out.permute(0, 3, 1, 2)
+ out = out.reshape(
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
+ )
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ )
+ out = out.permute(0, 2, 3, 1)
+ out = out[:, ::down_y, ::down_x, :]
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+
+ return out.view(-1, channel, out_h, out_w)
\ No newline at end of file
diff --git a/models/e4e/stylegan2/op/upfirdn2d_kernel.cu b/models/e4e/stylegan2/op/upfirdn2d_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..2a710aa6adc3d43ac93136a1814e3c39970e1c7e
--- /dev/null
+++ b/models/e4e/stylegan2/op/upfirdn2d_kernel.cu
@@ -0,0 +1,272 @@
+// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, visit
+// https://nvlabs.github.io/stylegan2/license.html
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+
+static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
+ int c = a / b;
+
+ if (c * b > a) {
+ c--;
+ }
+
+ return c;
+}
+
+
+struct UpFirDn2DKernelParams {
+ int up_x;
+ int up_y;
+ int down_x;
+ int down_y;
+ int pad_x0;
+ int pad_x1;
+ int pad_y0;
+ int pad_y1;
+
+ int major_dim;
+ int in_h;
+ int in_w;
+ int minor_dim;
+ int kernel_h;
+ int kernel_w;
+ int out_h;
+ int out_w;
+ int loop_major;
+ int loop_x;
+};
+
+
+template
+__global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) {
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
+
+ __shared__ volatile float sk[kernel_h][kernel_w];
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
+
+ int minor_idx = blockIdx.x;
+ int tile_out_y = minor_idx / p.minor_dim;
+ minor_idx -= tile_out_y * p.minor_dim;
+ tile_out_y *= tile_out_h;
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
+ int major_idx_base = blockIdx.z * p.loop_major;
+
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) {
+ return;
+ }
+
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) {
+ int ky = tap_idx / kernel_w;
+ int kx = tap_idx - ky * kernel_w;
+ scalar_t v = 0.0;
+
+ if (kx < p.kernel_w & ky < p.kernel_h) {
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
+ }
+
+ sk[ky][kx] = v;
+ }
+
+ for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) {
+ for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) {
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
+ int tile_in_x = floor_div(tile_mid_x, up_x);
+ int tile_in_y = floor_div(tile_mid_y, up_y);
+
+ __syncthreads();
+
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) {
+ int rel_in_y = in_idx / tile_in_w;
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
+ int in_x = rel_in_x + tile_in_x;
+ int in_y = rel_in_y + tile_in_y;
+
+ scalar_t v = 0.0;
+
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx];
+ }
+
+ sx[rel_in_y][rel_in_x] = v;
+ }
+
+ __syncthreads();
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) {
+ int rel_out_y = out_idx / tile_out_w;
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
+ int out_x = rel_out_x + tile_out_x;
+ int out_y = rel_out_y + tile_out_y;
+
+ int mid_x = tile_mid_x + rel_out_x * down_x;
+ int mid_y = tile_mid_y + rel_out_y * down_y;
+ int in_x = floor_div(mid_x, up_x);
+ int in_y = floor_div(mid_y, up_y);
+ int rel_in_x = in_x - tile_in_x;
+ int rel_in_y = in_y - tile_in_y;
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
+
+ scalar_t v = 0.0;
+
+ #pragma unroll
+ for (int y = 0; y < kernel_h / up_y; y++)
+ #pragma unroll
+ for (int x = 0; x < kernel_w / up_x; x++)
+ v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x];
+
+ if (out_x < p.out_w & out_y < p.out_h) {
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v;
+ }
+ }
+ }
+ }
+}
+
+
+torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
+ int up_x, int up_y, int down_x, int down_y,
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+
+ UpFirDn2DKernelParams p;
+
+ auto x = input.contiguous();
+ auto k = kernel.contiguous();
+
+ p.major_dim = x.size(0);
+ p.in_h = x.size(1);
+ p.in_w = x.size(2);
+ p.minor_dim = x.size(3);
+ p.kernel_h = k.size(0);
+ p.kernel_w = k.size(1);
+ p.up_x = up_x;
+ p.up_y = up_y;
+ p.down_x = down_x;
+ p.down_y = down_y;
+ p.pad_x0 = pad_x0;
+ p.pad_x1 = pad_x1;
+ p.pad_y0 = pad_y0;
+ p.pad_y1 = pad_y1;
+
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y;
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x;
+
+ auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
+
+ int mode = -1;
+
+ int tile_out_h;
+ int tile_out_w;
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 1;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) {
+ mode = 2;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 3;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) {
+ mode = 4;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 5;
+ tile_out_h = 8;
+ tile_out_w = 32;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) {
+ mode = 6;
+ tile_out_h = 8;
+ tile_out_w = 32;
+ }
+
+ dim3 block_size;
+ dim3 grid_size;
+
+ if (tile_out_h > 0 && tile_out_w) {
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
+ p.loop_x = 1;
+ block_size = dim3(32 * 8, 1, 1);
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
+ (p.major_dim - 1) / p.loop_major + 1);
+ }
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
+ switch (mode) {
+ case 1:
+ upfirdn2d_kernel<<>>(
+ out.data_ptr(), x.data_ptr(), k.data_ptr(), p
+ );
+
+ break;
+
+ case 2:
+ upfirdn2d_kernel<<>>(
+ out.data_ptr(), x.data_ptr(), k.data_ptr