|
from transformers import PreTrainedModel |
|
from .unet3d import UNet, UNetDeepSup |
|
from .UNetConfigs import UNet3DConfig, UNetMSS3DConfig |
|
|
|
class UNet3D(PreTrainedModel): |
|
config_class = UNet3DConfig |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = UNet( |
|
in_ch=config.in_ch, |
|
out_ch=config.out_ch, |
|
init_features=config.init_features, |
|
dropout_rate=config.dropout_rate) |
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
class UNetMSS3D(PreTrainedModel): |
|
config_class = UNetMSS3DConfig |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = UNetDeepSup( |
|
in_ch=config.in_ch, |
|
out_ch=config.out_ch, |
|
init_features=config.init_features, |
|
dropout_rate=config.dropout_rate) |
|
def forward(self, x): |
|
return self.model(x) |