diff --git a/PTI/.gitignore b/PTI/.gitignore deleted file mode 100644 index 8b137891791fe96927ad78e64b0aad7bded08bdc..0000000000000000000000000000000000000000 --- a/PTI/.gitignore +++ /dev/null @@ -1 +0,0 @@ - diff --git a/PTI/LICENSE b/PTI/LICENSE deleted file mode 100644 index 490821a63031f8f8d115b2edee69ae26ac5ea0b1..0000000000000000000000000000000000000000 --- a/PTI/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 Daniel Roich - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/PTI/README.md b/PTI/README.md deleted file mode 100644 index e57569893ab1695dde78a5c2265553a3bb1fb50a..0000000000000000000000000000000000000000 --- a/PTI/README.md +++ /dev/null @@ -1,229 +0,0 @@ -# PTI: Pivotal Tuning for Latent-based editing of Real Images - - - - - -Inference Notebook: - -

- -
-Pivotal Tuning Inversion (PTI) enables employing off-the-shelf latent based -semantic editing techniques on real images using StyleGAN. -PTI excels in identity preserving edits, portrayed through recognizable figures — -Serena Williams and Robert Downey Jr. (top), and in handling faces which -are clearly out-of-domain, e.g., due to heavy makeup (bottom). -
-

- -## Description -Official Implementation of our PTI paper + code for evaluation metrics. PTI introduces an optimization mechanizem for solving the StyleGAN inversion task. -Providing near-perfect reconstruction results while maintaining the high editing abilitis of the native StyleGAN latent space W. For more details, see - -## Recent Updates -**2021.07.01**: Fixed files download phase in the inference notebook. Which might caused the notebook not to run smoothly. - -**2021.06.29**: Added support for CPU. In order to run PTI on CPU please change `device` parameter under `configs/global_config.py` to "cpu" instead of "cuda". - -**2021.06.25** : Adding mohawk edit using StyleCLIP+PTI in inference notebook. - Updating documentation in inference notebook due to Google Drive rate limit reached. - Currently, Google Drive does not allow to download the pretrined models using Colab automatically. Manual intervention might be needed. - -## Getting Started -### Prerequisites -- Linux or macOS -- NVIDIA GPU + CUDA CuDNN (Not mandatory bur recommended) -- Python 3 - -### Installation -- Dependencies: - 1. lpips - 2. wandb - 3. pytorch - 4. torchvision - 5. matplotlib - 6. dlib -- All dependencies can be installed using *pip install* and the package name - -## Pretrained Models -Please download the pretrained models from the following links. - -### Auxiliary Models -We provide various auxiliary models needed for PTI inversion task. -This includes the StyleGAN generator and pre-trained models used for loss computation. -| Path | Description -| :--- | :---------- -|[FFHQ StyleGAN](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl) | StyleGAN2-ada model trained on FFHQ with 1024x1024 output resolution. -|[Dlib alignment](https://drive.google.com/file/d/1HKmjg6iXsWr4aFPuU0gBXPGR83wqMzq7/view?usp=sharing) | Dlib alignment used for images preproccessing. -|[FFHQ e4e encoder](https://drive.google.com/file/d/1ALC5CLA89Ouw40TwvxcwebhzWXM5YSCm/view?usp=sharing) | Pretrained e4e encoder. Used for StyleCLIP editing. - -Note: The StyleGAN model is used directly from the official [stylegan2-ada-pytorch implementation](https://github.com/NVlabs/stylegan2-ada-pytorch). -For StyleCLIP pretrained mappers, please see [StyleCLIP's official routes](https://github.com/orpatashnik/StyleCLIP/blob/main/utils.py) - - -By default, we assume that all auxiliary models are downloaded and saved to the directory `pretrained_models`. -However, you may use your own paths by changing the necessary values in `configs/path_configs.py`. - - -## Inversion -### Preparing your Data -In order to invert a real image and edit it you should first align and crop it to the correct size. To do so you should perform *One* of the following steps: -1. Run `notebooks/align_data.ipynb` and change the "images_path" variable to the raw images path -2. Run `utils/align_data.py` and change the "images_path" variable to the raw images path - - -### Weights And Biases -The project supports [Weights And Biases](https://wandb.ai/home) framework for experiment tracking. For the inversion task it enables visualization of the losses progression and the generator intermediate results during the initial inversion and the *Pivotal Tuning*(PT) procedure. - -The log frequency can be adjusted using the parameters defined at `configs/global_config.py` under the "Logs" subsection. - -There is no no need to have an account. However, in order to use the features provided by Weights and Biases you first have to register on their site. - - -### Running PTI -The main training script is `scripts/run_pti.py`. The script receives aligned and cropped images from paths configured in the "Input info" subscetion in - `configs/paths_config.py`. -Results are saved to directories found at "Dirs for output files" under `configs/paths_config.py`. This includes inversion latent codes and tuned generators. -The hyperparametrs for the inversion task can be found at `configs/hyperparameters.py`. They are intilized to the default values used in the paper. - -## Editing -By default, we assume that all auxiliary edit directions are downloaded and saved to the directory `editings`. -However, you may use your own paths by changing the necessary values in `configs/path_configs.py` under "Edit directions" subsection. - -Example of editing code can be found at `scripts/latent_editor_wrapper.py` - -## Inference Notebooks -To help visualize the results of PTI we provide a Jupyter notebook found in `notebooks/inference_playground.ipynb`. -The notebook will download the pretrained models and run inference on a sample image found online or -on images of your choosing. It is recommended to run this in [Google Colab](https://colab.research.google.com/github/danielroich/PTI/blob/main/notebooks/inference_playground.ipynb). - -The notebook demonstrates how to: -- Invert an image using PTI -- Visualise the inversion and use the PTI output -- Edit the image after PTI using InterfaceGAN and StyleCLIP -- Compare to other inversion methods - -## Evaluation -Currently the repository supports qualitative evaluation for reconstruction of: PTI, SG2 (*W Space*), e4e, SG2Plus (*W+ Space*). -As well as editing using InterfaceGAN and GANSpace for the same inversion methods. -To run the evaluation please see `evaluation/qualitative_edit_comparison.py`. Examples of the evaluation scripts are: - -

- -
-Reconsturction comparison between different methods. The images order is: Original image, W+ inversion, e4e inversion, W inversion, PTI inversion -
-

- -

- -
-InterfaceGAN pose edit comparison between different methods. The images order is: Original, W+, e4e, W, PTI -
-

- -

- - -
-Image per edit or several edits without comparison -
-

- -### Coming Soon - Quantitative evaluation and StyleCLIP qualitative evaluation - -## Repository structure -| Path | Description -| :--- | :--- -| ├  configs | Folder containing configs defining Hyperparameters, paths and logging -| ├  criteria | Folder containing various loss and regularization criterias for the optimization -| ├  dnnlib | Folder containing internal utils for StyleGAN2-ada -| ├  docs | Folder containing the latent space edit directions -| ├  editings | Folder containing images displayed in the README -| ├  environment | Folder containing Anaconda environment used in our experiments -| ├  licenses | Folder containing licenses of the open source projects used in this repository -| ├  models | Folder containing models used in different editing techniques and first phase inversion -| ├  notebooks | Folder with jupyter notebooks to demonstrate the usage of PTI end-to-end -| ├  scripts | Folder with running scripts for inversion, editing and metric computations -| ├  torch_utils | Folder containing internal utils for StyleGAN2-ada -| ├  training | Folder containing the core training logic of PTI -| ├  utils | Folder with various utility functions - - -## Credits -**StyleGAN2-ada model and implementation:** -https://github.com/NVlabs/stylegan2-ada-pytorch -Copyright © 2021, NVIDIA Corporation. -Nvidia Source Code License https://nvlabs.github.io/stylegan2-ada-pytorch/license.html - -**LPIPS model and implementation:** -https://github.com/richzhang/PerceptualSimilarity -Copyright (c) 2020, Sou Uchida -License (BSD 2-Clause) https://github.com/richzhang/PerceptualSimilarity/blob/master/LICENSE - -**e4e model and implementation:** -https://github.com/omertov/encoder4editing -Copyright (c) 2021 omertov -License (MIT) https://github.com/omertov/encoder4editing/blob/main/LICENSE - -**StyleCLIP model and implementation:** -https://github.com/orpatashnik/StyleCLIP -Copyright (c) 2021 orpatashnik -License (MIT) https://github.com/orpatashnik/StyleCLIP/blob/main/LICENSE - -**InterfaceGAN implementation:** -https://github.com/genforce/interfacegan -Copyright (c) 2020 genforce -License (MIT) https://github.com/genforce/interfacegan/blob/master/LICENSE - -**GANSpace implementation:** -https://github.com/harskish/ganspace -Copyright (c) 2020 harkish -License (Apache License 2.0) https://github.com/harskish/ganspace/blob/master/LICENSE - - -## Acknowledgments -This repository structure is based on [encoder4editing](https://github.com/omertov/encoder4editing) and [ReStyle](https://github.com/yuval-alaluf/restyle-encoder) repositories - -## Contact -For any inquiry please contact us at our email addresses: danielroich@gmail.com or ron.mokady@gmail.com - - -## Citation -If you use this code for your research, please cite: -``` -@article{roich2021pivotal, - title={Pivotal Tuning for Latent-based Editing of Real Images}, - author={Roich, Daniel and Mokady, Ron and Bermano, Amit H and Cohen-Or, Daniel}, - journal={arXiv preprint arXiv:2106.05744}, - year={2021} -} -``` diff --git a/PTI/torch_utils/custom_ops.py b/PTI/torch_utils/custom_ops.py deleted file mode 100644 index 4cc4e43fc6f6ce79f2bd68a44ba87990b9b8564e..0000000000000000000000000000000000000000 --- a/PTI/torch_utils/custom_ops.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -import os -import glob -import torch -import torch.utils.cpp_extension -import importlib -import hashlib -import shutil -from pathlib import Path - -from torch.utils.file_baton import FileBaton - -#---------------------------------------------------------------------------- -# Global options. - -verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' - -#---------------------------------------------------------------------------- -# Internal helper funcs. - -def _find_compiler_bindir(): - patterns = [ - 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', - 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', - 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', - 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', - ] - for pattern in patterns: - matches = sorted(glob.glob(pattern)) - if len(matches): - return matches[-1] - return None - -#---------------------------------------------------------------------------- -# Main entry point for compiling and loading C++/CUDA plugins. - -_cached_plugins = dict() - -def get_plugin(module_name, sources, **build_kwargs): - assert verbosity in ['none', 'brief', 'full'] - - # Already cached? - if module_name in _cached_plugins: - return _cached_plugins[module_name] - - # Print status. - if verbosity == 'full': - print(f'Setting up PyTorch plugin "{module_name}"...') - elif verbosity == 'brief': - print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) - - try: # pylint: disable=too-many-nested-blocks - # Make sure we can find the necessary compiler binaries. - if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: - compiler_bindir = _find_compiler_bindir() - if compiler_bindir is None: - raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') - os.environ['PATH'] += ';' + compiler_bindir - - # Compile and load. - verbose_build = (verbosity == 'full') - - # Incremental build md5sum trickery. Copies all the input source files - # into a cached build directory under a combined md5 digest of the input - # source files. Copying is done only if the combined digest has changed. - # This keeps input file timestamps and filenames the same as in previous - # extension builds, allowing for fast incremental rebuilds. - # - # This optimization is done only in case all the source files reside in - # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR - # environment variable is set (we take this as a signal that the user - # actually cares about this.) - source_dirs_set = set(os.path.dirname(source) for source in sources) - if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): - all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) - - # Compute a combined hash digest for all source files in the same - # custom op directory (usually .cu, .cpp, .py and .h files). - hash_md5 = hashlib.md5() - for src in all_source_files: - with open(src, 'rb') as f: - hash_md5.update(f.read()) - build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access - digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) - - if not os.path.isdir(digest_build_dir): - os.makedirs(digest_build_dir, exist_ok=True) - baton = FileBaton(os.path.join(digest_build_dir, 'lock')) - if baton.try_acquire(): - try: - for src in all_source_files: - shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) - finally: - baton.release() - else: - # Someone else is copying source files under the digest dir, - # wait until done and continue. - baton.wait() - digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] - torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, - verbose=verbose_build, sources=digest_sources, **build_kwargs) - else: - torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) - module = importlib.import_module(module_name) - - except: - if verbosity == 'brief': - print('Failed!') - raise - - # Print status and add to cache. - if verbosity == 'full': - print(f'Done setting up PyTorch plugin "{module_name}".') - elif verbosity == 'brief': - print('Done.') - _cached_plugins[module_name] = module - return module - -#---------------------------------------------------------------------------- diff --git a/PTI/torch_utils/misc.py b/PTI/torch_utils/misc.py deleted file mode 100644 index 7829f4d9f168557ce8a9a6dec289aa964234cb8c..0000000000000000000000000000000000000000 --- a/PTI/torch_utils/misc.py +++ /dev/null @@ -1,262 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -import re -import contextlib -import numpy as np -import torch -import warnings -import dnnlib - -#---------------------------------------------------------------------------- -# Cached construction of constant tensors. Avoids CPU=>GPU copy when the -# same constant is used multiple times. - -_constant_cache = dict() - -def constant(value, shape=None, dtype=None, device=None, memory_format=None): - value = np.asarray(value) - if shape is not None: - shape = tuple(shape) - if dtype is None: - dtype = torch.get_default_dtype() - if device is None: - device = torch.device('cpu') - if memory_format is None: - memory_format = torch.contiguous_format - - key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) - tensor = _constant_cache.get(key, None) - if tensor is None: - tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) - if shape is not None: - tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) - tensor = tensor.contiguous(memory_format=memory_format) - _constant_cache[key] = tensor - return tensor - -#---------------------------------------------------------------------------- -# Replace NaN/Inf with specified numerical values. - -try: - nan_to_num = torch.nan_to_num # 1.8.0a0 -except AttributeError: - def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin - assert isinstance(input, torch.Tensor) - if posinf is None: - posinf = torch.finfo(input.dtype).max - if neginf is None: - neginf = torch.finfo(input.dtype).min - assert nan == 0 - return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) - -#---------------------------------------------------------------------------- -# Symbolic assert. - -try: - symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access -except AttributeError: - symbolic_assert = torch.Assert # 1.7.0 - -#---------------------------------------------------------------------------- -# Context manager to suppress known warnings in torch.jit.trace(). - -class suppress_tracer_warnings(warnings.catch_warnings): - def __enter__(self): - super().__enter__() - warnings.simplefilter('ignore', category=torch.jit.TracerWarning) - return self - -#---------------------------------------------------------------------------- -# Assert that the shape of a tensor matches the given list of integers. -# None indicates that the size of a dimension is allowed to vary. -# Performs symbolic assertion when used in torch.jit.trace(). - -def assert_shape(tensor, ref_shape): - if tensor.ndim != len(ref_shape): - raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') - for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): - if ref_size is None: - pass - elif isinstance(ref_size, torch.Tensor): - with suppress_tracer_warnings(): # as_tensor results are registered as constants - symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') - elif isinstance(size, torch.Tensor): - with suppress_tracer_warnings(): # as_tensor results are registered as constants - symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') - elif size != ref_size: - raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') - -#---------------------------------------------------------------------------- -# Function decorator that calls torch.autograd.profiler.record_function(). - -def profiled_function(fn): - def decorator(*args, **kwargs): - with torch.autograd.profiler.record_function(fn.__name__): - return fn(*args, **kwargs) - decorator.__name__ = fn.__name__ - return decorator - -#---------------------------------------------------------------------------- -# Sampler for torch.utils.data.DataLoader that loops over the dataset -# indefinitely, shuffling items as it goes. - -class InfiniteSampler(torch.utils.data.Sampler): - def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): - assert len(dataset) > 0 - assert num_replicas > 0 - assert 0 <= rank < num_replicas - assert 0 <= window_size <= 1 - super().__init__(dataset) - self.dataset = dataset - self.rank = rank - self.num_replicas = num_replicas - self.shuffle = shuffle - self.seed = seed - self.window_size = window_size - - def __iter__(self): - order = np.arange(len(self.dataset)) - rnd = None - window = 0 - if self.shuffle: - rnd = np.random.RandomState(self.seed) - rnd.shuffle(order) - window = int(np.rint(order.size * self.window_size)) - - idx = 0 - while True: - i = idx % order.size - if idx % self.num_replicas == self.rank: - yield order[i] - if window >= 2: - j = (i - rnd.randint(window)) % order.size - order[i], order[j] = order[j], order[i] - idx += 1 - -#---------------------------------------------------------------------------- -# Utilities for operating with torch.nn.Module parameters and buffers. - -def params_and_buffers(module): - assert isinstance(module, torch.nn.Module) - return list(module.parameters()) + list(module.buffers()) - -def named_params_and_buffers(module): - assert isinstance(module, torch.nn.Module) - return list(module.named_parameters()) + list(module.named_buffers()) - -def copy_params_and_buffers(src_module, dst_module, require_all=False): - assert isinstance(src_module, torch.nn.Module) - assert isinstance(dst_module, torch.nn.Module) - src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)} - for name, tensor in named_params_and_buffers(dst_module): - assert (name in src_tensors) or (not require_all) - if name in src_tensors: - tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) - -#---------------------------------------------------------------------------- -# Context manager for easily enabling/disabling DistributedDataParallel -# synchronization. - -@contextlib.contextmanager -def ddp_sync(module, sync): - assert isinstance(module, torch.nn.Module) - if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): - yield - else: - with module.no_sync(): - yield - -#---------------------------------------------------------------------------- -# Check DistributedDataParallel consistency across processes. - -def check_ddp_consistency(module, ignore_regex=None): - assert isinstance(module, torch.nn.Module) - for name, tensor in named_params_and_buffers(module): - fullname = type(module).__name__ + '.' + name - if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): - continue - tensor = tensor.detach() - other = tensor.clone() - torch.distributed.broadcast(tensor=other, src=0) - assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname - -#---------------------------------------------------------------------------- -# Print summary table of module hierarchy. - -def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): - assert isinstance(module, torch.nn.Module) - assert not isinstance(module, torch.jit.ScriptModule) - assert isinstance(inputs, (tuple, list)) - - # Register hooks. - entries = [] - nesting = [0] - def pre_hook(_mod, _inputs): - nesting[0] += 1 - def post_hook(mod, _inputs, outputs): - nesting[0] -= 1 - if nesting[0] <= max_nesting: - outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] - outputs = [t for t in outputs if isinstance(t, torch.Tensor)] - entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) - hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] - hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] - - # Run module. - outputs = module(*inputs) - for hook in hooks: - hook.remove() - - # Identify unique outputs, parameters, and buffers. - tensors_seen = set() - for e in entries: - e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] - e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] - e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] - tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} - - # Filter out redundant entries. - if skip_redundant: - entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] - - # Construct table. - rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] - rows += [['---'] * len(rows[0])] - param_total = 0 - buffer_total = 0 - submodule_names = {mod: name for name, mod in module.named_modules()} - for e in entries: - name = '' if e.mod is module else submodule_names[e.mod] - param_size = sum(t.numel() for t in e.unique_params) - buffer_size = sum(t.numel() for t in e.unique_buffers) - output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] - output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] - rows += [[ - name + (':0' if len(e.outputs) >= 2 else ''), - str(param_size) if param_size else '-', - str(buffer_size) if buffer_size else '-', - (output_shapes + ['-'])[0], - (output_dtypes + ['-'])[0], - ]] - for idx in range(1, len(e.outputs)): - rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] - param_total += param_size - buffer_total += buffer_size - rows += [['---'] * len(rows[0])] - rows += [['Total', str(param_total), str(buffer_total), '-', '-']] - - # Print table. - widths = [max(len(cell) for cell in column) for column in zip(*rows)] - print() - for row in rows: - print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) - print() - return outputs - -#---------------------------------------------------------------------------- diff --git a/PTI/torch_utils/ops/bias_act.cpp b/PTI/torch_utils/ops/bias_act.cpp deleted file mode 100644 index 5d2425d8054991a8e8b6f7a940fd0ff7fa0bb330..0000000000000000000000000000000000000000 --- a/PTI/torch_utils/ops/bias_act.cpp +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -#include -#include -#include -#include "bias_act.h" - -//------------------------------------------------------------------------ - -static bool has_same_layout(torch::Tensor x, torch::Tensor y) -{ - if (x.dim() != y.dim()) - return false; - for (int64_t i = 0; i < x.dim(); i++) - { - if (x.size(i) != y.size(i)) - return false; - if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) - return false; - } - return true; -} - -//------------------------------------------------------------------------ - -static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) -{ - // Validate arguments. - TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); - TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); - TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); - TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); - TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); - TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); - TORCH_CHECK(b.dim() == 1, "b must have rank 1"); - TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); - TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); - TORCH_CHECK(grad >= 0, "grad must be non-negative"); - - // Validate layout. - TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); - TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); - TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); - TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); - TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); - - // Create output tensor. - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - torch::Tensor y = torch::empty_like(x); - TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); - - // Initialize CUDA kernel parameters. - bias_act_kernel_params p; - p.x = x.data_ptr(); - p.b = (b.numel()) ? b.data_ptr() : NULL; - p.xref = (xref.numel()) ? xref.data_ptr() : NULL; - p.yref = (yref.numel()) ? yref.data_ptr() : NULL; - p.dy = (dy.numel()) ? dy.data_ptr() : NULL; - p.y = y.data_ptr(); - p.grad = grad; - p.act = act; - p.alpha = alpha; - p.gain = gain; - p.clamp = clamp; - p.sizeX = (int)x.numel(); - p.sizeB = (int)b.numel(); - p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; - - // Choose CUDA kernel. - void* kernel; - AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] - { - kernel = choose_bias_act_kernel(p); - }); - TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); - - // Launch CUDA kernel. - p.loopX = 4; - int blockSize = 4 * 32; - int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; - void* args[] = {&p}; - AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); - return y; -} - -//------------------------------------------------------------------------ - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("bias_act", &bias_act); -} - -//------------------------------------------------------------------------ diff --git a/PTI/torch_utils/ops/bias_act.cu b/PTI/torch_utils/ops/bias_act.cu deleted file mode 100644 index dd8fc4756d7d94727f94af738665b68d9c518880..0000000000000000000000000000000000000000 --- a/PTI/torch_utils/ops/bias_act.cu +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -#include -#include "bias_act.h" - -//------------------------------------------------------------------------ -// Helpers. - -template struct InternalType; -template <> struct InternalType { typedef double scalar_t; }; -template <> struct InternalType { typedef float scalar_t; }; -template <> struct InternalType { typedef float scalar_t; }; - -//------------------------------------------------------------------------ -// CUDA kernel. - -template -__global__ void bias_act_kernel(bias_act_kernel_params p) -{ - typedef typename InternalType::scalar_t scalar_t; - int G = p.grad; - scalar_t alpha = (scalar_t)p.alpha; - scalar_t gain = (scalar_t)p.gain; - scalar_t clamp = (scalar_t)p.clamp; - scalar_t one = (scalar_t)1; - scalar_t two = (scalar_t)2; - scalar_t expRange = (scalar_t)80; - scalar_t halfExpRange = (scalar_t)40; - scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; - scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; - - // Loop over elements. - int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; - for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) - { - // Load. - scalar_t x = (scalar_t)((const T*)p.x)[xi]; - scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; - scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; - scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; - scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; - scalar_t yy = (gain != 0) ? yref / gain : 0; - scalar_t y = 0; - - // Apply bias. - ((G == 0) ? x : xref) += b; - - // linear - if (A == 1) - { - if (G == 0) y = x; - if (G == 1) y = x; - } - - // relu - if (A == 2) - { - if (G == 0) y = (x > 0) ? x : 0; - if (G == 1) y = (yy > 0) ? x : 0; - } - - // lrelu - if (A == 3) - { - if (G == 0) y = (x > 0) ? x : x * alpha; - if (G == 1) y = (yy > 0) ? x : x * alpha; - } - - // tanh - if (A == 4) - { - if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } - if (G == 1) y = x * (one - yy * yy); - if (G == 2) y = x * (one - yy * yy) * (-two * yy); - } - - // sigmoid - if (A == 5) - { - if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); - if (G == 1) y = x * yy * (one - yy); - if (G == 2) y = x * yy * (one - yy) * (one - two * yy); - } - - // elu - if (A == 6) - { - if (G == 0) y = (x >= 0) ? x : exp(x) - one; - if (G == 1) y = (yy >= 0) ? x : x * (yy + one); - if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); - } - - // selu - if (A == 7) - { - if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); - if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); - if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); - } - - // softplus - if (A == 8) - { - if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); - if (G == 1) y = x * (one - exp(-yy)); - if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } - } - - // swish - if (A == 9) - { - if (G == 0) - y = (x < -expRange) ? 0 : x / (exp(-x) + one); - else - { - scalar_t c = exp(xref); - scalar_t d = c + one; - if (G == 1) - y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); - else - y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); - yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; - } - } - - // Apply gain. - y *= gain * dy; - - // Clamp. - if (clamp >= 0) - { - if (G == 0) - y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; - else - y = (yref > -clamp & yref < clamp) ? y : 0; - } - - // Store. - ((T*)p.y)[xi] = (T)y; - } -} - -//------------------------------------------------------------------------ -// CUDA kernel selection. - -template void* choose_bias_act_kernel(const bias_act_kernel_params& p) -{ - if (p.act == 1) return (void*)bias_act_kernel; - if (p.act == 2) return (void*)bias_act_kernel; - if (p.act == 3) return (void*)bias_act_kernel; - if (p.act == 4) return (void*)bias_act_kernel; - if (p.act == 5) return (void*)bias_act_kernel; - if (p.act == 6) return (void*)bias_act_kernel; - if (p.act == 7) return (void*)bias_act_kernel; - if (p.act == 8) return (void*)bias_act_kernel; - if (p.act == 9) return (void*)bias_act_kernel; - return NULL; -} - -//------------------------------------------------------------------------ -// Template specializations. - -template void* choose_bias_act_kernel (const bias_act_kernel_params& p); -template void* choose_bias_act_kernel (const bias_act_kernel_params& p); -template void* choose_bias_act_kernel (const bias_act_kernel_params& p); - -//------------------------------------------------------------------------ diff --git a/PTI/torch_utils/ops/bias_act.h b/PTI/torch_utils/ops/bias_act.h deleted file mode 100644 index a32187e1fb7e3bae509d4eceaf900866866875a4..0000000000000000000000000000000000000000 --- a/PTI/torch_utils/ops/bias_act.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -//------------------------------------------------------------------------ -// CUDA kernel parameters. - -struct bias_act_kernel_params -{ - const void* x; // [sizeX] - const void* b; // [sizeB] or NULL - const void* xref; // [sizeX] or NULL - const void* yref; // [sizeX] or NULL - const void* dy; // [sizeX] or NULL - void* y; // [sizeX] - - int grad; - int act; - float alpha; - float gain; - float clamp; - - int sizeX; - int sizeB; - int stepB; - int loopX; -}; - -//------------------------------------------------------------------------ -// CUDA kernel selection. - -template void* choose_bias_act_kernel(const bias_act_kernel_params& p); - -//------------------------------------------------------------------------ diff --git a/PTI/torch_utils/ops/bias_act.py b/PTI/torch_utils/ops/bias_act.py deleted file mode 100644 index 4bcb409a89ccf6c6f6ecfca5962683df2d280b1f..0000000000000000000000000000000000000000 --- a/PTI/torch_utils/ops/bias_act.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -"""Custom PyTorch ops for efficient bias and activation.""" - -import os -import warnings -import numpy as np -import torch -import dnnlib -import traceback - -from .. import custom_ops -from .. import misc - -#---------------------------------------------------------------------------- - -activation_funcs = { - 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), - 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), - 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), - 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), - 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), - 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), - 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), - 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), - 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), -} - -#---------------------------------------------------------------------------- - -_inited = False -_plugin = None -_null_tensor = torch.empty([0]) - -def _init(): - global _inited, _plugin - if not _inited: - _inited = True - sources = ['bias_act.cpp', 'bias_act.cu'] - sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] - try: - _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) - except: - warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) - return _plugin is not None - -#---------------------------------------------------------------------------- - -def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): - r"""Fused bias and activation function. - - Adds bias `b` to activation tensor `x`, evaluates activation function `act`, - and scales the result by `gain`. Each of the steps is optional. In most cases, - the fused op is considerably more efficient than performing the same calculation - using standard PyTorch ops. It supports first and second order gradients, - but not third order gradients. - - Args: - x: Input activation tensor. Can be of any shape. - b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type - as `x`. The shape must be known, and it must match the dimension of `x` - corresponding to `dim`. - dim: The dimension in `x` corresponding to the elements of `b`. - The value of `dim` is ignored if `b` is not specified. - act: Name of the activation function to evaluate, or `"linear"` to disable. - Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. - See `activation_funcs` for a full list. `None` is not allowed. - alpha: Shape parameter for the activation function, or `None` to use the default. - gain: Scaling factor for the output tensor, or `None` to use default. - See `activation_funcs` for the default scaling of each activation function. - If unsure, consider specifying 1. - clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable - the clamping (default). - impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). - - Returns: - Tensor of the same shape and datatype as `x`. - """ - assert isinstance(x, torch.Tensor) - assert impl in ['ref', 'cuda'] - if impl == 'cuda' and x.device.type == 'cuda' and _init(): - return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) - return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) - -#---------------------------------------------------------------------------- - -@misc.profiled_function -def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): - """Slow reference implementation of `bias_act()` using standard TensorFlow ops. - """ - assert isinstance(x, torch.Tensor) - assert clamp is None or clamp >= 0 - spec = activation_funcs[act] - alpha = float(alpha if alpha is not None else spec.def_alpha) - gain = float(gain if gain is not None else spec.def_gain) - clamp = float(clamp if clamp is not None else -1) - - # Add bias. - if b is not None: - assert isinstance(b, torch.Tensor) and b.ndim == 1 - assert 0 <= dim < x.ndim - assert b.shape[0] == x.shape[dim] - x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) - - # Evaluate activation function. - alpha = float(alpha) - x = spec.func(x, alpha=alpha) - - # Scale by gain. - gain = float(gain) - if gain != 1: - x = x * gain - - # Clamp. - if clamp >= 0: - x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type - return x - -#---------------------------------------------------------------------------- - -_bias_act_cuda_cache = dict() - -def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): - """Fast CUDA implementation of `bias_act()` using custom ops. - """ - # Parse arguments. - assert clamp is None or clamp >= 0 - spec = activation_funcs[act] - alpha = float(alpha if alpha is not None else spec.def_alpha) - gain = float(gain if gain is not None else spec.def_gain) - clamp = float(clamp if clamp is not None else -1) - - # Lookup from cache. - key = (dim, act, alpha, gain, clamp) - if key in _bias_act_cuda_cache: - return _bias_act_cuda_cache[key] - - # Forward op. - class BiasActCuda(torch.autograd.Function): - @staticmethod - def forward(ctx, x, b): # pylint: disable=arguments-differ - ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format - x = x.contiguous(memory_format=ctx.memory_format) - b = b.contiguous() if b is not None else _null_tensor - y = x - if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: - y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) - ctx.save_for_backward( - x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, - b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, - y if 'y' in spec.ref else _null_tensor) - return y - - @staticmethod - def backward(ctx, dy): # pylint: disable=arguments-differ - dy = dy.contiguous(memory_format=ctx.memory_format) - x, b, y = ctx.saved_tensors - dx = None - db = None - - if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: - dx = dy - if act != 'linear' or gain != 1 or clamp >= 0: - dx = BiasActCudaGrad.apply(dy, x, b, y) - - if ctx.needs_input_grad[1]: - db = dx.sum([i for i in range(dx.ndim) if i != dim]) - - return dx, db - - # Backward op. - class BiasActCudaGrad(torch.autograd.Function): - @staticmethod - def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ - ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format - dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) - ctx.save_for_backward( - dy if spec.has_2nd_grad else _null_tensor, - x, b, y) - return dx - - @staticmethod - def backward(ctx, d_dx): # pylint: disable=arguments-differ - d_dx = d_dx.contiguous(memory_format=ctx.memory_format) - dy, x, b, y = ctx.saved_tensors - d_dy = None - d_x = None - d_b = None - d_y = None - - if ctx.needs_input_grad[0]: - d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) - - if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): - d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) - - if spec.has_2nd_grad and ctx.needs_input_grad[2]: - d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) - - return d_dy, d_x, d_b, d_y - - # Add to cache. - _bias_act_cuda_cache[key] = BiasActCuda - return BiasActCuda - -#---------------------------------------------------------------------------- diff --git a/PTI/torch_utils/ops/conv2d_gradfix.py b/PTI/torch_utils/ops/conv2d_gradfix.py deleted file mode 100644 index e95e10d0b1d0315a63a76446fd4c5c293c8bbc6d..0000000000000000000000000000000000000000 --- a/PTI/torch_utils/ops/conv2d_gradfix.py +++ /dev/null @@ -1,170 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -"""Custom replacement for `torch.nn.functional.conv2d` that supports -arbitrarily high order gradients with zero performance penalty.""" - -import warnings -import contextlib -import torch - -# pylint: disable=redefined-builtin -# pylint: disable=arguments-differ -# pylint: disable=protected-access - -#---------------------------------------------------------------------------- - -enabled = False # Enable the custom op by setting this to true. -weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. - -@contextlib.contextmanager -def no_weight_gradients(): - global weight_gradients_disabled - old = weight_gradients_disabled - weight_gradients_disabled = True - yield - weight_gradients_disabled = old - -#---------------------------------------------------------------------------- - -def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): - if _should_use_custom_op(input): - return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) - return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) - -def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): - if _should_use_custom_op(input): - return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) - return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) - -#---------------------------------------------------------------------------- - -def _should_use_custom_op(input): - assert isinstance(input, torch.Tensor) - if (not enabled) or (not torch.backends.cudnn.enabled): - return False - if input.device.type != 'cuda': - return False - if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): - return True - warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') - return False - -def _tuple_of_ints(xs, ndim): - xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim - assert len(xs) == ndim - assert all(isinstance(x, int) for x in xs) - return xs - -#---------------------------------------------------------------------------- - -_conv2d_gradfix_cache = dict() - -def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): - # Parse arguments. - ndim = 2 - weight_shape = tuple(weight_shape) - stride = _tuple_of_ints(stride, ndim) - padding = _tuple_of_ints(padding, ndim) - output_padding = _tuple_of_ints(output_padding, ndim) - dilation = _tuple_of_ints(dilation, ndim) - - # Lookup from cache. - key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) - if key in _conv2d_gradfix_cache: - return _conv2d_gradfix_cache[key] - - # Validate arguments. - assert groups >= 1 - assert len(weight_shape) == ndim + 2 - assert all(stride[i] >= 1 for i in range(ndim)) - assert all(padding[i] >= 0 for i in range(ndim)) - assert all(dilation[i] >= 0 for i in range(ndim)) - if not transpose: - assert all(output_padding[i] == 0 for i in range(ndim)) - else: # transpose - assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) - - # Helpers. - common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) - def calc_output_padding(input_shape, output_shape): - if transpose: - return [0, 0] - return [ - input_shape[i + 2] - - (output_shape[i + 2] - 1) * stride[i] - - (1 - 2 * padding[i]) - - dilation[i] * (weight_shape[i + 2] - 1) - for i in range(ndim) - ] - - # Forward & backward. - class Conv2d(torch.autograd.Function): - @staticmethod - def forward(ctx, input, weight, bias): - assert weight.shape == weight_shape - if not transpose: - output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) - else: # transpose - output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) - ctx.save_for_backward(input, weight) - return output - - @staticmethod - def backward(ctx, grad_output): - input, weight = ctx.saved_tensors - grad_input = None - grad_weight = None - grad_bias = None - - if ctx.needs_input_grad[0]: - p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) - grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) - assert grad_input.shape == input.shape - - if ctx.needs_input_grad[1] and not weight_gradients_disabled: - grad_weight = Conv2dGradWeight.apply(grad_output, input) - assert grad_weight.shape == weight_shape - - if ctx.needs_input_grad[2]: - grad_bias = grad_output.sum([0, 2, 3]) - - return grad_input, grad_weight, grad_bias - - # Gradient with respect to the weights. - class Conv2dGradWeight(torch.autograd.Function): - @staticmethod - def forward(ctx, grad_output, input): - op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') - flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] - grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) - assert grad_weight.shape == weight_shape - ctx.save_for_backward(grad_output, input) - return grad_weight - - @staticmethod - def backward(ctx, grad2_grad_weight): - grad_output, input = ctx.saved_tensors - grad2_grad_output = None - grad2_input = None - - if ctx.needs_input_grad[0]: - grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) - assert grad2_grad_output.shape == grad_output.shape - - if ctx.needs_input_grad[1]: - p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) - grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) - assert grad2_input.shape == input.shape - - return grad2_grad_output, grad2_input - - _conv2d_gradfix_cache[key] = Conv2d - return Conv2d - -#---------------------------------------------------------------------------- diff --git a/PTI/torch_utils/ops/conv2d_resample.py b/PTI/torch_utils/ops/conv2d_resample.py deleted file mode 100644 index cd4750744c83354bab78704d4ef51ad1070fcc4a..0000000000000000000000000000000000000000 --- a/PTI/torch_utils/ops/conv2d_resample.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -"""2D convolution with optional up/downsampling.""" - -import torch - -from .. import misc -from . import conv2d_gradfix -from . import upfirdn2d -from .upfirdn2d import _parse_padding -from .upfirdn2d import _get_filter_size - -#---------------------------------------------------------------------------- - -def _get_weight_shape(w): - with misc.suppress_tracer_warnings(): # this value will be treated as a constant - shape = [int(sz) for sz in w.shape] - misc.assert_shape(w, shape) - return shape - -#---------------------------------------------------------------------------- - -def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): - """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. - """ - out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) - - # Flip weight if requested. - if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). - w = w.flip([2, 3]) - - # Workaround performance pitfall in cuDNN 8.0.5, triggered when using - # 1x1 kernel + memory_format=channels_last + less than 64 channels. - if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: - if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: - if out_channels <= 4 and groups == 1: - in_shape = x.shape - x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) - x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) - else: - x = x.to(memory_format=torch.contiguous_format) - w = w.to(memory_format=torch.contiguous_format) - x = conv2d_gradfix.conv2d(x, w, groups=groups) - return x.to(memory_format=torch.channels_last) - - # Otherwise => execute using conv2d_gradfix. - op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d - return op(x, w, stride=stride, padding=padding, groups=groups) - -#---------------------------------------------------------------------------- - -@misc.profiled_function -def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): - r"""2D convolution with optional up/downsampling. - - Padding is performed only once at the beginning, not between the operations. - - Args: - x: Input tensor of shape - `[batch_size, in_channels, in_height, in_width]`. - w: Weight tensor of shape - `[out_channels, in_channels//groups, kernel_height, kernel_width]`. - f: Low-pass filter for up/downsampling. Must be prepared beforehand by - calling upfirdn2d.setup_filter(). None = identity (default). - up: Integer upsampling factor (default: 1). - down: Integer downsampling factor (default: 1). - padding: Padding with respect to the upsampled image. Can be a single number - or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` - (default: 0). - groups: Split input channels into N groups (default: 1). - flip_weight: False = convolution, True = correlation (default: True). - flip_filter: False = convolution, True = correlation (default: False). - - Returns: - Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. - """ - # Validate arguments. - assert isinstance(x, torch.Tensor) and (x.ndim == 4) - assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) - assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) - assert isinstance(up, int) and (up >= 1) - assert isinstance(down, int) and (down >= 1) - assert isinstance(groups, int) and (groups >= 1) - out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) - fw, fh = _get_filter_size(f) - px0, px1, py0, py1 = _parse_padding(padding) - - # Adjust padding to account for up/downsampling. - if up > 1: - px0 += (fw + up - 1) // 2 - px1 += (fw - up) // 2 - py0 += (fh + up - 1) // 2 - py1 += (fh - up) // 2 - if down > 1: - px0 += (fw - down + 1) // 2 - px1 += (fw - down) // 2 - py0 += (fh - down + 1) // 2 - py1 += (fh - down) // 2 - - # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. - if kw == 1 and kh == 1 and (down > 1 and up == 1): - x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) - x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) - return x - - # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. - if kw == 1 and kh == 1 and (up > 1 and down == 1): - x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) - x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) - return x - - # Fast path: downsampling only => use strided convolution. - if down > 1 and up == 1: - x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) - x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) - return x - - # Fast path: upsampling with optional downsampling => use transpose strided convolution. - if up > 1: - if groups == 1: - w = w.transpose(0, 1) - else: - w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) - w = w.transpose(1, 2) - w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) - px0 -= kw - 1 - px1 -= kw - up - py0 -= kh - 1 - py1 -= kh - up - pxt = max(min(-px0, -px1), 0) - pyt = max(min(-py0, -py1), 0) - x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) - x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) - if down > 1: - x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) - return x - - # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. - if up == 1 and down == 1: - if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: - return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) - - # Fallback: Generic reference implementation. - x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) - x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) - if down > 1: - x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) - return x - -#---------------------------------------------------------------------------- diff --git a/PTI/torch_utils/ops/fma.py b/PTI/torch_utils/ops/fma.py deleted file mode 100644 index 2eeac58a626c49231e04122b93e321ada954c5d3..0000000000000000000000000000000000000000 --- a/PTI/torch_utils/ops/fma.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" - -import torch - -#---------------------------------------------------------------------------- - -def fma(a, b, c): # => a * b + c - return _FusedMultiplyAdd.apply(a, b, c) - -#---------------------------------------------------------------------------- - -class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c - @staticmethod - def forward(ctx, a, b, c): # pylint: disable=arguments-differ - out = torch.addcmul(c, a, b) - ctx.save_for_backward(a, b) - ctx.c_shape = c.shape - return out - - @staticmethod - def backward(ctx, dout): # pylint: disable=arguments-differ - a, b = ctx.saved_tensors - c_shape = ctx.c_shape - da = None - db = None - dc = None - - if ctx.needs_input_grad[0]: - da = _unbroadcast(dout * b, a.shape) - - if ctx.needs_input_grad[1]: - db = _unbroadcast(dout * a, b.shape) - - if ctx.needs_input_grad[2]: - dc = _unbroadcast(dout, c_shape) - - return da, db, dc - -#---------------------------------------------------------------------------- - -def _unbroadcast(x, shape): - extra_dims = x.ndim - len(shape) - assert extra_dims >= 0 - dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] - if len(dim): - x = x.sum(dim=dim, keepdim=True) - if extra_dims: - x = x.reshape(-1, *x.shape[extra_dims+1:]) - assert x.shape == shape - return x - -#---------------------------------------------------------------------------- diff --git a/PTI/torch_utils/ops/grid_sample_gradfix.py b/PTI/torch_utils/ops/grid_sample_gradfix.py deleted file mode 100644 index ca6b3413ea72a734703c34382c023b84523601fd..0000000000000000000000000000000000000000 --- a/PTI/torch_utils/ops/grid_sample_gradfix.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -"""Custom replacement for `torch.nn.functional.grid_sample` that -supports arbitrarily high order gradients between the input and output. -Only works on 2D images and assumes -`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" - -import warnings -import torch - -# pylint: disable=redefined-builtin -# pylint: disable=arguments-differ -# pylint: disable=protected-access - -#---------------------------------------------------------------------------- - -enabled = False # Enable the custom op by setting this to true. - -#---------------------------------------------------------------------------- - -def grid_sample(input, grid): - if _should_use_custom_op(): - return _GridSample2dForward.apply(input, grid) - return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) - -#---------------------------------------------------------------------------- - -def _should_use_custom_op(): - if not enabled: - return False - if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): - return True - warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') - return False - -#---------------------------------------------------------------------------- - -class _GridSample2dForward(torch.autograd.Function): - @staticmethod - def forward(ctx, input, grid): - assert input.ndim == 4 - assert grid.ndim == 4 - output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) - ctx.save_for_backward(input, grid) - return output - - @staticmethod - def backward(ctx, grad_output): - input, grid = ctx.saved_tensors - grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) - return grad_input, grad_grid - -#---------------------------------------------------------------------------- - -class _GridSample2dBackward(torch.autograd.Function): - @staticmethod - def forward(ctx, grad_output, input, grid): - op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') - grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) - ctx.save_for_backward(grid) - return grad_input, grad_grid - - @staticmethod - def backward(ctx, grad2_grad_input, grad2_grad_grid): - _ = grad2_grad_grid # unused - grid, = ctx.saved_tensors - grad2_grad_output = None - grad2_input = None - grad2_grid = None - - if ctx.needs_input_grad[0]: - grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) - - assert not ctx.needs_input_grad[2] - return grad2_grad_output, grad2_input, grad2_grid - -#---------------------------------------------------------------------------- diff --git a/PTI/torch_utils/ops/upfirdn2d.cpp b/PTI/torch_utils/ops/upfirdn2d.cpp deleted file mode 100644 index 2d7177fc60040751d20e9a8da0301fa3ab64968a..0000000000000000000000000000000000000000 --- a/PTI/torch_utils/ops/upfirdn2d.cpp +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -#include -#include -#include -#include "upfirdn2d.h" - -//------------------------------------------------------------------------ - -static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) -{ - // Validate arguments. - TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); - TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); - TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); - TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); - TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); - TORCH_CHECK(x.dim() == 4, "x must be rank 4"); - TORCH_CHECK(f.dim() == 2, "f must be rank 2"); - TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); - TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); - TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); - - // Create output tensor. - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; - int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; - TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); - torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); - TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); - - // Initialize CUDA kernel parameters. - upfirdn2d_kernel_params p; - p.x = x.data_ptr(); - p.f = f.data_ptr(); - p.y = y.data_ptr(); - p.up = make_int2(upx, upy); - p.down = make_int2(downx, downy); - p.pad0 = make_int2(padx0, pady0); - p.flip = (flip) ? 1 : 0; - p.gain = gain; - p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); - p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); - p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); - p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); - p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); - p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); - p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; - p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; - - // Choose CUDA kernel. - upfirdn2d_kernel_spec spec; - AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] - { - spec = choose_upfirdn2d_kernel(p); - }); - - // Set looping options. - p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; - p.loopMinor = spec.loopMinor; - p.loopX = spec.loopX; - p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; - p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; - - // Compute grid size. - dim3 blockSize, gridSize; - if (spec.tileOutW < 0) // large - { - blockSize = dim3(4, 32, 1); - gridSize = dim3( - ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, - (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, - p.launchMajor); - } - else // small - { - blockSize = dim3(256, 1, 1); - gridSize = dim3( - ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, - (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, - p.launchMajor); - } - - // Launch CUDA kernel. - void* args[] = {&p}; - AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); - return y; -} - -//------------------------------------------------------------------------ - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("upfirdn2d", &upfirdn2d); -} - -//------------------------------------------------------------------------ diff --git a/PTI/torch_utils/ops/upfirdn2d.cu b/PTI/torch_utils/ops/upfirdn2d.cu deleted file mode 100644 index ebdd9879f4bb16fc57a23cbc81f9de8ef54e4916..0000000000000000000000000000000000000000 --- a/PTI/torch_utils/ops/upfirdn2d.cu +++ /dev/null @@ -1,350 +0,0 @@ -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -#include -#include "upfirdn2d.h" - -//------------------------------------------------------------------------ -// Helpers. - -template struct InternalType; -template <> struct InternalType { typedef double scalar_t; }; -template <> struct InternalType { typedef float scalar_t; }; -template <> struct InternalType { typedef float scalar_t; }; - -static __device__ __forceinline__ int floor_div(int a, int b) -{ - int t = 1 - a / b; - return (a + t * b) / b - t; -} - -//------------------------------------------------------------------------ -// Generic CUDA implementation for large filters. - -template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) -{ - typedef typename InternalType::scalar_t scalar_t; - - // Calculate thread index. - int minorBase = blockIdx.x * blockDim.x + threadIdx.x; - int outY = minorBase / p.launchMinor; - minorBase -= outY * p.launchMinor; - int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; - int majorBase = blockIdx.z * p.loopMajor; - if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) - return; - - // Setup Y receptive field. - int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; - int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); - int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; - int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; - if (p.flip) - filterY = p.filterSize.y - 1 - filterY; - - // Loop over major, minor, and X. - for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) - for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor) - { - int nc = major * p.sizeMinor + minor; - int n = nc / p.inSize.z; - int c = nc - n * p.inSize.z; - for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y) - { - // Setup X receptive field. - int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; - int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); - int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX; - int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; - if (p.flip) - filterX = p.filterSize.x - 1 - filterX; - - // Initialize pointers. - const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; - const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; - int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; - int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; - - // Inner loop. - scalar_t v = 0; - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - v += (scalar_t)(*xp) * (scalar_t)(*fp); - xp += p.inStride.x; - fp += filterStepX; - } - xp += p.inStride.y - w * p.inStride.x; - fp += filterStepY - w * filterStepX; - } - - // Store result. - v *= p.gain; - ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; - } - } -} - -//------------------------------------------------------------------------ -// Specialized CUDA implementation for small filters. - -template -static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) -{ - typedef typename InternalType::scalar_t scalar_t; - const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; - const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; - __shared__ volatile scalar_t sf[filterH][filterW]; - __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; - - // Calculate tile index. - int minorBase = blockIdx.x; - int tileOutY = minorBase / p.launchMinor; - minorBase -= tileOutY * p.launchMinor; - minorBase *= loopMinor; - tileOutY *= tileOutH; - int tileOutXBase = blockIdx.y * p.loopX * tileOutW; - int majorBase = blockIdx.z * p.loopMajor; - if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor) - return; - - // Load filter (flipped). - for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x) - { - int fy = tapIdx / filterW; - int fx = tapIdx - fy * filterW; - scalar_t v = 0; - if (fx < p.filterSize.x & fy < p.filterSize.y) - { - int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; - int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; - v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; - } - sf[fy][fx] = v; - } - - // Loop over major and X. - for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) - { - int baseNC = major * p.sizeMinor + minorBase; - int n = baseNC / p.inSize.z; - int baseC = baseNC - n * p.inSize.z; - for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW) - { - // Load input pixels. - int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; - int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; - int tileInX = floor_div(tileMidX, upx); - int tileInY = floor_div(tileMidY, upy); - __syncthreads(); - for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x) - { - int relC = inIdx; - int relInX = relC / loopMinor; - int relInY = relInX / tileInW; - relC -= relInX * loopMinor; - relInX -= relInY * tileInW; - int c = baseC + relC; - int inX = tileInX + relInX; - int inY = tileInY + relInY; - scalar_t v = 0; - if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z) - v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; - sx[relInY][relInX][relC] = v; - } - - // Loop over output pixels. - __syncthreads(); - for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x) - { - int relC = outIdx; - int relOutX = relC / loopMinor; - int relOutY = relOutX / tileOutW; - relC -= relOutX * loopMinor; - relOutX -= relOutY * tileOutW; - int c = baseC + relC; - int outX = tileOutX + relOutX; - int outY = tileOutY + relOutY; - - // Setup receptive field. - int midX = tileMidX + relOutX * downx; - int midY = tileMidY + relOutY * downy; - int inX = floor_div(midX, upx); - int inY = floor_div(midY, upy); - int relInX = inX - tileInX; - int relInY = inY - tileInY; - int filterX = (inX + 1) * upx - midX - 1; // flipped - int filterY = (inY + 1) * upy - midY - 1; // flipped - - // Inner loop. - if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) - { - scalar_t v = 0; - #pragma unroll - for (int y = 0; y < filterH / upy; y++) - #pragma unroll - for (int x = 0; x < filterW / upx; x++) - v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx]; - v *= p.gain; - ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; - } - } - } - } -} - -//------------------------------------------------------------------------ -// CUDA kernel selection. - -template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p) -{ - int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; - - upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous - if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last - - if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous - { - if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - } - if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last - { - if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - } - if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous - { - if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - } - if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last - { - if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - } - if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous - { - if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - } - if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last - { - if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - } - if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous - { - if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - } - if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last - { - if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - } - if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous - { - if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; - if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; - if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; - if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; - } - if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last - { - if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; - if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; - if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; - if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; - } - if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous - { - if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; - if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; - if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; - if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; - if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; - } - if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last - { - if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; - if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; - if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; - if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; - if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; - } - if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous - { - if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; - if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; - if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; - if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; - if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; - } - if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last - { - if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; - if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; - if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; - if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; - if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; - } - return spec; -} - -//------------------------------------------------------------------------ -// Template specializations. - -template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); -template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); -template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); - -//------------------------------------------------------------------------ diff --git a/PTI/torch_utils/ops/upfirdn2d.h b/PTI/torch_utils/ops/upfirdn2d.h deleted file mode 100644 index c9e2032bcac9d2abde7a75eea4d812da348afadd..0000000000000000000000000000000000000000 --- a/PTI/torch_utils/ops/upfirdn2d.h +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -#include - -//------------------------------------------------------------------------ -// CUDA kernel parameters. - -struct upfirdn2d_kernel_params -{ - const void* x; - const float* f; - void* y; - - int2 up; - int2 down; - int2 pad0; - int flip; - float gain; - - int4 inSize; // [width, height, channel, batch] - int4 inStride; - int2 filterSize; // [width, height] - int2 filterStride; - int4 outSize; // [width, height, channel, batch] - int4 outStride; - int sizeMinor; - int sizeMajor; - - int loopMinor; - int loopMajor; - int loopX; - int launchMinor; - int launchMajor; -}; - -//------------------------------------------------------------------------ -// CUDA kernel specialization. - -struct upfirdn2d_kernel_spec -{ - void* kernel; - int tileOutW; - int tileOutH; - int loopMinor; - int loopX; -}; - -//------------------------------------------------------------------------ -// CUDA kernel selection. - -template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); - -//------------------------------------------------------------------------ diff --git a/PTI/torch_utils/ops/upfirdn2d.py b/PTI/torch_utils/ops/upfirdn2d.py deleted file mode 100644 index ceeac2b9834e33b7c601c28bf27f32aa91c69256..0000000000000000000000000000000000000000 --- a/PTI/torch_utils/ops/upfirdn2d.py +++ /dev/null @@ -1,384 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -"""Custom PyTorch ops for efficient resampling of 2D images.""" - -import os -import warnings -import numpy as np -import torch -import traceback - -from .. import custom_ops -from .. import misc -from . import conv2d_gradfix - -#---------------------------------------------------------------------------- - -_inited = False -_plugin = None - -def _init(): - global _inited, _plugin - if not _inited: - sources = ['upfirdn2d.cpp', 'upfirdn2d.cu'] - sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] - try: - _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) - except: - warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) - return _plugin is not None - -def _parse_scaling(scaling): - if isinstance(scaling, int): - scaling = [scaling, scaling] - assert isinstance(scaling, (list, tuple)) - assert all(isinstance(x, int) for x in scaling) - sx, sy = scaling - assert sx >= 1 and sy >= 1 - return sx, sy - -def _parse_padding(padding): - if isinstance(padding, int): - padding = [padding, padding] - assert isinstance(padding, (list, tuple)) - assert all(isinstance(x, int) for x in padding) - if len(padding) == 2: - padx, pady = padding - padding = [padx, padx, pady, pady] - padx0, padx1, pady0, pady1 = padding - return padx0, padx1, pady0, pady1 - -def _get_filter_size(f): - if f is None: - return 1, 1 - assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] - fw = f.shape[-1] - fh = f.shape[0] - with misc.suppress_tracer_warnings(): - fw = int(fw) - fh = int(fh) - misc.assert_shape(f, [fh, fw][:f.ndim]) - assert fw >= 1 and fh >= 1 - return fw, fh - -#---------------------------------------------------------------------------- - -def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): - r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. - - Args: - f: Torch tensor, numpy array, or python list of the shape - `[filter_height, filter_width]` (non-separable), - `[filter_taps]` (separable), - `[]` (impulse), or - `None` (identity). - device: Result device (default: cpu). - normalize: Normalize the filter so that it retains the magnitude - for constant input signal (DC)? (default: True). - flip_filter: Flip the filter? (default: False). - gain: Overall scaling factor for signal magnitude (default: 1). - separable: Return a separable filter? (default: select automatically). - - Returns: - Float32 tensor of the shape - `[filter_height, filter_width]` (non-separable) or - `[filter_taps]` (separable). - """ - # Validate. - if f is None: - f = 1 - f = torch.as_tensor(f, dtype=torch.float32) - assert f.ndim in [0, 1, 2] - assert f.numel() > 0 - if f.ndim == 0: - f = f[np.newaxis] - - # Separable? - if separable is None: - separable = (f.ndim == 1 and f.numel() >= 8) - if f.ndim == 1 and not separable: - f = f.ger(f) - assert f.ndim == (1 if separable else 2) - - # Apply normalize, flip, gain, and device. - if normalize: - f /= f.sum() - if flip_filter: - f = f.flip(list(range(f.ndim))) - f = f * (gain ** (f.ndim / 2)) - f = f.to(device=device) - return f - -#---------------------------------------------------------------------------- - -def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): - r"""Pad, upsample, filter, and downsample a batch of 2D images. - - Performs the following sequence of operations for each channel: - - 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). - - 2. Pad the image with the specified number of zeros on each side (`padding`). - Negative padding corresponds to cropping the image. - - 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it - so that the footprint of all output pixels lies within the input image. - - 4. Downsample the image by keeping every Nth pixel (`down`). - - This sequence of operations bears close resemblance to scipy.signal.upfirdn(). - The fused op is considerably more efficient than performing the same calculation - using standard PyTorch ops. It supports gradients of arbitrary order. - - Args: - x: Float32/float64/float16 input tensor of the shape - `[batch_size, num_channels, in_height, in_width]`. - f: Float32 FIR filter of the shape - `[filter_height, filter_width]` (non-separable), - `[filter_taps]` (separable), or - `None` (identity). - up: Integer upsampling factor. Can be a single int or a list/tuple - `[x, y]` (default: 1). - down: Integer downsampling factor. Can be a single int or a list/tuple - `[x, y]` (default: 1). - padding: Padding with respect to the upsampled image. Can be a single number - or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` - (default: 0). - flip_filter: False = convolution, True = correlation (default: False). - gain: Overall scaling factor for signal magnitude (default: 1). - impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). - - Returns: - Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. - """ - assert isinstance(x, torch.Tensor) - assert impl in ['ref', 'cuda'] - if impl == 'cuda' and x.device.type == 'cuda' and _init(): - return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) - return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) - -#---------------------------------------------------------------------------- - -@misc.profiled_function -def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): - """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. - """ - # Validate arguments. - assert isinstance(x, torch.Tensor) and x.ndim == 4 - if f is None: - f = torch.ones([1, 1], dtype=torch.float32, device=x.device) - assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] - assert f.dtype == torch.float32 and not f.requires_grad - batch_size, num_channels, in_height, in_width = x.shape - upx, upy = _parse_scaling(up) - downx, downy = _parse_scaling(down) - padx0, padx1, pady0, pady1 = _parse_padding(padding) - - # Upsample by inserting zeros. - x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) - x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) - x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) - - # Pad or crop. - x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) - x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] - - # Setup filter. - f = f * (gain ** (f.ndim / 2)) - f = f.to(x.dtype) - if not flip_filter: - f = f.flip(list(range(f.ndim))) - - # Convolve with the filter. - f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) - if f.ndim == 4: - x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) - else: - x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) - x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) - - # Downsample by throwing away pixels. - x = x[:, :, ::downy, ::downx] - return x - -#---------------------------------------------------------------------------- - -_upfirdn2d_cuda_cache = dict() - -def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): - """Fast CUDA implementation of `upfirdn2d()` using custom ops. - """ - # Parse arguments. - upx, upy = _parse_scaling(up) - downx, downy = _parse_scaling(down) - padx0, padx1, pady0, pady1 = _parse_padding(padding) - - # Lookup from cache. - key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) - if key in _upfirdn2d_cuda_cache: - return _upfirdn2d_cuda_cache[key] - - # Forward op. - class Upfirdn2dCuda(torch.autograd.Function): - @staticmethod - def forward(ctx, x, f): # pylint: disable=arguments-differ - assert isinstance(x, torch.Tensor) and x.ndim == 4 - if f is None: - f = torch.ones([1, 1], dtype=torch.float32, device=x.device) - assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] - y = x - if f.ndim == 2: - y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) - else: - y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain)) - y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain)) - ctx.save_for_backward(f) - ctx.x_shape = x.shape - return y - - @staticmethod - def backward(ctx, dy): # pylint: disable=arguments-differ - f, = ctx.saved_tensors - _, _, ih, iw = ctx.x_shape - _, _, oh, ow = dy.shape - fw, fh = _get_filter_size(f) - p = [ - fw - padx0 - 1, - iw * upx - ow * downx + padx0 - upx + 1, - fh - pady0 - 1, - ih * upy - oh * downy + pady0 - upy + 1, - ] - dx = None - df = None - - if ctx.needs_input_grad[0]: - dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) - - assert not ctx.needs_input_grad[1] - return dx, df - - # Add to cache. - _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda - return Upfirdn2dCuda - -#---------------------------------------------------------------------------- - -def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): - r"""Filter a batch of 2D images using the given 2D FIR filter. - - By default, the result is padded so that its shape matches the input. - User-specified padding is applied on top of that, with negative values - indicating cropping. Pixels outside the image are assumed to be zero. - - Args: - x: Float32/float64/float16 input tensor of the shape - `[batch_size, num_channels, in_height, in_width]`. - f: Float32 FIR filter of the shape - `[filter_height, filter_width]` (non-separable), - `[filter_taps]` (separable), or - `None` (identity). - padding: Padding with respect to the output. Can be a single number or a - list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` - (default: 0). - flip_filter: False = convolution, True = correlation (default: False). - gain: Overall scaling factor for signal magnitude (default: 1). - impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). - - Returns: - Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. - """ - padx0, padx1, pady0, pady1 = _parse_padding(padding) - fw, fh = _get_filter_size(f) - p = [ - padx0 + fw // 2, - padx1 + (fw - 1) // 2, - pady0 + fh // 2, - pady1 + (fh - 1) // 2, - ] - return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) - -#---------------------------------------------------------------------------- - -def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): - r"""Upsample a batch of 2D images using the given 2D FIR filter. - - By default, the result is padded so that its shape is a multiple of the input. - User-specified padding is applied on top of that, with negative values - indicating cropping. Pixels outside the image are assumed to be zero. - - Args: - x: Float32/float64/float16 input tensor of the shape - `[batch_size, num_channels, in_height, in_width]`. - f: Float32 FIR filter of the shape - `[filter_height, filter_width]` (non-separable), - `[filter_taps]` (separable), or - `None` (identity). - up: Integer upsampling factor. Can be a single int or a list/tuple - `[x, y]` (default: 1). - padding: Padding with respect to the output. Can be a single number or a - list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` - (default: 0). - flip_filter: False = convolution, True = correlation (default: False). - gain: Overall scaling factor for signal magnitude (default: 1). - impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). - - Returns: - Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. - """ - upx, upy = _parse_scaling(up) - padx0, padx1, pady0, pady1 = _parse_padding(padding) - fw, fh = _get_filter_size(f) - p = [ - padx0 + (fw + upx - 1) // 2, - padx1 + (fw - upx) // 2, - pady0 + (fh + upy - 1) // 2, - pady1 + (fh - upy) // 2, - ] - return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) - -#---------------------------------------------------------------------------- - -def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): - r"""Downsample a batch of 2D images using the given 2D FIR filter. - - By default, the result is padded so that its shape is a fraction of the input. - User-specified padding is applied on top of that, with negative values - indicating cropping. Pixels outside the image are assumed to be zero. - - Args: - x: Float32/float64/float16 input tensor of the shape - `[batch_size, num_channels, in_height, in_width]`. - f: Float32 FIR filter of the shape - `[filter_height, filter_width]` (non-separable), - `[filter_taps]` (separable), or - `None` (identity). - down: Integer downsampling factor. Can be a single int or a list/tuple - `[x, y]` (default: 1). - padding: Padding with respect to the input. Can be a single number or a - list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` - (default: 0). - flip_filter: False = convolution, True = correlation (default: False). - gain: Overall scaling factor for signal magnitude (default: 1). - impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). - - Returns: - Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. - """ - downx, downy = _parse_scaling(down) - padx0, padx1, pady0, pady1 = _parse_padding(padding) - fw, fh = _get_filter_size(f) - p = [ - padx0 + (fw - downx + 1) // 2, - padx1 + (fw - downx) // 2, - pady0 + (fh - downy + 1) // 2, - pady1 + (fh - downy) // 2, - ] - return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) - -#---------------------------------------------------------------------------- diff --git a/PTI/torch_utils/persistence.py b/PTI/torch_utils/persistence.py deleted file mode 100644 index 0186cfd97bca0fcb397a7b73643520c1d1105a02..0000000000000000000000000000000000000000 --- a/PTI/torch_utils/persistence.py +++ /dev/null @@ -1,251 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -"""Facilities for pickling Python code alongside other data. - -The pickled code is automatically imported into a separate Python module -during unpickling. This way, any previously exported pickles will remain -usable even if the original code is no longer available, or if the current -version of the code is not consistent with what was originally pickled.""" - -import sys -import pickle -import io -import inspect -import copy -import uuid -import types -import dnnlib - -#---------------------------------------------------------------------------- - -_version = 6 # internal version number -_decorators = set() # {decorator_class, ...} -_import_hooks = [] # [hook_function, ...] -_module_to_src_dict = dict() # {module: src, ...} -_src_to_module_dict = dict() # {src: module, ...} - -#---------------------------------------------------------------------------- - -def persistent_class(orig_class): - r"""Class decorator that extends a given class to save its source code - when pickled. - - Example: - - from torch_utils import persistence - - @persistence.persistent_class - class MyNetwork(torch.nn.Module): - def __init__(self, num_inputs, num_outputs): - super().__init__() - self.fc = MyLayer(num_inputs, num_outputs) - ... - - @persistence.persistent_class - class MyLayer(torch.nn.Module): - ... - - When pickled, any instance of `MyNetwork` and `MyLayer` will save its - source code alongside other internal state (e.g., parameters, buffers, - and submodules). This way, any previously exported pickle will remain - usable even if the class definitions have been modified or are no - longer available. - - The decorator saves the source code of the entire Python module - containing the decorated class. It does *not* save the source code of - any imported modules. Thus, the imported modules must be available - during unpickling, also including `torch_utils.persistence` itself. - - It is ok to call functions defined in the same module from the - decorated class. However, if the decorated class depends on other - classes defined in the same module, they must be decorated as well. - This is illustrated in the above example in the case of `MyLayer`. - - It is also possible to employ the decorator just-in-time before - calling the constructor. For example: - - cls = MyLayer - if want_to_make_it_persistent: - cls = persistence.persistent_class(cls) - layer = cls(num_inputs, num_outputs) - - As an additional feature, the decorator also keeps track of the - arguments that were used to construct each instance of the decorated - class. The arguments can be queried via `obj.init_args` and - `obj.init_kwargs`, and they are automatically pickled alongside other - object state. A typical use case is to first unpickle a previous - instance of a persistent class, and then upgrade it to use the latest - version of the source code: - - with open('old_pickle.pkl', 'rb') as f: - old_net = pickle.load(f) - new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) - misc.copy_params_and_buffers(old_net, new_net, require_all=True) - """ - assert isinstance(orig_class, type) - if is_persistent(orig_class): - return orig_class - - assert orig_class.__module__ in sys.modules - orig_module = sys.modules[orig_class.__module__] - orig_module_src = _module_to_src(orig_module) - - class Decorator(orig_class): - _orig_module_src = orig_module_src - _orig_class_name = orig_class.__name__ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._init_args = copy.deepcopy(args) - self._init_kwargs = copy.deepcopy(kwargs) - assert orig_class.__name__ in orig_module.__dict__ - _check_pickleable(self.__reduce__()) - - @property - def init_args(self): - return copy.deepcopy(self._init_args) - - @property - def init_kwargs(self): - return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) - - def __reduce__(self): - fields = list(super().__reduce__()) - fields += [None] * max(3 - len(fields), 0) - if fields[0] is not _reconstruct_persistent_obj: - meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) - fields[0] = _reconstruct_persistent_obj # reconstruct func - fields[1] = (meta,) # reconstruct args - fields[2] = None # state dict - return tuple(fields) - - Decorator.__name__ = orig_class.__name__ - _decorators.add(Decorator) - return Decorator - -#---------------------------------------------------------------------------- - -def is_persistent(obj): - r"""Test whether the given object or class is persistent, i.e., - whether it will save its source code when pickled. - """ - try: - if obj in _decorators: - return True - except TypeError: - pass - return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck - -#---------------------------------------------------------------------------- - -def import_hook(hook): - r"""Register an import hook that is called whenever a persistent object - is being unpickled. A typical use case is to patch the pickled source - code to avoid errors and inconsistencies when the API of some imported - module has changed. - - The hook should have the following signature: - - hook(meta) -> modified meta - - `meta` is an instance of `dnnlib.EasyDict` with the following fields: - - type: Type of the persistent object, e.g. `'class'`. - version: Internal version number of `torch_utils.persistence`. - module_src Original source code of the Python module. - class_name: Class name in the original Python module. - state: Internal state of the object. - - Example: - - @persistence.import_hook - def wreck_my_network(meta): - if meta.class_name == 'MyNetwork': - print('MyNetwork is being imported. I will wreck it!') - meta.module_src = meta.module_src.replace("True", "False") - return meta - """ - assert callable(hook) - _import_hooks.append(hook) - -#---------------------------------------------------------------------------- - -def _reconstruct_persistent_obj(meta): - r"""Hook that is called internally by the `pickle` module to unpickle - a persistent object. - """ - meta = dnnlib.EasyDict(meta) - meta.state = dnnlib.EasyDict(meta.state) - for hook in _import_hooks: - meta = hook(meta) - assert meta is not None - - assert meta.version == _version - module = _src_to_module(meta.module_src) - - assert meta.type == 'class' - orig_class = module.__dict__[meta.class_name] - decorator_class = persistent_class(orig_class) - obj = decorator_class.__new__(decorator_class) - - setstate = getattr(obj, '__setstate__', None) - if callable(setstate): - setstate(meta.state) # pylint: disable=not-callable - else: - obj.__dict__.update(meta.state) - return obj - -#---------------------------------------------------------------------------- - -def _module_to_src(module): - r"""Query the source code of a given Python module. - """ - src = _module_to_src_dict.get(module, None) - if src is None: - src = inspect.getsource(module) - _module_to_src_dict[module] = src - _src_to_module_dict[src] = module - return src - -def _src_to_module(src): - r"""Get or create a Python module for the given source code. - """ - module = _src_to_module_dict.get(src, None) - if module is None: - module_name = "_imported_module_" + uuid.uuid4().hex - module = types.ModuleType(module_name) - sys.modules[module_name] = module - _module_to_src_dict[module] = src - _src_to_module_dict[src] = module - exec(src, module.__dict__) # pylint: disable=exec-used - return module - -#---------------------------------------------------------------------------- - -def _check_pickleable(obj): - r"""Check that the given object is pickleable, raising an exception if - it is not. This function is expected to be considerably more efficient - than actually pickling the object. - """ - def recurse(obj): - if isinstance(obj, (list, tuple, set)): - return [recurse(x) for x in obj] - if isinstance(obj, dict): - return [[recurse(x), recurse(y)] for x, y in obj.items()] - if isinstance(obj, (str, int, float, bool, bytes, bytearray)): - return None # Python primitive types are pickleable. - if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']: - return None # NumPy arrays and PyTorch tensors are pickleable. - if is_persistent(obj): - return None # Persistent objects are pickleable, by virtue of the constructor check. - return obj - with io.BytesIO() as f: - pickle.dump(recurse(obj), f) - -#---------------------------------------------------------------------------- diff --git a/PTI/torch_utils/training_stats.py b/PTI/torch_utils/training_stats.py deleted file mode 100644 index 26f467f9eaa074ee13de1cf2625cd7da44880847..0000000000000000000000000000000000000000 --- a/PTI/torch_utils/training_stats.py +++ /dev/null @@ -1,268 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -"""Facilities for reporting and collecting training statistics across -multiple processes and devices. The interface is designed to minimize -synchronization overhead as well as the amount of boilerplate in user -code.""" - -import re -import numpy as np -import torch -import dnnlib - -from . import misc - -#---------------------------------------------------------------------------- - -_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] -_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. -_counter_dtype = torch.float64 # Data type to use for the internal counters. -_rank = 0 # Rank of the current process. -_sync_device = None # Device to use for multiprocess communication. None = single-process. -_sync_called = False # Has _sync() been called yet? -_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor -_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor - -#---------------------------------------------------------------------------- - -def init_multiprocessing(rank, sync_device): - r"""Initializes `torch_utils.training_stats` for collecting statistics - across multiple processes. - - This function must be called after - `torch.distributed.init_process_group()` and before `Collector.update()`. - The call is not necessary if multi-process collection is not needed. - - Args: - rank: Rank of the current process. - sync_device: PyTorch device to use for inter-process - communication, or None to disable multi-process - collection. Typically `torch.device('cuda', rank)`. - """ - global _rank, _sync_device - assert not _sync_called - _rank = rank - _sync_device = sync_device - -#---------------------------------------------------------------------------- - -@misc.profiled_function -def report(name, value): - r"""Broadcasts the given set of scalars to all interested instances of - `Collector`, across device and process boundaries. - - This function is expected to be extremely cheap and can be safely - called from anywhere in the training loop, loss function, or inside a - `torch.nn.Module`. - - Warning: The current implementation expects the set of unique names to - be consistent across processes. Please make sure that `report()` is - called at least once for each unique name by each process, and in the - same order. If a given process has no scalars to broadcast, it can do - `report(name, [])` (empty list). - - Args: - name: Arbitrary string specifying the name of the statistic. - Averages are accumulated separately for each unique name. - value: Arbitrary set of scalars. Can be a list, tuple, - NumPy array, PyTorch tensor, or Python scalar. - - Returns: - The same `value` that was passed in. - """ - if name not in _counters: - _counters[name] = dict() - - elems = torch.as_tensor(value) - if elems.numel() == 0: - return value - - elems = elems.detach().flatten().to(_reduce_dtype) - moments = torch.stack([ - torch.ones_like(elems).sum(), - elems.sum(), - elems.square().sum(), - ]) - assert moments.ndim == 1 and moments.shape[0] == _num_moments - moments = moments.to(_counter_dtype) - - device = moments.device - if device not in _counters[name]: - _counters[name][device] = torch.zeros_like(moments) - _counters[name][device].add_(moments) - return value - -#---------------------------------------------------------------------------- - -def report0(name, value): - r"""Broadcasts the given set of scalars by the first process (`rank = 0`), - but ignores any scalars provided by the other processes. - See `report()` for further details. - """ - report(name, value if _rank == 0 else []) - return value - -#---------------------------------------------------------------------------- - -class Collector: - r"""Collects the scalars broadcasted by `report()` and `report0()` and - computes their long-term averages (mean and standard deviation) over - user-defined periods of time. - - The averages are first collected into internal counters that are not - directly visible to the user. They are then copied to the user-visible - state as a result of calling `update()` and can then be queried using - `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the - internal counters for the next round, so that the user-visible state - effectively reflects averages collected between the last two calls to - `update()`. - - Args: - regex: Regular expression defining which statistics to - collect. The default is to collect everything. - keep_previous: Whether to retain the previous averages if no - scalars were collected on a given round - (default: True). - """ - def __init__(self, regex='.*', keep_previous=True): - self._regex = re.compile(regex) - self._keep_previous = keep_previous - self._cumulative = dict() - self._moments = dict() - self.update() - self._moments.clear() - - def names(self): - r"""Returns the names of all statistics broadcasted so far that - match the regular expression specified at construction time. - """ - return [name for name in _counters if self._regex.fullmatch(name)] - - def update(self): - r"""Copies current values of the internal counters to the - user-visible state and resets them for the next round. - - If `keep_previous=True` was specified at construction time, the - operation is skipped for statistics that have received no scalars - since the last update, retaining their previous averages. - - This method performs a number of GPU-to-CPU transfers and one - `torch.distributed.all_reduce()`. It is intended to be called - periodically in the main training loop, typically once every - N training steps. - """ - if not self._keep_previous: - self._moments.clear() - for name, cumulative in _sync(self.names()): - if name not in self._cumulative: - self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) - delta = cumulative - self._cumulative[name] - self._cumulative[name].copy_(cumulative) - if float(delta[0]) != 0: - self._moments[name] = delta - - def _get_delta(self, name): - r"""Returns the raw moments that were accumulated for the given - statistic between the last two calls to `update()`, or zero if - no scalars were collected. - """ - assert self._regex.fullmatch(name) - if name not in self._moments: - self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) - return self._moments[name] - - def num(self, name): - r"""Returns the number of scalars that were accumulated for the given - statistic between the last two calls to `update()`, or zero if - no scalars were collected. - """ - delta = self._get_delta(name) - return int(delta[0]) - - def mean(self, name): - r"""Returns the mean of the scalars that were accumulated for the - given statistic between the last two calls to `update()`, or NaN if - no scalars were collected. - """ - delta = self._get_delta(name) - if int(delta[0]) == 0: - return float('nan') - return float(delta[1] / delta[0]) - - def std(self, name): - r"""Returns the standard deviation of the scalars that were - accumulated for the given statistic between the last two calls to - `update()`, or NaN if no scalars were collected. - """ - delta = self._get_delta(name) - if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): - return float('nan') - if int(delta[0]) == 1: - return float(0) - mean = float(delta[1] / delta[0]) - raw_var = float(delta[2] / delta[0]) - return np.sqrt(max(raw_var - np.square(mean), 0)) - - def as_dict(self): - r"""Returns the averages accumulated between the last two calls to - `update()` as an `dnnlib.EasyDict`. The contents are as follows: - - dnnlib.EasyDict( - NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), - ... - ) - """ - stats = dnnlib.EasyDict() - for name in self.names(): - stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) - return stats - - def __getitem__(self, name): - r"""Convenience getter. - `collector[name]` is a synonym for `collector.mean(name)`. - """ - return self.mean(name) - -#---------------------------------------------------------------------------- - -def _sync(names): - r"""Synchronize the global cumulative counters across devices and - processes. Called internally by `Collector.update()`. - """ - if len(names) == 0: - return [] - global _sync_called - _sync_called = True - - # Collect deltas within current rank. - deltas = [] - device = _sync_device if _sync_device is not None else torch.device('cpu') - for name in names: - delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) - for counter in _counters[name].values(): - delta.add_(counter.to(device)) - counter.copy_(torch.zeros_like(counter)) - deltas.append(delta) - deltas = torch.stack(deltas) - - # Sum deltas across ranks. - if _sync_device is not None: - torch.distributed.all_reduce(deltas) - - # Update cumulative values. - deltas = deltas.cpu() - for idx, name in enumerate(names): - if name not in _cumulative: - _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) - _cumulative[name].add_(deltas[idx]) - - # Return name-value pairs. - return [(name, _cumulative[name]) for name in names] - -#---------------------------------------------------------------------------- diff --git a/app.py b/app.py index 1d5f6bcdc75c1d098989629b34e0b8ca79145394..201c1e76c2d90821cf1c892421dd0d378b396b4f 100644 --- a/app.py +++ b/app.py @@ -1,40 +1,89 @@ import gradio as gr -import utils +import utils.utils as utils from PIL import Image import torch import math from torchvision import transforms - - +from run_pti import run_PTI device = "cpu" years = [str(y) for y in range(1880, 2020, 10)] +decades = [y + "s" for y in years] +transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) + orig_models = {} for year in years: G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device) - orig_models[year] = { "G": G.eval()} + orig_models[year] = { "G": G.eval().float()} def run_alignment(image_path,idx=None): import dlib from align_all_parallel import align_face predictor = dlib.shape_predictor("pretrained_models/shape_predictor_68_face_landmarks.dat") - aligned_image = align_face(filepath=image_path, predictor=predictor, idx=idx) - print("Aligned image has shape: {}".format(aligned_image.size)) + aligned_image = align_face(filepath=image_path, predictor=predictor, idx=idx) + return aligned_image -def predict(inp): +def predict(inp, in_decade): #with torch.no_grad(): inp.save("imgs/input.png") - out = run_alignment("imgs/input.png", idx=0) - return out + inversion = run_alignment("imgs/input.png", idx=0) + inversion.save("imgs/cropped/input.png") + run_PTI(run_name="gradio_demo", use_wandb=False, use_multi_id_training=False) + #inversion = Image.open("imgs/cropped/input.png") + + in_year = in_decade[:-1] + pti_models = {} + + for year in years: + G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device) + pti_models[year] = { "G": G.eval().float()} + + + pti_models[in_year]['G'] = torch.load(f"checkpoints/model_gradio_demo_input.pt", device).eval().float() + + for year in years: + if year != in_year: + for p_pti, p_orig, (names, p) in zip(pti_models[in_year]['G'].parameters(),orig_models[in_year]['G'].parameters(), pti_models[year]['G'].named_parameters()): + with torch.no_grad(): + delta = p_pti - p_orig + p += delta + + space = 0 + dst = Image.new("RGB", (256 * (len(years) + 1) + (space * len(years)), 256), color='white') + + + w_pti = torch.load(f"embeddings/{in_year}/PTI/input/0.pt", map_location=device) + + border_width = 10 + #fill_color = 'red' + dst.paste(inversion, (0, 0)) + + + + for i in range(0, len(years)): + year = str(years[i]) + with torch.no_grad(): + child_tensor = pti_models[year]["G"].synthesis(w_pti.view(1, 14, 512), noise_mode="const", force_fp32=True) + img = utils.tensor2im(child_tensor.squeeze(0)) + # if year == in_year: + # img = img.crop((border_width, border_width, 256 - border_width, 256-border_width)) + # img = PIL.ImageOps.expand(img, border=border_width, fill=fill_color) + dst.paste(img, ((256 + space) * (i+1), 0)) + dst + return dst gr.Interface(fn=predict, - inputs=gr.Image(type="pil"), - outputs=gr.Image(type="pil"), - #examples=["lion.jpg", "cheetah.jpg"] - ).launch() + inputs=[gr.Image(label="Input Image", type="pil"), gr.Dropdown(label="Input Decade", choices=decades, value="2010s")], + outputs=gr.Image(label="Decade Transformations", type="pil"), + examples=[["imgs/Steven-Yeun.jpg", "2010s"]] + + ).launch() #.launch(server_name="0.0.0.0", server_port=8098) diff --git a/checkpoints/model_gradio_demo_input.pt b/checkpoints/model_gradio_demo_input.pt new file mode 100644 index 0000000000000000000000000000000000000000..faddf9dcb32f85217f12c31e9b9cd5a8a3ccdead --- /dev/null +++ b/checkpoints/model_gradio_demo_input.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65ee5644ec8ab0966a4eb51971995c071f8178db765750d45e32e0ed18a09738 +size 99867041 diff --git a/PTI/criteria/color_transfer_loss.py b/color_transfer_loss.py similarity index 100% rename from PTI/criteria/color_transfer_loss.py rename to color_transfer_loss.py diff --git a/PTI/configs/__init__.py b/configs/__init__.py similarity index 100% rename from PTI/configs/__init__.py rename to configs/__init__.py diff --git a/PTI/configs/evaluation_config.py b/configs/evaluation_config.py similarity index 100% rename from PTI/configs/evaluation_config.py rename to configs/evaluation_config.py diff --git a/PTI/configs/global_config.py b/configs/global_config.py similarity index 80% rename from PTI/configs/global_config.py rename to configs/global_config.py index 52dc2253c4d157fe7b7e0e630dc4a6eac9e1e142..82c1033caa6ea35404886cd17be3a510ceea4fbf 100644 --- a/PTI/configs/global_config.py +++ b/configs/global_config.py @@ -1,6 +1,6 @@ ## Device -cuda_visible_devices = "1" -device = "cuda:0" +cuda_visible_devices = "0" +device = "cpu" ## Logs training_step = 1 diff --git a/PTI/configs/hyperparameters.py b/configs/hyperparameters.py similarity index 95% rename from PTI/configs/hyperparameters.py rename to configs/hyperparameters.py index 67901452855be546851289d8fb67920933922075..62871505436118e52583a5b8e918d9d191dda73b 100644 --- a/PTI/configs/hyperparameters.py +++ b/configs/hyperparameters.py @@ -28,4 +28,4 @@ max_images_to_invert = 10 pti_learning_rate = 3e-4 first_inv_lr = 5e-3 train_batch_size = 1 -use_last_w_pivots = True +use_last_w_pivots = False diff --git a/PTI/configs/paths_config.py b/configs/paths_config.py similarity index 75% rename from PTI/configs/paths_config.py rename to configs/paths_config.py index 931d699f37b17ade6560f11520169e664ebfc96e..1513a9d7b347fc7bb56ba0a4affdfe3e893cf220 100644 --- a/PTI/configs/paths_config.py +++ b/configs/paths_config.py @@ -4,12 +4,12 @@ year = "2010" e4e = "./pretrained_models/e4e_ffhq_encode.pt" -stylegan2_ada_ffhq = f"../pretrained_models/{year}.pkl" +stylegan2_ada_ffhq = f"pretrained_models/{year}.pkl" style_clip_pretrained_mappers = "" -ir_se50 = "/share/phoenix/nfs04/S7/wikitime_models/model_ir_se50.pth" +ir_se50 = "pretrained_models/model_ir_se50.pth" dlib = "./pretrained_models/align.dat" -deeplab = "/share/phoenix/nfs04/S7/wikitime_models/deeplab_model/deeplab_model.pth" +deeplab = "pretrained_models/deeplab_model/deeplab_model.pth" ## Dirs for output files checkpoints_dir = "./checkpoints" @@ -20,7 +20,7 @@ experiments_output_dir = "./output" ## Input info ### Input dir, where the images reside input_data_path = ( - f"/share/phoenix/nfs04/S7/emc348/WikiFaces/datasets/new_crops/test/{year}" + f"imgs/cropped" ) input_data_id = f"{year}" diff --git a/PTI/criteria/__init__.py b/criteria/__init__.py similarity index 100% rename from PTI/criteria/__init__.py rename to criteria/__init__.py diff --git a/PTI/criteria/backbones/__init__.py b/criteria/backbones/__init__.py similarity index 100% rename from PTI/criteria/backbones/__init__.py rename to criteria/backbones/__init__.py diff --git a/PTI/criteria/backbones/iresnet.py b/criteria/backbones/iresnet.py similarity index 100% rename from PTI/criteria/backbones/iresnet.py rename to criteria/backbones/iresnet.py diff --git a/PTI/criteria/backbones/iresnet2060.py b/criteria/backbones/iresnet2060.py similarity index 100% rename from PTI/criteria/backbones/iresnet2060.py rename to criteria/backbones/iresnet2060.py diff --git a/PTI/criteria/backbones/mobilefacenet.py b/criteria/backbones/mobilefacenet.py similarity index 100% rename from PTI/criteria/backbones/mobilefacenet.py rename to criteria/backbones/mobilefacenet.py diff --git a/PTI/criteria/deeplab.py b/criteria/deeplab.py similarity index 100% rename from PTI/criteria/deeplab.py rename to criteria/deeplab.py diff --git a/PTI/criteria/helpers.py b/criteria/helpers.py similarity index 100% rename from PTI/criteria/helpers.py rename to criteria/helpers.py diff --git a/PTI/criteria/id_loss.py b/criteria/id_loss.py similarity index 100% rename from PTI/criteria/id_loss.py rename to criteria/id_loss.py diff --git a/PTI/criteria/l2_loss.py b/criteria/l2_loss.py similarity index 100% rename from PTI/criteria/l2_loss.py rename to criteria/l2_loss.py diff --git a/PTI/criteria/localitly_regulizer.py b/criteria/localitly_regulizer.py similarity index 100% rename from PTI/criteria/localitly_regulizer.py rename to criteria/localitly_regulizer.py diff --git a/PTI/criteria/mask.py b/criteria/mask.py similarity index 100% rename from PTI/criteria/mask.py rename to criteria/mask.py diff --git a/PTI/criteria/model_irse.py b/criteria/model_irse.py similarity index 100% rename from PTI/criteria/model_irse.py rename to criteria/model_irse.py diff --git a/PTI/criteria/validation.py b/criteria/validation.py similarity index 100% rename from PTI/criteria/validation.py rename to criteria/validation.py diff --git a/dnnlib/__pycache__/__init__.cpython-39.pyc b/dnnlib/__pycache__/__init__.cpython-39.pyc index 0d65d958504c94699a494d504272f86185474ecc..56545ef7c6547f8c4b2046e5ea1d76c302995f9e 100644 Binary files a/dnnlib/__pycache__/__init__.cpython-39.pyc and b/dnnlib/__pycache__/__init__.cpython-39.pyc differ diff --git a/dnnlib/__pycache__/util.cpython-39.pyc b/dnnlib/__pycache__/util.cpython-39.pyc index f89a9685b65a3158471685f56daab3197c58be66..cdb0592847d19c19a5def909c3ff01c37300677d 100644 Binary files a/dnnlib/__pycache__/util.cpython-39.pyc and b/dnnlib/__pycache__/util.cpython-39.pyc differ diff --git a/embeddings/2010/PTI/input/0.pt b/embeddings/2010/PTI/input/0.pt new file mode 100644 index 0000000000000000000000000000000000000000..c9bed619221b4030ec4df47345c4d80f4c6267d9 --- /dev/null +++ b/embeddings/2010/PTI/input/0.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cb9b3fc3d08f6dd3ee5c87299fff8cd932e4a6afaa90fe06f3d5c5f9503ebf26 +size 29419 diff --git a/imgs/Steven-Yeun.jpg b/imgs/Steven-Yeun.jpg new file mode 100644 index 0000000000000000000000000000000000000000..28583afeb8306b28c29ca83419c7ff9d75747403 --- /dev/null +++ b/imgs/Steven-Yeun.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f7d9da1331c75fc2b8ac8caa024c804ac500c9c29b5ed4edf60bf30247eae8a5 +size 1627062 diff --git a/imgs/cropped/input.png b/imgs/cropped/input.png new file mode 100644 index 0000000000000000000000000000000000000000..b046f6481e637e5f91dfa89b37927ce441811877 --- /dev/null +++ b/imgs/cropped/input.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba7b8df0bffe226c723eb22c537e66ff9de844e6aae7845a6c88e696f03b6a40 +size 104592 diff --git a/imgs/input.png b/imgs/input.png new file mode 100644 index 0000000000000000000000000000000000000000..0ab9444f18b692a7cacd82159370a64b02a29bdd --- /dev/null +++ b/imgs/input.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f8c1b42d80f44efcf0cb03a301072284a0b7ad6ae6f11871be44b7fae79613e +size 13719370 diff --git a/PTI/training/__init__.py b/models/StyleCLIP/__init__.py similarity index 100% rename from PTI/training/__init__.py rename to models/StyleCLIP/__init__.py diff --git a/PTI/training/coaches/__init__.py b/models/StyleCLIP/criteria/__init__.py similarity index 100% rename from PTI/training/coaches/__init__.py rename to models/StyleCLIP/criteria/__init__.py diff --git a/models/StyleCLIP/criteria/clip_loss.py b/models/StyleCLIP/criteria/clip_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..18176ee8eb0d992d69d5b951d7f36e2efa92a37b --- /dev/null +++ b/models/StyleCLIP/criteria/clip_loss.py @@ -0,0 +1,17 @@ + +import torch +import clip + + +class CLIPLoss(torch.nn.Module): + + def __init__(self, opts): + super(CLIPLoss, self).__init__() + self.model, self.preprocess = clip.load("ViT-B/32", device="cuda") + self.upsample = torch.nn.Upsample(scale_factor=7) + self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32) + + def forward(self, image, text): + image = self.avg_pool(self.upsample(image)) + similarity = 1 - self.model(image, text)[0] / 100 + return similarity \ No newline at end of file diff --git a/models/StyleCLIP/criteria/id_loss.py b/models/StyleCLIP/criteria/id_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a828023e115243e48918538d31b91d662cd12d0f --- /dev/null +++ b/models/StyleCLIP/criteria/id_loss.py @@ -0,0 +1,39 @@ +import torch +from torch import nn + +from models.facial_recognition.model_irse import Backbone + + +class IDLoss(nn.Module): + def __init__(self, opts): + super(IDLoss, self).__init__() + print('Loading ResNet ArcFace') + self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') + self.facenet.load_state_dict(torch.load(opts.ir_se50_weights)) + self.pool = torch.nn.AdaptiveAvgPool2d((256, 256)) + self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) + self.facenet.eval() + self.opts = opts + + def extract_feats(self, x): + if x.shape[2] != 256: + x = self.pool(x) + x = x[:, :, 35:223, 32:220] # Crop interesting region + x = self.face_pool(x) + x_feats = self.facenet(x) + return x_feats + + def forward(self, y_hat, y): + n_samples = y.shape[0] + y_feats = self.extract_feats(y) # Otherwise use the feature from there + y_hat_feats = self.extract_feats(y_hat) + y_feats = y_feats.detach() + loss = 0 + sim_improvement = 0 + count = 0 + for i in range(n_samples): + diff_target = y_hat_feats[i].dot(y_feats[i]) + loss += 1 - diff_target + count += 1 + + return loss / count, sim_improvement / count diff --git a/models/StyleCLIP/global_directions/GUI.py b/models/StyleCLIP/global_directions/GUI.py new file mode 100644 index 0000000000000000000000000000000000000000..19f7f8cce9305819b22664642799200d9e1cfff0 --- /dev/null +++ b/models/StyleCLIP/global_directions/GUI.py @@ -0,0 +1,103 @@ + + +from tkinter import Tk,Frame ,Label,Button,messagebox,Canvas,Text,Scale +from tkinter import HORIZONTAL + +class View(): + def __init__(self,master): + + self.width=600 + self.height=600 + + + self.root=master + self.root.geometry("600x600") + + self.left_frame=Frame(self.root,width=600) + self.left_frame.pack_propagate(0) + self.left_frame.pack(fill='both', side='left', expand='True') + + self.retrieval_frame=Frame(self.root,bg='snow3') + self.retrieval_frame.pack_propagate(0) + self.retrieval_frame.pack(fill='both', side='right', expand='True') + + self.bg_frame=Frame(self.left_frame,bg='snow3',height=600,width=600) + self.bg_frame.pack_propagate(0) + self.bg_frame.pack(fill='both', side='top', expand='True') + + self.command_frame=Frame(self.left_frame,bg='snow3') + self.command_frame.pack_propagate(0) + self.command_frame.pack(fill='both', side='bottom', expand='True') +# self.command_frame.grid(row=1, column=0,padx=0, pady=0) + + self.bg=Canvas(self.bg_frame,width=self.width,height=self.height, bg='gray') + self.bg.place(relx=0.5, rely=0.5, anchor='center') + + self.mani=Canvas(self.retrieval_frame,width=1024,height=1024, bg='gray') + self.mani.grid(row=0, column=0,padx=0, pady=42) + + self.SetCommand() + + + + + def run(self): + self.root.mainloop() + + def helloCallBack(self): + category=self.set_category.get() + messagebox.showinfo( "Hello Python",category) + + def SetCommand(self): + + tmp = Label(self.command_frame, text="neutral", width=10 ,bg='snow3') + tmp.grid(row=1, column=0,padx=10, pady=10) + + tmp = Label(self.command_frame, text="a photo of a", width=10 ,bg='snow3') + tmp.grid(row=1, column=1,padx=10, pady=10) + + self.neutral = Text ( self.command_frame, height=2, width=30) + self.neutral.grid(row=1, column=2,padx=10, pady=10) + + + tmp = Label(self.command_frame, text="target", width=10 ,bg='snow3') + tmp.grid(row=2, column=0,padx=10, pady=10) + + tmp = Label(self.command_frame, text="a photo of a", width=10 ,bg='snow3') + tmp.grid(row=2, column=1,padx=10, pady=10) + + self.target = Text ( self.command_frame, height=2, width=30) + self.target.grid(row=2, column=2,padx=10, pady=10) + + tmp = Label(self.command_frame, text="strength", width=10 ,bg='snow3') + tmp.grid(row=3, column=0,padx=10, pady=10) + + self.alpha = Scale(self.command_frame, from_=-15, to=25, orient=HORIZONTAL,bg='snow3', length=250,resolution=0.01) + self.alpha.grid(row=3, column=2,padx=10, pady=10) + + + tmp = Label(self.command_frame, text="disentangle", width=10 ,bg='snow3') + tmp.grid(row=4, column=0,padx=10, pady=10) + + self.beta = Scale(self.command_frame, from_=0.08, to=0.4, orient=HORIZONTAL,bg='snow3', length=250,resolution=0.001) + self.beta.grid(row=4, column=2,padx=10, pady=10) + + self.reset = Button(self.command_frame, text='Reset') + self.reset.grid(row=5, column=1,padx=10, pady=10) + + + self.set_init = Button(self.command_frame, text='Accept') + self.set_init.grid(row=5, column=2,padx=10, pady=10) + +#%% +if __name__ == "__main__": + master=Tk() + self=View(master) + self.run() + + + + + + + \ No newline at end of file diff --git a/models/StyleCLIP/global_directions/GenerateImg.py b/models/StyleCLIP/global_directions/GenerateImg.py new file mode 100644 index 0000000000000000000000000000000000000000..0c6dee48f2d6d9ac37c00ee77c7a46c2cc6b25e1 --- /dev/null +++ b/models/StyleCLIP/global_directions/GenerateImg.py @@ -0,0 +1,50 @@ + +import os +import numpy as np +import argparse +from manipulate import Manipulator + +from PIL import Image +#%% + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Process some integers.') + + parser.add_argument('--dataset_name',type=str,default='ffhq', + help='name of dataset, for example, ffhq') + + args = parser.parse_args() + dataset_name=args.dataset_name + + if not os.path.isdir('./data/'+dataset_name): + os.system('mkdir ./data/'+dataset_name) + #%% + M=Manipulator(dataset_name=dataset_name) + np.set_printoptions(suppress=True) + print(M.dataset_name) + #%% + + M.img_index=0 + M.num_images=50 + M.alpha=[0] + M.step=1 + lindex,bname=0,0 + + M.manipulate_layers=[lindex] + codes,out=M.EditOneC(bname) + #%% + + for i in range(len(out)): + img=out[i,0] + img=Image.fromarray(img) + img.save('./data/'+dataset_name+'/'+str(i)+'.jpg') + #%% + w=np.load('./npy/'+dataset_name+'/W.npy') + + tmp=w[:M.num_images] + tmp=tmp[:,None,:] + tmp=np.tile(tmp,(1,M.Gs.components.synthesis.input_shape[1],1)) + + np.save('./data/'+dataset_name+'/w_plus.npy',tmp) + + \ No newline at end of file diff --git a/models/StyleCLIP/global_directions/GetCode.py b/models/StyleCLIP/global_directions/GetCode.py new file mode 100644 index 0000000000000000000000000000000000000000..62e64dc8cbc5ad2bb16aef5da8f6d41c26b24170 --- /dev/null +++ b/models/StyleCLIP/global_directions/GetCode.py @@ -0,0 +1,232 @@ + + + +import os +import pickle +import numpy as np +from dnnlib import tflib +import tensorflow as tf + +import argparse + +def LoadModel(dataset_name): + # Initialize TensorFlow. + tflib.init_tf() + model_path='./model/' + model_name=dataset_name+'.pkl' + + tmp=os.path.join(model_path,model_name) + with open(tmp, 'rb') as f: + _, _, Gs = pickle.load(f) + return Gs + +def lerp(a,b,t): + return a + (b - a) * t + +#stylegan-ada +def SelectName(layer_name,suffix): + if suffix==None: + tmp1='add:0' in layer_name + tmp2='shape=(?,' in layer_name + tmp4='G_synthesis_1' in layer_name + tmp= tmp1 and tmp2 and tmp4 + else: + tmp1=('/Conv0_up'+suffix) in layer_name + tmp2=('/Conv1'+suffix) in layer_name + tmp3=('4x4/Conv'+suffix) in layer_name + tmp4='G_synthesis_1' in layer_name + tmp5=('/ToRGB'+suffix) in layer_name + tmp= (tmp1 or tmp2 or tmp3 or tmp5) and tmp4 + return tmp + + +def GetSNames(suffix): + #get style tensor name + with tf.Session() as sess: + op = sess.graph.get_operations() + layers=[m.values() for m in op] + + + select_layers=[] + for layer in layers: + layer_name=str(layer) + if SelectName(layer_name,suffix): + select_layers.append(layer[0]) + return select_layers + +def SelectName2(layer_name): + tmp1='mod_bias' in layer_name + tmp2='mod_weight' in layer_name + tmp3='ToRGB' in layer_name + + tmp= (tmp1 or tmp2) and (not tmp3) + return tmp + +def GetKName(Gs): + + layers=[var for name, var in Gs.components.synthesis.vars.items()] + + select_layers=[] + for layer in layers: + layer_name=str(layer) + if SelectName2(layer_name): + select_layers.append(layer) + return select_layers + +def GetCode(Gs,random_state,num_img,num_once,dataset_name): + rnd = np.random.RandomState(random_state) #5 + + truncation_psi=0.7 + truncation_cutoff=8 + + dlatent_avg=Gs.get_var('dlatent_avg') + + dlatents=np.zeros((num_img,512),dtype='float32') + for i in range(int(num_img/num_once)): + src_latents = rnd.randn(num_once, Gs.input_shape[1]) + src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component] + + # Apply truncation trick. + if truncation_psi is not None and truncation_cutoff is not None: + layer_idx = np.arange(src_dlatents.shape[1])[np.newaxis, :, np.newaxis] + ones = np.ones(layer_idx.shape, dtype=np.float32) + coefs = np.where(layer_idx < truncation_cutoff, truncation_psi * ones, ones) + src_dlatents_np=lerp(dlatent_avg, src_dlatents, coefs) + src_dlatents=src_dlatents_np[:,0,:].astype('float32') + dlatents[(i*num_once):((i+1)*num_once),:]=src_dlatents + print('get all z and w') + + tmp='./npy/'+dataset_name+'/W' + np.save(tmp,dlatents) + + +def GetImg(Gs,num_img,num_once,dataset_name,save_name='images'): + print('Generate Image') + tmp='./npy/'+dataset_name+'/W.npy' + dlatents=np.load(tmp) + fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) + + all_images=[] + for i in range(int(num_img/num_once)): + print(i) + images=[] + for k in range(num_once): + tmp=dlatents[i*num_once+k] + tmp=tmp[None,None,:] + tmp=np.tile(tmp,(1,Gs.components.synthesis.input_shape[1],1)) + image2= Gs.components.synthesis.run(tmp, randomize_noise=False, output_transform=fmt) + images.append(image2) + + images=np.concatenate(images) + + all_images.append(images) + + all_images=np.concatenate(all_images) + + tmp='./npy/'+dataset_name+'/'+save_name + np.save(tmp,all_images) + +def GetS(dataset_name,num_img): + print('Generate S') + tmp='./npy/'+dataset_name+'/W.npy' + dlatents=np.load(tmp)[:num_img] + + with tf.Session() as sess: + init = tf.global_variables_initializer() + sess.run(init) + + Gs=LoadModel(dataset_name) + Gs.print_layers() #for ada + select_layers1=GetSNames(suffix=None) #None,'/mul_1:0','/mod_weight/read:0','/MatMul:0' + dlatents=dlatents[:,None,:] + dlatents=np.tile(dlatents,(1,Gs.components.synthesis.input_shape[1],1)) + + all_s = sess.run( + select_layers1, + feed_dict={'G_synthesis_1/dlatents_in:0': dlatents}) + + layer_names=[layer.name for layer in select_layers1] + save_tmp=[layer_names,all_s] + return save_tmp + + + + +def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False): + """Convert a minibatch of images from float32 to uint8 with configurable dynamic range. + Can be used as an output transformation for Network.run(). + """ + if nchw_to_nhwc: + images = np.transpose(images, [0, 2, 3, 1]) + + scale = 255 / (drange[1] - drange[0]) + images = images * scale + (0.5 - drange[0] * scale) + + np.clip(images, 0, 255, out=images) + images=images.astype('uint8') + return images + + +def GetCodeMS(dlatents): + m=[] + std=[] + for i in range(len(dlatents)): + tmp= dlatents[i] + tmp_mean=tmp.mean(axis=0) + tmp_std=tmp.std(axis=0) + m.append(tmp_mean) + std.append(tmp_std) + return m,std + + + +#%% +if __name__ == "__main__": + + + parser = argparse.ArgumentParser(description='Process some integers.') + + parser.add_argument('--dataset_name',type=str,default='ffhq', + help='name of dataset, for example, ffhq') + parser.add_argument('--code_type',choices=['w','s','s_mean_std'],default='w') + + args = parser.parse_args() + random_state=5 + num_img=100_000 + num_once=1_000 + dataset_name=args.dataset_name + + if not os.path.isfile('./model/'+dataset_name+'.pkl'): + url='https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/' + name='stylegan2-'+dataset_name+'-config-f.pkl' + os.system('wget ' +url+name + ' -P ./model/') + os.system('mv ./model/'+name+' ./model/'+dataset_name+'.pkl') + + if not os.path.isdir('./npy/'+dataset_name): + os.system('mkdir ./npy/'+dataset_name) + + if args.code_type=='w': + Gs=LoadModel(dataset_name=dataset_name) + GetCode(Gs,random_state,num_img,num_once,dataset_name) +# GetImg(Gs,num_img=num_img,num_once=num_once,dataset_name=dataset_name,save_name='images_100K') #no need + elif args.code_type=='s': + save_name='S' + save_tmp=GetS(dataset_name,num_img=2_000) + tmp='./npy/'+dataset_name+'/'+save_name + with open(tmp, "wb") as fp: + pickle.dump(save_tmp, fp) + + elif args.code_type=='s_mean_std': + save_tmp=GetS(dataset_name,num_img=num_img) + dlatents=save_tmp[1] + m,std=GetCodeMS(dlatents) + save_tmp=[m,std] + save_name='S_mean_std' + tmp='./npy/'+dataset_name+'/'+save_name + with open(tmp, "wb") as fp: + pickle.dump(save_tmp, fp) + + + + + diff --git a/models/StyleCLIP/global_directions/GetGUIData.py b/models/StyleCLIP/global_directions/GetGUIData.py new file mode 100644 index 0000000000000000000000000000000000000000..52f77213ab88edf8b33eff166b89b9e56ac4ff01 --- /dev/null +++ b/models/StyleCLIP/global_directions/GetGUIData.py @@ -0,0 +1,67 @@ + +import os +import numpy as np +import argparse +from manipulate import Manipulator +import torch +from PIL import Image +#%% + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Process some integers.') + + parser.add_argument('--dataset_name',type=str,default='ffhq', + help='name of dataset, for example, ffhq') + + parser.add_argument('--real', action='store_true') + + args = parser.parse_args() + dataset_name=args.dataset_name + + if not os.path.isdir('./data/'+dataset_name): + os.system('mkdir ./data/'+dataset_name) + #%% + M=Manipulator(dataset_name=dataset_name) + np.set_printoptions(suppress=True) + print(M.dataset_name) + #%% + #remove all .jpg + names=os.listdir('./data/'+dataset_name+'/') + for name in names: + if '.jpg' in name: + os.system('rm ./data/'+dataset_name+'/'+name) + + + #%% + if args.real: + latents=torch.load('./data/'+dataset_name+'/latents.pt') + w_plus=latents.cpu().detach().numpy() + else: + w=np.load('./npy/'+dataset_name+'/W.npy') + tmp=w[:50] #only use 50 images + tmp=tmp[:,None,:] + w_plus=np.tile(tmp,(1,M.Gs.components.synthesis.input_shape[1],1)) + np.save('./data/'+dataset_name+'/w_plus.npy',w_plus) + + #%% + tmp=M.W2S(w_plus) + M.dlatents=tmp + + M.img_index=0 + M.num_images=len(w_plus) + M.alpha=[0] + M.step=1 + lindex,bname=0,0 + + M.manipulate_layers=[lindex] + codes,out=M.EditOneC(bname) + #%% + + for i in range(len(out)): + img=out[i,0] + img=Image.fromarray(img) + img.save('./data/'+dataset_name+'/'+str(i)+'.jpg') + #%% + + + \ No newline at end of file diff --git a/models/StyleCLIP/global_directions/Inference.py b/models/StyleCLIP/global_directions/Inference.py new file mode 100644 index 0000000000000000000000000000000000000000..a292787c88a370b15b4f0d633ac27bb5bed2b510 --- /dev/null +++ b/models/StyleCLIP/global_directions/Inference.py @@ -0,0 +1,106 @@ + + +from manipulate import Manipulator +import tensorflow as tf +import numpy as np +import torch +import clip +from MapTS import GetBoundary,GetDt + +class StyleCLIP(): + + def __init__(self,dataset_name='ffhq'): + print('load clip') + device = "cuda" if torch.cuda.is_available() else "cpu" + self.model, preprocess = clip.load("ViT-B/32", device=device) + self.LoadData(dataset_name) + + def LoadData(self, dataset_name): + tf.keras.backend.clear_session() + M=Manipulator(dataset_name=dataset_name) + np.set_printoptions(suppress=True) + fs3=np.load('./npy/'+dataset_name+'/fs3.npy') + + self.M=M + self.fs3=fs3 + + w_plus=np.load('./data/'+dataset_name+'/w_plus.npy') + self.M.dlatents=M.W2S(w_plus) + + if dataset_name=='ffhq': + self.c_threshold=20 + else: + self.c_threshold=100 + self.SetInitP() + + def SetInitP(self): + self.M.alpha=[3] + self.M.num_images=1 + + self.target='' + self.neutral='' + self.GetDt2() + img_index=0 + self.M.dlatent_tmp=[tmp[img_index:(img_index+1)] for tmp in self.M.dlatents] + + + def GetDt2(self): + classnames=[self.target,self.neutral] + dt=GetDt(classnames,self.model) + + self.dt=dt + num_cs=[] + betas=np.arange(0.1,0.3,0.01) + for i in range(len(betas)): + boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=betas[i]) + print(betas[i]) + num_cs.append(num_c) + + num_cs=np.array(num_cs) + select=num_cs>self.c_threshold + + if sum(select)==0: + self.beta=0.1 + else: + self.beta=betas[select][-1] + + + def GetCode(self): + boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=self.beta) + codes=self.M.MSCode(self.M.dlatent_tmp,boundary_tmp2) + return codes + + def GetImg(self): + + codes=self.GetCode() + out=self.M.GenerateImg(codes) + img=out[0,0] + return img + + + + +#%% +if __name__ == "__main__": + style_clip=StyleCLIP() + self=style_clip + + + + + + + + + + + + + + + + + + + + diff --git a/models/StyleCLIP/global_directions/MapTS.py b/models/StyleCLIP/global_directions/MapTS.py new file mode 100644 index 0000000000000000000000000000000000000000..2160a62cdbb0278d213076637f79b1e6f66db906 --- /dev/null +++ b/models/StyleCLIP/global_directions/MapTS.py @@ -0,0 +1,394 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Thu Feb 4 17:36:31 2021 + +@author: wuzongze +""" + +import os +#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +#os.environ["CUDA_VISIBLE_DEVICES"] = "1" #(or "1" or "2") + +import sys + +#sys.path=['', '/usr/local/tensorflow/avx-avx2-gpu/1.14.0/python3.7/site-packages', '/usr/local/matlab/2018b/lib/python3.7/site-packages', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python37.zip', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/lib-dynload', '/usr/lib/python3.7', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/site-packages', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/site-packages/copkmeans-1.5-py3.7.egg', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/site-packages/spherecluster-0.1.7-py3.7.egg', '/usr/lib/python3/dist-packages', '/usr/local/lib/python3.7/dist-packages', '/usr/lib/python3/dist-packages/IPython/extensions'] + +import tensorflow as tf + +import numpy as np +import torch +import clip +from PIL import Image +import pickle +import copy +import matplotlib.pyplot as plt + +def GetAlign(out,dt,model,preprocess): + imgs=out + imgs1=imgs.reshape([-1]+list(imgs.shape[2:])) + + tmp=[] + for i in range(len(imgs1)): + + img=Image.fromarray(imgs1[i]) + image = preprocess(img).unsqueeze(0).to(device) + tmp.append(image) + + image=torch.cat(tmp) + + with torch.no_grad(): + image_features = model.encode_image(image) + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + + image_features1=image_features.cpu().numpy() + + image_features1=image_features1.reshape(list(imgs.shape[:2])+[512]) + + fd=image_features1[:,1:,:]-image_features1[:,:-1,:] + + fd1=fd.reshape([-1,512]) + fd2=fd1/np.linalg.norm(fd1,axis=1)[:,None] + + tmp=np.dot(fd2,dt) + m=tmp.mean() + acc=np.sum(tmp>0)/len(tmp) + print(m,acc) + return m,acc + + +def SplitS(ds_p,M,if_std): + all_ds=[] + start=0 + for i in M.mindexs: + tmp=M.dlatents[i].shape[1] + end=start+tmp + tmp=ds_p[start:end] +# tmp=tmp*M.code_std[i] + + all_ds.append(tmp) + start=end + + all_ds2=[] + tmp_index=0 + for i in range(len(M.s_names)): + if (not 'RGB' in M.s_names[i]) and (not len(all_ds[tmp_index])==0): + +# tmp=np.abs(all_ds[tmp_index]/M.code_std[i]) +# print(i,tmp.mean()) +# tmp=np.dot(M.latent_codes[i],all_ds[tmp_index]) +# print(tmp) + if if_std: + tmp=all_ds[tmp_index]*M.code_std[i] + else: + tmp=all_ds[tmp_index] + + all_ds2.append(tmp) + tmp_index+=1 + else: + tmp=np.zeros(len(M.dlatents[i][0])) + all_ds2.append(tmp) + return all_ds2 + + +imagenet_templates = [ + 'a bad photo of a {}.', +# 'a photo of many {}.', + 'a sculpture of a {}.', + 'a photo of the hard to see {}.', + 'a low resolution photo of the {}.', + 'a rendering of a {}.', + 'graffiti of a {}.', + 'a bad photo of the {}.', + 'a cropped photo of the {}.', + 'a tattoo of a {}.', + 'the embroidered {}.', + 'a photo of a hard to see {}.', + 'a bright photo of a {}.', + 'a photo of a clean {}.', + 'a photo of a dirty {}.', + 'a dark photo of the {}.', + 'a drawing of a {}.', + 'a photo of my {}.', + 'the plastic {}.', + 'a photo of the cool {}.', + 'a close-up photo of a {}.', + 'a black and white photo of the {}.', + 'a painting of the {}.', + 'a painting of a {}.', + 'a pixelated photo of the {}.', + 'a sculpture of the {}.', + 'a bright photo of the {}.', + 'a cropped photo of a {}.', + 'a plastic {}.', + 'a photo of the dirty {}.', + 'a jpeg corrupted photo of a {}.', + 'a blurry photo of the {}.', + 'a photo of the {}.', + 'a good photo of the {}.', + 'a rendering of the {}.', + 'a {} in a video game.', + 'a photo of one {}.', + 'a doodle of a {}.', + 'a close-up photo of the {}.', + 'a photo of a {}.', + 'the origami {}.', + 'the {} in a video game.', + 'a sketch of a {}.', + 'a doodle of the {}.', + 'a origami {}.', + 'a low resolution photo of a {}.', + 'the toy {}.', + 'a rendition of the {}.', + 'a photo of the clean {}.', + 'a photo of a large {}.', + 'a rendition of a {}.', + 'a photo of a nice {}.', + 'a photo of a weird {}.', + 'a blurry photo of a {}.', + 'a cartoon {}.', + 'art of a {}.', + 'a sketch of the {}.', + 'a embroidered {}.', + 'a pixelated photo of a {}.', + 'itap of the {}.', + 'a jpeg corrupted photo of the {}.', + 'a good photo of a {}.', + 'a plushie {}.', + 'a photo of the nice {}.', + 'a photo of the small {}.', + 'a photo of the weird {}.', + 'the cartoon {}.', + 'art of the {}.', + 'a drawing of the {}.', + 'a photo of the large {}.', + 'a black and white photo of a {}.', + 'the plushie {}.', + 'a dark photo of a {}.', + 'itap of a {}.', + 'graffiti of the {}.', + 'a toy {}.', + 'itap of my {}.', + 'a photo of a cool {}.', + 'a photo of a small {}.', + 'a tattoo of the {}.', +] + + +def zeroshot_classifier(classnames, templates,model): + with torch.no_grad(): + zeroshot_weights = [] + for classname in classnames: + texts = [template.format(classname) for template in templates] #format with class + texts = clip.tokenize(texts).cuda() #tokenize + class_embeddings = model.encode_text(texts) #embed with text encoder + class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) + class_embedding = class_embeddings.mean(dim=0) + class_embedding /= class_embedding.norm() + zeroshot_weights.append(class_embedding) + zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda() + return zeroshot_weights + + +def GetDt(classnames,model): + text_features=zeroshot_classifier(classnames, imagenet_templates,model).t() + + dt=text_features[0]-text_features[1] + dt=dt.cpu().numpy() + +# t_m1=t_m/np.linalg.norm(t_m) +# dt=text_features.cpu().numpy()[0]-t_m1 + print(np.linalg.norm(dt)) + dt=dt/np.linalg.norm(dt) + return dt + + +def GetBoundary(fs3,dt,M,threshold): + tmp=np.dot(fs3,dt) + + ds_imp=copy.copy(tmp) + select=np.abs(tmp)", self.text_n) + self.view.target.bind("", self.text_t) + self.view.alpha.bind('', self.ChangeAlpha) + self.view.beta.bind('', self.ChangeBeta) + self.view.set_init.bind('', self.SetInit) + self.view.reset.bind('', self.Reset) + self.view.bg.bind('', self.open_img) + + + self.drawn = None + + self.view.target.delete(1.0, "end") + self.view.target.insert("end", self.style_clip.target) +# + self.view.neutral.delete(1.0, "end") + self.view.neutral.insert("end", self.style_clip.neutral) + + + def Reset(self,event): + self.style_clip.GetDt2() + self.style_clip.M.alpha=[0] + + self.view.beta.set(self.style_clip.beta) + self.view.alpha.set(0) + + img=self.style_clip.GetImg() + img=Image.fromarray(img) + img = ImageTk.PhotoImage(img) + self.addImage_m(img) + + + def SetInit(self,event): + codes=self.style_clip.GetCode() + self.style_clip.M.dlatent_tmp=[tmp[:,0] for tmp in codes] + print('set init') + + def ChangeAlpha(self,event): + tmp=self.view.alpha.get() + self.style_clip.M.alpha=[float(tmp)] + + img=self.style_clip.GetImg() + print('manipulate one') + img=Image.fromarray(img) + img = ImageTk.PhotoImage(img) + self.addImage_m(img) + + def ChangeBeta(self,event): + tmp=self.view.beta.get() + self.style_clip.beta=float(tmp) + + img=self.style_clip.GetImg() + print('manipulate one') + img=Image.fromarray(img) + img = ImageTk.PhotoImage(img) + self.addImage_m(img) + + def ChangeDataset(self,event): + + dataset_name=self.view.set_category.get() + + self.style_clip.LoadData(dataset_name) + + self.view.target.delete(1.0, "end") + self.view.target.insert("end", self.style_clip.target) + + self.view.neutral.delete(1.0, "end") + self.view.neutral.insert("end", self.style_clip.neutral) + + def text_t(self,event): + tmp=self.view.target.get("1.0",'end') + tmp=tmp.replace('\n','') + + self.view.target.delete(1.0, "end") + self.view.target.insert("end", tmp) + + print('target',tmp,'###') + self.style_clip.target=tmp + self.style_clip.GetDt2() + self.view.beta.set(self.style_clip.beta) + self.view.alpha.set(3) + self.style_clip.M.alpha=[3] + + img=self.style_clip.GetImg() + print('manipulate one') + img=Image.fromarray(img) + img = ImageTk.PhotoImage(img) + self.addImage_m(img) + + + def text_n(self,event): + tmp=self.view.neutral.get("1.0",'end') + tmp=tmp.replace('\n','') + + self.view.neutral.delete(1.0, "end") + self.view.neutral.insert("end", tmp) + + print('neutral',tmp,'###') + self.style_clip.neutral=tmp + self.view.target.delete(1.0, "end") + self.view.target.insert("end", tmp) + + + def run(self): + self.root.mainloop() + + def addImage(self,img): + self.view.bg.create_image(self.view.width/2, self.view.height/2, image=img, anchor='center') + self.image=img #save a copy of image. if not the image will disappear + + def addImage_m(self,img): + self.view.mani.create_image(512, 512, image=img, anchor='center') + self.image2=img + + + def openfn(self): + filename = askopenfilename(title='open',initialdir='./data/'+self.style_clip.M.dataset_name+'/',filetypes=[("all image format", ".jpg"),("all image format", ".png")]) + return filename + + def open_img(self,event): + x = self.openfn() + print(x) + + + img = Image.open(x) + img2 = img.resize(( 512,512), Image.ANTIALIAS) + img2 = ImageTk.PhotoImage(img2) + self.addImage(img2) + + img = ImageTk.PhotoImage(img) + self.addImage_m(img) + + img_index=x.split('/')[-1].split('.')[0] + img_index=int(img_index) + print(img_index) + self.style_clip.M.img_index=img_index + self.style_clip.M.dlatent_tmp=[tmp[img_index:(img_index+1)] for tmp in self.style_clip.M.dlatents] + + + self.style_clip.GetDt2() + self.view.beta.set(self.style_clip.beta) + self.view.alpha.set(3) + + #%% +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Process some integers.') + + parser.add_argument('--dataset_name',type=str,default='ffhq', + help='name of dataset, for example, ffhq') + + args = parser.parse_args() + dataset_name=args.dataset_name + + self=PlayInteractively(dataset_name) + self.run() + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/models/StyleCLIP/global_directions/SingleChannel.py b/models/StyleCLIP/global_directions/SingleChannel.py new file mode 100644 index 0000000000000000000000000000000000000000..ecaa7ec7898d37f8f5db171f9141a5253af3fa73 --- /dev/null +++ b/models/StyleCLIP/global_directions/SingleChannel.py @@ -0,0 +1,109 @@ + + + +import numpy as np +import torch +import clip +from PIL import Image +import copy +from manipulate import Manipulator +import argparse + +def GetImgF(out,model,preprocess): + imgs=out + imgs1=imgs.reshape([-1]+list(imgs.shape[2:])) + + tmp=[] + for i in range(len(imgs1)): + + img=Image.fromarray(imgs1[i]) + image = preprocess(img).unsqueeze(0).to(device) + tmp.append(image) + + image=torch.cat(tmp) + with torch.no_grad(): + image_features = model.encode_image(image) + + image_features1=image_features.cpu().numpy() + image_features1=image_features1.reshape(list(imgs.shape[:2])+[512]) + + return image_features1 + +def GetFs(fs): + tmp=np.linalg.norm(fs,axis=-1) + fs1=fs/tmp[:,:,:,None] + fs2=fs1[:,:,1,:]-fs1[:,:,0,:] # 5*sigma - (-5)* sigma + fs3=fs2/np.linalg.norm(fs2,axis=-1)[:,:,None] + fs3=fs3.mean(axis=1) + fs3=fs3/np.linalg.norm(fs3,axis=-1)[:,None] + return fs3 + +#%% +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Process some integers.') + + parser.add_argument('--dataset_name',type=str,default='cat', + help='name of dataset, for example, ffhq') + args = parser.parse_args() + dataset_name=args.dataset_name + + #%% + device = "cuda" if torch.cuda.is_available() else "cpu" + model, preprocess = clip.load("ViT-B/32", device=device) + #%% + M=Manipulator(dataset_name=dataset_name) + np.set_printoptions(suppress=True) + print(M.dataset_name) + #%% + img_sindex=0 + num_images=100 + dlatents_o=[] + tmp=img_sindex*num_images + for i in range(len(M.dlatents)): + tmp1=M.dlatents[i][tmp:(tmp+num_images)] + dlatents_o.append(tmp1) + #%% + + all_f=[] + M.alpha=[-5,5] #ffhq 5 + M.step=2 + M.num_images=num_images + select=np.array(M.mindexs)<=16 #below or equal to 128 resolution + mindexs2=np.array(M.mindexs)[select] + for lindex in mindexs2: #ignore ToRGB layers + print(lindex) + num_c=M.dlatents[lindex].shape[1] + for cindex in range(num_c): + + M.dlatents=copy.copy(dlatents_o) + M.dlatents[lindex][:,cindex]=M.code_mean[lindex][cindex] + + M.manipulate_layers=[lindex] + codes,out=M.EditOneC(cindex) + image_features1=GetImgF(out,model,preprocess) + all_f.append(image_features1) + + all_f=np.array(all_f) + + fs3=GetFs(all_f) + + #%% + file_path='./npy/'+M.dataset_name+'/' + np.save(file_path+'fs3',fs3) + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/PTI/training/projectors/__init__.py b/models/StyleCLIP/global_directions/__init__.py similarity index 100% rename from PTI/training/projectors/__init__.py rename to models/StyleCLIP/global_directions/__init__.py diff --git a/models/StyleCLIP/global_directions/data/ffhq/w_plus.npy b/models/StyleCLIP/global_directions/data/ffhq/w_plus.npy new file mode 100644 index 0000000000000000000000000000000000000000..db524aae88e16239679a8f72ccb3403fd16c95a9 --- /dev/null +++ b/models/StyleCLIP/global_directions/data/ffhq/w_plus.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:394f0f166305654f49cd1b0cd3d4f2b7a51e740a449a1ebfa1c69f79d01399fa +size 2506880 diff --git a/PTI/dnnlib/__init__.py b/models/StyleCLIP/global_directions/dnnlib/__init__.py similarity index 86% rename from PTI/dnnlib/__init__.py rename to models/StyleCLIP/global_directions/dnnlib/__init__.py index 2f08cf36f11f9b0fd94c1b7caeadf69b98375b04..c73940d81233142ae3dcd9a37b7ec2185c5d5fc5 100644 --- a/PTI/dnnlib/__init__.py +++ b/models/StyleCLIP/global_directions/dnnlib/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation diff --git a/PTI/torch_utils/ops/__init__.py b/models/StyleCLIP/global_directions/dnnlib/tflib/__init__.py similarity index 54% rename from PTI/torch_utils/ops/__init__.py rename to models/StyleCLIP/global_directions/dnnlib/tflib/__init__.py index ece0ea08fe2e939cc260a1dafc0ab5b391b773d9..ca852844ec488c0134bffa647e25a40646ff4718 100644 --- a/PTI/torch_utils/ops/__init__.py +++ b/models/StyleCLIP/global_directions/dnnlib/tflib/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation @@ -6,4 +6,15 @@ # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. -# empty +from . import autosummary +from . import network +from . import optimizer +from . import tfutil +from . import custom_ops + +from .tfutil import * +from .network import Network + +from .optimizer import Optimizer + +from .custom_ops import get_plugin diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/autosummary.py b/models/StyleCLIP/global_directions/dnnlib/tflib/autosummary.py new file mode 100644 index 0000000000000000000000000000000000000000..56dfb96093bb5b1129a99585b4ce655b98d80009 --- /dev/null +++ b/models/StyleCLIP/global_directions/dnnlib/tflib/autosummary.py @@ -0,0 +1,193 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Helper for adding automatically tracked values to Tensorboard. + +Autosummary creates an identity op that internally keeps track of the input +values and automatically shows up in TensorBoard. The reported value +represents an average over input components. The average is accumulated +constantly over time and flushed when save_summaries() is called. + +Notes: +- The output tensor must be used as an input for something else in the + graph. Otherwise, the autosummary op will not get executed, and the average + value will not get accumulated. +- It is perfectly fine to include autosummaries with the same name in + several places throughout the graph, even if they are executed concurrently. +- It is ok to also pass in a python scalar or numpy array. In this case, it + is added to the average immediately. +""" + +from collections import OrderedDict +import numpy as np +import tensorflow as tf +from tensorboard import summary as summary_lib +from tensorboard.plugins.custom_scalar import layout_pb2 + +from . import tfutil +from .tfutil import TfExpression +from .tfutil import TfExpressionEx + +# Enable "Custom scalars" tab in TensorBoard for advanced formatting. +# Disabled by default to reduce tfevents file size. +enable_custom_scalars = False + +_dtype = tf.float64 +_vars = OrderedDict() # name => [var, ...] +_immediate = OrderedDict() # name => update_op, update_value +_finalized = False +_merge_op = None + + +def _create_var(name: str, value_expr: TfExpression) -> TfExpression: + """Internal helper for creating autosummary accumulators.""" + assert not _finalized + name_id = name.replace("/", "_") + v = tf.cast(value_expr, _dtype) + + if v.shape.is_fully_defined(): + size = np.prod(v.shape.as_list()) + size_expr = tf.constant(size, dtype=_dtype) + else: + size = None + size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype)) + + if size == 1: + if v.shape.ndims != 0: + v = tf.reshape(v, []) + v = [size_expr, v, tf.square(v)] + else: + v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))] + v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype)) + + with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None): + var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)] + update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v)) + + if name in _vars: + _vars[name].append(var) + else: + _vars[name] = [var] + return update_op + + +def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None, condition: TfExpressionEx = True) -> TfExpressionEx: + """Create a new autosummary. + + Args: + name: Name to use in TensorBoard + value: TensorFlow expression or python value to track + passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node. + + Example use of the passthru mechanism: + + n = autosummary('l2loss', loss, passthru=n) + + This is a shorthand for the following code: + + with tf.control_dependencies([autosummary('l2loss', loss)]): + n = tf.identity(n) + """ + tfutil.assert_tf_initialized() + name_id = name.replace("/", "_") + + if tfutil.is_tf_expression(value): + with tf.name_scope("summary_" + name_id), tf.device(value.device): + condition = tf.convert_to_tensor(condition, name='condition') + update_op = tf.cond(condition, lambda: tf.group(_create_var(name, value)), tf.no_op) + with tf.control_dependencies([update_op]): + return tf.identity(value if passthru is None else passthru) + + else: # python scalar or numpy array + assert not tfutil.is_tf_expression(passthru) + assert not tfutil.is_tf_expression(condition) + if condition: + if name not in _immediate: + with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None): + update_value = tf.placeholder(_dtype) + update_op = _create_var(name, update_value) + _immediate[name] = update_op, update_value + update_op, update_value = _immediate[name] + tfutil.run(update_op, {update_value: value}) + return value if passthru is None else passthru + + +def finalize_autosummaries() -> None: + """Create the necessary ops to include autosummaries in TensorBoard report. + Note: This should be done only once per graph. + """ + global _finalized + tfutil.assert_tf_initialized() + + if _finalized: + return None + + _finalized = True + tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list]) + + # Create summary ops. + with tf.device(None), tf.control_dependencies(None): + for name, vars_list in _vars.items(): + name_id = name.replace("/", "_") + with tfutil.absolute_name_scope("Autosummary/" + name_id): + moments = tf.add_n(vars_list) + moments /= moments[0] + with tf.control_dependencies([moments]): # read before resetting + reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list] + with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting + mean = moments[1] + std = tf.sqrt(moments[2] - tf.square(moments[1])) + tf.summary.scalar(name, mean) + if enable_custom_scalars: + tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std) + tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std) + + # Setup layout for custom scalars. + layout = None + if enable_custom_scalars: + cat_dict = OrderedDict() + for series_name in sorted(_vars.keys()): + p = series_name.split("/") + cat = p[0] if len(p) >= 2 else "" + chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1] + if cat not in cat_dict: + cat_dict[cat] = OrderedDict() + if chart not in cat_dict[cat]: + cat_dict[cat][chart] = [] + cat_dict[cat][chart].append(series_name) + categories = [] + for cat_name, chart_dict in cat_dict.items(): + charts = [] + for chart_name, series_names in chart_dict.items(): + series = [] + for series_name in series_names: + series.append(layout_pb2.MarginChartContent.Series( + value=series_name, + lower="xCustomScalars/" + series_name + "/margin_lo", + upper="xCustomScalars/" + series_name + "/margin_hi")) + margin = layout_pb2.MarginChartContent(series=series) + charts.append(layout_pb2.Chart(title=chart_name, margin=margin)) + categories.append(layout_pb2.Category(title=cat_name, chart=charts)) + layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories)) + return layout + +def save_summaries(file_writer, global_step=None): + """Call FileWriter.add_summary() with all summaries in the default graph, + automatically finalizing and merging them on the first call. + """ + global _merge_op + tfutil.assert_tf_initialized() + + if _merge_op is None: + layout = finalize_autosummaries() + if layout is not None: + file_writer.add_summary(layout) + with tf.device(None), tf.control_dependencies(None): + _merge_op = tf.summary.merge_all() + + file_writer.add_summary(_merge_op.eval(), global_step) diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/custom_ops.py b/models/StyleCLIP/global_directions/dnnlib/tflib/custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..702471e2006af6858345c1225c1e55b0acd17d32 --- /dev/null +++ b/models/StyleCLIP/global_directions/dnnlib/tflib/custom_ops.py @@ -0,0 +1,181 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""TensorFlow custom ops builder. +""" + +import glob +import os +import re +import uuid +import hashlib +import tempfile +import shutil +import tensorflow as tf +from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module + +from .. import util + +#---------------------------------------------------------------------------- +# Global configs. + +cuda_cache_path = None +cuda_cache_version_tag = 'v1' +do_not_hash_included_headers = True # Speed up compilation by assuming that headers included by the CUDA code never change. +verbose = True # Print status messages to stdout. + +#---------------------------------------------------------------------------- +# Internal helper funcs. + +def _find_compiler_bindir(): + hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + if hostx64_paths != []: + return hostx64_paths[0] + hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + if hostx64_paths != []: + return hostx64_paths[0] + hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + if hostx64_paths != []: + return hostx64_paths[0] + vc_bin_dir = 'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin' + if os.path.isdir(vc_bin_dir): + return vc_bin_dir + return None + +def _get_compute_cap(device): + caps_str = device.physical_device_desc + m = re.search('compute capability: (\\d+).(\\d+)', caps_str) + major = m.group(1) + minor = m.group(2) + return (major, minor) + +def _get_cuda_gpu_arch_string(): + gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU'] + if len(gpus) == 0: + raise RuntimeError('No GPU devices found') + (major, minor) = _get_compute_cap(gpus[0]) + return 'sm_%s%s' % (major, minor) + +def _run_cmd(cmd): + with os.popen(cmd) as pipe: + output = pipe.read() + status = pipe.close() + if status is not None: + raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output)) + +def _prepare_nvcc_cli(opts): + cmd = 'nvcc ' + opts.strip() + cmd += ' --disable-warnings' + cmd += ' --include-path "%s"' % tf.sysconfig.get_include() + cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src') + cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'com_google_absl') + cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'eigen_archive') + + compiler_bindir = _find_compiler_bindir() + if compiler_bindir is None: + # Require that _find_compiler_bindir succeeds on Windows. Allow + # nvcc to use whatever is the default on Linux. + if os.name == 'nt': + raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__) + else: + cmd += ' --compiler-bindir "%s"' % compiler_bindir + cmd += ' 2>&1' + return cmd + +#---------------------------------------------------------------------------- +# Main entry point. + +_plugin_cache = dict() + +def get_plugin(cuda_file, extra_nvcc_options=[]): + cuda_file_base = os.path.basename(cuda_file) + cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base) + + # Already in cache? + if cuda_file in _plugin_cache: + return _plugin_cache[cuda_file] + + # Setup plugin. + if verbose: + print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True) + try: + # Hash CUDA source. + md5 = hashlib.md5() + with open(cuda_file, 'rb') as f: + md5.update(f.read()) + md5.update(b'\n') + + # Hash headers included by the CUDA code by running it through the preprocessor. + if not do_not_hash_included_headers: + if verbose: + print('Preprocessing... ', end='', flush=True) + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext) + _run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))) + with open(tmp_file, 'rb') as f: + bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros + good_file_str = ('"' + cuda_file_base + '"').encode('utf-8') + for ln in f: + if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas + ln = ln.replace(bad_file_str, good_file_str) + md5.update(ln) + md5.update(b'\n') + + # Select compiler configs. + compile_opts = '' + if os.name == 'nt': + compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.lib') + elif os.name == 'posix': + compile_opts += f' --compiler-options \'-fPIC\'' + compile_opts += f' --compiler-options \'{" ".join(tf.sysconfig.get_compile_flags())}\'' + compile_opts += f' --linker-options \'{" ".join(tf.sysconfig.get_link_flags())}\'' + else: + assert False # not Windows or Linux, w00t? + compile_opts += f' --gpu-architecture={_get_cuda_gpu_arch_string()}' + compile_opts += ' --use_fast_math' + for opt in extra_nvcc_options: + compile_opts += ' ' + opt + nvcc_cmd = _prepare_nvcc_cli(compile_opts) + + # Hash build configuration. + md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n') + md5.update(('tf.VERSION: ' + tf.VERSION).encode('utf-8') + b'\n') + md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n') + + # Compile if not already compiled. + cache_dir = util.make_cache_dir_path('tflib-cudacache') if cuda_cache_path is None else cuda_cache_path + bin_file_ext = '.dll' if os.name == 'nt' else '.so' + bin_file = os.path.join(cache_dir, cuda_file_name + '_' + md5.hexdigest() + bin_file_ext) + if not os.path.isfile(bin_file): + if verbose: + print('Compiling... ', end='', flush=True) + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext) + _run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)) + os.makedirs(cache_dir, exist_ok=True) + intermediate_file = os.path.join(cache_dir, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext) + shutil.copyfile(tmp_file, intermediate_file) + os.rename(intermediate_file, bin_file) # atomic + + # Load. + if verbose: + print('Loading... ', end='', flush=True) + plugin = tf.load_op_library(bin_file) + + # Add to cache. + _plugin_cache[cuda_file] = plugin + if verbose: + print('Done.', flush=True) + return plugin + + except: + if verbose: + print('Failed!', flush=True) + raise + +#---------------------------------------------------------------------------- diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/network.py b/models/StyleCLIP/global_directions/dnnlib/tflib/network.py new file mode 100644 index 0000000000000000000000000000000000000000..ff0c169eabdc579041dac0650fbc6da956646594 --- /dev/null +++ b/models/StyleCLIP/global_directions/dnnlib/tflib/network.py @@ -0,0 +1,781 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Helper for managing networks.""" + +import types +import inspect +import re +import uuid +import sys +import copy +import numpy as np +import tensorflow as tf + +from collections import OrderedDict +from typing import Any, List, Tuple, Union, Callable + +from . import tfutil +from .. import util + +from .tfutil import TfExpression, TfExpressionEx + +# pylint: disable=protected-access +# pylint: disable=attribute-defined-outside-init +# pylint: disable=too-many-public-methods + +_import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import. +_import_module_src = dict() # Source code for temporary modules created during pickle import. + + +def import_handler(handler_func): + """Function decorator for declaring custom import handlers.""" + _import_handlers.append(handler_func) + return handler_func + + +class Network: + """Generic network abstraction. + + Acts as a convenience wrapper for a parameterized network construction + function, providing several utility methods and convenient access to + the inputs/outputs/weights. + + Network objects can be safely pickled and unpickled for long-term + archival purposes. The pickling works reliably as long as the underlying + network construction function is defined in a standalone Python module + that has no side effects or application-specific imports. + + Args: + name: Network name. Used to select TensorFlow name and variable scopes. Defaults to build func name if None. + func_name: Fully qualified name of the underlying network construction function, or a top-level function object. + static_kwargs: Keyword arguments to be passed in to the network construction function. + """ + + def __init__(self, name: str = None, func_name: Any = None, **static_kwargs): + # Locate the user-specified build function. + assert isinstance(func_name, str) or util.is_top_level_function(func_name) + if util.is_top_level_function(func_name): + func_name = util.get_top_level_function_name(func_name) + module, func_name = util.get_module_from_obj_name(func_name) + func = util.get_obj_from_module(module, func_name) + + # Dig up source code for the module containing the build function. + module_src = _import_module_src.get(module, None) + if module_src is None: + module_src = inspect.getsource(module) + + # Initialize fields. + self._init_fields(name=(name or func_name), static_kwargs=static_kwargs, build_func=func, build_func_name=func_name, build_module_src=module_src) + + def _init_fields(self, name: str, static_kwargs: dict, build_func: Callable, build_func_name: str, build_module_src: str) -> None: + tfutil.assert_tf_initialized() + assert isinstance(name, str) + assert len(name) >= 1 + assert re.fullmatch(r"[A-Za-z0-9_.\\-]*", name) + assert isinstance(static_kwargs, dict) + assert util.is_pickleable(static_kwargs) + assert callable(build_func) + assert isinstance(build_func_name, str) + assert isinstance(build_module_src, str) + + # Choose TensorFlow name scope. + with tf.name_scope(None): + scope = tf.get_default_graph().unique_name(name, mark_as_used=True) + + # Query current TensorFlow device. + with tfutil.absolute_name_scope(scope), tf.control_dependencies(None): + device = tf.no_op(name="_QueryDevice").device + + # Immutable state. + self._name = name + self._scope = scope + self._device = device + self._static_kwargs = util.EasyDict(copy.deepcopy(static_kwargs)) + self._build_func = build_func + self._build_func_name = build_func_name + self._build_module_src = build_module_src + + # State before _init_graph(). + self._var_inits = dict() # var_name => initial_value, set to None by _init_graph() + self._all_inits_known = False # Do we know for sure that _var_inits covers all the variables? + self._components = None # subnet_name => Network, None if the components are not known yet + + # Initialized by _init_graph(). + self._input_templates = None + self._output_templates = None + self._own_vars = None + + # Cached values initialized the respective methods. + self._input_shapes = None + self._output_shapes = None + self._input_names = None + self._output_names = None + self._vars = None + self._trainables = None + self._var_global_to_local = None + self._run_cache = dict() + + def _init_graph(self) -> None: + assert self._var_inits is not None + assert self._input_templates is None + assert self._output_templates is None + assert self._own_vars is None + + # Initialize components. + if self._components is None: + self._components = util.EasyDict() + + # Choose build func kwargs. + build_kwargs = dict(self.static_kwargs) + build_kwargs["is_template_graph"] = True + build_kwargs["components"] = self._components + + # Override scope and device, and ignore surrounding control dependencies. + with tfutil.absolute_variable_scope(self.scope, reuse=False), tfutil.absolute_name_scope(self.scope), tf.device(self.device), tf.control_dependencies(None): + assert tf.get_variable_scope().name == self.scope + assert tf.get_default_graph().get_name_scope() == self.scope + + # Create input templates. + self._input_templates = [] + for param in inspect.signature(self._build_func).parameters.values(): + if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty: + self._input_templates.append(tf.placeholder(tf.float32, name=param.name)) + + # Call build func. + out_expr = self._build_func(*self._input_templates, **build_kwargs) + + # Collect output templates and variables. + assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple) + self._output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr) + self._own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/")) + + # Check for errors. + if len(self._input_templates) == 0: + raise ValueError("Network build func did not list any inputs.") + if len(self._output_templates) == 0: + raise ValueError("Network build func did not return any outputs.") + if any(not tfutil.is_tf_expression(t) for t in self._output_templates): + raise ValueError("Network outputs must be TensorFlow expressions.") + if any(t.shape.ndims is None for t in self._input_templates): + raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.") + if any(t.shape.ndims is None for t in self._output_templates): + raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.") + if any(not isinstance(comp, Network) for comp in self._components.values()): + raise ValueError("Components of a Network must be Networks themselves.") + if len(self._components) != len(set(comp.name for comp in self._components.values())): + raise ValueError("Components of a Network must have unique names.") + + # Initialize variables. + if len(self._var_inits): + tfutil.set_vars({self._get_vars()[name]: value for name, value in self._var_inits.items() if name in self._get_vars()}) + remaining_inits = [var.initializer for name, var in self._own_vars.items() if name not in self._var_inits] + if self._all_inits_known: + assert len(remaining_inits) == 0 + else: + tfutil.run(remaining_inits) + self._var_inits = None + + @property + def name(self): + """User-specified name string.""" + return self._name + + @property + def scope(self): + """Unique TensorFlow scope containing template graph and variables, derived from the user-specified name.""" + return self._scope + + @property + def device(self): + """Name of the TensorFlow device that the weights of this network reside on. Determined by the current device at construction time.""" + return self._device + + @property + def static_kwargs(self): + """EasyDict of arguments passed to the user-supplied build func.""" + return copy.deepcopy(self._static_kwargs) + + @property + def components(self): + """EasyDict of sub-networks created by the build func.""" + return copy.copy(self._get_components()) + + def _get_components(self): + if self._components is None: + self._init_graph() + assert self._components is not None + return self._components + + @property + def input_shapes(self): + """List of input tensor shapes, including minibatch dimension.""" + if self._input_shapes is None: + self._input_shapes = [t.shape.as_list() for t in self.input_templates] + return copy.deepcopy(self._input_shapes) + + @property + def output_shapes(self): + """List of output tensor shapes, including minibatch dimension.""" + if self._output_shapes is None: + self._output_shapes = [t.shape.as_list() for t in self.output_templates] + return copy.deepcopy(self._output_shapes) + + @property + def input_shape(self): + """Short-hand for input_shapes[0].""" + return self.input_shapes[0] + + @property + def output_shape(self): + """Short-hand for output_shapes[0].""" + return self.output_shapes[0] + + @property + def num_inputs(self): + """Number of input tensors.""" + return len(self.input_shapes) + + @property + def num_outputs(self): + """Number of output tensors.""" + return len(self.output_shapes) + + @property + def input_names(self): + """Name string for each input.""" + if self._input_names is None: + self._input_names = [t.name.split("/")[-1].split(":")[0] for t in self.input_templates] + return copy.copy(self._input_names) + + @property + def output_names(self): + """Name string for each output.""" + if self._output_names is None: + self._output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates] + return copy.copy(self._output_names) + + @property + def input_templates(self): + """Input placeholders in the template graph.""" + if self._input_templates is None: + self._init_graph() + assert self._input_templates is not None + return copy.copy(self._input_templates) + + @property + def output_templates(self): + """Output tensors in the template graph.""" + if self._output_templates is None: + self._init_graph() + assert self._output_templates is not None + return copy.copy(self._output_templates) + + @property + def own_vars(self): + """Variables defined by this network (local_name => var), excluding sub-networks.""" + return copy.copy(self._get_own_vars()) + + def _get_own_vars(self): + if self._own_vars is None: + self._init_graph() + assert self._own_vars is not None + return self._own_vars + + @property + def vars(self): + """All variables (local_name => var).""" + return copy.copy(self._get_vars()) + + def _get_vars(self): + if self._vars is None: + self._vars = OrderedDict(self._get_own_vars()) + for comp in self._get_components().values(): + self._vars.update((comp.name + "/" + name, var) for name, var in comp._get_vars().items()) + return self._vars + + @property + def trainables(self): + """All trainable variables (local_name => var).""" + return copy.copy(self._get_trainables()) + + def _get_trainables(self): + if self._trainables is None: + self._trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable) + return self._trainables + + @property + def var_global_to_local(self): + """Mapping from variable global names to local names.""" + return copy.copy(self._get_var_global_to_local()) + + def _get_var_global_to_local(self): + if self._var_global_to_local is None: + self._var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items()) + return self._var_global_to_local + + def reset_own_vars(self) -> None: + """Re-initialize all variables of this network, excluding sub-networks.""" + if self._var_inits is None or self._components is None: + tfutil.run([var.initializer for var in self._get_own_vars().values()]) + else: + self._var_inits.clear() + self._all_inits_known = False + + def reset_vars(self) -> None: + """Re-initialize all variables of this network, including sub-networks.""" + if self._var_inits is None: + tfutil.run([var.initializer for var in self._get_vars().values()]) + else: + self._var_inits.clear() + self._all_inits_known = False + if self._components is not None: + for comp in self._components.values(): + comp.reset_vars() + + def reset_trainables(self) -> None: + """Re-initialize all trainable variables of this network, including sub-networks.""" + tfutil.run([var.initializer for var in self._get_trainables().values()]) + + def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]: + """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s). + The graph is placed on the current TensorFlow device.""" + assert len(in_expr) == self.num_inputs + assert not all(expr is None for expr in in_expr) + self._get_vars() # ensure that all variables have been created + + # Choose build func kwargs. + build_kwargs = dict(self.static_kwargs) + build_kwargs.update(dynamic_kwargs) + build_kwargs["is_template_graph"] = False + build_kwargs["components"] = self._components + + # Build TensorFlow graph to evaluate the network. + with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name): + assert tf.get_variable_scope().name == self.scope + valid_inputs = [expr for expr in in_expr if expr is not None] + final_inputs = [] + for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes): + if expr is not None: + expr = tf.identity(expr, name=name) + else: + expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name) + final_inputs.append(expr) + out_expr = self._build_func(*final_inputs, **build_kwargs) + + # Propagate input shapes back to the user-specified expressions. + for expr, final in zip(in_expr, final_inputs): + if isinstance(expr, tf.Tensor): + expr.set_shape(final.shape) + + # Express outputs in the desired format. + assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple) + if return_as_list: + out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr) + return out_expr + + def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str: + """Get the local name of a given variable, without any surrounding name scopes.""" + assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str) + global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name + return self._get_var_global_to_local()[global_name] + + def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression: + """Find variable by local or global name.""" + assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str) + return self._get_vars()[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name + + def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray: + """Get the value of a given variable as NumPy array. + Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible.""" + return self.find_var(var_or_local_name).eval() + + def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None: + """Set the value of a given variable based on the given NumPy array. + Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible.""" + tfutil.set_vars({self.find_var(var_or_local_name): new_value}) + + def __getstate__(self) -> dict: + """Pickle export.""" + state = dict() + state["version"] = 5 + state["name"] = self.name + state["static_kwargs"] = dict(self.static_kwargs) + state["components"] = dict(self.components) + state["build_module_src"] = self._build_module_src + state["build_func_name"] = self._build_func_name + state["variables"] = list(zip(self._get_own_vars().keys(), tfutil.run(list(self._get_own_vars().values())))) + state["input_shapes"] = self.input_shapes + state["output_shapes"] = self.output_shapes + state["input_names"] = self.input_names + state["output_names"] = self.output_names + return state + + def __setstate__(self, state: dict) -> None: + """Pickle import.""" + + # Execute custom import handlers. + for handler in _import_handlers: + state = handler(state) + + # Get basic fields. + assert state["version"] in [2, 3, 4, 5] + name = state["name"] + static_kwargs = state["static_kwargs"] + build_module_src = state["build_module_src"] + build_func_name = state["build_func_name"] + + # Create temporary module from the imported source code. + module_name = "_tflib_network_import_" + uuid.uuid4().hex + module = types.ModuleType(module_name) + sys.modules[module_name] = module + _import_module_src[module] = build_module_src + exec(build_module_src, module.__dict__) # pylint: disable=exec-used + build_func = util.get_obj_from_module(module, build_func_name) + + # Initialize fields. + self._init_fields(name=name, static_kwargs=static_kwargs, build_func=build_func, build_func_name=build_func_name, build_module_src=build_module_src) + self._var_inits.update(copy.deepcopy(state["variables"])) + self._all_inits_known = True + self._components = util.EasyDict(state.get("components", {})) + self._input_shapes = copy.deepcopy(state.get("input_shapes", None)) + self._output_shapes = copy.deepcopy(state.get("output_shapes", None)) + self._input_names = copy.deepcopy(state.get("input_names", None)) + self._output_names = copy.deepcopy(state.get("output_names", None)) + + def clone(self, name: str = None, **new_static_kwargs) -> "Network": + """Create a clone of this network with its own copy of the variables.""" + static_kwargs = dict(self.static_kwargs) + static_kwargs.update(new_static_kwargs) + net = object.__new__(Network) + net._init_fields(name=(name or self.name), static_kwargs=static_kwargs, build_func=self._build_func, build_func_name=self._build_func_name, build_module_src=self._build_module_src) + net.copy_vars_from(self) + return net + + def copy_own_vars_from(self, src_net: "Network") -> None: + """Copy the values of all variables from the given network, excluding sub-networks.""" + + # Source has unknown variables or unknown components => init now. + if (src_net._var_inits is not None and not src_net._all_inits_known) or src_net._components is None: + src_net._get_vars() + + # Both networks are inited => copy directly. + if src_net._var_inits is None and self._var_inits is None: + names = [name for name in self._get_own_vars().keys() if name in src_net._get_own_vars()] + tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names})) + return + + # Read from source. + if src_net._var_inits is None: + value_dict = tfutil.run(src_net._get_own_vars()) + else: + value_dict = src_net._var_inits + + # Write to destination. + if self._var_inits is None: + tfutil.set_vars({self._get_vars()[name]: value for name, value in value_dict.items() if name in self._get_vars()}) + else: + self._var_inits.update(value_dict) + + def copy_vars_from(self, src_net: "Network") -> None: + """Copy the values of all variables from the given network, including sub-networks.""" + + # Source has unknown variables or unknown components => init now. + if (src_net._var_inits is not None and not src_net._all_inits_known) or src_net._components is None: + src_net._get_vars() + + # Source is inited, but destination components have not been created yet => set as initial values. + if src_net._var_inits is None and self._components is None: + self._var_inits.update(tfutil.run(src_net._get_vars())) + return + + # Destination has unknown components => init now. + if self._components is None: + self._get_vars() + + # Both networks are inited => copy directly. + if src_net._var_inits is None and self._var_inits is None: + names = [name for name in self._get_vars().keys() if name in src_net._get_vars()] + tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names})) + return + + # Copy recursively, component by component. + self.copy_own_vars_from(src_net) + for name, src_comp in src_net._components.items(): + if name in self._components: + self._components[name].copy_vars_from(src_comp) + + def copy_trainables_from(self, src_net: "Network") -> None: + """Copy the values of all trainable variables from the given network, including sub-networks.""" + names = [name for name in self._get_trainables().keys() if name in src_net._get_trainables()] + tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names})) + + def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network": + """Create new network with the given parameters, and copy all variables from this network.""" + if new_name is None: + new_name = self.name + static_kwargs = dict(self.static_kwargs) + static_kwargs.update(new_static_kwargs) + net = Network(name=new_name, func_name=new_func_name, **static_kwargs) + net.copy_vars_from(self) + return net + + def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation: + """Construct a TensorFlow op that updates the variables of this network + to be slightly closer to those of the given network.""" + with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"): + ops = [] + for name, var in self._get_vars().items(): + if name in src_net._get_vars(): + cur_beta = beta if var.trainable else beta_nontrainable + new_value = tfutil.lerp(src_net._get_vars()[name], var, cur_beta) + ops.append(var.assign(new_value)) + return tf.group(*ops) + + def run(self, + *in_arrays: Tuple[Union[np.ndarray, None], ...], + input_transform: dict = None, + output_transform: dict = None, + return_as_list: bool = False, + print_progress: bool = False, + minibatch_size: int = None, + num_gpus: int = 1, + assume_frozen: bool = False, + **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]: + """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s). + + Args: + input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network. + The dict must contain a 'func' field that points to a top-level function. The function is called with the input + TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs. + output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network. + The dict must contain a 'func' field that points to a top-level function. The function is called with the output + TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs. + return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs. + print_progress: Print progress to the console? Useful for very large input arrays. + minibatch_size: Maximum minibatch size to use, None = disable batching. + num_gpus: Number of GPUs to use. + assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls. + dynamic_kwargs: Additional keyword arguments to be passed into the network build function. + """ + assert len(in_arrays) == self.num_inputs + assert not all(arr is None for arr in in_arrays) + assert input_transform is None or util.is_top_level_function(input_transform["func"]) + assert output_transform is None or util.is_top_level_function(output_transform["func"]) + output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs) + num_items = in_arrays[0].shape[0] + if minibatch_size is None: + minibatch_size = num_items + + # Construct unique hash key from all arguments that affect the TensorFlow graph. + key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs) + def unwind_key(obj): + if isinstance(obj, dict): + return [(key, unwind_key(value)) for key, value in sorted(obj.items())] + if callable(obj): + return util.get_top_level_function_name(obj) + return obj + key = repr(unwind_key(key)) + + # Build graph. + if key not in self._run_cache: + with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None): + with tf.device("/cpu:0"): + in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names] + in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr])) + + out_split = [] + for gpu in range(num_gpus): + with tf.device(self.device if num_gpus == 1 else "/gpu:%d" % gpu): + net_gpu = self.clone() if assume_frozen else self + in_gpu = in_split[gpu] + + if input_transform is not None: + in_kwargs = dict(input_transform) + in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs) + in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu) + + assert len(in_gpu) == self.num_inputs + out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs) + + if output_transform is not None: + out_kwargs = dict(output_transform) + out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs) + out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu) + + assert len(out_gpu) == self.num_outputs + out_split.append(out_gpu) + + with tf.device("/cpu:0"): + out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)] + self._run_cache[key] = in_expr, out_expr + + # Run minibatches. + in_expr, out_expr = self._run_cache[key] + out_arrays = [np.empty([num_items] + expr.shape.as_list()[1:], expr.dtype.name) for expr in out_expr] + + for mb_begin in range(0, num_items, minibatch_size): + if print_progress: + print("\r%d / %d" % (mb_begin, num_items), end="") + + mb_end = min(mb_begin + minibatch_size, num_items) + mb_num = mb_end - mb_begin + mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)] + mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in))) + + for dst, src in zip(out_arrays, mb_out): + dst[mb_begin: mb_end] = src + + # Done. + if print_progress: + print("\r%d / %d" % (num_items, num_items)) + + if not return_as_list: + out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays) + return out_arrays + + def list_ops(self) -> List[TfExpression]: + _ = self.output_templates # ensure that the template graph has been created + include_prefix = self.scope + "/" + exclude_prefix = include_prefix + "_" + ops = tf.get_default_graph().get_operations() + ops = [op for op in ops if op.name.startswith(include_prefix)] + ops = [op for op in ops if not op.name.startswith(exclude_prefix)] + return ops + + def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]: + """Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to + individual layers of the network. Mainly intended to be used for reporting.""" + layers = [] + + def recurse(scope, parent_ops, parent_vars, level): + if len(parent_ops) == 0 and len(parent_vars) == 0: + return + + # Ignore specific patterns. + if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]): + return + + # Filter ops and vars by scope. + global_prefix = scope + "/" + local_prefix = global_prefix[len(self.scope) + 1:] + cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]] + cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]] + if not cur_ops and not cur_vars: + return + + # Filter out all ops related to variables. + for var in [op for op in cur_ops if op.type.startswith("Variable")]: + var_prefix = var.name + "/" + cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)] + + # Scope does not contain ops as immediate children => recurse deeper. + contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type not in ["Identity", "Cast", "Transpose"] for op in cur_ops) + if (level == 0 or not contains_direct_ops) and (len(cur_ops) != 0 or len(cur_vars) != 0): + visited = set() + for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]: + token = rel_name.split("/")[0] + if token not in visited: + recurse(global_prefix + token, cur_ops, cur_vars, level + 1) + visited.add(token) + return + + # Report layer. + layer_name = scope[len(self.scope) + 1:] + layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1] + layer_trainables = [var for _name, var in cur_vars if var.trainable] + layers.append((layer_name, layer_output, layer_trainables)) + + recurse(self.scope, self.list_ops(), list(self._get_vars().items()), 0) + return layers + + def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None: + """Print a summary table of the network structure.""" + rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]] + rows += [["---"] * 4] + total_params = 0 + + for layer_name, layer_output, layer_trainables in self.list_layers(): + num_params = sum(int(np.prod(var.shape.as_list())) for var in layer_trainables) + weights = [var for var in layer_trainables if var.name.endswith("/weight:0")] + weights.sort(key=lambda x: len(x.name)) + if len(weights) == 0 and len(layer_trainables) == 1: + weights = layer_trainables + total_params += num_params + + if not hide_layers_with_no_params or num_params != 0: + num_params_str = str(num_params) if num_params > 0 else "-" + output_shape_str = str(layer_output.shape) + weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-" + rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]] + + rows += [["---"] * 4] + rows += [["Total", str(total_params), "", ""]] + + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + print() + for row in rows: + print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths))) + print() + + def setup_weight_histograms(self, title: str = None) -> None: + """Construct summary ops to include histograms of all trainable parameters in TensorBoard.""" + if title is None: + title = self.name + + with tf.name_scope(None), tf.device(None), tf.control_dependencies(None): + for local_name, var in self._get_trainables().items(): + if "/" in local_name: + p = local_name.split("/") + name = title + "_" + p[-1] + "/" + "_".join(p[:-1]) + else: + name = title + "_toplevel/" + local_name + + tf.summary.histogram(name, var) + +#---------------------------------------------------------------------------- +# Backwards-compatible emulation of legacy output transformation in Network.run(). + +_print_legacy_warning = True + +def _handle_legacy_output_transforms(output_transform, dynamic_kwargs): + global _print_legacy_warning + legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"] + if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs): + return output_transform, dynamic_kwargs + + if _print_legacy_warning: + _print_legacy_warning = False + print() + print("WARNING: Old-style output transformations in Network.run() are deprecated.") + print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'") + print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.") + print() + assert output_transform is None + + new_kwargs = dict(dynamic_kwargs) + new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs} + new_transform["func"] = _legacy_output_transform_func + return new_transform, new_kwargs + +def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None): + if out_mul != 1.0: + expr = [x * out_mul for x in expr] + + if out_add != 0.0: + expr = [x + out_add for x in expr] + + if out_shrink > 1: + ksize = [1, 1, out_shrink, out_shrink] + expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr] + + if out_dtype is not None: + if tf.as_dtype(out_dtype).is_integer: + expr = [tf.round(x) for x in expr] + expr = [tf.saturate_cast(x, out_dtype) for x in expr] + return expr diff --git a/PTI/torch_utils/__init__.py b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/__init__.py similarity index 85% rename from PTI/torch_utils/__init__.py rename to models/StyleCLIP/global_directions/dnnlib/tflib/ops/__init__.py index ece0ea08fe2e939cc260a1dafc0ab5b391b773d9..43cce37364064146fd30e18612b1d9e3a84f513a 100644 --- a/PTI/torch_utils/__init__.py +++ b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/ops/fused_bias_act.cu b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/fused_bias_act.cu new file mode 100644 index 0000000000000000000000000000000000000000..0268f14395319003240b4a5a59141d703e9a4257 --- /dev/null +++ b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/fused_bias_act.cu @@ -0,0 +1,220 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#define EIGEN_USE_GPU +#define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include + +using namespace tensorflow; +using namespace tensorflow::shape_inference; + +#define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false) + +//------------------------------------------------------------------------ +// CUDA kernel. + +template +struct FusedBiasActKernelParams +{ + const T* x; // [sizeX] + const T* b; // [sizeB] or NULL + const T* xref; // [sizeX] or NULL + const T* yref; // [sizeX] or NULL + T* y; // [sizeX] + + int grad; + int axis; + int act; + float alpha; + float gain; + float clamp; + + int sizeX; + int sizeB; + int stepB; + int loopX; +}; + +template +static __global__ void FusedBiasActKernel(const FusedBiasActKernelParams p) +{ + const float expRange = 80.0f; + const float halfExpRange = 40.0f; + const float seluScale = 1.0507009873554804934193349852946f; + const float seluAlpha = 1.6732632423543772848170429916717f; + + // Loop over elements. + int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; + for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) + { + // Load and apply bias. + float x = (float)p.x[xi]; + if (p.b) + x += (float)p.b[(xi / p.stepB) % p.sizeB]; + float xref = (p.xref) ? (float)p.xref[xi] : 0.0f; + float yref = (p.yref) ? (float)p.yref[xi] : 0.0f; + float yy = (p.gain != 0.0f) ? yref / p.gain : 0.0f; + + // Evaluate activation func. + float y; + switch (p.act * 10 + p.grad) + { + // linear + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0f; break; + + // relu + case 20: y = (x > 0.0f) ? x : 0.0f; break; + case 21: y = (yy > 0.0f) ? x : 0.0f; break; + case 22: y = 0.0f; break; + + // lrelu + case 30: y = (x > 0.0f) ? x : x * p.alpha; break; + case 31: y = (yy > 0.0f) ? x : x * p.alpha; break; + case 32: y = 0.0f; break; + + // tanh + case 40: { float c = expf(x); float d = 1.0f / c; y = (x < -expRange) ? -1.0f : (x > expRange) ? 1.0f : (c - d) / (c + d); } break; + case 41: y = x * (1.0f - yy * yy); break; + case 42: y = x * (1.0f - yy * yy) * (-2.0f * yy); break; + + // sigmoid + case 50: y = (x < -expRange) ? 0.0f : 1.0f / (expf(-x) + 1.0f); break; + case 51: y = x * yy * (1.0f - yy); break; + case 52: y = x * yy * (1.0f - yy) * (1.0f - 2.0f * yy); break; + + // elu + case 60: y = (x >= 0.0f) ? x : expf(x) - 1.0f; break; + case 61: y = (yy >= 0.0f) ? x : x * (yy + 1.0f); break; + case 62: y = (yy >= 0.0f) ? 0.0f : x * (yy + 1.0f); break; + + // selu + case 70: y = (x >= 0.0f) ? seluScale * x : (seluScale * seluAlpha) * (expf(x) - 1.0f); break; + case 71: y = (yy >= 0.0f) ? x * seluScale : x * (yy + seluScale * seluAlpha); break; + case 72: y = (yy >= 0.0f) ? 0.0f : x * (yy + seluScale * seluAlpha); break; + + // softplus + case 80: y = (x > expRange) ? x : logf(expf(x) + 1.0f); break; + case 81: y = x * (1.0f - expf(-yy)); break; + case 82: { float c = expf(-yy); y = x * c * (1.0f - c); } break; + + // swish + case 90: y = (x < -expRange) ? 0.0f : x / (expf(-x) + 1.0f); break; + case 91: + case 92: + { + float c = expf(xref); + float d = c + 1.0f; + if (p.grad == 1) + y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); + else + y = (xref > halfExpRange) ? 0.0f : x * c * (xref * (2.0f - d) + 2.0f * d) / (d * d * d); + yref = (xref < -expRange) ? 0.0f : xref / (expf(-xref) + 1.0f) * p.gain; + } + break; + } + + // Apply gain. + y *= p.gain; + + // Clamp. + if (p.clamp >= 0.0f) + { + if (p.grad == 0) + y = (fabsf(y) < p.clamp) ? y : (y >= 0.0f) ? p.clamp : -p.clamp; + else + y = (fabsf(yref) < p.clamp) ? y : 0.0f; + } + + // Store. + p.y[xi] = (T)y; + } +} + +//------------------------------------------------------------------------ +// TensorFlow op. + +template +struct FusedBiasActOp : public OpKernel +{ + FusedBiasActKernelParams m_attribs; + + FusedBiasActOp(OpKernelConstruction* ctx) : OpKernel(ctx) + { + memset(&m_attribs, 0, sizeof(m_attribs)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("grad", &m_attribs.grad)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &m_attribs.axis)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("act", &m_attribs.act)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &m_attribs.alpha)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("gain", &m_attribs.gain)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("clamp", &m_attribs.clamp)); + OP_REQUIRES(ctx, m_attribs.grad >= 0, errors::InvalidArgument("grad must be non-negative")); + OP_REQUIRES(ctx, m_attribs.axis >= 0, errors::InvalidArgument("axis must be non-negative")); + OP_REQUIRES(ctx, m_attribs.act >= 0, errors::InvalidArgument("act must be non-negative")); + } + + void Compute(OpKernelContext* ctx) + { + FusedBiasActKernelParams p = m_attribs; + cudaStream_t stream = ctx->eigen_device().stream(); + + const Tensor& x = ctx->input(0); // [...] + const Tensor& b = ctx->input(1); // [sizeB] or [0] + const Tensor& xref = ctx->input(2); // x.shape or [0] + const Tensor& yref = ctx->input(3); // x.shape or [0] + p.x = x.flat().data(); + p.b = (b.NumElements()) ? b.flat().data() : NULL; + p.xref = (xref.NumElements()) ? xref.flat().data() : NULL; + p.yref = (yref.NumElements()) ? yref.flat().data() : NULL; + OP_REQUIRES(ctx, b.NumElements() == 0 || m_attribs.axis < x.dims(), errors::InvalidArgument("axis out of bounds")); + OP_REQUIRES(ctx, b.dims() == 1, errors::InvalidArgument("b must have rank 1")); + OP_REQUIRES(ctx, b.NumElements() == 0 || b.NumElements() == x.dim_size(m_attribs.axis), errors::InvalidArgument("b has wrong number of elements")); + OP_REQUIRES(ctx, xref.NumElements() == 0 || xref.NumElements() == x.NumElements(), errors::InvalidArgument("xref has wrong number of elements")); + OP_REQUIRES(ctx, yref.NumElements() == 0 || yref.NumElements() == x.NumElements(), errors::InvalidArgument("yref has wrong number of elements")); + OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("x is too large")); + + p.sizeX = (int)x.NumElements(); + p.sizeB = (int)b.NumElements(); + p.stepB = 1; + for (int i = m_attribs.axis + 1; i < x.dims(); i++) + p.stepB *= (int)x.dim_size(i); + + Tensor* y = NULL; // x.shape + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y)); + p.y = y->flat().data(); + + p.loopX = 4; + int blockSize = 4 * 32; + int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; + void* args[] = {&p}; + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)FusedBiasActKernel, gridSize, blockSize, args, 0, stream)); + } +}; + +REGISTER_OP("FusedBiasAct") + .Input ("x: T") + .Input ("b: T") + .Input ("xref: T") + .Input ("yref: T") + .Output ("y: T") + .Attr ("T: {float, half}") + .Attr ("grad: int = 0") + .Attr ("axis: int = 1") + .Attr ("act: int = 0") + .Attr ("alpha: float = 0.0") + .Attr ("gain: float = 1.0") + .Attr ("clamp: float = -1.0"); +REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp); +REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp); + +//------------------------------------------------------------------------ diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/ops/fused_bias_act.py b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/fused_bias_act.py new file mode 100644 index 0000000000000000000000000000000000000000..79991b0497d3d92f25194a31668b9568048163f8 --- /dev/null +++ b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/fused_bias_act.py @@ -0,0 +1,211 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom TensorFlow ops for efficient bias and activation.""" + +import os +import numpy as np +import tensorflow as tf +from .. import custom_ops +from ...util import EasyDict + +def _get_plugin(): + return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu') + +#---------------------------------------------------------------------------- + +activation_funcs = { + 'linear': EasyDict(func=lambda x, **_: x, def_alpha=None, def_gain=1.0, cuda_idx=1, ref='y', zero_2nd_grad=True), + 'relu': EasyDict(func=lambda x, **_: tf.nn.relu(x), def_alpha=None, def_gain=np.sqrt(2), cuda_idx=2, ref='y', zero_2nd_grad=True), + 'lrelu': EasyDict(func=lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', zero_2nd_grad=True), + 'tanh': EasyDict(func=lambda x, **_: tf.nn.tanh(x), def_alpha=None, def_gain=1.0, cuda_idx=4, ref='y', zero_2nd_grad=False), + 'sigmoid': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x), def_alpha=None, def_gain=1.0, cuda_idx=5, ref='y', zero_2nd_grad=False), + 'elu': EasyDict(func=lambda x, **_: tf.nn.elu(x), def_alpha=None, def_gain=1.0, cuda_idx=6, ref='y', zero_2nd_grad=False), + 'selu': EasyDict(func=lambda x, **_: tf.nn.selu(x), def_alpha=None, def_gain=1.0, cuda_idx=7, ref='y', zero_2nd_grad=False), + 'softplus': EasyDict(func=lambda x, **_: tf.nn.softplus(x), def_alpha=None, def_gain=1.0, cuda_idx=8, ref='y', zero_2nd_grad=False), + 'swish': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x) * x, def_alpha=None, def_gain=np.sqrt(2), cuda_idx=9, ref='x', zero_2nd_grad=False), +} + +#---------------------------------------------------------------------------- + +def fused_bias_act(x, b=None, axis=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): + r"""Fused bias and activation function. + + Adds bias `b` to activation tensor `x`, evaluates activation function `act`, + and scales the result by `gain`. Each of the steps is optional. In most cases, + the fused op is considerably more efficient than performing the same calculation + using standard TensorFlow ops. It supports first and second order gradients, + but not third order gradients. + + Args: + x: Input activation tensor. Can have any shape, but if `b` is defined, the + dimension corresponding to `axis`, as well as the rank, must be known. + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The shape must be known, and it must match the dimension of `x` + corresponding to `axis`. + axis: The dimension in `x` corresponding to the elements of `b`. + The value of `axis` is ignored if `b` is not specified. + act: Name of the activation function to evaluate, or `"linear"` to disable. + Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. + See `activation_funcs` for a full list. `None` is not allowed. + alpha: Shape parameter for the activation function, or `None` to use the default. + gain: Scaling factor for the output tensor, or `None` to use default. + See `activation_funcs` for the default scaling of each activation function. + If unsure, consider specifying `1.0`. + clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable + the clamping (default). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the same shape and datatype as `x`. + """ + + impl_dict = { + 'ref': _fused_bias_act_ref, + 'cuda': _fused_bias_act_cuda, + } + return impl_dict[impl](x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain, clamp=clamp) + +#---------------------------------------------------------------------------- + +def _fused_bias_act_ref(x, b, axis, act, alpha, gain, clamp): + """Slow reference implementation of `fused_bias_act()` using standard TensorFlow ops.""" + + # Validate arguments. + x = tf.convert_to_tensor(x) + b = tf.convert_to_tensor(b) if b is not None else tf.constant([], dtype=x.dtype) + act_spec = activation_funcs[act] + assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis]) + assert b.shape[0] == 0 or 0 <= axis < x.shape.rank + if alpha is None: + alpha = act_spec.def_alpha + if gain is None: + gain = act_spec.def_gain + + # Add bias. + if b.shape[0] != 0: + x += tf.reshape(b, [-1 if i == axis else 1 for i in range(x.shape.rank)]) + + # Evaluate activation function. + x = act_spec.func(x, alpha=alpha) + + # Scale by gain. + if gain != 1: + x *= gain + + # Clamp. + if clamp is not None: + clamp = np.asarray(clamp, dtype=x.dtype.name) + assert clamp.shape == () and clamp >= 0 + x = tf.clip_by_value(x, -clamp, clamp) + return x + +#---------------------------------------------------------------------------- + +def _fused_bias_act_cuda(x, b, axis, act, alpha, gain, clamp): + """Fast CUDA implementation of `fused_bias_act()` using custom ops.""" + + # Validate arguments. + x = tf.convert_to_tensor(x) + empty_tensor = tf.constant([], dtype=x.dtype) + b = tf.convert_to_tensor(b) if b is not None else empty_tensor + act_spec = activation_funcs[act] + assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis]) + assert b.shape[0] == 0 or 0 <= axis < x.shape.rank + if alpha is None: + alpha = act_spec.def_alpha + if gain is None: + gain = act_spec.def_gain + + # Special cases. + if act == 'linear' and b is None and gain == 1.0: + return x + if act_spec.cuda_idx is None: + return _fused_bias_act_ref(x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain, clamp=clamp) + + # CUDA op. + cuda_op = _get_plugin().fused_bias_act + cuda_kwargs = dict(axis=int(axis), act=int(act_spec.cuda_idx), gain=float(gain)) + if alpha is not None: + cuda_kwargs['alpha'] = float(alpha) + if clamp is not None: + clamp = np.asarray(clamp, dtype=x.dtype.name) + assert clamp.shape == () and clamp >= 0 + cuda_kwargs['clamp'] = float(clamp.astype(np.float32)) + def ref(tensor, name): + return tensor if act_spec.ref == name else empty_tensor + + # Forward pass: y = func(x, b). + def func_y(x, b): + y = cuda_op(x=x, b=b, xref=empty_tensor, yref=empty_tensor, grad=0, **cuda_kwargs) + y.set_shape(x.shape) + return y + + # Backward pass: dx, db = grad(dy, x, y) + def grad_dx(dy, x, y): + dx = cuda_op(x=dy, b=empty_tensor, xref=ref(x,'x'), yref=ref(y,'y'), grad=1, **cuda_kwargs) + dx.set_shape(x.shape) + return dx + def grad_db(dx): + if b.shape[0] == 0: + return empty_tensor + db = dx + if axis < x.shape.rank - 1: + db = tf.reduce_sum(db, list(range(axis + 1, x.shape.rank))) + if axis > 0: + db = tf.reduce_sum(db, list(range(axis))) + db.set_shape(b.shape) + return db + + # Second order gradients: d_dy, d_x = grad2(d_dx, d_db, x, y) + def grad2_d_dy(d_dx, d_db, x, y): + d_dy = cuda_op(x=d_dx, b=d_db, xref=ref(x,'x'), yref=ref(y,'y'), grad=1, **cuda_kwargs) + d_dy.set_shape(x.shape) + return d_dy + def grad2_d_x(d_dx, d_db, x, y): + d_x = cuda_op(x=d_dx, b=d_db, xref=ref(x,'x'), yref=ref(y,'y'), grad=2, **cuda_kwargs) + d_x.set_shape(x.shape) + return d_x + + # Fast version for piecewise-linear activation funcs. + @tf.custom_gradient + def func_zero_2nd_grad(x, b): + y = func_y(x, b) + @tf.custom_gradient + def grad(dy): + dx = grad_dx(dy, x, y) + db = grad_db(dx) + def grad2(d_dx, d_db): + d_dy = grad2_d_dy(d_dx, d_db, x, y) + return d_dy + return (dx, db), grad2 + return y, grad + + # Slow version for general activation funcs. + @tf.custom_gradient + def func_nonzero_2nd_grad(x, b): + y = func_y(x, b) + def grad_wrap(dy): + @tf.custom_gradient + def grad_impl(dy, x): + dx = grad_dx(dy, x, y) + db = grad_db(dx) + def grad2(d_dx, d_db): + d_dy = grad2_d_dy(d_dx, d_db, x, y) + d_x = grad2_d_x(d_dx, d_db, x, y) + return d_dy, d_x + return (dx, db), grad2 + return grad_impl(dy, x) + return y, grad_wrap + + # Which version to use? + if act_spec.zero_2nd_grad: + return func_zero_2nd_grad(x, b) + return func_nonzero_2nd_grad(x, b) + +#---------------------------------------------------------------------------- diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/ops/upfirdn_2d.cu b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/upfirdn_2d.cu new file mode 100644 index 0000000000000000000000000000000000000000..7aad60d53e57d4f3e60f36a24df80a6278f1bb63 --- /dev/null +++ b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/upfirdn_2d.cu @@ -0,0 +1,359 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#define EIGEN_USE_GPU +#define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include + +using namespace tensorflow; +using namespace tensorflow::shape_inference; + +//------------------------------------------------------------------------ +// Helpers. + +#define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false) + +static __host__ __device__ __forceinline__ int floorDiv(int a, int b) +{ + int t = 1 - a / b; + return (a + t * b) / b - t; +} + +//------------------------------------------------------------------------ +// CUDA kernel params. + +template +struct UpFirDn2DKernelParams +{ + const T* x; // [majorDim, inH, inW, minorDim] + const T* k; // [kernelH, kernelW] + T* y; // [majorDim, outH, outW, minorDim] + + int upx; + int upy; + int downx; + int downy; + int padx0; + int padx1; + int pady0; + int pady1; + + int majorDim; + int inH; + int inW; + int minorDim; + int kernelH; + int kernelW; + int outH; + int outW; + int loopMajor; + int loopX; +}; + +//------------------------------------------------------------------------ +// General CUDA implementation for large filter kernels. + +template +static __global__ void UpFirDn2DKernel_large(const UpFirDn2DKernelParams p) +{ + // Calculate thread index. + int minorIdx = blockIdx.x * blockDim.x + threadIdx.x; + int outY = minorIdx / p.minorDim; + minorIdx -= outY * p.minorDim; + int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; + int majorIdxBase = blockIdx.z * p.loopMajor; + if (outXBase >= p.outW || outY >= p.outH || majorIdxBase >= p.majorDim) + return; + + // Setup Y receptive field. + int midY = outY * p.downy + p.upy - 1 - p.pady0; + int inY = min(max(floorDiv(midY, p.upy), 0), p.inH); + int h = min(max(floorDiv(midY + p.kernelH, p.upy), 0), p.inH) - inY; + int kernelY = midY + p.kernelH - (inY + 1) * p.upy; + + // Loop over majorDim and outX. + for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor && majorIdx < p.majorDim; loopMajor++, majorIdx++) + for (int loopX = 0, outX = outXBase; loopX < p.loopX && outX < p.outW; loopX++, outX += blockDim.y) + { + // Setup X receptive field. + int midX = outX * p.downx + p.upx - 1 - p.padx0; + int inX = min(max(floorDiv(midX, p.upx), 0), p.inW); + int w = min(max(floorDiv(midX + p.kernelW, p.upx), 0), p.inW) - inX; + int kernelX = midX + p.kernelW - (inX + 1) * p.upx; + + // Initialize pointers. + const T* xp = &p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx]; + const T* kp = &p.k[kernelY * p.kernelW + kernelX]; + int xpx = p.minorDim; + int kpx = -p.upx; + int xpy = p.inW * p.minorDim; + int kpy = -p.upy * p.kernelW; + + // Inner loop. + float v = 0.0f; + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + v += (float)(*xp) * (float)(*kp); + xp += xpx; + kp += kpx; + } + xp += xpy - w * xpx; + kp += kpy - w * kpx; + } + + // Store result. + p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v; + } +} + +//------------------------------------------------------------------------ +// Specialized CUDA implementation for small filter kernels. + +template +static __global__ void UpFirDn2DKernel_small(const UpFirDn2DKernelParams p) +{ + //assert(kernelW % upx == 0); + //assert(kernelH % upy == 0); + const int tileInW = ((tileOutW - 1) * downx + kernelW - 1) / upx + 1; + const int tileInH = ((tileOutH - 1) * downy + kernelH - 1) / upy + 1; + __shared__ volatile float sk[kernelH][kernelW]; + __shared__ volatile float sx[tileInH][tileInW]; + + // Calculate tile index. + int minorIdx = blockIdx.x; + int tileOutY = minorIdx / p.minorDim; + minorIdx -= tileOutY * p.minorDim; + tileOutY *= tileOutH; + int tileOutXBase = blockIdx.y * p.loopX * tileOutW; + int majorIdxBase = blockIdx.z * p.loopMajor; + if (tileOutXBase >= p.outW | tileOutY >= p.outH | majorIdxBase >= p.majorDim) + return; + + // Load filter kernel (flipped). + for (int tapIdx = threadIdx.x; tapIdx < kernelH * kernelW; tapIdx += blockDim.x) + { + int ky = tapIdx / kernelW; + int kx = tapIdx - ky * kernelW; + float v = 0.0f; + if (kx < p.kernelW & ky < p.kernelH) + v = (float)p.k[(p.kernelH - 1 - ky) * p.kernelW + (p.kernelW - 1 - kx)]; + sk[ky][kx] = v; + } + + // Loop over majorDim and outX. + for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor & majorIdx < p.majorDim; loopMajor++, majorIdx++) + for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outW; loopX++, tileOutX += tileOutW) + { + // Load input pixels. + int tileMidX = tileOutX * downx + upx - 1 - p.padx0; + int tileMidY = tileOutY * downy + upy - 1 - p.pady0; + int tileInX = floorDiv(tileMidX, upx); + int tileInY = floorDiv(tileMidY, upy); + __syncthreads(); + for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW; inIdx += blockDim.x) + { + int relInY = inIdx / tileInW; + int relInX = inIdx - relInY * tileInW; + int inX = relInX + tileInX; + int inY = relInY + tileInY; + float v = 0.0f; + if (inX >= 0 & inY >= 0 & inX < p.inW & inY < p.inH) + v = (float)p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx]; + sx[relInY][relInX] = v; + } + + // Loop over output pixels. + __syncthreads(); + for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW; outIdx += blockDim.x) + { + int relOutY = outIdx / tileOutW; + int relOutX = outIdx - relOutY * tileOutW; + int outX = relOutX + tileOutX; + int outY = relOutY + tileOutY; + + // Setup receptive field. + int midX = tileMidX + relOutX * downx; + int midY = tileMidY + relOutY * downy; + int inX = floorDiv(midX, upx); + int inY = floorDiv(midY, upy); + int relInX = inX - tileInX; + int relInY = inY - tileInY; + int kernelX = (inX + 1) * upx - midX - 1; // flipped + int kernelY = (inY + 1) * upy - midY - 1; // flipped + + // Inner loop. + float v = 0.0f; + #pragma unroll + for (int y = 0; y < kernelH / upy; y++) + #pragma unroll + for (int x = 0; x < kernelW / upx; x++) + v += sx[relInY + y][relInX + x] * sk[kernelY + y * upy][kernelX + x * upx]; + + // Store result. + if (outX < p.outW & outY < p.outH) + p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v; + } + } +} + +//------------------------------------------------------------------------ +// TensorFlow op. + +template +struct UpFirDn2DOp : public OpKernel +{ + UpFirDn2DKernelParams m_attribs; + + UpFirDn2DOp(OpKernelConstruction* ctx) : OpKernel(ctx) + { + memset(&m_attribs, 0, sizeof(m_attribs)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("upx", &m_attribs.upx)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("upy", &m_attribs.upy)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("downx", &m_attribs.downx)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("downy", &m_attribs.downy)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("padx0", &m_attribs.padx0)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("padx1", &m_attribs.padx1)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("pady0", &m_attribs.pady0)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("pady1", &m_attribs.pady1)); + OP_REQUIRES(ctx, m_attribs.upx >= 1 && m_attribs.upy >= 1, errors::InvalidArgument("upx and upy must be at least 1x1")); + OP_REQUIRES(ctx, m_attribs.downx >= 1 && m_attribs.downy >= 1, errors::InvalidArgument("downx and downy must be at least 1x1")); + } + + void Compute(OpKernelContext* ctx) + { + UpFirDn2DKernelParams p = m_attribs; + cudaStream_t stream = ctx->eigen_device().stream(); + + const Tensor& x = ctx->input(0); // [majorDim, inH, inW, minorDim] + const Tensor& k = ctx->input(1); // [kernelH, kernelW] + p.x = x.flat().data(); + p.k = k.flat().data(); + OP_REQUIRES(ctx, x.dims() == 4, errors::InvalidArgument("input must have rank 4")); + OP_REQUIRES(ctx, k.dims() == 2, errors::InvalidArgument("kernel must have rank 2")); + OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("input too large")); + OP_REQUIRES(ctx, k.NumElements() <= kint32max, errors::InvalidArgument("kernel too large")); + + p.majorDim = (int)x.dim_size(0); + p.inH = (int)x.dim_size(1); + p.inW = (int)x.dim_size(2); + p.minorDim = (int)x.dim_size(3); + p.kernelH = (int)k.dim_size(0); + p.kernelW = (int)k.dim_size(1); + OP_REQUIRES(ctx, p.kernelW >= 1 && p.kernelH >= 1, errors::InvalidArgument("kernel must be at least 1x1")); + + p.outW = (p.inW * p.upx + p.padx0 + p.padx1 - p.kernelW + p.downx) / p.downx; + p.outH = (p.inH * p.upy + p.pady0 + p.pady1 - p.kernelH + p.downy) / p.downy; + OP_REQUIRES(ctx, p.outW >= 1 && p.outH >= 1, errors::InvalidArgument("output must be at least 1x1")); + + Tensor* y = NULL; // [majorDim, outH, outW, minorDim] + TensorShape ys; + ys.AddDim(p.majorDim); + ys.AddDim(p.outH); + ys.AddDim(p.outW); + ys.AddDim(p.minorDim); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, ys, &y)); + p.y = y->flat().data(); + OP_REQUIRES(ctx, y->NumElements() <= kint32max, errors::InvalidArgument("output too large")); + + // Choose CUDA kernel to use. + void* cudaKernel = (void*)UpFirDn2DKernel_large; + int tileOutW = -1; + int tileOutH = -1; + + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 7 && p.kernelH <= 7 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 5 && p.kernelH <= 5 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 3 && p.kernelH <= 3 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } + + if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } + if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } + if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } + if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 2 && p.kernelH <= 2 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } + if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } + if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } + if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } + if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } + if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } + if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } + if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } + if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } + if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } + if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } + + if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 8 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 6 && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 4 && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 2 && p.kernelH <= 2 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 16; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 16; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 16; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 16; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 16; } + + // Choose launch params. + dim3 blockSize; + dim3 gridSize; + if (tileOutW > 0 && tileOutH > 0) // small + { + p.loopMajor = (p.majorDim - 1) / 16384 + 1; + p.loopX = 1; + blockSize = dim3(32 * 8, 1, 1); + gridSize = dim3(((p.outH - 1) / tileOutH + 1) * p.minorDim, (p.outW - 1) / (p.loopX * tileOutW) + 1, (p.majorDim - 1) / p.loopMajor + 1); + } + else // large + { + p.loopMajor = (p.majorDim - 1) / 16384 + 1; + p.loopX = 4; + blockSize = dim3(4, 32, 1); + gridSize = dim3((p.outH * p.minorDim - 1) / blockSize.x + 1, (p.outW - 1) / (p.loopX * blockSize.y) + 1, (p.majorDim - 1) / p.loopMajor + 1); + } + + // Launch CUDA kernel. + void* args[] = {&p}; + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(cudaKernel, gridSize, blockSize, args, 0, stream)); + } +}; + +REGISTER_OP("UpFirDn2D") + .Input ("x: T") + .Input ("k: T") + .Output ("y: T") + .Attr ("T: {float, half}") + .Attr ("upx: int = 1") + .Attr ("upy: int = 1") + .Attr ("downx: int = 1") + .Attr ("downy: int = 1") + .Attr ("padx0: int = 0") + .Attr ("padx1: int = 0") + .Attr ("pady0: int = 0") + .Attr ("pady1: int = 0"); +REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint("T"), UpFirDn2DOp); +REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint("T"), UpFirDn2DOp); + +//------------------------------------------------------------------------ diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/ops/upfirdn_2d.py b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/upfirdn_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..55a31af7e146da7afeb964db018f14aca3134920 --- /dev/null +++ b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/upfirdn_2d.py @@ -0,0 +1,418 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom TensorFlow ops for efficient resampling of 2D images.""" + +import os +import numpy as np +import tensorflow as tf +from .. import custom_ops + +def _get_plugin(): + return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu') + +#---------------------------------------------------------------------------- + +def upfirdn_2d(x, k, upx=1, upy=1, downx=1, downy=1, padx0=0, padx1=0, pady0=0, pady1=0, impl='cuda'): + r"""Pad, upsample, FIR filter, and downsample a batch of 2D images. + + Accepts a batch of 2D images of the shape `[majorDim, inH, inW, minorDim]` + and performs the following operations for each image, batched across + `majorDim` and `minorDim`: + + 1. Upsample the image by inserting the zeros after each pixel (`upx`, `upy`). + + 2. Pad the image with zeros by the specified number of pixels on each side + (`padx0`, `padx1`, `pady0`, `pady1`). Specifying a negative value + corresponds to cropping the image. + + 3. Convolve the image with the specified 2D FIR filter (`k`), shrinking the + image so that the footprint of all output pixels lies within the input image. + + 4. Downsample the image by throwing away pixels (`downx`, `downy`). + + This sequence of operations bears close resemblance to scipy.signal.upfirdn(). + The fused op is considerably more efficient than performing the same calculation + using standard TensorFlow ops. It supports gradients of arbitrary order. + + Args: + x: Input tensor of the shape `[majorDim, inH, inW, minorDim]`. + k: 2D FIR filter of the shape `[firH, firW]`. + upx: Integer upsampling factor along the X-axis (default: 1). + upy: Integer upsampling factor along the Y-axis (default: 1). + downx: Integer downsampling factor along the X-axis (default: 1). + downy: Integer downsampling factor along the Y-axis (default: 1). + padx0: Number of pixels to pad on the left side (default: 0). + padx1: Number of pixels to pad on the right side (default: 0). + pady0: Number of pixels to pad on the top side (default: 0). + pady1: Number of pixels to pad on the bottom side (default: 0). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the shape `[majorDim, outH, outW, minorDim]`, and same datatype as `x`. + """ + + impl_dict = { + 'ref': _upfirdn_2d_ref, + 'cuda': _upfirdn_2d_cuda, + } + return impl_dict[impl](x=x, k=k, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1) + +#---------------------------------------------------------------------------- + +def _upfirdn_2d_ref(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1): + """Slow reference implementation of `upfirdn_2d()` using standard TensorFlow ops.""" + + x = tf.convert_to_tensor(x) + k = np.asarray(k, dtype=np.float32) + assert x.shape.rank == 4 + inH = x.shape[1].value + inW = x.shape[2].value + minorDim = _shape(x, 3) + kernelH, kernelW = k.shape + assert inW >= 1 and inH >= 1 + assert kernelW >= 1 and kernelH >= 1 + assert isinstance(upx, int) and isinstance(upy, int) + assert isinstance(downx, int) and isinstance(downy, int) + assert isinstance(padx0, int) and isinstance(padx1, int) + assert isinstance(pady0, int) and isinstance(pady1, int) + + # Upsample (insert zeros). + x = tf.reshape(x, [-1, inH, 1, inW, 1, minorDim]) + x = tf.pad(x, [[0, 0], [0, 0], [0, upy - 1], [0, 0], [0, upx - 1], [0, 0]]) + x = tf.reshape(x, [-1, inH * upy, inW * upx, minorDim]) + + # Pad (crop if negative). + x = tf.pad(x, [[0, 0], [max(pady0, 0), max(pady1, 0)], [max(padx0, 0), max(padx1, 0)], [0, 0]]) + x = x[:, max(-pady0, 0) : x.shape[1].value - max(-pady1, 0), max(-padx0, 0) : x.shape[2].value - max(-padx1, 0), :] + + # Convolve with filter. + x = tf.transpose(x, [0, 3, 1, 2]) + x = tf.reshape(x, [-1, 1, inH * upy + pady0 + pady1, inW * upx + padx0 + padx1]) + w = tf.constant(k[::-1, ::-1, np.newaxis, np.newaxis], dtype=x.dtype) + x = tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='VALID', data_format='NCHW') + x = tf.reshape(x, [-1, minorDim, inH * upy + pady0 + pady1 - kernelH + 1, inW * upx + padx0 + padx1 - kernelW + 1]) + x = tf.transpose(x, [0, 2, 3, 1]) + + # Downsample (throw away pixels). + return x[:, ::downy, ::downx, :] + +#---------------------------------------------------------------------------- + +def _upfirdn_2d_cuda(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1): + """Fast CUDA implementation of `upfirdn_2d()` using custom ops.""" + + x = tf.convert_to_tensor(x) + k = np.asarray(k, dtype=np.float32) + majorDim, inH, inW, minorDim = x.shape.as_list() + kernelH, kernelW = k.shape + assert inW >= 1 and inH >= 1 + assert kernelW >= 1 and kernelH >= 1 + assert isinstance(upx, int) and isinstance(upy, int) + assert isinstance(downx, int) and isinstance(downy, int) + assert isinstance(padx0, int) and isinstance(padx1, int) + assert isinstance(pady0, int) and isinstance(pady1, int) + + outW = (inW * upx + padx0 + padx1 - kernelW) // downx + 1 + outH = (inH * upy + pady0 + pady1 - kernelH) // downy + 1 + assert outW >= 1 and outH >= 1 + + cuda_op = _get_plugin().up_fir_dn2d + kc = tf.constant(k, dtype=x.dtype) + gkc = tf.constant(k[::-1, ::-1], dtype=x.dtype) + gpadx0 = kernelW - padx0 - 1 + gpady0 = kernelH - pady0 - 1 + gpadx1 = inW * upx - outW * downx + padx0 - upx + 1 + gpady1 = inH * upy - outH * downy + pady0 - upy + 1 + + @tf.custom_gradient + def func(x): + y = cuda_op(x=x, k=kc, upx=int(upx), upy=int(upy), downx=int(downx), downy=int(downy), padx0=int(padx0), padx1=int(padx1), pady0=int(pady0), pady1=int(pady1)) + y.set_shape([majorDim, outH, outW, minorDim]) + @tf.custom_gradient + def grad(dy): + dx = cuda_op(x=dy, k=gkc, upx=int(downx), upy=int(downy), downx=int(upx), downy=int(upy), padx0=int(gpadx0), padx1=int(gpadx1), pady0=int(gpady0), pady1=int(gpady1)) + dx.set_shape([majorDim, inH, inW, minorDim]) + return dx, func + return y, grad + return func(x) + +#---------------------------------------------------------------------------- + +def filter_2d(x, k, gain=1, padding=0, data_format='NCHW', impl='cuda'): + r"""Filter a batch of 2D images with the given FIR filter. + + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` + and filters each image with the given filter. The filter is normalized so that + if the input pixels are constant, they will be scaled by the specified `gain`. + Pixels outside the image are assumed to be zero. + + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). + gain: Scaling factor for signal magnitude (default: 1.0). + padding: Number of pixels to pad or crop the output on each side (default: 0). + data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the same shape and datatype as `x`. + """ + + assert isinstance(padding, int) + k = _FilterKernel(k=k, gain=gain) + assert k.w == k.h + pad0 = k.w // 2 + padding + pad1 = (k.w - 1) // 2 + padding + return _simple_upfirdn_2d(x, k, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl) + +#---------------------------------------------------------------------------- + +def upsample_2d(x, k=None, factor=2, gain=1, padding=0, data_format='NCHW', impl='cuda'): + r"""Upsample a batch of 2D images with the given filter. + + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` + and upsamples each image with the given filter. The filter is normalized so that + if the input pixels are constant, they will be scaled by the specified `gain`. + Pixels outside the image are assumed to be zero, and the filter is padded with + zeros so that its shape is a multiple of the upsampling factor. + + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). + The default is `[1] * factor`, which corresponds to nearest-neighbor + upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + padding: Number of pixels to pad or crop the output on each side (default: 0). + data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` or + `[N, H * factor, W * factor, C]`, and same datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + assert isinstance(padding, int) + k = _FilterKernel(k if k is not None else [1] * factor, gain * (factor ** 2)) + assert k.w == k.h + pad0 = (k.w + factor - 1) // 2 + padding + pad1 = (k.w - factor) // 2 + padding + return _simple_upfirdn_2d(x, k, up=factor, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl) + +#---------------------------------------------------------------------------- + +def downsample_2d(x, k=None, factor=2, gain=1, padding=0, data_format='NCHW', impl='cuda'): + r"""Downsample a batch of 2D images with the given filter. + + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` + and downsamples each image with the given filter. The filter is normalized so that + if the input pixels are constant, they will be scaled by the specified `gain`. + Pixels outside the image are assumed to be zero, and the filter is padded with + zeros so that its shape is a multiple of the downsampling factor. + + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). + The default is `[1] * factor`, which corresponds to average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + padding: Number of pixels to pad or crop the output on each side (default: 0). + data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` or + `[N, H // factor, W // factor, C]`, and same datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + assert isinstance(padding, int) + k = _FilterKernel(k if k is not None else [1] * factor, gain) + assert k.w == k.h + pad0 = (k.w - factor + 1) // 2 + padding * factor + pad1 = (k.w - factor) // 2 + padding * factor + return _simple_upfirdn_2d(x, k, down=factor, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl) + +#---------------------------------------------------------------------------- + +def upsample_conv_2d(x, w, k=None, factor=2, gain=1, padding=0, data_format='NCHW', impl='cuda'): + r"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`. + + Padding is performed only once at the beginning, not between the operations. + The fused op is considerably more efficient than performing the same calculation + using standard TensorFlow ops. It supports gradients of arbitrary order. + + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. + Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). + The default is `[1] * factor`, which corresponds to nearest-neighbor + upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + padding: Number of pixels to pad or crop the output on each side (default: 0). + data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` or + `[N, H * factor, W * factor, C]`, and same datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + assert isinstance(padding, int) + + # Check weight shape. + w = tf.convert_to_tensor(w) + ch, cw, _inC, _outC = w.shape.as_list() + inC = _shape(w, 2) + outC = _shape(w, 3) + assert cw == ch + + # Fast path for 1x1 convolution. + if cw == 1 and ch == 1: + x = tf.nn.conv2d(x, w, data_format=data_format, strides=[1,1,1,1], padding='VALID') + x = upsample_2d(x, k, factor=factor, gain=gain, padding=padding, data_format=data_format, impl=impl) + return x + + # Setup filter kernel. + k = _FilterKernel(k if k is not None else [1] * factor, gain * (factor ** 2)) + assert k.w == k.h + + # Determine data dimensions. + if data_format == 'NCHW': + stride = [1, 1, factor, factor] + output_shape = [_shape(x, 0), outC, (_shape(x, 2) - 1) * factor + ch, (_shape(x, 3) - 1) * factor + cw] + num_groups = _shape(x, 1) // inC + else: + stride = [1, factor, factor, 1] + output_shape = [_shape(x, 0), (_shape(x, 1) - 1) * factor + ch, (_shape(x, 2) - 1) * factor + cw, outC] + num_groups = _shape(x, 3) // inC + + # Transpose weights. + w = tf.reshape(w, [ch, cw, inC, num_groups, -1]) + w = tf.transpose(w[::-1, ::-1], [0, 1, 4, 3, 2]) + w = tf.reshape(w, [ch, cw, -1, num_groups * inC]) + + # Execute. + x = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=stride, padding='VALID', data_format=data_format) + pad0 = (k.w + factor - cw) // 2 + padding + pad1 = (k.w - factor - cw + 3) // 2 + padding + return _simple_upfirdn_2d(x, k, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl) + +#---------------------------------------------------------------------------- + +def conv_downsample_2d(x, w, k=None, factor=2, gain=1, padding=0, data_format='NCHW', impl='cuda'): + r"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`. + + Padding is performed only once at the beginning, not between the operations. + The fused op is considerably more efficient than performing the same calculation + using standard TensorFlow ops. It supports gradients of arbitrary order. + + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. + Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). + The default is `[1] * factor`, which corresponds to average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + padding: Number of pixels to pad or crop the output on each side (default: 0). + data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` or + `[N, H // factor, W // factor, C]`, and same datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + assert isinstance(padding, int) + + # Check weight shape. + w = tf.convert_to_tensor(w) + ch, cw, _inC, _outC = w.shape.as_list() + assert cw == ch + + # Fast path for 1x1 convolution. + if cw == 1 and ch == 1: + x = downsample_2d(x, k, factor=factor, gain=gain, padding=padding, data_format=data_format, impl=impl) + x = tf.nn.conv2d(x, w, data_format=data_format, strides=[1,1,1,1], padding='VALID') + return x + + # Setup filter kernel. + k = _FilterKernel(k if k is not None else [1] * factor, gain) + assert k.w == k.h + + # Determine stride. + if data_format == 'NCHW': + s = [1, 1, factor, factor] + else: + s = [1, factor, factor, 1] + + # Execute. + pad0 = (k.w - factor + cw) // 2 + padding * factor + pad1 = (k.w - factor + cw - 1) // 2 + padding * factor + x = _simple_upfirdn_2d(x, k, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl) + return tf.nn.conv2d(x, w, strides=s, padding='VALID', data_format=data_format) + +#---------------------------------------------------------------------------- +# Internal helpers. + +class _FilterKernel: + def __init__(self, k, gain=1): + k = np.asarray(k, dtype=np.float32) + k /= np.sum(k) + + # Separable. + if k.ndim == 1 and k.size >= 8: + self.w = k.size + self.h = k.size + self.kx = k[np.newaxis, :] + self.ky = k[:, np.newaxis] * gain + self.kxy = None + + # Non-separable. + else: + if k.ndim == 1: + k = np.outer(k, k) + assert k.ndim == 2 + self.w = k.shape[1] + self.h = k.shape[0] + self.kx = None + self.ky = None + self.kxy = k * gain + +def _simple_upfirdn_2d(x, k, up=1, down=1, pad0=0, pad1=0, data_format='NCHW', impl='cuda'): + assert isinstance(k, _FilterKernel) + assert data_format in ['NCHW', 'NHWC'] + assert x.shape.rank == 4 + y = x + if data_format == 'NCHW': + y = tf.reshape(y, [-1, _shape(y, 2), _shape(y, 3), 1]) + if k.kx is not None: + y = upfirdn_2d(y, k.kx, upx=up, downx=down, padx0=pad0, padx1=pad1, impl=impl) + if k.ky is not None: + y = upfirdn_2d(y, k.ky, upy=up, downy=down, pady0=pad0, pady1=pad1, impl=impl) + if k.kxy is not None: + y = upfirdn_2d(y, k.kxy, upx=up, upy=up, downx=down, downy=down, padx0=pad0, padx1=pad1, pady0=pad0, pady1=pad1, impl=impl) + if data_format == 'NCHW': + y = tf.reshape(y, [-1, _shape(x, 1), _shape(y, 1), _shape(y, 2)]) + return y + +def _shape(tf_expr, dim_idx): + if tf_expr.shape.rank is not None: + dim = tf_expr.shape[dim_idx].value + if dim is not None: + return dim + return tf.shape(tf_expr)[dim_idx] + +#---------------------------------------------------------------------------- diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/optimizer.py b/models/StyleCLIP/global_directions/dnnlib/tflib/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..cae5ffff3d11aaccd705d6936e080175ab97dd0e --- /dev/null +++ b/models/StyleCLIP/global_directions/dnnlib/tflib/optimizer.py @@ -0,0 +1,372 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Helper wrapper for a Tensorflow optimizer.""" + +import platform +import numpy as np +import tensorflow as tf + +from collections import OrderedDict +from typing import List, Union + +from . import autosummary +from . import tfutil +from .. import util + +from .tfutil import TfExpression, TfExpressionEx + +_collective_ops_warning_printed = False +_collective_ops_group_key = 831766147 +_collective_ops_instance_key = 436340067 + +class Optimizer: + """A Wrapper for tf.train.Optimizer. + + Automatically takes care of: + - Gradient averaging for multi-GPU training. + - Gradient accumulation for arbitrarily large minibatches. + - Dynamic loss scaling and typecasts for FP16 training. + - Ignoring corrupted gradients that contain NaNs/Infs. + - Reporting statistics. + - Well-chosen default settings. + """ + + def __init__(self, + name: str = "Train", # Name string that will appear in TensorFlow graph. + tf_optimizer: str = "tf.train.AdamOptimizer", # Underlying optimizer class. + learning_rate: TfExpressionEx = 0.001, # Learning rate. Can vary over time. + minibatch_multiplier: TfExpressionEx = None, # Treat N consecutive minibatches as one by accumulating gradients. + share: "Optimizer" = None, # Share internal state with a previously created optimizer? + use_loss_scaling: bool = False, # Enable dynamic loss scaling for robust mixed-precision training? + loss_scaling_init: float = 64.0, # Log2 of initial loss scaling factor. + loss_scaling_inc: float = 0.0005, # Log2 of per-minibatch loss scaling increment when there is no overflow. + loss_scaling_dec: float = 1.0, # Log2 of per-minibatch loss scaling decrement when there is an overflow. + report_mem_usage: bool = False, # Report fine-grained memory usage statistics in TensorBoard? + **kwargs): + + # Public fields. + self.name = name + self.learning_rate = learning_rate + self.minibatch_multiplier = minibatch_multiplier + self.id = self.name.replace("/", ".") + self.scope = tf.get_default_graph().unique_name(self.id) + self.optimizer_class = util.get_obj_by_name(tf_optimizer) + self.optimizer_kwargs = dict(kwargs) + self.use_loss_scaling = use_loss_scaling + self.loss_scaling_init = loss_scaling_init + self.loss_scaling_inc = loss_scaling_inc + self.loss_scaling_dec = loss_scaling_dec + + # Private fields. + self._updates_applied = False + self._devices = OrderedDict() # device_name => EasyDict() + self._shared_optimizers = OrderedDict() # device_name => optimizer_class + self._gradient_shapes = None # [shape, ...] + self._report_mem_usage = report_mem_usage + + # Validate arguments. + assert callable(self.optimizer_class) + + # Share internal state if requested. + if share is not None: + assert isinstance(share, Optimizer) + assert self.optimizer_class is share.optimizer_class + assert self.learning_rate is share.learning_rate + assert self.optimizer_kwargs == share.optimizer_kwargs + self._shared_optimizers = share._shared_optimizers # pylint: disable=protected-access + + def _get_device(self, device_name: str): + """Get internal state for the given TensorFlow device.""" + tfutil.assert_tf_initialized() + if device_name in self._devices: + return self._devices[device_name] + + # Initialize fields. + device = util.EasyDict() + device.name = device_name + device.optimizer = None # Underlying optimizer: optimizer_class + device.loss_scaling_var = None # Log2 of loss scaling: tf.Variable + device.grad_raw = OrderedDict() # Raw gradients: var => [grad, ...] + device.grad_clean = OrderedDict() # Clean gradients: var => grad + device.grad_acc_vars = OrderedDict() # Accumulation sums: var => tf.Variable + device.grad_acc_count = None # Accumulation counter: tf.Variable + device.grad_acc = OrderedDict() # Accumulated gradients: var => grad + + # Setup TensorFlow objects. + with tfutil.absolute_name_scope(self.scope + "/Devices"), tf.device(device_name), tf.control_dependencies(None): + if device_name not in self._shared_optimizers: + optimizer_name = self.scope.replace("/", "_") + "_opt%d" % len(self._shared_optimizers) + self._shared_optimizers[device_name] = self.optimizer_class(name=optimizer_name, learning_rate=self.learning_rate, **self.optimizer_kwargs) + device.optimizer = self._shared_optimizers[device_name] + if self.use_loss_scaling: + device.loss_scaling_var = tf.Variable(np.float32(self.loss_scaling_init), trainable=False, name="loss_scaling_var") + + # Register device. + self._devices[device_name] = device + return device + + def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None: + """Register the gradients of the given loss function with respect to the given variables. + Intended to be called once per GPU.""" + tfutil.assert_tf_initialized() + assert not self._updates_applied + device = self._get_device(loss.device) + + # Validate trainables. + if isinstance(trainable_vars, dict): + trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars + assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1 + assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss]) + assert all(var.device == device.name for var in trainable_vars) + + # Validate shapes. + if self._gradient_shapes is None: + self._gradient_shapes = [var.shape.as_list() for var in trainable_vars] + assert len(trainable_vars) == len(self._gradient_shapes) + assert all(var.shape.as_list() == var_shape for var, var_shape in zip(trainable_vars, self._gradient_shapes)) + + # Report memory usage if requested. + deps = [loss] + if self._report_mem_usage: + self._report_mem_usage = False + try: + with tf.name_scope(self.id + '_mem'), tf.device(device.name), tf.control_dependencies([loss]): + deps.append(autosummary.autosummary(self.id + "/mem_usage_gb", tf.contrib.memory_stats.BytesInUse() / 2**30)) + except tf.errors.NotFoundError: + pass + + # Compute gradients. + with tf.name_scope(self.id + "_grad"), tf.device(device.name), tf.control_dependencies(deps): + loss = self.apply_loss_scaling(tf.cast(loss, tf.float32)) + gate = tf.train.Optimizer.GATE_NONE # disable gating to reduce memory usage + grad_list = device.optimizer.compute_gradients(loss=loss, var_list=trainable_vars, gate_gradients=gate) + + # Register gradients. + for grad, var in grad_list: + if var not in device.grad_raw: + device.grad_raw[var] = [] + device.grad_raw[var].append(grad) + + def apply_updates(self, allow_no_op: bool = False) -> tf.Operation: + """Construct training op to update the registered variables based on their gradients.""" + tfutil.assert_tf_initialized() + assert not self._updates_applied + self._updates_applied = True + all_ops = [] + + # Check for no-op. + if allow_no_op and len(self._devices) == 0: + with tfutil.absolute_name_scope(self.scope): + return tf.no_op(name='TrainingOp') + + # Clean up gradients. + for device_idx, device in enumerate(self._devices.values()): + with tfutil.absolute_name_scope(self.scope + "/Clean%d" % device_idx), tf.device(device.name): + for var, grad in device.grad_raw.items(): + + # Filter out disconnected gradients and convert to float32. + grad = [g for g in grad if g is not None] + grad = [tf.cast(g, tf.float32) for g in grad] + + # Sum within the device. + if len(grad) == 0: + grad = tf.zeros(var.shape) # No gradients => zero. + elif len(grad) == 1: + grad = grad[0] # Single gradient => use as is. + else: + grad = tf.add_n(grad) # Multiple gradients => sum. + + # Scale as needed. + scale = 1.0 / len(device.grad_raw[var]) / len(self._devices) + scale = tf.constant(scale, dtype=tf.float32, name="scale") + if self.minibatch_multiplier is not None: + scale /= tf.cast(self.minibatch_multiplier, tf.float32) + scale = self.undo_loss_scaling(scale) + device.grad_clean[var] = grad * scale + + # Sum gradients across devices. + if len(self._devices) > 1: + with tfutil.absolute_name_scope(self.scope + "/Broadcast"), tf.device(None): + if platform.system() == "Windows": # Windows => NCCL ops are not available. + self._broadcast_fallback() + elif tf.VERSION.startswith("1.15."): # TF 1.15 => NCCL ops are broken: https://github.com/tensorflow/tensorflow/issues/41539 + self._broadcast_fallback() + else: # Otherwise => NCCL ops are safe to use. + self._broadcast_nccl() + + # Apply updates separately on each device. + for device_idx, device in enumerate(self._devices.values()): + with tfutil.absolute_name_scope(self.scope + "/Apply%d" % device_idx), tf.device(device.name): + # pylint: disable=cell-var-from-loop + + # Accumulate gradients over time. + if self.minibatch_multiplier is None: + acc_ok = tf.constant(True, name='acc_ok') + device.grad_acc = OrderedDict(device.grad_clean) + else: + # Create variables. + with tf.control_dependencies(None): + for var in device.grad_clean.keys(): + device.grad_acc_vars[var] = tf.Variable(tf.zeros(var.shape), trainable=False, name="grad_acc_var") + device.grad_acc_count = tf.Variable(tf.zeros([]), trainable=False, name="grad_acc_count") + + # Track counter. + count_cur = device.grad_acc_count + 1.0 + count_inc_op = lambda: tf.assign(device.grad_acc_count, count_cur) + count_reset_op = lambda: tf.assign(device.grad_acc_count, tf.zeros([])) + acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier, tf.float32)) + all_ops.append(tf.cond(acc_ok, count_reset_op, count_inc_op)) + + # Track gradients. + for var, grad in device.grad_clean.items(): + acc_var = device.grad_acc_vars[var] + acc_cur = acc_var + grad + device.grad_acc[var] = acc_cur + with tf.control_dependencies([acc_cur]): + acc_inc_op = lambda: tf.assign(acc_var, acc_cur) + acc_reset_op = lambda: tf.assign(acc_var, tf.zeros(var.shape)) + all_ops.append(tf.cond(acc_ok, acc_reset_op, acc_inc_op)) + + # No overflow => apply gradients. + all_ok = tf.reduce_all(tf.stack([acc_ok] + [tf.reduce_all(tf.is_finite(g)) for g in device.grad_acc.values()])) + apply_op = lambda: device.optimizer.apply_gradients([(tf.cast(grad, var.dtype), var) for var, grad in device.grad_acc.items()]) + all_ops.append(tf.cond(all_ok, apply_op, tf.no_op)) + + # Adjust loss scaling. + if self.use_loss_scaling: + ls_inc_op = lambda: tf.assign_add(device.loss_scaling_var, self.loss_scaling_inc) + ls_dec_op = lambda: tf.assign_sub(device.loss_scaling_var, self.loss_scaling_dec) + ls_update_op = lambda: tf.group(tf.cond(all_ok, ls_inc_op, ls_dec_op)) + all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op)) + + # Last device => report statistics. + if device_idx == len(self._devices) - 1: + all_ops.append(autosummary.autosummary(self.id + "/learning_rate", tf.convert_to_tensor(self.learning_rate))) + all_ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(all_ok, 0, 1), condition=acc_ok)) + if self.use_loss_scaling: + all_ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", device.loss_scaling_var)) + + # Initialize variables. + self.reset_optimizer_state() + if self.use_loss_scaling: + tfutil.init_uninitialized_vars([device.loss_scaling_var for device in self._devices.values()]) + if self.minibatch_multiplier is not None: + tfutil.run([var.initializer for device in self._devices.values() for var in list(device.grad_acc_vars.values()) + [device.grad_acc_count]]) + + # Group everything into a single op. + with tfutil.absolute_name_scope(self.scope): + return tf.group(*all_ops, name="TrainingOp") + + def reset_optimizer_state(self) -> None: + """Reset internal state of the underlying optimizer.""" + tfutil.assert_tf_initialized() + tfutil.run([var.initializer for device in self._devices.values() for var in device.optimizer.variables()]) + + def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]: + """Get or create variable representing log2 of the current dynamic loss scaling factor.""" + return self._get_device(device).loss_scaling_var + + def apply_loss_scaling(self, value: TfExpression) -> TfExpression: + """Apply dynamic loss scaling for the given expression.""" + assert tfutil.is_tf_expression(value) + if not self.use_loss_scaling: + return value + return value * tfutil.exp2(self.get_loss_scaling_var(value.device)) + + def undo_loss_scaling(self, value: TfExpression) -> TfExpression: + """Undo the effect of dynamic loss scaling for the given expression.""" + assert tfutil.is_tf_expression(value) + if not self.use_loss_scaling: + return value + return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type + + def _broadcast_nccl(self): + """Sum gradients across devices using NCCL ops (fast path).""" + from tensorflow.python.ops import nccl_ops # pylint: disable=no-name-in-module + for all_vars in zip(*[device.grad_clean.keys() for device in self._devices.values()]): + if any(x.shape.num_elements() > 0 for x in all_vars): + all_grads = [device.grad_clean[var] for device, var in zip(self._devices.values(), all_vars)] + all_grads = nccl_ops.all_sum(all_grads) + for device, var, grad in zip(self._devices.values(), all_vars, all_grads): + device.grad_clean[var] = grad + + def _broadcast_fallback(self): + """Sum gradients across devices using TensorFlow collective ops (slow fallback path).""" + from tensorflow.python.ops import collective_ops # pylint: disable=no-name-in-module + global _collective_ops_warning_printed, _collective_ops_group_key, _collective_ops_instance_key + if all(x.shape.num_elements() == 0 for device in self._devices.values() for x in device.grad_clean.values()): + return + if not _collective_ops_warning_printed: + print("------------------------------------------------------------------------") + print("WARNING: Using slow fallback implementation for inter-GPU communication.") + print("Please use TensorFlow 1.14 on Linux for optimal training performance.") + print("------------------------------------------------------------------------") + _collective_ops_warning_printed = True + for device in self._devices.values(): + with tf.device(device.name): + combo = [tf.reshape(x, [x.shape.num_elements()]) for x in device.grad_clean.values()] + combo = tf.concat(combo, axis=0) + combo = collective_ops.all_reduce(combo, merge_op='Add', final_op='Id', + group_size=len(self._devices), group_key=_collective_ops_group_key, + instance_key=_collective_ops_instance_key) + cur_ofs = 0 + for var, grad_old in device.grad_clean.items(): + grad_new = tf.reshape(combo[cur_ofs : cur_ofs + grad_old.shape.num_elements()], grad_old.shape) + cur_ofs += grad_old.shape.num_elements() + device.grad_clean[var] = grad_new + _collective_ops_instance_key += 1 + + +class SimpleAdam: + """Simplified version of tf.train.AdamOptimizer that behaves identically when used with dnnlib.tflib.Optimizer.""" + + def __init__(self, name="Adam", learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8): + self.name = name + self.learning_rate = learning_rate + self.beta1 = beta1 + self.beta2 = beta2 + self.epsilon = epsilon + self.all_state_vars = [] + + def variables(self): + return self.all_state_vars + + def compute_gradients(self, loss, var_list, gate_gradients=tf.train.Optimizer.GATE_NONE): + assert gate_gradients == tf.train.Optimizer.GATE_NONE + return list(zip(tf.gradients(loss, var_list), var_list)) + + def apply_gradients(self, grads_and_vars): + with tf.name_scope(self.name): + state_vars = [] + update_ops = [] + + # Adjust learning rate to deal with startup bias. + with tf.control_dependencies(None): + b1pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False) + b2pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False) + state_vars += [b1pow_var, b2pow_var] + b1pow_new = b1pow_var * self.beta1 + b2pow_new = b2pow_var * self.beta2 + update_ops += [tf.assign(b1pow_var, b1pow_new), tf.assign(b2pow_var, b2pow_new)] + lr_new = self.learning_rate * tf.sqrt(1 - b2pow_new) / (1 - b1pow_new) + + # Construct ops to update each variable. + for grad, var in grads_and_vars: + with tf.control_dependencies(None): + m_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False) + v_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False) + state_vars += [m_var, v_var] + m_new = self.beta1 * m_var + (1 - self.beta1) * grad + v_new = self.beta2 * v_var + (1 - self.beta2) * tf.square(grad) + var_delta = lr_new * m_new / (tf.sqrt(v_new) + self.epsilon) + update_ops += [tf.assign(m_var, m_new), tf.assign(v_var, v_new), tf.assign_sub(var, var_delta)] + + # Group everything together. + self.all_state_vars += state_vars + return tf.group(*update_ops) diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/tfutil.py b/models/StyleCLIP/global_directions/dnnlib/tflib/tfutil.py new file mode 100644 index 0000000000000000000000000000000000000000..fe21100299251492ee6d49a7fab566ffb8283702 --- /dev/null +++ b/models/StyleCLIP/global_directions/dnnlib/tflib/tfutil.py @@ -0,0 +1,262 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Miscellaneous helper utils for Tensorflow.""" + +import os +import numpy as np +import tensorflow as tf + +# Silence deprecation warnings from TensorFlow 1.13 onwards +import logging +logging.getLogger('tensorflow').setLevel(logging.ERROR) +import tensorflow.contrib # requires TensorFlow 1.x! +tf.contrib = tensorflow.contrib + +from typing import Any, Iterable, List, Union + +TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation] +"""A type that represents a valid Tensorflow expression.""" + +TfExpressionEx = Union[TfExpression, int, float, np.ndarray] +"""A type that can be converted to a valid Tensorflow expression.""" + + +def run(*args, **kwargs) -> Any: + """Run the specified ops in the default session.""" + assert_tf_initialized() + return tf.get_default_session().run(*args, **kwargs) + + +def is_tf_expression(x: Any) -> bool: + """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation.""" + return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation)) + + +def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]: + """Convert a Tensorflow shape to a list of ints. Retained for backwards compatibility -- use TensorShape.as_list() in new code.""" + return [dim.value for dim in shape] + + +def flatten(x: TfExpressionEx) -> TfExpression: + """Shortcut function for flattening a tensor.""" + with tf.name_scope("Flatten"): + return tf.reshape(x, [-1]) + + +def log2(x: TfExpressionEx) -> TfExpression: + """Logarithm in base 2.""" + with tf.name_scope("Log2"): + return tf.log(x) * np.float32(1.0 / np.log(2.0)) + + +def exp2(x: TfExpressionEx) -> TfExpression: + """Exponent in base 2.""" + with tf.name_scope("Exp2"): + return tf.exp(x * np.float32(np.log(2.0))) + + +def erfinv(y: TfExpressionEx) -> TfExpression: + """Inverse of the error function.""" + # pylint: disable=no-name-in-module + from tensorflow.python.ops.distributions import special_math + return special_math.erfinv(y) + + +def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx: + """Linear interpolation.""" + with tf.name_scope("Lerp"): + return a + (b - a) * t + + +def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression: + """Linear interpolation with clip.""" + with tf.name_scope("LerpClip"): + return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0) + + +def absolute_name_scope(scope: str) -> tf.name_scope: + """Forcefully enter the specified name scope, ignoring any surrounding scopes.""" + return tf.name_scope(scope + "/") + + +def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope: + """Forcefully enter the specified variable scope, ignoring any surrounding scopes.""" + return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False) + + +def _sanitize_tf_config(config_dict: dict = None) -> dict: + # Defaults. + cfg = dict() + cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is. + cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is. + cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info. + cfg["env.HDF5_USE_FILE_LOCKING"] = "FALSE" # Disable HDF5 file locking to avoid concurrency issues with network shares. + cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used. + cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed. + + # Remove defaults for environment variables that are already set. + for key in list(cfg): + fields = key.split(".") + if fields[0] == "env": + assert len(fields) == 2 + if fields[1] in os.environ: + del cfg[key] + + # User overrides. + if config_dict is not None: + cfg.update(config_dict) + return cfg + + +def init_tf(config_dict: dict = None) -> None: + """Initialize TensorFlow session using good default settings.""" + # Skip if already initialized. + if tf.get_default_session() is not None: + return + + # Setup config dict and random seeds. + cfg = _sanitize_tf_config(config_dict) + np_random_seed = cfg["rnd.np_random_seed"] + if np_random_seed is not None: + np.random.seed(np_random_seed) + tf_random_seed = cfg["rnd.tf_random_seed"] + if tf_random_seed == "auto": + tf_random_seed = np.random.randint(1 << 31) + if tf_random_seed is not None: + tf.set_random_seed(tf_random_seed) + + # Setup environment variables. + for key, value in cfg.items(): + fields = key.split(".") + if fields[0] == "env": + assert len(fields) == 2 + os.environ[fields[1]] = str(value) + + # Create default TensorFlow session. + create_session(cfg, force_as_default=True) + + +def assert_tf_initialized(): + """Check that TensorFlow session has been initialized.""" + if tf.get_default_session() is None: + raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().") + + +def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session: + """Create tf.Session based on config dict.""" + # Setup TensorFlow config proto. + cfg = _sanitize_tf_config(config_dict) + config_proto = tf.ConfigProto() + for key, value in cfg.items(): + fields = key.split(".") + if fields[0] not in ["rnd", "env"]: + obj = config_proto + for field in fields[:-1]: + obj = getattr(obj, field) + setattr(obj, fields[-1], value) + + # Create session. + session = tf.Session(config=config_proto) + if force_as_default: + # pylint: disable=protected-access + session._default_session = session.as_default() + session._default_session.enforce_nesting = False + session._default_session.__enter__() + return session + + +def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None: + """Initialize all tf.Variables that have not already been initialized. + + Equivalent to the following, but more efficient and does not bloat the tf graph: + tf.variables_initializer(tf.report_uninitialized_variables()).run() + """ + assert_tf_initialized() + if target_vars is None: + target_vars = tf.global_variables() + + test_vars = [] + test_ops = [] + + with tf.control_dependencies(None): # ignore surrounding control_dependencies + for var in target_vars: + assert is_tf_expression(var) + + try: + tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0")) + except KeyError: + # Op does not exist => variable may be uninitialized. + test_vars.append(var) + + with absolute_name_scope(var.name.split(":")[0]): + test_ops.append(tf.is_variable_initialized(var)) + + init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited] + run([var.initializer for var in init_vars]) + + +def set_vars(var_to_value_dict: dict) -> None: + """Set the values of given tf.Variables. + + Equivalent to the following, but more efficient and does not bloat the tf graph: + tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()] + """ + assert_tf_initialized() + ops = [] + feed_dict = {} + + for var, value in var_to_value_dict.items(): + assert is_tf_expression(var) + + try: + setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op + except KeyError: + with absolute_name_scope(var.name.split(":")[0]): + with tf.control_dependencies(None): # ignore surrounding control_dependencies + setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter + + ops.append(setter) + feed_dict[setter.op.inputs[1]] = value + + run(ops, feed_dict) + + +def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs): + """Create tf.Variable with large initial value without bloating the tf graph.""" + assert_tf_initialized() + assert isinstance(initial_value, np.ndarray) + zeros = tf.zeros(initial_value.shape, initial_value.dtype) + var = tf.Variable(zeros, *args, **kwargs) + set_vars({var: initial_value}) + return var + + +def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False): + """Convert a minibatch of images from uint8 to float32 with configurable dynamic range. + Can be used as an input transformation for Network.run(). + """ + images = tf.cast(images, tf.float32) + if nhwc_to_nchw: + images = tf.transpose(images, [0, 3, 1, 2]) + return images * ((drange[1] - drange[0]) / 255) + drange[0] + + +def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1): + """Convert a minibatch of images from float32 to uint8 with configurable dynamic range. + Can be used as an output transformation for Network.run(). + """ + images = tf.cast(images, tf.float32) + if shrink > 1: + ksize = [1, 1, shrink, shrink] + images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") + if nchw_to_nhwc: + images = tf.transpose(images, [0, 2, 3, 1]) + scale = 255 / (drange[1] - drange[0]) + images = images * scale + (0.5 - drange[0] * scale) + return tf.saturate_cast(images, tf.uint8) diff --git a/PTI/dnnlib/util.py b/models/StyleCLIP/global_directions/dnnlib/util.py similarity index 98% rename from PTI/dnnlib/util.py rename to models/StyleCLIP/global_directions/dnnlib/util.py index 76725336d01e75e1c68daa88be47f4fde0bbc63b..0c35b8923bb27bcd91fd0c14234480067138a3fc 100644 --- a/PTI/dnnlib/util.py +++ b/models/StyleCLIP/global_directions/dnnlib/util.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation @@ -75,10 +75,8 @@ class Logger(object): def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: self.close() - def write(self, text: Union[str, bytes]) -> None: + def write(self, text: str) -> None: """Write text to stdout (and a file) and optionally flush.""" - if isinstance(text, bytes): - text = text.decode() if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash return @@ -109,7 +107,6 @@ class Logger(object): if self.file is not None: self.file.close() - self.file = None # Cache directories @@ -450,8 +447,6 @@ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: b if verbose: print(" done") break - except KeyboardInterrupt: - raise except: if not attempts_left: if verbose: diff --git a/models/StyleCLIP/global_directions/manipulate.py b/models/StyleCLIP/global_directions/manipulate.py new file mode 100644 index 0000000000000000000000000000000000000000..e1a2480caad8016fea0c06f0bfe521b25f084436 --- /dev/null +++ b/models/StyleCLIP/global_directions/manipulate.py @@ -0,0 +1,278 @@ + + +import os +import os.path +import pickle +import numpy as np +import tensorflow as tf +from dnnlib import tflib +from global_directions.utils.visualizer import HtmlPageVisualizer + + +def Vis(bname,suffix,out,rownames=None,colnames=None): + num_images=out.shape[0] + step=out.shape[1] + + if colnames is None: + colnames=[f'Step {i:02d}' for i in range(1, step + 1)] + if rownames is None: + rownames=[str(i) for i in range(num_images)] + + + visualizer = HtmlPageVisualizer( + num_rows=num_images, num_cols=step + 1, viz_size=256) + visualizer.set_headers( + ['Name'] +colnames) + + for i in range(num_images): + visualizer.set_cell(i, 0, text=rownames[i]) + + for i in range(num_images): + for k in range(step): + image=out[i,k,:,:,:] + visualizer.set_cell(i, 1+k, image=image) + + # Save results. + visualizer.save(f'./html/'+bname+'_'+suffix+'.html') + + + + +def LoadData(img_path): + tmp=img_path+'S' + with open(tmp, "rb") as fp: #Pickling + s_names,all_s=pickle.load( fp) + dlatents=all_s + + pindexs=[] + mindexs=[] + for i in range(len(s_names)): + name=s_names[i] + if not('ToRGB' in name): + mindexs.append(i) + else: + pindexs.append(i) + + tmp=img_path+'S_mean_std' + with open(tmp, "rb") as fp: #Pickling + m,std=pickle.load( fp) + + return dlatents,s_names,mindexs,pindexs,m,std + + +def LoadModel(model_path,model_name): + # Initialize TensorFlow. + tflib.init_tf() + tmp=os.path.join(model_path,model_name) + with open(tmp, 'rb') as f: + _, _, Gs = pickle.load(f) + Gs.print_layers() + return Gs + +def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False): + """Convert a minibatch of images from float32 to uint8 with configurable dynamic range. + Can be used as an output transformation for Network.run(). + """ + if nchw_to_nhwc: + images = np.transpose(images, [0, 2, 3, 1]) + + scale = 255 / (drange[1] - drange[0]) + images = images * scale + (0.5 - drange[0] * scale) + + np.clip(images, 0, 255, out=images) + images=images.astype('uint8') + return images + + +def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False): + """Convert a minibatch of images from uint8 to float32 with configurable dynamic range. + Can be used as an input transformation for Network.run(). + """ + if nhwc_to_nchw: + images=np.rollaxis(images, 3, 1) + return images/ 255 *(drange[1] - drange[0])+ drange[0] + + +class Manipulator(): + def __init__(self,dataset_name='ffhq'): + self.file_path='./' + self.img_path=self.file_path+'npy/'+dataset_name+'/' + self.model_path=self.file_path+'model/' + self.dataset_name=dataset_name + self.model_name=dataset_name+'.pkl' + + self.alpha=[0] #manipulation strength + self.num_images=10 + self.img_index=0 #which image to start + self.viz_size=256 + self.manipulate_layers=None #which layer to manipulate, list + + self.dlatents,self.s_names,self.mindexs,self.pindexs,self.code_mean,self.code_std=LoadData(self.img_path) + + self.sess=tf.InteractiveSession() + init = tf.global_variables_initializer() + self.sess.run(init) + self.Gs=LoadModel(self.model_path,self.model_name) + self.num_layers=len(self.dlatents) + + self.Vis=Vis + self.noise_constant={} + + for i in range(len(self.s_names)): + tmp1=self.s_names[i].split('/') + if not 'ToRGB' in tmp1: + tmp1[-1]='random_normal:0' + size=int(tmp1[1].split('x')[0]) + tmp1='/'.join(tmp1) + tmp=(1,1,size,size) + self.noise_constant[tmp1]=np.random.random(tmp) + + tmp=self.Gs.components.synthesis.input_shape[1] + d={} + d['G_synthesis_1/dlatents_in:0']=np.zeros([1,tmp,512]) + names=list(self.noise_constant.keys()) + tmp=tflib.run(names,d) + for i in range(len(names)): + self.noise_constant[names[i]]=tmp[i] + + self.fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) + self.img_size=self.Gs.output_shape[-1] + + def GenerateImg(self,codes): + + + num_images,step=codes[0].shape[:2] + + + out=np.zeros((num_images,step,self.img_size,self.img_size,3),dtype='uint8') + for i in range(num_images): + for k in range(step): + d={} + for m in range(len(self.s_names)): + d[self.s_names[m]]=codes[m][i,k][None,:] #need to change + d['G_synthesis_1/4x4/Const/Shape:0']=np.array([1,18, 512], dtype=np.int32) + d.update(self.noise_constant) + img=tflib.run('G_synthesis_1/images_out:0', d) + image=convert_images_to_uint8(img, nchw_to_nhwc=True) + out[i,k,:,:,:]=image[0] + return out + + + + def MSCode(self,dlatent_tmp,boundary_tmp): + + step=len(self.alpha) + dlatent_tmp1=[tmp.reshape((self.num_images,-1)) for tmp in dlatent_tmp] + dlatent_tmp2=[np.tile(tmp[:,None],(1,step,1)) for tmp in dlatent_tmp1] # (10, 7, 512) + + l=np.array(self.alpha) + l=l.reshape( + [step if axis == 1 else 1 for axis in range(dlatent_tmp2[0].ndim)]) + + if type(self.manipulate_layers)==int: + tmp=[self.manipulate_layers] + elif type(self.manipulate_layers)==list: + tmp=self.manipulate_layers + elif self.manipulate_layers is None: + tmp=np.arange(len(boundary_tmp)) + else: + raise ValueError('manipulate_layers is wrong') + + for i in tmp: + dlatent_tmp2[i]+=l*boundary_tmp[i] + + codes=[] + for i in range(len(dlatent_tmp2)): + tmp=list(dlatent_tmp[i].shape) + tmp.insert(1,step) + codes.append(dlatent_tmp2[i].reshape(tmp)) + return codes + + + def EditOne(self,bname,dlatent_tmp=None): + if dlatent_tmp==None: + dlatent_tmp=[tmp[self.img_index:(self.img_index+self.num_images)] for tmp in self.dlatents] + + boundary_tmp=[] + for i in range(len(self.boundary)): + tmp=self.boundary[i] + if len(tmp)<=bname: + boundary_tmp.append([]) + else: + boundary_tmp.append(tmp[bname]) + + codes=self.MSCode(dlatent_tmp,boundary_tmp) + + out=self.GenerateImg(codes) + return codes,out + + def EditOneC(self,cindex,dlatent_tmp=None): + if dlatent_tmp==None: + dlatent_tmp=[tmp[self.img_index:(self.img_index+self.num_images)] for tmp in self.dlatents] + + boundary_tmp=[[] for i in range(len(self.dlatents))] + + #'only manipulate 1 layer and one channel' + assert len(self.manipulate_layers)==1 + + ml=self.manipulate_layers[0] + tmp=dlatent_tmp[ml].shape[1] #ada + tmp1=np.zeros(tmp) + tmp1[cindex]=self.code_std[ml][cindex] #1 + boundary_tmp[ml]=tmp1 + + codes=self.MSCode(dlatent_tmp,boundary_tmp) + out=self.GenerateImg(codes) + return codes,out + + + def W2S(self,dlatent_tmp): + + all_s = self.sess.run( + self.s_names, + feed_dict={'G_synthesis_1/dlatents_in:0': dlatent_tmp}) + return all_s + + + + + + + + +#%% +if __name__ == "__main__": + + + M=Manipulator(dataset_name='ffhq') + + + #%% + M.alpha=[-5,0,5] + M.num_images=20 + lindex,cindex=6,501 + + M.manipulate_layers=[lindex] + codes,out=M.EditOneC(cindex) #dlatent_tmp + tmp=str(M.manipulate_layers)+'_'+str(cindex) + M.Vis(tmp,'c',out) + + + + + + + + + + + + + + + + + + + + diff --git a/PTI/utils/__init__.py b/models/StyleCLIP/global_directions/utils/__init__.py similarity index 100% rename from PTI/utils/__init__.py rename to models/StyleCLIP/global_directions/utils/__init__.py diff --git a/models/StyleCLIP/global_directions/utils/editor.py b/models/StyleCLIP/global_directions/utils/editor.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c2ac56fd7b4b127f948c6b8cf15874a8fe9d93 --- /dev/null +++ b/models/StyleCLIP/global_directions/utils/editor.py @@ -0,0 +1,507 @@ +# python 3.7 +"""Utility functions for image editing from latent space.""" + +import os.path +import numpy as np + +__all__ = [ + 'parse_indices', 'interpolate', 'mix_style', + 'get_layerwise_manipulation_strength', 'manipulate', 'parse_boundary_list' +] + + +def parse_indices(obj, min_val=None, max_val=None): + """Parses indices. + + If the input is a list or tuple, this function has no effect. + + The input can also be a string, which is either a comma separated list of + numbers 'a, b, c', or a dash separated range 'a - c'. Space in the string will + be ignored. + + Args: + obj: The input object to parse indices from. + min_val: If not `None`, this function will check that all indices are equal + to or larger than this value. (default: None) + max_val: If not `None`, this function will check that all indices are equal + to or smaller than this field. (default: None) + + Returns: + A list of integers. + + Raises: + If the input is invalid, i.e., neither a list or tuple, nor a string. + """ + if obj is None or obj == '': + indices = [] + elif isinstance(obj, int): + indices = [obj] + elif isinstance(obj, (list, tuple, np.ndarray)): + indices = list(obj) + elif isinstance(obj, str): + indices = [] + splits = obj.replace(' ', '').split(',') + for split in splits: + numbers = list(map(int, split.split('-'))) + if len(numbers) == 1: + indices.append(numbers[0]) + elif len(numbers) == 2: + indices.extend(list(range(numbers[0], numbers[1] + 1))) + else: + raise ValueError(f'Invalid type of input: {type(obj)}!') + + assert isinstance(indices, list) + indices = sorted(list(set(indices))) + for idx in indices: + assert isinstance(idx, int) + if min_val is not None: + assert idx >= min_val, f'{idx} is smaller than min val `{min_val}`!' + if max_val is not None: + assert idx <= max_val, f'{idx} is larger than max val `{max_val}`!' + + return indices + + +def interpolate(src_codes, dst_codes, step=5): + """Interpolates two sets of latent codes linearly. + + Args: + src_codes: Source codes, with shape [num, *code_shape]. + dst_codes: Target codes, with shape [num, *code_shape]. + step: Number of interplolation steps, with source and target included. For + example, if `step = 5`, three more samples will be inserted. (default: 5) + + Returns: + Interpolated codes, with shape [num, step, *code_shape]. + + Raises: + ValueError: If the input two sets of latent codes are with different shapes. + """ + if not (src_codes.ndim >= 2 and src_codes.shape == dst_codes.shape): + raise ValueError(f'Shapes of source codes and target codes should both be ' + f'[num, *code_shape], but {src_codes.shape} and ' + f'{dst_codes.shape} are received!') + num = src_codes.shape[0] + code_shape = src_codes.shape[1:] + + a = src_codes[:, np.newaxis] + b = dst_codes[:, np.newaxis] + l = np.linspace(0.0, 1.0, step).reshape( + [step if axis == 1 else 1 for axis in range(a.ndim)]) + results = a + l * (b - a) + assert results.shape == (num, step, *code_shape) + + return results + + +def mix_style(style_codes, + content_codes, + num_layers=1, + mix_layers=None, + is_style_layerwise=True, + is_content_layerwise=True): + """Mixes styles from style codes to those of content codes. + + Each style code or content code consists of `num_layers` codes, each of which + is typically fed into a particular layer of the generator. This function mixes + styles by partially replacing the codes of `content_codes` from some certain + layers with those of `style_codes`. + + For example, if both style code and content code are with shape [10, 512], + meaning to have 10 layers and each employs a 512-dimensional latent code. And + the 1st, 2nd, and 3rd layers are the target layers to perform style mixing. + Then the top half of the content code (with shape [3, 512]) will be replaced + by the top half of the style code (also with shape [3, 512]). + + NOTE: This function also supports taking single-layer latent codes as inputs, + i.e., setting `is_style_layerwise` or `is_content_layerwise` as False. In this + case, the corresponding code will be first repeated for `num_layers` before + performing style mixing. + + Args: + style_codes: Style codes, with shape [num_styles, *code_shape] or + [num_styles, num_layers, *code_shape]. + content_codes: Content codes, with shape [num_contents, *code_shape] or + [num_contents, num_layers, *code_shape]. + num_layers: Total number of layers in the generative model. (default: 1) + mix_layers: Indices of the layers to perform style mixing. `None` means to + replace all layers, in which case the content code will be completely + replaced by style code. (default: None) + is_style_layerwise: Indicating whether the input `style_codes` are + layer-wise codes. (default: True) + is_content_layerwise: Indicating whether the input `content_codes` are + layer-wise codes. (default: True) + num_layers + + Returns: + Codes after style mixing, with shape [num_styles, num_contents, num_layers, + *code_shape]. + + Raises: + ValueError: If input `content_codes` or `style_codes` is with invalid shape. + """ + if not is_style_layerwise: + style_codes = style_codes[:, np.newaxis] + style_codes = np.tile( + style_codes, + [num_layers if axis == 1 else 1 for axis in range(style_codes.ndim)]) + if not is_content_layerwise: + content_codes = content_codes[:, np.newaxis] + content_codes = np.tile( + content_codes, + [num_layers if axis == 1 else 1 for axis in range(content_codes.ndim)]) + + if not (style_codes.ndim >= 3 and style_codes.shape[1] == num_layers and + style_codes.shape[1:] == content_codes.shape[1:]): + raise ValueError(f'Shapes of style codes and content codes should be ' + f'[num_styles, num_layers, *code_shape] and ' + f'[num_contents, num_layers, *code_shape] respectively, ' + f'but {style_codes.shape} and {content_codes.shape} are ' + f'received!') + + layer_indices = parse_indices(mix_layers, min_val=0, max_val=num_layers - 1) + if not layer_indices: + layer_indices = list(range(num_layers)) + + num_styles = style_codes.shape[0] + num_contents = content_codes.shape[0] + code_shape = content_codes.shape[2:] + + s = style_codes[:, np.newaxis] + s = np.tile(s, [num_contents if axis == 1 else 1 for axis in range(s.ndim)]) + c = content_codes[np.newaxis] + c = np.tile(c, [num_styles if axis == 0 else 1 for axis in range(c.ndim)]) + + from_style = np.zeros(s.shape, dtype=bool) + from_style[:, :, layer_indices] = True + results = np.where(from_style, s, c) + assert results.shape == (num_styles, num_contents, num_layers, *code_shape) + + return results + + +def get_layerwise_manipulation_strength(num_layers, + truncation_psi, + truncation_layers): + """Gets layer-wise strength for manipulation. + + Recall the truncation trick played on layer [0, truncation_layers): + + w = truncation_psi * w + (1 - truncation_psi) * w_avg + + So, when using the same boundary to manipulate different layers, layer + [0, truncation_layers) and layer [truncation_layers, num_layers) should use + different strength to eliminate the effect from the truncation trick. More + concretely, the strength for layer [0, truncation_layers) is set as + `truncation_psi`, while that for other layers are set as 1. + """ + strength = [1.0 for _ in range(num_layers)] + if truncation_layers > 0: + for layer_idx in range(0, truncation_layers): + strength[layer_idx] = truncation_psi + return strength + + +def manipulate(latent_codes, + boundary, + start_distance=-5.0, + end_distance=5.0, + step=21, + layerwise_manipulation=False, + num_layers=1, + manipulate_layers=None, + is_code_layerwise=False, + is_boundary_layerwise=False, + layerwise_manipulation_strength=1.0): + """Manipulates the given latent codes with respect to a particular boundary. + + Basically, this function takes a set of latent codes and a boundary as inputs, + and outputs a collection of manipulated latent codes. + + For example, let `step` to be 10, `latent_codes` to be with shape [num, + *code_shape], and `boundary` to be with shape [1, *code_shape] and unit norm. + Then the output will be with shape [num, 10, *code_shape]. For each 10-element + manipulated codes, the first code is `start_distance` away from the original + code (i.e., the input) along the `boundary` direction, while the last code is + `end_distance` away. Remaining codes are linearly interpolated. Here, + `distance` is sign sensitive. + + NOTE: This function also supports layer-wise manipulation, in which case the + generator should be able to take layer-wise latent codes as inputs. For + example, if the generator has 18 convolutional layers in total, and each of + which takes an independent latent code as input. It is possible, sometimes + with even better performance, to only partially manipulate these latent codes + corresponding to some certain layers yet keeping others untouched. + + NOTE: Boundary is assumed to be normalized to unit norm already. + + Args: + latent_codes: The input latent codes for manipulation, with shape + [num, *code_shape] or [num, num_layers, *code_shape]. + boundary: The semantic boundary as reference, with shape [1, *code_shape] or + [1, num_layers, *code_shape]. + start_distance: Start point for manipulation. (default: -5.0) + end_distance: End point for manipulation. (default: 5.0) + step: Number of manipulation steps. (default: 21) + layerwise_manipulation: Whether to perform layer-wise manipulation. + (default: False) + num_layers: Number of layers. Only active when `layerwise_manipulation` is + set as `True`. Should be a positive integer. (default: 1) + manipulate_layers: Indices of the layers to perform manipulation. `None` + means to manipulate latent codes from all layers. (default: None) + is_code_layerwise: Whether the input latent codes are layer-wise. If set as + `False`, the function will first repeat the input codes for `num_layers` + times before perform manipulation. (default: False) + is_boundary_layerwise: Whether the input boundary is layer-wise. If set as + `False`, the function will first repeat boundary for `num_layers` times + before perform manipulation. (default: False) + layerwise_manipulation_strength: Manipulation strength for each layer. Only + active when `layerwise_manipulation` is set as `True`. This field can be + used to resolve the strength discrepancy across layers when truncation + trick is on. See function `get_layerwise_manipulation_strength()` for + details. A tuple, list, or `numpy.ndarray` is expected. If set as a single + number, this strength will be used for all layers. (default: 1.0) + + Returns: + Manipulated codes, with shape [num, step, *code_shape] if + `layerwise_manipulation` is set as `False`, or shape [num, step, + num_layers, *code_shape] if `layerwise_manipulation` is set as `True`. + + Raises: + ValueError: If the input latent codes, boundary, or strength are with + invalid shape. + """ + if not (boundary.ndim >= 2 and boundary.shape[0] == 1): + raise ValueError(f'Boundary should be with shape [1, *code_shape] or ' + f'[1, num_layers, *code_shape], but ' + f'{boundary.shape} is received!') + + if not layerwise_manipulation: + assert not is_code_layerwise + assert not is_boundary_layerwise + num_layers = 1 + manipulate_layers = None + layerwise_manipulation_strength = 1.0 + + # Preprocessing for layer-wise manipulation. + # Parse indices of manipulation layers. + layer_indices = parse_indices( + manipulate_layers, min_val=0, max_val=num_layers - 1) + if not layer_indices: + layer_indices = list(range(num_layers)) + # Make latent codes layer-wise if needed. + assert num_layers > 0 + if not is_code_layerwise: + x = latent_codes[:, np.newaxis] + x = np.tile(x, [num_layers if axis == 1 else 1 for axis in range(x.ndim)]) + else: + x = latent_codes + if x.shape[1] != num_layers: + raise ValueError(f'Latent codes should be with shape [num, num_layers, ' + f'*code_shape], where `num_layers` equals to ' + f'{num_layers}, but {x.shape} is received!') + # Make boundary layer-wise if needed. + if not is_boundary_layerwise: + b = boundary + b = np.tile(b, [num_layers if axis == 0 else 1 for axis in range(b.ndim)]) + else: + b = boundary[0] + if b.shape[0] != num_layers: + raise ValueError(f'Boundary should be with shape [num_layers, ' + f'*code_shape], where `num_layers` equals to ' + f'{num_layers}, but {b.shape} is received!') + # Get layer-wise manipulation strength. + if isinstance(layerwise_manipulation_strength, (int, float)): + s = [float(layerwise_manipulation_strength) for _ in range(num_layers)] + elif isinstance(layerwise_manipulation_strength, (list, tuple)): + s = layerwise_manipulation_strength + if len(s) != num_layers: + raise ValueError(f'Shape of layer-wise manipulation strength `{len(s)}` ' + f'mismatches number of layers `{num_layers}`!') + elif isinstance(layerwise_manipulation_strength, np.ndarray): + s = layerwise_manipulation_strength + if s.size != num_layers: + raise ValueError(f'Shape of layer-wise manipulation strength `{s.size}` ' + f'mismatches number of layers `{num_layers}`!') + else: + raise ValueError(f'Unsupported type of `layerwise_manipulation_strength`!') + s = np.array(s).reshape( + [num_layers if axis == 0 else 1 for axis in range(b.ndim)]) + b = b * s + + if x.shape[1:] != b.shape: + raise ValueError(f'Latent code shape {x.shape} and boundary shape ' + f'{b.shape} mismatch!') + num = x.shape[0] + code_shape = x.shape[2:] + + x = x[:, np.newaxis] + b = b[np.newaxis, np.newaxis, :] + l = np.linspace(start_distance, end_distance, step).reshape( + [step if axis == 1 else 1 for axis in range(x.ndim)]) + results = np.tile(x, [step if axis == 1 else 1 for axis in range(x.ndim)]) + is_manipulatable = np.zeros(results.shape, dtype=bool) + is_manipulatable[:, :, layer_indices] = True + results = np.where(is_manipulatable, x + l * b, results) + assert results.shape == (num, step, num_layers, *code_shape) + + return results if layerwise_manipulation else results[:, :, 0] + + +def manipulate2(latent_codes, + proj, + mindex, + start_distance=-5.0, + end_distance=5.0, + step=21, + layerwise_manipulation=False, + num_layers=1, + manipulate_layers=None, + is_code_layerwise=False, + layerwise_manipulation_strength=1.0): + + + if not layerwise_manipulation: + assert not is_code_layerwise +# assert not is_boundary_layerwise + num_layers = 1 + manipulate_layers = None + layerwise_manipulation_strength = 1.0 + + # Preprocessing for layer-wise manipulation. + # Parse indices of manipulation layers. + layer_indices = parse_indices( + manipulate_layers, min_val=0, max_val=num_layers - 1) + if not layer_indices: + layer_indices = list(range(num_layers)) + # Make latent codes layer-wise if needed. + assert num_layers > 0 + if not is_code_layerwise: + x = latent_codes[:, np.newaxis] + x = np.tile(x, [num_layers if axis == 1 else 1 for axis in range(x.ndim)]) + else: + x = latent_codes + if x.shape[1] != num_layers: + raise ValueError(f'Latent codes should be with shape [num, num_layers, ' + f'*code_shape], where `num_layers` equals to ' + f'{num_layers}, but {x.shape} is received!') + # Make boundary layer-wise if needed. +# if not is_boundary_layerwise: +# b = boundary +# b = np.tile(b, [num_layers if axis == 0 else 1 for axis in range(b.ndim)]) +# else: +# b = boundary[0] +# if b.shape[0] != num_layers: +# raise ValueError(f'Boundary should be with shape [num_layers, ' +# f'*code_shape], where `num_layers` equals to ' +# f'{num_layers}, but {b.shape} is received!') + # Get layer-wise manipulation strength. + if isinstance(layerwise_manipulation_strength, (int, float)): + s = [float(layerwise_manipulation_strength) for _ in range(num_layers)] + elif isinstance(layerwise_manipulation_strength, (list, tuple)): + s = layerwise_manipulation_strength + if len(s) != num_layers: + raise ValueError(f'Shape of layer-wise manipulation strength `{len(s)}` ' + f'mismatches number of layers `{num_layers}`!') + elif isinstance(layerwise_manipulation_strength, np.ndarray): + s = layerwise_manipulation_strength + if s.size != num_layers: + raise ValueError(f'Shape of layer-wise manipulation strength `{s.size}` ' + f'mismatches number of layers `{num_layers}`!') + else: + raise ValueError(f'Unsupported type of `layerwise_manipulation_strength`!') +# s = np.array(s).reshape( +# [num_layers if axis == 0 else 1 for axis in range(b.ndim)]) +# b = b * s + +# if x.shape[1:] != b.shape: +# raise ValueError(f'Latent code shape {x.shape} and boundary shape ' +# f'{b.shape} mismatch!') + num = x.shape[0] + code_shape = x.shape[2:] + + x = x[:, np.newaxis] +# b = b[np.newaxis, np.newaxis, :] +# l = np.linspace(start_distance, end_distance, step).reshape( +# [step if axis == 1 else 1 for axis in range(x.ndim)]) + results = np.tile(x, [step if axis == 1 else 1 for axis in range(x.ndim)]) + is_manipulatable = np.zeros(results.shape, dtype=bool) + is_manipulatable[:, :, layer_indices] = True + + tmp=MPC(proj,x,mindex,start_distance,end_distance,step) + tmp = tmp[:, :,np.newaxis] + tmp1 = np.tile(tmp, [num_layers if axis == 2 else 1 for axis in range(tmp.ndim)]) + + + results = np.where(is_manipulatable, tmp1, results) +# print(results.shape) + assert results.shape == (num, step, num_layers, *code_shape) + return results if layerwise_manipulation else results[:, :, 0] + +def MPC(proj,x,mindex,start_distance,end_distance,step): + # x shape (batch_size,1,num_layers,feature) +# print(x.shape) + x1=proj.transform(x[:,0,0,:]) #/np.sqrt(proj.explained_variance_) # (batch_size,num_pc) + + x1 = x1[:, np.newaxis] + x1 = np.tile(x1, [step if axis == 1 else 1 for axis in range(x1.ndim)]) + + + l = np.linspace(start_distance, end_distance, step)[None,:] + x1[:,:,mindex]+=l + + tmp=x1.reshape((-1,x1.shape[-1])) #*np.sqrt(proj.explained_variance_) +# print('xxx') + x2=proj.inverse_transform(tmp) + x2=x2.reshape((x1.shape[0],x1.shape[1],-1)) + +# x1 = x1[:, np.newaxis] +# x1 = np.tile(x1, [step if axis == 1 else 1 for axis in range(x1.ndim)]) + + return x2 + + + + +def parse_boundary_list(boundary_list_path): + """Parses boundary list. + + Sometimes, a text file containing a list of boundaries will significantly + simplify image manipulation with a large amount of boundaries. This function + is used to parse boundary information from such list file. + + Basically, each item in the list should be with format + `($NAME, $SPACE_TYPE): $PATH`. `DISABLE` at the beginning of the line can + disable a particular boundary. + + Sample: + + (age, z): $AGE_BOUNDARY_PATH + (gender, w): $GENDER_BOUNDARY_PATH + DISABLE(pose, wp): $POSE_BOUNDARY_PATH + + Args: + boundary_list_path: Path to the boundary list. + + Returns: + A dictionary, whose key is a two-element tuple (boundary_name, space_type) + and value is the corresponding boundary path. + + Raise: + ValueError: If the given boundary list does not exist. + """ + if not os.path.isfile(boundary_list_path): + raise ValueError(f'Boundary list `boundary_list_path` does not exist!') + + boundaries = {} + with open(boundary_list_path, 'r') as f: + for line in f: + if line[:len('DISABLE')] == 'DISABLE': + continue + boundary_info, boundary_path = line.strip().split(':') + boundary_name, space_type = boundary_info.strip()[1:-1].split(',') + boundary_name = boundary_name.strip() + space_type = space_type.strip().lower() + boundary_path = boundary_path.strip() + boundaries[(boundary_name, space_type)] = boundary_path + return boundaries diff --git a/models/StyleCLIP/global_directions/utils/train_boundary.py b/models/StyleCLIP/global_directions/utils/train_boundary.py new file mode 100644 index 0000000000000000000000000000000000000000..710d062bc4b42913fcc5b12bd545e47af00c7123 --- /dev/null +++ b/models/StyleCLIP/global_directions/utils/train_boundary.py @@ -0,0 +1,158 @@ + +import numpy as np +from sklearn import svm + + + + + +def train_boundary(latent_codes, + scores, + chosen_num_or_ratio=0.02, + split_ratio=0.7, + invalid_value=None, + logger=None, + logger_name='train_boundary'): + """Trains boundary in latent space with offline predicted attribute scores. + + Given a collection of latent codes and the attribute scores predicted from the + corresponding images, this function will train a linear SVM by treating it as + a bi-classification problem. Basically, the samples with highest attribute + scores are treated as positive samples, while those with lowest scores as + negative. For now, the latent code can ONLY be with 1 dimension. + + NOTE: The returned boundary is with shape (1, latent_space_dim), and also + normalized with unit norm. + + Args: + latent_codes: Input latent codes as training data. + scores: Input attribute scores used to generate training labels. + chosen_num_or_ratio: How many samples will be chosen as positive (negative) + samples. If this field lies in range (0, 0.5], `chosen_num_or_ratio * + latent_codes_num` will be used. Otherwise, `min(chosen_num_or_ratio, + 0.5 * latent_codes_num)` will be used. (default: 0.02) + split_ratio: Ratio to split training and validation sets. (default: 0.7) + invalid_value: This field is used to filter out data. (default: None) + logger: Logger for recording log messages. If set as `None`, a default + logger, which prints messages from all levels to screen, will be created. + (default: None) + + Returns: + A decision boundary with type `numpy.ndarray`. + + Raises: + ValueError: If the input `latent_codes` or `scores` are with invalid format. + """ +# if not logger: +# logger = setup_logger(work_dir='', logger_name=logger_name) + + if (not isinstance(latent_codes, np.ndarray) or + not len(latent_codes.shape) == 2): + raise ValueError(f'Input `latent_codes` should be with type' + f'`numpy.ndarray`, and shape [num_samples, ' + f'latent_space_dim]!') + num_samples = latent_codes.shape[0] + latent_space_dim = latent_codes.shape[1] + if (not isinstance(scores, np.ndarray) or not len(scores.shape) == 2 or + not scores.shape[0] == num_samples or not scores.shape[1] == 1): + raise ValueError(f'Input `scores` should be with type `numpy.ndarray`, and ' + f'shape [num_samples, 1], where `num_samples` should be ' + f'exactly same as that of input `latent_codes`!') + if chosen_num_or_ratio <= 0: + raise ValueError(f'Input `chosen_num_or_ratio` should be positive, ' + f'but {chosen_num_or_ratio} received!') + +# logger.info(f'Filtering training data.') + print('Filtering training data.') + if invalid_value is not None: + latent_codes = latent_codes[scores[:, 0] != invalid_value] + scores = scores[scores[:, 0] != invalid_value] + +# logger.info(f'Sorting scores to get positive and negative samples.') + print('Sorting scores to get positive and negative samples.') + + sorted_idx = np.argsort(scores, axis=0)[::-1, 0] + latent_codes = latent_codes[sorted_idx] + scores = scores[sorted_idx] + num_samples = latent_codes.shape[0] + if 0 < chosen_num_or_ratio <= 1: + chosen_num = int(num_samples * chosen_num_or_ratio) + else: + chosen_num = int(chosen_num_or_ratio) + chosen_num = min(chosen_num, num_samples // 2) + +# logger.info(f'Spliting training and validation sets:') + print('Filtering training data.') + + train_num = int(chosen_num * split_ratio) + val_num = chosen_num - train_num + # Positive samples. + positive_idx = np.arange(chosen_num) + np.random.shuffle(positive_idx) + positive_train = latent_codes[:chosen_num][positive_idx[:train_num]] + positive_val = latent_codes[:chosen_num][positive_idx[train_num:]] + # Negative samples. + negative_idx = np.arange(chosen_num) + np.random.shuffle(negative_idx) + negative_train = latent_codes[-chosen_num:][negative_idx[:train_num]] + negative_val = latent_codes[-chosen_num:][negative_idx[train_num:]] + # Training set. + train_data = np.concatenate([positive_train, negative_train], axis=0) + train_label = np.concatenate([np.ones(train_num, dtype=np.int), + np.zeros(train_num, dtype=np.int)], axis=0) +# logger.info(f' Training: {train_num} positive, {train_num} negative.') + print(f' Training: {train_num} positive, {train_num} negative.') + # Validation set. + val_data = np.concatenate([positive_val, negative_val], axis=0) + val_label = np.concatenate([np.ones(val_num, dtype=np.int), + np.zeros(val_num, dtype=np.int)], axis=0) +# logger.info(f' Validation: {val_num} positive, {val_num} negative.') + print(f' Validation: {val_num} positive, {val_num} negative.') + + # Remaining set. + remaining_num = num_samples - chosen_num * 2 + remaining_data = latent_codes[chosen_num:-chosen_num] + remaining_scores = scores[chosen_num:-chosen_num] + decision_value = (scores[0] + scores[-1]) / 2 + remaining_label = np.ones(remaining_num, dtype=np.int) + remaining_label[remaining_scores.ravel() < decision_value] = 0 + remaining_positive_num = np.sum(remaining_label == 1) + remaining_negative_num = np.sum(remaining_label == 0) +# logger.info(f' Remaining: {remaining_positive_num} positive, ' +# f'{remaining_negative_num} negative.') + print(f' Remaining: {remaining_positive_num} positive, ' + f'{remaining_negative_num} negative.') +# logger.info(f'Training boundary.') + print(f'Training boundary.') + + clf = svm.SVC(kernel='linear') + classifier = clf.fit(train_data, train_label) +# logger.info(f'Finish training.') + print(f'Finish training.') + + + if val_num: + val_prediction = classifier.predict(val_data) + correct_num = np.sum(val_label == val_prediction) +# logger.info(f'Accuracy for validation set: ' +# f'{correct_num} / {val_num * 2} = ' +# f'{correct_num / (val_num * 2):.6f}') + print(f'Accuracy for validation set: ' + f'{correct_num} / {val_num * 2} = ' + f'{correct_num / (val_num * 2):.6f}') + vacc=correct_num/len(val_label) + ''' + if remaining_num: + remaining_prediction = classifier.predict(remaining_data) + correct_num = np.sum(remaining_label == remaining_prediction) + logger.info(f'Accuracy for remaining set: ' + f'{correct_num} / {remaining_num} = ' + f'{correct_num / remaining_num:.6f}') + ''' + a = classifier.coef_.reshape(1, latent_space_dim).astype(np.float32) + return a / np.linalg.norm(a),vacc + + + + + diff --git a/models/StyleCLIP/global_directions/utils/visualizer.py b/models/StyleCLIP/global_directions/utils/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..8c4a1fba06bf6bc680aa59bf645f796283f6f1c6 --- /dev/null +++ b/models/StyleCLIP/global_directions/utils/visualizer.py @@ -0,0 +1,605 @@ +# python 3.7 +"""Utility functions for visualizing results on html page.""" + +import base64 +import os.path +import cv2 +import numpy as np + +__all__ = [ + 'get_grid_shape', 'get_blank_image', 'load_image', 'save_image', + 'resize_image', 'add_text_to_image', 'fuse_images', 'HtmlPageVisualizer', + 'VideoReader', 'VideoWriter', 'adjust_pixel_range' +] + + +def adjust_pixel_range(images, min_val=-1.0, max_val=1.0, channel_order='NCHW'): + """Adjusts the pixel range of the input images. + + This function assumes the input array (image batch) is with shape [batch_size, + channel, height, width] if `channel_order = NCHW`, or with shape [batch_size, + height, width] if `channel_order = NHWC`. The returned images are with shape + [batch_size, height, width, channel] and pixel range [0, 255]. + + NOTE: The channel order of output images will remain the same as the input. + + Args: + images: Input images to adjust pixel range. + min_val: Min value of the input images. (default: -1.0) + max_val: Max value of the input images. (default: 1.0) + channel_order: Channel order of the input array. (default: NCHW) + + Returns: + The postprocessed images with dtype `numpy.uint8` and range [0, 255]. + + Raises: + ValueError: If the input `images` are not with type `numpy.ndarray` or the + shape is invalid according to `channel_order`. + """ + if not isinstance(images, np.ndarray): + raise ValueError(f'Images should be with type `numpy.ndarray`!') + + channel_order = channel_order.upper() + if channel_order not in ['NCHW', 'NHWC']: + raise ValueError(f'Invalid channel order `{channel_order}`!') + + if images.ndim != 4: + raise ValueError(f'Input images are expected to be with shape `NCHW` or ' + f'`NHWC`, but `{images.shape}` is received!') + if channel_order == 'NCHW' and images.shape[1] not in [1, 3]: + raise ValueError(f'Input images should have 1 or 3 channels under `NCHW` ' + f'channel order!') + if channel_order == 'NHWC' and images.shape[3] not in [1, 3]: + raise ValueError(f'Input images should have 1 or 3 channels under `NHWC` ' + f'channel order!') + + images = images.astype(np.float32) + images = (images - min_val) * 255 / (max_val - min_val) + images = np.clip(images + 0.5, 0, 255).astype(np.uint8) + if channel_order == 'NCHW': + images = images.transpose(0, 2, 3, 1) + + return images + + +def get_grid_shape(size, row=0, col=0, is_portrait=False): + """Gets the shape of a grid based on the size. + + This function makes greatest effort on making the output grid square if + neither `row` nor `col` is set. If `is_portrait` is set as `False`, the height + will always be equal to or smaller than the width. For example, if input + `size = 16`, output shape will be `(4, 4)`; if input `size = 15`, output shape + will be (3, 5). Otherwise, the height will always be equal to or larger than + the width. + + Args: + size: Size (height * width) of the target grid. + is_portrait: Whether to return a portrait size of a landscape size. + (default: False) + + Returns: + A two-element tuple, representing height and width respectively. + """ + assert isinstance(size, int) + assert isinstance(row, int) + assert isinstance(col, int) + if size == 0: + return (0, 0) + + if row > 0 and col > 0 and row * col != size: + row = 0 + col = 0 + + if row > 0 and size % row == 0: + return (row, size // row) + if col > 0 and size % col == 0: + return (size // col, col) + + row = int(np.sqrt(size)) + while row > 0: + if size % row == 0: + col = size // row + break + row = row - 1 + + return (col, row) if is_portrait else (row, col) + + +def get_blank_image(height, width, channels=3, is_black=True): + """Gets a blank image, either white of black. + + NOTE: This function will always return an image with `RGB` channel order for + color image and pixel range [0, 255]. + + Args: + height: Height of the returned image. + width: Width of the returned image. + channels: Number of channels. (default: 3) + is_black: Whether to return a black image or white image. (default: True) + """ + shape = (height, width, channels) + if is_black: + return np.zeros(shape, dtype=np.uint8) + return np.ones(shape, dtype=np.uint8) * 255 + + +def load_image(path): + """Loads an image from disk. + + NOTE: This function will always return an image with `RGB` channel order for + color image and pixel range [0, 255]. + + Args: + path: Path to load the image from. + + Returns: + An image with dtype `np.ndarray` or `None` if input `path` does not exist. + """ + if not os.path.isfile(path): + return None + + image = cv2.imread(path) + return image[:, :, ::-1] + + +def save_image(path, image): + """Saves an image to disk. + + NOTE: The input image (if colorful) is assumed to be with `RGB` channel order + and pixel range [0, 255]. + + Args: + path: Path to save the image to. + image: Image to save. + """ + if image is None: + return + + assert len(image.shape) == 3 and image.shape[2] in [1, 3] + cv2.imwrite(path, image[:, :, ::-1]) + + +def resize_image(image, *args, **kwargs): + """Resizes image. + + This is a wrap of `cv2.resize()`. + + NOTE: THe channel order of the input image will not be changed. + + Args: + image: Image to resize. + """ + if image is None: + return None + + assert image.ndim == 3 and image.shape[2] in [1, 3] + image = cv2.resize(image, *args, **kwargs) + if image.ndim == 2: + return image[:, :, np.newaxis] + return image + + +def add_text_to_image(image, + text='', + position=None, + font=cv2.FONT_HERSHEY_TRIPLEX, + font_size=1.0, + line_type=cv2.LINE_8, + line_width=1, + color=(255, 255, 255)): + """Overlays text on given image. + + NOTE: The input image is assumed to be with `RGB` channel order. + + Args: + image: The image to overlay text on. + text: Text content to overlay on the image. (default: '') + position: Target position (bottom-left corner) to add text. If not set, + center of the image will be used by default. (default: None) + font: Font of the text added. (default: cv2.FONT_HERSHEY_TRIPLEX) + font_size: Font size of the text added. (default: 1.0) + line_type: Line type used to depict the text. (default: cv2.LINE_8) + line_width: Line width used to depict the text. (default: 1) + color: Color of the text added in `RGB` channel order. (default: + (255, 255, 255)) + + Returns: + An image with target text overlayed on. + """ + if image is None or not text: + return image + + cv2.putText(img=image, + text=text, + org=position, + fontFace=font, + fontScale=font_size, + color=color, + thickness=line_width, + lineType=line_type, + bottomLeftOrigin=False) + + return image + + +def fuse_images(images, + image_size=None, + row=0, + col=0, + is_row_major=True, + is_portrait=False, + row_spacing=0, + col_spacing=0, + border_left=0, + border_right=0, + border_top=0, + border_bottom=0, + black_background=True): + """Fuses a collection of images into an entire image. + + Args: + images: A collection of images to fuse. Should be with shape [num, height, + width, channels]. + image_size: Int or two-element tuple. This field is used to resize the image + before fusing. `None` disables resizing. (default: None) + row: Number of rows used for image fusion. If not set, this field will be + automatically assigned based on `col` and total number of images. + (default: None) + col: Number of columns used for image fusion. If not set, this field will be + automatically assigned based on `row` and total number of images. + (default: None) + is_row_major: Whether the input images should be arranged row-major or + column-major. (default: True) + is_portrait: Only active when both `row` and `col` should be assigned + automatically. (default: False) + row_spacing: Space between rows. (default: 0) + col_spacing: Space between columns. (default: 0) + border_left: Width of left border. (default: 0) + border_right: Width of right border. (default: 0) + border_top: Width of top border. (default: 0) + border_bottom: Width of bottom border. (default: 0) + + Returns: + The fused image. + + Raises: + ValueError: If the input `images` is not with shape [num, height, width, + width]. + """ + if images is None: + return images + + if not images.ndim == 4: + raise ValueError(f'Input `images` should be with shape [num, height, ' + f'width, channels], but {images.shape} is received!') + + num, image_height, image_width, channels = images.shape + if image_size is not None: + if isinstance(image_size, int): + image_size = (image_size, image_size) + assert isinstance(image_size, (list, tuple)) and len(image_size) == 2 + width, height = image_size + else: + height, width = image_height, image_width + row, col = get_grid_shape(num, row=row, col=col, is_portrait=is_portrait) + fused_height = ( + height * row + row_spacing * (row - 1) + border_top + border_bottom) + fused_width = ( + width * col + col_spacing * (col - 1) + border_left + border_right) + fused_image = get_blank_image( + fused_height, fused_width, channels=channels, is_black=black_background) + images = images.reshape(row, col, image_height, image_width, channels) + if not is_row_major: + images = images.transpose(1, 0, 2, 3, 4) + + for i in range(row): + y = border_top + i * (height + row_spacing) + for j in range(col): + x = border_left + j * (width + col_spacing) + if image_size is not None: + image = cv2.resize(images[i, j], image_size) + else: + image = images[i, j] + fused_image[y:y + height, x:x + width] = image + + return fused_image + + +def get_sortable_html_header(column_name_list, sort_by_ascending=False): + """Gets header for sortable html page. + + Basically, the html page contains a sortable table, where user can sort the + rows by a particular column by clicking the column head. + + Example: + + column_name_list = [name_1, name_2, name_3] + header = get_sortable_html_header(column_name_list) + footer = get_sortable_html_footer() + sortable_table = ... + html_page = header + sortable_table + footer + + Args: + column_name_list: List of column header names. + sort_by_ascending: Default sorting order. If set as `True`, the html page + will be sorted by ascending order when the header is clicked for the first + time. + + Returns: + A string, which represents for the header for a sortable html page. + """ + header = '\n'.join([ + '', + '', + '', + '', + '', + '', + '', + '', + '', + '', + '', + '', + '', + '']) + for idx, column_name in enumerate(column_name_list): + header += f' \n' + header += '\n' + header += '\n' + header += '\n' + + return header + + +def get_sortable_html_footer(): + """Gets footer for sortable html page. + + Check function `get_sortable_html_header()` for more details. + """ + return '\n
{column_name}
\n\n\n\n' + + +def encode_image_to_html_str(image, image_size=None): + """Encodes an image to html language. + + Args: + image: The input image to encode. Should be with `RGB` channel order. + image_size: Int or two-element tuple. This field is used to resize the image + before encoding. `None` disables resizing. (default: None) + + Returns: + A string which represents the encoded image. + """ + if image is None: + return '' + + assert len(image.shape) == 3 and image.shape[2] in [1, 3] + + # Change channel order to `BGR`, which is opencv-friendly. + image = image[:, :, ::-1] + + # Resize the image if needed. + if image_size is not None: + if isinstance(image_size, int): + image_size = (image_size, image_size) + assert isinstance(image_size, (list, tuple)) and len(image_size) == 2 + image = cv2.resize(image, image_size) + + # Encode the image to html-format string. + encoded_image = cv2.imencode(".jpg", image)[1].tostring() + encoded_image_base64 = base64.b64encode(encoded_image).decode('utf-8') + html_str = f'' + + return html_str + + +class HtmlPageVisualizer(object): + """Defines the html page visualizer. + + This class can be used to visualize image results as html page. Basically, it + is based on an html-format sorted table with helper functions + `get_sortable_html_header()`, `get_sortable_html_footer()`, and + `encode_image_to_html_str()`. To simplify the usage, specifying the following + fields is enough to create a visualization page: + + (1) num_rows: Number of rows of the table (header-row exclusive). + (2) num_cols: Number of columns of the table. + (3) header contents (optional): Title of each column. + + NOTE: `grid_size` can be used to assign `num_rows` and `num_cols` + automatically. + + Example: + + html = HtmlPageVisualizer(num_rows, num_cols) + html.set_headers([...]) + for i in range(num_rows): + for j in range(num_cols): + html.set_cell(i, j, text=..., image=...) + html.save('visualize.html') + """ + + def __init__(self, + num_rows=0, + num_cols=0, + grid_size=0, + is_portrait=False, + viz_size=None): + if grid_size > 0: + num_rows, num_cols = get_grid_shape( + grid_size, row=num_rows, col=num_cols, is_portrait=is_portrait) + assert num_rows > 0 and num_cols > 0 + + self.num_rows = num_rows + self.num_cols = num_cols + self.viz_size = viz_size + self.headers = ['' for _ in range(self.num_cols)] + self.cells = [[{ + 'text': '', + 'image': '', + } for _ in range(self.num_cols)] for _ in range(self.num_rows)] + + def set_header(self, column_idx, content): + """Sets the content of a particular header by column index.""" + self.headers[column_idx] = content + + def set_headers(self, contents): + """Sets the contents of all headers.""" + if isinstance(contents, str): + contents = [contents] + assert isinstance(contents, (list, tuple)) + assert len(contents) == self.num_cols + for column_idx, content in enumerate(contents): + self.set_header(column_idx, content) + + def set_cell(self, row_idx, column_idx, text='', image=None): + """Sets the content of a particular cell. + + Basically, a cell contains some text as well as an image. Both text and + image can be empty. + + Args: + row_idx: Row index of the cell to edit. + column_idx: Column index of the cell to edit. + text: Text to add into the target cell. + image: Image to show in the target cell. Should be with `RGB` channel + order. + """ + self.cells[row_idx][column_idx]['text'] = text + self.cells[row_idx][column_idx]['image'] = encode_image_to_html_str( + image, self.viz_size) + + def save(self, save_path): + """Saves the html page.""" + html = '' + for i in range(self.num_rows): + html += f'\n' + for j in range(self.num_cols): + text = self.cells[i][j]['text'] + image = self.cells[i][j]['image'] + if text: + html += f' {text}

{image}\n' + else: + html += f' {image}\n' + html += f'\n' + + header = get_sortable_html_header(self.headers) + footer = get_sortable_html_footer() + + with open(save_path, 'w') as f: + f.write(header + html + footer) + + +class VideoReader(object): + """Defines the video reader. + + This class can be used to read frames from a given video. + """ + + def __init__(self, path): + """Initializes the video reader by loading the video from disk.""" + if not os.path.isfile(path): + raise ValueError(f'Video `{path}` does not exist!') + + self.path = path + self.video = cv2.VideoCapture(path) + assert self.video.isOpened() + self.position = 0 + + self.length = int(self.video.get(cv2.CAP_PROP_FRAME_COUNT)) + self.frame_height = int(self.video.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_width = int(self.video.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.fps = self.video.get(cv2.CAP_PROP_FPS) + + def __del__(self): + """Releases the opened video.""" + self.video.release() + + def read(self, position=None): + """Reads a certain frame. + + NOTE: The returned frame is assumed to be with `RGB` channel order. + + Args: + position: Optional. If set, the reader will read frames from the exact + position. Otherwise, the reader will read next frames. (default: None) + """ + if position is not None and position < self.length: + self.video.set(cv2.CAP_PROP_POS_FRAMES, position) + self.position = position + + success, frame = self.video.read() + self.position = self.position + 1 + + return frame[:, :, ::-1] if success else None + + +class VideoWriter(object): + """Defines the video writer. + + This class can be used to create a video. + + NOTE: `.avi` and `DIVX` is the most recommended codec format since it does not + rely on other dependencies. + """ + + def __init__(self, path, frame_height, frame_width, fps=24, codec='DIVX'): + """Creates the video writer.""" + self.path = path + self.frame_height = frame_height + self.frame_width = frame_width + self.fps = fps + self.codec = codec + + self.video = cv2.VideoWriter(filename=path, + fourcc=cv2.VideoWriter_fourcc(*codec), + fps=fps, + frameSize=(frame_width, frame_height)) + + def __del__(self): + """Releases the opened video.""" + self.video.release() + + def write(self, frame): + """Writes a target frame. + + NOTE: The input frame is assumed to be with `RGB` channel order. + """ + self.video.write(frame[:, :, ::-1]) diff --git a/models/StyleCLIP/mapper/__init__.py b/models/StyleCLIP/mapper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/StyleCLIP/mapper/datasets/__init__.py b/models/StyleCLIP/mapper/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/StyleCLIP/mapper/datasets/latents_dataset.py b/models/StyleCLIP/mapper/datasets/latents_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..dde6ef52b7488e864ccd2fa2930b5100c1025c17 --- /dev/null +++ b/models/StyleCLIP/mapper/datasets/latents_dataset.py @@ -0,0 +1,15 @@ +from torch.utils.data import Dataset + + +class LatentsDataset(Dataset): + + def __init__(self, latents, opts): + self.latents = latents + self.opts = opts + + def __len__(self): + return self.latents.shape[0] + + def __getitem__(self, index): + + return self.latents[index] diff --git a/models/StyleCLIP/mapper/latent_mappers.py b/models/StyleCLIP/mapper/latent_mappers.py new file mode 100644 index 0000000000000000000000000000000000000000..63637adc9646986a3546edd19f4555a2f75a379f --- /dev/null +++ b/models/StyleCLIP/mapper/latent_mappers.py @@ -0,0 +1,81 @@ +import torch +from torch import nn +from torch.nn import Module + +from models.StyleCLIP.models.stylegan2.model import EqualLinear, PixelNorm + + +class Mapper(Module): + + def __init__(self, opts): + super(Mapper, self).__init__() + + self.opts = opts + layers = [PixelNorm()] + + for i in range(4): + layers.append( + EqualLinear( + 512, 512, lr_mul=0.01, activation='fused_lrelu' + ) + ) + + self.mapping = nn.Sequential(*layers) + + + def forward(self, x): + x = self.mapping(x) + return x + + +class SingleMapper(Module): + + def __init__(self, opts): + super(SingleMapper, self).__init__() + + self.opts = opts + + self.mapping = Mapper(opts) + + def forward(self, x): + out = self.mapping(x) + return out + + +class LevelsMapper(Module): + + def __init__(self, opts): + super(LevelsMapper, self).__init__() + + self.opts = opts + + if not opts.no_coarse_mapper: + self.course_mapping = Mapper(opts) + if not opts.no_medium_mapper: + self.medium_mapping = Mapper(opts) + if not opts.no_fine_mapper: + self.fine_mapping = Mapper(opts) + + def forward(self, x): + x_coarse = x[:, :4, :] + x_medium = x[:, 4:8, :] + x_fine = x[:, 8:, :] + + if not self.opts.no_coarse_mapper: + x_coarse = self.course_mapping(x_coarse) + else: + x_coarse = torch.zeros_like(x_coarse) + if not self.opts.no_medium_mapper: + x_medium = self.medium_mapping(x_medium) + else: + x_medium = torch.zeros_like(x_medium) + if not self.opts.no_fine_mapper: + x_fine = self.fine_mapping(x_fine) + else: + x_fine = torch.zeros_like(x_fine) + + + out = torch.cat([x_coarse, x_medium, x_fine], dim=1) + + return out + diff --git a/models/StyleCLIP/mapper/options/__init__.py b/models/StyleCLIP/mapper/options/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/StyleCLIP/mapper/options/test_options.py b/models/StyleCLIP/mapper/options/test_options.py new file mode 100644 index 0000000000000000000000000000000000000000..aab2e5a5bba1038b97110fa6c8e8bce14de7390c --- /dev/null +++ b/models/StyleCLIP/mapper/options/test_options.py @@ -0,0 +1,42 @@ +from argparse import ArgumentParser + + +class TestOptions: + + def __init__(self): + self.parser = ArgumentParser() + self.initialize() + + def initialize(self): + # arguments for inference script + self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory') + self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to model checkpoint') + self.parser.add_argument('--couple_outputs', action='store_true', + help='Whether to also save inputs + outputs side-by-side') + + self.parser.add_argument('--mapper_type', default='LevelsMapper', type=str, help='Which mapper to use') + self.parser.add_argument('--no_coarse_mapper', default=False, action="store_true") + self.parser.add_argument('--no_medium_mapper', default=False, action="store_true") + self.parser.add_argument('--no_fine_mapper', default=False, action="store_true") + self.parser.add_argument('--stylegan_size', default=1024, type=int) + + self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference') + self.parser.add_argument('--latents_test_path', default=None, type=str, help="The latents for the validation") + self.parser.add_argument('--test_workers', default=2, type=int, + help='Number of test/inference dataloader workers') + + self.parser.add_argument('--n_images', type=int, default=None, + help='Number of images to output. If None, run on all data') + + self.parser.add_argument('--run_id', type=str, default='PKNWUQRQRKXQ', + help='The generator id to use') + + self.parser.add_argument('--image_name', type=str, default='', + help='image to run on') + + self.parser.add_argument('--edit_name', type=str, default='', + help='edit type') + + def parse(self): + opts = self.parser.parse_args() + return opts diff --git a/models/StyleCLIP/mapper/options/train_options.py b/models/StyleCLIP/mapper/options/train_options.py new file mode 100644 index 0000000000000000000000000000000000000000..a365217f8b76d38aaef4a42b90152ec7a8e7bf1f --- /dev/null +++ b/models/StyleCLIP/mapper/options/train_options.py @@ -0,0 +1,49 @@ +from argparse import ArgumentParser + + +class TrainOptions: + + def __init__(self): + self.parser = ArgumentParser() + self.initialize() + + def initialize(self): + self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory') + self.parser.add_argument('--mapper_type', default='LevelsMapper', type=str, help='Which mapper to use') + self.parser.add_argument('--no_coarse_mapper', default=False, action="store_true") + self.parser.add_argument('--no_medium_mapper', default=False, action="store_true") + self.parser.add_argument('--no_fine_mapper', default=False, action="store_true") + self.parser.add_argument('--latents_train_path', default="train_faces.pt", type=str, help="The latents for the training") + self.parser.add_argument('--latents_test_path', default="test_faces.pt", type=str, help="The latents for the validation") + self.parser.add_argument('--train_dataset_size', default=5000, type=int, help="Will be used only if no latents are given") + self.parser.add_argument('--test_dataset_size', default=1000, type=int, help="Will be used only if no latents are given") + + self.parser.add_argument('--batch_size', default=2, type=int, help='Batch size for training') + self.parser.add_argument('--test_batch_size', default=1, type=int, help='Batch size for testing and inference') + self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers') + self.parser.add_argument('--test_workers', default=2, type=int, help='Number of test/inference dataloader workers') + + self.parser.add_argument('--learning_rate', default=0.5, type=float, help='Optimizer learning rate') + self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use') + + self.parser.add_argument('--id_lambda', default=0.1, type=float, help='ID loss multiplier factor') + self.parser.add_argument('--clip_lambda', default=1.0, type=float, help='CLIP loss multiplier factor') + self.parser.add_argument('--latent_l2_lambda', default=0.8, type=float, help='Latent L2 loss multiplier factor') + + self.parser.add_argument('--stylegan_weights', default='../pretrained_models/stylegan2-ffhq-config-f.pt', type=str, help='Path to StyleGAN model weights') + self.parser.add_argument('--stylegan_size', default=1024, type=int) + self.parser.add_argument('--ir_se50_weights', default='../pretrained_models/model_ir_se50.pth', type=str, help="Path to facial recognition network used in ID loss") + self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to StyleCLIPModel model checkpoint') + + self.parser.add_argument('--max_steps', default=50000, type=int, help='Maximum number of training steps') + self.parser.add_argument('--image_interval', default=100, type=int, help='Interval for logging train images during training') + self.parser.add_argument('--board_interval', default=50, type=int, help='Interval for logging metrics to tensorboard') + self.parser.add_argument('--val_interval', default=2000, type=int, help='Validation interval') + self.parser.add_argument('--save_interval', default=2000, type=int, help='Model checkpoint interval') + + self.parser.add_argument('--description', required=True, type=str, help='Driving text prompt') + + + def parse(self): + opts = self.parser.parse_args() + return opts \ No newline at end of file diff --git a/models/StyleCLIP/mapper/scripts/inference.py b/models/StyleCLIP/mapper/scripts/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..98d765b3607bc6ecf4d137ac3a876b400269c82a --- /dev/null +++ b/models/StyleCLIP/mapper/scripts/inference.py @@ -0,0 +1,80 @@ +import os +import pickle +from argparse import Namespace +import torchvision +import torch +import sys +import time + +from configs import paths_config, global_config +from models.StyleCLIP.mapper.styleclip_mapper import StyleCLIPMapper +from utils.models_utils import load_tuned_G, load_old_G + +sys.path.append(".") +sys.path.append("..") + + +def run(test_opts, model_id, image_name, use_multi_id_G): + out_path_results = os.path.join(test_opts.exp_dir, test_opts.data_dir_name) + os.makedirs(out_path_results, exist_ok=True) + out_path_results = os.path.join(out_path_results, test_opts.image_name) + os.makedirs(out_path_results, exist_ok=True) + + # update test configs with configs used during training + ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu') + opts = ckpt['opts'] + opts.update(vars(test_opts)) + opts = Namespace(**opts) + + net = StyleCLIPMapper(opts, test_opts.run_id) + net.eval() + net.to(global_config.device) + + generator_type = paths_config.multi_id_model_type if use_multi_id_G else image_name + + new_G = load_tuned_G(model_id, generator_type) + old_G = load_old_G() + + run_styleclip(net, new_G, opts, paths_config.pti_results_keyword, out_path_results, test_opts) + run_styleclip(net, old_G, opts, paths_config.e4e_results_keyword, out_path_results, test_opts) + + +def run_styleclip(net, G, opts, method, out_path_results, test_opts): + net.set_G(G) + + out_path_results = os.path.join(out_path_results, method) + os.makedirs(out_path_results, exist_ok=True) + + latent = torch.load(opts.latents_test_path) + + global_i = 0 + global_time = [] + with torch.no_grad(): + input_cuda = latent.cuda().float() + tic = time.time() + result_batch = run_on_batch(input_cuda, net, test_opts.couple_outputs) + toc = time.time() + global_time.append(toc - tic) + + for i in range(opts.test_batch_size): + im_path = f'{test_opts.image_name}_{test_opts.edit_name}' + if test_opts.couple_outputs: + couple_output = torch.cat([result_batch[2][i].unsqueeze(0), result_batch[0][i].unsqueeze(0)]) + torchvision.utils.save_image(couple_output, os.path.join(out_path_results, f"{im_path}.jpg"), + normalize=True, range=(-1, 1)) + else: + torchvision.utils.save_image(result_batch[0][i], os.path.join(out_path_results, f"{im_path}.jpg"), + normalize=True, range=(-1, 1)) + torch.save(result_batch[1][i].detach().cpu(), os.path.join(out_path_results, f"latent_{im_path}.pt")) + + +def run_on_batch(inputs, net, couple_outputs=False): + w = inputs + with torch.no_grad(): + w_hat = w + 0.06 * net.mapper(w) + x_hat = net.decoder.synthesis(w_hat, noise_mode='const', force_fp32=True) + result_batch = (x_hat, w_hat) + if couple_outputs: + x = net.decoder.synthesis(w, noise_mode='const', force_fp32=True) + result_batch = (x_hat, w_hat, x) + return result_batch diff --git a/models/StyleCLIP/mapper/scripts/train.py b/models/StyleCLIP/mapper/scripts/train.py new file mode 100644 index 0000000000000000000000000000000000000000..4141436fb3edee8ab5f7576fde0c0e53b529ef66 --- /dev/null +++ b/models/StyleCLIP/mapper/scripts/train.py @@ -0,0 +1,32 @@ +""" +This file runs the main training/val loop +""" +import os +import json +import sys +import pprint + +sys.path.append(".") +sys.path.append("..") + +from mapper.options.train_options import TrainOptions +from mapper.training.coach import Coach + + +def main(opts): + if os.path.exists(opts.exp_dir): + raise Exception('Oops... {} already exists'.format(opts.exp_dir)) + os.makedirs(opts.exp_dir, exist_ok=True) + + opts_dict = vars(opts) + pprint.pprint(opts_dict) + with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f: + json.dump(opts_dict, f, indent=4, sort_keys=True) + + coach = Coach(opts) + coach.train() + + +if __name__ == '__main__': + opts = TrainOptions().parse() + main(opts) diff --git a/models/StyleCLIP/mapper/styleclip_mapper.py b/models/StyleCLIP/mapper/styleclip_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..86c04bee5744a551f4c0d31359e0de1f5492ff7e --- /dev/null +++ b/models/StyleCLIP/mapper/styleclip_mapper.py @@ -0,0 +1,76 @@ +import torch +from torch import nn +from models.StyleCLIP.mapper import latent_mappers +from models.StyleCLIP.models.stylegan2.model import Generator + + +def get_keys(d, name): + if 'state_dict' in d: + d = d['state_dict'] + d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} + return d_filt + + +class StyleCLIPMapper(nn.Module): + + def __init__(self, opts, run_id): + super(StyleCLIPMapper, self).__init__() + self.opts = opts + # Define architecture + self.mapper = self.set_mapper() + self.run_id = run_id + + self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) + # Load weights if needed + self.load_weights() + + def set_mapper(self): + if self.opts.mapper_type == 'SingleMapper': + mapper = latent_mappers.SingleMapper(self.opts) + elif self.opts.mapper_type == 'LevelsMapper': + mapper = latent_mappers.LevelsMapper(self.opts) + else: + raise Exception('{} is not a valid mapper'.format(self.opts.mapper_type)) + return mapper + + def load_weights(self): + if self.opts.checkpoint_path is not None: + print('Loading from checkpoint: {}'.format(self.opts.checkpoint_path)) + ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') + self.mapper.load_state_dict(get_keys(ckpt, 'mapper'), strict=True) + + def set_G(self, new_G): + self.decoder = new_G + + def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True, + inject_latent=None, return_latents=False, alpha=None): + if input_code: + codes = x + else: + codes = self.mapper(x) + + if latent_mask is not None: + for i in latent_mask: + if inject_latent is not None: + if alpha is not None: + codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] + else: + codes[:, i] = inject_latent[:, i] + else: + codes[:, i] = 0 + + input_is_latent = not input_code + images = self.decoder.synthesis(codes, noise_mode='const') + result_latent = None + # images, result_latent = self.decoder([codes], + # input_is_latent=input_is_latent, + # randomize_noise=randomize_noise, + # return_latents=return_latents) + + if resize: + images = self.face_pool(images) + + if return_latents: + return images, result_latent + else: + return images diff --git a/models/StyleCLIP/mapper/training/__init__.py b/models/StyleCLIP/mapper/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/StyleCLIP/mapper/training/coach.py b/models/StyleCLIP/mapper/training/coach.py new file mode 100644 index 0000000000000000000000000000000000000000..fd38eb226106a21e19beb306cd9b0de6a1e7db04 --- /dev/null +++ b/models/StyleCLIP/mapper/training/coach.py @@ -0,0 +1,242 @@ +import os + +import clip +import torch +import torchvision +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +import criteria.clip_loss as clip_loss +from criteria import id_loss +from mapper.datasets.latents_dataset import LatentsDataset +from mapper.styleclip_mapper import StyleCLIPMapper +from mapper.training.ranger import Ranger +from mapper.training import train_utils + + +class Coach: + def __init__(self, opts): + self.opts = opts + + self.global_step = 0 + + self.device = 'cuda:0' + self.opts.device = self.device + + # Initialize network + self.net = StyleCLIPMapper(self.opts).to(self.device) + + # Initialize loss + if self.opts.id_lambda > 0: + self.id_loss = id_loss.IDLoss(self.opts).to(self.device).eval() + if self.opts.clip_lambda > 0: + self.clip_loss = clip_loss.CLIPLoss(opts) + if self.opts.latent_l2_lambda > 0: + self.latent_l2_loss = nn.MSELoss().to(self.device).eval() + + # Initialize optimizer + self.optimizer = self.configure_optimizers() + + # Initialize dataset + self.train_dataset, self.test_dataset = self.configure_datasets() + self.train_dataloader = DataLoader(self.train_dataset, + batch_size=self.opts.batch_size, + shuffle=True, + num_workers=int(self.opts.workers), + drop_last=True) + self.test_dataloader = DataLoader(self.test_dataset, + batch_size=self.opts.test_batch_size, + shuffle=False, + num_workers=int(self.opts.test_workers), + drop_last=True) + + self.text_inputs = torch.cat([clip.tokenize(self.opts.description)]).cuda() + + # Initialize logger + log_dir = os.path.join(opts.exp_dir, 'logs') + os.makedirs(log_dir, exist_ok=True) + self.log_dir = log_dir + self.logger = SummaryWriter(log_dir=log_dir) + + # Initialize checkpoint dir + self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints') + os.makedirs(self.checkpoint_dir, exist_ok=True) + self.best_val_loss = None + if self.opts.save_interval is None: + self.opts.save_interval = self.opts.max_steps + + def train(self): + self.net.train() + while self.global_step < self.opts.max_steps: + for batch_idx, batch in enumerate(self.train_dataloader): + self.optimizer.zero_grad() + w = batch + w = w.to(self.device) + with torch.no_grad(): + x, _ = self.net.decoder([w], input_is_latent=True, randomize_noise=False, truncation=1) + w_hat = w + 0.1 * self.net.mapper(w) + x_hat, w_hat = self.net.decoder([w_hat], input_is_latent=True, return_latents=True, randomize_noise=False, truncation=1) + loss, loss_dict = self.calc_loss(w, x, w_hat, x_hat) + loss.backward() + self.optimizer.step() + + # Logging related + if self.global_step % self.opts.image_interval == 0 or ( + self.global_step < 1000 and self.global_step % 1000 == 0): + self.parse_and_log_images(x, x_hat, title='images_train') + if self.global_step % self.opts.board_interval == 0: + self.print_metrics(loss_dict, prefix='train') + self.log_metrics(loss_dict, prefix='train') + + # Validation related + val_loss_dict = None + if self.global_step % self.opts.val_interval == 0 or self.global_step == self.opts.max_steps: + val_loss_dict = self.validate() + if val_loss_dict and (self.best_val_loss is None or val_loss_dict['loss'] < self.best_val_loss): + self.best_val_loss = val_loss_dict['loss'] + self.checkpoint_me(val_loss_dict, is_best=True) + + if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps: + if val_loss_dict is not None: + self.checkpoint_me(val_loss_dict, is_best=False) + else: + self.checkpoint_me(loss_dict, is_best=False) + + if self.global_step == self.opts.max_steps: + print('OMG, finished training!') + break + + self.global_step += 1 + + def validate(self): + self.net.eval() + agg_loss_dict = [] + for batch_idx, batch in enumerate(self.test_dataloader): + if batch_idx > 200: + break + + w = batch + + with torch.no_grad(): + w = w.to(self.device).float() + x, _ = self.net.decoder([w], input_is_latent=True, randomize_noise=True, truncation=1) + w_hat = w + 0.1 * self.net.mapper(w) + x_hat, _ = self.net.decoder([w_hat], input_is_latent=True, randomize_noise=True, truncation=1) + loss, cur_loss_dict = self.calc_loss(w, x, w_hat, x_hat) + agg_loss_dict.append(cur_loss_dict) + + # Logging related + self.parse_and_log_images(x, x_hat, title='images_val', index=batch_idx) + + # For first step just do sanity test on small amount of data + if self.global_step == 0 and batch_idx >= 4: + self.net.train() + return None # Do not log, inaccurate in first batch + + loss_dict = train_utils.aggregate_loss_dict(agg_loss_dict) + self.log_metrics(loss_dict, prefix='test') + self.print_metrics(loss_dict, prefix='test') + + self.net.train() + return loss_dict + + def checkpoint_me(self, loss_dict, is_best): + save_name = 'best_model.pt' if is_best else 'iteration_{}.pt'.format(self.global_step) + save_dict = self.__get_save_dict() + checkpoint_path = os.path.join(self.checkpoint_dir, save_name) + torch.save(save_dict, checkpoint_path) + with open(os.path.join(self.checkpoint_dir, 'timestamp.txt'), 'a') as f: + if is_best: + f.write('**Best**: Step - {}, Loss - {:.3f} \n{}\n'.format(self.global_step, self.best_val_loss, loss_dict)) + else: + f.write('Step - {}, \n{}\n'.format(self.global_step, loss_dict)) + + def configure_optimizers(self): + params = list(self.net.mapper.parameters()) + if self.opts.optim_name == 'adam': + optimizer = torch.optim.Adam(params, lr=self.opts.learning_rate) + else: + optimizer = Ranger(params, lr=self.opts.learning_rate) + return optimizer + + def configure_datasets(self): + if self.opts.latents_train_path: + train_latents = torch.load(self.opts.latents_train_path) + else: + train_latents_z = torch.randn(self.opts.train_dataset_size, 512).cuda() + train_latents = [] + for b in range(self.opts.train_dataset_size // self.opts.batch_size): + with torch.no_grad(): + _, train_latents_b = self.net.decoder([train_latents_z[b: b + self.opts.batch_size]], + truncation=0.7, truncation_latent=self.net.latent_avg, return_latents=True) + train_latents.append(train_latents_b) + train_latents = torch.cat(train_latents) + + if self.opts.latents_test_path: + test_latents = torch.load(self.opts.latents_test_path) + else: + test_latents_z = torch.randn(self.opts.train_dataset_size, 512).cuda() + test_latents = [] + for b in range(self.opts.test_dataset_size // self.opts.test_batch_size): + with torch.no_grad(): + _, test_latents_b = self.net.decoder([test_latents_z[b: b + self.opts.test_batch_size]], + truncation=0.7, truncation_latent=self.net.latent_avg, return_latents=True) + test_latents.append(test_latents_b) + test_latents = torch.cat(test_latents) + + train_dataset_celeba = LatentsDataset(latents=train_latents.cpu(), + opts=self.opts) + test_dataset_celeba = LatentsDataset(latents=test_latents.cpu(), + opts=self.opts) + train_dataset = train_dataset_celeba + test_dataset = test_dataset_celeba + print("Number of training samples: {}".format(len(train_dataset))) + print("Number of test samples: {}".format(len(test_dataset))) + return train_dataset, test_dataset + + def calc_loss(self, w, x, w_hat, x_hat): + loss_dict = {} + loss = 0.0 + if self.opts.id_lambda > 0: + loss_id, sim_improvement = self.id_loss(x_hat, x) + loss_dict['loss_id'] = float(loss_id) + loss_dict['id_improve'] = float(sim_improvement) + loss = loss_id * self.opts.id_lambda + if self.opts.clip_lambda > 0: + loss_clip = self.clip_loss(x_hat, self.text_inputs).mean() + loss_dict['loss_clip'] = float(loss_clip) + loss += loss_clip * self.opts.clip_lambda + if self.opts.latent_l2_lambda > 0: + loss_l2_latent = self.latent_l2_loss(w_hat, w) + loss_dict['loss_l2_latent'] = float(loss_l2_latent) + loss += loss_l2_latent * self.opts.latent_l2_lambda + loss_dict['loss'] = float(loss) + return loss, loss_dict + + def log_metrics(self, metrics_dict, prefix): + for key, value in metrics_dict.items(): + #pass + print(f"step: {self.global_step} \t metric: {prefix}/{key} \t value: {value}") + self.logger.add_scalar('{}/{}'.format(prefix, key), value, self.global_step) + + def print_metrics(self, metrics_dict, prefix): + print('Metrics for {}, step {}'.format(prefix, self.global_step)) + for key, value in metrics_dict.items(): + print('\t{} = '.format(key), value) + + def parse_and_log_images(self, x, x_hat, title, index=None): + if index is None: + path = os.path.join(self.log_dir, title, f'{str(self.global_step).zfill(5)}.jpg') + else: + path = os.path.join(self.log_dir, title, f'{str(self.global_step).zfill(5)}_{str(index).zfill(5)}.jpg') + os.makedirs(os.path.dirname(path), exist_ok=True) + torchvision.utils.save_image(torch.cat([x.detach().cpu(), x_hat.detach().cpu()]), path, + normalize=True, scale_each=True, range=(-1, 1), nrow=self.opts.batch_size) + + def __get_save_dict(self): + save_dict = { + 'state_dict': self.net.state_dict(), + 'opts': vars(self.opts) + } + return save_dict \ No newline at end of file diff --git a/models/StyleCLIP/mapper/training/ranger.py b/models/StyleCLIP/mapper/training/ranger.py new file mode 100644 index 0000000000000000000000000000000000000000..9442fd10d42fcc19f4e0dd798d1573b31ed2c0a0 --- /dev/null +++ b/models/StyleCLIP/mapper/training/ranger.py @@ -0,0 +1,164 @@ +# Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer. + +# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer +# and/or +# https://github.com/lessw2020/Best-Deep-Learning-Optimizers + +# Ranger has now been used to capture 12 records on the FastAI leaderboard. + +# This version = 20.4.11 + +# Credits: +# Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization +# RAdam --> https://github.com/LiyuanLucasLiu/RAdam +# Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code. +# Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610 + +# summary of changes: +# 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init. +# full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights), +# supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues. +# changes 8/31/19 - fix references to *self*.N_sma_threshold; +# changed eps to 1e-5 as better default than 1e-8. + +import math +import torch +from torch.optim.optimizer import Optimizer + + +class Ranger(Optimizer): + + def __init__(self, params, lr=1e-3, # lr + alpha=0.5, k=6, N_sma_threshhold=5, # Ranger configs + betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam configs + use_gc=True, gc_conv_only=False + # Gradient centralization on or off, applied to conv layers only or conv + fc layers + ): + + # parameter checks + if not 0.0 <= alpha <= 1.0: + raise ValueError(f'Invalid slow update rate: {alpha}') + if not 1 <= k: + raise ValueError(f'Invalid lookahead steps: {k}') + if not lr > 0: + raise ValueError(f'Invalid Learning Rate: {lr}') + if not eps > 0: + raise ValueError(f'Invalid eps: {eps}') + + # parameter comments: + # beta1 (momentum) of .95 seems to work better than .90... + # N_sma_threshold of 5 seems better in testing than 4. + # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. + + # prep defaults and init torch.optim base + defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, + eps=eps, weight_decay=weight_decay) + super().__init__(params, defaults) + + # adjustable threshold + self.N_sma_threshhold = N_sma_threshhold + + # look ahead params + + self.alpha = alpha + self.k = k + + # radam buffer for state + self.radam_buffer = [[None, None, None] for ind in range(10)] + + # gc on or off + self.use_gc = use_gc + + # level of gradient centralization + self.gc_gradient_threshold = 3 if gc_conv_only else 1 + + def __setstate__(self, state): + super(Ranger, self).__setstate__(state) + + def step(self, closure=None): + loss = None + + # Evaluate averages and grad, update param tensors + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + + if grad.is_sparse: + raise RuntimeError('Ranger optimizer does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] # get state dict for this param + + if len(state) == 0: # if first time to run...init dictionary with our desired entries + # if self.first_run_check==0: + # self.first_run_check=1 + # print("Initializing slow buffer...should not see this at load from saved model!") + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + + # look ahead weight storage now in state dict + state['slow_buffer'] = torch.empty_like(p.data) + state['slow_buffer'].copy_(p.data) + + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + # begin computations + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + # GC operation for Conv layers and FC layers + if grad.dim() > self.gc_gradient_threshold: + grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True)) + + state['step'] += 1 + + # compute variance mov avg + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + # compute mean moving avg + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + buffered = self.radam_buffer[int(state['step'] % 10)] + + if state['step'] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + if N_sma > self.N_sma_threshhold: + step_size = math.sqrt( + (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( + N_sma_max - 2)) / (1 - beta1 ** state['step']) + else: + step_size = 1.0 / (1 - beta1 ** state['step']) + buffered[2] = step_size + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + + # apply lr + if N_sma > self.N_sma_threshhold: + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) + else: + p_data_fp32.add_(-step_size * group['lr'], exp_avg) + + p.data.copy_(p_data_fp32) + + # integrated look ahead... + # we do it at the param level instead of group level + if state['step'] % group['k'] == 0: + slow_p = state['slow_buffer'] # get access to slow param tensor + slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha + p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor + + return loss \ No newline at end of file diff --git a/models/StyleCLIP/mapper/training/train_utils.py b/models/StyleCLIP/mapper/training/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0c55177f7442010bc1fcc64de3d142585c22adc0 --- /dev/null +++ b/models/StyleCLIP/mapper/training/train_utils.py @@ -0,0 +1,13 @@ + +def aggregate_loss_dict(agg_loss_dict): + mean_vals = {} + for output in agg_loss_dict: + for key in output: + mean_vals[key] = mean_vals.setdefault(key, []) + [output[key]] + for key in mean_vals: + if len(mean_vals[key]) > 0: + mean_vals[key] = sum(mean_vals[key]) / len(mean_vals[key]) + else: + print('{} has no value'.format(key)) + mean_vals[key] = 0 + return mean_vals diff --git a/models/StyleCLIP/models/__init__.py b/models/StyleCLIP/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/StyleCLIP/models/facial_recognition/__init__.py b/models/StyleCLIP/models/facial_recognition/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/StyleCLIP/models/facial_recognition/helpers.py b/models/StyleCLIP/models/facial_recognition/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..b51fdf97141407fcc1c9d249a086ddbfd042469f --- /dev/null +++ b/models/StyleCLIP/models/facial_recognition/helpers.py @@ -0,0 +1,119 @@ +from collections import namedtuple +import torch +from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module + +""" +ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Flatten(Module): + def forward(self, input): + return input.view(input.size(0), -1) + + +def l2_norm(input, axis=1): + norm = torch.norm(input, 2, axis, True) + output = torch.div(input, norm) + return output + + +class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): + """ A named tuple describing a ResNet block. """ + + +def get_block(in_channel, depth, num_units, stride=2): + return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=8), + get_block(in_channel=128, depth=256, num_units=36), + get_block(in_channel=256, depth=512, num_units=3) + ] + else: + raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) + return blocks + + +class SEModule(Module): + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2d(1) + self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class bottleneck_IR(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +class bottleneck_IR_SE(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR_SE, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth), + SEModule(depth, 16) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut diff --git a/models/StyleCLIP/models/facial_recognition/model_irse.py b/models/StyleCLIP/models/facial_recognition/model_irse.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c79e0366e4a6fd92011e86df80f8b31ec671ae --- /dev/null +++ b/models/StyleCLIP/models/facial_recognition/model_irse.py @@ -0,0 +1,84 @@ +from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module +from models.facial_recognition.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm + +""" +Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Backbone(Module): + def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): + super(Backbone, self).__init__() + assert input_size in [112, 224], "input_size should be 112 or 224" + assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" + assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + if input_size == 112: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 7 * 7, 512), + BatchNorm1d(512, affine=affine)) + else: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 14 * 14, 512), + BatchNorm1d(512, affine=affine)) + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer(x) + return l2_norm(x) + + +def IR_50(input_size): + """Constructs a ir-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_101(input_size): + """Constructs a ir-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_152(input_size): + """Constructs a ir-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_50(input_size): + """Constructs a ir_se-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_101(input_size): + """Constructs a ir_se-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_152(input_size): + """Constructs a ir_se-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) + return model diff --git a/models/StyleCLIP/models/stylegan2/__init__.py b/models/StyleCLIP/models/stylegan2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/StyleCLIP/models/stylegan2/model.py b/models/StyleCLIP/models/stylegan2/model.py new file mode 100644 index 0000000000000000000000000000000000000000..9d5559203f4f3843fc814b090780ffa129a6fdf0 --- /dev/null +++ b/models/StyleCLIP/models/stylegan2/model.py @@ -0,0 +1,674 @@ +import math +import random + +import torch +from torch import nn +from torch.nn import functional as F + +from models.StyleCLIP.models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d + + +class PixelNorm(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + + if k.ndim == 1: + k = k[None, :] * k[:, None] + + k /= k.sum() + + return k + + +class Upsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor ** 2) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) + + return out + + +class Downsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) + + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + out = upfirdn2d(input, self.kernel, pad=self.pad) + + return out + + +class EqualConv2d(nn.Module): + def __init__( + self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True + ): + super().__init__() + + self.weight = nn.Parameter( + torch.randn(out_channel, in_channel, kernel_size, kernel_size) + ) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + out = F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__( + self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None + ): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear( + input, self.weight * self.scale, bias=self.bias * self.lr_mul + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' + ) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + out = F.leaky_relu(input, negative_slope=self.negative_slope) + + return out * math.sqrt(2) + + +class ModulatedConv2d(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size ** 2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter( + torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) + ) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' + f'upsample={self.upsample}, downsample={self.downsample})' + ) + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view( + batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view( + batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + weight = weight.transpose(1, 2).reshape( + batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size + ) + out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + else: + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out + + +class NoiseInjection(nn.Module): + def __init__(self): + super().__init__() + + self.weight = nn.Parameter(torch.zeros(1)) + + def forward(self, image, noise=None): + if noise is None: + batch, _, height, width = image.shape + noise = image.new_empty(batch, 1, height, width).normal_() + + return image + self.weight * noise + + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + + +class StyledConv(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection() + # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + # self.activate = ScaledLeakyReLU(0.2) + self.activate = FusedLeakyReLU(out_channel) + + def forward(self, input, style, noise=None): + out = self.conv(input, style) + out = self.noise(out, noise=noise) + # out = out + self.bias + out = self.activate(out) + + return out + + +class ToRGB(nn.Module): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, skip=None): + out = self.conv(input, style) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out + + +class Generator(nn.Module): + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + ): + super().__init__() + + self.size = size + + self.style_dim = style_dim + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear( + style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' + ) + ) + + self.style = nn.Sequential(*layers) + + self.channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + self.input = ConstantInput(self.channels[4]) + self.conv1 = StyledConv( + self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel + ) + self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) + + self.log_size = int(math.log(size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channel = self.channels[4] + + for layer_idx in range(self.num_layers): + res = (layer_idx + 5) // 2 + shape = [1, 1, 2 ** res, 2 ** res] + self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape)) + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2 ** i] + + self.convs.append( + StyledConv( + in_channel, + out_channel, + 3, + style_dim, + upsample=True, + blur_kernel=blur_kernel, + ) + ) + + self.convs.append( + StyledConv( + out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel + ) + ) + + self.to_rgbs.append(ToRGB(out_channel, style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + + def make_noise(self): + device = self.input.input.device + + noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) + + return noises + + def mean_latent(self, n_latent): + latent_in = torch.randn( + n_latent, self.style_dim, device=self.input.input.device + ) + latent = self.style(latent_in).mean(0, keepdim=True) + + return latent + + def get_latent(self, input): + return self.style(input) + + def forward( + self, + styles, + return_latents=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + noise=None, + randomize_noise=True, + ): + if not input_is_latent: + styles = [self.style(s) for s in styles] + + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [ + getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) + ] + + if truncation < 1: + style_t = [] + + for style in styles: + style_t.append( + truncation_latent + truncation * (style - truncation_latent) + ) + + styles = style_t + + if len(styles) < 2: + inject_index = self.n_latent + + if styles[0].ndim < 3: + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + + else: + latent = styles[0] + + else: + if inject_index is None: + inject_index = random.randint(1, self.n_latent - 1) + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) + + latent = torch.cat([latent, latent2], 1) + + out = self.input(latent) + out = self.conv1(out, latent[:, 0], noise=noise[0]) + + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs + ): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + + i += 2 + + image = skip + + if return_latents: + return image, latent + + else: + return image, None + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append( + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + ) + ) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer( + in_channel, out_channel, 1, downsample=True, activate=False, bias=False + ) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class Discriminator(nn.Module): + def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + convs = [ConvLayer(3, channels[size], 1)] + + log_size = int(math.log(size, 2)) + + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + self.stddev_group = 4 + self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + self.final_linear = nn.Sequential( + EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + out = self.convs(input) + + batch, channel, height, width = out.shape + group = min(batch, self.stddev_group) + stddev = out.view( + group, -1, self.stddev_feat, channel // self.stddev_feat, height, width + ) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + stddev = stddev.repeat(group, 1, height, width) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.view(batch, -1) + out = self.final_linear(out) + + return out + diff --git a/models/StyleCLIP/models/stylegan2/op/__init__.py b/models/StyleCLIP/models/stylegan2/op/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0918d92285955855be89f00096b888ee5597ce3 --- /dev/null +++ b/models/StyleCLIP/models/stylegan2/op/__init__.py @@ -0,0 +1,2 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu +from .upfirdn2d import upfirdn2d diff --git a/models/StyleCLIP/models/stylegan2/op/fused_act.py b/models/StyleCLIP/models/stylegan2/op/fused_act.py new file mode 100644 index 0000000000000000000000000000000000000000..2d575bc9198e6d46eee040eb374c6d8f64c3363c --- /dev/null +++ b/models/StyleCLIP/models/stylegan2/op/fused_act.py @@ -0,0 +1,40 @@ +import os + +import torch +from torch import nn +from torch.nn import functional as F + +module_path = os.path.dirname(__file__) + + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + rest_dim = [1] * (input.ndim - bias.ndim - 1) + input = input.cuda() + if input.ndim == 3: + return ( + F.leaky_relu( + input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope + ) + * scale + ) + else: + return ( + F.leaky_relu( + input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope + ) + * scale + ) + diff --git a/models/StyleCLIP/models/stylegan2/op/upfirdn2d.py b/models/StyleCLIP/models/stylegan2/op/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..02fc25af780868d9b883631eb6b03a25c225d745 --- /dev/null +++ b/models/StyleCLIP/models/stylegan2/op/upfirdn2d.py @@ -0,0 +1,60 @@ +import os + +import torch +from torch.nn import functional as F + + +module_path = os.path.dirname(__file__) + + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + out = upfirdn2d_native( + input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] + ) + + return out + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) \ No newline at end of file diff --git a/models/StyleCLIP/optimization/run_optimization.py b/models/StyleCLIP/optimization/run_optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..766d0c81400951202bed51e3f1812e1260ccf071 --- /dev/null +++ b/models/StyleCLIP/optimization/run_optimization.py @@ -0,0 +1,128 @@ +import argparse +import math +import os +import pickle + +import torch +import torchvision +from torch import optim +from tqdm import tqdm + +from StyleCLIP.criteria.clip_loss import CLIPLoss +from StyleCLIP.models.stylegan2.model import Generator +import clip +from StyleCLIP.utils import ensure_checkpoint_exists + + +def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): + lr_ramp = min(1, (1 - t) / rampdown) + lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) + lr_ramp = lr_ramp * min(1, t / rampup) + + return initial_lr * lr_ramp + + +def main(args, use_old_G): + ensure_checkpoint_exists(args.ckpt) + text_inputs = torch.cat([clip.tokenize(args.description)]).cuda() + os.makedirs(args.results_dir, exist_ok=True) + new_generator_path = f'/disk2/danielroich/Sandbox/stylegan2_ada_pytorch/checkpoints/model_{args.run_id}_{args.image_name}.pt' + old_generator_path = '/disk2/danielroich/Sandbox/pretrained_models/ffhq.pkl' + + if not use_old_G: + with open(new_generator_path, 'rb') as f: + G = torch.load(f).cuda().eval() + else: + with open(old_generator_path, 'rb') as f: + G = pickle.load(f)['G_ema'].cuda().eval() + + if args.latent_path: + latent_code_init = torch.load(args.latent_path).cuda() + elif args.mode == "edit": + latent_code_init_not_trunc = torch.randn(1, 512).cuda() + with torch.no_grad(): + latent_code_init = G.mapping(latent_code_init_not_trunc, None) + + latent = latent_code_init.detach().clone() + latent.requires_grad = True + + clip_loss = CLIPLoss(args) + + optimizer = optim.Adam([latent], lr=args.lr) + + pbar = tqdm(range(args.step)) + + for i in pbar: + t = i / args.step + lr = get_lr(t, args.lr) + optimizer.param_groups[0]["lr"] = lr + + img_gen = G.synthesis(latent, noise_mode='const') + + c_loss = clip_loss(img_gen, text_inputs) + + if args.mode == "edit": + l2_loss = ((latent_code_init - latent) ** 2).sum() + loss = c_loss + args.l2_lambda * l2_loss + else: + loss = c_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + pbar.set_description( + ( + f"loss: {loss.item():.4f};" + ) + ) + if args.save_intermediate_image_every > 0 and i % args.save_intermediate_image_every == 0: + with torch.no_grad(): + img_gen = G.synthesis(latent, noise_mode='const') + + torchvision.utils.save_image(img_gen, + f"/disk2/danielroich/Sandbox/StyleCLIP/results/inference_results/{str(i).zfill(5)}.png", + normalize=True, range=(-1, 1)) + + if args.mode == "edit": + with torch.no_grad(): + img_orig = G.synthesis(latent_code_init, noise_mode='const') + + final_result = torch.cat([img_orig, img_gen]) + else: + final_result = img_gen + + return final_result + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--description", type=str, default="a person with purple hair", + help="the text that guides the editing/generation") + parser.add_argument("--ckpt", type=str, default="../pretrained_models/stylegan2-ffhq-config-f.pt", + help="pretrained StyleGAN2 weights") + parser.add_argument("--stylegan_size", type=int, default=1024, help="StyleGAN resolution") + parser.add_argument("--lr_rampup", type=float, default=0.05) + parser.add_argument("--lr", type=float, default=0.1) + parser.add_argument("--step", type=int, default=300, help="number of optimization steps") + parser.add_argument("--mode", type=str, default="edit", choices=["edit", "free_generation"], + help="choose between edit an image an generate a free one") + parser.add_argument("--l2_lambda", type=float, default=0.008, + help="weight of the latent distance (used for editing only)") + parser.add_argument("--latent_path", type=str, default=None, + help="starts the optimization from the given latent code if provided. Otherwose, starts from" + "the mean latent in a free generation, and from a random one in editing. " + "Expects a .pt format") + parser.add_argument("--truncation", type=float, default=0.7, + help="used only for the initial latent vector, and only when a latent code path is" + "not provided") + parser.add_argument("--save_intermediate_image_every", type=int, default=20, + help="if > 0 then saves intermidate results during the optimization") + parser.add_argument("--results_dir", type=str, default="results") + + args = parser.parse_args() + + result_image = main(args) + + torchvision.utils.save_image(result_image.detach().cpu(), os.path.join(args.results_dir, "final_result.jpg"), + normalize=True, scale_each=True, range=(-1, 1)) diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/e4e/__init__.py b/models/e4e/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/e4e/discriminator.py b/models/e4e/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..16bf3722c7f2e35cdc9bd177a33ed0975e67200d --- /dev/null +++ b/models/e4e/discriminator.py @@ -0,0 +1,20 @@ +from torch import nn + + +class LatentCodesDiscriminator(nn.Module): + def __init__(self, style_dim, n_mlp): + super().__init__() + + self.style_dim = style_dim + + layers = [] + for i in range(n_mlp-1): + layers.append( + nn.Linear(style_dim, style_dim) + ) + layers.append(nn.LeakyReLU(0.2)) + layers.append(nn.Linear(512, 1)) + self.mlp = nn.Sequential(*layers) + + def forward(self, w): + return self.mlp(w) diff --git a/models/e4e/encoders/__init__.py b/models/e4e/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/e4e/encoders/helpers.py b/models/e4e/encoders/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..c4a58b34ea5ca6912fe53c63dede0a8696f5c024 --- /dev/null +++ b/models/e4e/encoders/helpers.py @@ -0,0 +1,140 @@ +from collections import namedtuple +import torch +import torch.nn.functional as F +from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module + +""" +ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Flatten(Module): + def forward(self, input): + return input.view(input.size(0), -1) + + +def l2_norm(input, axis=1): + norm = torch.norm(input, 2, axis, True) + output = torch.div(input, norm) + return output + + +class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): + """ A named tuple describing a ResNet block. """ + + +def get_block(in_channel, depth, num_units, stride=2): + return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=8), + get_block(in_channel=128, depth=256, num_units=36), + get_block(in_channel=256, depth=512, num_units=3) + ] + else: + raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) + return blocks + + +class SEModule(Module): + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2d(1) + self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class bottleneck_IR(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +class bottleneck_IR_SE(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR_SE, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth), + SEModule(depth, 16) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +def _upsample_add(x, y): + """Upsample and add two feature maps. + Args: + x: (Variable) top feature map to be upsampled. + y: (Variable) lateral feature map. + Returns: + (Variable) added feature map. + Note in PyTorch, when input size is odd, the upsampled feature map + with `F.upsample(..., scale_factor=2, mode='nearest')` + maybe not equal to the lateral feature map size. + e.g. + original input size: [N,_,15,15] -> + conv2d feature map size: [N,_,8,8] -> + upsampled feature map size: [N,_,16,16] + So we choose bilinear upsample which supports arbitrary output sizes. + """ + _, _, H, W = y.size() + return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y diff --git a/models/e4e/encoders/model_irse.py b/models/e4e/encoders/model_irse.py new file mode 100644 index 0000000000000000000000000000000000000000..976ce2c61104efdc6b0015d895830346dd01bc10 --- /dev/null +++ b/models/e4e/encoders/model_irse.py @@ -0,0 +1,84 @@ +from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module +from encoder4editing.models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm + +""" +Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Backbone(Module): + def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): + super(Backbone, self).__init__() + assert input_size in [112, 224], "input_size should be 112 or 224" + assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" + assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + if input_size == 112: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 7 * 7, 512), + BatchNorm1d(512, affine=affine)) + else: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 14 * 14, 512), + BatchNorm1d(512, affine=affine)) + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer(x) + return l2_norm(x) + + +def IR_50(input_size): + """Constructs a ir-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_101(input_size): + """Constructs a ir-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_152(input_size): + """Constructs a ir-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_50(input_size): + """Constructs a ir_se-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_101(input_size): + """Constructs a ir_se-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_152(input_size): + """Constructs a ir_se-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) + return model diff --git a/models/e4e/encoders/psp_encoders.py b/models/e4e/encoders/psp_encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..9c7c70e5e2586bd6a0de825e45a80e9116156166 --- /dev/null +++ b/models/e4e/encoders/psp_encoders.py @@ -0,0 +1,200 @@ +from enum import Enum +import math +import numpy as np +import torch +from torch import nn +from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module + +from models.e4e.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add +from models.e4e.stylegan2.model import EqualLinear + + +class ProgressiveStage(Enum): + WTraining = 0 + Delta1Training = 1 + Delta2Training = 2 + Delta3Training = 3 + Delta4Training = 4 + Delta5Training = 5 + Delta6Training = 6 + Delta7Training = 7 + Delta8Training = 8 + Delta9Training = 9 + Delta10Training = 10 + Delta11Training = 11 + Delta12Training = 12 + Delta13Training = 13 + Delta14Training = 14 + Delta15Training = 15 + Delta16Training = 16 + Delta17Training = 17 + Inference = 18 + + +class GradualStyleBlock(Module): + def __init__(self, in_c, out_c, spatial): + super(GradualStyleBlock, self).__init__() + self.out_c = out_c + self.spatial = spatial + num_pools = int(np.log2(spatial)) + modules = [] + modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU()] + for i in range(num_pools - 1): + modules += [ + Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU() + ] + self.convs = nn.Sequential(*modules) + self.linear = EqualLinear(out_c, out_c, lr_mul=1) + + def forward(self, x): + x = self.convs(x) + x = x.view(-1, self.out_c) + x = self.linear(x) + return x + + +class GradualStyleEncoder(Module): + def __init__(self, num_layers, mode='ir', opts=None): + super(GradualStyleEncoder, self).__init__() + assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + self.styles = nn.ModuleList() + log_size = int(math.log(opts.stylegan_size, 2)) + self.style_count = 2 * log_size - 2 + self.coarse_ind = 3 + self.middle_ind = 7 + for i in range(self.style_count): + if i < self.coarse_ind: + style = GradualStyleBlock(512, 512, 16) + elif i < self.middle_ind: + style = GradualStyleBlock(512, 512, 32) + else: + style = GradualStyleBlock(512, 512, 64) + self.styles.append(style) + self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) + self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + x = self.input_layer(x) + + latents = [] + modulelist = list(self.body._modules.values()) + for i, l in enumerate(modulelist): + x = l(x) + if i == 6: + c1 = x + elif i == 20: + c2 = x + elif i == 23: + c3 = x + + for j in range(self.coarse_ind): + latents.append(self.styles[j](c3)) + + p2 = _upsample_add(c3, self.latlayer1(c2)) + for j in range(self.coarse_ind, self.middle_ind): + latents.append(self.styles[j](p2)) + + p1 = _upsample_add(p2, self.latlayer2(c1)) + for j in range(self.middle_ind, self.style_count): + latents.append(self.styles[j](p1)) + + out = torch.stack(latents, dim=1) + return out + + +class Encoder4Editing(Module): + def __init__(self, num_layers, mode='ir', opts=None): + super(Encoder4Editing, self).__init__() + assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + self.styles = nn.ModuleList() + log_size = int(math.log(opts.stylegan_size, 2)) + self.style_count = 2 * log_size - 2 + self.coarse_ind = 3 + self.middle_ind = 7 + + for i in range(self.style_count): + if i < self.coarse_ind: + style = GradualStyleBlock(512, 512, 16) + elif i < self.middle_ind: + style = GradualStyleBlock(512, 512, 32) + else: + style = GradualStyleBlock(512, 512, 64) + self.styles.append(style) + + self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) + self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) + + self.progressive_stage = ProgressiveStage.Inference + + def get_deltas_starting_dimensions(self): + ''' Get a list of the initial dimension of every delta from which it is applied ''' + return list(range(self.style_count)) # Each dimension has a delta applied to it + + def set_progressive_stage(self, new_stage: ProgressiveStage): + self.progressive_stage = new_stage + print('Changed progressive stage to: ', new_stage) + + def forward(self, x): + x = self.input_layer(x) + + modulelist = list(self.body._modules.values()) + for i, l in enumerate(modulelist): + x = l(x) + if i == 6: + c1 = x + elif i == 20: + c2 = x + elif i == 23: + c3 = x + + # Infer main W and duplicate it + w0 = self.styles[0](c3) + w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2) + stage = self.progressive_stage.value + features = c3 + for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas + if i == self.coarse_ind: + p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features + features = p2 + elif i == self.middle_ind: + p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features + features = p1 + delta_i = self.styles[i](features) + w[:, i] += delta_i + return w diff --git a/models/e4e/latent_codes_pool.py b/models/e4e/latent_codes_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..0281d4b5e80f8eb26e824fa35b4f908dcb6634e6 --- /dev/null +++ b/models/e4e/latent_codes_pool.py @@ -0,0 +1,55 @@ +import random +import torch + + +class LatentCodesPool: + """This class implements latent codes buffer that stores previously generated w latent codes. + This buffer enables us to update discriminators using a history of generated w's + rather than the ones produced by the latest encoder. + """ + + def __init__(self, pool_size): + """Initialize the ImagePool class + Parameters: + pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created + """ + self.pool_size = pool_size + if self.pool_size > 0: # create an empty pool + self.num_ws = 0 + self.ws = [] + + def query(self, ws): + """Return w's from the pool. + Parameters: + ws: the latest generated w's from the generator + Returns w's from the buffer. + By 50/100, the buffer will return input w's. + By 50/100, the buffer will return w's previously stored in the buffer, + and insert the current w's to the buffer. + """ + if self.pool_size == 0: # if the buffer size is 0, do nothing + return ws + return_ws = [] + for w in ws: # ws.shape: (batch, 512) or (batch, n_latent, 512) + # w = torch.unsqueeze(image.data, 0) + if w.ndim == 2: + i = random.randint(0, len(w) - 1) # apply a random latent index as a candidate + w = w[i] + self.handle_w(w, return_ws) + return_ws = torch.stack(return_ws, 0) # collect all the images and return + return return_ws + + def handle_w(self, w, return_ws): + if self.num_ws < self.pool_size: # if the buffer is not full; keep inserting current codes to the buffer + self.num_ws = self.num_ws + 1 + self.ws.append(w) + return_ws.append(w) + else: + p = random.uniform(0, 1) + if p > 0.5: # by 50% chance, the buffer will return a previously stored latent code, and insert the current code into the buffer + random_id = random.randint(0, self.pool_size - 1) # randint is inclusive + tmp = self.ws[random_id].clone() + self.ws[random_id] = w + return_ws.append(tmp) + else: # by another 50% chance, the buffer will return the current image + return_ws.append(w) diff --git a/models/e4e/psp.py b/models/e4e/psp.py new file mode 100644 index 0000000000000000000000000000000000000000..bf9f75dbaa66997abfc1e3e0e4f19ddfec7fedac --- /dev/null +++ b/models/e4e/psp.py @@ -0,0 +1,97 @@ +import matplotlib +from configs import paths_config +matplotlib.use('Agg') +import torch +from torch import nn +from models.e4e.encoders import psp_encoders +from models.e4e.stylegan2.model import Generator + + +def get_keys(d, name): + if 'state_dict' in d: + d = d['state_dict'] + d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} + return d_filt + + +class pSp(nn.Module): + + def __init__(self, opts): + super(pSp, self).__init__() + self.opts = opts + # Define architecture + self.encoder = self.set_encoder() + self.decoder = Generator(opts.stylegan_size, 512, 8, channel_multiplier=2) + self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) + # Load weights if needed + self.load_weights() + + def set_encoder(self): + if self.opts.encoder_type == 'GradualStyleEncoder': + encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts) + elif self.opts.encoder_type == 'Encoder4Editing': + encoder = psp_encoders.Encoder4Editing(50, 'ir_se', self.opts) + else: + raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type)) + return encoder + + def load_weights(self): + if self.opts.checkpoint_path is not None: + print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.checkpoint_path)) + ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') + self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True) + self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True) + self.__load_latent_avg(ckpt) + else: + print('Loading encoders weights from irse50!') + encoder_ckpt = torch.load(paths_config.ir_se50) + self.encoder.load_state_dict(encoder_ckpt, strict=False) + print('Loading decoder weights from pretrained!') + ckpt = torch.load(self.opts.stylegan_weights) + self.decoder.load_state_dict(ckpt['g_ema'], strict=False) + self.__load_latent_avg(ckpt, repeat=self.encoder.style_count) + + def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True, + inject_latent=None, return_latents=False, alpha=None): + if input_code: + codes = x + else: + codes = self.encoder(x) + # normalize with respect to the center of an average face + if self.opts.start_from_latent_avg: + if codes.ndim == 2: + codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :] + else: + codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1) + + if latent_mask is not None: + for i in latent_mask: + if inject_latent is not None: + if alpha is not None: + codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] + else: + codes[:, i] = inject_latent[:, i] + else: + codes[:, i] = 0 + + input_is_latent = not input_code + images, result_latent = self.decoder([codes], + input_is_latent=input_is_latent, + randomize_noise=randomize_noise, + return_latents=return_latents) + + if resize: + images = self.face_pool(images) + + if return_latents: + return images, result_latent + else: + return images + + def __load_latent_avg(self, ckpt, repeat=None): + if 'latent_avg' in ckpt: + self.latent_avg = ckpt['latent_avg'].to(self.opts.device) + if repeat is not None: + self.latent_avg = self.latent_avg.repeat(repeat, 1) + else: + self.latent_avg = None diff --git a/models/e4e/stylegan2/__init__.py b/models/e4e/stylegan2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/e4e/stylegan2/model.py b/models/e4e/stylegan2/model.py new file mode 100644 index 0000000000000000000000000000000000000000..ede4360148e260363887662bae7fe68c987ee60e --- /dev/null +++ b/models/e4e/stylegan2/model.py @@ -0,0 +1,674 @@ +import math +import random +import torch +from torch import nn +from torch.nn import functional as F + +from .op.fused_act import FusedLeakyReLU, fused_leaky_relu +from .op.upfirdn2d import upfirdn2d + + +class PixelNorm(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + + if k.ndim == 1: + k = k[None, :] * k[:, None] + + k /= k.sum() + + return k + + +class Upsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor ** 2) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) + + return out + + +class Downsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) + + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + out = upfirdn2d(input, self.kernel, pad=self.pad) + + return out + + +class EqualConv2d(nn.Module): + def __init__( + self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True + ): + super().__init__() + + self.weight = nn.Parameter( + torch.randn(out_channel, in_channel, kernel_size, kernel_size) + ) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + out = F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__( + self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None + ): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear( + input, self.weight * self.scale, bias=self.bias * self.lr_mul + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' + ) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + out = F.leaky_relu(input, negative_slope=self.negative_slope) + + return out * math.sqrt(2) + + +class ModulatedConv2d(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size ** 2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter( + torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) + ) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' + f'upsample={self.upsample}, downsample={self.downsample})' + ) + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view( + batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view( + batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + weight = weight.transpose(1, 2).reshape( + batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size + ) + out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + else: + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out + + +class NoiseInjection(nn.Module): + def __init__(self): + super().__init__() + + self.weight = nn.Parameter(torch.zeros(1)) + + def forward(self, image, noise=None): + if noise is None: + batch, _, height, width = image.shape + noise = image.new_empty(batch, 1, height, width).normal_() + + return image + self.weight * noise + + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + + +class StyledConv(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection() + # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + # self.activate = ScaledLeakyReLU(0.2) + self.activate = FusedLeakyReLU(out_channel) + + def forward(self, input, style, noise=None): + out = self.conv(input, style) + out = self.noise(out, noise=noise) + # out = out + self.bias + out = self.activate(out) + + return out + + +class ToRGB(nn.Module): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, skip=None): + out = self.conv(input, style) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out + + +class Generator(nn.Module): + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + ): + super().__init__() + + self.size = size + + self.style_dim = style_dim + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear( + style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' + ) + ) + + self.style = nn.Sequential(*layers) + + self.channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + self.input = ConstantInput(self.channels[4]) + self.conv1 = StyledConv( + self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel + ) + self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) + + self.log_size = int(math.log(size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channel = self.channels[4] + + for layer_idx in range(self.num_layers): + res = (layer_idx + 5) // 2 + shape = [1, 1, 2 ** res, 2 ** res] + self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape)) + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2 ** i] + + self.convs.append( + StyledConv( + in_channel, + out_channel, + 3, + style_dim, + upsample=True, + blur_kernel=blur_kernel, + ) + ) + + self.convs.append( + StyledConv( + out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel + ) + ) + + self.to_rgbs.append(ToRGB(out_channel, style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + + def make_noise(self): + device = self.input.input.device + + noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) + + return noises + + def mean_latent(self, n_latent): + latent_in = torch.randn( + n_latent, self.style_dim, device=self.input.input.device + ) + latent = self.style(latent_in).mean(0, keepdim=True) + + return latent + + def get_latent(self, input): + return self.style(input) + + def forward( + self, + styles, + return_latents=False, + return_features=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + noise=None, + randomize_noise=True, + ): + if not input_is_latent: + styles = [self.style(s) for s in styles] + + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [ + getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) + ] + + if truncation < 1: + style_t = [] + + for style in styles: + style_t.append( + truncation_latent + truncation * (style - truncation_latent) + ) + + styles = style_t + + if len(styles) < 2: + inject_index = self.n_latent + + if styles[0].ndim < 3: + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: + latent = styles[0] + + else: + if inject_index is None: + inject_index = random.randint(1, self.n_latent - 1) + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) + + latent = torch.cat([latent, latent2], 1) + + out = self.input(latent) + out = self.conv1(out, latent[:, 0], noise=noise[0]) + + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs + ): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + + i += 2 + + image = skip + + if return_latents: + return image, latent + elif return_features: + return image, out + else: + return image, None + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append( + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + ) + ) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer( + in_channel, out_channel, 1, downsample=True, activate=False, bias=False + ) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class Discriminator(nn.Module): + def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + convs = [ConvLayer(3, channels[size], 1)] + + log_size = int(math.log(size, 2)) + + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + self.stddev_group = 4 + self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + self.final_linear = nn.Sequential( + EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + out = self.convs(input) + + batch, channel, height, width = out.shape + group = min(batch, self.stddev_group) + stddev = out.view( + group, -1, self.stddev_feat, channel // self.stddev_feat, height, width + ) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + stddev = stddev.repeat(group, 1, height, width) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.view(batch, -1) + out = self.final_linear(out) + + return out diff --git a/models/e4e/stylegan2/op/__init__.py b/models/e4e/stylegan2/op/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0918d92285955855be89f00096b888ee5597ce3 --- /dev/null +++ b/models/e4e/stylegan2/op/__init__.py @@ -0,0 +1,2 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu +from .upfirdn2d import upfirdn2d diff --git a/models/e4e/stylegan2/op/fused_act.py b/models/e4e/stylegan2/op/fused_act.py new file mode 100644 index 0000000000000000000000000000000000000000..90949545ba955dabf2e17d8cf5e524d5cb190a63 --- /dev/null +++ b/models/e4e/stylegan2/op/fused_act.py @@ -0,0 +1,34 @@ +import os + +import torch +from torch import nn +from torch.nn import functional as F +from torch.autograd import Function + + +module_path = os.path.dirname(__file__) + + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + rest_dim = [1] * (input.ndim - bias.ndim - 1) + input = input.cuda() + return ( + F.leaky_relu( + input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope + ) + * scale + ) + diff --git a/models/e4e/stylegan2/op/fused_bias_act.cpp b/models/e4e/stylegan2/op/fused_bias_act.cpp new file mode 100644 index 0000000000000000000000000000000000000000..02be898f970bcc8ea297867fcaa4e71b24b3d949 --- /dev/null +++ b/models/e4e/stylegan2/op/fused_bias_act.cpp @@ -0,0 +1,21 @@ +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} \ No newline at end of file diff --git a/models/e4e/stylegan2/op/fused_bias_act_kernel.cu b/models/e4e/stylegan2/op/fused_bias_act_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..c9fa56fea7ede7072dc8925cfb0148f136eb85b8 --- /dev/null +++ b/models/e4e/stylegan2/op/fused_bias_act_kernel.cu @@ -0,0 +1,99 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +template +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, + int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + switch (act * 10 + grad) { + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0; break; + + case 30: y = (x > 0.0) ? x : x * alpha; break; + case 31: y = (ref > 0.0) ? x : x * alpha; break; + case 32: y = 0.0; break; + } + + out[xi] = y * scale; + } +} + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), + x.data_ptr(), + b.data_ptr(), + ref.data_ptr(), + act, + grad, + alpha, + scale, + loop_x, + size_x, + step_b, + size_b, + use_bias, + use_ref + ); + }); + + return y; +} \ No newline at end of file diff --git a/models/e4e/stylegan2/op/upfirdn2d.cpp b/models/e4e/stylegan2/op/upfirdn2d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d2e633dc896433c205e18bc3e455539192ff968e --- /dev/null +++ b/models/e4e/stylegan2/op/upfirdn2d.cpp @@ -0,0 +1,23 @@ +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} \ No newline at end of file diff --git a/models/e4e/stylegan2/op/upfirdn2d.py b/models/e4e/stylegan2/op/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..02fc25af780868d9b883631eb6b03a25c225d745 --- /dev/null +++ b/models/e4e/stylegan2/op/upfirdn2d.py @@ -0,0 +1,60 @@ +import os + +import torch +from torch.nn import functional as F + + +module_path = os.path.dirname(__file__) + + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + out = upfirdn2d_native( + input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] + ) + + return out + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) \ No newline at end of file diff --git a/models/e4e/stylegan2/op/upfirdn2d_kernel.cu b/models/e4e/stylegan2/op/upfirdn2d_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..2a710aa6adc3d43ac93136a1814e3c39970e1c7e --- /dev/null +++ b/models/e4e/stylegan2/op/upfirdn2d_kernel.cu @@ -0,0 +1,272 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + + +template +__global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + + #pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) + #pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; + } + } + } + } +} + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; + + auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h; + int tile_out_w; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 2: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 3: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 4: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 5: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 6: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + } + }); + + return out; +} \ No newline at end of file diff --git a/pretrained_models/deeplab_model/R-101-GN-WS.pth.tar b/pretrained_models/deeplab_model/R-101-GN-WS.pth.tar new file mode 100644 index 0000000000000000000000000000000000000000..f2533b78b9479b4439ab4b651579b4f9aadf3fe5 --- /dev/null +++ b/pretrained_models/deeplab_model/R-101-GN-WS.pth.tar @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7355634b145d620f4b0d6d7d0295df460e02a617029a3adb71991b952c7db0d3 +size 178260167 diff --git a/pretrained_models/deeplab_model/deeplab_model.pth b/pretrained_models/deeplab_model/deeplab_model.pth new file mode 100644 index 0000000000000000000000000000000000000000..72ec3520a284467f119bec3ba781437a7f8c99a6 --- /dev/null +++ b/pretrained_models/deeplab_model/deeplab_model.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b36429f3d7adc6b9bd7ca90c07817f608a189a38e70dd9464b32d0585d846c5b +size 464446305 diff --git a/pretrained_models/model_ir_se50.pth b/pretrained_models/model_ir_se50.pth new file mode 100644 index 0000000000000000000000000000000000000000..d3a030dd9a353d94023d3fc3a5baa0991ca3873b --- /dev/null +++ b/pretrained_models/model_ir_se50.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a035c768259b98ab1ce0e646312f48b9e1e218197a0f80ac6765e88f8b6ddf28 +size 175367323 diff --git a/PTI/run_pti.py b/run_pti.py similarity index 90% rename from PTI/run_pti.py rename to run_pti.py index 1cbe5398f6f2ccbc00834d997af65fa91c5de7dd..d3788be0aa054414f3608b5851df145c65c26020 100644 --- a/PTI/run_pti.py +++ b/run_pti.py @@ -6,7 +6,6 @@ import os import sys from configs import global_config, paths_config -import wandb from training.coaches.multi_id_coach import MultiIDCoach from training.coaches.single_id_coach import SingleIDCoach @@ -21,13 +20,6 @@ def run_PTI(run_name="", use_wandb=False, use_multi_id_training=False): global_config.run_name = "".join(choice(ascii_uppercase) for i in range(12)) else: global_config.run_name = run_name - - if use_wandb: - run = wandb.init( - project=paths_config.pti_results_keyword, - reinit=True, - name=global_config.run_name, - ) global_config.pivotal_training_steps = 1 global_config.training_step = 1 diff --git a/torch_utils/__pycache__/__init__.cpython-39.pyc b/torch_utils/__pycache__/__init__.cpython-39.pyc index 8c5297410e63e17cbb200b8d06ff9ad71d8c4de9..ac3757ab7939f070f9d700054e8b5c8d8d5dd13a 100644 Binary files a/torch_utils/__pycache__/__init__.cpython-39.pyc and b/torch_utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/torch_utils/__pycache__/custom_ops.cpython-39.pyc b/torch_utils/__pycache__/custom_ops.cpython-39.pyc index f1d857540856af0f54edc1dbabcacd6b113332d9..5760d0b7be50944a5fd56f5e14e71c493af906bd 100644 Binary files a/torch_utils/__pycache__/custom_ops.cpython-39.pyc and b/torch_utils/__pycache__/custom_ops.cpython-39.pyc differ diff --git a/torch_utils/__pycache__/misc.cpython-39.pyc b/torch_utils/__pycache__/misc.cpython-39.pyc index 8bc767000c3e3970764cf7926b65c8c1eef0699a..5454c56fe0fce631c4c7e1420bc41b7916d9b18b 100644 Binary files a/torch_utils/__pycache__/misc.cpython-39.pyc and b/torch_utils/__pycache__/misc.cpython-39.pyc differ diff --git a/torch_utils/__pycache__/persistence.cpython-39.pyc b/torch_utils/__pycache__/persistence.cpython-39.pyc index 8c69104deacd31cd5a2013ec3aba03ff08f32860..ea3a5b02e6480a341e91b40389c23a1e4306b959 100644 Binary files a/torch_utils/__pycache__/persistence.cpython-39.pyc and b/torch_utils/__pycache__/persistence.cpython-39.pyc differ diff --git a/torch_utils/ops/__pycache__/__init__.cpython-39.pyc b/torch_utils/ops/__pycache__/__init__.cpython-39.pyc index 6d1c49f3b2ab434e446868e42bcf743d25deedb8..cc1c03bcc67a259b43c825a8594286811d15eb4d 100644 Binary files a/torch_utils/ops/__pycache__/__init__.cpython-39.pyc and b/torch_utils/ops/__pycache__/__init__.cpython-39.pyc differ diff --git a/torch_utils/ops/__pycache__/bias_act.cpython-39.pyc b/torch_utils/ops/__pycache__/bias_act.cpython-39.pyc index f2255e29702430554e9e556951ddc54047d8eac3..1e9e88ce8fe594c15ceec21299374b73f3c47a19 100644 Binary files a/torch_utils/ops/__pycache__/bias_act.cpython-39.pyc and b/torch_utils/ops/__pycache__/bias_act.cpython-39.pyc differ diff --git a/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc b/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc index d2fe02e17b1038d6ed3ffee17eb463ebb09754ac..52225ff3b32fa8547ec29dce64add155214395c9 100644 Binary files a/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc and b/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc differ diff --git a/torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc b/torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc index f642517501eb342e94d8ac92ee019182b4eb16b4..063f4e6fcf0ea79d45ff06e0c7c6e37bf5e8739c 100644 Binary files a/torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc and b/torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc differ diff --git a/torch_utils/ops/__pycache__/fma.cpython-39.pyc b/torch_utils/ops/__pycache__/fma.cpython-39.pyc index a570d90ba20af0e326584c43d13ffc35d876186d..c32c63818c735ad0537ba552fe5878fc8b31fef8 100644 Binary files a/torch_utils/ops/__pycache__/fma.cpython-39.pyc and b/torch_utils/ops/__pycache__/fma.cpython-39.pyc differ diff --git a/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-39.pyc b/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-39.pyc index 78a7a5f1b3c272209712c8dab4e183011e772639..91f37c24642eefe730c65290706936aea796c501 100644 Binary files a/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-39.pyc and b/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-39.pyc differ diff --git a/torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc b/torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc index 2da9e5ddd9fd8e4195b752e1eb67a365ca8226e7..f811a0d69d79a23f4e828c98ae9c7444bff2431d 100644 Binary files a/torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc and b/torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc differ diff --git a/training/__init__.py b/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/training/coaches/__init__.py b/training/coaches/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/PTI/training/coaches/base_coach.py b/training/coaches/base_coach.py similarity index 94% rename from PTI/training/coaches/base_coach.py rename to training/coaches/base_coach.py index 025e1f77e9458e551a3ee650ae20e3e05c6d325b..20e24813835f22897ae9ca5ccdb75938f852704b 100644 --- a/PTI/training/coaches/base_coach.py +++ b/training/coaches/base_coach.py @@ -42,18 +42,6 @@ class BaseCoach: ] ) - self.years = [str(y) for y in range(1900, 2020, 10)] - # self.years.remove(paths_config.input_data_id[:4]) - if hyperparameters.color_transfer_lambda > 0: - self.siblings = {} - for y in self.years: - with open( - f"/phoenix/S7/wikitime_models/id_child_models/{y}_564.pkl", "rb" - ) as f: - self.siblings[y] = ( - pickle.load(f)["G_ema"].to(global_config.device).eval() - ) - # Initialize loss self.lpips_loss = ( LPIPS(net=hyperparameters.lpips_type).to(global_config.device).eval() @@ -61,7 +49,7 @@ class BaseCoach: self.id_loss = ( id_loss.IDLoss( - "/share/phoenix/nfs04/S7/wikitime_models/model_ir_se50.pth", + paths_config.ir_se50, official=False, ) .to(global_config.device) diff --git a/PTI/training/coaches/multi_id_coach.py b/training/coaches/multi_id_coach.py similarity index 100% rename from PTI/training/coaches/multi_id_coach.py rename to training/coaches/multi_id_coach.py diff --git a/PTI/training/coaches/single_id_coach.py b/training/coaches/single_id_coach.py similarity index 100% rename from PTI/training/coaches/single_id_coach.py rename to training/coaches/single_id_coach.py diff --git a/training/projectors/__init__.py b/training/projectors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/PTI/training/projectors/w_plus_projector.py b/training/projectors/w_plus_projector.py similarity index 100% rename from PTI/training/projectors/w_plus_projector.py rename to training/projectors/w_plus_projector.py diff --git a/PTI/training/projectors/w_projector.py b/training/projectors/w_projector.py similarity index 100% rename from PTI/training/projectors/w_projector.py rename to training/projectors/w_projector.py diff --git a/PTI/utils/ImagesDataset.py b/utils/ImagesDataset.py similarity index 100% rename from PTI/utils/ImagesDataset.py rename to utils/ImagesDataset.py diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/PTI/utils/align_data.py b/utils/align_data.py similarity index 100% rename from PTI/utils/align_data.py rename to utils/align_data.py diff --git a/PTI/utils/alignment.py b/utils/alignment.py similarity index 100% rename from PTI/utils/alignment.py rename to utils/alignment.py diff --git a/PTI/utils/data_utils.py b/utils/data_utils.py similarity index 100% rename from PTI/utils/data_utils.py rename to utils/data_utils.py diff --git a/PTI/utils/log_utils.py b/utils/log_utils.py similarity index 100% rename from PTI/utils/log_utils.py rename to utils/log_utils.py diff --git a/PTI/utils/models_utils.py b/utils/models_utils.py similarity index 100% rename from PTI/utils/models_utils.py rename to utils/models_utils.py diff --git a/utils.py b/utils/utils.py similarity index 100% rename from utils.py rename to utils/utils.py