|
from transformers import PreTrainedModel |
|
from .ReconResNetBase import ReconResNetBase |
|
from .ReconResNetConfig import ReconResNetConfig |
|
|
|
class ReconResNet(PreTrainedModel): |
|
config_class = ReconResNetConfig |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = ReconResNetBase( |
|
in_channels=config.in_channels, |
|
out_channels=config.out_channels, |
|
res_blocks=config.res_blocks, |
|
starting_nfeatures=config.starting_nfeatures, |
|
updown_blocks=config.updown_blocks, |
|
is_relu_leaky=config.is_relu_leaky, |
|
do_batchnorm=config.do_batchnorm, |
|
res_drop_prob=config.res_drop_prob, |
|
is_replicatepad=config.is_replicatepad, |
|
out_act=config.out_act, |
|
forwardV=config.forwardV, |
|
upinterp_algo=config.upinterp_algo, |
|
post_interp_convtrans=config.post_interp_convtrans, |
|
is3D=config.is3D) |
|
def forward(self, x): |
|
return self.model(x) |