sczhou commited on
Commit
e373be4
Β·
1 Parent(s): 00fc5a8

add out path control and save name prefix (#44)

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +91 -88
  3. inference_codeformer.py +14 -4
README.md CHANGED
@@ -97,7 +97,7 @@ You can put the testing images in the `inputs/TestWhole` folder. If you would li
97
  #### Testing on Face Restoration:
98
  [Note] If you want to compare CodeFormer in your paper, please run the following command indicating `--has_aligned` (for cropped and aligned face), as the command for the whole image will involve a process of face-background fusion that may damage hair texture on the boundary, which leads to unfair comparison.
99
 
100
- πŸ‘¨πŸ» Face Restoration (cropped and aligned face)
101
  ```
102
  # For cropped and aligned faces
103
  python inference_codeformer.py --w 0.5 --has_aligned --test_path [input folder]
 
97
  #### Testing on Face Restoration:
98
  [Note] If you want to compare CodeFormer in your paper, please run the following command indicating `--has_aligned` (for cropped and aligned face), as the command for the whole image will involve a process of face-background fusion that may damage hair texture on the boundary, which leads to unfair comparison.
99
 
100
+ πŸ§‘πŸ» Face Restoration (cropped and aligned face)
101
  ```
102
  # For cropped and aligned faces
103
  python inference_codeformer.py --w 0.5 --has_aligned --test_path [input folder]
app.py CHANGED
@@ -103,98 +103,101 @@ os.makedirs('output', exist_ok=True)
103
 
104
  def inference(image, background_enhance, face_upsample, upscale, codeformer_fidelity):
105
  """Run a single prediction on the model"""
106
- # take the default setting for the demo
107
- has_aligned = False
108
- only_center_face = False
109
- draw_box = False
110
- detection_model = "retinaface_resnet50"
111
-
112
- upscale = int(upscale) # covert type to int
113
- face_helper = FaceRestoreHelper(
114
- upscale,
115
- face_size=512,
116
- crop_ratio=(1, 1),
117
- det_model=detection_model,
118
- save_ext="png",
119
- use_parse=True,
120
- device=device,
121
- )
122
- bg_upsampler = upsampler if background_enhance else None
123
- face_upsampler = upsampler if face_upsample else None
124
-
125
- img = cv2.imread(str(image), cv2.IMREAD_COLOR)
126
-
127
- if has_aligned:
128
- # the input faces are already cropped and aligned
129
- img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
130
- face_helper.is_gray = is_gray(img, threshold=5)
131
- if face_helper.is_gray:
132
- print('Grayscale input: True')
133
- face_helper.cropped_faces = [img]
134
- else:
135
- face_helper.read_image(img)
136
- # get face landmarks for each face
137
- num_det_faces = face_helper.get_face_landmarks_5(
138
- only_center_face=only_center_face, resize=640, eye_dist_threshold=5
139
- )
140
- print(f"\tdetect {num_det_faces} faces")
141
- # align and warp each face
142
- face_helper.align_warp_face()
143
-
144
- # face restoration for each cropped face
145
- for idx, cropped_face in enumerate(face_helper.cropped_faces):
146
- # prepare data
147
- cropped_face_t = img2tensor(
148
- cropped_face / 255.0, bgr2rgb=True, float32=True
149
  )
150
- normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
151
- cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
152
-
153
- try:
154
- with torch.no_grad():
155
- output = codeformer_net(
156
- cropped_face_t, w=codeformer_fidelity, adain=True
157
- )[0]
158
- restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
159
- del output
160
- torch.cuda.empty_cache()
161
- except Exception as error:
162
- print(f"\tFailed inference for CodeFormer: {error}")
163
- restored_face = tensor2img(
164
- cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
165
- )
166
-
167
- restored_face = restored_face.astype("uint8")
168
- face_helper.add_restored_face(restored_face)
169
-
170
- # paste_back
171
- if not has_aligned:
172
- # upsample the background
173
- if bg_upsampler is not None:
174
- # Now only support RealESRGAN for upsampling background
175
- bg_img = bg_upsampler.enhance(img, outscale=upscale)[0]
176
  else:
177
- bg_img = None
178
- face_helper.get_inverse_affine(None)
179
- # paste each restored face to the input image
180
- if face_upsample and face_upsampler is not None:
181
- restored_img = face_helper.paste_faces_to_input_image(
182
- upsample_img=bg_img,
183
- draw_box=draw_box,
184
- face_upsampler=face_upsampler,
185
  )
186
- else:
187
- restored_img = face_helper.paste_faces_to_input_image(
188
- upsample_img=bg_img, draw_box=draw_box
 
 
 
 
 
 
189
  )
190
-
191
- # save restored img
192
- save_path = f'output/out.png'
193
- imwrite(restored_img, str(save_path))
194
-
195
- restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
196
- return restored_img, save_path
197
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
 
200
  title = "CodeFormer: Robust Face Restoration and Enhancement Network"
 
103
 
104
  def inference(image, background_enhance, face_upsample, upscale, codeformer_fidelity):
105
  """Run a single prediction on the model"""
106
+ try: # global try
107
+ # take the default setting for the demo
108
+ has_aligned = False
109
+ only_center_face = False
110
+ draw_box = False
111
+ detection_model = "retinaface_resnet50"
112
+
113
+ upscale = int(upscale) # covert type to int
114
+ face_helper = FaceRestoreHelper(
115
+ upscale,
116
+ face_size=512,
117
+ crop_ratio=(1, 1),
118
+ det_model=detection_model,
119
+ save_ext="png",
120
+ use_parse=True,
121
+ device=device,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  )
123
+ bg_upsampler = upsampler if background_enhance else None
124
+ face_upsampler = upsampler if face_upsample else None
125
+
126
+ img = cv2.imread(str(image), cv2.IMREAD_COLOR)
127
+
128
+ if has_aligned:
129
+ # the input faces are already cropped and aligned
130
+ img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
131
+ face_helper.is_gray = is_gray(img, threshold=5)
132
+ if face_helper.is_gray:
133
+ print('Grayscale input: True')
134
+ face_helper.cropped_faces = [img]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  else:
136
+ face_helper.read_image(img)
137
+ # get face landmarks for each face
138
+ num_det_faces = face_helper.get_face_landmarks_5(
139
+ only_center_face=only_center_face, resize=640, eye_dist_threshold=5
 
 
 
 
140
  )
141
+ print(f"\tdetect {num_det_faces} faces")
142
+ # align and warp each face
143
+ face_helper.align_warp_face()
144
+
145
+ # face restoration for each cropped face
146
+ for idx, cropped_face in enumerate(face_helper.cropped_faces):
147
+ # prepare data
148
+ cropped_face_t = img2tensor(
149
+ cropped_face / 255.0, bgr2rgb=True, float32=True
150
  )
151
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
152
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
153
+
154
+ try:
155
+ with torch.no_grad():
156
+ output = codeformer_net(
157
+ cropped_face_t, w=codeformer_fidelity, adain=True
158
+ )[0]
159
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
160
+ del output
161
+ torch.cuda.empty_cache()
162
+ except Exception as error:
163
+ print(f"\tFailed inference for CodeFormer: {error}")
164
+ restored_face = tensor2img(
165
+ cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
166
+ )
167
+
168
+ restored_face = restored_face.astype("uint8")
169
+ face_helper.add_restored_face(restored_face)
170
+
171
+ # paste_back
172
+ if not has_aligned:
173
+ # upsample the background
174
+ if bg_upsampler is not None:
175
+ # Now only support RealESRGAN for upsampling background
176
+ bg_img = bg_upsampler.enhance(img, outscale=upscale)[0]
177
+ else:
178
+ bg_img = None
179
+ face_helper.get_inverse_affine(None)
180
+ # paste each restored face to the input image
181
+ if face_upsample and face_upsampler is not None:
182
+ restored_img = face_helper.paste_faces_to_input_image(
183
+ upsample_img=bg_img,
184
+ draw_box=draw_box,
185
+ face_upsampler=face_upsampler,
186
+ )
187
+ else:
188
+ restored_img = face_helper.paste_faces_to_input_image(
189
+ upsample_img=bg_img, draw_box=draw_box
190
+ )
191
+
192
+ # save restored img
193
+ save_path = f'output/out.png'
194
+ imwrite(restored_img, str(save_path))
195
+
196
+ restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
197
+ return restored_img, save_path
198
+ except Exception as error:
199
+ print('global exception', error)
200
+ return None, None
201
 
202
 
203
  title = "CodeFormer: Robust Face Restoration and Enhancement Network"
inference_codeformer.py CHANGED
@@ -52,9 +52,10 @@ if __name__ == '__main__':
52
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
53
  parser = argparse.ArgumentParser()
54
 
55
- parser.add_argument('--w', type=float, default=0.5, help='Balance the quality and fidelity')
56
- parser.add_argument('--upscale', type=int, default=2, help='The final upsampling scale of the image. Default: 2')
57
- parser.add_argument('--test_path', type=str, default='./inputs/cropped_faces')
 
58
  parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces')
59
  parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face')
60
  # large det_model: 'YOLOv5l', 'retinaface_resnet50'
@@ -64,12 +65,14 @@ if __name__ == '__main__':
64
  parser.add_argument('--bg_upsampler', type=str, default='None', help='background upsampler. Optional: realesrgan')
65
  parser.add_argument('--face_upsample', action='store_true', help='face upsampler after enhancement.')
66
  parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler. Default: 400')
 
67
  parser.add_argument('--save_video_fps', type=int, default=24, help='frame rate for saving video. Default: 24')
68
 
69
  args = parser.parse_args()
70
 
71
  # ------------------------ input & output ------------------------
72
  w = args.w
 
73
  if args.test_path.endswith(('jpg', 'png')): # input single img path
74
  input_img_list = [args.test_path]
75
  result_root = f'results/test_img_{w}'
@@ -89,7 +92,10 @@ if __name__ == '__main__':
89
  # scan all the jpg and png images
90
  input_img_list = sorted(glob.glob(os.path.join(args.test_path, '*.[jp][pn]g')))
91
  result_root = f'results/{os.path.basename(args.test_path)}_{w}'
92
-
 
 
 
93
  test_img_num = len(input_img_list)
94
  # ------------------ set up background upsampler ------------------
95
  if args.bg_upsampler == 'realesrgan':
@@ -215,11 +221,15 @@ if __name__ == '__main__':
215
  save_face_name = f'{basename}.png'
216
  else:
217
  save_face_name = f'{basename}_{idx:02d}.png'
 
 
218
  save_restore_path = os.path.join(result_root, 'restored_faces', save_face_name)
219
  imwrite(restored_face, save_restore_path)
220
 
221
  # save restored img
222
  if not args.has_aligned and restored_img is not None:
 
 
223
  save_restore_path = os.path.join(result_root, 'final_results', f'{basename}.png')
224
  imwrite(restored_img, save_restore_path)
225
 
 
52
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
53
  parser = argparse.ArgumentParser()
54
 
55
+ parser.add_argument('-i', '--test_path', type=str, default='./inputs/cropped_faces')
56
+ parser.add_argument('-o', '--save_path', type=str, default=None)
57
+ parser.add_argument('-w', '--w', type=float, default=0.5, help='Balance the quality and fidelity')
58
+ parser.add_argument('-s', '--upscale', type=int, default=2, help='The final upsampling scale of the image. Default: 2')
59
  parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces')
60
  parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face')
61
  # large det_model: 'YOLOv5l', 'retinaface_resnet50'
 
65
  parser.add_argument('--bg_upsampler', type=str, default='None', help='background upsampler. Optional: realesrgan')
66
  parser.add_argument('--face_upsample', action='store_true', help='face upsampler after enhancement.')
67
  parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler. Default: 400')
68
+ parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
69
  parser.add_argument('--save_video_fps', type=int, default=24, help='frame rate for saving video. Default: 24')
70
 
71
  args = parser.parse_args()
72
 
73
  # ------------------------ input & output ------------------------
74
  w = args.w
75
+
76
  if args.test_path.endswith(('jpg', 'png')): # input single img path
77
  input_img_list = [args.test_path]
78
  result_root = f'results/test_img_{w}'
 
92
  # scan all the jpg and png images
93
  input_img_list = sorted(glob.glob(os.path.join(args.test_path, '*.[jp][pn]g')))
94
  result_root = f'results/{os.path.basename(args.test_path)}_{w}'
95
+
96
+ if not args.save_path is None: # set output path
97
+ result_root = args.save_path
98
+
99
  test_img_num = len(input_img_list)
100
  # ------------------ set up background upsampler ------------------
101
  if args.bg_upsampler == 'realesrgan':
 
221
  save_face_name = f'{basename}.png'
222
  else:
223
  save_face_name = f'{basename}_{idx:02d}.png'
224
+ if args.suffix is not None:
225
+ save_face_name = f'{save_face_name[:-4]}_{args.suffix}.png'
226
  save_restore_path = os.path.join(result_root, 'restored_faces', save_face_name)
227
  imwrite(restored_face, save_restore_path)
228
 
229
  # save restored img
230
  if not args.has_aligned and restored_img is not None:
231
+ if args.suffix is not None:
232
+ basename = f'{basename}_{args.suffix}'
233
  save_restore_path = os.path.join(result_root, 'final_results', f'{basename}.png')
234
  imwrite(restored_img, save_restore_path)
235