sczhou commited on
Commit
eb22a9a
·
1 Parent(s): 581abcb

update face_upsample.

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. inference_codeformer.py +37 -24
README.md CHANGED
@@ -20,7 +20,7 @@ S-Lab, Nanyang Technological University
20
 
21
  ### Updates
22
 
23
- - **2022.09.04**: Add face upsampling '--face_upsample' for high-resolution AI-created face enhancement.
24
  - **2022.08.23**: Some modifications on face detection and fusion for better AI-created face enhancement.
25
  - **2022.08.07**: Integrate Real-ESRGAN to support background image enhancement.
26
  - **2022.07.29**: Integrate new face detectors of `['RetinaFace'(default), 'YOLOv5']`.
 
20
 
21
  ### Updates
22
 
23
+ - **2022.09.04**: Add face upsampling `--face_upsample` for high-resolution AI-created face enhancement.
24
  - **2022.08.23**: Some modifications on face detection and fusion for better AI-created face enhancement.
25
  - **2022.08.07**: Integrate Real-ESRGAN to support background image enhancement.
26
  - **2022.07.29**: Integrate new face detectors of `['RetinaFace'(default), 'YOLOv5']`.
inference_codeformer.py CHANGED
@@ -16,6 +16,27 @@ pretrain_model_url = {
16
  'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
17
  }
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  if __name__ == '__main__':
20
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
  parser = argparse.ArgumentParser()
@@ -44,27 +65,19 @@ if __name__ == '__main__':
44
 
45
  # ------------------ set up background upsampler ------------------
46
  if args.bg_upsampler == 'realesrgan':
47
- if not torch.cuda.is_available(): # CPU
48
- import warnings
49
- warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
50
- 'If you really want to use it, please modify the corresponding codes.',
51
- category=RuntimeWarning)
52
- bg_upsampler = None
53
- else:
54
- from basicsr.archs.rrdbnet_arch import RRDBNet
55
- from basicsr.utils.realesrgan_utils import RealESRGANer
56
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
57
- bg_upsampler = RealESRGANer(
58
- scale=2,
59
- model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
60
- model=model,
61
- tile=args.bg_tile,
62
- tile_pad=40,
63
- pre_pad=0,
64
- half=True) # need to set False in CPU mode
65
  else:
66
  bg_upsampler = None
67
 
 
 
 
 
 
 
 
 
 
68
  # ------------------ set up CodeFormer restorer -------------------
69
  net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
70
  connect_list=['32', '64', '128', '256']).to(device)
@@ -80,12 +93,12 @@ if __name__ == '__main__':
80
  # large det_model: 'YOLOv5l', 'retinaface_resnet50'
81
  # small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
82
  if not args.has_aligned:
83
- print(f'Using [{args.detection_model}] for face detection network.')
84
- if args.bg_upsampler is not None:
85
  print(f'Background upsampling: True, Face upsampling: {args.face_upsample}')
86
  else:
87
- print('Background upsampling: False, Face upsampling: False')
88
-
89
  face_helper = FaceRestoreHelper(
90
  args.upscale,
91
  face_size=512,
@@ -149,8 +162,8 @@ if __name__ == '__main__':
149
  bg_img = None
150
  face_helper.get_inverse_affine(None)
151
  # paste each restored face to the input image
152
- if args.face_upsample and bg_upsampler is not None:
153
- restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box, face_upsampler=bg_upsampler)
154
  else:
155
  restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box)
156
 
 
16
  'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
17
  }
18
 
19
+ def set_realesrgan():
20
+ if not torch.cuda.is_available(): # CPU
21
+ import warnings
22
+ warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
23
+ 'If you really want to use it, please modify the corresponding codes.',
24
+ category=RuntimeWarning)
25
+ bg_upsampler = None
26
+ else:
27
+ from basicsr.archs.rrdbnet_arch import RRDBNet
28
+ from basicsr.utils.realesrgan_utils import RealESRGANer
29
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
30
+ bg_upsampler = RealESRGANer(
31
+ scale=2,
32
+ model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
33
+ model=model,
34
+ tile=args.bg_tile,
35
+ tile_pad=40,
36
+ pre_pad=0,
37
+ half=True) # need to set False in CPU mode
38
+ return bg_upsampler
39
+
40
  if __name__ == '__main__':
41
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
  parser = argparse.ArgumentParser()
 
65
 
66
  # ------------------ set up background upsampler ------------------
67
  if args.bg_upsampler == 'realesrgan':
68
+ bg_upsampler = set_realesrgan()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  else:
70
  bg_upsampler = None
71
 
72
+ # ------------------ set up face upsampler ------------------
73
+ if args.face_upsample:
74
+ if bg_upsampler is not None:
75
+ face_upsampler = bg_upsampler
76
+ else:
77
+ face_upsampler = set_realesrgan()
78
+ else:
79
+ face_upsampler = None
80
+
81
  # ------------------ set up CodeFormer restorer -------------------
82
  net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
83
  connect_list=['32', '64', '128', '256']).to(device)
 
93
  # large det_model: 'YOLOv5l', 'retinaface_resnet50'
94
  # small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
95
  if not args.has_aligned:
96
+ print(f'Face detection model: {args.detection_model}')
97
+ if bg_upsampler is not None:
98
  print(f'Background upsampling: True, Face upsampling: {args.face_upsample}')
99
  else:
100
+ print(f'Background upsampling: False, Face upsampling: {args.face_upsample}')
101
+
102
  face_helper = FaceRestoreHelper(
103
  args.upscale,
104
  face_size=512,
 
162
  bg_img = None
163
  face_helper.get_inverse_affine(None)
164
  # paste each restored face to the input image
165
+ if args.face_upsample and face_upsampler is not None:
166
+ restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box, face_upsampler=face_upsampler)
167
  else:
168
  restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box)
169