from transformers import PreTrainedModel from .RRDB import RRDB from .unet3DMSS import UNetMSS from .SRMRIModelsConfigs import RRDBConfiguration, UNetMSSConfiguration class SRMRIModelUNetMSS(PreTrainedModel): config_class = UNetMSSConfiguration def __init__(self, config): super().__init__(config) self.model = UNetMSS( in_channels=config.in_channels, n_classes=config.n_classes, depth=config.depth, wf=config.wf, padding=config.padding, batch_norm=config.batch_norm, up_mode=config.up_mode, dropout=config.dropout, mss_level=config.mss_level, mss_fromlatent=config.mss_fromlatent, mss_up=config.mss_up, mss_interpb4=config.mss_interpb4) def forward(self, x): return self.model.forward(x) class SRMRIModelRRDB(PreTrainedModel): config_class = RRDBConfiguration def __init__(self, config): super().__init__(config) self.model = RRDB( nChannels=config.nChannels, nDenseLayers=config.nDenseLayers, nInitFeat=config.nInitFeat, GrowthRate=config.GrowthRate, featureFusion=config.featureFusion, kernel_config=config.kernel_config, ) def forward(self, x): return self.model.forward(x)