yumingj commited on
Commit
2e9f191
·
1 Parent(s): 50368d1

process mask

Browse files
Files changed (1) hide show
  1. model.py +3 -1
model.py CHANGED
@@ -86,7 +86,9 @@ class Model:
86
  @staticmethod
87
  def process_mask(mask: np.ndarray) -> np.ndarray:
88
  if mask.shape != (512, 256, 3):
89
- return None
 
 
90
  seg_map = np.full(mask.shape[:-1], -1)
91
  for index, color in enumerate(COLOR_LIST):
92
  seg_map[np.sum(mask == color, axis=2) == 3] = index
 
86
  @staticmethod
87
  def process_mask(mask: np.ndarray) -> np.ndarray:
88
  if mask.shape != (512, 256, 3):
89
+ mask = image.resize(
90
+ size=(256, 512),
91
+ resample=PIL.Image.NEAREST)
92
  seg_map = np.full(mask.shape[:-1], -1)
93
  for index, color in enumerate(COLOR_LIST):
94
  seg_map[np.sum(mask == color, axis=2) == 3] = index