osbm's picture
Upload 9 files
3953219
raw
history blame
11.9 kB
# create transforms for training, validation and test dataset
## TODO: Make Transforms more dynamic by directly building from config args
## Maybe like this
## TFM_NAME=config.transforms.keys()[0]
## tfm_fun=getattr(monai.transforms, TFM_NAME)
## tmfs+=[tfms_fun(keys=image+cols, **config.transforms[TFM_NAME], prob=prob, mode=mode)
## ---------- imports ----------
import os
# only import of base transforms, others are imported as needed
from monai.utils.enums import CommonKeys
from monai.transforms import (
Activationsd,
AsDiscreted,
Compose,
ConcatItemsd,
KeepLargestConnectedComponentd,
LoadImaged,
EnsureChannelFirstd,
EnsureTyped,
SaveImaged,
ScaleIntensityd,
NormalizeIntensityd
)
# images should be interploated with `bilinear` but masks with `nearest`
## ---------- base transforms ----------
# applied everytime
def get_base_transforms(
config: dict,
minv: int=0,
maxv: int=1
)->list:
tfms=[]
tfms+=[LoadImaged(keys=config.data.image_cols+config.data.label_cols)]
tfms+=[EnsureChannelFirstd(keys=config.data.image_cols+config.data.label_cols)]
if config.transforms.spacing:
from monai.transforms import Spacingd
tfms+=[
Spacingd(
keys=config.data.image_cols+config.data.label_cols,
pixdim=config.transforms.spacing,
mode=config.transforms.mode
)
]
if config.transforms.orientation:
from monai.transforms import Orientationd
tfms+=[
Orientationd(
keys=config.data.image_cols+config.data.label_cols,
axcodes=config.transforms.orientation
)
]
tfms+=[
ScaleIntensityd(
keys=config.data.image_cols,
minv=minv,
maxv=maxv
)
]
tfms+=[NormalizeIntensityd(keys=config.data.image_cols)]
return tfms
## ---------- train transforms ----------
def get_train_transforms(config: dict):
tfms=get_base_transforms(config=config)
# ---------- specific transforms for mri ----------
if 'rand_bias_field' in config.transforms.keys():
from monai.transforms import RandBiasFieldd
args=config.transforms.rand_bias_field
tfms+=[
RandBiasFieldd(
keys=config.data.image_cols,
degree=args['degree'],
coeff_range=args['coeff_range'],
prob=config.transforms.prob
)
]
if 'rand_gaussian_smooth' in config.transforms.keys():
from monai.transforms import RandGaussianSmoothd
args=config.transforms.rand_gaussian_smooth
tfms+=[
RandGaussianSmoothd(
keys=config.data.image_cols,
sigma_x=args['sigma_x'],
sigma_y=args['sigma_y'],
sigma_z=args['sigma_z'],
prob=config.transforms.prob
)
]
if 'rand_gibbs_nose' in config.transforms.keys():
from monai.transforms import RandGibbsNoised
args=config.transforms.rand_gibbs_nose
tfms+=[
RandGibbsNoised(
keys=config.data.image_cols,
alpha=args['alpha'],
prob=config.transforms.prob
)
]
# ---------- affine transforms ----------
if 'rand_affine' in config.transforms.keys():
from monai.transforms import RandAffined
args=config.transforms.rand_affine
tfms+=[
RandAffined(
keys=config.data.image_cols+config.data.label_cols,
rotate_range=args['rotate_range'],
shear_range=args['shear_range'],
translate_range=args['translate_range'],
mode=config.transforms.mode,
prob=config.transforms.prob
)
]
if 'rand_rotate90' in config.transforms.keys():
from monai.transforms import RandRotate90d
args=config.transforms.rand_rotate90
tfms+=[
RandRotate90d(
keys=config.data.image_cols+config.data.label_cols,
spatial_axes=args['spatial_axes'],
prob=config.transforms.prob
)
]
if 'rand_rotate' in config.transforms.keys():
from monai.transforms import RandRotated
args=config.transforms.rand_rotate
tfms+=[
RandRotated(
keys=config.data.image_cols+config.data.label_cols,
range_x=args['range_x'],
range_y=args['range_y'],
range_z=args['range_z'],
mode=config.transforms.mode,
prob=config.transforms.prob
)
]
if 'rand_elastic' in config.transforms.keys():
if config['ndim'] == 3:
from monai.transforms import Rand3DElasticd as RandElasticd
elif config['ndim'] == 2:
from monai.transforms import Rand2DElasticd as RandElasticd
args=config.transforms.rand_elastic
tfms+=[
RandElasticd(
keys=config.data.image_cols+config.data.label_cols,
sigma_range=args['sigma_range'],
magnitude_range=args['magnitude_range'],
rotate_range=args['rotate_range'],
shear_range=args['shear_range'],
translate_range=args['translate_range'],
mode=config.transforms.mode,
prob=config.transforms.prob
)
]
if 'rand_zoom' in config.transforms.keys():
from monai.transforms import RandZoomd
args=config.transforms.rand_zoom
tfms+=[
RandZoomd(
keys=config.data.image_cols+config.data.label_cols,
min_zoom=args['min'],
max_zoom=args['max'],
mode=['area' if x == 'bilinear' else x for x in config.transforms.mode],
prob=config.transforms.prob
)
]
# ---------- random cropping, very effective for large images ----------
# RandCropByPosNegLabeld is not advisable for data with missing lables
# e.g., segmentation of carcinomas which are not present on all images
# thus fallback to RandSpatialCropSamplesd. Completly replacing Cropping
# by just resizing could be discussed, but I believe it is not beneficial
# For the first version, this is an ungly hack. For the second version,
# a better verion for transforms should be written.
if 'rand_crop_pos_neg_label' in config.transforms.keys():
from monai.transforms import RandCropByPosNegLabeld
args=config.transforms.rand_crop_pos_neg_label
tfms+=[
RandCropByPosNegLabeld(
keys=config.data.image_cols+config.data.label_cols,
label_key=config.data.label_cols[0],
spatial_size=args['spatial_size'],
pos=args['pos'],
neg=args['neg'],
num_samples=args['num_samples'],
image_key=config.data.image_cols[0],
image_threshold=0,
)
]
elif 'rand_spatial_crop_samples' in config.transforms.keys():
from monai.transforms import RandSpatialCropSamplesd
args=config.transforms.rand_spatial_crop_samples
tfms+=[
RandSpatialCropSamplesd(
keys=config.data.image_cols+config.data.label_cols,
roi_size=args['roi_size'],
random_size=False,
num_samples=args['num_samples'],
)
]
else:
raise ValueError('Either `rand_crop_pos_neg_label` or `rand_spatial_crop_samples` '\
'need to be specified')
# ---------- intensity transforms ----------
if 'gaussian_noise' in config.transforms.keys():
from monai.transforms import RandGaussianNoised
args=config.transforms.gaussian_noise
tfms+=[
RandGaussianNoised(
keys=config.data.image_cols,
mean=args['mean'],
std=args['std'],
prob=config.transforms.prob
)
]
if 'shift_intensity' in config.transforms.keys():
from monai.transforms import RandShiftIntensityd
args=config.transforms.shift_intensity
tfms+=[
RandShiftIntensityd(
keys=config.data.image_cols,
offsets=args['offsets'],
prob=config.transforms.prob
)
]
if 'gaussian_sharpen' in config.transforms.keys():
from monai.transforms import RandGaussianSharpend
args=config.transforms.gaussian_sharpen
tfms+=[
RandGaussianSharpend(
keys=config.data.image_cols,
sigma1_x=args['sigma1_x'],
sigma1_y=args['sigma1_y'],
sigma1_z=args['sigma1_z'],
sigma2_x=args['sigma2_x'],
sigma2_y=args['sigma2_y'],
sigma2_z=args['sigma2_z'],
alpha=args['alpha'],
prob=config.transforms.prob
)
]
if 'adjust_contrast' in config.transforms.keys():
from monai.transforms import RandAdjustContrastd
args=config.transforms.adjust_contrast
tfms+=[
RandAdjustContrastd(
keys=config.data.image_cols,
gamma=args['gamma'],
prob=config.transforms.prob
)
]
# Concat mutlisequence data to single Tensors on the ChannelDim
# Rename images to `CommonKeys.IMAGE` and labels to `CommonKeys.LABELS`
# for more compatibility with monai.engines
tfms+=[
ConcatItemsd(
keys=config.data.image_cols,
name=CommonKeys.IMAGE,
dim=0
)
]
tfms+=[
ConcatItemsd(
keys=config.data.label_cols,
name=CommonKeys.LABEL,
dim=0
)
]
return Compose(tfms)
## ---------- valid transforms ----------
def get_val_transforms(config: dict):
tfms=get_base_transforms(config=config)
tfms+=[EnsureTyped(keys=config.data.image_cols+config.data.label_cols)]
tfms+=[
ConcatItemsd(
keys=config.data.image_cols,
name=CommonKeys.IMAGE,
dim=0
)
]
tfms+=[
ConcatItemsd(
keys=config.data.label_cols,
name=CommonKeys.LABEL,
dim=0
)
]
return Compose(tfms)
## ---------- test transforms ----------
# same as valid transforms
def get_test_transforms(config: dict):
tfms=get_base_transforms(config=config)
tfms+=[EnsureTyped(keys=config.data.image_cols+config.data.label_cols)]
tfms+=[
ConcatItemsd(
keys=config.data.image_cols,
name=CommonKeys.IMAGE,
dim=0
)
]
tfms+=[
ConcatItemsd(
keys=config.data.label_cols,
name=CommonKeys.LABEL,
dim=0
)
]
return Compose(tfms)
def get_val_post_transforms(config: dict):
tfms=[EnsureTyped(keys=[CommonKeys.PRED, CommonKeys.LABEL]),
AsDiscreted(
keys=CommonKeys.PRED,
argmax=True,
to_onehot=config.model.out_channels,
num_classes=config.model.out_channels
),
AsDiscreted(
keys=CommonKeys.LABEL,
to_onehot=config.model.out_channels,
num_classes=config.model.out_channels
),
KeepLargestConnectedComponentd(
keys=CommonKeys.PRED,
applied_labels=list(range(1, config.model.out_channels))
),
]
return Compose(tfms)