diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..07f0db3339ad9053dc95b284c4ae14e014efff89 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,16 @@ +*.bin.* filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tar.gz filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text diff --git a/FAQ.md b/FAQ.md new file mode 100644 index 0000000000000000000000000000000000000000..caa8c08cfe4302eb8812c823569e8a0be30fa49c --- /dev/null +++ b/FAQ.md @@ -0,0 +1,9 @@ +# FAQ + +1. **What is the difference of `--netscale` and `outscale`?** + +A: TODO. + +1. **How to select models?** + +A: TODO. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..552a1eeaf01f4e7077013ed3496600c608f35202 --- /dev/null +++ b/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2021, Xintao Wang +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..b18403e062a7cd846692a462f57232e0734394fe --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,8 @@ +include assets/* +include inputs/* +include scripts/*.py +include inference_realesrgan.py +include VERSION +include LICENSE +include requirements.txt +include realesrgan/weights/README.md diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..36b007f172e4075a0c07957364e710f8cbd0e1b5 --- /dev/null +++ b/README.md @@ -0,0 +1,35 @@ +--- +title: Real ESRGAN +emoji: 🏃 +colorFrom: blue +colorTo: blue +sdk: gradio +sdk_version: 3.1.7 +app_file: app.py +pinned: false +duplicated_from: akhaliq/Real-ESRGAN +--- + +# Configuration + +`title`: _string_ +Display title for the Space + +`emoji`: _string_ +Space emoji (emoji-only character allowed) + +`colorFrom`: _string_ +Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray) + +`colorTo`: _string_ +Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray) + +`sdk`: _string_ +Can be either `gradio` or `streamlit` + +`app_file`: _string_ +Path to your main application file (which contains either `gradio` or `streamlit` Python code). +Path is relative to the root of the repository. + +`pinned`: _boolean_ +Whether the Space stays on top of your list. diff --git a/Training.md b/Training.md new file mode 100644 index 0000000000000000000000000000000000000000..64704e1d2e1f334984232afd12b245235b274a9e --- /dev/null +++ b/Training.md @@ -0,0 +1,100 @@ +# :computer: How to Train Real-ESRGAN + +The training codes have been released.
+Note that the codes have a lot of refactoring. So there may be some bugs/performance drops. Welcome to report issues and I will also retrain the models. + +## Overview + +The training has been divided into two stages. These two stages have the same data synthesis process and training pipeline, except for the loss functions. Specifically, + +1. We first train Real-ESRNet with L1 loss from the pre-trained model ESRGAN. +1. We then use the trained Real-ESRNet model as an initialization of the generator, and train the Real-ESRGAN with a combination of L1 loss, perceptual loss and GAN loss. + +## Dataset Preparation + +We use DF2K (DIV2K and Flickr2K) + OST datasets for our training. Only HR images are required.
+You can download from : + +1. DIV2K: http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip +2. Flickr2K: https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar +3. OST: https://openmmlab.oss-cn-hangzhou.aliyuncs.com/datasets/OST_dataset.zip + +For the DF2K dataset, we use a multi-scale strategy, *i.e.*, we downsample HR images to obtain several Ground-Truth images with different scales. + +We then crop DF2K images into sub-images for faster IO and processing. + +You need to prepare a txt file containing the image paths. The following are some examples in `meta_info_DF2Kmultiscale+OST_sub.txt` (As different users may have different sub-images partitions, this file is not suitable for your purpose and you need to prepare your own txt file): + +```txt +DF2K_HR_sub/000001_s001.png +DF2K_HR_sub/000001_s002.png +DF2K_HR_sub/000001_s003.png +... +``` + +## Train Real-ESRNet + +1. Download pre-trained model [ESRGAN](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth) into `experiments/pretrained_models`. + ```bash + wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth -P experiments/pretrained_models + ``` +1. Modify the content in the option file `options/train_realesrnet_x4plus.yml` accordingly: + ```yml + train: + name: DF2K+OST + type: RealESRGANDataset + dataroot_gt: datasets/DF2K # modify to the root path of your folder + meta_info: realesrgan/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt # modify to your own generate meta info txt + io_backend: + type: disk + ``` +1. If you want to perform validation during training, uncomment those lines and modify accordingly: + ```yml + # Uncomment these for validation + # val: + # name: validation + # type: PairedImageDataset + # dataroot_gt: path_to_gt + # dataroot_lq: path_to_lq + # io_backend: + # type: disk + + ... + + # Uncomment these for validation + # validation settings + # val: + # val_freq: !!float 5e3 + # save_img: True + + # metrics: + # psnr: # metric name, can be arbitrary + # type: calculate_psnr + # crop_border: 4 + # test_y_channel: false + ``` +1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training: + ```bash + CUDA_VISIBLE_DEVICES=0,1,2,3 \ + python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --debug + ``` +1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary. + ```bash + CUDA_VISIBLE_DEVICES=0,1,2,3 \ + python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --auto_resume + ``` + +## Train Real-ESRGAN + +1. After the training of Real-ESRNet, you now have the file `experiments/train_RealESRNetx4plus_1000k_B12G4_fromESRGAN/model/net_g_1000000.pth`. If you need to specify the pre-trained path to other files, modify the `pretrain_network_g` value in the option file `train_realesrgan_x4plus.yml`. +1. Modify the option file `train_realesrgan_x4plus.yml` accordingly. Most modifications are similar to those listed above. +1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training: + ```bash + CUDA_VISIBLE_DEVICES=0,1,2,3 \ + python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --debug + ``` +1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary. + ```bash + CUDA_VISIBLE_DEVICES=0,1,2,3 \ + python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --auto_resume + ``` diff --git a/VERSION b/VERSION new file mode 100644 index 0000000000000000000000000000000000000000..517717a3cfc4cdaf7c7363cbb090b3abe5bdeecd --- /dev/null +++ b/VERSION @@ -0,0 +1 @@ +0.2.3.0 diff --git a/anime.png b/anime.png new file mode 100644 index 0000000000000000000000000000000000000000..afe0358752b7ce15d71a23064ac1c61a3725b79c Binary files /dev/null and b/anime.png differ diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..97c59221c429e335c3a2e3413c11cc155d5b6122 --- /dev/null +++ b/app.py @@ -0,0 +1,68 @@ +import os +os.system("pip install gradio==2.9b23") +import random +import gradio as gr +from PIL import Image +import torch +from random import randint +import sys +from subprocess import call +import psutil + + + + +torch.hub.download_url_to_file('http://people.csail.mit.edu/billf/project%20pages/sresCode/Markov%20Random%20Fields%20for%20Super-Resolution_files/100075_lowres.jpg', 'bear.jpg') + + +def run_cmd(command): + try: + print(command) + call(command, shell=True) + except KeyboardInterrupt: + print("Process interrupted") + sys.exit(1) +run_cmd("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P .") +run_cmd("pip install basicsr") +run_cmd("pip freeze") + +os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P .") + + +def inference(img,mode): + _id = randint(1, 10000) + INPUT_DIR = "/tmp/input_image" + str(_id) + "/" + OUTPUT_DIR = "/tmp/output_image" + str(_id) + "/" + run_cmd("rm -rf " + INPUT_DIR) + run_cmd("rm -rf " + OUTPUT_DIR) + run_cmd("mkdir " + INPUT_DIR) + run_cmd("mkdir " + OUTPUT_DIR) + basewidth = 256 + wpercent = (basewidth/float(img.size[0])) + hsize = int((float(img.size[1])*float(wpercent))) + img = img.resize((basewidth,hsize), Image.ANTIALIAS) + img.save(INPUT_DIR + "1.jpg", "JPEG") + if mode == "base": + run_cmd("python inference_realesrgan.py -n RealESRGAN_x4plus -i "+ INPUT_DIR + " -o " + OUTPUT_DIR) + else: + os.system("python inference_realesrgan.py -n RealESRGAN_x4plus_anime_6B -i "+ INPUT_DIR + " -o " + OUTPUT_DIR) + return os.path.join(OUTPUT_DIR, "1_out.jpg") + + + + +title = "Real-ESRGAN" +description = "Gradio demo for Real-ESRGAN. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. Please click submit only once" +article = "

Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data | Github Repo

" + +gr.Interface( + inference, + [gr.inputs.Image(type="pil", label="Input"),gr.inputs.Radio(["base","anime"], type="value", default="base", label="model type")], + gr.outputs.Image(type="file", label="Output"), + title=title, + description=description, + article=article, + examples=[ + ['bear.jpg','base'], + ['anime.png','anime'] + ]).launch() \ No newline at end of file diff --git a/experiments/.DS_Store b/experiments/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..998b6bb5c48ac96fd16743a90223c2a300fa836f Binary files /dev/null and b/experiments/.DS_Store differ diff --git a/experiments/pretrained_models/README.md b/experiments/pretrained_models/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d0cc4afcbdd2c733f6b946bb86bd00baa90e8295 --- /dev/null +++ b/experiments/pretrained_models/README.md @@ -0,0 +1 @@ +# Put downloaded pre-trained models here diff --git a/inference_realesrgan.py b/inference_realesrgan.py new file mode 100644 index 0000000000000000000000000000000000000000..6d5ff4d188faaa16c0131be69a08fd22fb608f80 --- /dev/null +++ b/inference_realesrgan.py @@ -0,0 +1,128 @@ +import argparse +import cv2 +import glob +import os +from basicsr.archs.rrdbnet_arch import RRDBNet + +from realesrgan import RealESRGANer +from realesrgan.archs.srvgg_arch import SRVGGNetCompact + + +def main(): + """Inference demo for Real-ESRGAN. + """ + parser = argparse.ArgumentParser() + parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder') + parser.add_argument( + '-n', + '--model_name', + type=str, + default='RealESRGAN_x4plus', + help=('Model names: RealESRGAN_x4plus | RealESRNet_x4plus | RealESRGAN_x4plus_anime_6B | RealESRGAN_x2plus' + 'RealESRGANv2-anime-xsx2 | RealESRGANv2-animevideo-xsx2-nousm | RealESRGANv2-animevideo-xsx2' + 'RealESRGANv2-anime-xsx4 | RealESRGANv2-animevideo-xsx4-nousm | RealESRGANv2-animevideo-xsx4')) + parser.add_argument('-o', '--output', type=str, default='results', help='Output folder') + parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image') + parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image') + parser.add_argument('-t', '--tile', type=int, default=0, help='Tile size, 0 for no tile during testing') + parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding') + parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border') + parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face') + parser.add_argument('--half', action='store_true', help='Use half precision during inference') + parser.add_argument( + '--alpha_upsampler', + type=str, + default='realesrgan', + help='The upsampler for the alpha channels. Options: realesrgan | bicubic') + parser.add_argument( + '--ext', + type=str, + default='auto', + help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs') + args = parser.parse_args() + + # determine models according to model names + args.model_name = args.model_name.split('.')[0] + if args.model_name in ['RealESRGAN_x4plus', 'RealESRNet_x4plus']: # x4 RRDBNet model + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) + netscale = 4 + elif args.model_name in ['RealESRGAN_x4plus_anime_6B']: # x4 RRDBNet model with 6 blocks + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) + netscale = 4 + elif args.model_name in ['RealESRGAN_x2plus']: # x2 RRDBNet model + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) + netscale = 2 + elif args.model_name in [ + 'RealESRGANv2-anime-xsx2', 'RealESRGANv2-animevideo-xsx2-nousm', 'RealESRGANv2-animevideo-xsx2' + ]: # x2 VGG-style model (XS size) + model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=2, act_type='prelu') + netscale = 2 + elif args.model_name in [ + 'RealESRGANv2-anime-xsx4', 'RealESRGANv2-animevideo-xsx4-nousm', 'RealESRGANv2-animevideo-xsx4' + ]: # x4 VGG-style model (XS size) + model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') + netscale = 4 + + # determine model paths + model_path = os.path.join('.', args.model_name + '.pth') + if not os.path.isfile(model_path): + model_path = os.path.join('.', args.model_name + '.pth') + if not os.path.isfile(model_path): + raise ValueError(f'Model {args.model_name} does not exist.') + + # restorer + upsampler = RealESRGANer( + scale=netscale, + model_path=model_path, + model=model, + tile=args.tile, + tile_pad=args.tile_pad, + pre_pad=args.pre_pad, + half=args.half) + + if args.face_enhance: # Use GFPGAN for face enhancement + from gfpgan import GFPGANer + face_enhancer = GFPGANer( + model_path='https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth', + upscale=args.outscale, + arch='clean', + channel_multiplier=2, + bg_upsampler=upsampler) + os.makedirs(args.output, exist_ok=True) + + if os.path.isfile(args.input): + paths = [args.input] + else: + paths = sorted(glob.glob(os.path.join(args.input, '*'))) + + for idx, path in enumerate(paths): + imgname, extension = os.path.splitext(os.path.basename(path)) + print('Testing', idx, imgname) + + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + if len(img.shape) == 3 and img.shape[2] == 4: + img_mode = 'RGBA' + else: + img_mode = None + + try: + if args.face_enhance: + _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True) + else: + output, _ = upsampler.enhance(img, outscale=args.outscale) + except RuntimeError as error: + print('Error', error) + print('If you encounter CUDA out of memory, try to set --tile with a smaller number.') + else: + if args.ext == 'auto': + extension = extension[1:] + else: + extension = args.ext + if img_mode == 'RGBA': # RGBA images should be saved in png format + extension = 'png' + save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}') + cv2.imwrite(save_path, output) + + +if __name__ == '__main__': + main() diff --git a/inference_realesrgan_video.py b/inference_realesrgan_video.py new file mode 100644 index 0000000000000000000000000000000000000000..639b848e6578a2480ee0784e664c7751e325c477 --- /dev/null +++ b/inference_realesrgan_video.py @@ -0,0 +1,199 @@ +import argparse +import glob +import mimetypes +import os +import queue +import shutil +import torch +from basicsr.archs.rrdbnet_arch import RRDBNet +from basicsr.utils.logger import AvgTimer +from tqdm import tqdm + +from realesrgan import IOConsumer, PrefetchReader, RealESRGANer +from realesrgan.archs.srvgg_arch import SRVGGNetCompact + + +def main(): + """Inference demo for Real-ESRGAN. + It mainly for restoring anime videos. + + """ + parser = argparse.ArgumentParser() + parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder') + parser.add_argument( + '-n', + '--model_name', + type=str, + default='RealESRGAN_x4plus', + help=('Model names: RealESRGAN_x4plus | RealESRNet_x4plus | RealESRGAN_x4plus_anime_6B | RealESRGAN_x2plus' + 'RealESRGANv2-anime-xsx2 | RealESRGANv2-animevideo-xsx2-nousm | RealESRGANv2-animevideo-xsx2' + 'RealESRGANv2-anime-xsx4 | RealESRGANv2-animevideo-xsx4-nousm | RealESRGANv2-animevideo-xsx4')) + parser.add_argument('-o', '--output', type=str, default='results', help='Output folder') + parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image') + parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored video') + parser.add_argument('-t', '--tile', type=int, default=0, help='Tile size, 0 for no tile during testing') + parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding') + parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border') + parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face') + parser.add_argument('--half', action='store_true', help='Use half precision during inference') + parser.add_argument('-v', '--video', action='store_true', help='Output a video using ffmpeg') + parser.add_argument('-a', '--audio', action='store_true', help='Keep audio') + parser.add_argument('--fps', type=float, default=None, help='FPS of the output video') + parser.add_argument('--consumer', type=int, default=4, help='Number of IO consumers') + + parser.add_argument( + '--alpha_upsampler', + type=str, + default='realesrgan', + help='The upsampler for the alpha channels. Options: realesrgan | bicubic') + parser.add_argument( + '--ext', + type=str, + default='auto', + help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs') + args = parser.parse_args() + + # ---------------------- determine models according to model names ---------------------- # + args.model_name = args.model_name.split('.')[0] + if args.model_name in ['RealESRGAN_x4plus', 'RealESRNet_x4plus']: # x4 RRDBNet model + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) + netscale = 4 + elif args.model_name in ['RealESRGAN_x4plus_anime_6B']: # x4 RRDBNet model with 6 blocks + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) + netscale = 4 + elif args.model_name in ['RealESRGAN_x2plus']: # x2 RRDBNet model + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) + netscale = 2 + elif args.model_name in [ + 'RealESRGANv2-anime-xsx2', 'RealESRGANv2-animevideo-xsx2-nousm', 'RealESRGANv2-animevideo-xsx2' + ]: # x2 VGG-style model (XS size) + model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=2, act_type='prelu') + netscale = 2 + elif args.model_name in [ + 'RealESRGANv2-anime-xsx4', 'RealESRGANv2-animevideo-xsx4-nousm', 'RealESRGANv2-animevideo-xsx4' + ]: # x4 VGG-style model (XS size) + model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') + netscale = 4 + + # ---------------------- determine model paths ---------------------- # + model_path = os.path.join('experiments/pretrained_models', args.model_name + '.pth') + if not os.path.isfile(model_path): + model_path = os.path.join('realesrgan/weights', args.model_name + '.pth') + if not os.path.isfile(model_path): + raise ValueError(f'Model {args.model_name} does not exist.') + + # restorer + upsampler = RealESRGANer( + scale=netscale, + model_path=model_path, + model=model, + tile=args.tile, + tile_pad=args.tile_pad, + pre_pad=args.pre_pad, + half=args.half) + + if args.face_enhance: # Use GFPGAN for face enhancement + from gfpgan import GFPGANer + face_enhancer = GFPGANer( + model_path='https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth', + upscale=args.outscale, + arch='clean', + channel_multiplier=2, + bg_upsampler=upsampler) + os.makedirs(args.output, exist_ok=True) + # for saving restored frames + save_frame_folder = os.path.join(args.output, 'frames_tmpout') + os.makedirs(save_frame_folder, exist_ok=True) + + if mimetypes.guess_type(args.input)[0].startswith('video'): # is a video file + video_name = os.path.splitext(os.path.basename(args.input))[0] + frame_folder = os.path.join('tmp_frames', video_name) + os.makedirs(frame_folder, exist_ok=True) + # use ffmpeg to extract frames + os.system(f'ffmpeg -i {args.input} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0 {frame_folder}/frame%08d.png') + # get image path list + paths = sorted(glob.glob(os.path.join(frame_folder, '*'))) + if args.video: + if args.fps is None: + # get input video fps + import ffmpeg + probe = ffmpeg.probe(args.input) + video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video'] + args.fps = eval(video_streams[0]['avg_frame_rate']) + elif mimetypes.guess_type(args.input)[0].startswith('image'): # is an image file + paths = [args.input] + video_name = 'video' + else: + paths = sorted(glob.glob(os.path.join(args.input, '*'))) + video_name = 'video' + + timer = AvgTimer() + timer.start() + pbar = tqdm(total=len(paths), unit='frame', desc='inference') + # set up prefetch reader + reader = PrefetchReader(paths, num_prefetch_queue=4) + reader.start() + + que = queue.Queue() + consumers = [IOConsumer(args, que, f'IO_{i}') for i in range(args.consumer)] + for consumer in consumers: + consumer.start() + + for idx, (path, img) in enumerate(zip(paths, reader)): + imgname, extension = os.path.splitext(os.path.basename(path)) + if len(img.shape) == 3 and img.shape[2] == 4: + img_mode = 'RGBA' + else: + img_mode = None + + try: + if args.face_enhance: + _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True) + else: + output, _ = upsampler.enhance(img, outscale=args.outscale) + except RuntimeError as error: + print('Error', error) + print('If you encounter CUDA out of memory, try to set --tile with a smaller number.') + + else: + if args.ext == 'auto': + extension = extension[1:] + else: + extension = args.ext + if img_mode == 'RGBA': # RGBA images should be saved in png format + extension = 'png' + save_path = os.path.join(save_frame_folder, f'{imgname}_out.{extension}') + + que.put({'output': output, 'save_path': save_path}) + + pbar.update(1) + torch.cuda.synchronize() + timer.record() + avg_fps = 1. / (timer.get_avg_time() + 1e-7) + pbar.set_description(f'idx {idx}, fps {avg_fps:.2f}') + + for _ in range(args.consumer): + que.put('quit') + for consumer in consumers: + consumer.join() + pbar.close() + + # merge frames to video + if args.video: + video_save_path = os.path.join(args.output, f'{video_name}_{args.suffix}.mp4') + if args.audio: + os.system( + f'ffmpeg -r {args.fps} -i {save_frame_folder}/frame%08d_out.{extension} -i {args.input}' + f' -map 0:v:0 -map 1:a:0 -c:a copy -c:v libx264 -r {args.fps} -pix_fmt yuv420p {video_save_path}') + else: + os.system(f'ffmpeg -r {args.fps} -i {save_frame_folder}/frame%08d_out.{extension} ' + f'-c:v libx264 -r {args.fps} -pix_fmt yuv420p {video_save_path}') + + # delete tmp file + shutil.rmtree(save_frame_folder) + if os.path.isdir(frame_folder): + shutil.rmtree(frame_folder) + + +if __name__ == '__main__': + main() diff --git a/options/finetune_realesrgan_x4plus.yml b/options/finetune_realesrgan_x4plus.yml new file mode 100644 index 0000000000000000000000000000000000000000..aa9806570025dce0a967ca0541a0ea497a57d6a9 --- /dev/null +++ b/options/finetune_realesrgan_x4plus.yml @@ -0,0 +1,188 @@ +# general settings +name: finetune_RealESRGANx4plus_400k +model_type: RealESRGANModel +scale: 4 +num_gpu: auto +manual_seed: 0 + +# ----------------- options for synthesizing training data in RealESRGANModel ----------------- # +# USM the ground-truth +l1_gt_usm: True +percep_gt_usm: True +gan_gt_usm: False + +# the first degradation process +resize_prob: [0.2, 0.7, 0.1] # up, down, keep +resize_range: [0.15, 1.5] +gaussian_noise_prob: 0.5 +noise_range: [1, 30] +poisson_scale_range: [0.05, 3] +gray_noise_prob: 0.4 +jpeg_range: [30, 95] + +# the second degradation process +second_blur_prob: 0.8 +resize_prob2: [0.3, 0.4, 0.3] # up, down, keep +resize_range2: [0.3, 1.2] +gaussian_noise_prob2: 0.5 +noise_range2: [1, 25] +poisson_scale_range2: [0.05, 2.5] +gray_noise_prob2: 0.4 +jpeg_range2: [30, 95] + +gt_size: 256 +queue_size: 180 + +# dataset and data loader settings +datasets: + train: + name: DF2K+OST + type: RealESRGANDataset + dataroot_gt: datasets/DF2K + meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt + io_backend: + type: disk + + blur_kernel_size: 21 + kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob: 0.1 + blur_sigma: [0.2, 3] + betag_range: [0.5, 4] + betap_range: [1, 2] + + blur_kernel_size2: 21 + kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob2: 0.1 + blur_sigma2: [0.2, 1.5] + betag_range2: [0.5, 4] + betap_range2: [1, 2] + + final_sinc_prob: 0.8 + + gt_size: 256 + use_hflip: True + use_rot: False + + # data loader + use_shuffle: true + num_worker_per_gpu: 5 + batch_size_per_gpu: 12 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + # Uncomment these for validation + # val: + # name: validation + # type: PairedImageDataset + # dataroot_gt: path_to_gt + # dataroot_lq: path_to_lq + # io_backend: + # type: disk + +# network structures +network_g: + type: RRDBNet + num_in_ch: 3 + num_out_ch: 3 + num_feat: 64 + num_block: 23 + num_grow_ch: 32 + +network_d: + type: UNetDiscriminatorSN + num_in_ch: 3 + num_feat: 64 + skip_connection: True + +# path +path: + # use the pre-trained Real-ESRNet model + pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth + param_key_g: params_ema + strict_load_g: true + pretrain_network_d: experiments/pretrained_models/RealESRGAN_x4plus_netD.pth + param_key_d: params + strict_load_d: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + optim_d: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [400000] + gamma: 0.5 + + total_iter: 400000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + # perceptual loss (content and style losses) + perceptual_opt: + type: PerceptualLoss + layer_weights: + # before relu + 'conv1_2': 0.1 + 'conv2_2': 0.1 + 'conv3_4': 1 + 'conv4_4': 1 + 'conv5_4': 1 + vgg_type: vgg19 + use_input_norm: true + perceptual_weight: !!float 1.0 + style_weight: 0 + range_norm: false + criterion: l1 + # gan loss + gan_opt: + type: GANLoss + gan_type: vanilla + real_label_val: 1.0 + fake_label_val: 0.0 + loss_weight: !!float 1e-1 + + net_d_iters: 1 + net_d_init_iters: 0 + +# Uncomment these for validation +# validation settings +# val: +# val_freq: !!float 5e3 +# save_img: True + +# metrics: +# psnr: # metric name +# type: calculate_psnr +# crop_border: 4 +# test_y_channel: false + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/finetune_realesrgan_x4plus_pairdata.yml b/options/finetune_realesrgan_x4plus_pairdata.yml new file mode 100644 index 0000000000000000000000000000000000000000..db45d4d275facc1191caa87d2d8618c30624477a --- /dev/null +++ b/options/finetune_realesrgan_x4plus_pairdata.yml @@ -0,0 +1,150 @@ +# general settings +name: finetune_RealESRGANx4plus_400k_pairdata +model_type: RealESRGANModel +scale: 4 +num_gpu: auto +manual_seed: 0 + +# USM the ground-truth +l1_gt_usm: True +percep_gt_usm: True +gan_gt_usm: False + +high_order_degradation: False # do not use the high-order degradation generation process + +# dataset and data loader settings +datasets: + train: + name: DIV2K + type: RealESRGANPairedDataset + dataroot_gt: datasets/DF2K + dataroot_lq: datasets/DF2K + meta_info: datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt + io_backend: + type: disk + + gt_size: 256 + use_hflip: True + use_rot: False + + # data loader + use_shuffle: true + num_worker_per_gpu: 5 + batch_size_per_gpu: 12 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + # Uncomment these for validation + # val: + # name: validation + # type: PairedImageDataset + # dataroot_gt: path_to_gt + # dataroot_lq: path_to_lq + # io_backend: + # type: disk + +# network structures +network_g: + type: RRDBNet + num_in_ch: 3 + num_out_ch: 3 + num_feat: 64 + num_block: 23 + num_grow_ch: 32 + +network_d: + type: UNetDiscriminatorSN + num_in_ch: 3 + num_feat: 64 + skip_connection: True + +# path +path: + # use the pre-trained Real-ESRNet model + pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth + param_key_g: params_ema + strict_load_g: true + pretrain_network_d: experiments/pretrained_models/RealESRGAN_x4plus_netD.pth + param_key_d: params + strict_load_d: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + optim_d: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [400000] + gamma: 0.5 + + total_iter: 400000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + # perceptual loss (content and style losses) + perceptual_opt: + type: PerceptualLoss + layer_weights: + # before relu + 'conv1_2': 0.1 + 'conv2_2': 0.1 + 'conv3_4': 1 + 'conv4_4': 1 + 'conv5_4': 1 + vgg_type: vgg19 + use_input_norm: true + perceptual_weight: !!float 1.0 + style_weight: 0 + range_norm: false + criterion: l1 + # gan loss + gan_opt: + type: GANLoss + gan_type: vanilla + real_label_val: 1.0 + fake_label_val: 0.0 + loss_weight: !!float 1e-1 + + net_d_iters: 1 + net_d_init_iters: 0 + +# Uncomment these for validation +# validation settings +# val: +# val_freq: !!float 5e3 +# save_img: True + +# metrics: +# psnr: # metric name +# type: calculate_psnr +# crop_border: 4 +# test_y_channel: false + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/setup.cfg b/options/setup.cfg new file mode 100644 index 0000000000000000000000000000000000000000..9cecd96943e729db110b1960295d9d4bf76c1754 --- /dev/null +++ b/options/setup.cfg @@ -0,0 +1,33 @@ +[flake8] +ignore = + # line break before binary operator (W503) + W503, + # line break after binary operator (W504) + W504, +max-line-length=120 + +[yapf] +based_on_style = pep8 +column_limit = 120 +blank_line_before_nested_class_or_def = true +split_before_expression_after_opening_paren = true + +[isort] +line_length = 120 +multi_line_output = 0 +known_standard_library = pkg_resources,setuptools +known_first_party = realesrgan +known_third_party = PIL,basicsr,cv2,numpy,pytest,torch,torchvision,tqdm,yaml +no_lines_before = STDLIB,LOCALFOLDER +default_section = THIRDPARTY + +[codespell] +skip = .git,./docs/build +count = +quiet-level = 3 + +[aliases] +test=pytest + +[tool:pytest] +addopts=tests/ diff --git a/options/train_realesrgan_x2plus.yml b/options/train_realesrgan_x2plus.yml new file mode 100644 index 0000000000000000000000000000000000000000..3c98a0f370def397bdf47ede0fa5f6dd6a4411d5 --- /dev/null +++ b/options/train_realesrgan_x2plus.yml @@ -0,0 +1,186 @@ +# general settings +name: train_RealESRGANx2plus_400k_B12G4 +model_type: RealESRGANModel +scale: 2 +num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs +manual_seed: 0 + +# ----------------- options for synthesizing training data in RealESRGANModel ----------------- # +# USM the ground-truth +l1_gt_usm: True +percep_gt_usm: True +gan_gt_usm: False + +# the first degradation process +resize_prob: [0.2, 0.7, 0.1] # up, down, keep +resize_range: [0.15, 1.5] +gaussian_noise_prob: 0.5 +noise_range: [1, 30] +poisson_scale_range: [0.05, 3] +gray_noise_prob: 0.4 +jpeg_range: [30, 95] + +# the second degradation process +second_blur_prob: 0.8 +resize_prob2: [0.3, 0.4, 0.3] # up, down, keep +resize_range2: [0.3, 1.2] +gaussian_noise_prob2: 0.5 +noise_range2: [1, 25] +poisson_scale_range2: [0.05, 2.5] +gray_noise_prob2: 0.4 +jpeg_range2: [30, 95] + +gt_size: 256 +queue_size: 180 + +# dataset and data loader settings +datasets: + train: + name: DF2K+OST + type: RealESRGANDataset + dataroot_gt: datasets/DF2K + meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt + io_backend: + type: disk + + blur_kernel_size: 21 + kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob: 0.1 + blur_sigma: [0.2, 3] + betag_range: [0.5, 4] + betap_range: [1, 2] + + blur_kernel_size2: 21 + kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob2: 0.1 + blur_sigma2: [0.2, 1.5] + betag_range2: [0.5, 4] + betap_range2: [1, 2] + + final_sinc_prob: 0.8 + + gt_size: 256 + use_hflip: True + use_rot: False + + # data loader + use_shuffle: true + num_worker_per_gpu: 5 + batch_size_per_gpu: 12 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + # Uncomment these for validation + # val: + # name: validation + # type: PairedImageDataset + # dataroot_gt: path_to_gt + # dataroot_lq: path_to_lq + # io_backend: + # type: disk + +# network structures +network_g: + type: RRDBNet + num_in_ch: 3 + num_out_ch: 3 + num_feat: 64 + num_block: 23 + num_grow_ch: 32 + scale: 2 + +network_d: + type: UNetDiscriminatorSN + num_in_ch: 3 + num_feat: 64 + skip_connection: True + +# path +path: + # use the pre-trained Real-ESRNet model + pretrain_network_g: experiments/pretrained_models/RealESRNet_x2plus.pth + param_key_g: params_ema + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + optim_d: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [400000] + gamma: 0.5 + + total_iter: 400000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + # perceptual loss (content and style losses) + perceptual_opt: + type: PerceptualLoss + layer_weights: + # before relu + 'conv1_2': 0.1 + 'conv2_2': 0.1 + 'conv3_4': 1 + 'conv4_4': 1 + 'conv5_4': 1 + vgg_type: vgg19 + use_input_norm: true + perceptual_weight: !!float 1.0 + style_weight: 0 + range_norm: false + criterion: l1 + # gan loss + gan_opt: + type: GANLoss + gan_type: vanilla + real_label_val: 1.0 + fake_label_val: 0.0 + loss_weight: !!float 1e-1 + + net_d_iters: 1 + net_d_init_iters: 0 + +# Uncomment these for validation +# validation settings +# val: +# val_freq: !!float 5e3 +# save_img: True + +# metrics: +# psnr: # metric name +# type: calculate_psnr +# crop_border: 4 +# test_y_channel: false + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train_realesrgan_x4plus.yml b/options/train_realesrgan_x4plus.yml new file mode 100644 index 0000000000000000000000000000000000000000..763199a35fa0135713b4a87b00c25f63062ac8aa --- /dev/null +++ b/options/train_realesrgan_x4plus.yml @@ -0,0 +1,185 @@ +# general settings +name: train_RealESRGANx4plus_400k_B12G4 +model_type: RealESRGANModel +scale: 4 +num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs +manual_seed: 0 + +# ----------------- options for synthesizing training data in RealESRGANModel ----------------- # +# USM the ground-truth +l1_gt_usm: True +percep_gt_usm: True +gan_gt_usm: False + +# the first degradation process +resize_prob: [0.2, 0.7, 0.1] # up, down, keep +resize_range: [0.15, 1.5] +gaussian_noise_prob: 0.5 +noise_range: [1, 30] +poisson_scale_range: [0.05, 3] +gray_noise_prob: 0.4 +jpeg_range: [30, 95] + +# the second degradation process +second_blur_prob: 0.8 +resize_prob2: [0.3, 0.4, 0.3] # up, down, keep +resize_range2: [0.3, 1.2] +gaussian_noise_prob2: 0.5 +noise_range2: [1, 25] +poisson_scale_range2: [0.05, 2.5] +gray_noise_prob2: 0.4 +jpeg_range2: [30, 95] + +gt_size: 256 +queue_size: 180 + +# dataset and data loader settings +datasets: + train: + name: DF2K+OST + type: RealESRGANDataset + dataroot_gt: datasets/DF2K + meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt + io_backend: + type: disk + + blur_kernel_size: 21 + kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob: 0.1 + blur_sigma: [0.2, 3] + betag_range: [0.5, 4] + betap_range: [1, 2] + + blur_kernel_size2: 21 + kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob2: 0.1 + blur_sigma2: [0.2, 1.5] + betag_range2: [0.5, 4] + betap_range2: [1, 2] + + final_sinc_prob: 0.8 + + gt_size: 256 + use_hflip: True + use_rot: False + + # data loader + use_shuffle: true + num_worker_per_gpu: 5 + batch_size_per_gpu: 12 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + # Uncomment these for validation + # val: + # name: validation + # type: PairedImageDataset + # dataroot_gt: path_to_gt + # dataroot_lq: path_to_lq + # io_backend: + # type: disk + +# network structures +network_g: + type: RRDBNet + num_in_ch: 3 + num_out_ch: 3 + num_feat: 64 + num_block: 23 + num_grow_ch: 32 + +network_d: + type: UNetDiscriminatorSN + num_in_ch: 3 + num_feat: 64 + skip_connection: True + +# path +path: + # use the pre-trained Real-ESRNet model + pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth + param_key_g: params_ema + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + optim_d: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [400000] + gamma: 0.5 + + total_iter: 400000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + # perceptual loss (content and style losses) + perceptual_opt: + type: PerceptualLoss + layer_weights: + # before relu + 'conv1_2': 0.1 + 'conv2_2': 0.1 + 'conv3_4': 1 + 'conv4_4': 1 + 'conv5_4': 1 + vgg_type: vgg19 + use_input_norm: true + perceptual_weight: !!float 1.0 + style_weight: 0 + range_norm: false + criterion: l1 + # gan loss + gan_opt: + type: GANLoss + gan_type: vanilla + real_label_val: 1.0 + fake_label_val: 0.0 + loss_weight: !!float 1e-1 + + net_d_iters: 1 + net_d_init_iters: 0 + +# Uncomment these for validation +# validation settings +# val: +# val_freq: !!float 5e3 +# save_img: True + +# metrics: +# psnr: # metric name +# type: calculate_psnr +# crop_border: 4 +# test_y_channel: false + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train_realesrnet_x2plus.yml b/options/train_realesrnet_x2plus.yml new file mode 100644 index 0000000000000000000000000000000000000000..81ee9ef16817eaf17cf993cea1a4a8d51815d96c --- /dev/null +++ b/options/train_realesrnet_x2plus.yml @@ -0,0 +1,145 @@ +# general settings +name: train_RealESRNetx2plus_1000k_B12G4 +model_type: RealESRNetModel +scale: 2 +num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs +manual_seed: 0 + +# ----------------- options for synthesizing training data in RealESRNetModel ----------------- # +gt_usm: True # USM the ground-truth + +# the first degradation process +resize_prob: [0.2, 0.7, 0.1] # up, down, keep +resize_range: [0.15, 1.5] +gaussian_noise_prob: 0.5 +noise_range: [1, 30] +poisson_scale_range: [0.05, 3] +gray_noise_prob: 0.4 +jpeg_range: [30, 95] + +# the second degradation process +second_blur_prob: 0.8 +resize_prob2: [0.3, 0.4, 0.3] # up, down, keep +resize_range2: [0.3, 1.2] +gaussian_noise_prob2: 0.5 +noise_range2: [1, 25] +poisson_scale_range2: [0.05, 2.5] +gray_noise_prob2: 0.4 +jpeg_range2: [30, 95] + +gt_size: 256 +queue_size: 180 + +# dataset and data loader settings +datasets: + train: + name: DF2K+OST + type: RealESRGANDataset + dataroot_gt: datasets/DF2K + meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt + io_backend: + type: disk + + blur_kernel_size: 21 + kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob: 0.1 + blur_sigma: [0.2, 3] + betag_range: [0.5, 4] + betap_range: [1, 2] + + blur_kernel_size2: 21 + kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob2: 0.1 + blur_sigma2: [0.2, 1.5] + betag_range2: [0.5, 4] + betap_range2: [1, 2] + + final_sinc_prob: 0.8 + + gt_size: 256 + use_hflip: True + use_rot: False + + # data loader + use_shuffle: true + num_worker_per_gpu: 5 + batch_size_per_gpu: 12 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + # Uncomment these for validation + # val: + # name: validation + # type: PairedImageDataset + # dataroot_gt: path_to_gt + # dataroot_lq: path_to_lq + # io_backend: + # type: disk + +# network structures +network_g: + type: RRDBNet + num_in_ch: 3 + num_out_ch: 3 + num_feat: 64 + num_block: 23 + num_grow_ch: 32 + scale: 2 + +# path +path: + pretrain_network_g: experiments/pretrained_models/RealESRGAN_x4plus.pth + param_key_g: params_ema + strict_load_g: False + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 2e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [1000000] + gamma: 0.5 + + total_iter: 1000000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + +# Uncomment these for validation +# validation settings +# val: +# val_freq: !!float 5e3 +# save_img: True + +# metrics: +# psnr: # metric name +# type: calculate_psnr +# crop_border: 4 +# test_y_channel: false + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train_realesrnet_x4plus.yml b/options/train_realesrnet_x4plus.yml new file mode 100644 index 0000000000000000000000000000000000000000..45670ed824ae0c697a395049b089e50364292dfc --- /dev/null +++ b/options/train_realesrnet_x4plus.yml @@ -0,0 +1,144 @@ +# general settings +name: train_RealESRNetx4plus_1000k_B12G4 +model_type: RealESRNetModel +scale: 4 +num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs +manual_seed: 0 + +# ----------------- options for synthesizing training data in RealESRNetModel ----------------- # +gt_usm: True # USM the ground-truth + +# the first degradation process +resize_prob: [0.2, 0.7, 0.1] # up, down, keep +resize_range: [0.15, 1.5] +gaussian_noise_prob: 0.5 +noise_range: [1, 30] +poisson_scale_range: [0.05, 3] +gray_noise_prob: 0.4 +jpeg_range: [30, 95] + +# the second degradation process +second_blur_prob: 0.8 +resize_prob2: [0.3, 0.4, 0.3] # up, down, keep +resize_range2: [0.3, 1.2] +gaussian_noise_prob2: 0.5 +noise_range2: [1, 25] +poisson_scale_range2: [0.05, 2.5] +gray_noise_prob2: 0.4 +jpeg_range2: [30, 95] + +gt_size: 256 +queue_size: 180 + +# dataset and data loader settings +datasets: + train: + name: DF2K+OST + type: RealESRGANDataset + dataroot_gt: datasets/DF2K + meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt + io_backend: + type: disk + + blur_kernel_size: 21 + kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob: 0.1 + blur_sigma: [0.2, 3] + betag_range: [0.5, 4] + betap_range: [1, 2] + + blur_kernel_size2: 21 + kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob2: 0.1 + blur_sigma2: [0.2, 1.5] + betag_range2: [0.5, 4] + betap_range2: [1, 2] + + final_sinc_prob: 0.8 + + gt_size: 256 + use_hflip: True + use_rot: False + + # data loader + use_shuffle: true + num_worker_per_gpu: 5 + batch_size_per_gpu: 12 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + # Uncomment these for validation + # val: + # name: validation + # type: PairedImageDataset + # dataroot_gt: path_to_gt + # dataroot_lq: path_to_lq + # io_backend: + # type: disk + +# network structures +network_g: + type: RRDBNet + num_in_ch: 3 + num_out_ch: 3 + num_feat: 64 + num_block: 23 + num_grow_ch: 32 + +# path +path: + pretrain_network_g: experiments/pretrained_models/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth + param_key_g: params_ema + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 2e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [1000000] + gamma: 0.5 + + total_iter: 1000000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + +# Uncomment these for validation +# validation settings +# val: +# val_freq: !!float 5e3 +# save_img: True + +# metrics: +# psnr: # metric name +# type: calculate_psnr +# crop_border: 4 +# test_y_channel: false + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..5bd99e9459bcae330c45804e113191e957755c1a --- /dev/null +++ b/packages.txt @@ -0,0 +1,3 @@ +ffmpeg +libsm6 +libxext6 \ No newline at end of file diff --git a/realesrgan/__init__.py b/realesrgan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bfea78f284116dee22510d4aa91f9e44afb7d472 --- /dev/null +++ b/realesrgan/__init__.py @@ -0,0 +1,6 @@ +# flake8: noqa +from .archs import * +from .data import * +from .models import * +from .utils import * +#from .version import * diff --git a/realesrgan/archs/__init__.py b/realesrgan/archs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f3fbbf3b78e33b61fd4c33a564a9a617010d90de --- /dev/null +++ b/realesrgan/archs/__init__.py @@ -0,0 +1,10 @@ +import importlib +from basicsr.utils import scandir +from os import path as osp + +# automatically scan and import arch modules for registry +# scan all the files that end with '_arch.py' under the archs folder +arch_folder = osp.dirname(osp.abspath(__file__)) +arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] +# import all the arch modules +_arch_modules = [importlib.import_module(f'realesrgan.archs.{file_name}') for file_name in arch_filenames] diff --git a/realesrgan/archs/discriminator_arch.py b/realesrgan/archs/discriminator_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..4b66ab1226d6793de846bc9828bbe427031a0e2d --- /dev/null +++ b/realesrgan/archs/discriminator_arch.py @@ -0,0 +1,67 @@ +from basicsr.utils.registry import ARCH_REGISTRY +from torch import nn as nn +from torch.nn import functional as F +from torch.nn.utils import spectral_norm + + +@ARCH_REGISTRY.register() +class UNetDiscriminatorSN(nn.Module): + """Defines a U-Net discriminator with spectral normalization (SN) + + It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + Arg: + num_in_ch (int): Channel number of inputs. Default: 3. + num_feat (int): Channel number of base intermediate features. Default: 64. + skip_connection (bool): Whether to use skip connections between U-Net. Default: True. + """ + + def __init__(self, num_in_ch, num_feat=64, skip_connection=True): + super(UNetDiscriminatorSN, self).__init__() + self.skip_connection = skip_connection + norm = spectral_norm + # the first convolution + self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1) + # downsample + self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False)) + self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False)) + self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False)) + # upsample + self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False)) + self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False)) + self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False)) + # extra convolutions + self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) + self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) + self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1) + + def forward(self, x): + # downsample + x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) + x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) + x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) + x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True) + + # upsample + x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False) + x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x4 = x4 + x2 + x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) + x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x5 = x5 + x1 + x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False) + x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x6 = x6 + x0 + + # extra convolutions + out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) + out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) + out = self.conv9(out) + + return out diff --git a/realesrgan/archs/srvgg_arch.py b/realesrgan/archs/srvgg_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..39460965c9c5ee9cd6eb41c50d33574cb8ba6e50 --- /dev/null +++ b/realesrgan/archs/srvgg_arch.py @@ -0,0 +1,69 @@ +from basicsr.utils.registry import ARCH_REGISTRY +from torch import nn as nn +from torch.nn import functional as F + + +@ARCH_REGISTRY.register() +class SRVGGNetCompact(nn.Module): + """A compact VGG-style network structure for super-resolution. + + It is a compact network structure, which performs upsampling in the last layer and no convolution is + conducted on the HR feature space. + + Args: + num_in_ch (int): Channel number of inputs. Default: 3. + num_out_ch (int): Channel number of outputs. Default: 3. + num_feat (int): Channel number of intermediate features. Default: 64. + num_conv (int): Number of convolution layers in the body network. Default: 16. + upscale (int): Upsampling factor. Default: 4. + act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu. + """ + + def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'): + super(SRVGGNetCompact, self).__init__() + self.num_in_ch = num_in_ch + self.num_out_ch = num_out_ch + self.num_feat = num_feat + self.num_conv = num_conv + self.upscale = upscale + self.act_type = act_type + + self.body = nn.ModuleList() + # the first conv + self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) + # the first activation + if act_type == 'relu': + activation = nn.ReLU(inplace=True) + elif act_type == 'prelu': + activation = nn.PReLU(num_parameters=num_feat) + elif act_type == 'leakyrelu': + activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.body.append(activation) + + # the body structure + for _ in range(num_conv): + self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) + # activation + if act_type == 'relu': + activation = nn.ReLU(inplace=True) + elif act_type == 'prelu': + activation = nn.PReLU(num_parameters=num_feat) + elif act_type == 'leakyrelu': + activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.body.append(activation) + + # the last conv + self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) + # upsample + self.upsampler = nn.PixelShuffle(upscale) + + def forward(self, x): + out = x + for i in range(0, len(self.body)): + out = self.body[i](out) + + out = self.upsampler(out) + # add the nearest upsampled image, so that the network learns the residual + base = F.interpolate(x, scale_factor=self.upscale, mode='nearest') + out += base + return out diff --git a/realesrgan/data/__init__.py b/realesrgan/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f8fdd1aa47c12de9687c578094303eb7369246 --- /dev/null +++ b/realesrgan/data/__init__.py @@ -0,0 +1,10 @@ +import importlib +from basicsr.utils import scandir +from os import path as osp + +# automatically scan and import dataset modules for registry +# scan all the files that end with '_dataset.py' under the data folder +data_folder = osp.dirname(osp.abspath(__file__)) +dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] +# import all the dataset modules +_dataset_modules = [importlib.import_module(f'realesrgan.data.{file_name}') for file_name in dataset_filenames] diff --git a/realesrgan/data/realesrgan_dataset.py b/realesrgan/data/realesrgan_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4cf2d9e6583a6789b771679734ce55bb8a22e628 --- /dev/null +++ b/realesrgan/data/realesrgan_dataset.py @@ -0,0 +1,192 @@ +import cv2 +import math +import numpy as np +import os +import os.path as osp +import random +import time +import torch +from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels +from basicsr.data.transforms import augment +from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from basicsr.utils.registry import DATASET_REGISTRY +from torch.utils import data as data + + +@DATASET_REGISTRY.register() +class RealESRGANDataset(data.Dataset): + """Dataset used for Real-ESRGAN model: + Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + It loads gt (Ground-Truth) images, and augments them. + It also generates blur kernels and sinc kernels for generating low-quality images. + Note that the low-quality images are processed in tensors on GPUS for faster processing. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + meta_info (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). + Please see more options in the codes. + """ + + def __init__(self, opt): + super(RealESRGANDataset, self).__init__() + self.opt = opt + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.gt_folder = opt['dataroot_gt'] + + # file client (lmdb io backend) + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.gt_folder] + self.io_backend_opt['client_keys'] = ['gt'] + if not self.gt_folder.endswith('.lmdb'): + raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") + with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: + self.paths = [line.split('.')[0] for line in fin] + else: + # disk backend with meta_info + # Each line in the meta_info describes the relative path to an image + with open(self.opt['meta_info']) as fin: + paths = [line.strip().split(' ')[0] for line in fin] + self.paths = [os.path.join(self.gt_folder, v) for v in paths] + + # blur settings for the first degradation + self.blur_kernel_size = opt['blur_kernel_size'] + self.kernel_list = opt['kernel_list'] + self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability + self.blur_sigma = opt['blur_sigma'] + self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels + self.betap_range = opt['betap_range'] # betap used in plateau blur kernels + self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters + + # blur settings for the second degradation + self.blur_kernel_size2 = opt['blur_kernel_size2'] + self.kernel_list2 = opt['kernel_list2'] + self.kernel_prob2 = opt['kernel_prob2'] + self.blur_sigma2 = opt['blur_sigma2'] + self.betag_range2 = opt['betag_range2'] + self.betap_range2 = opt['betap_range2'] + self.sinc_prob2 = opt['sinc_prob2'] + + # a final sinc filter + self.final_sinc_prob = opt['final_sinc_prob'] + + self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21 + # TODO: kernel range is now hard-coded, should be in the configure file + self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect + self.pulse_tensor[10, 10] = 1 + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # -------------------------------- Load gt images -------------------------------- # + # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. + gt_path = self.paths[index] + # avoid errors caused by high latency in reading files + retry = 3 + while retry > 0: + try: + img_bytes = self.file_client.get(gt_path, 'gt') + except (IOError, OSError) as e: + logger = get_root_logger() + logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}') + # change another file to read + index = random.randint(0, self.__len__()) + gt_path = self.paths[index] + time.sleep(1) # sleep 1s for occasional server congestion + else: + break + finally: + retry -= 1 + img_gt = imfrombytes(img_bytes, float32=True) + + # -------------------- Do augmentation for training: flip, rotation -------------------- # + img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot']) + + # crop or pad to 400 + # TODO: 400 is hard-coded. You may change it accordingly + h, w = img_gt.shape[0:2] + crop_pad_size = 400 + # pad + if h < crop_pad_size or w < crop_pad_size: + pad_h = max(0, crop_pad_size - h) + pad_w = max(0, crop_pad_size - w) + img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101) + # crop + if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size: + h, w = img_gt.shape[0:2] + # randomly choose top and left coordinates + top = random.randint(0, h - crop_pad_size) + left = random.randint(0, w - crop_pad_size) + img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...] + + # ------------------------ Generate kernels (used in the first degradation) ------------------------ # + kernel_size = random.choice(self.kernel_range) + if np.random.uniform() < self.opt['sinc_prob']: + # this sinc filter setting is for kernels ranging from [7, 21] + if kernel_size < 13: + omega_c = np.random.uniform(np.pi / 3, np.pi) + else: + omega_c = np.random.uniform(np.pi / 5, np.pi) + kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) + else: + kernel = random_mixed_kernels( + self.kernel_list, + self.kernel_prob, + kernel_size, + self.blur_sigma, + self.blur_sigma, [-math.pi, math.pi], + self.betag_range, + self.betap_range, + noise_range=None) + # pad kernel + pad_size = (21 - kernel_size) // 2 + kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) + + # ------------------------ Generate kernels (used in the second degradation) ------------------------ # + kernel_size = random.choice(self.kernel_range) + if np.random.uniform() < self.opt['sinc_prob2']: + if kernel_size < 13: + omega_c = np.random.uniform(np.pi / 3, np.pi) + else: + omega_c = np.random.uniform(np.pi / 5, np.pi) + kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) + else: + kernel2 = random_mixed_kernels( + self.kernel_list2, + self.kernel_prob2, + kernel_size, + self.blur_sigma2, + self.blur_sigma2, [-math.pi, math.pi], + self.betag_range2, + self.betap_range2, + noise_range=None) + + # pad kernel + pad_size = (21 - kernel_size) // 2 + kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) + + # ------------------------------------- the final sinc kernel ------------------------------------- # + if np.random.uniform() < self.opt['final_sinc_prob']: + kernel_size = random.choice(self.kernel_range) + omega_c = np.random.uniform(np.pi / 3, np.pi) + sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21) + sinc_kernel = torch.FloatTensor(sinc_kernel) + else: + sinc_kernel = self.pulse_tensor + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0] + kernel = torch.FloatTensor(kernel) + kernel2 = torch.FloatTensor(kernel2) + + return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path} + return return_d + + def __len__(self): + return len(self.paths) diff --git a/realesrgan/data/realesrgan_paired_dataset.py b/realesrgan/data/realesrgan_paired_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..386c8d72496245dae8df033c2ebbd76b41ff45f1 --- /dev/null +++ b/realesrgan/data/realesrgan_paired_dataset.py @@ -0,0 +1,108 @@ +import os +from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb +from basicsr.data.transforms import augment, paired_random_crop +from basicsr.utils import FileClient, imfrombytes, img2tensor +from basicsr.utils.registry import DATASET_REGISTRY +from torch.utils import data as data +from torchvision.transforms.functional import normalize + + +@DATASET_REGISTRY.register() +class RealESRGANPairedDataset(data.Dataset): + """Paired image dataset for image restoration. + + Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. + + There are three modes: + 1. 'lmdb': Use lmdb files. + If opt['io_backend'] == lmdb. + 2. 'meta_info': Use meta information file to generate paths. + If opt['io_backend'] != lmdb and opt['meta_info'] is not None. + 3. 'folder': Scan folders to generate paths. + The rest. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + dataroot_lq (str): Data root path for lq. + meta_info (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. + Default: '{}'. + gt_size (int): Cropped patched size for gt patches. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h + and w for implementation). + + scale (bool): Scale, which will be added automatically. + phase (str): 'train' or 'val'. + """ + + def __init__(self, opt): + super(RealESRGANPairedDataset, self).__init__() + self.opt = opt + self.file_client = None + self.io_backend_opt = opt['io_backend'] + # mean and std for normalizing the input images + self.mean = opt['mean'] if 'mean' in opt else None + self.std = opt['std'] if 'std' in opt else None + + self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] + self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}' + + # file client (lmdb io backend) + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] + self.io_backend_opt['client_keys'] = ['lq', 'gt'] + self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) + elif 'meta_info' in self.opt and self.opt['meta_info'] is not None: + # disk backend with meta_info + # Each line in the meta_info describes the relative path to an image + with open(self.opt['meta_info']) as fin: + paths = [line.strip() for line in fin] + self.paths = [] + for path in paths: + gt_path, lq_path = path.split(', ') + gt_path = os.path.join(self.gt_folder, gt_path) + lq_path = os.path.join(self.lq_folder, lq_path) + self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)])) + else: + # disk backend + # it will scan the whole folder to get meta info + # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file + self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + scale = self.opt['scale'] + + # Load gt and lq images. Dimension order: HWC; channel order: BGR; + # image range: [0, 1], float32. + gt_path = self.paths[index]['gt_path'] + img_bytes = self.file_client.get(gt_path, 'gt') + img_gt = imfrombytes(img_bytes, float32=True) + lq_path = self.paths[index]['lq_path'] + img_bytes = self.file_client.get(lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + + # augmentation for training + if self.opt['phase'] == 'train': + gt_size = self.opt['gt_size'] + # random crop + img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) + # flip, rotation + img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) + # normalize + if self.mean is not None or self.std is not None: + normalize(img_lq, self.mean, self.std, inplace=True) + normalize(img_gt, self.mean, self.std, inplace=True) + + return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} + + def __len__(self): + return len(self.paths) diff --git a/realesrgan/models/__init__.py b/realesrgan/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0be7105dc75d150c49976396724085f678dc0675 --- /dev/null +++ b/realesrgan/models/__init__.py @@ -0,0 +1,10 @@ +import importlib +from basicsr.utils import scandir +from os import path as osp + +# automatically scan and import model modules for registry +# scan all the files that end with '_model.py' under the model folder +model_folder = osp.dirname(osp.abspath(__file__)) +model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] +# import all the model modules +_model_modules = [importlib.import_module(f'realesrgan.models.{file_name}') for file_name in model_filenames] diff --git a/realesrgan/models/realesrgan_model.py b/realesrgan/models/realesrgan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c298a09c42433177f90001a0a31d029576072ccd --- /dev/null +++ b/realesrgan/models/realesrgan_model.py @@ -0,0 +1,258 @@ +import numpy as np +import random +import torch +from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt +from basicsr.data.transforms import paired_random_crop +from basicsr.models.srgan_model import SRGANModel +from basicsr.utils import DiffJPEG, USMSharp +from basicsr.utils.img_process_util import filter2D +from basicsr.utils.registry import MODEL_REGISTRY +from collections import OrderedDict +from torch.nn import functional as F + + +@MODEL_REGISTRY.register() +class RealESRGANModel(SRGANModel): + """RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + It mainly performs: + 1. randomly synthesize LQ images in GPU tensors + 2. optimize the networks with GAN training. + """ + + def __init__(self, opt): + super(RealESRGANModel, self).__init__(opt) + self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts + self.usm_sharpener = USMSharp().cuda() # do usm sharpening + self.queue_size = opt.get('queue_size', 180) + + @torch.no_grad() + def _dequeue_and_enqueue(self): + """It is the training pair pool for increasing the diversity in a batch. + + Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a + batch could not have different resize scaling factors. Therefore, we employ this training pair pool + to increase the degradation diversity in a batch. + """ + # initialize + b, c, h, w = self.lq.size() + if not hasattr(self, 'queue_lr'): + assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}' + self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() + _, c, h, w = self.gt.size() + self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() + self.queue_ptr = 0 + if self.queue_ptr == self.queue_size: # the pool is full + # do dequeue and enqueue + # shuffle + idx = torch.randperm(self.queue_size) + self.queue_lr = self.queue_lr[idx] + self.queue_gt = self.queue_gt[idx] + # get first b samples + lq_dequeue = self.queue_lr[0:b, :, :, :].clone() + gt_dequeue = self.queue_gt[0:b, :, :, :].clone() + # update the queue + self.queue_lr[0:b, :, :, :] = self.lq.clone() + self.queue_gt[0:b, :, :, :] = self.gt.clone() + + self.lq = lq_dequeue + self.gt = gt_dequeue + else: + # only do enqueue + self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone() + self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone() + self.queue_ptr = self.queue_ptr + b + + @torch.no_grad() + def feed_data(self, data): + """Accept data from dataloader, and then add two-order degradations to obtain LQ images. + """ + if self.is_train and self.opt.get('high_order_degradation', True): + # training data synthesis + self.gt = data['gt'].to(self.device) + self.gt_usm = self.usm_sharpener(self.gt) + + self.kernel1 = data['kernel1'].to(self.device) + self.kernel2 = data['kernel2'].to(self.device) + self.sinc_kernel = data['sinc_kernel'].to(self.device) + + ori_h, ori_w = self.gt.size()[2:4] + + # ----------------------- The first degradation process ----------------------- # + # blur + out = filter2D(self.gt_usm, self.kernel1) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, scale_factor=scale, mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob'] + if np.random.uniform() < self.opt['gaussian_noise_prob']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) + out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts + out = self.jpeger(out, quality=jpeg_p) + + # ----------------------- The second degradation process ----------------------- # + # blur + if np.random.uniform() < self.opt['second_blur_prob']: + out = filter2D(out, self.kernel2) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range2'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range2'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate( + out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob2'] + if np.random.uniform() < self.opt['gaussian_noise_prob2']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range2'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + + # JPEG compression + the final sinc filter + # We also need to resize images to desired sizes. We group [resize back + sinc filter] together + # as one operation. + # We consider two orders: + # 1. [resize back + sinc filter] + JPEG compression + # 2. JPEG compression + [resize back + sinc filter] + # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. + if np.random.uniform() < 0.5: + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + else: + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + + # clamp and round + self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. + + # random crop + gt_size = self.opt['gt_size'] + (self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size, + self.opt['scale']) + + # training pair pool + self._dequeue_and_enqueue() + # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue + self.gt_usm = self.usm_sharpener(self.gt) + self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract + else: + # for paired training or validation + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + self.gt_usm = self.usm_sharpener(self.gt) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + # do not use the synthetic process during validation + self.is_train = False + super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img) + self.is_train = True + + def optimize_parameters(self, current_iter): + # usm sharpening + l1_gt = self.gt_usm + percep_gt = self.gt_usm + gan_gt = self.gt_usm + if self.opt['l1_gt_usm'] is False: + l1_gt = self.gt + if self.opt['percep_gt_usm'] is False: + percep_gt = self.gt + if self.opt['gan_gt_usm'] is False: + gan_gt = self.gt + + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + l_g_total = 0 + loss_dict = OrderedDict() + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, l1_gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + # gan loss + fake_g_pred = self.net_d(self.output) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + l_g_total.backward() + self.optimizer_g.step() + + # optimize net_d + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # real + real_d_pred = self.net_d(gan_gt) + l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) + loss_dict['l_d_real'] = l_d_real + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + l_d_real.backward() + # fake + fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9 + l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + l_d_fake.backward() + self.optimizer_d.step() + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + self.log_dict = self.reduce_loss_dict(loss_dict) diff --git a/realesrgan/models/realesrnet_model.py b/realesrgan/models/realesrnet_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d11668f3712bffcd062c57db14d22ca3a0e1e59d --- /dev/null +++ b/realesrgan/models/realesrnet_model.py @@ -0,0 +1,188 @@ +import numpy as np +import random +import torch +from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt +from basicsr.data.transforms import paired_random_crop +from basicsr.models.sr_model import SRModel +from basicsr.utils import DiffJPEG, USMSharp +from basicsr.utils.img_process_util import filter2D +from basicsr.utils.registry import MODEL_REGISTRY +from torch.nn import functional as F + + +@MODEL_REGISTRY.register() +class RealESRNetModel(SRModel): + """RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + It is trained without GAN losses. + It mainly performs: + 1. randomly synthesize LQ images in GPU tensors + 2. optimize the networks with GAN training. + """ + + def __init__(self, opt): + super(RealESRNetModel, self).__init__(opt) + self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts + self.usm_sharpener = USMSharp().cuda() # do usm sharpening + self.queue_size = opt.get('queue_size', 180) + + @torch.no_grad() + def _dequeue_and_enqueue(self): + """It is the training pair pool for increasing the diversity in a batch. + + Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a + batch could not have different resize scaling factors. Therefore, we employ this training pair pool + to increase the degradation diversity in a batch. + """ + # initialize + b, c, h, w = self.lq.size() + if not hasattr(self, 'queue_lr'): + assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}' + self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() + _, c, h, w = self.gt.size() + self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() + self.queue_ptr = 0 + if self.queue_ptr == self.queue_size: # the pool is full + # do dequeue and enqueue + # shuffle + idx = torch.randperm(self.queue_size) + self.queue_lr = self.queue_lr[idx] + self.queue_gt = self.queue_gt[idx] + # get first b samples + lq_dequeue = self.queue_lr[0:b, :, :, :].clone() + gt_dequeue = self.queue_gt[0:b, :, :, :].clone() + # update the queue + self.queue_lr[0:b, :, :, :] = self.lq.clone() + self.queue_gt[0:b, :, :, :] = self.gt.clone() + + self.lq = lq_dequeue + self.gt = gt_dequeue + else: + # only do enqueue + self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone() + self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone() + self.queue_ptr = self.queue_ptr + b + + @torch.no_grad() + def feed_data(self, data): + """Accept data from dataloader, and then add two-order degradations to obtain LQ images. + """ + if self.is_train and self.opt.get('high_order_degradation', True): + # training data synthesis + self.gt = data['gt'].to(self.device) + # USM sharpen the GT images + if self.opt['gt_usm'] is True: + self.gt = self.usm_sharpener(self.gt) + + self.kernel1 = data['kernel1'].to(self.device) + self.kernel2 = data['kernel2'].to(self.device) + self.sinc_kernel = data['sinc_kernel'].to(self.device) + + ori_h, ori_w = self.gt.size()[2:4] + + # ----------------------- The first degradation process ----------------------- # + # blur + out = filter2D(self.gt, self.kernel1) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, scale_factor=scale, mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob'] + if np.random.uniform() < self.opt['gaussian_noise_prob']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) + out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts + out = self.jpeger(out, quality=jpeg_p) + + # ----------------------- The second degradation process ----------------------- # + # blur + if np.random.uniform() < self.opt['second_blur_prob']: + out = filter2D(out, self.kernel2) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range2'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range2'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate( + out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob2'] + if np.random.uniform() < self.opt['gaussian_noise_prob2']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range2'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + + # JPEG compression + the final sinc filter + # We also need to resize images to desired sizes. We group [resize back + sinc filter] together + # as one operation. + # We consider two orders: + # 1. [resize back + sinc filter] + JPEG compression + # 2. JPEG compression + [resize back + sinc filter] + # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. + if np.random.uniform() < 0.5: + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + else: + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + + # clamp and round + self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. + + # random crop + gt_size = self.opt['gt_size'] + self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale']) + + # training pair pool + self._dequeue_and_enqueue() + self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract + else: + # for paired training or validation + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + self.gt_usm = self.usm_sharpener(self.gt) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + # do not use the synthetic process during validation + self.is_train = False + super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img) + self.is_train = True diff --git a/realesrgan/train.py b/realesrgan/train.py new file mode 100644 index 0000000000000000000000000000000000000000..8a9cec9ed80d9f362984779548dcec921a636a04 --- /dev/null +++ b/realesrgan/train.py @@ -0,0 +1,11 @@ +# flake8: noqa +import os.path as osp +from basicsr.train import train_pipeline + +import realesrgan.archs +import realesrgan.data +import realesrgan.models + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + train_pipeline(root_path) diff --git a/realesrgan/utils.py b/realesrgan/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..10e7c23d04f777c250160e74470fdfacb16eab88 --- /dev/null +++ b/realesrgan/utils.py @@ -0,0 +1,280 @@ +import cv2 +import math +import numpy as np +import os +import queue +import threading +import torch +from basicsr.utils.download_util import load_file_from_url +from torch.nn import functional as F + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +class RealESRGANer(): + """A helper class for upsampling images with RealESRGAN. + + Args: + scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4. + model_path (str): The path to the pretrained model. It can be urls (will first download it automatically). + model (nn.Module): The defined network. Default: None. + tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop + input images into tiles, and then process each of them. Finally, they will be merged into one image. + 0 denotes for do not use tile. Default: 0. + tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10. + pre_pad (int): Pad the input images to avoid border artifacts. Default: 10. + half (float): Whether to use half precision during inference. Default: False. + """ + + def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False): + self.scale = scale + self.tile_size = tile + self.tile_pad = tile_pad + self.pre_pad = pre_pad + self.mod_scale = None + self.half = half + + # initialize model + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # if the model_path starts with https, it will first download models to the folder: realesrgan/weights + if model_path.startswith('https://'): + model_path = load_file_from_url( + url=model_path, model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'), progress=True, file_name=None) + loadnet = torch.load(model_path, map_location=torch.device('cpu')) + # prefer to use params_ema + if 'params_ema' in loadnet: + keyname = 'params_ema' + else: + keyname = 'params' + model.load_state_dict(loadnet[keyname], strict=True) + model.eval() + self.model = model.to(self.device) + if self.half: + self.model = self.model.half() + + def pre_process(self, img): + """Pre-process, such as pre-pad and mod pad, so that the images can be divisible + """ + img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() + self.img = img.unsqueeze(0).to(self.device) + if self.half: + self.img = self.img.half() + + # pre_pad + if self.pre_pad != 0: + self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') + # mod pad for divisible borders + if self.scale == 2: + self.mod_scale = 2 + elif self.scale == 1: + self.mod_scale = 4 + if self.mod_scale is not None: + self.mod_pad_h, self.mod_pad_w = 0, 0 + _, _, h, w = self.img.size() + if (h % self.mod_scale != 0): + self.mod_pad_h = (self.mod_scale - h % self.mod_scale) + if (w % self.mod_scale != 0): + self.mod_pad_w = (self.mod_scale - w % self.mod_scale) + self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') + + def process(self): + # model inference + self.output = self.model(self.img) + + def tile_process(self): + """It will first crop input images to tiles, and then process each tile. + Finally, all the processed tiles are merged into one images. + + Modified from: https://github.com/ata4/esrgan-launcher + """ + batch, channel, height, width = self.img.shape + output_height = height * self.scale + output_width = width * self.scale + output_shape = (batch, channel, output_height, output_width) + + # start with black image + self.output = self.img.new_zeros(output_shape) + tiles_x = math.ceil(width / self.tile_size) + tiles_y = math.ceil(height / self.tile_size) + + # loop over all tiles + for y in range(tiles_y): + for x in range(tiles_x): + # extract tile from input image + ofs_x = x * self.tile_size + ofs_y = y * self.tile_size + # input tile area on total image + input_start_x = ofs_x + input_end_x = min(ofs_x + self.tile_size, width) + input_start_y = ofs_y + input_end_y = min(ofs_y + self.tile_size, height) + + # input tile area on total image with padding + input_start_x_pad = max(input_start_x - self.tile_pad, 0) + input_end_x_pad = min(input_end_x + self.tile_pad, width) + input_start_y_pad = max(input_start_y - self.tile_pad, 0) + input_end_y_pad = min(input_end_y + self.tile_pad, height) + + # input tile dimensions + input_tile_width = input_end_x - input_start_x + input_tile_height = input_end_y - input_start_y + tile_idx = y * tiles_x + x + 1 + input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] + + # upscale tile + try: + with torch.no_grad(): + output_tile = self.model(input_tile) + except RuntimeError as error: + print('Error', error) + print(f'\tTile {tile_idx}/{tiles_x * tiles_y}') + + # output tile area on total image + output_start_x = input_start_x * self.scale + output_end_x = input_end_x * self.scale + output_start_y = input_start_y * self.scale + output_end_y = input_end_y * self.scale + + # output tile area without padding + output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale + output_end_x_tile = output_start_x_tile + input_tile_width * self.scale + output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale + output_end_y_tile = output_start_y_tile + input_tile_height * self.scale + + # put tile into output image + self.output[:, :, output_start_y:output_end_y, + output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, + output_start_x_tile:output_end_x_tile] + + def post_process(self): + # remove extra pad + if self.mod_scale is not None: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale] + # remove prepad + if self.pre_pad != 0: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale] + return self.output + + @torch.no_grad() + def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'): + h_input, w_input = img.shape[0:2] + # img: numpy + img = img.astype(np.float32) + if np.max(img) > 256: # 16-bit image + max_range = 65535 + print('\tInput is a 16-bit image') + else: + max_range = 255 + img = img / max_range + if len(img.shape) == 2: # gray image + img_mode = 'L' + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + elif img.shape[2] == 4: # RGBA image with alpha channel + img_mode = 'RGBA' + alpha = img[:, :, 3] + img = img[:, :, 0:3] + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if alpha_upsampler == 'realesrgan': + alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) + else: + img_mode = 'RGB' + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # ------------------- process image (without the alpha channel) ------------------- # + self.pre_process(img) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_img = self.post_process() + output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) + if img_mode == 'L': + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) + + # ------------------- process the alpha channel if necessary ------------------- # + if img_mode == 'RGBA': + if alpha_upsampler == 'realesrgan': + self.pre_process(alpha) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_alpha = self.post_process() + output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) + output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) + else: # use the cv2 resize for alpha channel + h, w = alpha.shape[0:2] + output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR) + + # merge the alpha channel + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) + output_img[:, :, 3] = output_alpha + + # ------------------------------ return ------------------------------ # + if max_range == 65535: # 16-bit image + output = (output_img * 65535.0).round().astype(np.uint16) + else: + output = (output_img * 255.0).round().astype(np.uint8) + + if outscale is not None and outscale != float(self.scale): + output = cv2.resize( + output, ( + int(w_input * outscale), + int(h_input * outscale), + ), interpolation=cv2.INTER_LANCZOS4) + + return output, img_mode + + +class PrefetchReader(threading.Thread): + """Prefetch images. + + Args: + img_list (list[str]): A image list of image paths to be read. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, img_list, num_prefetch_queue): + super().__init__() + self.que = queue.Queue(num_prefetch_queue) + self.img_list = img_list + + def run(self): + for img_path in self.img_list: + img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) + self.que.put(img) + + self.que.put(None) + + def __next__(self): + next_item = self.que.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class IOConsumer(threading.Thread): + + def __init__(self, opt, que, qid): + super().__init__() + self._queue = que + self.qid = qid + self.opt = opt + + def run(self): + while True: + msg = self._queue.get() + if isinstance(msg, str) and msg == 'quit': + break + + output = msg['output'] + save_path = msg['save_path'] + cv2.imwrite(save_path, output) + print(f'IO worker {self.qid} is done.') diff --git a/realesrgan/weights/README.md b/realesrgan/weights/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4d7b7e642591ef88575d9e6c360a4d29e0cc1a4f --- /dev/null +++ b/realesrgan/weights/README.md @@ -0,0 +1,3 @@ +# Weights + +Put the downloaded weights to this folder. diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3b073a22b842af8d37c9c3acd527726d881d4264 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,18 @@ +torch +numpy +opencv-python-headless +setuptools +Pillow +gradio +torchvision +addict +future +lmdb +pyyaml +requests +scikit-image +scipy +tb-nightly +tqdm +yapf +psutil \ No newline at end of file diff --git a/scripts/extract_subimages.py b/scripts/extract_subimages.py new file mode 100644 index 0000000000000000000000000000000000000000..9b969ae0d4adff403f2ad362b9afaaaee58e2cef --- /dev/null +++ b/scripts/extract_subimages.py @@ -0,0 +1,135 @@ +import argparse +import cv2 +import numpy as np +import os +import sys +from basicsr.utils import scandir +from multiprocessing import Pool +from os import path as osp +from tqdm import tqdm + + +def main(args): + """A multi-thread tool to crop large images to sub-images for faster IO. + + opt (dict): Configuration dict. It contains: + n_thread (int): Thread number. + compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size + and longer compression time. Use 0 for faster CPU decompression. Default: 3, same in cv2. + input_folder (str): Path to the input folder. + save_folder (str): Path to save folder. + crop_size (int): Crop size. + step (int): Step for overlapped sliding window. + thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped. + + Usage: + For each folder, run this script. + Typically, there are GT folder and LQ folder to be processed for DIV2K dataset. + After process, each sub_folder should have the same number of subimages. + Remember to modify opt configurations according to your settings. + """ + + opt = {} + opt['n_thread'] = args.n_thread + opt['compression_level'] = args.compression_level + opt['input_folder'] = args.input + opt['save_folder'] = args.output + opt['crop_size'] = args.crop_size + opt['step'] = args.step + opt['thresh_size'] = args.thresh_size + extract_subimages(opt) + + +def extract_subimages(opt): + """Crop images to subimages. + + Args: + opt (dict): Configuration dict. It contains: + input_folder (str): Path to the input folder. + save_folder (str): Path to save folder. + n_thread (int): Thread number. + """ + input_folder = opt['input_folder'] + save_folder = opt['save_folder'] + if not osp.exists(save_folder): + os.makedirs(save_folder) + print(f'mkdir {save_folder} ...') + else: + print(f'Folder {save_folder} already exists. Exit.') + sys.exit(1) + + # scan all images + img_list = list(scandir(input_folder, full_path=True)) + + pbar = tqdm(total=len(img_list), unit='image', desc='Extract') + pool = Pool(opt['n_thread']) + for path in img_list: + pool.apply_async(worker, args=(path, opt), callback=lambda arg: pbar.update(1)) + pool.close() + pool.join() + pbar.close() + print('All processes done.') + + +def worker(path, opt): + """Worker for each process. + + Args: + path (str): Image path. + opt (dict): Configuration dict. It contains: + crop_size (int): Crop size. + step (int): Step for overlapped sliding window. + thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped. + save_folder (str): Path to save folder. + compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION. + + Returns: + process_info (str): Process information displayed in progress bar. + """ + crop_size = opt['crop_size'] + step = opt['step'] + thresh_size = opt['thresh_size'] + img_name, extension = osp.splitext(osp.basename(path)) + + # remove the x2, x3, x4 and x8 in the filename for DIV2K + img_name = img_name.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '') + + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + + h, w = img.shape[0:2] + h_space = np.arange(0, h - crop_size + 1, step) + if h - (h_space[-1] + crop_size) > thresh_size: + h_space = np.append(h_space, h - crop_size) + w_space = np.arange(0, w - crop_size + 1, step) + if w - (w_space[-1] + crop_size) > thresh_size: + w_space = np.append(w_space, w - crop_size) + + index = 0 + for x in h_space: + for y in w_space: + index += 1 + cropped_img = img[x:x + crop_size, y:y + crop_size, ...] + cropped_img = np.ascontiguousarray(cropped_img) + cv2.imwrite( + osp.join(opt['save_folder'], f'{img_name}_s{index:03d}{extension}'), cropped_img, + [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']]) + process_info = f'Processing {img_name} ...' + return process_info + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder') + parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_HR_sub', help='Output folder') + parser.add_argument('--crop_size', type=int, default=480, help='Crop size') + parser.add_argument('--step', type=int, default=240, help='Step for overlapped sliding window') + parser.add_argument( + '--thresh_size', + type=int, + default=0, + help='Threshold size. Patches whose size is lower than thresh_size will be dropped.') + parser.add_argument('--n_thread', type=int, default=20, help='Thread number.') + parser.add_argument('--compression_level', type=int, default=3, help='Compression level') + args = parser.parse_args() + + main(args) diff --git a/scripts/generate_meta_info.py b/scripts/generate_meta_info.py new file mode 100644 index 0000000000000000000000000000000000000000..9c3b7a37e85f534075c50e6c33d7cca999d8b836 --- /dev/null +++ b/scripts/generate_meta_info.py @@ -0,0 +1,58 @@ +import argparse +import cv2 +import glob +import os + + +def main(args): + txt_file = open(args.meta_info, 'w') + for folder, root in zip(args.input, args.root): + img_paths = sorted(glob.glob(os.path.join(folder, '*'))) + for img_path in img_paths: + status = True + if args.check: + # read the image once for check, as some images may have errors + try: + img = cv2.imread(img_path) + except (IOError, OSError) as error: + print(f'Read {img_path} error: {error}') + status = False + if img is None: + status = False + print(f'Img is None: {img_path}') + if status: + # get the relative path + img_name = os.path.relpath(img_path, root) + print(img_name) + txt_file.write(f'{img_name}\n') + + +if __name__ == '__main__': + """Generate meta info (txt file) for only Ground-Truth images. + + It can also generate meta info from several folders into one txt file. + """ + parser = argparse.ArgumentParser() + parser.add_argument( + '--input', + nargs='+', + default=['datasets/DF2K/DF2K_HR', 'datasets/DF2K/DF2K_multiscale'], + help='Input folder, can be a list') + parser.add_argument( + '--root', + nargs='+', + default=['datasets/DF2K', 'datasets/DF2K'], + help='Folder root, should have the length as input folders') + parser.add_argument( + '--meta_info', + type=str, + default='datasets/DF2K/meta_info/meta_info_DF2Kmultiscale.txt', + help='txt path for meta info') + parser.add_argument('--check', action='store_true', help='Read image to check whether it is ok') + args = parser.parse_args() + + assert len(args.input) == len(args.root), ('Input folder and folder root should have the same length, but got ' + f'{len(args.input)} and {len(args.root)}.') + os.makedirs(os.path.dirname(args.meta_info), exist_ok=True) + + main(args) diff --git a/scripts/generate_meta_info_pairdata.py b/scripts/generate_meta_info_pairdata.py new file mode 100644 index 0000000000000000000000000000000000000000..76dce7e41c803a8055f3627cccb98deb51419b09 --- /dev/null +++ b/scripts/generate_meta_info_pairdata.py @@ -0,0 +1,49 @@ +import argparse +import glob +import os + + +def main(args): + txt_file = open(args.meta_info, 'w') + # sca images + img_paths_gt = sorted(glob.glob(os.path.join(args.input[0], '*'))) + img_paths_lq = sorted(glob.glob(os.path.join(args.input[1], '*'))) + + assert len(img_paths_gt) == len(img_paths_lq), ('GT folder and LQ folder should have the same length, but got ' + f'{len(img_paths_gt)} and {len(img_paths_lq)}.') + + for img_path_gt, img_path_lq in zip(img_paths_gt, img_paths_lq): + # get the relative paths + img_name_gt = os.path.relpath(img_path_gt, args.root[0]) + img_name_lq = os.path.relpath(img_path_lq, args.root[1]) + print(f'{img_name_gt}, {img_name_lq}') + txt_file.write(f'{img_name_gt}, {img_name_lq}\n') + + +if __name__ == '__main__': + """This script is used to generate meta info (txt file) for paired images. + """ + parser = argparse.ArgumentParser() + parser.add_argument( + '--input', + nargs='+', + default=['datasets/DF2K/DIV2K_train_HR_sub', 'datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub'], + help='Input folder, should be [gt_folder, lq_folder]') + parser.add_argument('--root', nargs='+', default=[None, None], help='Folder root, will use the ') + parser.add_argument( + '--meta_info', + type=str, + default='datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt', + help='txt path for meta info') + args = parser.parse_args() + + assert len(args.input) == 2, 'Input folder should have two elements: gt folder and lq folder' + assert len(args.root) == 2, 'Root path should have two elements: root for gt folder and lq folder' + os.makedirs(os.path.dirname(args.meta_info), exist_ok=True) + for i in range(2): + if args.input[i].endswith('/'): + args.input[i] = args.input[i][:-1] + if args.root[i] is None: + args.root[i] = os.path.dirname(args.input[i]) + + main(args) diff --git a/scripts/generate_multiscale_DF2K.py b/scripts/generate_multiscale_DF2K.py new file mode 100644 index 0000000000000000000000000000000000000000..d4f5d8324b1624e4cb6163754703b8dac2d188fd --- /dev/null +++ b/scripts/generate_multiscale_DF2K.py @@ -0,0 +1,48 @@ +import argparse +import glob +import os +from PIL import Image + + +def main(args): + # For DF2K, we consider the following three scales, + # and the smallest image whose shortest edge is 400 + scale_list = [0.75, 0.5, 1 / 3] + shortest_edge = 400 + + path_list = sorted(glob.glob(os.path.join(args.input, '*'))) + for path in path_list: + print(path) + basename = os.path.splitext(os.path.basename(path))[0] + + img = Image.open(path) + width, height = img.size + for idx, scale in enumerate(scale_list): + print(f'\t{scale:.2f}') + rlt = img.resize((int(width * scale), int(height * scale)), resample=Image.LANCZOS) + rlt.save(os.path.join(args.output, f'{basename}T{idx}.png')) + + # save the smallest image which the shortest edge is 400 + if width < height: + ratio = height / width + width = shortest_edge + height = int(width * ratio) + else: + ratio = width / height + height = shortest_edge + width = int(height * ratio) + rlt = img.resize((int(width), int(height)), resample=Image.LANCZOS) + rlt.save(os.path.join(args.output, f'{basename}T{idx+1}.png')) + + +if __name__ == '__main__': + """Generate multi-scale versions for GT images with LANCZOS resampling. + It is now used for DF2K dataset (DIV2K + Flickr 2K) + """ + parser = argparse.ArgumentParser() + parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder') + parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_multiscale', help='Output folder') + args = parser.parse_args() + + os.makedirs(args.output, exist_ok=True) + main(args) diff --git a/scripts/pytorch2onnx.py b/scripts/pytorch2onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..09d99b2e0171265e70e7507ed8e882b616b449a1 --- /dev/null +++ b/scripts/pytorch2onnx.py @@ -0,0 +1,36 @@ +import argparse +import torch +import torch.onnx +from basicsr.archs.rrdbnet_arch import RRDBNet + + +def main(args): + # An instance of the model + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) + if args.params: + keyname = 'params' + else: + keyname = 'params_ema' + model.load_state_dict(torch.load(args.input)[keyname]) + # set the train mode to false since we will only run the forward pass. + model.train(False) + model.cpu().eval() + + # An example input + x = torch.rand(1, 3, 64, 64) + # Export the model + with torch.no_grad(): + torch_out = torch.onnx._export(model, x, args.output, opset_version=11, export_params=True) + print(torch_out.shape) + + +if __name__ == '__main__': + """Convert pytorch model to onnx models""" + parser = argparse.ArgumentParser() + parser.add_argument( + '--input', type=str, default='experiments/pretrained_models/RealESRGAN_x4plus.pth', help='Input model path') + parser.add_argument('--output', type=str, default='realesrgan-x4.onnx', help='Output onnx path') + parser.add_argument('--params', action='store_false', help='Use params instead of params_ema') + args = parser.parse_args() + + main(args) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000000000000000000000000000000000000..2293ad7a17db08715a166f58fe91ecb576850bbb --- /dev/null +++ b/setup.cfg @@ -0,0 +1,22 @@ +[flake8] +ignore = + # line break before binary operator (W503) + W503, + # line break after binary operator (W504) + W504, +max-line-length=120 + +[yapf] +based_on_style = pep8 +column_limit = 120 +blank_line_before_nested_class_or_def = true +split_before_expression_after_opening_paren = true + +[isort] +line_length = 120 +multi_line_output = 0 +known_standard_library = pkg_resources,setuptools +known_first_party = realesrgan +known_third_party = basicsr,cv2,numpy,torch +no_lines_before = STDLIB,LOCALFOLDER +default_section = THIRDPARTY diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..c2b92e31d2db1aba50767f4f844540cfd53c609d --- /dev/null +++ b/setup.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python + +from setuptools import find_packages, setup + +import os +import subprocess +import time + +version_file = 'realesrgan/version.py' + + +def readme(): + with open('README.md', encoding='utf-8') as f: + content = f.read() + return content + + +def get_git_hash(): + + def _minimal_ext_cmd(cmd): + # construct minimal environment + env = {} + for k in ['SYSTEMROOT', 'PATH', 'HOME']: + v = os.environ.get(k) + if v is not None: + env[k] = v + # LANGUAGE is used on win32 + env['LANGUAGE'] = 'C' + env['LANG'] = 'C' + env['LC_ALL'] = 'C' + out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0] + return out + + try: + out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) + sha = out.strip().decode('ascii') + except OSError: + sha = 'unknown' + + return sha + + +def get_hash(): + if os.path.exists('.git'): + sha = get_git_hash()[:7] + else: + sha = 'unknown' + + return sha + + +def write_version_py(): + content = """# GENERATED VERSION FILE +# TIME: {} +__version__ = '{}' +__gitsha__ = '{}' +version_info = ({}) +""" + sha = get_hash() + with open('VERSION', 'r') as f: + SHORT_VERSION = f.read().strip() + VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) + + version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO) + with open(version_file, 'w') as f: + f.write(version_file_str) + + +def get_version(): + with open(version_file, 'r') as f: + exec(compile(f.read(), version_file, 'exec')) + return locals()['__version__'] + + +def get_requirements(filename='requirements.txt'): + here = os.path.dirname(os.path.realpath(__file__)) + with open(os.path.join(here, filename), 'r') as f: + requires = [line.replace('\n', '') for line in f.readlines()] + return requires + + +if __name__ == '__main__': + write_version_py() + setup( + name='realesrgan', + version=get_version(), + description='Real-ESRGAN aims at developing Practical Algorithms for General Image Restoration', + long_description=readme(), + long_description_content_type='text/markdown', + author='Xintao Wang', + author_email='xintao.wang@outlook.com', + keywords='computer vision, pytorch, image restoration, super-resolution, esrgan, real-esrgan', + url='https://github.com/xinntao/Real-ESRGAN', + include_package_data=True, + packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')), + classifiers=[ + 'Development Status :: 4 - Beta', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + ], + license='BSD-3-Clause License', + setup_requires=['cython', 'numpy'], + install_requires=get_requirements(), + zip_safe=False) diff --git a/tests/data/gt.lmdb/data.mdb b/tests/data/gt.lmdb/data.mdb new file mode 100644 index 0000000000000000000000000000000000000000..f28ad48dd320c1b624cdd30f492cd8fd0c1c9fab Binary files /dev/null and b/tests/data/gt.lmdb/data.mdb differ diff --git a/tests/data/gt.lmdb/lock.mdb b/tests/data/gt.lmdb/lock.mdb new file mode 100644 index 0000000000000000000000000000000000000000..37b3f72fa44829db318abca1f9495d73d7d6e071 Binary files /dev/null and b/tests/data/gt.lmdb/lock.mdb differ diff --git a/tests/data/gt.lmdb/meta_info.txt b/tests/data/gt.lmdb/meta_info.txt new file mode 100644 index 0000000000000000000000000000000000000000..f42295426c1783261024a409005ee693c798951f --- /dev/null +++ b/tests/data/gt.lmdb/meta_info.txt @@ -0,0 +1,2 @@ +baboon.png (480,500,3) 1 +comic.png (360,240,3) 1 diff --git a/tests/data/gt/baboon.png b/tests/data/gt/baboon.png new file mode 100644 index 0000000000000000000000000000000000000000..c81e18de0346d8801f44495f267148919f6ac70a Binary files /dev/null and b/tests/data/gt/baboon.png differ diff --git a/tests/data/gt/comic.png b/tests/data/gt/comic.png new file mode 100644 index 0000000000000000000000000000000000000000..600f5486503b53b77323a7c28b53822b23d576ba Binary files /dev/null and b/tests/data/gt/comic.png differ diff --git a/tests/data/lq.lmdb/data.mdb b/tests/data/lq.lmdb/data.mdb new file mode 100644 index 0000000000000000000000000000000000000000..c0162153452f63afbc798e99bfcdc1a6866caa0a Binary files /dev/null and b/tests/data/lq.lmdb/data.mdb differ diff --git a/tests/data/lq.lmdb/lock.mdb b/tests/data/lq.lmdb/lock.mdb new file mode 100644 index 0000000000000000000000000000000000000000..c3b69ed59644c8337389f82010234aab8f688b09 Binary files /dev/null and b/tests/data/lq.lmdb/lock.mdb differ diff --git a/tests/data/lq.lmdb/meta_info.txt b/tests/data/lq.lmdb/meta_info.txt new file mode 100644 index 0000000000000000000000000000000000000000..6dfca0d9de4717a97db69167f020f34d8da6c0d0 --- /dev/null +++ b/tests/data/lq.lmdb/meta_info.txt @@ -0,0 +1,2 @@ +baboon.png (120,125,3) 1 +comic.png (80,60,3) 1 diff --git a/tests/data/lq/baboon.png b/tests/data/lq/baboon.png new file mode 100644 index 0000000000000000000000000000000000000000..bbd201245f3bb1736bc35820eb28f0d59eef766f Binary files /dev/null and b/tests/data/lq/baboon.png differ diff --git a/tests/data/lq/comic.png b/tests/data/lq/comic.png new file mode 100644 index 0000000000000000000000000000000000000000..c4e38ab76ecb80deb84fdc8f16f5afa009d95ddd Binary files /dev/null and b/tests/data/lq/comic.png differ diff --git a/tests/data/meta_info_gt.txt b/tests/data/meta_info_gt.txt new file mode 100644 index 0000000000000000000000000000000000000000..2234632d9ed7db237273779fe7cd6ddcbee4e67f --- /dev/null +++ b/tests/data/meta_info_gt.txt @@ -0,0 +1,2 @@ +baboon.png +comic.png diff --git a/tests/data/meta_info_pair.txt b/tests/data/meta_info_pair.txt new file mode 100644 index 0000000000000000000000000000000000000000..4775dda818d25e1a2ebf67c98df571b26d87c912 --- /dev/null +++ b/tests/data/meta_info_pair.txt @@ -0,0 +1,2 @@ +gt/baboon.png, lq/baboon.png +gt/comic.png, lq/comic.png diff --git a/tests/data/test_realesrgan_dataset.yml b/tests/data/test_realesrgan_dataset.yml new file mode 100644 index 0000000000000000000000000000000000000000..48e6ecc338e730e74ed5a24aefb66ea5e45381e7 --- /dev/null +++ b/tests/data/test_realesrgan_dataset.yml @@ -0,0 +1,28 @@ +name: Demo +type: RealESRGANDataset +dataroot_gt: tests/data/gt +meta_info: tests/data/meta_info_gt.txt +io_backend: + type: disk + +blur_kernel_size: 21 +kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] +kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] +sinc_prob: 1 +blur_sigma: [0.2, 3] +betag_range: [0.5, 4] +betap_range: [1, 2] + +blur_kernel_size2: 21 +kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] +kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] +sinc_prob2: 1 +blur_sigma2: [0.2, 1.5] +betag_range2: [0.5, 4] +betap_range2: [1, 2] + +final_sinc_prob: 1 + +gt_size: 128 +use_hflip: True +use_rot: False diff --git a/tests/data/test_realesrgan_model.yml b/tests/data/test_realesrgan_model.yml new file mode 100644 index 0000000000000000000000000000000000000000..1cbdab23be5cf973c4bea66a85ac0dca2c1d713e --- /dev/null +++ b/tests/data/test_realesrgan_model.yml @@ -0,0 +1,115 @@ +scale: 4 +num_gpu: 1 +manual_seed: 0 +is_train: True +dist: False + +# ----------------- options for synthesizing training data ----------------- # +# USM the ground-truth +l1_gt_usm: True +percep_gt_usm: True +gan_gt_usm: False + +# the first degradation process +resize_prob: [0.2, 0.7, 0.1] # up, down, keep +resize_range: [0.15, 1.5] +gaussian_noise_prob: 1 +noise_range: [1, 30] +poisson_scale_range: [0.05, 3] +gray_noise_prob: 1 +jpeg_range: [30, 95] + +# the second degradation process +second_blur_prob: 1 +resize_prob2: [0.3, 0.4, 0.3] # up, down, keep +resize_range2: [0.3, 1.2] +gaussian_noise_prob2: 1 +noise_range2: [1, 25] +poisson_scale_range2: [0.05, 2.5] +gray_noise_prob2: 1 +jpeg_range2: [30, 95] + +gt_size: 32 +queue_size: 1 + +# network structures +network_g: + type: RRDBNet + num_in_ch: 3 + num_out_ch: 3 + num_feat: 4 + num_block: 1 + num_grow_ch: 2 + +network_d: + type: UNetDiscriminatorSN + num_in_ch: 3 + num_feat: 2 + skip_connection: True + +# path +path: + pretrain_network_g: ~ + param_key_g: params_ema + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + optim_d: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [400000] + gamma: 0.5 + + total_iter: 400000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + # perceptual loss (content and style losses) + perceptual_opt: + type: PerceptualLoss + layer_weights: + # before relu + 'conv1_2': 0.1 + 'conv2_2': 0.1 + 'conv3_4': 1 + 'conv4_4': 1 + 'conv5_4': 1 + vgg_type: vgg19 + use_input_norm: true + perceptual_weight: !!float 1.0 + style_weight: 0 + range_norm: false + criterion: l1 + # gan loss + gan_opt: + type: GANLoss + gan_type: vanilla + real_label_val: 1.0 + fake_label_val: 0.0 + loss_weight: !!float 1e-1 + + net_d_iters: 1 + net_d_init_iters: 0 + + +# validation settings +val: + val_freq: !!float 5e3 + save_img: False diff --git a/tests/data/test_realesrgan_paired_dataset.yml b/tests/data/test_realesrgan_paired_dataset.yml new file mode 100644 index 0000000000000000000000000000000000000000..8ea9709d214852ae8f792e3ee732edf542dc382d --- /dev/null +++ b/tests/data/test_realesrgan_paired_dataset.yml @@ -0,0 +1,13 @@ +name: Demo +type: RealESRGANPairedDataset +scale: 4 +dataroot_gt: tests/data +dataroot_lq: tests/data +meta_info: tests/data/meta_info_pair.txt +io_backend: + type: disk + +phase: train +gt_size: 128 +use_hflip: True +use_rot: False diff --git a/tests/data/test_realesrnet_model.yml b/tests/data/test_realesrnet_model.yml new file mode 100644 index 0000000000000000000000000000000000000000..06ceb26f4df3cad96ea8d00cf1ede0dc85d5b8d4 --- /dev/null +++ b/tests/data/test_realesrnet_model.yml @@ -0,0 +1,75 @@ +scale: 4 +num_gpu: 1 +manual_seed: 0 +is_train: True +dist: False + +# ----------------- options for synthesizing training data ----------------- # +gt_usm: True # USM the ground-truth + +# the first degradation process +resize_prob: [0.2, 0.7, 0.1] # up, down, keep +resize_range: [0.15, 1.5] +gaussian_noise_prob: 1 +noise_range: [1, 30] +poisson_scale_range: [0.05, 3] +gray_noise_prob: 1 +jpeg_range: [30, 95] + +# the second degradation process +second_blur_prob: 1 +resize_prob2: [0.3, 0.4, 0.3] # up, down, keep +resize_range2: [0.3, 1.2] +gaussian_noise_prob2: 1 +noise_range2: [1, 25] +poisson_scale_range2: [0.05, 2.5] +gray_noise_prob2: 1 +jpeg_range2: [30, 95] + +gt_size: 32 +queue_size: 1 + +# network structures +network_g: + type: RRDBNet + num_in_ch: 3 + num_out_ch: 3 + num_feat: 4 + num_block: 1 + num_grow_ch: 2 + +# path +path: + pretrain_network_g: ~ + param_key_g: params_ema + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 2e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [1000000] + gamma: 0.5 + + total_iter: 1000000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + + +# validation settings +val: + val_freq: !!float 5e3 + save_img: False diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..715b4082645c131d43d728ae8f65bcc2430aa8c9 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,151 @@ +import pytest +import yaml + +from realesrgan.data.realesrgan_dataset import RealESRGANDataset +from realesrgan.data.realesrgan_paired_dataset import RealESRGANPairedDataset + + +def test_realesrgan_dataset(): + + with open('tests/data/test_realesrgan_dataset.yml', mode='r') as f: + opt = yaml.load(f, Loader=yaml.FullLoader) + + dataset = RealESRGANDataset(opt) + assert dataset.io_backend_opt['type'] == 'disk' # io backend + assert len(dataset) == 2 # whether to read correct meta info + assert dataset.kernel_list == [ + 'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso' + ] # correct initialization the degradation configurations + assert dataset.betag_range2 == [0.5, 4] + + # test __getitem__ + result = dataset.__getitem__(0) + # check returned keys + expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path'] + assert set(expected_keys).issubset(set(result.keys())) + # check shape and contents + assert result['gt'].shape == (3, 400, 400) + assert result['kernel1'].shape == (21, 21) + assert result['kernel2'].shape == (21, 21) + assert result['sinc_kernel'].shape == (21, 21) + assert result['gt_path'] == 'tests/data/gt/baboon.png' + + # ------------------ test lmdb backend -------------------- # + opt['dataroot_gt'] = 'tests/data/gt.lmdb' + opt['io_backend']['type'] = 'lmdb' + + dataset = RealESRGANDataset(opt) + assert dataset.io_backend_opt['type'] == 'lmdb' # io backend + assert len(dataset.paths) == 2 # whether to read correct meta info + assert dataset.kernel_list == [ + 'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso' + ] # correct initialization the degradation configurations + assert dataset.betag_range2 == [0.5, 4] + + # test __getitem__ + result = dataset.__getitem__(1) + # check returned keys + expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path'] + assert set(expected_keys).issubset(set(result.keys())) + # check shape and contents + assert result['gt'].shape == (3, 400, 400) + assert result['kernel1'].shape == (21, 21) + assert result['kernel2'].shape == (21, 21) + assert result['sinc_kernel'].shape == (21, 21) + assert result['gt_path'] == 'comic' + + # ------------------ test with sinc_prob = 0 -------------------- # + opt['dataroot_gt'] = 'tests/data/gt.lmdb' + opt['io_backend']['type'] = 'lmdb' + opt['sinc_prob'] = 0 + opt['sinc_prob2'] = 0 + opt['final_sinc_prob'] = 0 + dataset = RealESRGANDataset(opt) + result = dataset.__getitem__(0) + # check returned keys + expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path'] + assert set(expected_keys).issubset(set(result.keys())) + # check shape and contents + assert result['gt'].shape == (3, 400, 400) + assert result['kernel1'].shape == (21, 21) + assert result['kernel2'].shape == (21, 21) + assert result['sinc_kernel'].shape == (21, 21) + assert result['gt_path'] == 'baboon' + + # ------------------ lmdb backend should have paths ends with lmdb -------------------- # + with pytest.raises(ValueError): + opt['dataroot_gt'] = 'tests/data/gt' + opt['io_backend']['type'] = 'lmdb' + dataset = RealESRGANDataset(opt) + + +def test_realesrgan_paired_dataset(): + + with open('tests/data/test_realesrgan_paired_dataset.yml', mode='r') as f: + opt = yaml.load(f, Loader=yaml.FullLoader) + + dataset = RealESRGANPairedDataset(opt) + assert dataset.io_backend_opt['type'] == 'disk' # io backend + assert len(dataset) == 2 # whether to read correct meta info + + # test __getitem__ + result = dataset.__getitem__(0) + # check returned keys + expected_keys = ['gt', 'lq', 'gt_path', 'lq_path'] + assert set(expected_keys).issubset(set(result.keys())) + # check shape and contents + assert result['gt'].shape == (3, 128, 128) + assert result['lq'].shape == (3, 32, 32) + assert result['gt_path'] == 'tests/data/gt/baboon.png' + assert result['lq_path'] == 'tests/data/lq/baboon.png' + + # ------------------ test lmdb backend -------------------- # + opt['dataroot_gt'] = 'tests/data/gt.lmdb' + opt['dataroot_lq'] = 'tests/data/lq.lmdb' + opt['io_backend']['type'] = 'lmdb' + + dataset = RealESRGANPairedDataset(opt) + assert dataset.io_backend_opt['type'] == 'lmdb' # io backend + assert len(dataset) == 2 # whether to read correct meta info + + # test __getitem__ + result = dataset.__getitem__(1) + # check returned keys + expected_keys = ['gt', 'lq', 'gt_path', 'lq_path'] + assert set(expected_keys).issubset(set(result.keys())) + # check shape and contents + assert result['gt'].shape == (3, 128, 128) + assert result['lq'].shape == (3, 32, 32) + assert result['gt_path'] == 'comic' + assert result['lq_path'] == 'comic' + + # ------------------ test paired_paths_from_folder -------------------- # + opt['dataroot_gt'] = 'tests/data/gt' + opt['dataroot_lq'] = 'tests/data/lq' + opt['io_backend'] = dict(type='disk') + opt['meta_info'] = None + + dataset = RealESRGANPairedDataset(opt) + assert dataset.io_backend_opt['type'] == 'disk' # io backend + assert len(dataset) == 2 # whether to read correct meta info + + # test __getitem__ + result = dataset.__getitem__(0) + # check returned keys + expected_keys = ['gt', 'lq', 'gt_path', 'lq_path'] + assert set(expected_keys).issubset(set(result.keys())) + # check shape and contents + assert result['gt'].shape == (3, 128, 128) + assert result['lq'].shape == (3, 32, 32) + + # ------------------ test normalization -------------------- # + dataset.mean = [0.5, 0.5, 0.5] + dataset.std = [0.5, 0.5, 0.5] + # test __getitem__ + result = dataset.__getitem__(0) + # check returned keys + expected_keys = ['gt', 'lq', 'gt_path', 'lq_path'] + assert set(expected_keys).issubset(set(result.keys())) + # check shape and contents + assert result['gt'].shape == (3, 128, 128) + assert result['lq'].shape == (3, 32, 32) diff --git a/tests/test_discriminator_arch.py b/tests/test_discriminator_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..c56a40c7743630aa63b3e99bca8dc1a85949c4c5 --- /dev/null +++ b/tests/test_discriminator_arch.py @@ -0,0 +1,19 @@ +import torch + +from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN + + +def test_unetdiscriminatorsn(): + """Test arch: UNetDiscriminatorSN.""" + + # model init and forward (cpu) + net = UNetDiscriminatorSN(num_in_ch=3, num_feat=4, skip_connection=True) + img = torch.rand((1, 3, 32, 32), dtype=torch.float32) + output = net(img) + assert output.shape == (1, 1, 32, 32) + + # model init and forward (gpu) + if torch.cuda.is_available(): + net.cuda() + output = net(img.cuda()) + assert output.shape == (1, 1, 32, 32) diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c20bb1d56ed20222e929e9c94026f6ea383c6026 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,126 @@ +import torch +import yaml +from basicsr.archs.rrdbnet_arch import RRDBNet +from basicsr.data.paired_image_dataset import PairedImageDataset +from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss + +from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN +from realesrgan.models.realesrgan_model import RealESRGANModel +from realesrgan.models.realesrnet_model import RealESRNetModel + + +def test_realesrnet_model(): + with open('tests/data/test_realesrnet_model.yml', mode='r') as f: + opt = yaml.load(f, Loader=yaml.FullLoader) + + # build model + model = RealESRNetModel(opt) + # test attributes + assert model.__class__.__name__ == 'RealESRNetModel' + assert isinstance(model.net_g, RRDBNet) + assert isinstance(model.cri_pix, L1Loss) + assert isinstance(model.optimizers[0], torch.optim.Adam) + + # prepare data + gt = torch.rand((1, 3, 32, 32), dtype=torch.float32) + kernel1 = torch.rand((1, 5, 5), dtype=torch.float32) + kernel2 = torch.rand((1, 5, 5), dtype=torch.float32) + sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32) + data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel) + model.feed_data(data) + # check dequeue + model.feed_data(data) + # check data shape + assert model.lq.shape == (1, 3, 8, 8) + assert model.gt.shape == (1, 3, 32, 32) + + # change probability to test if-else + model.opt['gaussian_noise_prob'] = 0 + model.opt['gray_noise_prob'] = 0 + model.opt['second_blur_prob'] = 0 + model.opt['gaussian_noise_prob2'] = 0 + model.opt['gray_noise_prob2'] = 0 + model.feed_data(data) + # check data shape + assert model.lq.shape == (1, 3, 8, 8) + assert model.gt.shape == (1, 3, 32, 32) + + # ----------------- test nondist_validation -------------------- # + # construct dataloader + dataset_opt = dict( + name='Demo', + dataroot_gt='tests/data/gt', + dataroot_lq='tests/data/lq', + io_backend=dict(type='disk'), + scale=4, + phase='val') + dataset = PairedImageDataset(dataset_opt) + dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) + assert model.is_train is True + model.nondist_validation(dataloader, 1, None, False) + assert model.is_train is True + + +def test_realesrgan_model(): + with open('tests/data/test_realesrgan_model.yml', mode='r') as f: + opt = yaml.load(f, Loader=yaml.FullLoader) + + # build model + model = RealESRGANModel(opt) + # test attributes + assert model.__class__.__name__ == 'RealESRGANModel' + assert isinstance(model.net_g, RRDBNet) # generator + assert isinstance(model.net_d, UNetDiscriminatorSN) # discriminator + assert isinstance(model.cri_pix, L1Loss) + assert isinstance(model.cri_perceptual, PerceptualLoss) + assert isinstance(model.cri_gan, GANLoss) + assert isinstance(model.optimizers[0], torch.optim.Adam) + assert isinstance(model.optimizers[1], torch.optim.Adam) + + # prepare data + gt = torch.rand((1, 3, 32, 32), dtype=torch.float32) + kernel1 = torch.rand((1, 5, 5), dtype=torch.float32) + kernel2 = torch.rand((1, 5, 5), dtype=torch.float32) + sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32) + data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel) + model.feed_data(data) + # check dequeue + model.feed_data(data) + # check data shape + assert model.lq.shape == (1, 3, 8, 8) + assert model.gt.shape == (1, 3, 32, 32) + + # change probability to test if-else + model.opt['gaussian_noise_prob'] = 0 + model.opt['gray_noise_prob'] = 0 + model.opt['second_blur_prob'] = 0 + model.opt['gaussian_noise_prob2'] = 0 + model.opt['gray_noise_prob2'] = 0 + model.feed_data(data) + # check data shape + assert model.lq.shape == (1, 3, 8, 8) + assert model.gt.shape == (1, 3, 32, 32) + + # ----------------- test nondist_validation -------------------- # + # construct dataloader + dataset_opt = dict( + name='Demo', + dataroot_gt='tests/data/gt', + dataroot_lq='tests/data/lq', + io_backend=dict(type='disk'), + scale=4, + phase='val') + dataset = PairedImageDataset(dataset_opt) + dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) + assert model.is_train is True + model.nondist_validation(dataloader, 1, None, False) + assert model.is_train is True + + # ----------------- test optimize_parameters -------------------- # + model.feed_data(data) + model.optimize_parameters(1) + assert model.output.shape == (1, 3, 32, 32) + assert isinstance(model.log_dict, dict) + # check returned keys + expected_keys = ['l_g_pix', 'l_g_percep', 'l_g_gan', 'l_d_real', 'out_d_real', 'l_d_fake', 'out_d_fake'] + assert set(expected_keys).issubset(set(model.log_dict.keys())) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7919b74905495b4b6f4aa957a1f0b5d7a174c782 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,87 @@ +import numpy as np +from basicsr.archs.rrdbnet_arch import RRDBNet + +from realesrgan.utils import RealESRGANer + + +def test_realesrganer(): + # initialize with default model + restorer = RealESRGANer( + scale=4, + model_path='experiments/pretrained_models/RealESRGAN_x4plus.pth', + model=None, + tile=10, + tile_pad=10, + pre_pad=2, + half=False) + assert isinstance(restorer.model, RRDBNet) + assert restorer.half is False + # initialize with user-defined model + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) + restorer = RealESRGANer( + scale=4, + model_path='experiments/pretrained_models/RealESRGAN_x4plus_anime_6B.pth', + model=model, + tile=10, + tile_pad=10, + pre_pad=2, + half=True) + # test attribute + assert isinstance(restorer.model, RRDBNet) + assert restorer.half is True + + # ------------------ test pre_process ---------------- # + img = np.random.random((12, 12, 3)).astype(np.float32) + restorer.pre_process(img) + assert restorer.img.shape == (1, 3, 14, 14) + # with modcrop + restorer.scale = 1 + restorer.pre_process(img) + assert restorer.img.shape == (1, 3, 16, 16) + + # ------------------ test process ---------------- # + restorer.process() + assert restorer.output.shape == (1, 3, 64, 64) + + # ------------------ test post_process ---------------- # + restorer.mod_scale = 4 + output = restorer.post_process() + assert output.shape == (1, 3, 60, 60) + + # ------------------ test tile_process ---------------- # + restorer.scale = 4 + img = np.random.random((12, 12, 3)).astype(np.float32) + restorer.pre_process(img) + restorer.tile_process() + assert restorer.output.shape == (1, 3, 64, 64) + + # ------------------ test enhance ---------------- # + img = np.random.random((12, 12, 3)).astype(np.float32) + result = restorer.enhance(img, outscale=2) + assert result[0].shape == (24, 24, 3) + assert result[1] == 'RGB' + + # ------------------ test enhance with 16-bit image---------------- # + img = np.random.random((4, 4, 3)).astype(np.uint16) + 512 + result = restorer.enhance(img, outscale=2) + assert result[0].shape == (8, 8, 3) + assert result[1] == 'RGB' + + # ------------------ test enhance with gray image---------------- # + img = np.random.random((4, 4)).astype(np.float32) + result = restorer.enhance(img, outscale=2) + assert result[0].shape == (8, 8) + assert result[1] == 'L' + + # ------------------ test enhance with RGBA---------------- # + img = np.random.random((4, 4, 4)).astype(np.float32) + result = restorer.enhance(img, outscale=2) + assert result[0].shape == (8, 8, 4) + assert result[1] == 'RGBA' + + # ------------------ test enhance with RGBA, alpha_upsampler---------------- # + restorer.tile_size = 0 + img = np.random.random((4, 4, 4)).astype(np.float32) + result = restorer.enhance(img, outscale=2, alpha_upsampler=None) + assert result[0].shape == (8, 8, 4) + assert result[1] == 'RGBA'