import json import numpy as np import torch import torchvision.transforms as transforms from pycocotools import mask as mask_utils from skimage import io from torch.utils.data import DataLoader, Dataset from torch.utils.data.distributed import DistributedSampler from efficientvit.apps.data_provider import DataProvider from efficientvit.samcore.data_provider.utils import ( Normalize_and_Pad, RandomHFlip, ResizeLongestSide, SAMDistributedSampler, ) __all__ = ["SAMDataProvider"] class OnlineDataset(Dataset): def __init__(self, root, train=True, num_masks=64, transform=None): self.root = root self.train = train self.num_masks = num_masks self.transform = transform self.data = open(f"{self.root}/sa_images_ids.txt", "r").read().splitlines() if self.train: self.data = self.data[: int(len(self.data) * 0.99)] else: self.data = self.data[int(len(self.data) * 0.99) :] def __len__(self): return len(self.data) def __getitem__(self, idx): """ Note: We provide the simplest data organization here. You can modify the code according to your data organization. """ index = int(self.data[idx]) image_path = f"{self.root}/images/sa_{index}.jpg" image = io.imread(image_path) json_path = f"{self.root}/masks/sa_{index}.json" annotations = json.load(open(json_path))["annotations"] if self.train: if len(annotations) > self.num_masks: r = np.random.choice(len(annotations), size=self.num_masks, replace=False) else: repeat, residue = self.num_masks // len(annotations), self.num_masks % len(annotations) r = np.random.choice(len(annotations), size=residue, replace=False) r = np.concatenate([np.arange(len(annotations)) for _ in range(repeat)] + [r], axis=0) else: if len(annotations) > self.num_masks: r = np.arange(self.num_masks) else: repeat, residue = self.num_masks // len(annotations), self.num_masks % len(annotations) r = np.arange(residue) r = np.concatenate([np.arange(len(annotations)) for _ in range(repeat)] + [r], axis=0) masks = np.stack([mask_utils.decode(annotations[i]["segmentation"]) for i in r]) points = np.stack([annotations[i]["point_coords"][0] for i in r]) bboxs = np.stack([annotations[i]["bbox"] for i in r]) image = torch.tensor(image, dtype=torch.float32) image = torch.transpose(torch.transpose(image, 1, 2), 0, 1) masks = torch.tensor(masks, dtype=torch.float32) points = torch.tensor(points, dtype=torch.float32) bboxs = torch.tensor(bboxs, dtype=torch.float32) sample = { "image": image, "masks": masks, "points": points, "bboxs": bboxs, "shape": torch.tensor(image.shape[-2:]), } if self.transform: sample = self.transform(sample) return sample class SAMDataProvider(DataProvider): name = "sam" def __init__( self, root: str, sub_epochs_per_epoch: int, num_masks: int, train_batch_size: int, test_batch_size: int, valid_size: int or float or None = None, n_worker=8, image_size: int = 1024, num_replicas: int or None = None, rank: int or None = None, train_ratio: float or None = None, drop_last: bool = False, ): self.root = root self.num_masks = num_masks self.sub_epochs_per_epoch = sub_epochs_per_epoch super().__init__( train_batch_size, test_batch_size, valid_size, n_worker, image_size, num_replicas, rank, train_ratio, drop_last, ) def build_train_transform(self): train_transforms = [ RandomHFlip(), ResizeLongestSide(target_length=self.image_size[0]), Normalize_and_Pad(target_length=self.image_size[0]), ] return transforms.Compose(train_transforms) def build_valid_transform(self): valid_transforms = [ ResizeLongestSide(target_length=self.image_size[0]), Normalize_and_Pad(target_length=self.image_size[0]), ] return transforms.Compose(valid_transforms) def build_datasets(self) -> tuple[any, any, any]: train_transform = self.build_train_transform() valid_transform = self.build_valid_transform() train_dataset = OnlineDataset(root=self.root, train=True, num_masks=self.num_masks, transform=train_transform) val_dataset = OnlineDataset(root=self.root, train=False, num_masks=2, transform=valid_transform) test_dataset = None return train_dataset, val_dataset, test_dataset def build_dataloader(self, dataset: any or None, batch_size: int, n_worker: int, drop_last: bool, train: bool): if dataset is None: return None if train: sampler = SAMDistributedSampler(dataset, sub_epochs_per_epoch=self.sub_epochs_per_epoch) dataloader = DataLoader(dataset, batch_size, sampler=sampler, drop_last=True, num_workers=n_worker) return dataloader else: sampler = DistributedSampler(dataset, shuffle=False) dataloader = DataLoader(dataset, batch_size, sampler=sampler, drop_last=False, num_workers=n_worker) return dataloader def set_epoch_and_sub_epoch(self, epoch: int, sub_epoch: int) -> None: if isinstance(self.train.sampler, SAMDistributedSampler): self.train.sampler.set_epoch_and_sub_epoch(epoch, sub_epoch)