HookBeforeAppDelegate commited on
Commit
d8fafc1
·
verified ·
1 Parent(s): 15346e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -252
app.py CHANGED
@@ -1,283 +1,147 @@
1
- """
2
- This file is used for deploying hugging face demo:
3
- https://huggingface.co/spaces/sczhou/CodeFormer
4
- """
5
-
6
- import sys
7
- sys.path.append('CodeFormer')
8
  import os
 
9
  import cv2
10
- import torch
11
- import torch.nn.functional as F
12
  import gradio as gr
13
-
14
- from torchvision.transforms.functional import normalize
15
-
16
- from basicsr.archs.rrdbnet_arch import RRDBNet
17
- from basicsr.utils import imwrite, img2tensor, tensor2img
18
- from basicsr.utils.download_util import load_file_from_url
19
- from basicsr.utils.misc import gpu_is_available, get_device
20
- from basicsr.utils.realesrgan_utils import RealESRGANer
21
- from basicsr.utils.registry import ARCH_REGISTRY
22
-
23
- from facelib.utils.face_restoration_helper import FaceRestoreHelper
24
- from facelib.utils.misc import is_gray
25
-
26
 
27
  os.system("pip freeze")
28
-
29
- pretrain_model_url = {
30
- 'codeformer': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
31
- 'detection': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth',
32
- 'parsing': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth',
33
- 'realesrgan': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth'
34
- }
35
  # download weights
36
- if not os.path.exists('CodeFormer/weights/CodeFormer/codeformer.pth'):
37
- load_file_from_url(url=pretrain_model_url['codeformer'], model_dir='CodeFormer/weights/CodeFormer', progress=True, file_name=None)
38
- if not os.path.exists('CodeFormer/weights/facelib/detection_Resnet50_Final.pth'):
39
- load_file_from_url(url=pretrain_model_url['detection'], model_dir='CodeFormer/weights/facelib', progress=True, file_name=None)
40
- if not os.path.exists('CodeFormer/weights/facelib/parsing_parsenet.pth'):
41
- load_file_from_url(url=pretrain_model_url['parsing'], model_dir='CodeFormer/weights/facelib', progress=True, file_name=None)
42
- if not os.path.exists('CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth'):
43
- load_file_from_url(url=pretrain_model_url['realesrgan'], model_dir='CodeFormer/weights/realesrgan', progress=True, file_name=None)
 
 
 
 
44
 
45
- # download images
46
  torch.hub.download_url_to_file(
47
- 'https://replicate.com/api/models/sczhou/codeformer/files/fa3fe3d1-76b0-4ca8-ac0d-0a925cb0ff54/06.png',
48
- '01.png')
49
  torch.hub.download_url_to_file(
50
- 'https://replicate.com/api/models/sczhou/codeformer/files/a1daba8e-af14-4b00-86a4-69cec9619b53/04.jpg',
51
- '02.jpg')
52
  torch.hub.download_url_to_file(
53
- 'https://replicate.com/api/models/sczhou/codeformer/files/542d64f9-1712-4de7-85f7-3863009a7c3d/03.jpg',
54
- '03.jpg')
55
  torch.hub.download_url_to_file(
56
- 'https://replicate.com/api/models/sczhou/codeformer/files/a11098b0-a18a-4c02-a19a-9a7045d68426/010.jpg',
57
- '04.jpg')
58
- torch.hub.download_url_to_file(
59
- 'https://replicate.com/api/models/sczhou/codeformer/files/7cf19c2c-e0cf-4712-9af8-cf5bdbb8d0ee/012.jpg',
60
- '05.jpg')
61
-
62
- def imread(img_path):
63
- img = cv2.imread(img_path)
64
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
65
- return img
66
 
67
- # set enhancer with RealESRGAN
68
- def set_realesrgan():
69
- # half = True if torch.cuda.is_available() else False
70
- half = True if gpu_is_available() else False
71
- model = RRDBNet(
72
- num_in_ch=3,
73
- num_out_ch=3,
74
- num_feat=64,
75
- num_block=23,
76
- num_grow_ch=32,
77
- scale=2,
78
- )
79
- upsampler = RealESRGANer(
80
- scale=2,
81
- model_path="CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth",
82
- model=model,
83
- tile=400,
84
- tile_pad=40,
85
- pre_pad=0,
86
- half=half,
87
- )
88
- return upsampler
89
-
90
- upsampler = set_realesrgan()
91
- # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
92
- device = get_device()
93
- codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
94
- dim_embd=512,
95
- codebook_size=1024,
96
- n_head=8,
97
- n_layers=9,
98
- connect_list=["32", "64", "128", "256"],
99
- ).to(device)
100
- ckpt_path = "CodeFormer/weights/CodeFormer/codeformer.pth"
101
- checkpoint = torch.load(ckpt_path)["params_ema"]
102
- codeformer_net.load_state_dict(checkpoint)
103
- codeformer_net.eval()
104
 
105
  os.makedirs('output', exist_ok=True)
106
 
107
- def inference(image, background_enhance, face_upsample, upscale, codeformer_fidelity):
108
- """Run a single prediction on the model"""
109
- try: # global try
110
- # take the default setting for the demo
111
- has_aligned = False
112
- only_center_face = False
113
- draw_box = False
114
- detection_model = "retinaface_resnet50"
115
- print('Inp:', image, background_enhance, face_upsample, upscale, codeformer_fidelity)
116
-
117
- img = cv2.imread(str(image), cv2.IMREAD_COLOR)
118
- print('\timage size:', img.shape)
119
-
120
- upscale = int(upscale) # convert type to int
121
- if upscale > 4: # avoid memory exceeded due to too large upscale
122
- upscale = 4
123
- if upscale > 2 and max(img.shape[:2])>1000: # avoid memory exceeded due to too large img resolution
124
- upscale = 2
125
- if max(img.shape[:2]) > 1500: # avoid memory exceeded due to too large img resolution
126
- upscale = 1
127
- background_enhance = False
128
- face_upsample = False
129
-
130
- face_helper = FaceRestoreHelper(
131
- upscale,
132
- face_size=512,
133
- crop_ratio=(1, 1),
134
- det_model=detection_model,
135
- save_ext="png",
136
- use_parse=True,
137
- device=device,
138
- )
139
- bg_upsampler = upsampler if background_enhance else None
140
- face_upsampler = upsampler if face_upsample else None
141
 
142
- if has_aligned:
143
- # the input faces are already cropped and aligned
144
- img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
145
- face_helper.is_gray = is_gray(img, threshold=5)
146
- if face_helper.is_gray:
147
- print('\tgrayscale input: True')
148
- face_helper.cropped_faces = [img]
 
 
 
 
 
 
 
149
  else:
150
- face_helper.read_image(img)
151
- # get face landmarks for each face
152
- num_det_faces = face_helper.get_face_landmarks_5(
153
- only_center_face=only_center_face, resize=640, eye_dist_threshold=5
154
- )
155
- print(f'\tdetect {num_det_faces} faces')
156
- # align and warp each face
157
- face_helper.align_warp_face()
158
-
159
- # face restoration for each cropped face
160
- for idx, cropped_face in enumerate(face_helper.cropped_faces):
161
- # prepare data
162
- cropped_face_t = img2tensor(
163
- cropped_face / 255.0, bgr2rgb=True, float32=True
164
- )
165
- normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
166
- cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
167
-
168
- try:
169
- with torch.no_grad():
170
- output = codeformer_net(
171
- cropped_face_t, w=codeformer_fidelity, adain=True
172
- )[0]
173
- restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
174
- del output
175
- torch.cuda.empty_cache()
176
- except RuntimeError as error:
177
- print(f"Failed inference for CodeFormer: {error}")
178
- restored_face = tensor2img(
179
- cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
180
- )
181
-
182
- restored_face = restored_face.astype("uint8")
183
- face_helper.add_restored_face(restored_face)
184
-
185
- # paste_back
186
- if not has_aligned:
187
- # upsample the background
188
- if bg_upsampler is not None:
189
- # Now only support RealESRGAN for upsampling background
190
- bg_img = bg_upsampler.enhance(img, outscale=upscale)[0]
191
- else:
192
- bg_img = None
193
- face_helper.get_inverse_affine(None)
194
- # paste each restored face to the input image
195
- if face_upsample and face_upsampler is not None:
196
- restored_img = face_helper.paste_faces_to_input_image(
197
- upsample_img=bg_img,
198
- draw_box=draw_box,
199
- face_upsampler=face_upsampler,
200
- )
201
- else:
202
- restored_img = face_helper.paste_faces_to_input_image(
203
- upsample_img=bg_img, draw_box=draw_box
204
- )
205
-
206
- # save restored img
207
- save_path = f'output/out.png'
208
- imwrite(restored_img, str(save_path))
209
 
210
- restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
211
- return restored_img, save_path
212
  except Exception as error:
213
- print('Global exception', error)
214
  return None, None
215
 
216
 
217
- title = "CodeFormer: Robust Face Restoration and Enhancement Network"
218
- description = r"""<center><img src='https://user-images.githubusercontent.com/14334509/189166076-94bb2cac-4f4e-40fb-a69f-66709e3d98f5.png' alt='CodeFormer logo'></center>
219
- <b>Official Gradio demo</b> for <a href='https://github.com/sczhou/CodeFormer' target='_blank'><b>Towards Robust Blind Face Restoration with Codebook Lookup Transformer (NeurIPS 2022)</b></a>.<br>
220
- 🔥 CodeFormer is a robust face restoration algorithm for old photos or AI-generated faces.<br>
221
- 🤗 Try CodeFormer for improved stable-diffusion generation!<br>
222
  """
223
  article = r"""
224
- If CodeFormer is helpful, please help to ⭐ the <a href='https://github.com/sczhou/CodeFormer' target='_blank'>Github Repo</a>. Thanks!
225
- [![GitHub Stars](https://img.shields.io/github/stars/sczhou/CodeFormer?style=social)](https://github.com/sczhou/CodeFormer)
226
-
227
- ---
228
-
229
- 📝 **Citation**
230
-
231
- If our work is useful for your research, please consider citing:
232
- ```bibtex
233
- @inproceedings{zhou2022codeformer,
234
- author = {Zhou, Shangchen and Chan, Kelvin C.K. and Li, Chongyi and Loy, Chen Change},
235
- title = {Towards Robust Blind Face Restoration with Codebook Lookup TransFormer},
236
- booktitle = {NeurIPS},
237
- year = {2022}
238
- }
239
- ```
240
-
241
- 📋 **License**
242
-
243
- This project is licensed under <a rel="license" href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">S-Lab License 1.0</a>.
244
- Redistribution and use for non-commercial purposes should follow this license.
245
-
246
- 📧 **Contact**
247
-
248
- If you have any questions, please feel free to reach me out at <b>[email protected]</b>.
249
-
250
- <div>
251
- 🤗 Find Me:
252
- <a href="https://twitter.com/ShangchenZhou"><img style="margin-top:0.5em; margin-bottom:0.5em" src="https://img.shields.io/twitter/follow/ShangchenZhou?label=%40ShangchenZhou&style=social" alt="Twitter Follow"></a>
253
- <a href="https://github.com/sczhou"><img style="margin-top:0.5em; margin-bottom:2em" src="https://img.shields.io/github/followers/sczhou?style=social" alt="Github Follow"></a>
254
- </div>
255
-
256
- <center><img src='https://visitor-badge-sczhou.glitch.me/badge?page_id=sczhou/CodeFormer' alt='visitors'></center>
257
  """
258
-
259
  demo = gr.Interface(
260
  inference, [
261
- gr.inputs.Image(type="filepath", label="Input"),
262
- gr.inputs.Checkbox(default=True, label="Background_Enhance"),
263
- gr.inputs.Checkbox(default=True, label="Face_Upsample"),
264
- gr.inputs.Number(default=2, label="Rescaling_Factor (up to 4)"),
265
- gr.Slider(0, 1, value=0.5, step=0.01, label='Codeformer_Fidelity (0 for better quality, 1 for better identity)')
266
  ], [
267
- gr.outputs.Image(type="numpy", label="Output"),
268
- gr.outputs.File(label="Download the output")
269
  ],
270
  title=title,
271
  description=description,
272
- article=article,
273
- examples=[
274
- ['01.png', True, True, 2, 0.7],
275
- ['02.jpg', True, True, 2, 0.7],
276
- ['03.jpg', True, True, 2, 0.7],
277
- ['04.jpg', True, True, 2, 0.1],
278
- ['05.jpg', True, True, 2, 0.1]
279
- ]
280
- )
281
-
282
- demo.queue(concurrency_count=2)
283
- demo.launch()
 
 
 
 
 
 
 
 
1
  import os
2
+
3
  import cv2
 
 
4
  import gradio as gr
5
+ import torch
6
+ from basicsr.archs.srvgg_arch import SRVGGNetCompact
7
+ from gfpgan.utils import GFPGANer
8
+ from realesrgan.utils import RealESRGANer
 
 
 
 
 
 
 
 
 
9
 
10
  os.system("pip freeze")
 
 
 
 
 
 
 
11
  # download weights
12
+ if not os.path.exists('realesr-general-x4v3.pth'):
13
+ os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
14
+ if not os.path.exists('GFPGANv1.2.pth'):
15
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P .")
16
+ if not os.path.exists('GFPGANv1.3.pth'):
17
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P .")
18
+ if not os.path.exists('GFPGANv1.4.pth'):
19
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
20
+ if not os.path.exists('RestoreFormer.pth'):
21
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth -P .")
22
+ if not os.path.exists('CodeFormer.pth'):
23
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth -P .")
24
 
 
25
  torch.hub.download_url_to_file(
26
+ 'https://upload.wikimedia.org/wikipedia/commons/thumb/a/ab/Abraham_Lincoln_O-77_matte_collodion_print.jpg/1024px-Abraham_Lincoln_O-77_matte_collodion_print.jpg',
27
+ 'lincoln.jpg')
28
  torch.hub.download_url_to_file(
29
+ 'https://user-images.githubusercontent.com/17445847/187400315-87a90ac9-d231-45d6-b377-38702bd1838f.jpg',
30
+ 'AI-generate.jpg')
31
  torch.hub.download_url_to_file(
32
+ 'https://user-images.githubusercontent.com/17445847/187400981-8a58f7a4-ef61-42d9-af80-bc6234cef860.jpg',
33
+ 'Blake_Lively.jpg')
34
  torch.hub.download_url_to_file(
35
+ 'https://user-images.githubusercontent.com/17445847/187401133-8a3bf269-5b4d-4432-b2f0-6d26ee1d3307.png',
36
+ '10045.png')
 
 
 
 
 
 
 
 
37
 
38
+ # background enhancer with RealESRGAN
39
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
40
+ model_path = 'realesr-general-x4v3.pth'
41
+ half = True if torch.cuda.is_available() else False
42
+ upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  os.makedirs('output', exist_ok=True)
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ # def inference(img, version, scale, weight):
48
+ def inference(img, version, scale):
49
+ # weight /= 100
50
+ print(img, version, scale)
51
+ if scale > 4:
52
+ scale = 4 # avoid too large scale value
53
+ try:
54
+ extension = os.path.splitext(os.path.basename(str(img)))[1]
55
+ img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
56
+ if len(img.shape) == 3 and img.shape[2] == 4:
57
+ img_mode = 'RGBA'
58
+ elif len(img.shape) == 2: # for gray inputs
59
+ img_mode = None
60
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
61
  else:
62
+ img_mode = None
63
+
64
+ h, w = img.shape[0:2]
65
+ if h > 3500 or w > 3500:
66
+ print('too large size')
67
+ return None, None
68
+
69
+ if h < 300:
70
+ img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
71
+
72
+ if version == 'v1.2':
73
+ face_enhancer = GFPGANer(
74
+ model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
75
+ elif version == 'v1.3':
76
+ face_enhancer = GFPGANer(
77
+ model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
78
+ elif version == 'v1.4':
79
+ face_enhancer = GFPGANer(
80
+ model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
81
+ elif version == 'RestoreFormer':
82
+ face_enhancer = GFPGANer(
83
+ model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
84
+ # elif version == 'CodeFormer':
85
+ # face_enhancer = GFPGANer(
86
+ # model_path='CodeFormer.pth', upscale=2, arch='CodeFormer', channel_multiplier=2, bg_upsampler=upsampler)
87
+
88
+ try:
89
+ # _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True, weight=weight)
90
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
91
+ except RuntimeError as error:
92
+ print('Error', error)
93
+
94
+ try:
95
+ if scale != 2:
96
+ interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
97
+ h, w = img.shape[0:2]
98
+ output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
99
+ except Exception as error:
100
+ print('wrong scale input.', error)
101
+ if img_mode == 'RGBA': # RGBA images should be saved in png format
102
+ extension = 'png'
103
+ else:
104
+ extension = 'jpg'
105
+ save_path = f'output/out.{extension}'
106
+ cv2.imwrite(save_path, output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
109
+ return output, save_path
110
  except Exception as error:
111
+ print('global exception', error)
112
  return None, None
113
 
114
 
115
+ title = "GFPGAN: Practical Face Restoration Algorithm"
116
+ description = r"""Gradio demo for <a href='https://github.com/TencentARC/GFPGAN' target='_blank'><b>GFPGAN: Towards Real-World Blind Face Restoration with Generative Facial Prior</b></a>.<br>
117
+ It can be used to restore your **old photos** or improve **AI-generated faces**.<br>
118
+ To use it, simply upload your image.<br>
119
+ If GFPGAN is helpful, please help to ⭐ the <a href='https://github.com/TencentARC/GFPGAN' target='_blank'>Github Repo</a> and recommend it to your friends 😊
120
  """
121
  article = r"""
122
+ [![download](https://img.shields.io/github/downloads/TencentARC/GFPGAN/total.svg)](https://github.com/TencentARC/GFPGAN/releases)
123
+ [![GitHub Stars](https://img.shields.io/github/stars/TencentARC/GFPGAN?style=social)](https://github.com/TencentARC/GFPGAN)
124
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2101.04061)
125
+ If you have any question, please email 📧 `[email protected]` or `[email protected]`.
126
+ <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_GFPGAN' alt='visitor badge'></center>
127
+ <center><img src='https://visitor-badge.glitch.me/badge?page_id=Gradio_Xintao_GFPGAN' alt='visitor badge'></center>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  """
 
129
  demo = gr.Interface(
130
  inference, [
131
+ gr.Image(type="filepath", label="Input"),
132
+ # gr.Radio(['v1.2', 'v1.3', 'v1.4', 'RestoreFormer', 'CodeFormer'], type="value", value='v1.4', label='version'),
133
+ gr.Radio(['v1.2', 'v1.3', 'v1.4', 'RestoreFormer'], type="value", value='v1.4', label='version'),
134
+ gr.Number(label="Rescaling factor", value=2),
135
+ # gr.Slider(0, 100, label='Weight, only for CodeFormer. 0 for better quality, 100 for better identity', value=50)
136
  ], [
137
+ gr.Image(type="numpy", label="Output (The whole image)"),
138
+ gr.File(label="Download the output image")
139
  ],
140
  title=title,
141
  description=description,
142
+ article=article,
143
+ # examples=[['AI-generate.jpg', 'v1.4', 2, 50], ['lincoln.jpg', 'v1.4', 2, 50], ['Blake_Lively.jpg', 'v1.4', 2, 50],
144
+ # ['10045.png', 'v1.4', 2, 50]]).launch()
145
+ examples=[['AI-generate.jpg', 'v1.4', 2], ['lincoln.jpg', 'v1.4', 2], ['Blake_Lively.jpg', 'v1.4', 2],
146
+ ['10045.png', 'v1.4', 2]])
147
+ demo.queue().launch()