import numpy as np import torch from monai.transforms import InvertibleTransform from monai.transforms.transform import MapTransform class ConcatImages(MapTransform, InvertibleTransform): def __init__(self, keys_merge, keys_out, allow_missing_keys=True): self.keys_merge = keys_merge self.keys_out = keys_out self.key_target_meta = keys_merge[0] + "_meta_dict" self.allow_missing_keys = allow_missing_keys def __call__(self, data): if isinstance(data, list): for data_row in data: data_row[self.keys_out] = np.concatenate([data_row[key] for key in self.keys_merge]) data_row[self.keys_out + "_meta_dict"] = data_row[self.key_target_meta] else: data[self.keys_out] = np.concatenate([data[key] for key in self.keys_merge]) data[self.keys_out + "_meta_dict"] = data[self.key_target_meta] return data def inverse(self, data): return data class MergeClassesd(MapTransform): def __call__(self, data): for key in self.keys: if key in data: num_classes = data[key].size(-4) device = data[key].device merged = None for channel in data[key].squeeze() * torch.tensor(list(range(num_classes)), device=device).view( -1, 1, 1, 1 ): imgvol = channel if merged is not None: merged = merged + imgvol * ~((merged != 0) & (imgvol != 0)) else: merged = imgvol data[key] = merged.unsqueeze(0) elif not self.allow_missing_keys: raise KeyError( f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data" " and allow_missing_keys==False." ) return data