sky24h commited on
Commit
ba32963
·
1 Parent(s): f3daba8
Files changed (1) hide show
  1. seg2art/inference_util.py +5 -5
seg2art/inference_util.py CHANGED
@@ -33,15 +33,15 @@ colors = np.array(colors)
33
 
34
 
35
  def remap_label(arr):
36
- # compare only first 1 channel to speed up
37
- arr_r = arr[:, :, 0]
38
 
39
  # remap color to label
40
  for i in range(len(colors)):
41
- arr_r[arr_r == colors[i][0]] = values[i]
42
  # others to 15
43
- arr_r[arr_r > 15] = 15
44
- return arr_r
45
 
46
 
47
  preprocess = transforms.Compose(
 
33
 
34
 
35
  def remap_label(arr):
36
+ # compare only last color channel to speed up
37
+ arr_b = arr[:, :, 0]
38
 
39
  # remap color to label
40
  for i in range(len(colors)):
41
+ arr_b[arr_b == colors[i][2]] = values[i]
42
  # others to 15
43
+ arr_b[arr_b > 15] = 15
44
+ return arr_b
45
 
46
 
47
  preprocess = transforms.Compose(