import os from glob import glob from collections import defaultdict import numpy as np from PIL import Image class MaskDataset(object): def __init__(self, root, sequences, is_label=True): self.is_label = is_label self.sequences = {} for seq in sequences: print(root, seq) if is_label: masks = np.sort(glob(os.path.join(root, seq, '*.png'))).tolist() else: masks = sorted(glob(os.path.join(root, seq, 'dynamic_mask_*.png')), key=lambda x: int(os.path.basename(x).split('_')[-1].split('.')[0])) self.sequences[seq] = masks def read_masks(self, seq): masks = [] for msk in self.sequences[seq]: if self.is_label: img = np.array(Image.open(msk)) img[img>0] = 255 img = Image.fromarray(img) masks.append(img) else: masks.append(Image.open(msk)) return masks