soumickmj's picture
Upload ReconResNet
b84dcb2 verified
raw
history blame
1.01 kB
from transformers import PreTrainedModel
from .ReconResNetBase import ReconResNetBase
from .ReconResNetConfig import ReconResNetConfig
class ReconResNet(PreTrainedModel):
config_class = ReconResNetConfig
def __init__(self, config):
super().__init__(config)
self.model = ReconResNetBase(
in_channels=config.in_channels,
out_channels=config.out_channels,
res_blocks=config.res_blocks,
starting_nfeatures=config.starting_nfeatures,
updown_blocks=config.updown_blocks,
is_relu_leaky=config.is_relu_leaky,
do_batchnorm=config.do_batchnorm,
res_drop_prob=config.res_drop_prob,
is_replicatepad=config.is_replicatepad,
out_act=config.out_act,
forwardV=config.forwardV,
upinterp_algo=config.upinterp_algo,
post_interp_convtrans=config.post_interp_convtrans,
is3D=config.is3D)
def forward(self, x):
return self.model(x)