duongttr commited on
Commit
2b79d08
·
1 Parent(s): 19d9b4e

Update RGB2LAB to optimize time

Browse files
Files changed (1) hide show
  1. src/utils.py +5 -2
src/utils.py CHANGED
@@ -13,6 +13,7 @@ from numba import cuda, jit
13
  import math
14
  import torchvision.utils as vutils
15
  from torch.autograd import Variable
 
16
 
17
  rgb_from_xyz = np.array(
18
  [
@@ -318,7 +319,9 @@ class RGB2Lab(object):
318
  pass
319
 
320
  def __call__(self, inputs):
321
- return color.rgb2lab(inputs)
 
 
322
 
323
 
324
  class ToTensor(object):
@@ -846,4 +849,4 @@ def print_num_params(model, is_trainable=False):
846
  num_params = sum(p.numel() for p in model.parameters())
847
  print(f"| GENERAL | {model_name} | {('{:,}'.format(num_params)).rjust(10)} |")
848
 
849
- return num_params
 
13
  import math
14
  import torchvision.utils as vutils
15
  from torch.autograd import Variable
16
+ import cv2
17
 
18
  rgb_from_xyz = np.array(
19
  [
 
319
  pass
320
 
321
  def __call__(self, inputs):
322
+ normed_inputs = np.float32(inputs) / 255.0
323
+ rgb_inputs = cv2.cvtColor(normed_inputs, cv2.COLOR_RGB2LAB)
324
+ return rgb_inputs
325
 
326
 
327
  class ToTensor(object):
 
849
  num_params = sum(p.numel() for p in model.parameters())
850
  print(f"| GENERAL | {model_name} | {('{:,}'.format(num_params)).rjust(10)} |")
851
 
852
+ return num_params