import torchio as tio from typing import Union, Optional, Sequence from torchio.typing import TypeTripletInt from torchio import Subject, Image from torchio.utils import to_tuple class CropOrPad_None(tio.CropOrPad): def __init__( self, target_shape: Union[int, TypeTripletInt, None] = None, padding_mode: Union[str, float] = 0, mask_name: Optional[str] = None, labels: Optional[Sequence[int]] = None, **kwargs ): # WARNING: Ugly workaround to allow None values if target_shape is not None: self.original_target_shape = to_tuple(target_shape, length=3) target_shape = [1 if t_s is None else t_s for t_s in target_shape] super().__init__(target_shape, padding_mode, mask_name, labels, **kwargs) def apply_transform(self, subject: Subject): # WARNING: This makes the transformation subject dependent - reverse transformation must be adapted if self.target_shape is not None: self.target_shape = [s_s if t_s is None else t_s for t_s, s_s in zip(self.original_target_shape, subject.spatial_shape)] return super().apply_transform(subject=subject) class SubjectToTensor(object): """Transforms TorchIO Subjects into a Python dict and changes axes order from TorchIO to Torch""" def __call__(self, subject: Subject): return {key: val.data.swapaxes(1,-1) if isinstance(val, Image) else val for key,val in subject.items()} class ImageToTensor(object): """Transforms TorchIO Image into a Numpy/Torch Tensor and changes axes order from TorchIO [B, C, W, H, D] to Torch [B, C, D, H, W]""" def __call__(self, image: Image): return image.data.swapaxes(1,-1)