sczhou commited on
Commit
637d603
·
1 Parent(s): fbc4a80

V0.1.0 release.

Browse files
README.md CHANGED
@@ -6,7 +6,8 @@
6
 
7
  [Paper](https://arxiv.org/abs/2206.11253) | [Project Page](https://shangchenzhou.com/projects/CodeFormer/) | [Video](https://youtu.be/d3VDpkXlueI)
8
 
9
- <a href="https://colab.research.google.com/drive/1m52PNveE4PBhYrecj34cnpEeiHcC5LTb?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>
 
10
 
11
  [Shangchen Zhou](https://shangchenzhou.com/), [Kelvin C.K. Chan](https://ckkelvinchan.github.io/), [Chongyi Li](https://li-chongyi.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/)
12
 
 
6
 
7
  [Paper](https://arxiv.org/abs/2206.11253) | [Project Page](https://shangchenzhou.com/projects/CodeFormer/) | [Video](https://youtu.be/d3VDpkXlueI)
8
 
9
+ <a href="https://colab.research.google.com/drive/1m52PNveE4PBhYrecj34cnpEeiHcC5LTb?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a> ![visitors](https://visitor-badge.glitch.me/badge?page_id=sczhou/CodeFormer)
10
+
11
 
12
  [Shangchen Zhou](https://shangchenzhou.com/), [Kelvin C.K. Chan](https://ckkelvinchan.github.io/), [Chongyi Li](https://li-chongyi.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/)
13
 
facelib/detection/__init__.py CHANGED
@@ -49,17 +49,14 @@ def init_retinaface_model(model_name, half=False, device='cuda'):
49
  def init_yolov5face_model(model_name, device='cuda'):
50
  if model_name == 'YOLOv5l':
51
  model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device)
52
- f_id = {'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV'}
53
  elif model_name == 'YOLOv5n':
54
  model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device)
55
- f_id = {'yolov5n-face.pth': '1fhcpFvWZqghpGXjYPIne2sw1Fy4yhw6o'}
56
  else:
57
  raise NotImplementedError(f'{model_name} is not implemented.')
58
-
59
- model_path = os.path.join('weights/facelib', list(f_id.keys())[0])
60
- if not os.path.exists(model_path):
61
- download_pretrained_models(file_ids=f_id, save_path_root='weights/facelib')
62
-
63
  load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
64
  model.detector.load_state_dict(load_net, strict=True)
65
  model.detector.eval()
@@ -71,4 +68,33 @@ def init_yolov5face_model(model_name, device='cuda'):
71
  elif isinstance(m, Conv):
72
  m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
73
 
74
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def init_yolov5face_model(model_name, device='cuda'):
50
  if model_name == 'YOLOv5l':
51
  model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device)
52
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth'
53
  elif model_name == 'YOLOv5n':
54
  model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device)
55
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth'
56
  else:
57
  raise NotImplementedError(f'{model_name} is not implemented.')
58
+
59
+ model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None)
 
 
 
60
  load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
61
  model.detector.load_state_dict(load_net, strict=True)
62
  model.detector.eval()
 
68
  elif isinstance(m, Conv):
69
  m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
70
 
71
+ return model
72
+
73
+
74
+ # Download from Google Drive
75
+ # def init_yolov5face_model(model_name, device='cuda'):
76
+ # if model_name == 'YOLOv5l':
77
+ # model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device)
78
+ # f_id = {'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV'}
79
+ # elif model_name == 'YOLOv5n':
80
+ # model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device)
81
+ # f_id = {'yolov5n-face.pth': '1fhcpFvWZqghpGXjYPIne2sw1Fy4yhw6o'}
82
+ # else:
83
+ # raise NotImplementedError(f'{model_name} is not implemented.')
84
+
85
+ # model_path = os.path.join('weights/facelib', list(f_id.keys())[0])
86
+ # if not os.path.exists(model_path):
87
+ # download_pretrained_models(file_ids=f_id, save_path_root='weights/facelib')
88
+
89
+ # load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
90
+ # model.detector.load_state_dict(load_net, strict=True)
91
+ # model.detector.eval()
92
+ # model.detector = model.detector.to(device).float()
93
+
94
+ # for m in model.detector.modules():
95
+ # if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
96
+ # m.inplace = True # pytorch 1.7.0 compatibility
97
+ # elif isinstance(m, Conv):
98
+ # m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
99
+
100
+ # return model
facelib/parsing/__init__.py CHANGED
@@ -8,10 +8,10 @@ from .parsenet import ParseNet
8
  def init_parsing_model(model_name='bisenet', half=False, device='cuda'):
9
  if model_name == 'bisenet':
10
  model = BiSeNet(num_class=19)
11
- model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/parsing_bisenet.pth'
12
  elif model_name == 'parsenet':
13
  model = ParseNet(in_size=512, out_size=512, parsing_ch=19)
14
- model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth'
15
  else:
16
  raise NotImplementedError(f'{model_name} is not implemented.')
17
 
 
8
  def init_parsing_model(model_name='bisenet', half=False, device='cuda'):
9
  if model_name == 'bisenet':
10
  model = BiSeNet(num_class=19)
11
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth'
12
  elif model_name == 'parsenet':
13
  model = ParseNet(in_size=512, out_size=512, parsing_ch=19)
14
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth'
15
  else:
16
  raise NotImplementedError(f'{model_name} is not implemented.')
17
 
inference_codeformer.py CHANGED
@@ -6,11 +6,16 @@ import glob
6
  import torch
7
  from torchvision.transforms.functional import normalize
8
  from basicsr.utils import imwrite, img2tensor, tensor2img
 
9
  from facelib.utils.face_restoration_helper import FaceRestoreHelper
10
  import torch.nn.functional as F
11
 
12
  from basicsr.utils.registry import ARCH_REGISTRY
13
 
 
 
 
 
14
  if __name__ == '__main__':
15
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
  parser = argparse.ArgumentParser()
@@ -59,8 +64,10 @@ if __name__ == '__main__':
59
  # ------------------ set up CodeFormer restorer -------------------
60
  net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
61
  connect_list=['32', '64', '128', '256']).to(device)
62
-
63
- ckpt_path = 'weights/CodeFormer/codeformer.pth'
 
 
64
  checkpoint = torch.load(ckpt_path)['params_ema']
65
  net.load_state_dict(checkpoint)
66
  net.eval()
 
6
  import torch
7
  from torchvision.transforms.functional import normalize
8
  from basicsr.utils import imwrite, img2tensor, tensor2img
9
+ from basicsr.utils.download_util import load_file_from_url
10
  from facelib.utils.face_restoration_helper import FaceRestoreHelper
11
  import torch.nn.functional as F
12
 
13
  from basicsr.utils.registry import ARCH_REGISTRY
14
 
15
+ pretrain_model_url = {
16
+ 'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
17
+ }
18
+
19
  if __name__ == '__main__':
20
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
  parser = argparse.ArgumentParser()
 
64
  # ------------------ set up CodeFormer restorer -------------------
65
  net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
66
  connect_list=['32', '64', '128', '256']).to(device)
67
+
68
+ # ckpt_path = 'weights/CodeFormer/codeformer.pth'
69
+ ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'],
70
+ model_dir='weights/CodeFormer', progress=True, file_name=None)
71
  checkpoint = torch.load(ckpt_path)['params_ema']
72
  net.load_state_dict(checkpoint)
73
  net.eval()
requirements.txt CHANGED
@@ -14,7 +14,7 @@ torchvision
14
  tqdm
15
  yapf
16
  lpips
17
- gdown # supports downloading the large file from Google Drive
18
  # cmake
19
  # dlib
20
  # conda install -c conda-forge dlib
 
14
  tqdm
15
  yapf
16
  lpips
17
+ # gdown # supports downloading the large file from Google Drive
18
  # cmake
19
  # dlib
20
  # conda install -c conda-forge dlib
scripts/download_pretrained_models.py CHANGED
@@ -2,31 +2,16 @@ import argparse
2
  import os
3
  from os import path as osp
4
 
5
- # from basicsr.utils.download_util import download_file_from_google_drive
6
- import gdown
7
 
8
 
9
- def download_pretrained_models(method, file_ids):
10
  save_path_root = f'./weights/{method}'
11
  os.makedirs(save_path_root, exist_ok=True)
12
 
13
- for file_name, file_id in file_ids.items():
14
- file_url = 'https://drive.google.com/uc?id='+file_id
15
- save_path = osp.abspath(osp.join(save_path_root, file_name))
16
- if osp.exists(save_path):
17
- user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
18
- if user_response.lower() == 'y':
19
- print(f'Covering {file_name} to {save_path}')
20
- gdown.download(file_url, save_path, quiet=False)
21
- # download_file_from_google_drive(file_id, save_path)
22
- elif user_response.lower() == 'n':
23
- print(f'Skipping {file_name}')
24
- else:
25
- raise ValueError('Wrong input. Only accepts Y/N.')
26
- else:
27
- print(f'Downloading {file_name} to {save_path}')
28
- gdown.download(file_url, save_path, quiet=False)
29
- # download_file_from_google_drive(file_id, save_path)
30
 
31
  if __name__ == '__main__':
32
  parser = argparse.ArgumentParser()
@@ -37,24 +22,18 @@ if __name__ == '__main__':
37
  help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models."))
38
  args = parser.parse_args()
39
 
40
- # file name: file id
41
- # 'dlib': {
42
- # 'mmod_human_face_detector-4cb19393.dat': '1qD-OqY8M6j4PWUP_FtqfwUPFPRMu6ubX',
43
- # 'shape_predictor_5_face_landmarks-c4b1e980.dat': '1vF3WBUApw4662v9Pw6wke3uk1qxnmLdg',
44
- # 'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1tJyIVdCHaU6IDMDx86BZCxLGZfsWB8yq'
45
- # }
46
- file_ids = {
47
  'CodeFormer': {
48
- 'codeformer.pth': '1v_E_vZvP-dQPF55Kc5SRCjaKTQXDz-JB'
49
  },
50
  'facelib': {
51
- 'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV',
52
- 'parsing_parsenet.pth': '16pkohyZZ8ViHGBk3QtVqxLZKzdo466bK'
53
  }
54
  }
55
 
56
  if args.method == 'all':
57
- for method in file_ids.keys():
58
- download_pretrained_models(method, file_ids[method])
59
  else:
60
- download_pretrained_models(args.method, file_ids[args.method])
 
2
  import os
3
  from os import path as osp
4
 
5
+ from basicsr.utils.download_util import load_file_from_url
 
6
 
7
 
8
+ def download_pretrained_models(method, file_urls):
9
  save_path_root = f'./weights/{method}'
10
  os.makedirs(save_path_root, exist_ok=True)
11
 
12
+ for file_name, file_url in file_urls.items():
13
+ save_path = load_file_from_url(url=file_url, model_dir=save_path_root, progress=True, file_name=file_name)
14
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  if __name__ == '__main__':
17
  parser = argparse.ArgumentParser()
 
22
  help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models."))
23
  args = parser.parse_args()
24
 
25
+ file_urls = {
 
 
 
 
 
 
26
  'CodeFormer': {
27
+ 'codeformer.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
28
  },
29
  'facelib': {
30
+ 'yolov5l-face.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth',
31
+ 'parsing_parsenet.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth'
32
  }
33
  }
34
 
35
  if args.method == 'all':
36
+ for method in file_urls.keys():
37
+ download_pretrained_models(method, file_urls[method])
38
  else:
39
+ download_pretrained_models(args.method, file_urls[args.method])
scripts/download_pretrained_models_from_gdrive.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from os import path as osp
4
+
5
+ # from basicsr.utils.download_util import download_file_from_google_drive
6
+ import gdown
7
+
8
+
9
+ def download_pretrained_models(method, file_ids):
10
+ save_path_root = f'./weights/{method}'
11
+ os.makedirs(save_path_root, exist_ok=True)
12
+
13
+ for file_name, file_id in file_ids.items():
14
+ file_url = 'https://drive.google.com/uc?id='+file_id
15
+ save_path = osp.abspath(osp.join(save_path_root, file_name))
16
+ if osp.exists(save_path):
17
+ user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
18
+ if user_response.lower() == 'y':
19
+ print(f'Covering {file_name} to {save_path}')
20
+ gdown.download(file_url, save_path, quiet=False)
21
+ # download_file_from_google_drive(file_id, save_path)
22
+ elif user_response.lower() == 'n':
23
+ print(f'Skipping {file_name}')
24
+ else:
25
+ raise ValueError('Wrong input. Only accepts Y/N.')
26
+ else:
27
+ print(f'Downloading {file_name} to {save_path}')
28
+ gdown.download(file_url, save_path, quiet=False)
29
+ # download_file_from_google_drive(file_id, save_path)
30
+
31
+ if __name__ == '__main__':
32
+ parser = argparse.ArgumentParser()
33
+
34
+ parser.add_argument(
35
+ 'method',
36
+ type=str,
37
+ help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models."))
38
+ args = parser.parse_args()
39
+
40
+ # file name: file id
41
+ # 'dlib': {
42
+ # 'mmod_human_face_detector-4cb19393.dat': '1qD-OqY8M6j4PWUP_FtqfwUPFPRMu6ubX',
43
+ # 'shape_predictor_5_face_landmarks-c4b1e980.dat': '1vF3WBUApw4662v9Pw6wke3uk1qxnmLdg',
44
+ # 'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1tJyIVdCHaU6IDMDx86BZCxLGZfsWB8yq'
45
+ # }
46
+ file_ids = {
47
+ 'CodeFormer': {
48
+ 'codeformer.pth': '1v_E_vZvP-dQPF55Kc5SRCjaKTQXDz-JB'
49
+ },
50
+ 'facelib': {
51
+ 'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV',
52
+ 'parsing_parsenet.pth': '16pkohyZZ8ViHGBk3QtVqxLZKzdo466bK'
53
+ }
54
+ }
55
+
56
+ if args.method == 'all':
57
+ for method in file_ids.keys():
58
+ download_pretrained_models(method, file_ids[method])
59
+ else:
60
+ download_pretrained_models(args.method, file_ids[args.method])