sczhou commited on
Commit
79fc279
·
unverified ·
2 Parent(s): eb22a9a e7b069c

Merge pull request #28 from chenxwh/replicate

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. cog.yaml +25 -0
  3. predict.py +188 -0
README.md CHANGED
@@ -7,7 +7,7 @@
7
  [Paper](https://arxiv.org/abs/2206.11253) | [Project Page](https://shangchenzhou.com/projects/CodeFormer/) | [Video](https://youtu.be/d3VDpkXlueI)
8
 
9
  <a href="https://colab.research.google.com/drive/1m52PNveE4PBhYrecj34cnpEeiHcC5LTb?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a> ![visitors](https://visitor-badge.glitch.me/badge?page_id=sczhou/CodeFormer)
10
-
11
 
12
  [Shangchen Zhou](https://shangchenzhou.com/), [Kelvin C.K. Chan](https://ckkelvinchan.github.io/), [Chongyi Li](https://li-chongyi.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/)
13
 
 
7
  [Paper](https://arxiv.org/abs/2206.11253) | [Project Page](https://shangchenzhou.com/projects/CodeFormer/) | [Video](https://youtu.be/d3VDpkXlueI)
8
 
9
  <a href="https://colab.research.google.com/drive/1m52PNveE4PBhYrecj34cnpEeiHcC5LTb?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a> ![visitors](https://visitor-badge.glitch.me/badge?page_id=sczhou/CodeFormer)
10
+ [![Replicate](https://replicate.com/cjwbw/codeformer/badge)](https://replicate.com/cjwbw/codeformer)
11
 
12
  [Shangchen Zhou](https://shangchenzhou.com/), [Kelvin C.K. Chan](https://ckkelvinchan.github.io/), [Chongyi Li](https://li-chongyi.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/)
13
 
cog.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build:
2
+ gpu: true
3
+ cuda: "11.3"
4
+ python_version: "3.8"
5
+ system_packages:
6
+ - "libgl1-mesa-glx"
7
+ - "libglib2.0-0"
8
+ python_packages:
9
+ - "ipython==8.4.0"
10
+ - "future==0.18.2"
11
+ - "lmdb==1.3.0"
12
+ - "scikit-image==0.19.3"
13
+ - "torch==1.11.0 --extra-index-url=https://download.pytorch.org/whl/cu113"
14
+ - "torchvision==0.12.0 --extra-index-url=https://download.pytorch.org/whl/cu113"
15
+ - "scipy==1.9.0"
16
+ - "gdown==4.5.1"
17
+ - "pyyaml==6.0"
18
+ - "tb-nightly==2.11.0a20220906"
19
+ - "tqdm==4.64.1"
20
+ - "yapf==0.32.0"
21
+ - "lpips==0.1.4"
22
+ - "Pillow==9.2.0"
23
+ - "opencv-python==4.6.0.66"
24
+
25
+ predict: "predict.py:Predictor"
predict.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ download checkpoints to ./weights beforehand
3
+ python scripts/download_pretrained_models.py facelib
4
+ python scripts/download_pretrained_models.py CodeFormer
5
+ wget 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'
6
+ """
7
+
8
+ import tempfile
9
+ import cv2
10
+ import torch
11
+ from torchvision.transforms.functional import normalize
12
+ from cog import BasePredictor, Input, Path
13
+
14
+ from basicsr.utils import imwrite, img2tensor, tensor2img
15
+ from basicsr.archs.rrdbnet_arch import RRDBNet
16
+ from basicsr.utils.realesrgan_utils import RealESRGANer
17
+ from basicsr.utils.registry import ARCH_REGISTRY
18
+ from facelib.utils.face_restoration_helper import FaceRestoreHelper
19
+
20
+
21
+ class Predictor(BasePredictor):
22
+ def setup(self):
23
+ """Load the model into memory to make running multiple predictions efficient"""
24
+ self.device = "cuda:0"
25
+ self.bg_upsampler = set_realesrgan()
26
+ self.net = ARCH_REGISTRY.get("CodeFormer")(
27
+ dim_embd=512,
28
+ codebook_size=1024,
29
+ n_head=8,
30
+ n_layers=9,
31
+ connect_list=["32", "64", "128", "256"],
32
+ ).to(self.device)
33
+ ckpt_path = "weights/CodeFormer/codeformer.pth"
34
+ checkpoint = torch.load(ckpt_path)[
35
+ "params_ema"
36
+ ] # update file permission if cannot load
37
+ self.net.load_state_dict(checkpoint)
38
+ self.net.eval()
39
+
40
+ def predict(
41
+ self,
42
+ image: Path = Input(description="Input image"),
43
+ codeformer_fidelity: float = Input(
44
+ default=0.5,
45
+ ge=0,
46
+ le=1,
47
+ description="Balance the quality (lower number) and fidelity (higher number).",
48
+ ),
49
+ background_enhance: bool = Input(
50
+ description="Enhance background image with Real-ESRGAN", default=True
51
+ ),
52
+ face_upsample: bool = Input(
53
+ description="Upsample restored faces for high-resolution AI-created images",
54
+ default=True,
55
+ ),
56
+ upscale: int = Input(
57
+ description="The final upsampling scale of the image",
58
+ default=2,
59
+ ),
60
+ ) -> Path:
61
+ """Run a single prediction on the model"""
62
+
63
+ # take the default setting for the demo
64
+ has_aligned = False
65
+ only_center_face = False
66
+ draw_box = False
67
+ detection_model = "retinaface_resnet50"
68
+
69
+ self.face_helper = FaceRestoreHelper(
70
+ upscale,
71
+ face_size=512,
72
+ crop_ratio=(1, 1),
73
+ det_model=detection_model,
74
+ save_ext="png",
75
+ use_parse=True,
76
+ device=self.device,
77
+ )
78
+
79
+ bg_upsampler = self.bg_upsampler if background_enhance else None
80
+ face_upsampler = self.bg_upsampler if face_upsample else None
81
+
82
+ img = cv2.imread(str(image), cv2.IMREAD_COLOR)
83
+
84
+ if has_aligned:
85
+ # the input faces are already cropped and aligned
86
+ img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
87
+ self.face_helper.cropped_faces = [img]
88
+ else:
89
+ self.face_helper.read_image(img)
90
+ # get face landmarks for each face
91
+ num_det_faces = self.face_helper.get_face_landmarks_5(
92
+ only_center_face=only_center_face, resize=640, eye_dist_threshold=5
93
+ )
94
+ print(f"\tdetect {num_det_faces} faces")
95
+ # align and warp each face
96
+ self.face_helper.align_warp_face()
97
+
98
+ # face restoration for each cropped face
99
+ for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
100
+ # prepare data
101
+ cropped_face_t = img2tensor(
102
+ cropped_face / 255.0, bgr2rgb=True, float32=True
103
+ )
104
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
105
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
106
+
107
+ try:
108
+ with torch.no_grad():
109
+ output = self.net(
110
+ cropped_face_t, w=codeformer_fidelity, adain=True
111
+ )[0]
112
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
113
+ del output
114
+ torch.cuda.empty_cache()
115
+ except Exception as error:
116
+ print(f"\tFailed inference for CodeFormer: {error}")
117
+ restored_face = tensor2img(
118
+ cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
119
+ )
120
+
121
+ restored_face = restored_face.astype("uint8")
122
+ self.face_helper.add_restored_face(restored_face)
123
+
124
+ # paste_back
125
+ if not has_aligned:
126
+ # upsample the background
127
+ if bg_upsampler is not None:
128
+ # Now only support RealESRGAN for upsampling background
129
+ bg_img = bg_upsampler.enhance(img, outscale=upscale)[0]
130
+ else:
131
+ bg_img = None
132
+ self.face_helper.get_inverse_affine(None)
133
+ # paste each restored face to the input image
134
+ if face_upsample and face_upsampler is not None:
135
+ restored_img = self.face_helper.paste_faces_to_input_image(
136
+ upsample_img=bg_img,
137
+ draw_box=draw_box,
138
+ face_upsampler=face_upsampler,
139
+ )
140
+ else:
141
+ restored_img = self.face_helper.paste_faces_to_input_image(
142
+ upsample_img=bg_img, draw_box=draw_box
143
+ )
144
+
145
+ # save restored img
146
+ out_path = Path(tempfile.mkdtemp()) / "output.png"
147
+
148
+ if not has_aligned and restored_img is not None:
149
+ imwrite(restored_img, str(out_path))
150
+
151
+ return out_path
152
+
153
+
154
+ def imread(img_path):
155
+ img = cv2.imread(img_path)
156
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
157
+ return img
158
+
159
+
160
+ def set_realesrgan():
161
+ if not torch.cuda.is_available(): # CPU
162
+ import warnings
163
+
164
+ warnings.warn(
165
+ "The unoptimized RealESRGAN is slow on CPU. We do not use it. "
166
+ "If you really want to use it, please modify the corresponding codes.",
167
+ category=RuntimeWarning,
168
+ )
169
+ bg_upsampler = None
170
+ else:
171
+ model = RRDBNet(
172
+ num_in_ch=3,
173
+ num_out_ch=3,
174
+ num_feat=64,
175
+ num_block=23,
176
+ num_grow_ch=32,
177
+ scale=2,
178
+ )
179
+ bg_upsampler = RealESRGANer(
180
+ scale=2,
181
+ model_path="./weights/RealESRGAN_x2plus.pth",
182
+ model=model,
183
+ tile=400,
184
+ tile_pad=40,
185
+ pre_pad=0,
186
+ half=True,
187
+ )
188
+ return bg_upsampler