diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..a63f99ea4342c600458ed1f3ffc23eb4db213c36 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+docs/teaser.jpeg filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..4ac8ac2dd81e2d440ca7ee263ec0f8a1396b9054
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+.idea
+.DS_Store
+pretrained_models/
+shape_predictor_68_face_landmarks.dat
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..146f2b337c321112dd4cf0de55939d6d15605243
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2021 Yuval Alaluf
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/README.md b/README.md
index 96570f28f70f0b8b8aa246a0e78040c3363c1e50..5e8ef7039ebbed5e3d0d092a4160f43de75e63d3 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,330 @@
----
-title: Aging MouthReplace
-emoji: 🦀
-colorFrom: red
-colorTo: green
-sdk: gradio
-sdk_version: 5.3.0
-app_file: app.py
-pinned: false
-license: mit
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# Only a Matter of Style: Age Transformation Using a Style-Based Regression Model (SIGGRAPH 2021)
+
+> The task of age transformation illustrates the change of an individual's appearance over time. Accurately modeling this complex transformation over an input facial image is extremely challenging as it requires making convincing and possibly large changes to facial features and head shape, while still preserving the input identity. In this work, we present an image-to-image translation method that learns to directly encode real facial images into the latent space of a pre-trained unconditional GAN (e.g., StyleGAN) subject to a given aging shift. We employ a pre-trained age regression network used to explicitly guide the encoder to generate the latent codes corresponding to the desired age. In this formulation, our method approaches the continuous aging process as a regression task between the input age and desired target age, providing fine-grained control on the generated image. Moreover, unlike other approaches that operate solely in the latent space using a prior on the path controlling age, our method learns a more disentangled, non-linear path. We demonstrate that the end-to-end nature of our approach, coupled with the rich semantic latent space of StyleGAN, allows for further editing of the generated images. Qualitative and quantitative evaluations show the advantages of our method compared to state-of-the-art approaches.
+
+
+
+
+
+
+
+
+
+Inference Notebook:
+Animation Notebook:
+
+
+
+
+
+
+## Description
+Official Implementation of our Style-based Age Manipulation (SAM) paper for both training and evaluation. SAM
+allows modeling fine-grained age transformation using a single input facial image
+
+
+
+
+
+
+## Table of Contents
+ * [Getting Started](#getting-started)
+ + [Prerequisites](#prerequisites)
+ + [Installation](#installation)
+ * [Pretrained Models](#pretrained-models)
+ * [Training](#training)
+ + [Preparing your Data](#preparing-your-data)
+ + [Training SAM](#training-sam)
+ + [Additional Notes](#additional-notes)
+ * [Notebooks](#notebooks)
+ + [Inference Notebook](#inference-notebook)
+ + [MP4 Notebook](#mp4-notebook)
+ * [Testing](#testing)
+ + [Inference](#inference)
+ + [Side-by-Side Inference](#side-by-side-inference)
+ + [Reference-Guided Inference](#reference-guided-inference)
+ + [Style Mixing](#style-mixing)
+ * [Repository structure](#repository-structure)
+ * [Credits](#credits)
+ * [Acknowledgments](#acknowledgments)
+ * [Citation](#citation)
+
+
+## Getting Started
+### Prerequisites
+- Linux or macOS
+- NVIDIA GPU + CUDA CuDNN (CPU may be possible with some modifications, but is not inherently supported)
+- Python 3
+
+### Installation
+- Dependencies:
+We recommend running this repository using [Anaconda](https://docs.anaconda.com/anaconda/install/).
+All dependencies for defining the environment are provided in `environment/sam_env.yaml`.
+
+## Pretrained Models
+Please download the pretrained aging model from the following links.
+
+| Path | Description
+| :--- | :----------
+|[SAM](https://drive.google.com/file/d/1XyumF6_fdAxFmxpFcmPf-q84LU_22EMC/view?usp=sharing) | SAM trained on the FFHQ dataset for age transformation.
+
+You can run this to download it to the right place:
+
+```
+mkdir pretrained_models
+pip install gdown
+gdown "https://drive.google.com/u/0/uc?id=1XyumF6_fdAxFmxpFcmPf-q84LU_22EMC&export=download" -O pretrained_models/sam_ffhq_aging.pt
+wget "https://github.com/italojs/facial-landmarks-recognition/raw/master/shape_predictor_68_face_landmarks.dat"
+```
+
+In addition, we provide various auxiliary models needed for training your own SAM model from scratch.
+This includes the pretrained pSp encoder model for generating the encodings of the input image and the aging classifier
+used to compute the aging loss during training.
+
+| Path | Description
+| :--- | :----------
+|[pSp Encoder](https://drive.google.com/file/d/1bMTNWkh5LArlaWSc_wa8VKyq2V42T2z0/view?usp=sharing) | pSp taken from [pixel2style2pixel](https://github.com/eladrich/pixel2style2pixel) trained on the FFHQ dataset for StyleGAN inversion.
+|[FFHQ StyleGAN](https://drive.google.com/file/d/1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT/view?usp=sharing) | StyleGAN model pretrained on FFHQ taken from [rosinality](https://github.com/rosinality/stylegan2-pytorch) with 1024x1024 output resolution.
+|[IR-SE50 Model](https://drive.google.com/file/d/1KW7bjndL3QG3sxBbZxreGHigcCCpsDgn/view?usp=sharing) | Pretrained IR-SE50 model taken from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) for use in our ID loss during training.
+|[VGG Age Classifier](https://drive.google.com/file/d/1atzjZm_dJrCmFWCqWlyspSpr3nI6Evsh/view?usp=sharing) | VGG age classifier from DEX and fine-tuned on the FFHQ-Aging dataset for use in our aging loss
+
+By default, we assume that all auxiliary models are downloaded and saved to the directory `pretrained_models`.
+However, you may use your own paths by changing the necessary values in `configs/path_configs.py`.
+
+## Training
+### Preparing your Data
+Please refer to `configs/paths_config.py` to define the necessary data paths and model paths for training and inference.
+Then, refer to `configs/data_configs.py` to define the source/target data paths for the train and test sets as well as the
+transforms to be used for training and inference.
+
+As an example, we can first go to `configs/paths_config.py` and define:
+```
+dataset_paths = {
+ 'ffhq': '/path/to/ffhq/images256x256'
+ 'celeba_test': '/path/to/CelebAMask-HQ/test_img',
+}
+```
+Then, in `configs/data_configs.py`, we define:
+```
+DATASETS = {
+ 'ffhq_aging': {
+ 'transforms': transforms_config.AgingTransforms,
+ 'train_source_root': dataset_paths['ffhq'],
+ 'train_target_root': dataset_paths['ffhq'],
+ 'test_source_root': dataset_paths['celeba_test'],
+ 'test_target_root': dataset_paths['celeba_test'],
+ }
+}
+```
+When defining the datasets for training and inference, we will use the values defined in the above dictionary.
+
+
+### Training SAM
+The main training script can be found in `scripts/train.py`.
+Intermediate training results are saved to `opts.exp_dir`. This includes checkpoints, train outputs, and test outputs.
+Additionally, if you have tensorboard installed, you can visualize tensorboard logs in `opts.exp_dir/logs`.
+
+Training SAM with the settings used in the paper can be done by running the following command:
+```
+python scripts/train.py \
+--dataset_type=ffhq_aging \
+--exp_dir=/path/to/experiment \
+--workers=6 \
+--batch_size=6 \
+--test_batch_size=6 \
+--test_workers=6 \
+--val_interval=2500 \
+--save_interval=10000 \
+--start_from_encoded_w_plus \
+--id_lambda=0.1 \
+--lpips_lambda=0.1 \
+--lpips_lambda_aging=0.1 \
+--lpips_lambda_crop=0.6 \
+--l2_lambda=0.25 \
+--l2_lambda_aging=0.25 \
+--l2_lambda_crop=1 \
+--w_norm_lambda=0.005 \
+--aging_lambda=5 \
+--cycle_lambda=1 \
+--input_nc=4 \
+--target_age=uniform_random \
+--use_weighted_id_loss
+```
+
+### Additional Notes
+- See `options/train_options.py` for all training-specific flags.
+- Note that using the flag `--start_from_encoded_w_plus` requires you to specify the path to the pretrained pSp encoder.
+ By default, this path is taken from `configs.paths_config.model_paths['pretrained_psp']`.
+- If you wish to resume from a specific checkpoint (e.g. a pretrained SAM model), you may do so using `--checkpoint_path`.
+
+
+## Notebooks
+### Inference Notebook
+To help visualize the results of SAM we provide a Jupyter notebook found in `notebooks/inference_playground.ipynb`.
+The notebook will download the pretrained aging model and run inference on the images found in `notebooks/images`.
+
+In addition, [Replicate](https://replicate.ai/) have created a demo for SAM where you can easily upload an image and run SAM on a desired set of ages! Check
+out the demo [here](https://replicate.ai/yuval-alaluf/sam).
+
+### MP4 Notebook
+To show full lifespan results using SAM we provide an additional notebook `notebooks/animation_inference_playground.ipynb` that will
+run aging on multiple ages between 0 and 100 and interpolate between the results to display full aging.
+The results will be saved as an MP4 files in `notebooks/animations` showing the aging and de-aging results.
+
+## Testing
+### Inference
+Having trained your model or if you're using a pretrained SAM model, you can use `scripts/inference.py` to run inference
+on a set of images.
+For example,
+```
+python scripts/inference.py \
+--exp_dir=/path/to/experiment \
+--checkpoint_path=experiment/checkpoints/best_model.pt \
+--data_path=/path/to/test_data \
+--test_batch_size=4 \
+--test_workers=4 \
+--couple_outputs
+--target_age=0,10,20,30,40,50,60,70,80
+```
+Additional notes to consider:
+- During inference, the options used during training are loaded from the saved checkpoint and are then updated using the
+test options passed to the inference script.
+- Adding the flag `--couple_outputs` will save an additional image containing the input and output images side-by-side in the sub-directory
+`inference_coupled`. Otherwise, only the output image is saved to the sub-directory `inference_results`.
+- In the above example, we will run age transformation with target ages 0,10,...,80.
+ - The results of each target age are saved to the sub-directories `inference_results/TARGET_AGE` and `inference_coupled/TARGET_AGE`.
+- By default, the images will be saved at resolution of 1024x1024, the original output size of StyleGAN.
+ - If you wish to save outputs resized to resolutions of 256x256, you can do so by adding the flag `--resize_outputs`.
+
+### Side-by-Side Inference
+The above inference script will save each aging result in a different sub-directory for each target age. Sometimes,
+however, it is more convenient to save all aging results of a given input side-by-side like the following:
+
+
+
+
+
+To do so, we provide a script `inference_side_by_side.py` that works in a similar manner as the regular inference script:
+```
+python scripts/inference_side_by_side.py \
+--exp_dir=/path/to/experiment \
+--checkpoint_path=experiment/checkpoints/best_model.pt \
+--data_path=/path/to/test_data \
+--test_batch_size=4 \
+--test_workers=4 \
+--target_age=0,10,20,30,40,50,60,70,80
+```
+Here, all aging results 0,10,...,80 will be save side-by-side with the original input image.
+
+### Reference-Guided Inference
+In the paper, we demonstrated how one can perform style-mixing on the fine-level style inputs with a reference image
+to control global features such as hair color. For example,
+
+
+
+
+
+To perform style mixing using reference images, we provide the script `reference_guided_inference.py`. Here,
+we first perform aging using the specified target age(s). Then, style mixing is performed using the specified
+reference images and the specified layers. For example, one can run:
+```
+python scripts/reference_guided_inference.py \
+--exp_dir=/path/to/experiment \
+--checkpoint_path=experiment/checkpoints/best_model.pt \
+--data_path=/path/to/test_data \
+--test_batch_size=4 \
+--test_workers=4 \
+--ref_images_paths_file=/path/to/ref_list.txt \
+--latent_mask=8,9 \
+--target_age=50,60,70,80
+```
+Here, the reference images should be specified in the file defined by `--ref_images_paths_file` and should have the
+following format:
+```
+/path/to/reference/1.jpg
+/path/to/reference/2.jpg
+/path/to/reference/3.jpg
+/path/to/reference/4.jpg
+/path/to/reference/5.jpg
+```
+In the above example, we will aging using 4 different target ages. For each target age, we first transform the
+test samples defined by `--data_path` and then perform style mixing on layers 8,9 defined by `--latent_mask`.
+The results of each target age are saved in its own sub-directory.
+
+### Style Mixing
+Instead of performing style mixing using a reference image, you can perform style mixing using randomly generated
+w latent vectors by running the script `style_mixing.py`. This script works in a similar manner to the reference
+guided inference except you do not need to specify the `--ref_images_paths_file` flag.
+
+## Repository structure
+| Path | Description
+| :--- | :---
+| SAM | Repository root folder
+| ├ configs | Folder containing configs defining model/data paths and data transforms
+| ├ criteria | Folder containing various loss criterias for training
+| ├ datasets | Folder with various dataset objects and augmentations
+| ├ docs | Folder containing images displayed in the README
+| ├ environment | Folder containing Anaconda environment used in our experiments
+| ├ models | Folder containing all the models and training objects
+| │ ├ encoders | Folder containing various architecture implementations
+| │ ├ stylegan2 | StyleGAN2 model from [rosinality](https://github.com/rosinality/stylegan2-pytorch)
+| │ ├ psp.py | Implementation of pSp encoder
+| │ └ dex_vgg.py | Implementation of DEX VGG classifier used in computation of aging loss
+| ├ notebook | Folder with jupyter notebook containing SAM inference playground
+| ├ options | Folder with training and test command-line options
+| ├ scripts | Folder with running scripts for training and inference
+| ├ training | Folder with main training logic and Ranger implementation from [lessw2020](https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer)
+| ├ utils | Folder with various utility functions
+| |
+
+
+## Credits
+**StyleGAN2 model and implementation:**
+https://github.com/rosinality/stylegan2-pytorch
+Copyright (c) 2019 Kim Seonghyeon
+License (MIT) https://github.com/rosinality/stylegan2-pytorch/blob/master/LICENSE
+
+**IR-SE50 model and implementations:**
+https://github.com/TreB1eN/InsightFace_Pytorch
+Copyright (c) 2018 TreB1eN
+License (MIT) https://github.com/TreB1eN/InsightFace_Pytorch/blob/master/LICENSE
+
+**Ranger optimizer implementation:**
+https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
+License (Apache License 2.0) https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer/blob/master/LICENSE
+
+**LPIPS model and implementation:**
+https://github.com/S-aiueo32/lpips-pytorch
+Copyright (c) 2020, Sou Uchida
+License (BSD 2-Clause) https://github.com/S-aiueo32/lpips-pytorch/blob/master/LICENSE
+
+**DEX VGG model and implementation:**
+https://github.com/InterDigitalInc/HRFAE
+Copyright (c) 2020, InterDigital R&D France
+https://github.com/InterDigitalInc/HRFAE/blob/master/LICENSE.txt
+
+**pSp model and implementation:**
+https://github.com/eladrich/pixel2style2pixel
+Copyright (c) 2020 Elad Richardson, Yuval Alaluf
+https://github.com/eladrich/pixel2style2pixel/blob/master/LICENSE
+
+## Acknowledgments
+This code borrows heavily from [pixel2style2pixel](https://github.com/eladrich/pixel2style2pixel)
+
+## Citation
+If you use this code for your research, please cite our paper Only a Matter of Style: Age Transformation Using a Style-Based Regression Model :
+
+```
+@article{alaluf2021matter,
+ author = {Alaluf, Yuval and Patashnik, Or and Cohen-Or, Daniel},
+ title = {Only a Matter of Style: Age Transformation Using a Style-Based Regression Model},
+ journal = {ACM Trans. Graph.},
+ issue_date = {August 2021},
+ volume = {40},
+ number = {4},
+ year = {2021},
+ articleno = {45},
+ publisher = {Association for Computing Machinery},
+ url = {https://doi.org/10.1145/3450626.3459805}
+}
+```
diff --git a/cog.yaml b/cog.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..424070944b5c1148a34eef82587d00c28374522f
--- /dev/null
+++ b/cog.yaml
@@ -0,0 +1,25 @@
+image: "r8.im/yuval-alaluf/sam"
+build:
+ gpu: true
+ python_version: "3.8"
+ system_packages:
+ - "cmake"
+ - "libgl1-mesa-glx"
+ - "libglib2.0-0"
+ - "ninja-build"
+ python_packages:
+ - "Pillow==8.3.1"
+ - "cmake==3.21.1"
+ - "dlib==19.22.1"
+ - "imageio==2.9.0"
+ - "ipython==7.21.0"
+ - "matplotlib==3.1.3"
+ - "numpy==1.21.1"
+ - "opencv-python==4.5.3.56"
+ - "scipy==1.4.1"
+ - "tensorboard==2.2.1"
+ - "torch==1.8.0"
+ - "torchvision==0.9.0"
+ - "tqdm==4.42.1"
+predict: "predict.py:Predictor"
+
diff --git a/configs/__init__.py b/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/configs/__pycache__/__init__.cpython-310.pyc b/configs/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..71d104a20628dae221a4dd06b4445f7c2f185d6a
Binary files /dev/null and b/configs/__pycache__/__init__.cpython-310.pyc differ
diff --git a/configs/__pycache__/paths_config.cpython-310.pyc b/configs/__pycache__/paths_config.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce5947935b8d485412efe6eb280f0528b08204da
Binary files /dev/null and b/configs/__pycache__/paths_config.cpython-310.pyc differ
diff --git a/configs/data_configs.py b/configs/data_configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c9a2e95b34f85a4aa429becc1722b2f1c5bb7ba
--- /dev/null
+++ b/configs/data_configs.py
@@ -0,0 +1,13 @@
+from configs import transforms_config
+from configs.paths_config import dataset_paths
+
+
+DATASETS = {
+ 'ffhq_aging': {
+ 'transforms': transforms_config.AgingTransforms,
+ 'train_source_root': dataset_paths['ffhq'],
+ 'train_target_root': dataset_paths['ffhq'],
+ 'test_source_root': dataset_paths['celeba_test'],
+ 'test_target_root': dataset_paths['celeba_test'],
+ }
+}
diff --git a/configs/paths_config.py b/configs/paths_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f01a3abc8534e58f6c09ba64dcdd8a5daec4cfb
--- /dev/null
+++ b/configs/paths_config.py
@@ -0,0 +1,12 @@
+dataset_paths = {
+ 'celeba_test': '',
+ 'ffhq': '',
+}
+
+model_paths = {
+ 'pretrained_psp_encoder': 'pretrained_models/psp_ffhq_encode.pt',
+ 'ir_se50': 'pretrained_models/model_ir_se50.pth',
+ 'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt',
+ 'shape_predictor': 'shape_predictor_68_face_landmarks.dat',
+ 'age_predictor': 'pretrained_models/dex_age_classifier.pth'
+}
diff --git a/configs/transforms_config.py b/configs/transforms_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d15214cfcf6e51e06b7f7687bcd96cd324db5931
--- /dev/null
+++ b/configs/transforms_config.py
@@ -0,0 +1,37 @@
+from abc import abstractmethod
+import torchvision.transforms as transforms
+
+
+class TransformsConfig(object):
+
+ def __init__(self, opts):
+ self.opts = opts
+
+ @abstractmethod
+ def get_transforms(self):
+ pass
+
+
+class AgingTransforms(TransformsConfig):
+
+ def __init__(self, opts):
+ super(AgingTransforms, self).__init__(opts)
+
+ def get_transforms(self):
+ transforms_dict = {
+ 'transform_gt_train': transforms.Compose([
+ transforms.Resize((256, 256)),
+ transforms.RandomHorizontalFlip(0.5),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
+ 'transform_source': None,
+ 'transform_test': transforms.Compose([
+ transforms.Resize((256, 256)),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
+ 'transform_inference': transforms.Compose([
+ transforms.Resize((256, 256)),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
+ }
+ return transforms_dict
diff --git a/criteria/__init__.py b/criteria/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/criteria/aging_loss.py b/criteria/aging_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..5748d25883a5338f99ab119640715bf8bb2a43c3
--- /dev/null
+++ b/criteria/aging_loss.py
@@ -0,0 +1,59 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from configs.paths_config import model_paths
+from models.dex_vgg import VGG
+
+
+class AgingLoss(nn.Module):
+
+ def __init__(self, opts):
+ super(AgingLoss, self).__init__()
+ self.age_net = VGG()
+ ckpt = torch.load(model_paths['age_predictor'], map_location="cpu")['state_dict']
+ ckpt = {k.replace('-', '_'): v for k, v in ckpt.items()}
+ self.age_net.load_state_dict(ckpt)
+ self.age_net.cuda()
+ self.age_net.eval()
+ self.min_age = 0
+ self.max_age = 100
+ self.opts = opts
+
+ def __get_predicted_age(self, age_pb):
+ predict_age_pb = F.softmax(age_pb)
+ predict_age = torch.zeros(age_pb.size(0)).type_as(predict_age_pb)
+ for i in range(age_pb.size(0)):
+ for j in range(age_pb.size(1)):
+ predict_age[i] += j * predict_age_pb[i][j]
+ return predict_age
+
+ def extract_ages(self, x):
+ x = F.interpolate(x, size=(224, 224), mode='bilinear')
+ predict_age_pb = self.age_net(x)['fc8']
+ predicted_age = self.__get_predicted_age(predict_age_pb)
+ return predicted_age
+
+ def forward(self, y_hat, y, target_ages, id_logs, label=None):
+ n_samples = y.shape[0]
+
+ if id_logs is None:
+ id_logs = []
+
+ input_ages = self.extract_ages(y) / 100.
+ output_ages = self.extract_ages(y_hat) / 100.
+
+ for i in range(n_samples):
+ # if id logs for the same exists, update the dictionary
+ if len(id_logs) > i:
+ id_logs[i].update({f'input_age_{label}': float(input_ages[i]) * 100,
+ f'output_age_{label}': float(output_ages[i]) * 100,
+ f'target_age_{label}': float(target_ages[i]) * 100})
+ # otherwise, create a new entry for the sample
+ else:
+ id_logs.append({f'input_age_{label}': float(input_ages[i]) * 100,
+ f'output_age_{label}': float(output_ages[i]) * 100,
+ f'target_age_{label}': float(target_ages[i]) * 100})
+
+ loss = F.mse_loss(output_ages, target_ages)
+ return loss, id_logs
diff --git a/criteria/id_loss.py b/criteria/id_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..26d0c105ed444f7e1cbc4ba49710e8495def9b64
--- /dev/null
+++ b/criteria/id_loss.py
@@ -0,0 +1,55 @@
+import torch
+from torch import nn
+from configs.paths_config import model_paths
+from models.encoders.model_irse import Backbone
+
+
+class IDLoss(nn.Module):
+ def __init__(self):
+ super(IDLoss, self).__init__()
+ print('Loading ResNet ArcFace')
+ self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
+ self.facenet.load_state_dict(torch.load(model_paths['ir_se50']))
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
+ self.facenet.eval()
+
+ def extract_feats(self, x):
+ x = x[:, :, 35:223, 32:220] # Crop interesting region
+ x = self.face_pool(x)
+ x_feats = self.facenet(x)
+ return x_feats
+
+ def forward(self, y_hat, y, x, label=None, weights=None):
+ n_samples = x.shape[0]
+ x_feats = self.extract_feats(x)
+ y_feats = self.extract_feats(y)
+ y_hat_feats = self.extract_feats(y_hat)
+ y_feats = y_feats.detach()
+ total_loss = 0
+ sim_improvement = 0
+ id_logs = []
+ count = 0
+ for i in range(n_samples):
+ diff_target = y_hat_feats[i].dot(y_feats[i])
+ diff_input = y_hat_feats[i].dot(x_feats[i])
+ diff_views = y_feats[i].dot(x_feats[i])
+
+ if label is None:
+ id_logs.append({'diff_target': float(diff_target),
+ 'diff_input': float(diff_input),
+ 'diff_views': float(diff_views)})
+ else:
+ id_logs.append({f'diff_target_{label}': float(diff_target),
+ f'diff_input_{label}': float(diff_input),
+ f'diff_views_{label}': float(diff_views)})
+
+ loss = 1 - diff_target
+ if weights is not None:
+ loss = weights[i] * loss
+
+ total_loss += loss
+ id_diff = float(diff_target) - float(diff_views)
+ sim_improvement += id_diff
+ count += 1
+
+ return total_loss / count, sim_improvement / count, id_logs
diff --git a/criteria/lpips/__init__.py b/criteria/lpips/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/criteria/lpips/lpips.py b/criteria/lpips/lpips.py
new file mode 100644
index 0000000000000000000000000000000000000000..1add6acc84c1c04cfcb536cf31ec5acdf24b716b
--- /dev/null
+++ b/criteria/lpips/lpips.py
@@ -0,0 +1,35 @@
+import torch
+import torch.nn as nn
+
+from criteria.lpips.networks import get_network, LinLayers
+from criteria.lpips.utils import get_state_dict
+
+
+class LPIPS(nn.Module):
+ r"""Creates a criterion that measures
+ Learned Perceptual Image Patch Similarity (LPIPS).
+ Arguments:
+ net_type (str): the network type to compare the features:
+ 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
+ version (str): the version of LPIPS. Default: 0.1.
+ """
+ def __init__(self, net_type: str = 'alex', version: str = '0.1'):
+
+ assert version in ['0.1'], 'v0.1 is only supported now'
+
+ super(LPIPS, self).__init__()
+
+ # pretrained network
+ self.net = get_network(net_type).to("cuda")
+
+ # linear layers
+ self.lin = LinLayers(self.net.n_channels_list).to("cuda")
+ self.lin.load_state_dict(get_state_dict(net_type, version))
+
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
+ feat_x, feat_y = self.net(x), self.net(y)
+
+ diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
+ res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
+
+ return torch.sum(torch.cat(res, 0)) / x.shape[0]
diff --git a/criteria/lpips/networks.py b/criteria/lpips/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a0d13ad2d560278f16586da68d3a5eadb26e746
--- /dev/null
+++ b/criteria/lpips/networks.py
@@ -0,0 +1,96 @@
+from typing import Sequence
+
+from itertools import chain
+
+import torch
+import torch.nn as nn
+from torchvision import models
+
+from criteria.lpips.utils import normalize_activation
+
+
+def get_network(net_type: str):
+ if net_type == 'alex':
+ return AlexNet()
+ elif net_type == 'squeeze':
+ return SqueezeNet()
+ elif net_type == 'vgg':
+ return VGG16()
+ else:
+ raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
+
+
+class LinLayers(nn.ModuleList):
+ def __init__(self, n_channels_list: Sequence[int]):
+ super(LinLayers, self).__init__([
+ nn.Sequential(
+ nn.Identity(),
+ nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
+ ) for nc in n_channels_list
+ ])
+
+ for param in self.parameters():
+ param.requires_grad = False
+
+
+class BaseNet(nn.Module):
+ def __init__(self):
+ super(BaseNet, self).__init__()
+
+ # register buffer
+ self.register_buffer(
+ 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
+ self.register_buffer(
+ 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
+
+ def set_requires_grad(self, state: bool):
+ for param in chain(self.parameters(), self.buffers()):
+ param.requires_grad = state
+
+ def z_score(self, x: torch.Tensor):
+ return (x - self.mean) / self.std
+
+ def forward(self, x: torch.Tensor):
+ x = self.z_score(x)
+
+ output = []
+ for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
+ x = layer(x)
+ if i in self.target_layers:
+ output.append(normalize_activation(x))
+ if len(output) == len(self.target_layers):
+ break
+ return output
+
+
+class SqueezeNet(BaseNet):
+ def __init__(self):
+ super(SqueezeNet, self).__init__()
+
+ self.layers = models.squeezenet1_1(True).features
+ self.target_layers = [2, 5, 8, 10, 11, 12, 13]
+ self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
+
+ self.set_requires_grad(False)
+
+
+class AlexNet(BaseNet):
+ def __init__(self):
+ super(AlexNet, self).__init__()
+
+ self.layers = models.alexnet(True).features
+ self.target_layers = [2, 5, 8, 10, 12]
+ self.n_channels_list = [64, 192, 384, 256, 256]
+
+ self.set_requires_grad(False)
+
+
+class VGG16(BaseNet):
+ def __init__(self):
+ super(VGG16, self).__init__()
+
+ self.layers = models.vgg16(True).features
+ self.target_layers = [4, 9, 16, 23, 30]
+ self.n_channels_list = [64, 128, 256, 512, 512]
+
+ self.set_requires_grad(False)
\ No newline at end of file
diff --git a/criteria/lpips/utils.py b/criteria/lpips/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d15a0983775810ef6239c561c67939b2b9ee3b5
--- /dev/null
+++ b/criteria/lpips/utils.py
@@ -0,0 +1,30 @@
+from collections import OrderedDict
+
+import torch
+
+
+def normalize_activation(x, eps=1e-10):
+ norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
+ return x / (norm_factor + eps)
+
+
+def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
+ # build url
+ url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
+ + f'master/lpips/weights/v{version}/{net_type}.pth'
+
+ # download
+ old_state_dict = torch.hub.load_state_dict_from_url(
+ url, progress=True,
+ map_location=None if torch.cuda.is_available() else torch.device('cpu')
+ )
+
+ # rename keys
+ new_state_dict = OrderedDict()
+ for key, val in old_state_dict.items():
+ new_key = key
+ new_key = new_key.replace('lin', '')
+ new_key = new_key.replace('model.', '')
+ new_state_dict[new_key] = val
+
+ return new_state_dict
diff --git a/criteria/w_norm.py b/criteria/w_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a71ba4498a5319e2382b55ed456498362acf151
--- /dev/null
+++ b/criteria/w_norm.py
@@ -0,0 +1,14 @@
+import torch
+from torch import nn
+
+
+class WNormLoss(nn.Module):
+
+ def __init__(self, opts):
+ super(WNormLoss, self).__init__()
+ self.opts = opts
+
+ def forward(self, latent, latent_avg=None):
+ if self.opts.start_from_latent_avg or self.opts.start_from_encoded_w_plus:
+ latent = latent - latent_avg
+ return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0]
diff --git a/datasets/__init__.py b/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/datasets/__pycache__/__init__.cpython-310.pyc b/datasets/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1624116c6ae6d3c98fba78f2547bab5225efff52
Binary files /dev/null and b/datasets/__pycache__/__init__.cpython-310.pyc differ
diff --git a/datasets/__pycache__/augmentations.cpython-310.pyc b/datasets/__pycache__/augmentations.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c77a1a41a5b06c1320e7efe167bdacd4052868c
Binary files /dev/null and b/datasets/__pycache__/augmentations.cpython-310.pyc differ
diff --git a/datasets/augmentations.py b/datasets/augmentations.py
new file mode 100644
index 0000000000000000000000000000000000000000..efdb02eb045680d4ac8e4217c9b0b72fd1096db2
--- /dev/null
+++ b/datasets/augmentations.py
@@ -0,0 +1,24 @@
+import numpy as np
+import torch
+
+
+class AgeTransformer(object):
+
+ def __init__(self, target_age):
+ self.target_age = target_age
+
+ def __call__(self, img):
+ img = self.add_aging_channel(img)
+ return img
+
+ def add_aging_channel(self, img):
+ target_age = self.__get_target_age()
+ target_age = int(target_age) / 100 # normalize aging amount to be in range [-1,1]
+ img = torch.cat((img, target_age * torch.ones((1, img.shape[1], img.shape[2]))))
+ return img
+
+ def __get_target_age(self):
+ if self.target_age == "uniform_random":
+ return np.random.randint(low=0., high=101, size=1)[0]
+ else:
+ return self.target_age
diff --git a/datasets/images_dataset.py b/datasets/images_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..62bb3e3eb85f3841696bac02fa5fb217488a43cd
--- /dev/null
+++ b/datasets/images_dataset.py
@@ -0,0 +1,33 @@
+from torch.utils.data import Dataset
+from PIL import Image
+from utils import data_utils
+
+
+class ImagesDataset(Dataset):
+
+ def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None):
+ self.source_paths = sorted(data_utils.make_dataset(source_root))
+ self.target_paths = sorted(data_utils.make_dataset(target_root))
+ self.source_transform = source_transform
+ self.target_transform = target_transform
+ self.opts = opts
+
+ def __len__(self):
+ return len(self.source_paths)
+
+ def __getitem__(self, index):
+ from_path = self.source_paths[index]
+ from_im = Image.open(from_path)
+ from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L')
+
+ to_path = self.target_paths[index]
+ to_im = Image.open(to_path).convert('RGB')
+ if self.target_transform:
+ to_im = self.target_transform(to_im)
+
+ if self.source_transform:
+ from_im = self.source_transform(from_im)
+ else:
+ from_im = to_im
+
+ return from_im, to_im
diff --git a/datasets/inference_dataset.py b/datasets/inference_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..cacf9bcdf0a2b45295ac2b8d25924ae9033dc7a1
--- /dev/null
+++ b/datasets/inference_dataset.py
@@ -0,0 +1,29 @@
+from torch.utils.data import Dataset
+from PIL import Image
+from utils import data_utils
+
+
+class InferenceDataset(Dataset):
+
+ def __init__(self, root=None, paths_list=None, opts=None, transform=None, return_path=False):
+ if paths_list is None:
+ self.paths = sorted(data_utils.make_dataset(root))
+ else:
+ self.paths = data_utils.make_dataset_from_paths_list(paths_list)
+ self.transform = transform
+ self.opts = opts
+ self.return_path = return_path
+
+ def __len__(self):
+ return len(self.paths)
+
+ def __getitem__(self, index):
+ from_path = self.paths[index]
+ from_im = Image.open(from_path)
+ from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L')
+ if self.transform:
+ from_im = self.transform(from_im)
+ if self.return_path:
+ return from_im, from_path
+ else:
+ return from_im
diff --git a/docs/1005_style_mixing.jpg b/docs/1005_style_mixing.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7686522a74ec6f16589469feeb4d4a2f8286bbec
Binary files /dev/null and b/docs/1005_style_mixing.jpg differ
diff --git a/docs/1936.jpg b/docs/1936.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cfc78d3f6b0890f96e4bd782ab9faee5f01a4761
Binary files /dev/null and b/docs/1936.jpg differ
diff --git a/docs/2195.jpg b/docs/2195.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2d9f4ece8a34e3fa597ace333eb4f651159413ac
Binary files /dev/null and b/docs/2195.jpg differ
diff --git a/docs/866.jpg b/docs/866.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1e671f906095465adc97adf8d345e579ac0eb98a
Binary files /dev/null and b/docs/866.jpg differ
diff --git a/docs/teaser.jpeg b/docs/teaser.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..e1b93faae09ed912be5b1ce38b6d40c5713a8b33
--- /dev/null
+++ b/docs/teaser.jpeg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:86f384cf500021d8fde41778c21394460ce7b01194fd49de9a27dccd31339ada
+size 3189368
diff --git a/environment/sam_env.yaml b/environment/sam_env.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f3cf596978b0355447ae5414aa4be5f7585ac484
--- /dev/null
+++ b/environment/sam_env.yaml
@@ -0,0 +1,36 @@
+name: sam_env
+channels:
+ - conda-forge
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=main
+ - ca-certificates=2020.4.5.1=hecc5488_0
+ - certifi=2020.4.5.1=py36h9f0ad1d_0
+ - libedit=3.1.20181209=hc058e9b_0
+ - libffi=3.2.1=hd88cf55_4
+ - libgcc-ng=9.1.0=hdf63c60_0
+ - libstdcxx-ng=9.1.0=hdf63c60_0
+ - ncurses=6.2=he6710b0_1
+ - ninja=1.10.0=hc9558a2_0
+ - openssl=1.1.1g=h516909a_0
+ - pip=20.0.2=py36_3
+ - python=3.6.7=h0371630_0
+ - python_abi=3.6=1_cp36m
+ - readline=7.0=h7b6447c_5
+ - setuptools=46.4.0=py36_0
+ - sqlite=3.31.1=h62c20be_1
+ - tk=8.6.8=hbc83047_0
+ - wheel=0.34.2=py36_0
+ - xz=5.2.5=h7b6447c_0
+ - zlib=1.2.11=h7b6447c_3
+ - pip:
+ - scipy==1.4.1
+ - matplotlib==3.2.1
+ - tqdm==4.46.0
+ - numpy==1.18.4
+ - opencv-python==4.2.0.34
+ - pillow==7.1.2
+ - tensorboard==2.2.1
+ - torch==1.6.0
+ - torchvision==0.4.2
+prefix: ~/anaconda3/envs/sam_env
\ No newline at end of file
diff --git a/licenses/LICENSE_InterDigitalInc b/licenses/LICENSE_InterDigitalInc
new file mode 100644
index 0000000000000000000000000000000000000000..9a0c71bddff9a53a6bc5bf492b174e67b3bd7731
--- /dev/null
+++ b/licenses/LICENSE_InterDigitalInc
@@ -0,0 +1,150 @@
+LIMITED SOFTWARE EVALUATION LICENSE AGREEMENT
+
+
+
+This Limited Software Evaluation License Agreement (the “Agreement”) is entered into as of April 9th 2020, (“Effective Date”)
+
+The following limited software evaluation license agreement (“the Agreement”) constitute an agreement between you (the “licensee”) and InterDigital R&D France, a French company existing and organized under the laws of France with its registered offices located at 975 avenue des champs blancs 35510 Cesson-Sévigné, FRANCE (hereinafter “InterDigital”)
+This Agreement governs the download and use of the Software (as defined below). Your use of the Software is subject to the terms and conditions set forth in this Agreement. By installing, using, accessing or copying the Software, you hereby irrevocably accept the terms and conditions of this Agreement. If you do not accept all or parts of the terms and conditions of this Agreement you cannot install, use, access nor copy the Software
+
+ Article 1. Definitions
+
+“Affiliate” as used herein shall mean any entity that, directly or indirectly, through one or more intermediates, is controlled by, controls, or is under common control with InterDigital or The Licensee, as the case may be. For purposes of this definition only, the term “control” means the possession of the power to direct or cause the direction of the management and policies of an entity, whether by ownership of voting stock or partnership interest, by contract, or otherwise, including direct or indirect ownership of more than fifty percent (50%) of the voting interest in the entity in question.
+
+“Authorized Purpose” means any use of the Software for research on the Software and evaluation of the Software exclusively, and academic research using the Software without any commercial use. For the avoidance of doubt, a commercial use includes, but is not limited to:
+- using the Software in advertisements of any kind,
+- licensing or selling of the Software,
+- use the Software to provide any service to any third Party
+- use the Software to develop a competitive product of the Software
+
+“Documentation” means textual materials delivered by InterDigital to the Licensee pursuant to this Agreement relating to the Software, in written or electronic format, including but not limited to: technical reference manuals, technical notes, user manuals, and application guides.
+
+“Limited Period” means the life of the copyright owned by InterDigital on the Software in each and every country where such copyright would exist.
+
+“Intellectual Property Rights” means all copyrights, trademarks, trade secrets, patents, mask works and other intellectual property rights recognized in any jurisdiction worldwide, including all applications and registrations with respect thereto.
+
+"Open Source software" shall mean any software, including where appropriate, any and all modifications, derivative works, enhancements, upgrades, improvements, fixed bugs, and/or statically linked to the source code of such software, released under a free software license, that requires as a condition of royalty-free usage, copy, modification and/or redistribution of the Open Source Software to:
+• Redistribute the Open Source Software royalty-free, and/or;
+• Redistribute the Open Source Software under the same license/distribution terms as those contained in the open source or free software license under which it has originally been released and/or;
+• Release to the public, disclose or otherwise make available the source code of the Open Source Software.
+
+For purposes of the Agreement, by means of example and without limitation, any software that is released or distributed under any of the following licenses shall be qualified as Open Source Software: (A) GNU General Public License (GPL), (B) GNU Lesser/Library GPL (LGPL), (C) the Artistic License, (D) the Mozilla Public License, (E) the Common Public License, (F) the Sun Community Source License (SCSL), (G) the Sun Industry Standards Source License (SISSL), (H) BSD License, (I) MIT License, (J) Apache Software License, (K) Open SSL License, (L) IBM Public License, (M) Open Software License.
+
+“Software” means any computer programming code, in object and/or source version, and related Documentation delivered by InterDigital to the Licensee pursuant to this Agreement as described in Exhibit A attached and incorporated herein by reference.
+
+ Article 2. License
+
+InterDigital grants Licensee a free, worldwide, non-exclusive, license on copyright owned on the Software to download, use, modify and reproduce solely for the Authorized Purpose for the Limited Period.
+
+The Licensee shall not pay any royalty, license fee or maintenance fee, or other fee of any nature under this Agreement.
+
+The Licensee shall have the right to correct, adapt, modify, reverse engineer, disassemble, decompile and any action leading to the transformation of Software provided that such action is made to accomplish the Authorized Purpose.
+
+Licensee shall have the right to make a demonstration of the Software, provided that it is in the Purpose and provided that Licensee shall maintain control of the Software at all time. This includes the control of any computer or server on which the Software is installed: no third party shall have access to such computer or server under any circumstances. No computer nor server containing the Software will be left in the possession of any third Party.
+
+ Article 3. Restrictions on use of the Software
+
+Licensee shall not remove, obscure or modify any copyright, trademark or other proprietary rights notices, marks or labels contained on or within the Software, falsify or delete any author attributions, legal notices or other labels of the origin or source of the material.
+
+Licensee shall not have the right to distribute the Software, either modified or not, to any third Party.
+
+The rights granted here above do not include any rights to automatically obtain any upgrade or update of the Software, acquired or otherwise made available by InterDigital. Such deliverance shall be discussed on a case by case basis by the Parties.
+
+ Article 4. Ownership
+
+Title to and ownership of the Software, the Documentation and/or any Intellectual Property Right protecting the Software or/and the Documentation shall, at all times, remain with InterDigital. Licensee agrees that except for the rights granted on copyright on the Software set forth in Section 2 above, in no event does anything in this Agreement grant, provide or convey any other rights, immunities or interest in or to any Intellectual Property Rights (including especially patents) of InterDigital or any of its Affiliates whether by implication, estoppel or otherwise.
+
+
+ Article 5. Publication/Communication
+
+Any publication or oral communication resulting from the use of the Software shall be elaborated in good faith and shall not be driven by a deliberate will to denigrate InterDigital or any of its products. In any publication and on any support joined to an oral communication (for instance a PowerPoint document) resulting from the use of the Software, the following statement shall be inserted:
+
+“HRFAE is an InterDigital product”
+
+And in any publication, the latest publication about the software shall be properly cited. The latest publication currently is:
+"Arxiv preprint (ref to come shortly)”
+
+In any oral communication resulting from the use of the Software, the Licensee shall orally indicate that the Software is InterDigital’s property.
+
+ Article 6. No Warranty - Disclaimer
+
+THE SOFTWARE AND DOCUMENTATION ARE PROVIDED TO LICENSEE ON AN “AS IS” BASIS. INTERDIGITAL MAKES NO WARRANTY THAT THE LICENSED TECHNOLOGY WILL OPERATE ON ANY PARTICULAR HARDWARE, PLATFORM, OR ENVIRONMENT. THERE IS NO WARRANTY THAT THE OPERATION OF THE LICENSED TECHNOLOGY SHALL BE UNINTERRUPTED, WITHOUT BUGS OR ERROR FREE. THE SOFTWARE AND DOCUMENTATION ARE PROVIDED HEREUNDER WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY IMPLIED LIABILITIES AND WARRANTIES OF NONINFRINGEMENT OF INTELLECTUAL PROPERTY, FREEDOM FROM INHERENT DEFECTS, CONFORMITY TO A SAMPLE OR MODEL, MERCHANTABILITY, FITNESS AND/OR SUITABILITY FOR A SPECIFIC OR GENERAL PURPOSE AND THOSE ARISING BY STATUTE OR BY LAW, OR FROM A CAUSE OF DEALING OR USAGE OF TRADE.
+
+InterDigital shall not be obliged to perform any modifications, derivative works, enhancements, upgrades, updates or improvements of the Software or to fix any bug that could arise.
+
+Hence, the Licensee uses the Software at his own cost, risks and responsibility. InterDigital shall not be liable for any damage that could arise to Licensee by using the Software, either in accordance with this Agreement or not.
+
+InterDigital shall not be liable for any consequential or indirect losses, including any indirect loss of profits, revenues, business, and/or anticipated savings, whether or not in the contemplation of the Parties at the time of entering into the Agreement unless expressly set out in the Agreement, or arising from gross negligence, willful misconduct or fraud.
+
+Licensee agrees that it will defend, indemnify and hold harmless InterDigital and its Affiliates against any and all losses, damages, costs and expenses arising from a breach by the Licensee of any of its obligations or representations hereunder, including, without limitation, any third party, and/or any claims in connection with any such breach and/or any use of the Software, including any claim from third party arising from access, use or any other activity in relation to this Software.
+
+The Licensee shall not make any warranty, representation, or commitment on behalf of InterDigital to any other third party.
+
+ Article 7. Open Source Software
+
+InterDigital hereby notifies the Licensee, and the Licensee hereby acknowledges and accepts, that the Software contains Open Source Software. The list of such Open Source Software is enclosed in exhibit B and the relevant license are contained at the root of the Software when downloaded. Hence, the Licensee shall comply with such license and agree on its terms on at its own risks.
+
+The Licensee hereby represents, warrants and covenants to InterDigital that The Licensee’s use of the Software shall not result in the Contamination of all or part of the Software, directly or indirectly, or of any Intellectual Property of InterDigital or its Affiliates.
+
+Contamination effect shall mean that the licensing terms under which one Open Source software, distinct from the Software, is released would also apply, by viral effect, to the software to which such Open Source software is linked to, combined with or otherwise connected to.
+
+ Article 8. No Future Contract Obligation
+
+Neither this Agreement nor the furnishing of the Software, nor any other Confidential Information shall be construed to obligate either party to: (a) enter into any further agreement or negotiation concerning the deployment of the Software; (b) refrain from entering into any agreement or negotiation with any other third party regarding the same or any other subject matter; or (c) refrain from pursuing its business in whatever manner it elects even if this involves competing with the other party.
+
+ Article 9. Term and Termination
+
+This Agreement shall terminate at the end of the Limited Period, unless earlier terminated by either party on the ground of material breach by the other party, which breach is not remedied after thirty (30) days advance written notice, specifying the breach with reasonable particularity and referencing this Agreement.
+
+ Article 10. General Provisions
+
+12.1 Severability. If any provision of this Agreement shall be held to be in contravention of applicable law, this Agreement shall be construed as if such provision were not a part thereof, and in all other respects the terms hereof shall remain in full force and effect.
+
+12.2 Governing Law. Regardless of the place of execution, delivery, performance or any other aspect of this Agreement, this Agreement and all of the rights of the parties under this Agreement shall be governed by, construed under and enforced in accordance with the substantive law of the France without regard to conflicts of law principles. In case of a dispute that could not be settled amicably, the courts of Nanterre shall be exclusively competent.
+
+12.3 Survival. The provisions of articles 1, 3, 4, 6, 7, 9, 10.2 and 10.6 shall survive termination of this Agreement.
+12.4 Assignment. InterDigital may assign this license to any third Party. Such assignment will be announced on the website as defined in article 5. Licensee may not assign this agreement to any third party without the previous written agreement from InterDigital.
+
+12.5 Entire Agreement. This Agreement constitutes the entire agreement between the parties hereto with respect to the subject matter hereof and supersedes any prior agreements or understanding.
+
+12.6 Notices. To have legal effect, notices must be provided by registered or certified mail, return receipt requested, to the representatives of InterDigital at the following address:
+
+InterDigital
+Legal Dept
+975 avenue des champs blancs
+35510 Cesson-Sévigné
+FRANCE
+
+=======================================================================
+
+Exhibit A
+Software
+
+
+The Software is comprised of the following software and Documentation:
+
+- README.md file that explains the content of the software and the procedure to use it.
+- Source python files, as well as pretrained models
+
+=======================================================================
+
+Exhibit B
+Open Source licenses
+
+
+PIL http://www.pythonware.com/products/pil/license.htm
+
+numpy https://numpy.org/license.html
+
+tensorboardX https://github.com/lanpa/tensorboardX/blob/master/LICENSE
+
+pytorch https://github.com/pytorch/pytorch/blob/master/LICENSE
+
+torchvision https://github.com/pytorch/vision/blob/master/LICENSE
+
+tensorboard_logger https://github.com/TeamHG-Memex/tensorboard_logger/blob/master/LICENSE
+
+argparse https://github.com/ThomasWaldmann/argparse/blob/master/LICENSE.txt
+
+yaml https://github.com/yaml/pyyaml/blob/master/LICENSE
+
diff --git a/licenses/LICENSE_S-aiueo32 b/licenses/LICENSE_S-aiueo32
new file mode 100644
index 0000000000000000000000000000000000000000..81e7b18bd6fcfd5a81e08d0bcb192be28cd6723c
--- /dev/null
+++ b/licenses/LICENSE_S-aiueo32
@@ -0,0 +1,25 @@
+BSD 2-Clause License
+
+Copyright (c) 2020, Sou Uchida
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
diff --git a/licenses/LICENSE_TreB1eN b/licenses/LICENSE_TreB1eN
new file mode 100644
index 0000000000000000000000000000000000000000..1c7d3585c795c41d2334036b01a8d660a5235671
--- /dev/null
+++ b/licenses/LICENSE_TreB1eN
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2018 TreB1eN
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/licenses/LICENSE_eladrich b/licenses/LICENSE_eladrich
new file mode 100644
index 0000000000000000000000000000000000000000..f1c322f32d65b5f21e2d19d74bbd513b5d0ed85c
--- /dev/null
+++ b/licenses/LICENSE_eladrich
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2020 Elad Richardson, Yuval Alaluf
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/licenses/LICENSE_lessw2020 b/licenses/LICENSE_lessw2020
new file mode 100644
index 0000000000000000000000000000000000000000..f49a4e16e68b128803cc2dcea614603632b04eac
--- /dev/null
+++ b/licenses/LICENSE_lessw2020
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/licenses/LICENSE_rosinality b/licenses/LICENSE_rosinality
new file mode 100644
index 0000000000000000000000000000000000000000..81da3fce025084b7005be5405d3842fbea29b5ba
--- /dev/null
+++ b/licenses/LICENSE_rosinality
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2019 Kim Seonghyeon
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/__pycache__/__init__.cpython-310.pyc b/models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d3a22aeb8340faafe72b5609763a6d0aeec994b0
Binary files /dev/null and b/models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/models/__pycache__/psp.cpython-310.pyc b/models/__pycache__/psp.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..acccde82e88fbff1a8bdc120455169586e5ba3ae
Binary files /dev/null and b/models/__pycache__/psp.cpython-310.pyc differ
diff --git a/models/dex_vgg.py b/models/dex_vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..dca947508d25643a8419c0f570b6873173fd57e0
--- /dev/null
+++ b/models/dex_vgg.py
@@ -0,0 +1,65 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+"""
+VGG implementation from [InterDigitalInc](https://github.com/InterDigitalInc/HRFAE/blob/master/nets.py)
+"""
+
+class VGG(nn.Module):
+ def __init__(self, pool='max'):
+ super(VGG, self).__init__()
+ # vgg modules
+ self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
+ self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
+ self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
+ self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
+ self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
+ self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
+ self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
+ self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
+ self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
+ self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
+ self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
+ self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
+ self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
+ self.fc6 = nn.Linear(25088, 4096, bias=True)
+ self.fc7 = nn.Linear(4096, 4096, bias=True)
+ self.fc8_101 = nn.Linear(4096, 101, bias=True)
+ if pool == 'max':
+ self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
+ elif pool == 'avg':
+ self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
+ self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
+ self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2)
+ self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2)
+ self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2)
+
+ def forward(self, x):
+ out = {}
+ out['r11'] = F.relu(self.conv1_1(x))
+ out['r12'] = F.relu(self.conv1_2(out['r11']))
+ out['p1'] = self.pool1(out['r12'])
+ out['r21'] = F.relu(self.conv2_1(out['p1']))
+ out['r22'] = F.relu(self.conv2_2(out['r21']))
+ out['p2'] = self.pool2(out['r22'])
+ out['r31'] = F.relu(self.conv3_1(out['p2']))
+ out['r32'] = F.relu(self.conv3_2(out['r31']))
+ out['r33'] = F.relu(self.conv3_3(out['r32']))
+ out['p3'] = self.pool3(out['r33'])
+ out['r41'] = F.relu(self.conv4_1(out['p3']))
+ out['r42'] = F.relu(self.conv4_2(out['r41']))
+ out['r43'] = F.relu(self.conv4_3(out['r42']))
+ out['p4'] = self.pool4(out['r43'])
+ out['r51'] = F.relu(self.conv5_1(out['p4']))
+ out['r52'] = F.relu(self.conv5_2(out['r51']))
+ out['r53'] = F.relu(self.conv5_3(out['r52']))
+ out['p5'] = self.pool5(out['r53'])
+ out['p5'] = out['p5'].view(out['p5'].size(0), -1)
+ out['fc6'] = F.relu(self.fc6(out['p5']))
+ out['fc7'] = F.relu(self.fc7(out['fc6']))
+ out['fc8'] = self.fc8_101(out['fc7'])
+ return out
\ No newline at end of file
diff --git a/models/encoders/__init__.py b/models/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/encoders/__pycache__/__init__.cpython-310.pyc b/models/encoders/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b83acda51100a3986bde816b6296c353acec7a5
Binary files /dev/null and b/models/encoders/__pycache__/__init__.cpython-310.pyc differ
diff --git a/models/encoders/__pycache__/helpers.cpython-310.pyc b/models/encoders/__pycache__/helpers.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b88b1c409b7d463da95322088000a0f53b304ff7
Binary files /dev/null and b/models/encoders/__pycache__/helpers.cpython-310.pyc differ
diff --git a/models/encoders/__pycache__/psp_encoders.cpython-310.pyc b/models/encoders/__pycache__/psp_encoders.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9419c5ff173fc5b9531c61ca20b4e53054a6e1f7
Binary files /dev/null and b/models/encoders/__pycache__/psp_encoders.cpython-310.pyc differ
diff --git a/models/encoders/helpers.py b/models/encoders/helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..b51fdf97141407fcc1c9d249a086ddbfd042469f
--- /dev/null
+++ b/models/encoders/helpers.py
@@ -0,0 +1,119 @@
+from collections import namedtuple
+import torch
+from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
+
+"""
+ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
+"""
+
+
+class Flatten(Module):
+ def forward(self, input):
+ return input.view(input.size(0), -1)
+
+
+def l2_norm(input, axis=1):
+ norm = torch.norm(input, 2, axis, True)
+ output = torch.div(input, norm)
+ return output
+
+
+class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
+ """ A named tuple describing a ResNet block. """
+
+
+def get_block(in_channel, depth, num_units, stride=2):
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
+
+
+def get_blocks(num_layers):
+ if num_layers == 50:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=4),
+ get_block(in_channel=128, depth=256, num_units=14),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ elif num_layers == 100:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=13),
+ get_block(in_channel=128, depth=256, num_units=30),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ elif num_layers == 152:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=8),
+ get_block(in_channel=128, depth=256, num_units=36),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ else:
+ raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
+ return blocks
+
+
+class SEModule(Module):
+ def __init__(self, channels, reduction):
+ super(SEModule, self).__init__()
+ self.avg_pool = AdaptiveAvgPool2d(1)
+ self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
+ self.relu = ReLU(inplace=True)
+ self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
+ self.sigmoid = Sigmoid()
+
+ def forward(self, x):
+ module_input = x
+ x = self.avg_pool(x)
+ x = self.fc1(x)
+ x = self.relu(x)
+ x = self.fc2(x)
+ x = self.sigmoid(x)
+ return module_input * x
+
+
+class bottleneck_IR(Module):
+ def __init__(self, in_channel, depth, stride):
+ super(bottleneck_IR, self).__init__()
+ if in_channel == depth:
+ self.shortcut_layer = MaxPool2d(1, stride)
+ else:
+ self.shortcut_layer = Sequential(
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
+ BatchNorm2d(depth)
+ )
+ self.res_layer = Sequential(
+ BatchNorm2d(in_channel),
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
+ )
+
+ def forward(self, x):
+ shortcut = self.shortcut_layer(x)
+ res = self.res_layer(x)
+ return res + shortcut
+
+
+class bottleneck_IR_SE(Module):
+ def __init__(self, in_channel, depth, stride):
+ super(bottleneck_IR_SE, self).__init__()
+ if in_channel == depth:
+ self.shortcut_layer = MaxPool2d(1, stride)
+ else:
+ self.shortcut_layer = Sequential(
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
+ BatchNorm2d(depth)
+ )
+ self.res_layer = Sequential(
+ BatchNorm2d(in_channel),
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
+ PReLU(depth),
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
+ BatchNorm2d(depth),
+ SEModule(depth, 16)
+ )
+
+ def forward(self, x):
+ shortcut = self.shortcut_layer(x)
+ res = self.res_layer(x)
+ return res + shortcut
diff --git a/models/encoders/model_irse.py b/models/encoders/model_irse.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e65d76d69d2fba71b35f6587f659c73f506eea
--- /dev/null
+++ b/models/encoders/model_irse.py
@@ -0,0 +1,48 @@
+from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
+from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
+
+"""
+Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
+"""
+
+
+class Backbone(Module):
+ def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
+ super(Backbone, self).__init__()
+ assert input_size in [112, 224], "input_size should be 112 or 224"
+ assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
+ assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
+ blocks = get_blocks(num_layers)
+ if mode == 'ir':
+ unit_module = bottleneck_IR
+ elif mode == 'ir_se':
+ unit_module = bottleneck_IR_SE
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
+ BatchNorm2d(64),
+ PReLU(64))
+ if input_size == 112:
+ self.output_layer = Sequential(BatchNorm2d(512),
+ Dropout(drop_ratio),
+ Flatten(),
+ Linear(512 * 7 * 7, 512),
+ BatchNorm1d(512, affine=affine))
+ else:
+ self.output_layer = Sequential(BatchNorm2d(512),
+ Dropout(drop_ratio),
+ Flatten(),
+ Linear(512 * 14 * 14, 512),
+ BatchNorm1d(512, affine=affine))
+
+ modules = []
+ for block in blocks:
+ for bottleneck in block:
+ modules.append(unit_module(bottleneck.in_channel,
+ bottleneck.depth,
+ bottleneck.stride))
+ self.body = Sequential(*modules)
+
+ def forward(self, x):
+ x = self.input_layer(x)
+ x = self.body(x)
+ x = self.output_layer(x)
+ return l2_norm(x)
diff --git a/models/encoders/psp_encoders.py b/models/encoders/psp_encoders.py
new file mode 100644
index 0000000000000000000000000000000000000000..06b7bc58da908c5ad0eaefb5d8b2b0407ce5a8bb
--- /dev/null
+++ b/models/encoders/psp_encoders.py
@@ -0,0 +1,114 @@
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module
+
+from models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE
+from models.stylegan2.model import EqualLinear
+
+
+class GradualStyleBlock(Module):
+ def __init__(self, in_c, out_c, spatial):
+ super(GradualStyleBlock, self).__init__()
+ self.out_c = out_c
+ self.spatial = spatial
+ num_pools = int(np.log2(spatial))
+ modules = []
+ modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()]
+ for i in range(num_pools - 1):
+ modules += [
+ Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()
+ ]
+ self.convs = nn.Sequential(*modules)
+ self.linear = EqualLinear(out_c, out_c, lr_mul=1)
+
+ def forward(self, x):
+ x = self.convs(x)
+ x = x.view(-1, self.out_c)
+ x = self.linear(x)
+ return x
+
+
+class GradualStyleEncoder(Module):
+ def __init__(self, num_layers, mode='ir', n_styles=18, opts=None):
+ super(GradualStyleEncoder, self).__init__()
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
+ blocks = get_blocks(num_layers)
+ if mode == 'ir':
+ unit_module = bottleneck_IR
+ elif mode == 'ir_se':
+ unit_module = bottleneck_IR_SE
+ self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
+ BatchNorm2d(64),
+ PReLU(64))
+ modules = []
+ for block in blocks:
+ for bottleneck in block:
+ modules.append(unit_module(bottleneck.in_channel,
+ bottleneck.depth,
+ bottleneck.stride))
+ self.body = Sequential(*modules)
+
+ self.styles = nn.ModuleList()
+ self.style_count = n_styles
+ self.coarse_ind = 3
+ self.middle_ind = 7
+ for i in range(self.style_count):
+ if i < self.coarse_ind:
+ style = GradualStyleBlock(512, 512, 16)
+ elif i < self.middle_ind:
+ style = GradualStyleBlock(512, 512, 32)
+ else:
+ style = GradualStyleBlock(512, 512, 64)
+ self.styles.append(style)
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
+
+ def _upsample_add(self, x, y):
+ '''Upsample and add two feature maps.
+ Args:
+ x: (Variable) top feature map to be upsampled.
+ y: (Variable) lateral feature map.
+ Returns:
+ (Variable) added feature map.
+ Note in PyTorch, when input size is odd, the upsampled feature map
+ with `F.upsample(..., scale_factor=2, mode='nearest')`
+ maybe not equal to the lateral feature map size.
+ e.g.
+ original input size: [N,_,15,15] ->
+ conv2d feature map size: [N,_,8,8] ->
+ upsampled feature map size: [N,_,16,16]
+ So we choose bilinear upsample which supports arbitrary output sizes.
+ '''
+ _, _, H, W = y.size()
+ return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
+
+ def forward(self, x):
+ x = self.input_layer(x)
+
+ latents = []
+ modulelist = list(self.body._modules.values())
+ for i, l in enumerate(modulelist):
+ x = l(x)
+ if i == 6:
+ c1 = x
+ elif i == 20:
+ c2 = x
+ elif i == 23:
+ c3 = x
+
+ for j in range(self.coarse_ind):
+ latents.append(self.styles[j](c3))
+
+ p2 = self._upsample_add(c3, self.latlayer1(c2))
+ for j in range(self.coarse_ind, self.middle_ind):
+ latents.append(self.styles[j](p2))
+
+ p1 = self._upsample_add(p2, self.latlayer2(c1))
+ for j in range(self.middle_ind, self.style_count):
+ latents.append(self.styles[j](p1))
+
+ out = torch.stack(latents, dim=1)
+ return out
diff --git a/models/psp.py b/models/psp.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee9454be7fd52bb56ea2a92ef46372db5611f2e0
--- /dev/null
+++ b/models/psp.py
@@ -0,0 +1,131 @@
+"""
+This file defines the core research contribution
+"""
+import copy
+from argparse import Namespace
+
+import torch
+from torch import nn
+import math
+
+from configs.paths_config import model_paths
+from models.encoders import psp_encoders
+from models.stylegan2.model import Generator
+
+
+class pSp(nn.Module):
+
+ def __init__(self, opts):
+ super(pSp, self).__init__()
+ self.set_opts(opts)
+ self.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2
+ # Define architecture
+ self.encoder = self.set_encoder()
+ self.decoder = Generator(self.opts.output_size, 512, 8)
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
+ # Load weights if needed
+ self.load_weights()
+
+ def set_encoder(self):
+ return psp_encoders.GradualStyleEncoder(50, 'ir_se', self.n_styles, self.opts)
+
+ def load_weights(self):
+ if self.opts.checkpoint_path is not None:
+ print(f'Loading SAM from checkpoint: {self.opts.checkpoint_path}')
+ ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
+ self.encoder.load_state_dict(self.__get_keys(ckpt, 'encoder'), strict=False)
+ self.decoder.load_state_dict(self.__get_keys(ckpt, 'decoder'), strict=True)
+ if self.opts.start_from_encoded_w_plus:
+ self.pretrained_encoder = self.__get_pretrained_psp_encoder()
+ self.pretrained_encoder.load_state_dict(self.__get_keys(ckpt, 'pretrained_encoder'), strict=True)
+ self.__load_latent_avg(ckpt)
+ else:
+ print('Loading encoders weights from irse50!')
+ encoder_ckpt = torch.load(model_paths['ir_se50'])
+ # Transfer the RGB input of the irse50 network to the first 3 input channels of SAM's encoder
+ if self.opts.input_nc != 3:
+ shape = encoder_ckpt['input_layer.0.weight'].shape
+ altered_input_layer = torch.randn(shape[0], self.opts.input_nc, shape[2], shape[3], dtype=torch.float32)
+ altered_input_layer[:, :3, :, :] = encoder_ckpt['input_layer.0.weight']
+ encoder_ckpt['input_layer.0.weight'] = altered_input_layer
+ self.encoder.load_state_dict(encoder_ckpt, strict=False)
+ print(f'Loading decoder weights from pretrained path: {self.opts.stylegan_weights}')
+ ckpt = torch.load(self.opts.stylegan_weights)
+ self.decoder.load_state_dict(ckpt['g_ema'], strict=True)
+ self.__load_latent_avg(ckpt, repeat=self.n_styles)
+ if self.opts.start_from_encoded_w_plus:
+ self.pretrained_encoder = self.__load_pretrained_psp_encoder()
+ self.pretrained_encoder.eval()
+
+ def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
+ inject_latent=None, return_latents=False, alpha=None, input_is_full=False):
+ if input_code:
+ codes = x
+ else:
+ codes = self.encoder(x)
+ # normalize with respect to the center of an average face
+ if self.opts.start_from_latent_avg:
+ codes = codes + self.latent_avg
+ # normalize with respect to the latent of the encoded image of pretrained pSp encoder
+ elif self.opts.start_from_encoded_w_plus:
+ with torch.no_grad():
+ encoded_latents = self.pretrained_encoder(x[:, :-1, :, :])
+ encoded_latents = encoded_latents + self.latent_avg
+ codes = codes + encoded_latents
+
+ if latent_mask is not None:
+ for i in latent_mask:
+ if inject_latent is not None:
+ if alpha is not None:
+ codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
+ else:
+ codes[:, i] = inject_latent[:, i]
+ else:
+ codes[:, i] = 0
+
+ input_is_latent = (not input_code) or (input_is_full)
+ images, result_latent = self.decoder([codes],
+ input_is_latent=input_is_latent,
+ randomize_noise=randomize_noise,
+ return_latents=return_latents)
+
+ if resize:
+ images = self.face_pool(images)
+
+ if return_latents:
+ return images, result_latent
+ else:
+ return images
+
+ def set_opts(self, opts):
+ self.opts = opts
+
+ def __load_latent_avg(self, ckpt, repeat=None):
+ if 'latent_avg' in ckpt:
+ self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
+ if repeat is not None:
+ self.latent_avg = self.latent_avg.repeat(repeat, 1)
+ else:
+ self.latent_avg = None
+
+ def __get_pretrained_psp_encoder(self):
+ opts_encoder = vars(copy.deepcopy(self.opts))
+ opts_encoder['input_nc'] = 3
+ opts_encoder = Namespace(**opts_encoder)
+ encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.n_styles, opts_encoder)
+ return encoder
+
+ def __load_pretrained_psp_encoder(self):
+ print(f'Loading pSp encoder from checkpoint: {self.opts.pretrained_psp_path}')
+ ckpt = torch.load(self.opts.pretrained_psp_path, map_location='cpu')
+ encoder_ckpt = self.__get_keys(ckpt, name='encoder')
+ encoder = self.__get_pretrained_psp_encoder()
+ encoder.load_state_dict(encoder_ckpt, strict=False)
+ return encoder
+
+ @staticmethod
+ def __get_keys(d, name):
+ if 'state_dict' in d:
+ d = d['state_dict']
+ d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
+ return d_filt
diff --git a/models/stylegan2/__init__.py b/models/stylegan2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/stylegan2/__pycache__/__init__.cpython-310.pyc b/models/stylegan2/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8adf83680a732532ef000086757cf39cebbb9f75
Binary files /dev/null and b/models/stylegan2/__pycache__/__init__.cpython-310.pyc differ
diff --git a/models/stylegan2/__pycache__/model.cpython-310.pyc b/models/stylegan2/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..828ac3e7ec4cc4017d788c8cb93f75f94626212f
Binary files /dev/null and b/models/stylegan2/__pycache__/model.cpython-310.pyc differ
diff --git a/models/stylegan2/model.py b/models/stylegan2/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..72a7a7afcbd06e3c21e126c5e87376a1ce750493
--- /dev/null
+++ b/models/stylegan2/model.py
@@ -0,0 +1,671 @@
+import math
+import random
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
+
+
+class PixelNorm(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, input):
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
+
+
+def make_kernel(k):
+ k = torch.tensor(k, dtype=torch.float32)
+
+ if k.ndim == 1:
+ k = k[None, :] * k[:, None]
+
+ k /= k.sum()
+
+ return k
+
+
+class Upsample(nn.Module):
+ def __init__(self, kernel, factor=2):
+ super().__init__()
+
+ self.factor = factor
+ kernel = make_kernel(kernel) * (factor ** 2)
+ self.register_buffer('kernel', kernel)
+
+ p = kernel.shape[0] - factor
+
+ pad0 = (p + 1) // 2 + factor - 1
+ pad1 = p // 2
+
+ self.pad = (pad0, pad1)
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
+
+ return out
+
+
+class Downsample(nn.Module):
+ def __init__(self, kernel, factor=2):
+ super().__init__()
+
+ self.factor = factor
+ kernel = make_kernel(kernel)
+ self.register_buffer('kernel', kernel)
+
+ p = kernel.shape[0] - factor
+
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ self.pad = (pad0, pad1)
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
+
+ return out
+
+
+class Blur(nn.Module):
+ def __init__(self, kernel, pad, upsample_factor=1):
+ super().__init__()
+
+ kernel = make_kernel(kernel)
+
+ if upsample_factor > 1:
+ kernel = kernel * (upsample_factor ** 2)
+
+ self.register_buffer('kernel', kernel)
+
+ self.pad = pad
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
+
+ return out
+
+
+class EqualConv2d(nn.Module):
+ def __init__(
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
+ ):
+ super().__init__()
+
+ self.weight = nn.Parameter(
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
+ )
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
+
+ self.stride = stride
+ self.padding = padding
+
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_channel))
+
+ else:
+ self.bias = None
+
+ def forward(self, input):
+ out = F.conv2d(
+ input,
+ self.weight * self.scale,
+ bias=self.bias,
+ stride=self.stride,
+ padding=self.padding,
+ )
+
+ return out
+
+ def __repr__(self):
+ return (
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
+ )
+
+
+class EqualLinear(nn.Module):
+ def __init__(
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
+ ):
+ super().__init__()
+
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
+
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
+
+ else:
+ self.bias = None
+
+ self.activation = activation
+
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
+ self.lr_mul = lr_mul
+
+ def forward(self, input):
+ if self.activation:
+ out = F.linear(input, self.weight * self.scale)
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
+
+ else:
+ out = F.linear(
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
+ )
+
+ return out
+
+ def __repr__(self):
+ return (
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
+ )
+
+
+class ScaledLeakyReLU(nn.Module):
+ def __init__(self, negative_slope=0.2):
+ super().__init__()
+
+ self.negative_slope = negative_slope
+
+ def forward(self, input):
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
+
+ return out * math.sqrt(2)
+
+
+class ModulatedConv2d(nn.Module):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim,
+ demodulate=True,
+ upsample=False,
+ downsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ ):
+ super().__init__()
+
+ self.eps = 1e-8
+ self.kernel_size = kernel_size
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+ self.upsample = upsample
+ self.downsample = downsample
+
+ if upsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
+ pad0 = (p + 1) // 2 + factor - 1
+ pad1 = p // 2 + 1
+
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
+
+ if downsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
+
+ fan_in = in_channel * kernel_size ** 2
+ self.scale = 1 / math.sqrt(fan_in)
+ self.padding = kernel_size // 2
+
+ self.weight = nn.Parameter(
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
+ )
+
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
+
+ self.demodulate = demodulate
+
+ def __repr__(self):
+ return (
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
+ f'upsample={self.upsample}, downsample={self.downsample})'
+ )
+
+ def forward(self, input, style):
+ batch, in_channel, height, width = input.shape
+
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
+ weight = self.scale * self.weight * style
+
+ if self.demodulate:
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
+
+ weight = weight.view(
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
+ )
+
+ if self.upsample:
+ input = input.view(1, batch * in_channel, height, width)
+ weight = weight.view(
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
+ )
+ weight = weight.transpose(1, 2).reshape(
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
+ )
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
+ _, _, height, width = out.shape
+ out = out.view(batch, self.out_channel, height, width)
+ out = self.blur(out)
+
+ elif self.downsample:
+ input = self.blur(input)
+ _, _, height, width = input.shape
+ input = input.view(1, batch * in_channel, height, width)
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
+ _, _, height, width = out.shape
+ out = out.view(batch, self.out_channel, height, width)
+
+ else:
+ input = input.view(1, batch * in_channel, height, width)
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
+ _, _, height, width = out.shape
+ out = out.view(batch, self.out_channel, height, width)
+
+ return out
+
+
+class NoiseInjection(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ self.weight = nn.Parameter(torch.zeros(1))
+
+ def forward(self, image, noise=None):
+ if noise is None:
+ batch, _, height, width = image.shape
+ noise = image.new_empty(batch, 1, height, width).normal_()
+
+ return image + self.weight * noise
+
+
+class ConstantInput(nn.Module):
+ def __init__(self, channel, size=4):
+ super().__init__()
+
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
+
+ def forward(self, input):
+ batch = input.shape[0]
+ out = self.input.repeat(batch, 1, 1, 1)
+
+ return out
+
+
+class StyledConv(nn.Module):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim,
+ upsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ demodulate=True,
+ ):
+ super().__init__()
+
+ self.conv = ModulatedConv2d(
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim,
+ upsample=upsample,
+ blur_kernel=blur_kernel,
+ demodulate=demodulate,
+ )
+
+ self.noise = NoiseInjection()
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
+ # self.activate = ScaledLeakyReLU(0.2)
+ self.activate = FusedLeakyReLU(out_channel)
+
+ def forward(self, input, style, noise=None):
+ out = self.conv(input, style)
+ out = self.noise(out, noise=noise)
+ # out = out + self.bias
+ out = self.activate(out)
+
+ return out
+
+
+class ToRGB(nn.Module):
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
+ super().__init__()
+
+ if upsample:
+ self.upsample = Upsample(blur_kernel)
+
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
+
+ def forward(self, input, style, skip=None):
+ out = self.conv(input, style)
+ out = out + self.bias
+
+ if skip is not None:
+ skip = self.upsample(skip)
+
+ out = out + skip
+
+ return out
+
+
+class Generator(nn.Module):
+ def __init__(
+ self,
+ size,
+ style_dim,
+ n_mlp,
+ channel_multiplier=2,
+ blur_kernel=[1, 3, 3, 1],
+ lr_mlp=0.01,
+ ):
+ super().__init__()
+
+ self.size = size
+
+ self.style_dim = style_dim
+
+ layers = [PixelNorm()]
+
+ for i in range(n_mlp):
+ layers.append(
+ EqualLinear(
+ style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
+ )
+ )
+
+ self.style = nn.Sequential(*layers)
+
+ self.channels = {
+ 4: 512,
+ 8: 512,
+ 16: 512,
+ 32: 512,
+ 64: 256 * channel_multiplier,
+ 128: 128 * channel_multiplier,
+ 256: 64 * channel_multiplier,
+ 512: 32 * channel_multiplier,
+ 1024: 16 * channel_multiplier,
+ }
+
+ self.input = ConstantInput(self.channels[4])
+ self.conv1 = StyledConv(
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
+ )
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
+
+ self.log_size = int(math.log(size, 2))
+ self.num_layers = (self.log_size - 2) * 2 + 1
+
+ self.convs = nn.ModuleList()
+ self.upsamples = nn.ModuleList()
+ self.to_rgbs = nn.ModuleList()
+ self.noises = nn.Module()
+
+ in_channel = self.channels[4]
+
+ for layer_idx in range(self.num_layers):
+ res = (layer_idx + 5) // 2
+ shape = [1, 1, 2 ** res, 2 ** res]
+ self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
+
+ for i in range(3, self.log_size + 1):
+ out_channel = self.channels[2 ** i]
+
+ self.convs.append(
+ StyledConv(
+ in_channel,
+ out_channel,
+ 3,
+ style_dim,
+ upsample=True,
+ blur_kernel=blur_kernel,
+ )
+ )
+
+ self.convs.append(
+ StyledConv(
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
+ )
+ )
+
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
+
+ in_channel = out_channel
+
+ self.n_latent = self.log_size * 2 - 2
+
+ def make_noise(self):
+ device = self.input.input.device
+
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
+
+ for i in range(3, self.log_size + 1):
+ for _ in range(2):
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
+
+ return noises
+
+ def mean_latent(self, n_latent):
+ latent_in = torch.randn(
+ n_latent, self.style_dim, device=self.input.input.device
+ )
+ latent = self.style(latent_in).mean(0, keepdim=True)
+
+ return latent
+
+ def get_latent(self, input):
+ return self.style(input)
+
+ def forward(
+ self,
+ styles,
+ return_latents=False,
+ return_features=False,
+ inject_index=None,
+ truncation=1,
+ truncation_latent=None,
+ input_is_latent=False,
+ noise=None,
+ randomize_noise=True,
+ ):
+ if not input_is_latent:
+ styles = [self.style(s) for s in styles]
+
+ if noise is None:
+ if randomize_noise:
+ noise = [None] * self.num_layers
+ else:
+ noise = [
+ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
+ ]
+
+ if truncation < 1:
+ style_t = []
+
+ for style in styles:
+ style_t.append(
+ truncation_latent + truncation * (style - truncation_latent)
+ )
+
+ styles = style_t
+
+ if len(styles) < 2:
+ inject_index = self.n_latent
+
+ if styles[0].ndim < 3:
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+ else:
+ latent = styles[0]
+
+ else:
+ if inject_index is None:
+ inject_index = random.randint(1, self.n_latent - 1)
+
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
+
+ latent = torch.cat([latent, latent2], 1)
+
+ out = self.input(latent)
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
+
+ skip = self.to_rgb1(out, latent[:, 1])
+
+ i = 1
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
+ ):
+ out = conv1(out, latent[:, i], noise=noise1)
+ out = conv2(out, latent[:, i + 1], noise=noise2)
+ skip = to_rgb(out, latent[:, i + 2], skip)
+
+ i += 2
+
+ image = skip
+
+ if return_latents:
+ return image, latent
+ elif return_features:
+ return image, out
+ else:
+ return image, None
+
+
+class ConvLayer(nn.Sequential):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ downsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ bias=True,
+ activate=True,
+ ):
+ layers = []
+
+ if downsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
+
+ stride = 2
+ self.padding = 0
+
+ else:
+ stride = 1
+ self.padding = kernel_size // 2
+
+ layers.append(
+ EqualConv2d(
+ in_channel,
+ out_channel,
+ kernel_size,
+ padding=self.padding,
+ stride=stride,
+ bias=bias and not activate,
+ )
+ )
+
+ if activate:
+ if bias:
+ layers.append(FusedLeakyReLU(out_channel))
+
+ else:
+ layers.append(ScaledLeakyReLU(0.2))
+
+ super().__init__(*layers)
+
+
+class ResBlock(nn.Module):
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
+ super().__init__()
+
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
+
+ self.skip = ConvLayer(
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
+ )
+
+ def forward(self, input):
+ out = self.conv1(input)
+ out = self.conv2(out)
+
+ skip = self.skip(input)
+ out = (out + skip) / math.sqrt(2)
+
+ return out
+
+
+class Discriminator(nn.Module):
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
+ super().__init__()
+
+ channels = {
+ 4: 512,
+ 8: 512,
+ 16: 512,
+ 32: 512,
+ 64: 256 * channel_multiplier,
+ 128: 128 * channel_multiplier,
+ 256: 64 * channel_multiplier,
+ 512: 32 * channel_multiplier,
+ 1024: 16 * channel_multiplier,
+ }
+
+ convs = [ConvLayer(3, channels[size], 1)]
+
+ log_size = int(math.log(size, 2))
+
+ in_channel = channels[size]
+
+ for i in range(log_size, 2, -1):
+ out_channel = channels[2 ** (i - 1)]
+
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
+
+ in_channel = out_channel
+
+ self.convs = nn.Sequential(*convs)
+
+ self.stddev_group = 4
+ self.stddev_feat = 1
+
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
+ self.final_linear = nn.Sequential(
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
+ EqualLinear(channels[4], 1),
+ )
+
+ def forward(self, input):
+ out = self.convs(input)
+
+ batch, channel, height, width = out.shape
+ group = min(batch, self.stddev_group)
+ stddev = out.view(group, -1, self.stddev_feat, channel // self.stddev_feat, height, width)
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
+ stddev = stddev.repeat(group, 1, height, width)
+ out = torch.cat([out, stddev], 1)
+
+ out = self.final_conv(out)
+
+ out = out.view(batch, -1)
+ out = self.final_linear(out)
+
+ return out
diff --git a/models/stylegan2/op/__init__.py b/models/stylegan2/op/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0918d92285955855be89f00096b888ee5597ce3
--- /dev/null
+++ b/models/stylegan2/op/__init__.py
@@ -0,0 +1,2 @@
+from .fused_act import FusedLeakyReLU, fused_leaky_relu
+from .upfirdn2d import upfirdn2d
diff --git a/models/stylegan2/op/__pycache__/__init__.cpython-310.pyc b/models/stylegan2/op/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7d747ff9c29348084fdac1986c33743ca3992e07
Binary files /dev/null and b/models/stylegan2/op/__pycache__/__init__.cpython-310.pyc differ
diff --git a/models/stylegan2/op/__pycache__/fused_act.cpython-310.pyc b/models/stylegan2/op/__pycache__/fused_act.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..86b1f1981827066dd129a26438e5ac2f392a98b8
Binary files /dev/null and b/models/stylegan2/op/__pycache__/fused_act.cpython-310.pyc differ
diff --git a/models/stylegan2/op/__pycache__/upfirdn2d.cpython-310.pyc b/models/stylegan2/op/__pycache__/upfirdn2d.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..136b6ca6b362405ba87b853a0a6da02539246748
Binary files /dev/null and b/models/stylegan2/op/__pycache__/upfirdn2d.cpython-310.pyc differ
diff --git a/models/stylegan2/op/fused_act.py b/models/stylegan2/op/fused_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..973a84fffde53668d31397da5fb993bbc95f7be0
--- /dev/null
+++ b/models/stylegan2/op/fused_act.py
@@ -0,0 +1,85 @@
+import os
+
+import torch
+from torch import nn
+from torch.autograd import Function
+from torch.utils.cpp_extension import load
+
+module_path = os.path.dirname(__file__)
+fused = load(
+ 'fused',
+ sources=[
+ os.path.join(module_path, 'fused_bias_act.cpp'),
+ os.path.join(module_path, 'fused_bias_act_kernel.cu'),
+ ],
+)
+
+
+class FusedLeakyReLUFunctionBackward(Function):
+ @staticmethod
+ def forward(ctx, grad_output, out, negative_slope, scale):
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ empty = grad_output.new_empty(0)
+
+ grad_input = fused.fused_bias_act(
+ grad_output, empty, out, 3, 1, negative_slope, scale
+ )
+
+ dim = [0]
+
+ if grad_input.ndim > 2:
+ dim += list(range(2, grad_input.ndim))
+
+ grad_bias = grad_input.sum(dim).detach()
+
+ return grad_input, grad_bias
+
+ @staticmethod
+ def backward(ctx, gradgrad_input, gradgrad_bias):
+ out, = ctx.saved_tensors
+ gradgrad_out = fused.fused_bias_act(
+ gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
+ )
+
+ return gradgrad_out, None, None, None
+
+
+class FusedLeakyReLUFunction(Function):
+ @staticmethod
+ def forward(ctx, input, bias, negative_slope, scale):
+ empty = input.new_empty(0)
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ out, = ctx.saved_tensors
+
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
+ grad_output, out, ctx.negative_slope, ctx.scale
+ )
+
+ return grad_input, grad_bias, None, None
+
+
+class FusedLeakyReLU(nn.Module):
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
+ super().__init__()
+
+ self.bias = nn.Parameter(torch.zeros(channel))
+ self.negative_slope = negative_slope
+ self.scale = scale
+
+ def forward(self, input):
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
+
+
+def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
diff --git a/models/stylegan2/op/fused_bias_act.cpp b/models/stylegan2/op/fused_bias_act.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..02be898f970bcc8ea297867fcaa4e71b24b3d949
--- /dev/null
+++ b/models/stylegan2/op/fused_bias_act.cpp
@@ -0,0 +1,21 @@
+#include
+
+
+torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale);
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale) {
+ CHECK_CUDA(input);
+ CHECK_CUDA(bias);
+
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
+}
\ No newline at end of file
diff --git a/models/stylegan2/op/fused_bias_act_kernel.cu b/models/stylegan2/op/fused_bias_act_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..c9fa56fea7ede7072dc8925cfb0148f136eb85b8
--- /dev/null
+++ b/models/stylegan2/op/fused_bias_act_kernel.cu
@@ -0,0 +1,99 @@
+// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, visit
+// https://nvlabs.github.io/stylegan2/license.html
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+
+template
+static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
+
+ scalar_t zero = 0.0;
+
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
+ scalar_t x = p_x[xi];
+
+ if (use_bias) {
+ x += p_b[(xi / step_b) % size_b];
+ }
+
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
+
+ scalar_t y;
+
+ switch (act * 10 + grad) {
+ default:
+ case 10: y = x; break;
+ case 11: y = x; break;
+ case 12: y = 0.0; break;
+
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
+ case 32: y = 0.0; break;
+ }
+
+ out[xi] = y * scale;
+ }
+}
+
+
+torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale) {
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+
+ auto x = input.contiguous();
+ auto b = bias.contiguous();
+ auto ref = refer.contiguous();
+
+ int use_bias = b.numel() ? 1 : 0;
+ int use_ref = ref.numel() ? 1 : 0;
+
+ int size_x = x.numel();
+ int size_b = b.numel();
+ int step_b = 1;
+
+ for (int i = 1 + 1; i < x.dim(); i++) {
+ step_b *= x.size(i);
+ }
+
+ int loop_x = 4;
+ int block_size = 4 * 32;
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
+
+ auto y = torch::empty_like(x);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
+ fused_bias_act_kernel<<>>(
+ y.data_ptr(),
+ x.data_ptr(),
+ b.data_ptr(),
+ ref.data_ptr(),
+ act,
+ grad,
+ alpha,
+ scale,
+ loop_x,
+ size_x,
+ step_b,
+ size_b,
+ use_bias,
+ use_ref
+ );
+ });
+
+ return y;
+}
\ No newline at end of file
diff --git a/models/stylegan2/op/upfirdn2d.cpp b/models/stylegan2/op/upfirdn2d.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..d2e633dc896433c205e18bc3e455539192ff968e
--- /dev/null
+++ b/models/stylegan2/op/upfirdn2d.cpp
@@ -0,0 +1,23 @@
+#include
+
+
+torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
+ int up_x, int up_y, int down_x, int down_y,
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1);
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
+ int up_x, int up_y, int down_x, int down_y,
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
+ CHECK_CUDA(input);
+ CHECK_CUDA(kernel);
+
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
+}
\ No newline at end of file
diff --git a/models/stylegan2/op/upfirdn2d.py b/models/stylegan2/op/upfirdn2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9cb52219689592e2745600abb19fad02740a139
--- /dev/null
+++ b/models/stylegan2/op/upfirdn2d.py
@@ -0,0 +1,184 @@
+import os
+
+import torch
+from torch.autograd import Function
+from torch.utils.cpp_extension import load
+
+module_path = os.path.dirname(__file__)
+upfirdn2d_op = load(
+ 'upfirdn2d',
+ sources=[
+ os.path.join(module_path, 'upfirdn2d.cpp'),
+ os.path.join(module_path, 'upfirdn2d_kernel.cu'),
+ ],
+)
+
+
+class UpFirDn2dBackward(Function):
+ @staticmethod
+ def forward(
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
+ ):
+ up_x, up_y = up
+ down_x, down_y = down
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
+
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
+
+ grad_input = upfirdn2d_op.upfirdn2d(
+ grad_output,
+ grad_kernel,
+ down_x,
+ down_y,
+ up_x,
+ up_y,
+ g_pad_x0,
+ g_pad_x1,
+ g_pad_y0,
+ g_pad_y1,
+ )
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
+
+ ctx.save_for_backward(kernel)
+
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ ctx.up_x = up_x
+ ctx.up_y = up_y
+ ctx.down_x = down_x
+ ctx.down_y = down_y
+ ctx.pad_x0 = pad_x0
+ ctx.pad_x1 = pad_x1
+ ctx.pad_y0 = pad_y0
+ ctx.pad_y1 = pad_y1
+ ctx.in_size = in_size
+ ctx.out_size = out_size
+
+ return grad_input
+
+ @staticmethod
+ def backward(ctx, gradgrad_input):
+ kernel, = ctx.saved_tensors
+
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
+
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
+ gradgrad_input,
+ kernel,
+ ctx.up_x,
+ ctx.up_y,
+ ctx.down_x,
+ ctx.down_y,
+ ctx.pad_x0,
+ ctx.pad_x1,
+ ctx.pad_y0,
+ ctx.pad_y1,
+ )
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
+ gradgrad_out = gradgrad_out.view(
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
+ )
+
+ return gradgrad_out, None, None, None, None, None, None, None, None
+
+
+class UpFirDn2d(Function):
+ @staticmethod
+ def forward(ctx, input, kernel, up, down, pad):
+ up_x, up_y = up
+ down_x, down_y = down
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ kernel_h, kernel_w = kernel.shape
+ batch, channel, in_h, in_w = input.shape
+ ctx.in_size = input.shape
+
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+ ctx.out_size = (out_h, out_w)
+
+ ctx.up = (up_x, up_y)
+ ctx.down = (down_x, down_y)
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
+
+ g_pad_x0 = kernel_w - pad_x0 - 1
+ g_pad_y0 = kernel_h - pad_y0 - 1
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
+
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
+
+ out = upfirdn2d_op.upfirdn2d(
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
+ )
+ # out = out.view(major, out_h, out_w, minor)
+ out = out.view(-1, channel, out_h, out_w)
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ kernel, grad_kernel = ctx.saved_tensors
+
+ grad_input = UpFirDn2dBackward.apply(
+ grad_output,
+ kernel,
+ grad_kernel,
+ ctx.up,
+ ctx.down,
+ ctx.pad,
+ ctx.g_pad,
+ ctx.in_size,
+ ctx.out_size,
+ )
+
+ return grad_input, None, None, None, None
+
+
+def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
+ out = UpFirDn2d.apply(
+ input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
+ )
+
+ return out
+
+
+def upfirdn2d_native(
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
+):
+ _, in_h, in_w, minor = input.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+
+ out = F.pad(
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
+ )
+ out = out[
+ :,
+ max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0),
+ max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0),
+ :,
+ ]
+
+ out = out.permute(0, 3, 1, 2)
+ out = out.reshape(
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
+ )
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ )
+ out = out.permute(0, 2, 3, 1)
+
+ return out[:, ::down_y, ::down_x, :]
\ No newline at end of file
diff --git a/models/stylegan2/op/upfirdn2d_kernel.cu b/models/stylegan2/op/upfirdn2d_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..a88bc7720da6cd54fccd0c4a03dd20fde85c063d
--- /dev/null
+++ b/models/stylegan2/op/upfirdn2d_kernel.cu
@@ -0,0 +1,369 @@
+// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, visit
+// https://nvlabs.github.io/stylegan2/license.html
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
+ int c = a / b;
+
+ if (c * b > a) {
+ c--;
+ }
+
+ return c;
+}
+
+struct UpFirDn2DKernelParams {
+ int up_x;
+ int up_y;
+ int down_x;
+ int down_y;
+ int pad_x0;
+ int pad_x1;
+ int pad_y0;
+ int pad_y1;
+
+ int major_dim;
+ int in_h;
+ int in_w;
+ int minor_dim;
+ int kernel_h;
+ int kernel_w;
+ int out_h;
+ int out_w;
+ int loop_major;
+ int loop_x;
+};
+
+template
+__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
+ const scalar_t *kernel,
+ const UpFirDn2DKernelParams p) {
+ int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
+ int out_y = minor_idx / p.minor_dim;
+ minor_idx -= out_y * p.minor_dim;
+ int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
+ int major_idx_base = blockIdx.z * p.loop_major;
+
+ if (out_x_base >= p.out_w || out_y >= p.out_h ||
+ major_idx_base >= p.major_dim) {
+ return;
+ }
+
+ int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
+ int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
+ int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
+ int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
+
+ for (int loop_major = 0, major_idx = major_idx_base;
+ loop_major < p.loop_major && major_idx < p.major_dim;
+ loop_major++, major_idx++) {
+ for (int loop_x = 0, out_x = out_x_base;
+ loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
+ int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
+ int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
+ int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
+ int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
+
+ const scalar_t *x_p =
+ &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
+ minor_idx];
+ const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
+ int x_px = p.minor_dim;
+ int k_px = -p.up_x;
+ int x_py = p.in_w * p.minor_dim;
+ int k_py = -p.up_y * p.kernel_w;
+
+ scalar_t v = 0.0f;
+
+ for (int y = 0; y < h; y++) {
+ for (int x = 0; x < w; x++) {
+ v += static_cast(*x_p) * static_cast(*k_p);
+ x_p += x_px;
+ k_p += k_px;
+ }
+
+ x_p += x_py - w * x_px;
+ k_p += k_py - w * k_px;
+ }
+
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
+ minor_idx] = v;
+ }
+ }
+}
+
+template
+__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
+ const scalar_t *kernel,
+ const UpFirDn2DKernelParams p) {
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
+
+ __shared__ volatile float sk[kernel_h][kernel_w];
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
+
+ int minor_idx = blockIdx.x;
+ int tile_out_y = minor_idx / p.minor_dim;
+ minor_idx -= tile_out_y * p.minor_dim;
+ tile_out_y *= tile_out_h;
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
+ int major_idx_base = blockIdx.z * p.loop_major;
+
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
+ major_idx_base >= p.major_dim) {
+ return;
+ }
+
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
+ tap_idx += blockDim.x) {
+ int ky = tap_idx / kernel_w;
+ int kx = tap_idx - ky * kernel_w;
+ scalar_t v = 0.0;
+
+ if (kx < p.kernel_w & ky < p.kernel_h) {
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
+ }
+
+ sk[ky][kx] = v;
+ }
+
+ for (int loop_major = 0, major_idx = major_idx_base;
+ loop_major < p.loop_major & major_idx < p.major_dim;
+ loop_major++, major_idx++) {
+ for (int loop_x = 0, tile_out_x = tile_out_x_base;
+ loop_x < p.loop_x & tile_out_x < p.out_w;
+ loop_x++, tile_out_x += tile_out_w) {
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
+ int tile_in_x = floor_div(tile_mid_x, up_x);
+ int tile_in_y = floor_div(tile_mid_y, up_y);
+
+ __syncthreads();
+
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
+ in_idx += blockDim.x) {
+ int rel_in_y = in_idx / tile_in_w;
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
+ int in_x = rel_in_x + tile_in_x;
+ int in_y = rel_in_y + tile_in_y;
+
+ scalar_t v = 0.0;
+
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
+ p.minor_dim +
+ minor_idx];
+ }
+
+ sx[rel_in_y][rel_in_x] = v;
+ }
+
+ __syncthreads();
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
+ out_idx += blockDim.x) {
+ int rel_out_y = out_idx / tile_out_w;
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
+ int out_x = rel_out_x + tile_out_x;
+ int out_y = rel_out_y + tile_out_y;
+
+ int mid_x = tile_mid_x + rel_out_x * down_x;
+ int mid_y = tile_mid_y + rel_out_y * down_y;
+ int in_x = floor_div(mid_x, up_x);
+ int in_y = floor_div(mid_y, up_y);
+ int rel_in_x = in_x - tile_in_x;
+ int rel_in_y = in_y - tile_in_y;
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
+
+ scalar_t v = 0.0;
+
+#pragma unroll
+ for (int y = 0; y < kernel_h / up_y; y++)
+#pragma unroll
+ for (int x = 0; x < kernel_w / up_x; x++)
+ v += sx[rel_in_y + y][rel_in_x + x] *
+ sk[kernel_y + y * up_y][kernel_x + x * up_x];
+
+ if (out_x < p.out_w & out_y < p.out_h) {
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
+ minor_idx] = v;
+ }
+ }
+ }
+ }
+}
+
+torch::Tensor upfirdn2d_op(const torch::Tensor &input,
+ const torch::Tensor &kernel, int up_x, int up_y,
+ int down_x, int down_y, int pad_x0, int pad_x1,
+ int pad_y0, int pad_y1) {
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+
+ UpFirDn2DKernelParams p;
+
+ auto x = input.contiguous();
+ auto k = kernel.contiguous();
+
+ p.major_dim = x.size(0);
+ p.in_h = x.size(1);
+ p.in_w = x.size(2);
+ p.minor_dim = x.size(3);
+ p.kernel_h = k.size(0);
+ p.kernel_w = k.size(1);
+ p.up_x = up_x;
+ p.up_y = up_y;
+ p.down_x = down_x;
+ p.down_y = down_y;
+ p.pad_x0 = pad_x0;
+ p.pad_x1 = pad_x1;
+ p.pad_y0 = pad_y0;
+ p.pad_y1 = pad_y1;
+
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
+ p.down_y;
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
+ p.down_x;
+
+ auto out =
+ at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
+
+ int mode = -1;
+
+ int tile_out_h = -1;
+ int tile_out_w = -1;
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 1;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 3 && p.kernel_w <= 3) {
+ mode = 2;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 3;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
+ mode = 4;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 5;
+ tile_out_h = 8;
+ tile_out_w = 32;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
+ mode = 6;
+ tile_out_h = 8;
+ tile_out_w = 32;
+ }
+
+ dim3 block_size;
+ dim3 grid_size;
+
+ if (tile_out_h > 0 && tile_out_w > 0) {
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
+ p.loop_x = 1;
+ block_size = dim3(32 * 8, 1, 1);
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
+ (p.major_dim - 1) / p.loop_major + 1);
+ } else {
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
+ p.loop_x = 4;
+ block_size = dim3(4, 32, 1);
+ grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
+ (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
+ (p.major_dim - 1) / p.loop_major + 1);
+ }
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
+ switch (mode) {
+ case 1:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 2:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 3:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 4:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 5:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 6:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ default:
+ upfirdn2d_kernel_large<<>>(
+ out.data_ptr(), x.data_ptr(),
+ k.data_ptr(), p);
+ }
+ });
+
+ return out;
+}
\ No newline at end of file
diff --git a/notebooks/animation_inference_playground.ipynb b/notebooks/animation_inference_playground.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..a89f99a863357e6976d074023d7154d95d04f54d
--- /dev/null
+++ b/notebooks/animation_inference_playground.ipynb
@@ -0,0 +1,432 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "0aWBuYb2UDIO"
+ },
+ "source": [
+ "# SAM: Animation Inference Playground"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Uuviq3qQkUFy"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "os.chdir('/content')\n",
+ "CODE_DIR = 'SAM'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "QQ6XEmlHlXbk",
+ "outputId": "2d2af9bb-1bbe-4946-84db-dbeaf62aa226"
+ },
+ "outputs": [],
+ "source": [
+ "!git clone https://github.com/yuval-alaluf/SAM.git $CODE_DIR"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "JaRUFuVHkzye",
+ "outputId": "12fc6dcd-951b-472f-b9d6-0a09f749e931"
+ },
+ "outputs": [],
+ "source": [
+ "!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip\n",
+ "!sudo unzip ninja-linux.zip -d /usr/local/bin/\n",
+ "!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "23baccYQlU9E"
+ },
+ "outputs": [],
+ "source": [
+ "os.chdir(f'./{CODE_DIR}')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "lEWIzkaLSsFY"
+ },
+ "outputs": [],
+ "source": [
+ "from argparse import Namespace\n",
+ "import os\n",
+ "import sys\n",
+ "import pprint\n",
+ "import numpy as np\n",
+ "from PIL import Image\n",
+ "import torch\n",
+ "import torchvision.transforms as transforms\n",
+ "\n",
+ "sys.path.append(\".\")\n",
+ "sys.path.append(\"..\")\n",
+ "\n",
+ "from datasets.augmentations import AgeTransformer\n",
+ "from utils.common import tensor2im\n",
+ "from models.psp import pSp"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "9J3NEVlESsFl"
+ },
+ "outputs": [],
+ "source": [
+ "EXPERIMENT_TYPE = 'ffhq_aging'"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "eFjfO9q9SsFm"
+ },
+ "source": [
+ "## Step 1: Download Pretrained Model\n",
+ "As part of this repository, we provide our pretrained aging model.\n",
+ "We'll download the model for the selected experiments as save it to the folder `../pretrained_models`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "e2L9GRCFSsFm"
+ },
+ "outputs": [],
+ "source": [
+ "def get_download_model_command(file_id, file_name):\n",
+ " \"\"\" Get wget download command for downloading the desired model and save to directory ../pretrained_models. \"\"\"\n",
+ " current_directory = os.getcwd()\n",
+ " save_path = os.path.join(os.path.dirname(current_directory), \"pretrained_models\")\n",
+ " if not os.path.exists(save_path):\n",
+ " os.makedirs(save_path)\n",
+ " url = r\"\"\"wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id={FILE_ID}\" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt\"\"\".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path)\n",
+ " return url"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "yotXNk8PSsFn"
+ },
+ "outputs": [],
+ "source": [
+ "MODEL_PATHS = {\n",
+ " \"ffhq_aging\": {\"id\": \"1XyumF6_fdAxFmxpFcmPf-q84LU_22EMC\", \"name\": \"sam_ffhq_aging.pt\"}\n",
+ "}\n",
+ "\n",
+ "path = MODEL_PATHS[EXPERIMENT_TYPE]\n",
+ "download_command = get_download_model_command(file_id=path[\"id\"], file_name=path[\"name\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "tXUqcxd8SsFo",
+ "outputId": "0e2426ec-f96f-44cb-d1a3-736807bc4f37"
+ },
+ "outputs": [],
+ "source": [
+ "!wget {download_command}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "mismxuEvSsFp"
+ },
+ "source": [
+ "## Step 3: Define Inference Parameters"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "NRzP-untSsFq"
+ },
+ "source": [
+ "Below we have a dictionary defining parameters such as the path to the pretrained model to use and the path to the\n",
+ "image to perform inference on.\n",
+ "While we provide default values to run this script, feel free to change as needed."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "pHeJRfsnSsFq"
+ },
+ "outputs": [],
+ "source": [
+ "EXPERIMENT_DATA_ARGS = {\n",
+ " \"ffhq_aging\": {\n",
+ " \"model_path\": \"../pretrained_models/sam_ffhq_aging.pt\",\n",
+ " \"transform\": transforms.Compose([\n",
+ " transforms.Resize((256, 256)),\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
+ " }\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "IgzLA96mSsFq"
+ },
+ "outputs": [],
+ "source": [
+ "EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[EXPERIMENT_TYPE]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "kqlL9U5uSsFr"
+ },
+ "source": [
+ "## Step 4: Load Pretrained Model\n",
+ "We assume that you have downloaded the pretrained aging model and placed it in the path defined above."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "8khx-2fcSsFr"
+ },
+ "outputs": [],
+ "source": [
+ "model_path = EXPERIMENT_ARGS['model_path']\n",
+ "ckpt = torch.load(model_path, map_location='cpu')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "fVfLlsfjSsFr",
+ "outputId": "ccc5cc29-e59d-414c-a216-fa967ece4eb9"
+ },
+ "outputs": [],
+ "source": [
+ "opts = ckpt['opts']\n",
+ "pprint.pprint(opts)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "vxYxSJk1SsFs"
+ },
+ "outputs": [],
+ "source": [
+ "# update the training options\n",
+ "opts['checkpoint_path'] = model_path"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Z4LNoYt_SsFs",
+ "outputId": "030da871-e4a2-42a9-9c4b-8b96e028fd40"
+ },
+ "outputs": [],
+ "source": [
+ "opts = Namespace(**opts)\n",
+ "net = pSp(opts)\n",
+ "net.eval()\n",
+ "net.cuda()\n",
+ "print('Model successfully loaded!')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "3mNH95EJSsFs"
+ },
+ "source": [
+ "### Utils for Generating MP4 "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "GASwMVmJSsFs"
+ },
+ "outputs": [],
+ "source": [
+ "import imageio\n",
+ "from tqdm import tqdm\n",
+ "import matplotlib\n",
+ "from IPython.display import HTML\n",
+ "from base64 import b64encode\n",
+ "\n",
+ "matplotlib.use('module://ipykernel.pylab.backend_inline')\n",
+ "%matplotlib inline\n",
+ "\n",
+ "\n",
+ "def generate_mp4(out_name, images, kwargs):\n",
+ " writer = imageio.get_writer(out_name + '.mp4', **kwargs)\n",
+ " for image in images:\n",
+ " writer.append_data(image)\n",
+ " writer.close()\n",
+ "\n",
+ "\n",
+ "def run_on_batch_to_vecs(inputs, net):\n",
+ " _, result_batch = net(inputs.to(\"cuda\").float(), return_latents=True, randomize_noise=False, resize=False)\n",
+ " return result_batch.cpu()\n",
+ "\n",
+ "\n",
+ "def get_result_from_vecs(vectors_a, vectors_b, alpha):\n",
+ " results = []\n",
+ " for i in range(len(vectors_a)):\n",
+ " cur_vec = vectors_b[i] * alpha + vectors_a[i] * (1 - alpha)\n",
+ " res = net(cur_vec.cuda(), randomize_noise=False, input_code=True, input_is_full=True, resize=False)\n",
+ " results.append(res[0])\n",
+ " return results\n",
+ "\n",
+ "\n",
+ "def show_mp4(filename, width=400):\n",
+ " mp4 = open(filename + '.mp4', 'rb').read()\n",
+ " data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
+ " display(HTML(\"\"\"\n",
+ " \n",
+ " \n",
+ " \n",
+ " \"\"\" % (width, data_url)))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "y4oqFBfTSsFt",
+ "outputId": "6ff2e285-e5e0-4d9c-e16d-af43eee1e8f1"
+ },
+ "outputs": [],
+ "source": [
+ "SEED = 42\n",
+ "np.random.seed(SEED)\n",
+ "\n",
+ "img_transforms = EXPERIMENT_ARGS['transform']\n",
+ "n_transition = 25\n",
+ "kwargs = {'fps': 40}\n",
+ "save_path = \"notebooks/animations\"\n",
+ "os.makedirs(save_path, exist_ok=True)\n",
+ "\n",
+ "#################################################################\n",
+ "# TODO: define your image paths here to be fed into the model\n",
+ "#################################################################\n",
+ "root_dir = 'notebooks/images'\n",
+ "ims = ['866', '1287', '2468']\n",
+ "im_paths = [os.path.join(root_dir, im) + '.jpg' for im in ims]\n",
+ "\n",
+ "# NOTE: Please make sure the images are pre-aligned!\n",
+ "\n",
+ "target_ages = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 90, 80, 70, 60, 50, 40, 30, 20, 10, 0]\n",
+ "age_transformers = [AgeTransformer(target_age=age) for age in target_ages]\n",
+ "\n",
+ "for image_path in im_paths:\n",
+ " image_name = os.path.basename(image_path)\n",
+ " print(f'Working on image: {image_name}')\n",
+ " original_image = Image.open(image_path).convert(\"RGB\")\n",
+ " input_image = img_transforms(original_image)\n",
+ " all_vecs = []\n",
+ " for idx, age_transformer in enumerate(age_transformers):\n",
+ "\n",
+ " input_age_batch = [age_transformer(input_image.cpu()).to('cuda')]\n",
+ " input_age_batch = torch.stack(input_age_batch)\n",
+ "\n",
+ " # get latent vector for the current target age amount\n",
+ " with torch.no_grad():\n",
+ " result_vec = run_on_batch_to_vecs(input_age_batch, net)\n",
+ " result_image = get_result_from_vecs([result_vec], result_vec, 0)[0]\n",
+ " all_vecs.append([result_vec])\n",
+ "\n",
+ " images = []\n",
+ " for i in range(1, len(target_ages)):\n",
+ " alpha_vals = np.linspace(0, 1, n_transition).tolist()\n",
+ " for alpha in tqdm(alpha_vals):\n",
+ " result_image = get_result_from_vecs(all_vecs[i-1], all_vecs[i], alpha)[0]\n",
+ " output_im = tensor2im(result_image)\n",
+ " images.append(np.array(output_im))\n",
+ "\n",
+ " animation_path = os.path.join(save_path, f\"{image_name}_animation\")\n",
+ " generate_mp4(animation_path, images, kwargs)\n",
+ " show_mp4(animation_path)"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "name": "inference_playground_mp4.ipynb",
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
\ No newline at end of file
diff --git a/notebooks/images/1287.jpg b/notebooks/images/1287.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6a89ac6774e69fde72fbc369ce1e2597f073634c
Binary files /dev/null and b/notebooks/images/1287.jpg differ
diff --git a/notebooks/images/2468.jpg b/notebooks/images/2468.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2e033da2e5e2991aa9b9dc1a24be1d743068bd37
Binary files /dev/null and b/notebooks/images/2468.jpg differ
diff --git a/notebooks/images/866.jpg b/notebooks/images/866.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3561565a2c3ae0f5f4a3fb8042dc28dda73658a2
Binary files /dev/null and b/notebooks/images/866.jpg differ
diff --git a/notebooks/inference_playground.ipynb b/notebooks/inference_playground.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..591e35fdc3e36a4eb9b3e36aded21889ed205429
--- /dev/null
+++ b/notebooks/inference_playground.ipynb
@@ -0,0 +1,539 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "5OqEP1SlGeVZ"
+ },
+ "source": [
+ "# SAM: Inference Playground"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "dE2hzjSNQs0p"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "os.chdir('/content')\n",
+ "CODE_DIR = 'SAM'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "bbaMZ40hQxT0",
+ "outputId": "f7fac42a-77e7-4b79-ab87-b8805a4b8f39"
+ },
+ "outputs": [],
+ "source": [
+ "!git clone https://github.com/yuval-alaluf/SAM.git $CODE_DIR"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "43F-3KfeQ08S",
+ "outputId": "f1def785-f7aa-4016-c6f7-afc2463d6b06"
+ },
+ "outputs": [],
+ "source": [
+ "!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip\n",
+ "!sudo unzip ninja-linux.zip -d /usr/local/bin/\n",
+ "!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "av0207x4Q2iL"
+ },
+ "outputs": [],
+ "source": [
+ "os.chdir(f'./{CODE_DIR}')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Zvwx9NsiQq9t"
+ },
+ "outputs": [],
+ "source": [
+ "from argparse import Namespace\n",
+ "import os\n",
+ "import sys\n",
+ "import pprint\n",
+ "import numpy as np\n",
+ "from PIL import Image\n",
+ "import torch\n",
+ "import torchvision.transforms as transforms\n",
+ "\n",
+ "sys.path.append(\".\")\n",
+ "sys.path.append(\"..\")\n",
+ "\n",
+ "from datasets.augmentations import AgeTransformer\n",
+ "from utils.common import tensor2im\n",
+ "from models.psp import pSp"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "uj3dJjQsQq9y"
+ },
+ "outputs": [],
+ "source": [
+ "EXPERIMENT_TYPE = 'ffhq_aging'"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "mStxrAtuQq9y"
+ },
+ "source": [
+ "## Step 1: Download Pretrained Model\n",
+ "As part of this repository, we provide our pretrained aging model.\n",
+ "We'll download the model for the selected experiments as save it to the folder `../pretrained_models`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "_pC38oLGQq9z"
+ },
+ "outputs": [],
+ "source": [
+ "def get_download_model_command(file_id, file_name):\n",
+ " \"\"\" Get wget download command for downloading the desired model and save to directory ../pretrained_models. \"\"\"\n",
+ " current_directory = os.getcwd()\n",
+ " save_path = os.path.join(os.path.dirname(current_directory), \"pretrained_models\")\n",
+ " if not os.path.exists(save_path):\n",
+ " os.makedirs(save_path)\n",
+ " url = r\"\"\"wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id={FILE_ID}\" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt\"\"\".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path)\n",
+ " return url "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "rOQ2Vz2kQq9z"
+ },
+ "outputs": [],
+ "source": [
+ "MODEL_PATHS = {\n",
+ " \"ffhq_aging\": {\"id\": \"1XyumF6_fdAxFmxpFcmPf-q84LU_22EMC\", \"name\": \"sam_ffhq_aging.pt\"}\n",
+ "}\n",
+ "\n",
+ "path = MODEL_PATHS[EXPERIMENT_TYPE]\n",
+ "download_command = get_download_model_command(file_id=path[\"id\"], file_name=path[\"name\"]) "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "K0nHPvo5Qq9z",
+ "outputId": "3ac7ce05-077a-4d81-b6ca-0e5b2dc61753"
+ },
+ "outputs": [],
+ "source": [
+ "!wget {download_command}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "WvRDiRrMQq90"
+ },
+ "source": [
+ "## Step 2: Define Inference Parameters"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "GNaSSzZsQq90"
+ },
+ "source": [
+ "Below we have a dictionary defining parameters such as the path to the pretrained model to use and the path to the\n",
+ "image to perform inference on.\n",
+ "While we provide default values to run this script, feel free to change as needed."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "yaGqalwuQq90"
+ },
+ "outputs": [],
+ "source": [
+ "EXPERIMENT_DATA_ARGS = {\n",
+ " \"ffhq_aging\": {\n",
+ " \"model_path\": \"../pretrained_models/sam_ffhq_aging.pt\",\n",
+ " \"image_path\": \"notebooks/images/866.jpg\",\n",
+ " \"transform\": transforms.Compose([\n",
+ " transforms.Resize((256, 256)),\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
+ " }\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "wjkLqLkDQq90"
+ },
+ "outputs": [],
+ "source": [
+ "EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[EXPERIMENT_TYPE]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "YkfqoKJwQq91"
+ },
+ "source": [
+ "## Step 3: Load Pretrained Model\n",
+ "We assume that you have downloaded the pretrained aging model and placed it in the path defined above"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "cZuho98JQq91"
+ },
+ "outputs": [],
+ "source": [
+ "model_path = EXPERIMENT_ARGS['model_path']\n",
+ "ckpt = torch.load(model_path, map_location='cpu')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "f6NOxONxQq91",
+ "outputId": "7eecdad5-0678-45d4-d416-898e3fce250d"
+ },
+ "outputs": [],
+ "source": [
+ "opts = ckpt['opts']\n",
+ "pprint.pprint(opts)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "J6c93qE9Qq91"
+ },
+ "outputs": [],
+ "source": [
+ "# update the training options\n",
+ "opts['checkpoint_path'] = model_path"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "JRTfKFrkQq91",
+ "outputId": "1ebe3ebb-d33f-4764-d88c-8ba0a66ce0a8"
+ },
+ "outputs": [],
+ "source": [
+ "opts = Namespace(**opts)\n",
+ "net = pSp(opts)\n",
+ "net.eval()\n",
+ "net.cuda()\n",
+ "print('Model successfully loaded!')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "z6BegCirQq92"
+ },
+ "source": [
+ "## Step 4: Visualize Input"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "kc4Sr31TQq92"
+ },
+ "outputs": [],
+ "source": [
+ "image_path = EXPERIMENT_DATA_ARGS[EXPERIMENT_TYPE][\"image_path\"]\n",
+ "original_image = Image.open(image_path).convert(\"RGB\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 273
+ },
+ "id": "bKA9BO9_Qq92",
+ "outputId": "51152c46-8c8d-4020-f343-dc01dd523084"
+ },
+ "outputs": [],
+ "source": [
+ "original_image.resize((256, 256))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "u3a50tAcQq92"
+ },
+ "source": [
+ "## Step 5: Perform Inference"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "o6oqf8JwzK0K"
+ },
+ "source": [
+ "### Align Image\n",
+ "\n",
+ "Before running inference we'll run alignment on the input image."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "y244_ejy9Drx",
+ "outputId": "bb583763-1aa1-4745-95f5-4b7bb2f96715"
+ },
+ "outputs": [],
+ "source": [
+ "!wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2\n",
+ "!bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "hJ9Ce1aYzmFF"
+ },
+ "outputs": [],
+ "source": [
+ "def run_alignment(image_path):\n",
+ " import dlib\n",
+ " from scripts.align_all_parallel import align_face\n",
+ " predictor = dlib.shape_predictor(\"shape_predictor_68_face_landmarks.dat\")\n",
+ " aligned_image = align_face(filepath=image_path, predictor=predictor) \n",
+ " print(\"Aligned image has shape: {}\".format(aligned_image.size))\n",
+ " return aligned_image "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "aTZcKMdK8y77",
+ "outputId": "18d7a5da-9e98-4373-c296-727216406dd5"
+ },
+ "outputs": [],
+ "source": [
+ "aligned_image = run_alignment(image_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 273
+ },
+ "id": "hUBAfodh5PaM",
+ "outputId": "81545ff1-4184-4a3a-d887-52ad9f71e24a"
+ },
+ "outputs": [],
+ "source": [
+ "aligned_image.resize((256, 256))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "gMyoh4X1HYAS"
+ },
+ "source": [
+ "### Run Inference"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "XkzQpi1aQq92"
+ },
+ "outputs": [],
+ "source": [
+ "img_transforms = EXPERIMENT_ARGS['transform']\n",
+ "input_image = img_transforms(aligned_image)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "lI7yWNPDQq92"
+ },
+ "outputs": [],
+ "source": [
+ "# we'll run the image on multiple target ages \n",
+ "target_ages = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]\n",
+ "age_transformers = [AgeTransformer(target_age=age) for age in target_ages]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "kLP4pF-2Qq93"
+ },
+ "outputs": [],
+ "source": [
+ "def run_on_batch(inputs, net):\n",
+ " result_batch = net(inputs.to(\"cuda\").float(), randomize_noise=False, resize=False)\n",
+ " return result_batch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "nfrY_gEEQq93",
+ "outputId": "b3b8f881-424a-4eac-9e06-1598442f9c62"
+ },
+ "outputs": [],
+ "source": [
+ "# for each age transformed age, we'll concatenate the results to display them side-by-side\n",
+ "results = np.array(aligned_image.resize((1024, 1024)))\n",
+ "for age_transformer in age_transformers:\n",
+ " print(f\"Running on target age: {age_transformer.target_age}\")\n",
+ " with torch.no_grad():\n",
+ " input_image_age = [age_transformer(input_image.cpu()).to('cuda')]\n",
+ " input_image_age = torch.stack(input_image_age)\n",
+ " result_tensor = run_on_batch(input_image_age, net)[0]\n",
+ " result_image = tensor2im(result_tensor)\n",
+ " results = np.concatenate([results, result_image], axis=1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "IFgwfLTKQq93"
+ },
+ "source": [
+ "### Visualize Result"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "wpwyrv0iQq93"
+ },
+ "outputs": [],
+ "source": [
+ "results = Image.fromarray(results)\n",
+ "results # this is a very large image (11*1024 x 1024) so it may take some time to display!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "4sL7fHp9Qq93",
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# save image at full resolution\n",
+ "results.save(\"notebooks/images/age_transformed_image.jpg\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "name": "inference_playground.ipynb",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
\ No newline at end of file
diff --git a/options/__init__.py b/options/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/options/test_options.py b/options/test_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..16ad984db774a61b79f82d38edafa9993e07f6b4
--- /dev/null
+++ b/options/test_options.py
@@ -0,0 +1,49 @@
+from argparse import ArgumentParser
+
+
+class TestOptions:
+
+ def __init__(self):
+ self.parser = ArgumentParser()
+ self.initialize()
+
+ def initialize(self):
+ # arguments for inference script
+ self.parser.add_argument('--exp_dir', type=str,
+ help='Path to experiment output directory')
+ self.parser.add_argument('--checkpoint_path', default=None, type=str,
+ help='Path to pSp model checkpoint')
+ self.parser.add_argument('--data_path', type=str, default='gt_images',
+ help='Path to directory of images to evaluate')
+ self.parser.add_argument('--couple_outputs', action='store_true',
+ help='Whether to also save inputs + outputs side-by-side')
+ self.parser.add_argument('--resize_outputs', action='store_true',
+ help='Whether to resize outputs to 256x256 or keep at 1024x1024')
+
+ self.parser.add_argument('--test_batch_size', default=2, type=int,
+ help='Batch size for testing and inference')
+ self.parser.add_argument('--test_workers', default=2, type=int,
+ help='Number of test/inference dataloader workers')
+
+ # arguments for style-mixing script
+ self.parser.add_argument('--n_images', type=int, default=None,
+ help='Number of images to output. If None, run on all data')
+ self.parser.add_argument('--n_outputs_to_generate', type=int, default=5,
+ help='Number of outputs to generate per input image.')
+ self.parser.add_argument('--mix_alpha', type=float, default=None,
+ help='Alpha value for style-mixing')
+ self.parser.add_argument('--latent_mask', type=str, default=None,
+ help='Comma-separated list of latents to perform style-mixing with')
+
+ # arguments for aging
+ self.parser.add_argument('--target_age', type=str, default=None,
+ help='Target age for inference. Can be comma-separated list for multiple ages.')
+
+ # arguments for reference guided aging inference
+ self.parser.add_argument('--ref_images_paths_file', type=str, default='./ref_images.txt',
+ help='Path to file containing a list of reference images to use for '
+ 'reference guided inference.')
+
+ def parse(self):
+ opts = self.parser.parse_args()
+ return opts
diff --git a/options/train_options.py b/options/train_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6d8781f3d5ce94890a1759c378a43aae11ab25b
--- /dev/null
+++ b/options/train_options.py
@@ -0,0 +1,92 @@
+from argparse import ArgumentParser
+from configs.paths_config import model_paths
+
+
+class TrainOptions:
+
+ def __init__(self):
+ self.parser = ArgumentParser()
+ self.initialize()
+
+ def initialize(self):
+ self.parser.add_argument('--exp_dir', type=str,
+ help='Path to experiment output directory')
+ self.parser.add_argument('--dataset_type', default='ffhq_aging', type=str,
+ help='Type of dataset/experiment to run')
+ self.parser.add_argument('--input_nc', default=4, type=int,
+ help='Number of input image channels to the psp encoder')
+ self.parser.add_argument('--label_nc', default=0, type=int,
+ help='Number of input label channels to the psp encoder')
+ self.parser.add_argument('--output_size', default=1024, type=int,
+ help='Output size of generator')
+
+ self.parser.add_argument('--batch_size', default=4, type=int,
+ help='Batch size for training')
+ self.parser.add_argument('--test_batch_size', default=2, type=int,
+ help='Batch size for testing and inference')
+ self.parser.add_argument('--workers', default=4, type=int,
+ help='Number of train dataloader workers')
+ self.parser.add_argument('--test_workers', default=2, type=int,
+ help='Number of test/inference dataloader workers')
+
+ self.parser.add_argument('--learning_rate', default=0.0001, type=float,
+ help='Optimizer learning rate')
+ self.parser.add_argument('--optim_name', default='ranger', type=str,
+ help='Which optimizer to use')
+ self.parser.add_argument('--train_decoder', action='store_true',
+ help='Whether to train the decoder model')
+ self.parser.add_argument('--start_from_latent_avg', action='store_true',
+ help='Whether to add average latent vector to generate codes from encoder.')
+ self.parser.add_argument('--start_from_encoded_w_plus', action='store_true',
+ help='Whether to learn residual wrt w+ of encoded image using pretrained pSp.')
+
+ self.parser.add_argument('--lpips_lambda', default=0, type=float,
+ help='LPIPS loss multiplier factor')
+ self.parser.add_argument('--id_lambda', default=0, type=float,
+ help='ID loss multiplier factor')
+ self.parser.add_argument('--l2_lambda', default=0, type=float,
+ help='L2 loss multiplier factor')
+ self.parser.add_argument('--w_norm_lambda', default=0, type=float,
+ help='W-norm loss multiplier factor')
+ self.parser.add_argument('--aging_lambda', default=0, type=float,
+ help='Aging loss multiplier factor')
+ self.parser.add_argument('--cycle_lambda', default=0, type=float,
+ help='Cycle loss multiplier factor')
+
+ self.parser.add_argument('--lpips_lambda_crop', default=0, type=float,
+ help='LPIPS loss multiplier factor for inner image region')
+ self.parser.add_argument('--l2_lambda_crop', default=0, type=float,
+ help='L2 loss multiplier factor for inner image region')
+
+ self.parser.add_argument('--lpips_lambda_aging', default=0, type=float,
+ help='LPIPS loss multiplier factor for aging')
+ self.parser.add_argument('--l2_lambda_aging', default=0, type=float,
+ help='L2 loss multiplier factor for aging')
+
+ self.parser.add_argument('--stylegan_weights', default=model_paths['stylegan_ffhq'], type=str,
+ help='Path to StyleGAN model weights')
+ self.parser.add_argument('--checkpoint_path', default=None, type=str,
+ help='Path to pSp model checkpoint')
+
+ self.parser.add_argument('--max_steps', default=500000, type=int,
+ help='Maximum number of training steps')
+ self.parser.add_argument('--image_interval', default=100, type=int,
+ help='Interval for logging train images during training')
+ self.parser.add_argument('--board_interval', default=50, type=int,
+ help='Interval for logging metrics to tensorboard')
+ self.parser.add_argument('--val_interval', default=1000, type=int,
+ help='Validation interval')
+ self.parser.add_argument('--save_interval', default=None, type=int,
+ help='Model checkpoint interval')
+
+ # arguments for aging
+ self.parser.add_argument('--target_age', default=None, type=str,
+ help='Target age for training. Use `uniform_random` for random sampling of target age')
+ self.parser.add_argument('--use_weighted_id_loss', action="store_true",
+ help="Whether to weight id loss based on change in age (more change -> less weight)")
+ self.parser.add_argument('--pretrained_psp_path', default=model_paths['pretrained_psp'], type=str,
+ help="Path to pretrained pSp network.")
+
+ def parse(self):
+ opts = self.parser.parse_args()
+ return opts
diff --git a/predict.py b/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..ead1ae05b6c4a552cbfb2fde6f03053cdf26bc0e
--- /dev/null
+++ b/predict.py
@@ -0,0 +1,91 @@
+import tempfile
+from argparse import Namespace
+import dlib
+import imageio
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+from cog import BasePredictor, Path, Input
+
+from datasets.augmentations import AgeTransformer
+from models.psp import pSp
+from scripts.align_all_parallel import align_face
+from utils.common import tensor2im
+
+
+class Predictor(BasePredictor):
+ def setup(self):
+ self.transform = transforms.Compose(
+ [
+ transforms.Resize((256, 256)),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
+ ]
+ )
+ model_path = "pretrained_models/sam_ffhq_aging.pt"
+ ckpt = torch.load(model_path, map_location="cpu")
+
+ opts = ckpt["opts"]
+ opts["checkpoint_path"] = model_path
+ opts["device"] = "cuda" if torch.cuda.is_available() else "cpu"
+
+ self.opts = Namespace(**opts)
+
+ def predict(
+ self,
+ image: Path = Input(
+ description="facial image",
+ ),
+ target_age: str = Input(
+ description="age of the output image, when choose 'default' "
+ "a gif for age from 0, 10, 20,...,to 100 will be displayed",
+ ),
+ ) -> Path:
+ net = pSp(self.opts)
+ net.eval()
+ if torch.cuda.is_available():
+ net.cuda()
+
+ # align image
+ aligned_image = run_alignment(str(image))
+ aligned_image.resize((256, 256))
+
+ input_image = self.transform(aligned_image)
+
+ if target_age == "default":
+ target_ages = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
+ age_transformers = [AgeTransformer(target_age=age) for age in target_ages]
+ else:
+ age_transformers = [AgeTransformer(target_age=target_age)]
+
+ results = np.array(aligned_image.resize((1024, 1024)))
+ all_imgs = []
+ for age_transformer in age_transformers:
+ print(f"Running on target age: {age_transformer.target_age}")
+ with torch.no_grad():
+ input_image_age = [age_transformer(input_image.cpu()).to("cuda")]
+ input_image_age = torch.stack(input_image_age)
+ result_tensor = run_on_batch(input_image_age, net)[0]
+ result_image = tensor2im(result_tensor)
+ all_imgs.append(result_image)
+ results = np.concatenate([results, result_image], axis=1)
+
+ if target_age == "default":
+ out_path = Path(tempfile.mkdtemp()) / "output.gif"
+ imageio.mimwrite(str(out_path), all_imgs, duration=0.3)
+ else:
+ out_path = Path(tempfile.mkdtemp()) / "output.png"
+ imageio.imwrite(str(out_path), all_imgs[0])
+ return out_path
+
+
+def run_alignment(image_path):
+ predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
+ aligned_image = align_face(filepath=image_path, predictor=predictor)
+ print("Aligned image has shape: {}".format(aligned_image.size))
+ return aligned_image
+
+
+def run_on_batch(inputs, net):
+ result_batch = net(inputs.to("cuda").float(), randomize_noise=False, resize=False)
+ return result_batch
diff --git a/scripts/__pycache__/align_all_parallel.cpython-310.pyc b/scripts/__pycache__/align_all_parallel.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..10abdb5a25cb423e62b200a277704e0c52e74445
Binary files /dev/null and b/scripts/__pycache__/align_all_parallel.cpython-310.pyc differ
diff --git a/scripts/align_all_parallel.py b/scripts/align_all_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fff4e635870858fbf732396ff2c05e5a39f1d93
--- /dev/null
+++ b/scripts/align_all_parallel.py
@@ -0,0 +1,208 @@
+"""
+brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)
+author: lzhbrian (https://lzhbrian.me)
+date: 2020.1.5
+note: code is heavily borrowed from
+ https://github.com/NVlabs/ffhq-dataset
+ http://dlib.net/face_landmark_detection.py.html
+
+requirements:
+ apt install cmake
+ conda install Pillow numpy scipy
+ pip install dlib
+ # download face landmark model from:
+ # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
+"""
+from argparse import ArgumentParser
+import time
+import numpy as np
+import PIL
+import PIL.Image
+import os
+import scipy
+import scipy.ndimage
+import dlib
+import multiprocessing as mp
+import math
+
+from configs.paths_config import model_paths
+SHAPE_PREDICTOR_PATH = model_paths["shape_predictor"]
+
+
+def get_landmark(filepath, predictor):
+ """get landmark with dlib
+ :return: np.array shape=(68, 2)
+ """
+ detector = dlib.get_frontal_face_detector()
+
+ img = dlib.load_rgb_image(filepath)
+ dets = detector(img, 1)
+
+ shape = None
+
+ for k, d in enumerate(dets):
+ shape = predictor(img, d)
+
+ if not shape:
+ raise Exception("Could not find face in image. Try another!")
+
+ t = list(shape.parts())
+ a = []
+ for tt in t:
+ a.append([tt.x, tt.y])
+ lm = np.array(a)
+ return lm
+
+
+def align_face(filepath, predictor):
+ """
+ :param filepath: str
+ :return: PIL Image
+ """
+
+ lm = get_landmark(filepath, predictor)
+
+ lm_chin = lm[0: 17] # left-right
+ lm_eyebrow_left = lm[17: 22] # left-right
+ lm_eyebrow_right = lm[22: 27] # left-right
+ lm_nose = lm[27: 31] # top-down
+ lm_nostrils = lm[31: 36] # top-down
+ lm_eye_left = lm[36: 42] # left-clockwise
+ lm_eye_right = lm[42: 48] # left-clockwise
+ lm_mouth_outer = lm[48: 60] # left-clockwise
+ lm_mouth_inner = lm[60: 68] # left-clockwise
+
+ # Calculate auxiliary vectors.
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ eye_avg = (eye_left + eye_right) * 0.5
+ eye_to_eye = eye_right - eye_left
+ mouth_left = lm_mouth_outer[0]
+ mouth_right = lm_mouth_outer[6]
+ mouth_avg = (mouth_left + mouth_right) * 0.5
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Choose oriented crop rectangle.
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+ x /= np.hypot(*x)
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
+ y = np.flipud(x) * [-1, 1]
+ c = eye_avg + eye_to_mouth * 0.1
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+ qsize = np.hypot(*x) * 2
+
+ # read image
+ img = PIL.Image.open(filepath).convert("RGB")
+
+ output_size = 256
+ transform_size = 256
+ enable_padding = True
+
+ # Shrink.
+ shrink = int(np.floor(qsize / output_size * 0.5))
+ if shrink > 1:
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
+ img = img.resize(rsize, PIL.Image.ANTIALIAS)
+ quad /= shrink
+ qsize /= shrink
+
+ # Crop.
+ border = max(int(np.rint(qsize * 0.1)), 3)
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
+ min(crop[3] + border, img.size[1]))
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
+ img = img.crop(crop)
+ quad -= crop[0:2]
+
+ # Pad.
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
+ max(pad[3] - img.size[1] + border, 0))
+ if enable_padding and max(pad) > border - 4:
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
+ img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+ h, w, _ = img.shape
+ y, x, _ = np.ogrid[:h, :w, :1]
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
+ 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
+ blur = qsize * 0.02
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
+ img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
+ quad += pad[:2]
+
+ # Transform.
+ img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
+ if output_size < transform_size:
+ img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
+
+ # Save aligned image.
+ return img
+
+
+def chunks(lst, n):
+ """Yield successive n-sized chunks from lst."""
+ for i in range(0, len(lst), n):
+ yield lst[i:i + n]
+
+
+def extract_on_paths(file_paths):
+ predictor = dlib.shape_predictor(SHAPE_PREDICTOR_PATH)
+ pid = mp.current_process().name
+ print(f'\t{pid} is starting to extract on #{len(file_paths)} images')
+ tot_count = len(file_paths)
+ count = 0
+ for file_path, res_path in file_paths:
+ count += 1
+ if count % 100 == 0:
+ print(f'{pid} done with {count}/{tot_count}')
+ try:
+ res = align_face(file_path, predictor)
+ res = res.convert('RGB')
+ os.makedirs(os.path.dirname(res_path), exist_ok=True)
+ res.save(res_path)
+ except Exception:
+ continue
+ print('\tDone!')
+
+
+def parse_args():
+ parser = ArgumentParser(add_help=False)
+ parser.add_argument('--num_threads', type=int, default=1)
+ parser.add_argument('--root_path', type=str, default='')
+ args = parser.parse_args()
+ return args
+
+
+def run(args):
+ root_path = args.root_path
+ out_crops_path = root_path + '_crops'
+ if not os.path.exists(out_crops_path):
+ os.makedirs(out_crops_path, exist_ok=True)
+
+ file_paths = []
+ for root, dirs, files in os.walk(root_path):
+ for file in files:
+ file_path = os.path.join(root, file)
+ fname = os.path.join(out_crops_path, os.path.relpath(file_path, root_path))
+ res_path = f'{os.path.splitext(fname)[0]}.jpg'
+ if os.path.splitext(file_path)[1] == '.txt' or os.path.exists(res_path):
+ continue
+ file_paths.append((file_path, res_path))
+
+ file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads))))
+ print(len(file_chunks))
+ pool = mp.Pool(args.num_threads)
+ print(f'Running on {len(file_paths)} paths\nHere we goooo')
+ tic = time.time()
+ pool.map(extract_on_paths, file_chunks)
+ toc = time.time()
+ print(f'Mischief managed in {str(toc - tic)}s')
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ run(args)
diff --git a/scripts/inference.py b/scripts/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c0a365fac78c78e5c391a00b1c7e2cee0ff8e59
--- /dev/null
+++ b/scripts/inference.py
@@ -0,0 +1,107 @@
+from argparse import Namespace
+import os
+import time
+from tqdm import tqdm
+from PIL import Image
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+
+import sys
+sys.path.append(".")
+sys.path.append("..")
+
+from configs import data_configs
+from datasets.inference_dataset import InferenceDataset
+from datasets.augmentations import AgeTransformer
+from utils.common import tensor2im, log_image
+from options.test_options import TestOptions
+from models.psp import pSp
+
+
+def run():
+ test_opts = TestOptions().parse()
+
+ out_path_results = os.path.join(test_opts.exp_dir, 'inference_results')
+ out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled')
+ os.makedirs(out_path_results, exist_ok=True)
+ os.makedirs(out_path_coupled, exist_ok=True)
+
+ # update test options with options used during training
+ ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
+ opts = ckpt['opts']
+ opts.update(vars(test_opts))
+ opts = Namespace(**opts)
+
+ net = pSp(opts)
+ net.eval()
+ net.cuda()
+
+ age_transformers = [AgeTransformer(target_age=age) for age in opts.target_age.split(',')]
+
+ print(f'Loading dataset for {opts.dataset_type}')
+ dataset_args = data_configs.DATASETS[opts.dataset_type]
+ transforms_dict = dataset_args['transforms'](opts).get_transforms()
+ dataset = InferenceDataset(root=opts.data_path,
+ transform=transforms_dict['transform_inference'],
+ opts=opts)
+ dataloader = DataLoader(dataset,
+ batch_size=opts.test_batch_size,
+ shuffle=False,
+ num_workers=int(opts.test_workers),
+ drop_last=False)
+
+ if opts.n_images is None:
+ opts.n_images = len(dataset)
+
+ global_time = []
+ for age_transformer in age_transformers:
+ print(f"Running on target age: {age_transformer.target_age}")
+ global_i = 0
+ for input_batch in tqdm(dataloader):
+ if global_i >= opts.n_images:
+ break
+ with torch.no_grad():
+ input_age_batch = [age_transformer(img.cpu()).to('cuda') for img in input_batch]
+ input_age_batch = torch.stack(input_age_batch)
+ input_cuda = input_age_batch.cuda().float()
+ tic = time.time()
+ result_batch = run_on_batch(input_cuda, net, opts)
+ toc = time.time()
+ global_time.append(toc - tic)
+
+ for i in range(len(input_batch)):
+ result = tensor2im(result_batch[i])
+ im_path = dataset.paths[global_i]
+
+ if opts.couple_outputs or global_i % 100 == 0:
+ input_im = log_image(input_batch[i], opts)
+ resize_amount = (256, 256) if opts.resize_outputs else (1024, 1024)
+ res = np.concatenate([np.array(input_im.resize(resize_amount)),
+ np.array(result.resize(resize_amount))], axis=1)
+ age_out_path_coupled = os.path.join(out_path_coupled, age_transformer.target_age)
+ os.makedirs(age_out_path_coupled, exist_ok=True)
+ Image.fromarray(res).save(os.path.join(age_out_path_coupled, os.path.basename(im_path)))
+
+ age_out_path_results = os.path.join(out_path_results, age_transformer.target_age)
+ os.makedirs(age_out_path_results, exist_ok=True)
+ image_name = os.path.basename(im_path)
+ im_save_path = os.path.join(age_out_path_results, image_name)
+ Image.fromarray(np.array(result.resize(resize_amount))).save(im_save_path)
+ global_i += 1
+
+ stats_path = os.path.join(opts.exp_dir, 'stats.txt')
+ result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time), np.std(global_time))
+ print(result_str)
+
+ with open(stats_path, 'w') as f:
+ f.write(result_str)
+
+
+def run_on_batch(inputs, net, opts):
+ result_batch = net(inputs, randomize_noise=False, resize=opts.resize_outputs)
+ return result_batch
+
+
+if __name__ == '__main__':
+ run()
diff --git a/scripts/inference_side_by_side.py b/scripts/inference_side_by_side.py
new file mode 100644
index 0000000000000000000000000000000000000000..19edacc59c1add1cb76af46b326d9305aa167104
--- /dev/null
+++ b/scripts/inference_side_by_side.py
@@ -0,0 +1,96 @@
+from argparse import Namespace
+import os
+import time
+from tqdm import tqdm
+from PIL import Image
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+
+import sys
+sys.path.append(".")
+sys.path.append("..")
+
+from configs import data_configs
+from datasets.inference_dataset import InferenceDataset
+from datasets.augmentations import AgeTransformer
+from utils.common import tensor2im, log_image
+from options.test_options import TestOptions
+from models.psp import pSp
+
+
+def run():
+ test_opts = TestOptions().parse()
+
+ out_path_results = os.path.join(test_opts.exp_dir, 'inference_side_by_side')
+ os.makedirs(out_path_results, exist_ok=True)
+
+ # update test options with options used during training
+ ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
+ opts = ckpt['opts']
+ opts.update(vars(test_opts))
+ opts = Namespace(**opts)
+
+ net = pSp(opts)
+ net.eval()
+ net.cuda()
+
+ age_transformers = [AgeTransformer(target_age=age) for age in opts.target_age.split(',')]
+
+ print(f'Loading dataset for {opts.dataset_type}')
+ dataset_args = data_configs.DATASETS[opts.dataset_type]
+ transforms_dict = dataset_args['transforms'](opts).get_transforms()
+ dataset = InferenceDataset(root=opts.data_path,
+ transform=transforms_dict['transform_inference'],
+ opts=opts,
+ return_path=True)
+ dataloader = DataLoader(dataset,
+ batch_size=opts.test_batch_size,
+ shuffle=False,
+ num_workers=int(opts.test_workers),
+ drop_last=False)
+
+ if opts.n_images is None:
+ opts.n_images = len(dataset)
+
+ global_time = []
+ global_i = 0
+ for input_batch, image_paths in tqdm(dataloader):
+ if global_i >= opts.n_images:
+ break
+ batch_results = {}
+ for idx, age_transformer in enumerate(age_transformers):
+ with torch.no_grad():
+ input_age_batch = [age_transformer(img.cpu()).to('cuda') for img in input_batch]
+ input_age_batch = torch.stack(input_age_batch)
+ input_cuda = input_age_batch.cuda().float()
+ tic = time.time()
+ result_batch = run_on_batch(input_cuda, net, opts)
+ toc = time.time()
+ global_time.append(toc - tic)
+
+ resize_amount = (256, 256) if opts.resize_outputs else (1024, 1024)
+ for i in range(len(input_batch)):
+ result = tensor2im(result_batch[i])
+ im_path = image_paths[i]
+ input_im = log_image(input_batch[i], opts)
+ if im_path not in batch_results.keys():
+ batch_results[im_path] = np.array(input_im.resize(resize_amount))
+ batch_results[im_path] = np.concatenate([batch_results[im_path],
+ np.array(result.resize(resize_amount))],
+ axis=1)
+
+ for im_path, res in batch_results.items():
+ image_name = os.path.basename(im_path)
+ im_save_path = os.path.join(out_path_results, image_name)
+ Image.fromarray(np.array(res)).save(im_save_path)
+ global_i += 1
+
+
+def run_on_batch(inputs, net, opts):
+ result_batch = net(inputs, randomize_noise=False, resize=opts.resize_outputs)
+ return result_batch
+
+
+if __name__ == '__main__':
+ run()
diff --git a/scripts/reference_guided_inference.py b/scripts/reference_guided_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..a41ac9d01c76aa83e7cfb9bd8362d48c7029b46e
--- /dev/null
+++ b/scripts/reference_guided_inference.py
@@ -0,0 +1,137 @@
+from argparse import Namespace
+import os
+from tqdm import tqdm
+from PIL import Image
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+
+import sys
+sys.path.append(".")
+sys.path.append("..")
+
+from configs import data_configs
+from datasets.inference_dataset import InferenceDataset
+from datasets.augmentations import AgeTransformer
+from utils.common import log_image
+from options.test_options import TestOptions
+from models.psp import pSp
+
+
+def run():
+ test_opts = TestOptions().parse()
+
+ out_path_results = os.path.join(test_opts.exp_dir, 'reference_guided_inference')
+ os.makedirs(out_path_results, exist_ok=True)
+
+ # update test options with options used during training
+ ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
+ opts = ckpt['opts']
+ opts.update(vars(test_opts))
+ opts = Namespace(**opts)
+
+ net = pSp(opts)
+ net.eval()
+ net.cuda()
+
+ age_transformers = [AgeTransformer(target_age=age) for age in opts.target_age.split(',')]
+
+ print(f'Loading dataset for {opts.dataset_type}')
+ dataset_args = data_configs.DATASETS[opts.dataset_type]
+ transforms_dict = dataset_args['transforms'](opts).get_transforms()
+
+ source_dataset = InferenceDataset(root=opts.data_path,
+ transform=transforms_dict['transform_inference'],
+ opts=opts)
+ source_dataloader = DataLoader(source_dataset,
+ batch_size=opts.test_batch_size,
+ shuffle=False,
+ num_workers=int(opts.test_workers),
+ drop_last=False)
+
+ ref_dataset = InferenceDataset(paths_list=opts.ref_images_paths_file,
+ transform=transforms_dict['transform_inference'],
+ opts=opts)
+ ref_dataloader = DataLoader(ref_dataset,
+ batch_size=1,
+ shuffle=False,
+ num_workers=1,
+ drop_last=False)
+
+ if opts.n_images is None:
+ opts.n_images = len(source_dataset)
+
+ for age_transformer in age_transformers:
+ target_age = age_transformer.target_age
+ print(f"Running on target age: {target_age}")
+ age_save_path = os.path.join(out_path_results, str(target_age))
+ os.makedirs(age_save_path, exist_ok=True)
+ global_i = 0
+ for i, source_batch in enumerate(tqdm(source_dataloader)):
+ if global_i >= opts.n_images:
+ break
+ results_per_source = {idx: [] for idx in range(len(source_batch))}
+ with torch.no_grad():
+ for ref_batch in ref_dataloader:
+ source_batch = source_batch.cuda().float()
+ ref_batch = ref_batch.cuda().float()
+ source_input_age_batch = [age_transformer(img.cpu()).to('cuda') for img in source_batch]
+ source_input_age_batch = torch.stack(source_input_age_batch)
+
+ # compute w+ of ref images to be injected for style-mixing
+ ref_latents = net.pretrained_encoder(ref_batch) + net.latent_avg
+
+ # run age transformation on source images with style-mixing
+ res_batch_mixed = run_on_batch(source_input_age_batch, net, opts, latent_to_inject=ref_latents)
+
+ # store results
+ for idx in range(len(source_batch)):
+ results_per_source[idx].append([ref_batch[0], res_batch_mixed[idx]])
+
+ # save results
+ resize_amount = (256, 256) if opts.resize_outputs else (1024, 1024)
+ for image_idx, image_results in results_per_source.items():
+ input_im_path = source_dataset.paths[global_i]
+ image = source_batch[image_idx]
+ input_image = log_image(image, opts)
+ # initialize results image
+ ref_inputs = np.zeros_like(input_image.resize(resize_amount))
+ mixing_results = np.array(input_image.resize(resize_amount))
+ for ref_idx in range(len(image_results)):
+ ref_input, mixing_result = image_results[ref_idx]
+ ref_input = log_image(ref_input, opts)
+ mixing_result = log_image(mixing_result, opts)
+ # append current results
+ ref_inputs = np.concatenate([ref_inputs,
+ np.array(ref_input.resize(resize_amount))],
+ axis=1)
+ mixing_results = np.concatenate([mixing_results,
+ np.array(mixing_result.resize(resize_amount))],
+ axis=1)
+ res = np.concatenate([ref_inputs, mixing_results], axis=0)
+ save_path = os.path.join(age_save_path, os.path.basename(input_im_path))
+ Image.fromarray(res).save(save_path)
+ global_i += 1
+
+
+def run_on_batch(inputs, net, opts, latent_to_inject=None):
+ if opts.latent_mask is None:
+ result_batch = net(inputs, randomize_noise=False, resize=opts.resize_outputs)
+ else:
+ latent_mask = [int(l) for l in opts.latent_mask.split(",")]
+ result_batch = []
+ for image_idx, input_image in enumerate(inputs):
+ # get output image with injected style vector
+ res, res_latent = net(input_image.unsqueeze(0).to("cuda").float(),
+ latent_mask=latent_mask,
+ inject_latent=latent_to_inject,
+ alpha=opts.mix_alpha,
+ resize=opts.resize_outputs,
+ return_latents=True)
+ result_batch.append(res)
+ result_batch = torch.cat(result_batch, dim=0)
+ return result_batch
+
+
+if __name__ == '__main__':
+ run()
diff --git a/scripts/style_mixing.py b/scripts/style_mixing.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc9d7c4e189d75ef9bca834cb69db3374d06ed5c
--- /dev/null
+++ b/scripts/style_mixing.py
@@ -0,0 +1,96 @@
+from argparse import Namespace
+import os
+from tqdm import tqdm
+from PIL import Image
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+
+import sys
+sys.path.append(".")
+sys.path.append("..")
+
+from datasets.augmentations import AgeTransformer
+from configs import data_configs
+from datasets.inference_dataset import InferenceDataset
+from utils.common import tensor2im, log_image
+from options.test_options import TestOptions
+from models.psp import pSp
+
+
+def run():
+ test_opts = TestOptions().parse()
+
+ assert len(test_opts.target_age.split(',')) == 1, "Style-mixing supports only one target age!"
+
+ mixed_path_results = os.path.join(test_opts.exp_dir, 'style_mixing', str(test_opts.target_age))
+ os.makedirs(mixed_path_results, exist_ok=True)
+
+ # update test options with options used during training
+ ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
+ opts = ckpt['opts']
+ opts.update(vars(test_opts))
+ opts = Namespace(**opts)
+
+ net = pSp(opts)
+ net.eval()
+ net.cuda()
+
+ print(f'Loading dataset for {opts.dataset_type}')
+ dataset_args = data_configs.DATASETS[opts.dataset_type]
+ transforms_dict = dataset_args['transforms'](opts).get_transforms()
+ dataset = InferenceDataset(root=opts.data_path,
+ transform=transforms_dict['transform_inference'],
+ opts=opts)
+ dataloader = DataLoader(dataset,
+ batch_size=opts.test_batch_size,
+ shuffle=False,
+ num_workers=int(opts.test_workers),
+ drop_last=True)
+
+ age_transformer = AgeTransformer(target_age=opts.target_age)
+
+ latent_mask = [int(l) for l in opts.latent_mask.split(",")]
+ if opts.n_images is None:
+ opts.n_images = len(dataset)
+
+ global_i = 0
+ for i, input_batch in enumerate(tqdm(dataloader)):
+ if global_i >= opts.n_images:
+ break
+ with torch.no_grad():
+ input_age_batch = [age_transformer(img.cpu()).to('cuda') for img in input_batch]
+ input_age_batch = torch.stack(input_age_batch)
+ for image_idx, input_image in enumerate(input_age_batch):
+ # generate random vectors to inject into input image
+ vecs_to_inject = np.random.randn(opts.n_outputs_to_generate, 512).astype('float32')
+ multi_modal_outputs = []
+ for vec_to_inject in vecs_to_inject:
+ cur_vec = torch.from_numpy(vec_to_inject).unsqueeze(0).to("cuda")
+ # get latent vector to inject into our input image
+ _, latent_to_inject = net(cur_vec,
+ input_code=True,
+ return_latents=True)
+ # get output image with injected style vector
+ res = net(input_image.unsqueeze(0).to("cuda").float(),
+ latent_mask=latent_mask,
+ inject_latent=latent_to_inject,
+ alpha=opts.mix_alpha,
+ resize=opts.resize_outputs)
+ multi_modal_outputs.append(res[0])
+
+ # visualize multi modal outputs
+ input_im_path = dataset.paths[global_i]
+ image = input_batch[image_idx]
+ input_image = log_image(image, opts)
+ resize_amount = (256, 256) if opts.resize_outputs else (1024, 1024)
+ res = np.array(input_image.resize(resize_amount))
+ for output in multi_modal_outputs:
+ output = tensor2im(output)
+ res = np.concatenate([res, np.array(output.resize(resize_amount))], axis=1)
+ Image.fromarray(res).save(os.path.join(mixed_path_results, os.path.basename(input_im_path)))
+ global_i += 1
+
+
+if __name__ == '__main__':
+ run()
diff --git a/scripts/train.py b/scripts/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c5ff4ad14b32440d0f32d1d7b07c182bfdb3b2e
--- /dev/null
+++ b/scripts/train.py
@@ -0,0 +1,30 @@
+"""
+This file runs the main training/val loop
+"""
+import os
+import json
+import sys
+import pprint
+
+sys.path.append(".")
+sys.path.append("..")
+
+from options.train_options import TrainOptions
+from training.coach_aging import Coach
+
+
+def main():
+ opts = TrainOptions().parse()
+ os.makedirs(opts.exp_dir, exist_ok=True)
+
+ opts_dict = vars(opts)
+ pprint.pprint(opts_dict)
+ with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f:
+ json.dump(opts_dict, f, indent=4, sort_keys=True)
+
+ coach = Coach(opts)
+ coach.train()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/training/__init__.py b/training/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/training/coach_aging.py b/training/coach_aging.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f8c2d3c2ed0014f96131a80111435bdaff2f340
--- /dev/null
+++ b/training/coach_aging.py
@@ -0,0 +1,365 @@
+import os
+import random
+import matplotlib
+import matplotlib.pyplot as plt
+matplotlib.use('Agg')
+
+import torch
+from torch import nn
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+import torch.nn.functional as F
+
+from utils import common, train_utils
+from criteria import id_loss, w_norm
+from configs import data_configs
+from datasets.images_dataset import ImagesDataset
+from datasets.augmentations import AgeTransformer
+from criteria.lpips.lpips import LPIPS
+from criteria.aging_loss import AgingLoss
+from models.psp import pSp
+from training.ranger import Ranger
+
+
+class Coach:
+ def __init__(self, opts):
+ self.opts = opts
+
+ self.global_step = 0
+
+ self.device = 'cuda'
+ self.opts.device = self.device
+
+ # Initialize network
+ self.net = pSp(self.opts).to(self.device)
+
+ # Initialize loss
+ self.mse_loss = nn.MSELoss().to(self.device).eval()
+ if self.opts.lpips_lambda > 0:
+ self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval()
+ if self.opts.id_lambda > 0:
+ self.id_loss = id_loss.IDLoss().to(self.device).eval()
+ if self.opts.w_norm_lambda > 0:
+ self.w_norm_loss = w_norm.WNormLoss(opts=self.opts)
+ if self.opts.aging_lambda > 0:
+ self.aging_loss = AgingLoss(self.opts)
+
+ # Initialize optimizer
+ self.optimizer = self.configure_optimizers()
+
+ # Initialize dataset
+ self.train_dataset, self.test_dataset = self.configure_datasets()
+ self.train_dataloader = DataLoader(self.train_dataset,
+ batch_size=self.opts.batch_size,
+ shuffle=True,
+ num_workers=int(self.opts.workers),
+ drop_last=True)
+ self.test_dataloader = DataLoader(self.test_dataset,
+ batch_size=self.opts.test_batch_size,
+ shuffle=False,
+ num_workers=int(self.opts.test_workers),
+ drop_last=True)
+
+ self.age_transformer = AgeTransformer(target_age=self.opts.target_age)
+
+ # Initialize logger
+ log_dir = os.path.join(opts.exp_dir, 'logs')
+ os.makedirs(log_dir, exist_ok=True)
+ self.logger = SummaryWriter(log_dir=log_dir)
+
+ # Initialize checkpoint dir
+ self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints')
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
+ self.best_val_loss = None
+ if self.opts.save_interval is None:
+ self.opts.save_interval = self.opts.max_steps
+
+ def perform_forward_pass(self, x):
+ y_hat, latent = self.net.forward(x, return_latents=True)
+ return y_hat, latent
+
+ def __set_target_to_source(self, x, input_ages):
+ return [torch.cat((img, age * torch.ones((1, img.shape[1], img.shape[2])).to(self.device)))
+ for img, age in zip(x, input_ages)]
+
+ def train(self):
+ self.net.train()
+ while self.global_step < self.opts.max_steps:
+ for batch_idx, batch in enumerate(self.train_dataloader):
+ x, y = batch
+ x, y = x.to(self.device).float(), y.to(self.device).float()
+ self.optimizer.zero_grad()
+
+ input_ages = self.aging_loss.extract_ages(x) / 100.
+
+ # perform no aging in 33% of the time
+ no_aging = random.random() <= (1. / 3)
+ if no_aging:
+ x_input = self.__set_target_to_source(x=x, input_ages=input_ages)
+ else:
+ x_input = [self.age_transformer(img.cpu()).to(self.device) for img in x]
+
+ x_input = torch.stack(x_input)
+ target_ages = x_input[:, -1, 0, 0]
+
+ # perform forward/backward pass on real images
+ y_hat, latent = self.perform_forward_pass(x_input)
+ loss, loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent,
+ target_ages=target_ages,
+ input_ages=input_ages,
+ no_aging=no_aging,
+ data_type="real")
+ loss.backward()
+
+ # perform cycle on generate images by setting the target ages to the original input ages
+ y_hat_clone = y_hat.clone().detach().requires_grad_(True)
+ input_ages_clone = input_ages.clone().detach().requires_grad_(True)
+ y_hat_inverse = self.__set_target_to_source(x=y_hat_clone, input_ages=input_ages_clone)
+ y_hat_inverse = torch.stack(y_hat_inverse)
+ reverse_target_ages = y_hat_inverse[:, -1, 0, 0]
+ y_recovered, latent_cycle = self.perform_forward_pass(y_hat_inverse)
+ loss, cycle_loss_dict, cycle_id_logs = self.calc_loss(x, y, y_recovered, latent_cycle,
+ target_ages=reverse_target_ages,
+ input_ages=input_ages,
+ no_aging=no_aging,
+ data_type="cycle")
+ loss.backward()
+ self.optimizer.step()
+
+ # combine the logs of both forwards
+ for idx, cycle_log in enumerate(cycle_id_logs):
+ id_logs[idx].update(cycle_log)
+ loss_dict.update(cycle_loss_dict)
+ loss_dict["loss"] = loss_dict["loss_real"] + loss_dict["loss_cycle"]
+
+ # Logging related
+ if self.global_step % self.opts.image_interval == 0 or \
+ (self.global_step < 1000 and self.global_step % 25 == 0):
+ self.parse_and_log_images(id_logs, x, y, y_hat, y_recovered,
+ title='images/train/faces')
+
+ if self.global_step % self.opts.board_interval == 0:
+ self.print_metrics(loss_dict, prefix='train')
+ self.log_metrics(loss_dict, prefix='train')
+
+ # Validation related
+ val_loss_dict = None
+ if self.global_step % self.opts.val_interval == 0 or self.global_step == self.opts.max_steps:
+ val_loss_dict = self.validate()
+ if val_loss_dict and (self.best_val_loss is None or val_loss_dict['loss'] < self.best_val_loss):
+ self.best_val_loss = val_loss_dict['loss']
+ self.checkpoint_me(val_loss_dict, is_best=True)
+
+ if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps:
+ if val_loss_dict is not None:
+ self.checkpoint_me(val_loss_dict, is_best=False)
+ else:
+ self.checkpoint_me(loss_dict, is_best=False)
+
+ if self.global_step == self.opts.max_steps:
+ print('OMG, finished training!')
+ break
+
+ self.global_step += 1
+
+ def validate(self):
+ self.net.eval()
+ agg_loss_dict = []
+ for batch_idx, batch in enumerate(self.test_dataloader):
+ x, y = batch
+ with torch.no_grad():
+ x, y = x.to(self.device).float(), y.to(self.device).float()
+
+ input_ages = self.aging_loss.extract_ages(x) / 100.
+
+ # perform no aging in 33% of the time
+ no_aging = random.random() <= (1. / 3)
+ if no_aging:
+ x_input = self.__set_target_to_source(x=x, input_ages=input_ages)
+ else:
+ x_input = [self.age_transformer(img.cpu()).to(self.device) for img in x]
+
+ x_input = torch.stack(x_input)
+ target_ages = x_input[:, -1, 0, 0]
+
+ # perform forward/backward pass on real images
+ y_hat, latent = self.perform_forward_pass(x_input)
+ _, cur_loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent,
+ target_ages=target_ages,
+ input_ages=input_ages,
+ no_aging=no_aging,
+ data_type="real")
+
+ # perform cycle on generate images by setting the target ages to the original input ages
+ y_hat_inverse = self.__set_target_to_source(x=y_hat, input_ages=input_ages)
+ y_hat_inverse = torch.stack(y_hat_inverse)
+ reverse_target_ages = y_hat_inverse[:, -1, 0, 0]
+ y_recovered, latent_cycle = self.perform_forward_pass(y_hat_inverse)
+ loss, cycle_loss_dict, cycle_id_logs = self.calc_loss(x, y, y_recovered, latent_cycle,
+ target_ages=reverse_target_ages,
+ input_ages=input_ages,
+ no_aging=no_aging,
+ data_type="cycle")
+
+ # combine the logs of both forwards
+ for idx, cycle_log in enumerate(cycle_id_logs):
+ id_logs[idx].update(cycle_log)
+ cur_loss_dict.update(cycle_loss_dict)
+ cur_loss_dict["loss"] = cur_loss_dict["loss_real"] + cur_loss_dict["loss_cycle"]
+
+ agg_loss_dict.append(cur_loss_dict)
+
+ # Logging related
+ self.parse_and_log_images(id_logs, x, y, y_hat, y_recovered, title='images/test/faces',
+ subscript='{:04d}'.format(batch_idx))
+
+ # For first step just do sanity test on small amount of data
+ if self.global_step == 0 and batch_idx >= 4:
+ self.net.train()
+ return None # Do not log, inaccurate in first batch
+
+ loss_dict = train_utils.aggregate_loss_dict(agg_loss_dict)
+ self.log_metrics(loss_dict, prefix='test')
+ self.print_metrics(loss_dict, prefix='test')
+
+ self.net.train()
+ return loss_dict
+
+ def checkpoint_me(self, loss_dict, is_best):
+ save_name = 'best_model.pt' if is_best else f'iteration_{self.global_step}.pt'
+ save_dict = self.__get_save_dict()
+ checkpoint_path = os.path.join(self.checkpoint_dir, save_name)
+ torch.save(save_dict, checkpoint_path)
+ with open(os.path.join(self.checkpoint_dir, 'timestamp.txt'), 'a') as f:
+ if is_best:
+ f.write('**Best**: Step - {}, '
+ 'Loss - {:.3f} \n{}\n'.format(self.global_step, self.best_val_loss, loss_dict))
+ else:
+ f.write(f'Step - {self.global_step}, \n{loss_dict}\n')
+
+ def configure_optimizers(self):
+ params = list(self.net.encoder.parameters())
+ if self.opts.train_decoder:
+ params += list(self.net.decoder.parameters())
+ if self.opts.optim_name == 'adam':
+ optimizer = torch.optim.Adam(params, lr=self.opts.learning_rate)
+ else:
+ optimizer = Ranger(params, lr=self.opts.learning_rate)
+ return optimizer
+
+ def configure_datasets(self):
+ if self.opts.dataset_type not in data_configs.DATASETS.keys():
+ Exception(f'{self.opts.dataset_type} is not a valid dataset_type')
+ print(f'Loading dataset for {self.opts.dataset_type}')
+ dataset_args = data_configs.DATASETS[self.opts.dataset_type]
+ transforms_dict = dataset_args['transforms'](self.opts).get_transforms()
+ train_dataset = ImagesDataset(source_root=dataset_args['train_source_root'],
+ target_root=dataset_args['train_target_root'],
+ source_transform=transforms_dict['transform_source'],
+ target_transform=transforms_dict['transform_gt_train'],
+ opts=self.opts)
+ test_dataset = ImagesDataset(source_root=dataset_args['test_source_root'],
+ target_root=dataset_args['test_target_root'],
+ source_transform=transforms_dict['transform_source'],
+ target_transform=transforms_dict['transform_test'],
+ opts=self.opts)
+ print(f"Number of training samples: {len(train_dataset)}")
+ print(f"Number of test samples: {len(test_dataset)}")
+ return train_dataset, test_dataset
+
+ def calc_loss(self, x, y, y_hat, latent, target_ages, input_ages, no_aging, data_type="real"):
+ loss_dict = {}
+ id_logs = []
+ loss = 0.0
+ if self.opts.id_lambda > 0:
+ weights = None
+ if self.opts.use_weighted_id_loss: # compute weighted id loss only on forward pass
+ age_diffs = torch.abs(target_ages - input_ages)
+ weights = train_utils.compute_cosine_weights(x=age_diffs)
+ loss_id, sim_improvement, id_logs = self.id_loss(y_hat, y, x, label=data_type, weights=weights)
+ loss_dict[f'loss_id_{data_type}'] = float(loss_id)
+ loss_dict[f'id_improve_{data_type}'] = float(sim_improvement)
+ loss = loss_id * self.opts.id_lambda
+ if self.opts.l2_lambda > 0:
+ loss_l2 = F.mse_loss(y_hat, y)
+ loss_dict[f'loss_l2_{data_type}'] = float(loss_l2)
+ if data_type == "real" and not no_aging:
+ l2_lambda = self.opts.l2_lambda_aging
+ else:
+ l2_lambda = self.opts.l2_lambda
+ loss += loss_l2 * l2_lambda
+ if self.opts.lpips_lambda > 0:
+ loss_lpips = self.lpips_loss(y_hat, y)
+ loss_dict[f'loss_lpips_{data_type}'] = float(loss_lpips)
+ if data_type == "real" and not no_aging:
+ lpips_lambda = self.opts.lpips_lambda_aging
+ else:
+ lpips_lambda = self.opts.lpips_lambda
+ loss += loss_lpips * lpips_lambda
+ if self.opts.lpips_lambda_crop > 0:
+ loss_lpips_crop = self.lpips_loss(y_hat[:, :, 35:223, 32:220], y[:, :, 35:223, 32:220])
+ loss_dict['loss_lpips_crop'] = float(loss_lpips_crop)
+ loss += loss_lpips_crop * self.opts.lpips_lambda_crop
+ if self.opts.l2_lambda_crop > 0:
+ loss_l2_crop = F.mse_loss(y_hat[:, :, 35:223, 32:220], y[:, :, 35:223, 32:220])
+ loss_dict['loss_l2_crop'] = float(loss_l2_crop)
+ loss += loss_l2_crop * self.opts.l2_lambda_crop
+ if self.opts.w_norm_lambda > 0:
+ loss_w_norm = self.w_norm_loss(latent, latent_avg=self.net.latent_avg)
+ loss_dict[f'loss_w_norm_{data_type}'] = float(loss_w_norm)
+ loss += loss_w_norm * self.opts.w_norm_lambda
+ if self.opts.aging_lambda > 0:
+ aging_loss, id_logs = self.aging_loss(y_hat, y, target_ages, id_logs, label=data_type)
+ loss_dict[f'loss_aging_{data_type}'] = float(aging_loss)
+ loss += aging_loss * self.opts.aging_lambda
+ loss_dict[f'loss_{data_type}'] = float(loss)
+ if data_type == "cycle":
+ loss = loss * self.opts.cycle_lambda
+ return loss, loss_dict, id_logs
+
+ def log_metrics(self, metrics_dict, prefix):
+ for key, value in metrics_dict.items():
+ self.logger.add_scalar(f'{prefix}/{key}', value, self.global_step)
+
+ def print_metrics(self, metrics_dict, prefix):
+ print(f'Metrics for {prefix}, step {self.global_step}')
+ for key, value in metrics_dict.items():
+ print(f'\t{key} = ', value)
+
+ def parse_and_log_images(self, id_logs, x, y, y_hat, y_recovered, title, subscript=None, display_count=2):
+ im_data = []
+ for i in range(display_count):
+ cur_im_data = {
+ 'input_face': common.tensor2im(x[i]),
+ 'target_face': common.tensor2im(y[i]),
+ 'output_face': common.tensor2im(y_hat[i]),
+ 'recovered_face': common.tensor2im(y_recovered[i])
+ }
+ if id_logs is not None:
+ for key in id_logs[i]:
+ cur_im_data[key] = id_logs[i][key]
+ im_data.append(cur_im_data)
+ self.log_images(title, im_data=im_data, subscript=subscript)
+
+ def log_images(self, name, im_data, subscript=None, log_latest=False):
+ fig = common.vis_faces(im_data)
+ step = self.global_step
+ if log_latest:
+ step = 0
+ if subscript:
+ path = os.path.join(self.logger.log_dir, name, '{}_{:04d}.jpg'.format(subscript, step))
+ else:
+ path = os.path.join(self.logger.log_dir, name, '{:04d}.jpg'.format(step))
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ fig.savefig(path)
+ plt.close(fig)
+
+ def __get_save_dict(self):
+ save_dict = {
+ 'state_dict': self.net.state_dict(),
+ 'opts': vars(self.opts)
+ }
+ # save the latent avg in state_dict for inference if truncation of w was used during training
+ if self.net.latent_avg is not None:
+ save_dict['latent_avg'] = self.net.latent_avg
+ return save_dict
diff --git a/training/ranger.py b/training/ranger.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d63264dda6df0ee40cac143440f0b5f8977a9ad
--- /dev/null
+++ b/training/ranger.py
@@ -0,0 +1,164 @@
+# Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer.
+
+# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
+# and/or
+# https://github.com/lessw2020/Best-Deep-Learning-Optimizers
+
+# Ranger has now been used to capture 12 records on the FastAI leaderboard.
+
+# This version = 20.4.11
+
+# Credits:
+# Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization
+# RAdam --> https://github.com/LiyuanLucasLiu/RAdam
+# Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code.
+# Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610
+
+# summary of changes:
+# 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init.
+# full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights),
+# supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues.
+# changes 8/31/19 - fix references to *self*.N_sma_threshold;
+# changed eps to 1e-5 as better default than 1e-8.
+
+import math
+import torch
+from torch.optim.optimizer import Optimizer
+
+
+class Ranger(Optimizer):
+
+ def __init__(self, params, lr=1e-3, # lr
+ alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options
+ betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options
+ use_gc=True, gc_conv_only=False
+ # Gradient centralization on or off, applied to conv layers only or conv + fc layers
+ ):
+
+ # parameter checks
+ if not 0.0 <= alpha <= 1.0:
+ raise ValueError(f'Invalid slow update rate: {alpha}')
+ if not 1 <= k:
+ raise ValueError(f'Invalid lookahead steps: {k}')
+ if not lr > 0:
+ raise ValueError(f'Invalid Learning Rate: {lr}')
+ if not eps > 0:
+ raise ValueError(f'Invalid eps: {eps}')
+
+ # parameter comments:
+ # beta1 (momentum) of .95 seems to work better than .90...
+ # N_sma_threshold of 5 seems better in testing than 4.
+ # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
+
+ # prep defaults and init torch.optim base
+ defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold,
+ eps=eps, weight_decay=weight_decay)
+ super().__init__(params, defaults)
+
+ # adjustable threshold
+ self.N_sma_threshhold = N_sma_threshhold
+
+ # look ahead params
+
+ self.alpha = alpha
+ self.k = k
+
+ # radam buffer for state
+ self.radam_buffer = [[None, None, None] for ind in range(10)]
+
+ # gc on or off
+ self.use_gc = use_gc
+
+ # level of gradient centralization
+ self.gc_gradient_threshold = 3 if gc_conv_only else 1
+
+ def __setstate__(self, state):
+ super(Ranger, self).__setstate__(state)
+
+ def step(self, closure=None):
+ loss = None
+
+ # Evaluate averages and grad, update param tensors
+ for group in self.param_groups:
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad.data.float()
+
+ if grad.is_sparse:
+ raise RuntimeError('Ranger optimizer does not support sparse gradients')
+
+ p_data_fp32 = p.data.float()
+
+ state = self.state[p] # get state dict for this param
+
+ if len(state) == 0: # if first time to run...init dictionary with our desired entries
+ # if self.first_run_check==0:
+ # self.first_run_check=1
+ # print("Initializing slow buffer...should not see this at load from saved model!")
+ state['step'] = 0
+ state['exp_avg'] = torch.zeros_like(p_data_fp32)
+ state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
+
+ # look ahead weight storage now in state dict
+ state['slow_buffer'] = torch.empty_like(p.data)
+ state['slow_buffer'].copy_(p.data)
+
+ else:
+ state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
+ state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
+
+ # begin computations
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+ beta1, beta2 = group['betas']
+
+ # GC operation for Conv layers and FC layers
+ if grad.dim() > self.gc_gradient_threshold:
+ grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True))
+
+ state['step'] += 1
+
+ # compute variance mov avg
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
+ # compute mean moving avg
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
+
+ buffered = self.radam_buffer[int(state['step'] % 10)]
+
+ if state['step'] == buffered[0]:
+ N_sma, step_size = buffered[1], buffered[2]
+ else:
+ buffered[0] = state['step']
+ beta2_t = beta2 ** state['step']
+ N_sma_max = 2 / (1 - beta2) - 1
+ N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
+ buffered[1] = N_sma
+ if N_sma > self.N_sma_threshhold:
+ step_size = math.sqrt(
+ (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
+ N_sma_max - 2)) / (1 - beta1 ** state['step'])
+ else:
+ step_size = 1.0 / (1 - beta1 ** state['step'])
+ buffered[2] = step_size
+
+ if group['weight_decay'] != 0:
+ p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
+
+ # apply lr
+ if N_sma > self.N_sma_threshhold:
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
+ p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
+ else:
+ p_data_fp32.add_(-step_size * group['lr'], exp_avg)
+
+ p.data.copy_(p_data_fp32)
+
+ # integrated look ahead...
+ # we do it at the param level instead of group level
+ if state['step'] % group['k'] == 0:
+ slow_p = state['slow_buffer'] # get access to slow param tensor
+ slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha
+ p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor
+
+ return loss
\ No newline at end of file
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/utils/__pycache__/__init__.cpython-310.pyc b/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..62e6f7b85372d747f76f305874c2ad4a890894e5
Binary files /dev/null and b/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/utils/__pycache__/common.cpython-310.pyc b/utils/__pycache__/common.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d0cefa3d2e6fcb1b2b9091a2a8abd9c095b19fd8
Binary files /dev/null and b/utils/__pycache__/common.cpython-310.pyc differ
diff --git a/utils/common.py b/utils/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..002901428cedc2b9556c92a63b1aec63b22190c4
--- /dev/null
+++ b/utils/common.py
@@ -0,0 +1,47 @@
+from PIL import Image
+import matplotlib.pyplot as plt
+
+
+# Log images
+def log_image(x, opts):
+ return tensor2im(x)
+
+
+def tensor2im(var):
+ var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy()
+ var = ((var + 1) / 2)
+ var[var < 0] = 0
+ var[var > 1] = 1
+ var = var * 255
+ return Image.fromarray(var.astype('uint8'))
+
+
+def vis_faces(log_hooks):
+ display_count = len(log_hooks)
+ fig = plt.figure(figsize=(12, 4 * display_count))
+ gs = fig.add_gridspec(display_count, 4)
+ for i in range(display_count):
+ hooks_dict = log_hooks[i]
+ vis_faces_with_age(hooks_dict, fig, gs, i)
+ plt.tight_layout()
+ return fig
+
+
+def vis_faces_with_age(hooks_dict, fig, gs, i):
+ fig.add_subplot(gs[i, 0])
+ plt.imshow(hooks_dict['input_face'])
+ plt.title('Input\nOut Sim={:.2f}\nInput Age={:.2f}'.format(float(hooks_dict['diff_input_real']),
+ float(hooks_dict['input_age_real'])))
+ fig.add_subplot(gs[i, 1])
+ plt.imshow(hooks_dict['target_face'])
+ plt.title('Target\nIn={:.2f},Out={:.2f}\nTarget Age={:.2f}'.format(float(hooks_dict['diff_views_real']),
+ float(hooks_dict['diff_target_real']),
+ float(hooks_dict['target_age_real'])))
+ fig.add_subplot(gs[i, 2])
+ plt.imshow(hooks_dict['output_face'])
+ plt.title('Output\nTarget Sim={:.2f}\nOuput Age={:.2f}'.format(float(hooks_dict['diff_target_real']),
+ float(hooks_dict['output_age_real'])))
+ fig.add_subplot(gs[i, 3])
+ plt.imshow(hooks_dict['recovered_face'])
+ plt.title('Recovered\nTarget Sim={:.2f}\nOuput Age={:.2f}'.format(float(hooks_dict['diff_target_cycle']),
+ float(hooks_dict['output_age_cycle'])))
diff --git a/utils/data_utils.py b/utils/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6f4a2396a0a0d62ee29a647cbab3bc79692b9bc
--- /dev/null
+++ b/utils/data_utils.py
@@ -0,0 +1,34 @@
+"""
+Code adopted from pix2pixHD:
+https://github.com/NVIDIA/pix2pixHD/blob/master/data/image_folder.py
+"""
+import os
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def make_dataset(dir):
+ images = []
+ assert os.path.isdir(dir), f'{dir} is not a valid directory'
+ for root, _, fnames in sorted(os.walk(dir)):
+ for fname in fnames:
+ if is_image_file(fname):
+ path = os.path.join(root, fname)
+ images.append(path)
+ return images
+
+
+def make_dataset_from_paths_list(paths_file):
+ assert os.path.exists(paths_file), f'{paths_file} is not a valid file'
+ with open(paths_file, "r") as f:
+ paths = f.readlines()
+ paths = [p.strip() for p in paths]
+ paths = [p for p in paths if is_image_file(p)]
+ return paths
\ No newline at end of file
diff --git a/utils/train_utils.py b/utils/train_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..678806c7c4c8b4272a7e144636aa8a89e4d0008c
--- /dev/null
+++ b/utils/train_utils.py
@@ -0,0 +1,24 @@
+import numpy as np
+
+
+def aggregate_loss_dict(agg_loss_dict):
+ mean_vals = {}
+ for output in agg_loss_dict:
+ for key in output:
+ mean_vals[key] = mean_vals.setdefault(key, []) + [output[key]]
+ for key in mean_vals:
+ if len(mean_vals[key]) > 0:
+ mean_vals[key] = sum(mean_vals[key]) / len(mean_vals[key])
+ else:
+ print(f'{key} has no value')
+ mean_vals[key] = 0
+ return mean_vals
+
+
+def compute_cosine_weights(x):
+ """ Computes weights to be used in the id loss function with minimum value of 0.5 and maximum value of 1. """
+ values = np.abs(x.cpu().detach().numpy())
+ assert np.min(values) >= 0. and np.max(values) <= 1., "Input values should be between 0. and 1!"
+ weights = 0.25 * (np.cos(np.pi * values)) + 0.75
+ return weights
+