File size: 1,014 Bytes
9e773fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
from transformers import PreTrainedModel
from .ReconResNetBase import ReconResNetBase
from .ReconResNetConfig import ReconResNetConfig
class ReconResNet(PreTrainedModel):
config_class = ReconResNetConfig
def __init__(self, config):
super().__init__(config)
self.model = ReconResNetBase(
in_channels=config.in_channels,
out_channels=config.out_channels,
res_blocks=config.res_blocks,
starting_nfeatures=config.starting_nfeatures,
updown_blocks=config.updown_blocks,
is_relu_leaky=config.is_relu_leaky,
do_batchnorm=config.do_batchnorm,
res_drop_prob=config.res_drop_prob,
is_replicatepad=config.is_replicatepad,
out_act=config.out_act,
forwardV=config.forwardV,
upinterp_algo=config.upinterp_algo,
post_interp_convtrans=config.post_interp_convtrans,
is3D=config.is3D)
def forward(self, x):
return self.model(x) |