sczhou commited on
Commit
a39b60d
·
1 Parent(s): c5b4593

add hugging_face demo.

Browse files
README.md CHANGED
@@ -20,7 +20,8 @@ S-Lab, Nanyang Technological University
20
 
21
  ### Update
22
 
23
- - **2022.09.09**: Integrated to [Replicate](https://replicate.com/). Try out online demo! [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer)
 
24
  - **2022.09.04**: Add face upsampling `--face_upsample` for high-resolution AI-created face enhancement.
25
  - **2022.08.23**: Some modifications on face detection and fusion for better AI-created face enhancement.
26
  - **2022.08.07**: Integrate [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement.
 
20
 
21
  ### Update
22
 
23
+ - **2022.09.14**: Integrated to :hugs: [Hugging Face](https://replicate.com/). Try out online demo! [![Hugging Face](https://img.shields.io/badge/Demo-%F0%9F%A4%97%20Hugging%20Face-blue)](https://replicate.com/sczhou/codeformer)
24
+ - **2022.09.09**: Integrated to :rocket: [Replicate](https://replicate.com/). Try out online demo! [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer)
25
  - **2022.09.04**: Add face upsampling `--face_upsample` for high-resolution AI-created face enhancement.
26
  - **2022.08.23**: Some modifications on face detection and fusion for better AI-created face enhancement.
27
  - **2022.08.07**: Integrate [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement.
app.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is used for deploying hugging face demo:
3
+ https://huggingface.co/spaces/sczhou/CodeFormer
4
+ """
5
+
6
+ import sys
7
+ sys.path.append('CodeFormer')
8
+ import os
9
+ import cv2
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import gradio as gr
13
+
14
+ from torchvision.transforms.functional import normalize
15
+
16
+ from basicsr.utils import imwrite, img2tensor, tensor2img
17
+ from basicsr.utils.download_util import load_file_from_url
18
+ from facelib.utils.face_restoration_helper import FaceRestoreHelper
19
+ from basicsr.archs.rrdbnet_arch import RRDBNet
20
+ from basicsr.utils.realesrgan_utils import RealESRGANer
21
+
22
+ from basicsr.utils.registry import ARCH_REGISTRY
23
+
24
+
25
+ os.system("pip freeze")
26
+
27
+ pretrain_model_url = {
28
+ 'codeformer': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
29
+ 'detection': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth',
30
+ 'parsing': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth',
31
+ 'realesrgan': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth'
32
+ }
33
+ # download weights
34
+ if not os.path.exists('CodeFormer/weights/CodeFormer/codeformer.pth'):
35
+ load_file_from_url(url=pretrain_model_url['codeformer'], model_dir='CodeFormer/weights/CodeFormer', progress=True, file_name=None)
36
+ if not os.path.exists('CodeFormer/weights/facelib/detection_Resnet50_Final.pth'):
37
+ load_file_from_url(url=pretrain_model_url['detection'], model_dir='CodeFormer/weights/facelib', progress=True, file_name=None)
38
+ if not os.path.exists('CodeFormer/weights/facelib/parsing_parsenet.pth'):
39
+ load_file_from_url(url=pretrain_model_url['parsing'], model_dir='CodeFormer/weights/facelib', progress=True, file_name=None)
40
+ if not os.path.exists('CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth'):
41
+ load_file_from_url(url=pretrain_model_url['realesrgan'], model_dir='CodeFormer/weights/realesrgan', progress=True, file_name=None)
42
+
43
+ # download images
44
+ torch.hub.download_url_to_file(
45
+ 'https://replicate.com/api/models/sczhou/codeformer/files/fa3fe3d1-76b0-4ca8-ac0d-0a925cb0ff54/06.png',
46
+ '01.png')
47
+ torch.hub.download_url_to_file(
48
+ 'https://replicate.com/api/models/sczhou/codeformer/files/a1daba8e-af14-4b00-86a4-69cec9619b53/04.jpg',
49
+ '02.jpg')
50
+ torch.hub.download_url_to_file(
51
+ 'https://replicate.com/api/models/sczhou/codeformer/files/542d64f9-1712-4de7-85f7-3863009a7c3d/03.jpg',
52
+ '03.jpg')
53
+ torch.hub.download_url_to_file(
54
+ 'https://replicate.com/api/models/sczhou/codeformer/files/a11098b0-a18a-4c02-a19a-9a7045d68426/010.jpg',
55
+ '04.jpg')
56
+ torch.hub.download_url_to_file(
57
+ 'https://replicate.com/api/models/sczhou/codeformer/files/7cf19c2c-e0cf-4712-9af8-cf5bdbb8d0ee/012.jpg',
58
+ '05.jpg')
59
+
60
+ def imread(img_path):
61
+ img = cv2.imread(img_path)
62
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
63
+ return img
64
+
65
+ # set enhancer with RealESRGAN
66
+ def set_realesrgan():
67
+ half = True if torch.cuda.is_available() else False
68
+ model = RRDBNet(
69
+ num_in_ch=3,
70
+ num_out_ch=3,
71
+ num_feat=64,
72
+ num_block=23,
73
+ num_grow_ch=32,
74
+ scale=2,
75
+ )
76
+ upsampler = RealESRGANer(
77
+ scale=2,
78
+ model_path="CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth",
79
+ model=model,
80
+ tile=400,
81
+ tile_pad=40,
82
+ pre_pad=0,
83
+ half=half,
84
+ )
85
+ return upsampler
86
+
87
+ upsampler = set_realesrgan()
88
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
89
+ codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
90
+ dim_embd=512,
91
+ codebook_size=1024,
92
+ n_head=8,
93
+ n_layers=9,
94
+ connect_list=["32", "64", "128", "256"],
95
+ ).to(device)
96
+ ckpt_path = "CodeFormer/weights/CodeFormer/codeformer.pth"
97
+ checkpoint = torch.load(ckpt_path)["params_ema"]
98
+ codeformer_net.load_state_dict(checkpoint)
99
+ codeformer_net.eval()
100
+
101
+ os.makedirs('output', exist_ok=True)
102
+
103
+ def inference(image, background_enhance, face_upsample, upscale, codeformer_fidelity):
104
+ """Run a single prediction on the model"""
105
+ # take the default setting for the demo
106
+ has_aligned = False
107
+ only_center_face = False
108
+ draw_box = False
109
+ detection_model = "retinaface_resnet50"
110
+
111
+ face_helper = FaceRestoreHelper(
112
+ upscale,
113
+ face_size=512,
114
+ crop_ratio=(1, 1),
115
+ det_model=detection_model,
116
+ save_ext="png",
117
+ use_parse=True,
118
+ device=device,
119
+ )
120
+ bg_upsampler = upsampler if background_enhance else None
121
+ face_upsampler = upsampler if face_upsample else None
122
+
123
+ img = cv2.imread(str(image), cv2.IMREAD_COLOR)
124
+
125
+ if has_aligned:
126
+ # the input faces are already cropped and aligned
127
+ img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
128
+ face_helper.cropped_faces = [img]
129
+ else:
130
+ face_helper.read_image(img)
131
+ # get face landmarks for each face
132
+ num_det_faces = face_helper.get_face_landmarks_5(
133
+ only_center_face=only_center_face, resize=640, eye_dist_threshold=5
134
+ )
135
+ print(f"\tdetect {num_det_faces} faces")
136
+ # align and warp each face
137
+ face_helper.align_warp_face()
138
+
139
+ # face restoration for each cropped face
140
+ for idx, cropped_face in enumerate(face_helper.cropped_faces):
141
+ # prepare data
142
+ cropped_face_t = img2tensor(
143
+ cropped_face / 255.0, bgr2rgb=True, float32=True
144
+ )
145
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
146
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
147
+
148
+ try:
149
+ with torch.no_grad():
150
+ output = codeformer_net(
151
+ cropped_face_t, w=codeformer_fidelity, adain=True
152
+ )[0]
153
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
154
+ del output
155
+ torch.cuda.empty_cache()
156
+ except Exception as error:
157
+ print(f"\tFailed inference for CodeFormer: {error}")
158
+ restored_face = tensor2img(
159
+ cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
160
+ )
161
+
162
+ restored_face = restored_face.astype("uint8")
163
+ face_helper.add_restored_face(restored_face)
164
+
165
+ # paste_back
166
+ if not has_aligned:
167
+ # upsample the background
168
+ if bg_upsampler is not None:
169
+ # Now only support RealESRGAN for upsampling background
170
+ bg_img = bg_upsampler.enhance(img, outscale=upscale)[0]
171
+ else:
172
+ bg_img = None
173
+ face_helper.get_inverse_affine(None)
174
+ # paste each restored face to the input image
175
+ if face_upsample and face_upsampler is not None:
176
+ restored_img = face_helper.paste_faces_to_input_image(
177
+ upsample_img=bg_img,
178
+ draw_box=draw_box,
179
+ face_upsampler=face_upsampler,
180
+ )
181
+ else:
182
+ restored_img = face_helper.paste_faces_to_input_image(
183
+ upsample_img=bg_img, draw_box=draw_box
184
+ )
185
+
186
+ # save restored img
187
+ save_path = f'output/out.png'
188
+ imwrite(restored_img, str(save_path))
189
+
190
+ restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
191
+ return restored_img
192
+
193
+
194
+
195
+ title = "CodeFormer: Robust Face Restoration and Enhancement Network"
196
+ description = r"""<center><img src='https://user-images.githubusercontent.com/14334509/189166076-94bb2cac-4f4e-40fb-a69f-66709e3d98f5.png' alt='CodeFormer logo'></center>
197
+ <b>Official Gradio demo</b> for <a href='https://github.com/sczhou/CodeFormer' target='_blank'><b>Towards Robust Blind Face Restoration with Codebook Lookup Transformer</b></a>.<br>
198
+ 🔥 CodeFormer is a robust face restoration algorithm for old photos or AI-generated faces.<br>
199
+ 🤗 Try CodeFormer for improved stable-diffusion generation!<br>
200
+ """
201
+ article = r"""
202
+ If CodeFormer is helpful, please help to ⭐ the <a href='https://github.com/sczhou/CodeFormer' target='_blank'>Github Repo</a>. Thanks!
203
+ [![GitHub Stars](https://img.shields.io/github/stars/sczhou/CodeFormer?style=social)](https://github.com/sczhou/CodeFormer)
204
+
205
+ ---
206
+
207
+ 📝 Citation
208
+ If our work is useful for your research, please consider citing:
209
+ ```bibtex
210
+ @article{zhou2022codeformer,
211
+ author = {Zhou, Shangchen and Chan, Kelvin C.K. and Li, Chongyi and Loy, Chen Change},
212
+ title = {Towards Robust Blind Face Restoration with Codebook Lookup TransFormer},
213
+ journal = {arXiv preprint arXiv:2206.11253},
214
+ year = {2022}
215
+ }
216
+ ```
217
+
218
+ 📧 Contact
219
+ If you have any questions, please feel free to reach me out at <b>[email protected]</b>.
220
+
221
+ ![visitors](https://visitor-badge.glitch.me/badge?page_id=sczhou/CodeFormer)
222
+ """
223
+
224
+ gr.Interface(
225
+ inference, [
226
+ gr.inputs.Image(type="filepath", label="Input"),
227
+ gr.inputs.Checkbox(default=True, label="Background_Enhance"),
228
+ gr.inputs.Checkbox(default=True, label="Face_Upsample"),
229
+ gr.inputs.Number(default=2, label="Rescaling_Factor"),
230
+ gr.Slider(0, 1, value=0.5, step=0.01, label='Codeformer_Fidelity, 0 for better quality, 1 for better identity')
231
+ ], [
232
+ gr.outputs.Image(type="numpy", label="Output"),
233
+ ],
234
+ title=title,
235
+ description=description,
236
+ article=article,
237
+ examples=[
238
+ ['01.png', True, True, 2, 0.7],
239
+ ['02.jpg', True, True, 2, 0.7],
240
+ ['03.jpg', True, True, 2, 0.7],
241
+ ['04.jpg', True, True, 2, 0.1],
242
+ ['05.jpg', True, True, 2, 0.1]
243
+ ]
244
+ ).launch()
cog.yaml CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  build:
2
  gpu: true
3
  cuda: "11.3"
 
1
+ """
2
+ This file is used for deploying replicate demo:
3
+ https://replicate.com/sczhou/codeformer
4
+ """
5
+
6
  build:
7
  gpu: true
8
  cuda: "11.3"
facelib/detection/__init__.py CHANGED
@@ -25,10 +25,10 @@ def init_detection_model(model_name, half=False, device='cuda'):
25
  def init_retinaface_model(model_name, half=False, device='cuda'):
26
  if model_name == 'retinaface_resnet50':
27
  model = RetinaFace(network_name='resnet50', half=half)
28
- model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth'
29
  elif model_name == 'retinaface_mobile0.25':
30
  model = RetinaFace(network_name='mobile0.25', half=half)
31
- model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth'
32
  else:
33
  raise NotImplementedError(f'{model_name} is not implemented.')
34
 
 
25
  def init_retinaface_model(model_name, half=False, device='cuda'):
26
  if model_name == 'retinaface_resnet50':
27
  model = RetinaFace(network_name='resnet50', half=half)
28
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth'
29
  elif model_name == 'retinaface_mobile0.25':
30
  model = RetinaFace(network_name='mobile0.25', half=half)
31
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth'
32
  else:
33
  raise NotImplementedError(f'{model_name} is not implemented.')
34
 
facelib/utils/face_restoration_helper.py CHANGED
@@ -59,7 +59,7 @@ class FaceRestoreHelper(object):
59
  use_parse=False,
60
  device=None):
61
  self.template_3points = template_3points # improve robustness
62
- self.upscale_factor = upscale_factor
63
  # the cropped face ratio based on the square face
64
  self.crop_ratio = crop_ratio # (h, w)
65
  assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
 
59
  use_parse=False,
60
  device=None):
61
  self.template_3points = template_3points # improve robustness
62
+ self.upscale_factor = int(upscale_factor)
63
  # the cropped face ratio based on the square face
64
  self.crop_ratio = crop_ratio # (h, w)
65
  assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
predict.py CHANGED
@@ -1,15 +1,18 @@
1
  """
2
- download checkpoints to ./weights beforehand
3
- python scripts/download_pretrained_models.py facelib
4
- python scripts/download_pretrained_models.py CodeFormer
5
- wget 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'
6
  """
7
 
8
  import tempfile
9
  import cv2
10
  import torch
11
  from torchvision.transforms.functional import normalize
12
- from cog import BasePredictor, Input, Path
 
 
 
13
 
14
  from basicsr.utils import imwrite, img2tensor, tensor2img
15
  from basicsr.archs.rrdbnet_arch import RRDBNet
@@ -22,7 +25,7 @@ class Predictor(BasePredictor):
22
  def setup(self):
23
  """Load the model into memory to make running multiple predictions efficient"""
24
  self.device = "cuda:0"
25
- self.bg_upsampler = set_realesrgan()
26
  self.net = ARCH_REGISTRY.get("CodeFormer")(
27
  dim_embd=512,
28
  codebook_size=1024,
@@ -76,8 +79,8 @@ class Predictor(BasePredictor):
76
  device=self.device,
77
  )
78
 
79
- bg_upsampler = self.bg_upsampler if background_enhance else None
80
- face_upsampler = self.bg_upsampler if face_upsample else None
81
 
82
  img = cv2.imread(str(image), cv2.IMREAD_COLOR)
83
 
@@ -143,10 +146,8 @@ class Predictor(BasePredictor):
143
  )
144
 
145
  # save restored img
146
- out_path = Path(tempfile.mkdtemp()) / "output.png"
147
-
148
- if not has_aligned and restored_img is not None:
149
- imwrite(restored_img, str(out_path))
150
 
151
  return out_path
152
 
@@ -166,7 +167,7 @@ def set_realesrgan():
166
  "If you really want to use it, please modify the corresponding codes.",
167
  category=RuntimeWarning,
168
  )
169
- bg_upsampler = None
170
  else:
171
  model = RRDBNet(
172
  num_in_ch=3,
@@ -176,13 +177,13 @@ def set_realesrgan():
176
  num_grow_ch=32,
177
  scale=2,
178
  )
179
- bg_upsampler = RealESRGANer(
180
  scale=2,
181
- model_path="./weights/RealESRGAN_x2plus.pth",
182
  model=model,
183
  tile=400,
184
  tile_pad=40,
185
  pre_pad=0,
186
  half=True,
187
  )
188
- return bg_upsampler
 
1
  """
2
+ This file is used for deploying replicate demo:
3
+ https://replicate.com/sczhou/codeformer
4
+ running: cog predict -i image=@inputs/whole_imgs/04.jpg -i codeformer_fidelity=0.5 -i upscale=2
5
+ push: cog push r8.im/sczhou/codeformer
6
  """
7
 
8
  import tempfile
9
  import cv2
10
  import torch
11
  from torchvision.transforms.functional import normalize
12
+ try:
13
+ from cog import BasePredictor, Input, Path
14
+ except Exception:
15
+ print('please install cog package')
16
 
17
  from basicsr.utils import imwrite, img2tensor, tensor2img
18
  from basicsr.archs.rrdbnet_arch import RRDBNet
 
25
  def setup(self):
26
  """Load the model into memory to make running multiple predictions efficient"""
27
  self.device = "cuda:0"
28
+ self.upsampler = set_realesrgan()
29
  self.net = ARCH_REGISTRY.get("CodeFormer")(
30
  dim_embd=512,
31
  codebook_size=1024,
 
79
  device=self.device,
80
  )
81
 
82
+ bg_upsampler = self.upsampler if background_enhance else None
83
+ face_upsampler = self.upsampler if face_upsample else None
84
 
85
  img = cv2.imread(str(image), cv2.IMREAD_COLOR)
86
 
 
146
  )
147
 
148
  # save restored img
149
+ out_path = Path(tempfile.mkdtemp()) / 'output.png'
150
+ imwrite(restored_img, str(out_path))
 
 
151
 
152
  return out_path
153
 
 
167
  "If you really want to use it, please modify the corresponding codes.",
168
  category=RuntimeWarning,
169
  )
170
+ upsampler = None
171
  else:
172
  model = RRDBNet(
173
  num_in_ch=3,
 
177
  num_grow_ch=32,
178
  scale=2,
179
  )
180
+ upsampler = RealESRGANer(
181
  scale=2,
182
+ model_path="./weights/realesrgan/RealESRGAN_x2plus.pth",
183
  model=model,
184
  tile=400,
185
  tile_pad=40,
186
  pre_pad=0,
187
  half=True,
188
  )
189
+ return upsampler