culture commited on
Commit
1fa4de0
·
1 Parent(s): f37ad14

Upload inference_gfpgan.py

Browse files
Files changed (1) hide show
  1. inference_gfpgan.py +116 -0
inference_gfpgan.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import glob
4
+ import numpy as np
5
+ import os
6
+ import torch
7
+ from basicsr.utils import imwrite
8
+
9
+ from gfpgan import GFPGANer
10
+
11
+
12
+ def main():
13
+ """Inference demo for GFPGAN.
14
+ """
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument('--upscale', type=int, default=2, help='The final upsampling scale of the image')
17
+ parser.add_argument('--arch', type=str, default='clean', help='The GFPGAN architecture. Option: clean | original')
18
+ parser.add_argument('--channel', type=int, default=2, help='Channel multiplier for large networks of StyleGAN2')
19
+ parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth')
20
+ parser.add_argument('--bg_upsampler', type=str, default='realesrgan', help='background upsampler')
21
+ parser.add_argument(
22
+ '--bg_tile', type=int, default=400, help='Tile size for background sampler, 0 for no tile during testing')
23
+ parser.add_argument('--test_path', type=str, default='inputs/whole_imgs', help='Input folder')
24
+ parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
25
+ parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face')
26
+ parser.add_argument('--aligned', action='store_true', help='Input are aligned faces')
27
+ parser.add_argument('--paste_back', action='store_false', help='Paste the restored faces back to images')
28
+ parser.add_argument('--save_root', type=str, default='results', help='Path to save root')
29
+ parser.add_argument(
30
+ '--ext',
31
+ type=str,
32
+ default='auto',
33
+ help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
34
+ args = parser.parse_args()
35
+
36
+ args = parser.parse_args()
37
+ if args.test_path.endswith('/'):
38
+ args.test_path = args.test_path[:-1]
39
+ os.makedirs(args.save_root, exist_ok=True)
40
+
41
+ # background upsampler
42
+ if args.bg_upsampler == 'realesrgan':
43
+ if not torch.cuda.is_available(): # CPU
44
+ import warnings
45
+ warnings.warn('The unoptimized RealESRGAN is very slow on CPU. We do not use it. '
46
+ 'If you really want to use it, please modify the corresponding codes.')
47
+ bg_upsampler = None
48
+ else:
49
+ from basicsr.archs.rrdbnet_arch import RRDBNet
50
+ from realesrgan import RealESRGANer
51
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
52
+ bg_upsampler = RealESRGANer(
53
+ scale=2,
54
+ model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
55
+ model=model,
56
+ tile=args.bg_tile,
57
+ tile_pad=10,
58
+ pre_pad=0,
59
+ half=True) # need to set False in CPU mode
60
+ else:
61
+ bg_upsampler = None
62
+ # set up GFPGAN restorer
63
+ restorer = GFPGANer(
64
+ model_path=args.model_path,
65
+ upscale=args.upscale,
66
+ arch=args.arch,
67
+ channel_multiplier=args.channel,
68
+ bg_upsampler=bg_upsampler)
69
+
70
+ img_list = sorted(glob.glob(os.path.join(args.test_path, '*')))
71
+ for img_path in img_list:
72
+ # read image
73
+ img_name = os.path.basename(img_path)
74
+ print(f'Processing {img_name} ...')
75
+ basename, ext = os.path.splitext(img_name)
76
+ input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
77
+
78
+ # restore faces and background if necessary
79
+ cropped_faces, restored_faces, restored_img = restorer.enhance(
80
+ input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=args.paste_back)
81
+
82
+ # save faces
83
+ for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)):
84
+ # save cropped face
85
+ save_crop_path = os.path.join(args.save_root, 'cropped_faces', f'{basename}_{idx:02d}.png')
86
+ imwrite(cropped_face, save_crop_path)
87
+ # save restored face
88
+ if args.suffix is not None:
89
+ save_face_name = f'{basename}_{idx:02d}_{args.suffix}.png'
90
+ else:
91
+ save_face_name = f'{basename}_{idx:02d}.png'
92
+ save_restore_path = os.path.join(args.save_root, 'restored_faces', save_face_name)
93
+ imwrite(restored_face, save_restore_path)
94
+ # save comparison image
95
+ cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
96
+ imwrite(cmp_img, os.path.join(args.save_root, 'cmp', f'{basename}_{idx:02d}.png'))
97
+
98
+ # save restored img
99
+ if restored_img is not None:
100
+ if args.ext == 'auto':
101
+ extension = ext[1:]
102
+ else:
103
+ extension = args.ext
104
+
105
+ if args.suffix is not None:
106
+ save_restore_path = os.path.join(args.save_root, 'restored_imgs',
107
+ f'{basename}_{args.suffix}.{extension}')
108
+ else:
109
+ save_restore_path = os.path.join(args.save_root, 'restored_imgs', f'{basename}.{extension}')
110
+ imwrite(restored_img, save_restore_path)
111
+
112
+ print(f'Results are in the [{args.save_root}] folder.')
113
+
114
+
115
+ if __name__ == '__main__':
116
+ main()