from transformers import PreTrainedModel from .w_net_3d import WNet3dUNet, WNet3dAttUNet, WNet3dUNetMSS from .WNetConfigs import WNet3DConfig, AttWNet3DConfig, WNetMSS3DConfig class WNet3D(PreTrainedModel): config_class = WNet3DConfig def __init__(self, config): super().__init__(config) self.model = WNet3dUNet( in_ch=config.in_ch, out_ch=config.out_ch, init_features=config.init_features) def forward(self, x): return self.model(x) class AttWNet3D(PreTrainedModel): config_class = AttWNet3DConfig def __init__(self, config): super().__init__(config) self.model = WNet3dAttUNet( in_ch=config.in_ch, out_ch=config.out_ch, init_features=config.init_features) def forward(self, x): return self.model(x) class WNetMSS3D(PreTrainedModel): config_class = WNetMSS3DConfig def __init__(self, config): super().__init__(config) self.model = WNet3dUNetMSS( in_ch=config.in_ch, out_ch=config.out_ch, init_features=config.init_features) def forward(self, x): return self.model(x)