|
from transformers import PreTrainedModel |
|
from .w_net_3d import WNet3dUNet, WNet3dAttUNet, WNet3dUNetMSS |
|
from .WNetConfigs import WNet3DConfig, AttWNet3DConfig, WNetMSS3DConfig |
|
|
|
class WNet3D(PreTrainedModel): |
|
config_class = WNet3DConfig |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = WNet3dUNet( |
|
in_ch=config.in_ch, |
|
out_ch=config.out_ch, |
|
init_features=config.init_features) |
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
class AttWNet3D(PreTrainedModel): |
|
config_class = AttWNet3DConfig |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = WNet3dAttUNet( |
|
in_ch=config.in_ch, |
|
out_ch=config.out_ch, |
|
init_features=config.init_features) |
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
class WNetMSS3D(PreTrainedModel): |
|
config_class = WNetMSS3DConfig |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = WNet3dUNetMSS( |
|
in_ch=config.in_ch, |
|
out_ch=config.out_ch, |
|
init_features=config.init_features) |
|
def forward(self, x): |
|
return self.model(x) |