HookBeforeAppDelegate commited on
Commit
1a59253
Β·
verified Β·
1 Parent(s): cbb77d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +263 -134
app.py CHANGED
@@ -1,154 +1,283 @@
1
- import gradio as gr
2
- import numpy as np
3
- import random
 
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
 
 
7
  import torch
 
 
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
  }
65
- """
 
 
 
 
 
 
 
 
66
 
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
 
 
 
 
79
 
80
- run_button = gr.Button("Run", scale=0, variant="primary")
 
 
 
81
 
82
- result = gr.Image(label="Result", show_label=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
 
 
 
 
 
 
98
  )
 
 
 
99
 
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
 
 
 
 
 
 
101
 
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
 
 
 
 
 
109
  )
110
 
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
 
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
 
 
 
 
 
 
 
 
126
  )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
  )
135
 
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  )
152
 
153
- if __name__ == "__main__":
154
- demo.launch()
 
1
+ """
2
+ This file is used for deploying hugging face demo:
3
+ https://huggingface.co/spaces/sczhou/CodeFormer
4
+ """
5
 
6
+ import sys
7
+ sys.path.append('CodeFormer')
8
+ import os
9
+ import cv2
10
  import torch
11
+ import torch.nn.functional as F
12
+ import gradio as gr
13
+
14
+ from torchvision.transforms.functional import normalize
15
 
16
+ from basicsr.archs.rrdbnet_arch import RRDBNet
17
+ from basicsr.utils import imwrite, img2tensor, tensor2img
18
+ from basicsr.utils.download_util import load_file_from_url
19
+ from basicsr.utils.misc import gpu_is_available, get_device
20
+ from basicsr.utils.realesrgan_utils import RealESRGANer
21
+ from basicsr.utils.registry import ARCH_REGISTRY
22
+
23
+ from facelib.utils.face_restoration_helper import FaceRestoreHelper
24
+ from facelib.utils.misc import is_gray
25
+
26
+
27
+ os.system("pip freeze")
28
+
29
+ pretrain_model_url = {
30
+ 'codeformer': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
31
+ 'detection': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth',
32
+ 'parsing': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth',
33
+ 'realesrgan': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  }
35
+ # download weights
36
+ if not os.path.exists('CodeFormer/weights/CodeFormer/codeformer.pth'):
37
+ load_file_from_url(url=pretrain_model_url['codeformer'], model_dir='CodeFormer/weights/CodeFormer', progress=True, file_name=None)
38
+ if not os.path.exists('CodeFormer/weights/facelib/detection_Resnet50_Final.pth'):
39
+ load_file_from_url(url=pretrain_model_url['detection'], model_dir='CodeFormer/weights/facelib', progress=True, file_name=None)
40
+ if not os.path.exists('CodeFormer/weights/facelib/parsing_parsenet.pth'):
41
+ load_file_from_url(url=pretrain_model_url['parsing'], model_dir='CodeFormer/weights/facelib', progress=True, file_name=None)
42
+ if not os.path.exists('CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth'):
43
+ load_file_from_url(url=pretrain_model_url['realesrgan'], model_dir='CodeFormer/weights/realesrgan', progress=True, file_name=None)
44
 
45
+ # download images
46
+ torch.hub.download_url_to_file(
47
+ 'https://replicate.com/api/models/sczhou/codeformer/files/fa3fe3d1-76b0-4ca8-ac0d-0a925cb0ff54/06.png',
48
+ '01.png')
49
+ torch.hub.download_url_to_file(
50
+ 'https://replicate.com/api/models/sczhou/codeformer/files/a1daba8e-af14-4b00-86a4-69cec9619b53/04.jpg',
51
+ '02.jpg')
52
+ torch.hub.download_url_to_file(
53
+ 'https://replicate.com/api/models/sczhou/codeformer/files/542d64f9-1712-4de7-85f7-3863009a7c3d/03.jpg',
54
+ '03.jpg')
55
+ torch.hub.download_url_to_file(
56
+ 'https://replicate.com/api/models/sczhou/codeformer/files/a11098b0-a18a-4c02-a19a-9a7045d68426/010.jpg',
57
+ '04.jpg')
58
+ torch.hub.download_url_to_file(
59
+ 'https://replicate.com/api/models/sczhou/codeformer/files/7cf19c2c-e0cf-4712-9af8-cf5bdbb8d0ee/012.jpg',
60
+ '05.jpg')
61
 
62
+ def imread(img_path):
63
+ img = cv2.imread(img_path)
64
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
65
+ return img
66
 
67
+ # set enhancer with RealESRGAN
68
+ def set_realesrgan():
69
+ # half = True if torch.cuda.is_available() else False
70
+ half = True if gpu_is_available() else False
71
+ model = RRDBNet(
72
+ num_in_ch=3,
73
+ num_out_ch=3,
74
+ num_feat=64,
75
+ num_block=23,
76
+ num_grow_ch=32,
77
+ scale=2,
78
+ )
79
+ upsampler = RealESRGANer(
80
+ scale=2,
81
+ model_path="CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth",
82
+ model=model,
83
+ tile=400,
84
+ tile_pad=40,
85
+ pre_pad=0,
86
+ half=half,
87
+ )
88
+ return upsampler
89
 
90
+ upsampler = set_realesrgan()
91
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
92
+ device = get_device()
93
+ codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
94
+ dim_embd=512,
95
+ codebook_size=1024,
96
+ n_head=8,
97
+ n_layers=9,
98
+ connect_list=["32", "64", "128", "256"],
99
+ ).to(device)
100
+ ckpt_path = "CodeFormer/weights/CodeFormer/codeformer.pth"
101
+ checkpoint = torch.load(ckpt_path)["params_ema"]
102
+ codeformer_net.load_state_dict(checkpoint)
103
+ codeformer_net.eval()
104
+
105
+ os.makedirs('output', exist_ok=True)
106
+
107
+ def inference(image, background_enhance, face_upsample, upscale, codeformer_fidelity):
108
+ """Run a single prediction on the model"""
109
+ try: # global try
110
+ # take the default setting for the demo
111
+ has_aligned = False
112
+ only_center_face = False
113
+ draw_box = False
114
+ detection_model = "retinaface_resnet50"
115
+ print('Inp:', image, background_enhance, face_upsample, upscale, codeformer_fidelity)
116
+
117
+ img = cv2.imread(str(image), cv2.IMREAD_COLOR)
118
+ print('\timage size:', img.shape)
119
+
120
+ upscale = int(upscale) # convert type to int
121
+ if upscale > 4: # avoid memory exceeded due to too large upscale
122
+ upscale = 4
123
+ if upscale > 2 and max(img.shape[:2])>1000: # avoid memory exceeded due to too large img resolution
124
+ upscale = 2
125
+ if max(img.shape[:2]) > 1500: # avoid memory exceeded due to too large img resolution
126
+ upscale = 1
127
+ background_enhance = False
128
+ face_upsample = False
129
+
130
+ face_helper = FaceRestoreHelper(
131
+ upscale,
132
+ face_size=512,
133
+ crop_ratio=(1, 1),
134
+ det_model=detection_model,
135
+ save_ext="png",
136
+ use_parse=True,
137
+ device=device,
138
+ )
139
+ bg_upsampler = upsampler if background_enhance else None
140
+ face_upsampler = upsampler if face_upsample else None
141
 
142
+ if has_aligned:
143
+ # the input faces are already cropped and aligned
144
+ img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
145
+ face_helper.is_gray = is_gray(img, threshold=5)
146
+ if face_helper.is_gray:
147
+ print('\tgrayscale input: True')
148
+ face_helper.cropped_faces = [img]
149
+ else:
150
+ face_helper.read_image(img)
151
+ # get face landmarks for each face
152
+ num_det_faces = face_helper.get_face_landmarks_5(
153
+ only_center_face=only_center_face, resize=640, eye_dist_threshold=5
154
  )
155
+ print(f'\tdetect {num_det_faces} faces')
156
+ # align and warp each face
157
+ face_helper.align_warp_face()
158
 
159
+ # face restoration for each cropped face
160
+ for idx, cropped_face in enumerate(face_helper.cropped_faces):
161
+ # prepare data
162
+ cropped_face_t = img2tensor(
163
+ cropped_face / 255.0, bgr2rgb=True, float32=True
164
+ )
165
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
166
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
167
 
168
+ try:
169
+ with torch.no_grad():
170
+ output = codeformer_net(
171
+ cropped_face_t, w=codeformer_fidelity, adain=True
172
+ )[0]
173
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
174
+ del output
175
+ torch.cuda.empty_cache()
176
+ except RuntimeError as error:
177
+ print(f"Failed inference for CodeFormer: {error}")
178
+ restored_face = tensor2img(
179
+ cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
180
  )
181
 
182
+ restored_face = restored_face.astype("uint8")
183
+ face_helper.add_restored_face(restored_face)
 
 
 
 
 
184
 
185
+ # paste_back
186
+ if not has_aligned:
187
+ # upsample the background
188
+ if bg_upsampler is not None:
189
+ # Now only support RealESRGAN for upsampling background
190
+ bg_img = bg_upsampler.enhance(img, outscale=upscale)[0]
191
+ else:
192
+ bg_img = None
193
+ face_helper.get_inverse_affine(None)
194
+ # paste each restored face to the input image
195
+ if face_upsample and face_upsampler is not None:
196
+ restored_img = face_helper.paste_faces_to_input_image(
197
+ upsample_img=bg_img,
198
+ draw_box=draw_box,
199
+ face_upsampler=face_upsampler,
200
  )
201
+ else:
202
+ restored_img = face_helper.paste_faces_to_input_image(
203
+ upsample_img=bg_img, draw_box=draw_box
 
 
 
 
204
  )
205
 
206
+ # save restored img
207
+ save_path = f'output/out.png'
208
+ imwrite(restored_img, str(save_path))
209
+
210
+ restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
211
+ return restored_img, save_path
212
+ except Exception as error:
213
+ print('Global exception', error)
214
+ return None, None
215
+
216
+
217
+ title = "CodeFormer: Robust Face Restoration and Enhancement Network"
218
+ description = r"""<center><img src='https://user-images.githubusercontent.com/14334509/189166076-94bb2cac-4f4e-40fb-a69f-66709e3d98f5.png' alt='CodeFormer logo'></center>
219
+ <b>Official Gradio demo</b> for <a href='https://github.com/sczhou/CodeFormer' target='_blank'><b>Towards Robust Blind Face Restoration with Codebook Lookup Transformer (NeurIPS 2022)</b></a>.<br>
220
+ πŸ”₯ CodeFormer is a robust face restoration algorithm for old photos or AI-generated faces.<br>
221
+ πŸ€— Try CodeFormer for improved stable-diffusion generation!<br>
222
+ """
223
+ article = r"""
224
+ If CodeFormer is helpful, please help to ⭐ the <a href='https://github.com/sczhou/CodeFormer' target='_blank'>Github Repo</a>. Thanks!
225
+ [![GitHub Stars](https://img.shields.io/github/stars/sczhou/CodeFormer?style=social)](https://github.com/sczhou/CodeFormer)
226
+
227
+ ---
228
+
229
+ πŸ“ **Citation**
230
+
231
+ If our work is useful for your research, please consider citing:
232
+ ```bibtex
233
+ @inproceedings{zhou2022codeformer,
234
+ author = {Zhou, Shangchen and Chan, Kelvin C.K. and Li, Chongyi and Loy, Chen Change},
235
+ title = {Towards Robust Blind Face Restoration with Codebook Lookup TransFormer},
236
+ booktitle = {NeurIPS},
237
+ year = {2022}
238
+ }
239
+ ```
240
+
241
+ πŸ“‹ **License**
242
+
243
+ This project is licensed under <a rel="license" href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">S-Lab License 1.0</a>.
244
+ Redistribution and use for non-commercial purposes should follow this license.
245
+
246
+ πŸ“§ **Contact**
247
+
248
+ If you have any questions, please feel free to reach me out at <b>[email protected]</b>.
249
+
250
+ <div>
251
+ πŸ€— Find Me:
252
+ <a href="https://twitter.com/ShangchenZhou"><img style="margin-top:0.5em; margin-bottom:0.5em" src="https://img.shields.io/twitter/follow/ShangchenZhou?label=%40ShangchenZhou&style=social" alt="Twitter Follow"></a>
253
+ <a href="https://github.com/sczhou"><img style="margin-top:0.5em; margin-bottom:2em" src="https://img.shields.io/github/followers/sczhou?style=social" alt="Github Follow"></a>
254
+ </div>
255
+
256
+ <center><img src='https://visitor-badge-sczhou.glitch.me/badge?page_id=sczhou/CodeFormer' alt='visitors'></center>
257
+ """
258
+
259
+ demo = gr.Interface(
260
+ inference, [
261
+ gr.inputs.Image(type="filepath", label="Input"),
262
+ gr.inputs.Checkbox(default=True, label="Background_Enhance"),
263
+ gr.inputs.Checkbox(default=True, label="Face_Upsample"),
264
+ gr.inputs.Number(default=2, label="Rescaling_Factor (up to 4)"),
265
+ gr.Slider(0, 1, value=0.5, step=0.01, label='Codeformer_Fidelity (0 for better quality, 1 for better identity)')
266
+ ], [
267
+ gr.outputs.Image(type="numpy", label="Output"),
268
+ gr.outputs.File(label="Download the output")
269
+ ],
270
+ title=title,
271
+ description=description,
272
+ article=article,
273
+ examples=[
274
+ ['01.png', True, True, 2, 0.7],
275
+ ['02.jpg', True, True, 2, 0.7],
276
+ ['03.jpg', True, True, 2, 0.7],
277
+ ['04.jpg', True, True, 2, 0.1],
278
+ ['05.jpg', True, True, 2, 0.1]
279
+ ]
280
  )
281
 
282
+ demo.queue(concurrency_count=2)
283
+ demo.launch()