culture commited on
Commit
703e5bb
·
1 Parent(s): a5cfdba

Upload gfpgan/utils.py

Browse files
Files changed (1) hide show
  1. gfpgan/utils.py +130 -0
gfpgan/utils.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import torch
4
+ from basicsr.utils import img2tensor, tensor2img
5
+ from basicsr.utils.download_util import load_file_from_url
6
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
7
+ from torchvision.transforms.functional import normalize
8
+
9
+ from gfpgan.archs.gfpganv1_arch import GFPGANv1
10
+ from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
11
+
12
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
13
+
14
+
15
+ class GFPGANer():
16
+ """Helper for restoration with GFPGAN.
17
+
18
+ It will detect and crop faces, and then resize the faces to 512x512.
19
+ GFPGAN is used to restored the resized faces.
20
+ The background is upsampled with the bg_upsampler.
21
+ Finally, the faces will be pasted back to the upsample background image.
22
+
23
+ Args:
24
+ model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
25
+ upscale (float): The upscale of the final output. Default: 2.
26
+ arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
27
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
28
+ bg_upsampler (nn.Module): The upsampler for the background. Default: None.
29
+ """
30
+
31
+ def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None):
32
+ self.upscale = upscale
33
+ self.bg_upsampler = bg_upsampler
34
+
35
+ # initialize model
36
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
37
+ # initialize the GFP-GAN
38
+ if arch == 'clean':
39
+ self.gfpgan = GFPGANv1Clean(
40
+ out_size=512,
41
+ num_style_feat=512,
42
+ channel_multiplier=channel_multiplier,
43
+ decoder_load_path=None,
44
+ fix_decoder=False,
45
+ num_mlp=8,
46
+ input_is_latent=True,
47
+ different_w=True,
48
+ narrow=1,
49
+ sft_half=True)
50
+ else:
51
+ self.gfpgan = GFPGANv1(
52
+ out_size=512,
53
+ num_style_feat=512,
54
+ channel_multiplier=channel_multiplier,
55
+ decoder_load_path=None,
56
+ fix_decoder=True,
57
+ num_mlp=8,
58
+ input_is_latent=True,
59
+ different_w=True,
60
+ narrow=1,
61
+ sft_half=True)
62
+ # initialize face helper
63
+ self.face_helper = FaceRestoreHelper(
64
+ upscale,
65
+ face_size=512,
66
+ crop_ratio=(1, 1),
67
+ det_model='retinaface_resnet50',
68
+ save_ext='png',
69
+ device=self.device)
70
+
71
+ if model_path.startswith('https://'):
72
+ model_path = load_file_from_url(
73
+ url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None)
74
+ loadnet = torch.load(model_path)
75
+ if 'params_ema' in loadnet:
76
+ keyname = 'params_ema'
77
+ else:
78
+ keyname = 'params'
79
+ self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
80
+ self.gfpgan.eval()
81
+ self.gfpgan = self.gfpgan.to(self.device)
82
+
83
+ @torch.no_grad()
84
+ def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
85
+ self.face_helper.clean_all()
86
+
87
+ if has_aligned: # the inputs are already aligned
88
+ img = cv2.resize(img, (512, 512))
89
+ self.face_helper.cropped_faces = [img]
90
+ else:
91
+ self.face_helper.read_image(img)
92
+ # get face landmarks for each face
93
+ self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
94
+ # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
95
+ # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
96
+ # align and warp each face
97
+ self.face_helper.align_warp_face()
98
+
99
+ # face restoration
100
+ for cropped_face in self.face_helper.cropped_faces:
101
+ # prepare data
102
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
103
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
104
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
105
+
106
+ try:
107
+ output = self.gfpgan(cropped_face_t, return_rgb=False)[0]
108
+ # convert to image
109
+ restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
110
+ except RuntimeError as error:
111
+ print(f'\tFailed inference for GFPGAN: {error}.')
112
+ restored_face = cropped_face
113
+
114
+ restored_face = restored_face.astype('uint8')
115
+ self.face_helper.add_restored_face(restored_face)
116
+
117
+ if not has_aligned and paste_back:
118
+ # upsample the background
119
+ if self.bg_upsampler is not None:
120
+ # Now only support RealESRGAN for upsampling background
121
+ bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
122
+ else:
123
+ bg_img = None
124
+
125
+ self.face_helper.get_inverse_affine(None)
126
+ # paste each restored face to the input image
127
+ restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
128
+ return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
129
+ else:
130
+ return self.face_helper.cropped_faces, self.face_helper.restored_faces, None