File size: 1,014 Bytes
9e773fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from transformers import PreTrainedModel
from .ReconResNetBase import ReconResNetBase
from .ReconResNetConfig import ReconResNetConfig

class ReconResNet(PreTrainedModel):
    config_class = ReconResNetConfig
    def __init__(self, config):
        super().__init__(config)
        self.model = ReconResNetBase(
            in_channels=config.in_channels,
            out_channels=config.out_channels,
            res_blocks=config.res_blocks,
            starting_nfeatures=config.starting_nfeatures,
            updown_blocks=config.updown_blocks,
            is_relu_leaky=config.is_relu_leaky,
            do_batchnorm=config.do_batchnorm,
            res_drop_prob=config.res_drop_prob,
            is_replicatepad=config.is_replicatepad,
            out_act=config.out_act,
            forwardV=config.forwardV,
            upinterp_algo=config.upinterp_algo,
            post_interp_convtrans=config.post_interp_convtrans,
            is3D=config.is3D)
    def forward(self, x):
        return self.model(x)