from transformers import PreTrainedModel from .unet3d import UNet, UNetDeepSup from .UNetConfigs import UNet3DConfig, UNetMSS3DConfig class UNet3D(PreTrainedModel): config_class = UNet3DConfig def __init__(self, config): super().__init__(config) self.model = UNet( in_ch=config.in_ch, out_ch=config.out_ch, init_features=config.init_features, dropout_rate=config.dropout_rate) def forward(self, x): return self.model(x) class UNetMSS3D(PreTrainedModel): config_class = UNetMSS3DConfig def __init__(self, config): super().__init__(config) self.model = UNetDeepSup( in_ch=config.in_ch, out_ch=config.out_ch, init_features=config.init_features, dropout_rate=config.dropout_rate) def forward(self, x): return self.model(x)