sczhou commited on
Commit
3014b64
Β·
1 Parent(s): e373be4

update input arguments (#44)

Browse files
Files changed (2) hide show
  1. README.md +4 -4
  2. inference_codeformer.py +32 -24
README.md CHANGED
@@ -23,7 +23,7 @@ S-Lab, Nanyang Technological University
23
 
24
  **[<font color=#d1585d>News</font>]**: :whale: *Due to copyright issues, we have to delay the release of the training code (expected by the end of this year). Please star and stay tuned for our future updates!*
25
  ### Update
26
- - **2022.10.05**: Support video input `--test_path [YOUR_VIDOE.mp4]`. Try it to enhance your videos! :clapper:
27
  - **2022.09.14**: Integrated to :hugs: [Hugging Face](https://huggingface.co/spaces). Try out online demo! [![Hugging Face](https://img.shields.io/badge/Demo-%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/sczhou/CodeFormer)
28
  - **2022.09.09**: Integrated to :rocket: [Replicate](https://replicate.com/explore). Try out online demo! [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer)
29
  - **2022.09.04**: Add face upsampling `--face_upsample` for high-resolution AI-created face enhancement.
@@ -100,7 +100,7 @@ You can put the testing images in the `inputs/TestWhole` folder. If you would li
100
  πŸ§‘πŸ» Face Restoration (cropped and aligned face)
101
  ```
102
  # For cropped and aligned faces
103
- python inference_codeformer.py --w 0.5 --has_aligned --test_path [input folder]
104
  ```
105
 
106
  :framed_picture: Whole Image Enhancement
@@ -108,14 +108,14 @@ python inference_codeformer.py --w 0.5 --has_aligned --test_path [input folder]
108
  # For whole image
109
  # Add '--bg_upsampler realesrgan' to enhance the background regions with Real-ESRGAN
110
  # Add '--face_upsample' to further upsample restorated face with Real-ESRGAN
111
- python inference_codeformer.py --w 1.0 --test_path [input folder/image path]
112
  ```
113
 
114
  :clapper: Video Enhancement
115
  ```
116
  # For video clips
117
  # Set frame rate of saved video via '--save_video_fps 24'
118
- python inference_codeformer.py --bg_upsampler realesrgan --face_upsample --w 0.7 --test_path [video path] --save_video_fps 24
119
  ```
120
 
121
 
 
23
 
24
  **[<font color=#d1585d>News</font>]**: :whale: *Due to copyright issues, we have to delay the release of the training code (expected by the end of this year). Please star and stay tuned for our future updates!*
25
  ### Update
26
+ - **2022.10.05**: Support video input `--input_path [YOUR_VIDOE.mp4]`. Try it to enhance your videos! :clapper:
27
  - **2022.09.14**: Integrated to :hugs: [Hugging Face](https://huggingface.co/spaces). Try out online demo! [![Hugging Face](https://img.shields.io/badge/Demo-%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/sczhou/CodeFormer)
28
  - **2022.09.09**: Integrated to :rocket: [Replicate](https://replicate.com/explore). Try out online demo! [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer)
29
  - **2022.09.04**: Add face upsampling `--face_upsample` for high-resolution AI-created face enhancement.
 
100
  πŸ§‘πŸ» Face Restoration (cropped and aligned face)
101
  ```
102
  # For cropped and aligned faces
103
+ python inference_codeformer.py -w 0.5 --has_aligned --input_path [input folder]
104
  ```
105
 
106
  :framed_picture: Whole Image Enhancement
 
108
  # For whole image
109
  # Add '--bg_upsampler realesrgan' to enhance the background regions with Real-ESRGAN
110
  # Add '--face_upsample' to further upsample restorated face with Real-ESRGAN
111
+ python inference_codeformer.py -w 0.7 --input_path [image folder/image path]
112
  ```
113
 
114
  :clapper: Video Enhancement
115
  ```
116
  # For video clips
117
  # Set frame rate of saved video via '--save_video_fps 24'
118
+ python inference_codeformer.py --bg_upsampler realesrgan --face_upsample -w 1.0 --input_path [video path] --save_video_fps 24
119
  ```
120
 
121
 
inference_codeformer.py CHANGED
@@ -52,49 +52,55 @@ if __name__ == '__main__':
52
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
53
  parser = argparse.ArgumentParser()
54
 
55
- parser.add_argument('-i', '--test_path', type=str, default='./inputs/cropped_faces')
56
- parser.add_argument('-o', '--save_path', type=str, default=None)
57
- parser.add_argument('-w', '--w', type=float, default=0.5, help='Balance the quality and fidelity')
58
- parser.add_argument('-s', '--upscale', type=int, default=2, help='The final upsampling scale of the image. Default: 2')
59
- parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces')
60
- parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face')
 
 
 
 
 
61
  # large det_model: 'YOLOv5l', 'retinaface_resnet50'
62
  # small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
63
- parser.add_argument('--detection_model', type=str, default='retinaface_resnet50')
64
- parser.add_argument('--draw_box', action='store_true')
65
- parser.add_argument('--bg_upsampler', type=str, default='None', help='background upsampler. Optional: realesrgan')
66
- parser.add_argument('--face_upsample', action='store_true', help='face upsampler after enhancement.')
 
67
  parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler. Default: 400')
68
- parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
69
- parser.add_argument('--save_video_fps', type=int, default=24, help='frame rate for saving video. Default: 24')
70
 
71
  args = parser.parse_args()
72
 
73
  # ------------------------ input & output ------------------------
74
  w = args.w
75
-
76
- if args.test_path.endswith(('jpg', 'png')): # input single img path
77
- input_img_list = [args.test_path]
78
  result_root = f'results/test_img_{w}'
79
- elif args.test_path.endswith(('mp4', 'mov', 'avi')): # input video path
80
  input_img_list = []
81
- vidcap = cv2.VideoCapture(args.test_path)
82
  success, image = vidcap.read()
83
  while success:
84
  input_img_list.append(image)
85
  success, image = vidcap.read()
86
  input_video = True
87
- video_name = os.path.basename(args.test_path)[:-4]
88
  result_root = f'results/{video_name}_{w}'
89
  else: # input img folder
90
- if args.test_path.endswith('/'): # solve when path ends with /
91
- args.test_path = args.test_path[:-1]
92
  # scan all the jpg and png images
93
- input_img_list = sorted(glob.glob(os.path.join(args.test_path, '*.[jp][pn]g')))
94
- result_root = f'results/{os.path.basename(args.test_path)}_{w}'
95
 
96
- if not args.save_path is None: # set output path
97
- result_root = args.save_path
98
 
99
  test_img_num = len(input_img_list)
100
  # ------------------ set up background upsampler ------------------
@@ -243,6 +249,8 @@ if __name__ == '__main__':
243
  video_frames.append(img)
244
  # write images to video
245
  h, w = video_frames[0].shape[:2]
 
 
246
  save_restore_path = os.path.join(result_root, f'{video_name}.mp4')
247
  writer = cv2.VideoWriter(save_restore_path, cv2.VideoWriter_fourcc(*"mp4v"),
248
  args.save_video_fps, (w, h))
 
52
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
53
  parser = argparse.ArgumentParser()
54
 
55
+ parser.add_argument('-i', '--input_path', type=str, default='./inputs/whole_imgs',
56
+ help='Input image, video or folder. Default: inputs/whole_imgs')
57
+ parser.add_argument('-o', '--output_path', type=str, default=None,
58
+ help='Output folder. Default: results/<input_name>_<w>')
59
+ parser.add_argument('-w', '--fidelity_weight', type=float, default=0.5,
60
+ help='Balance the quality and fidelity')
61
+ parser.add_argument('-s', '--upscale', type=int, default=2,
62
+ help='The final upsampling scale of the image. Default: 2')
63
+ parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces. Default: False')
64
+ parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face. Default: False')
65
+ parser.add_argument('--draw_box', action='store_true', help='Draw the bounding box for the detected faces. Default: False')
66
  # large det_model: 'YOLOv5l', 'retinaface_resnet50'
67
  # small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
68
+ parser.add_argument('--detection_model', type=str, default='retinaface_resnet50',
69
+ help='Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n. \
70
+ Default: retinaface_resnet50')
71
+ parser.add_argument('--bg_upsampler', type=str, default='None', help='Background upsampler. Optional: realesrgan')
72
+ parser.add_argument('--face_upsample', action='store_true', help='Face upsampler after enhancement. Default: False')
73
  parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler. Default: 400')
74
+ parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces. Default: None')
75
+ parser.add_argument('--save_video_fps', type=int, default=24, help='Frame rate for saving video. Default: 24')
76
 
77
  args = parser.parse_args()
78
 
79
  # ------------------------ input & output ------------------------
80
  w = args.w
81
+ input_video = False
82
+ if args.input_path.endswith(('jpg', 'png')): # input single img path
83
+ input_img_list = [args.input_path]
84
  result_root = f'results/test_img_{w}'
85
+ elif args.input_path.endswith(('mp4', 'mov', 'avi')): # input video path
86
  input_img_list = []
87
+ vidcap = cv2.VideoCapture(args.input_path)
88
  success, image = vidcap.read()
89
  while success:
90
  input_img_list.append(image)
91
  success, image = vidcap.read()
92
  input_video = True
93
+ video_name = os.path.basename(args.input_path)[:-4]
94
  result_root = f'results/{video_name}_{w}'
95
  else: # input img folder
96
+ if args.input_path.endswith('/'): # solve when path ends with /
97
+ args.input_path = args.input_path[:-1]
98
  # scan all the jpg and png images
99
+ input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jp][pn]g')))
100
+ result_root = f'results/{os.path.basename(args.input_path)}_{w}'
101
 
102
+ if not args.output_path is None: # set output path
103
+ result_root = args.output_path
104
 
105
  test_img_num = len(input_img_list)
106
  # ------------------ set up background upsampler ------------------
 
249
  video_frames.append(img)
250
  # write images to video
251
  h, w = video_frames[0].shape[:2]
252
+ if args.suffix is not None:
253
+ video_name = f'{video_name}_{args.suffix}.png'
254
  save_restore_path = os.path.join(result_root, f'{video_name}.mp4')
255
  writer = cv2.VideoWriter(save_restore_path, cv2.VideoWriter_fourcc(*"mp4v"),
256
  args.save_video_fps, (w, h))