from transformers import PreTrainedModel from .ReconResNetBase import ReconResNetBase from .ReconResNetConfig import ReconResNetConfig class ReconResNet(PreTrainedModel): config_class = ReconResNetConfig def __init__(self, config): super().__init__(config) self.model = ReconResNetBase( in_channels=config.in_channels, out_channels=config.out_channels, res_blocks=config.res_blocks, starting_nfeatures=config.starting_nfeatures, updown_blocks=config.updown_blocks, is_relu_leaky=config.is_relu_leaky, do_batchnorm=config.do_batchnorm, res_drop_prob=config.res_drop_prob, is_replicatepad=config.is_replicatepad, out_act=config.out_act, forwardV=config.forwardV, upinterp_algo=config.upinterp_algo, post_interp_convtrans=config.post_interp_convtrans, is3D=config.is3D) def forward(self, x): return self.model(x)