import torch.utils.data as data from pathlib import Path from torchvision import transforms as T import torchio as tio from medical_diffusion.data.augmentation.augmentations_3d import ImageToTensor class SimpleDataset3D(data.Dataset): def __init__( self, path_root, item_pointers =[], crawler_ext = ['nii'], # other options are ['nii.gz'], transform = None, image_resize = None, flip = False, image_crop = None, use_znorm=True, # Use z-Norm for MRI as scale is arbitrary, otherwise scale intensity to [-1, 1] ): super().__init__() self.path_root = path_root self.crawler_ext = crawler_ext if transform is None: self.transform = T.Compose([ tio.Resize(image_resize) if image_resize is not None else tio.Lambda(lambda x: x), tio.RandomFlip((0,1,2)) if flip else tio.Lambda(lambda x: x), tio.CropOrPad(image_crop) if image_crop is not None else tio.Lambda(lambda x: x), tio.ZNormalization() if use_znorm else tio.RescaleIntensity((-1,1)), ImageToTensor() # [C, W, H, D] -> [C, D, H, W] ]) else: self.transform = transform if len(item_pointers): self.item_pointers = item_pointers else: self.item_pointers = self.run_item_crawler(self.path_root, self.crawler_ext) def __len__(self): return len(self.item_pointers) def __getitem__(self, index): rel_path_item = self.item_pointers[index] path_item = self.path_root/rel_path_item img = self.load_item(path_item) return {'uid':rel_path_item.stem, 'source': self.transform(img)} def load_item(self, path_item): return tio.ScalarImage(path_item) # Consider to use this or tio.ScalarLabel over SimpleITK (sitk.ReadImage(str(path_item))) @classmethod def run_item_crawler(cls, path_root, extension, **kwargs): return [path.relative_to(path_root) for path in Path(path_root).rglob(f'*.{extension}')]