sczhou commited on
Commit
224ac55
·
1 Parent(s): 6c247ba

transfer color for grayscale inputs.

Browse files
facelib/utils/face_restoration_helper.py CHANGED
@@ -6,7 +6,7 @@ from torchvision.transforms.functional import normalize
6
 
7
  from facelib.detection import init_detection_model
8
  from facelib.parsing import init_parsing_model
9
- from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray
10
  from basicsr.utils.misc import get_device
11
 
12
 
@@ -300,10 +300,12 @@ class FaceRestoreHelper(object):
300
  torch.save(inverse_affine, save_path)
301
 
302
 
303
- def add_restored_face(self, face):
304
  if self.is_gray:
305
- face = bgr2gray(face) # convert img into grayscale
306
- self.restored_faces.append(face)
 
 
307
 
308
 
309
  def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
 
6
 
7
  from facelib.detection import init_detection_model
8
  from facelib.parsing import init_parsing_model
9
+ from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray, adain_npy
10
  from basicsr.utils.misc import get_device
11
 
12
 
 
300
  torch.save(inverse_affine, save_path)
301
 
302
 
303
+ def add_restored_face(self, restored_face, input_face=None):
304
  if self.is_gray:
305
+ restored_face = bgr2gray(restored_face) # convert img into grayscale
306
+ if input_face is not None:
307
+ restored_face = adain_npy(restored_face, input_face) # transfer the color
308
+ self.restored_faces.append(restored_face)
309
 
310
 
311
  def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
facelib/utils/misc.py CHANGED
@@ -172,3 +172,31 @@ def bgr2gray(img, out_channel=3):
172
  if out_channel == 3:
173
  gray = gray[:,:,np.newaxis].repeat(3, axis=2)
174
  return gray
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  if out_channel == 3:
173
  gray = gray[:,:,np.newaxis].repeat(3, axis=2)
174
  return gray
175
+
176
+
177
+ def calc_mean_std(feat, eps=1e-5):
178
+ """
179
+ Args:
180
+ feat (numpy): 3D [w h c]s
181
+ """
182
+ size = feat.shape
183
+ assert len(size) == 3, 'The input feature should be 3D tensor.'
184
+ c = size[2]
185
+ feat_var = feat.reshape(-1, c).var(axis=0) + eps
186
+ feat_std = np.sqrt(feat_var).reshape(1, 1, c)
187
+ feat_mean = feat.reshape(-1, c).mean(axis=0).reshape(1, 1, c)
188
+ return feat_mean, feat_std
189
+
190
+
191
+ def adain_npy(content_feat, style_feat):
192
+ """Adaptive instance normalization for numpy.
193
+
194
+ Args:
195
+ content_feat (numpy): The input feature.
196
+ style_feat (numpy): The reference feature.
197
+ """
198
+ size = content_feat.shape
199
+ style_mean, style_std = calc_mean_std(style_feat)
200
+ content_mean, content_std = calc_mean_std(content_feat)
201
+ normalized_feat = (content_feat - np.broadcast_to(content_mean, size)) / np.broadcast_to(content_std, size)
202
+ return normalized_feat * np.broadcast_to(style_std, size) + np.broadcast_to(style_mean, size)
inference_codeformer.py CHANGED
@@ -205,7 +205,7 @@ if __name__ == '__main__':
205
  restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
206
 
207
  restored_face = restored_face.astype('uint8')
208
- face_helper.add_restored_face(restored_face)
209
 
210
  # paste_back
211
  if not args.has_aligned:
 
205
  restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
206
 
207
  restored_face = restored_face.astype('uint8')
208
+ face_helper.add_restored_face(restored_face, cropped_face)
209
 
210
  # paste_back
211
  if not args.has_aligned: