sczhou commited on
Commit
bddee53
·
1 Parent(s): 4d598f8

add grayscale judgement (#32)

Browse files
basicsr/utils/img_util.py CHANGED
@@ -168,3 +168,4 @@ def crop_border(imgs, crop_border):
168
  return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
169
  else:
170
  return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
 
 
168
  return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
169
  else:
170
  return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
171
+
facelib/utils/face_restoration_helper.py CHANGED
@@ -6,7 +6,7 @@ from torchvision.transforms.functional import normalize
6
 
7
  from facelib.detection import init_detection_model
8
  from facelib.parsing import init_parsing_model
9
- from facelib.utils.misc import img2tensor, imwrite
10
 
11
 
12
  def get_largest_face(det_faces, h, w):
@@ -125,6 +125,9 @@ class FaceRestoreHelper(object):
125
  img = img[:, :, 0:3]
126
 
127
  self.input_img = img
 
 
 
128
 
129
  if min(self.input_img.shape[:2])<512:
130
  f = 512.0/min(self.input_img.shape[:2])
@@ -416,6 +419,9 @@ class FaceRestoreHelper(object):
416
  fuse_mask = (inv_soft_parse_mask<inv_soft_mask).astype('int')
417
  inv_soft_mask = inv_soft_parse_mask*fuse_mask + inv_soft_mask*(1-fuse_mask)
418
 
 
 
 
419
  if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
420
  alpha = upsample_img[:, :, 3:]
421
  upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
 
6
 
7
  from facelib.detection import init_detection_model
8
  from facelib.parsing import init_parsing_model
9
+ from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray
10
 
11
 
12
  def get_largest_face(det_faces, h, w):
 
125
  img = img[:, :, 0:3]
126
 
127
  self.input_img = img
128
+ self.is_gray = is_gray(img, threshold=5)
129
+ if self.is_gray:
130
+ print('Grayscale input: True')
131
 
132
  if min(self.input_img.shape[:2])<512:
133
  f = 512.0/min(self.input_img.shape[:2])
 
419
  fuse_mask = (inv_soft_parse_mask<inv_soft_mask).astype('int')
420
  inv_soft_mask = inv_soft_parse_mask*fuse_mask + inv_soft_mask*(1-fuse_mask)
421
 
422
+ if self.is_gray:
423
+ pasted_face = bgr2gray(pasted_face) # convert img into grayscale
424
+
425
  if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
426
  alpha = upsample_img[:, :, 3:]
427
  upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
facelib/utils/misc.py CHANGED
@@ -1,6 +1,8 @@
1
  import cv2
2
  import os
3
  import os.path as osp
 
 
4
  import torch
5
  from torch.hub import download_url_to_file, get_dir
6
  from urllib.parse import urlparse
@@ -139,3 +141,34 @@ def scandir(dir_path, suffix=None, recursive=False, full_path=False):
139
  continue
140
 
141
  return _scandir(dir_path, suffix=suffix, recursive=recursive)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import cv2
2
  import os
3
  import os.path as osp
4
+ import numpy as np
5
+ from PIL import Image
6
  import torch
7
  from torch.hub import download_url_to_file, get_dir
8
  from urllib.parse import urlparse
 
141
  continue
142
 
143
  return _scandir(dir_path, suffix=suffix, recursive=recursive)
144
+
145
+
146
+ def is_gray(img, threshold=10):
147
+ img = Image.fromarray(img)
148
+ if len(img.getbands()) == 1:
149
+ return True
150
+ img1 = np.asarray(img.getchannel(channel=0), dtype=np.int16)
151
+ img2 = np.asarray(img.getchannel(channel=1), dtype=np.int16)
152
+ img3 = np.asarray(img.getchannel(channel=2), dtype=np.int16)
153
+ diff1 = (img1 - img2).var()
154
+ diff2 = (img2 - img3).var()
155
+ diff3 = (img3 - img1).var()
156
+ diff_sum = (diff1 + diff2 + diff3) / 3.0
157
+ if diff_sum <= threshold:
158
+ return True
159
+ else:
160
+ return False
161
+
162
+ def rgb2gray(img, out_channel=3):
163
+ r, g, b = img[:,:,0], img[:,:,1], img[:,:,2]
164
+ gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
165
+ if out_channel == 3:
166
+ gray = gray[:,:,np.newaxis].repeat(3, axis=2)
167
+ return gray
168
+
169
+ def bgr2gray(img, out_channel=3):
170
+ b, g, r = img[:,:,0], img[:,:,1], img[:,:,2]
171
+ gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
172
+ if out_channel == 3:
173
+ gray = gray[:,:,np.newaxis].repeat(3, axis=2)
174
+ return gray