File size: 1,256 Bytes
843a3ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from transformers import PretrainedConfig
from typing import List
class ReconResNetConfig(PretrainedConfig):
model_type = "ReconResNet"
def __init__(
self,
in_channels=1,
out_channels=1,
res_blocks=14,
starting_nfeatures=64,
updown_blocks=2,
is_relu_leaky=True,
do_batchnorm=False,
res_drop_prob=0.2,
is_replicatepad=0,
out_act="sigmoid",
forwardV=0,
upinterp_algo='convtrans',
post_interp_convtrans=False,
is3D=False,
**kwargs):
self.in_channels = in_channels
self.out_channels = out_channels
self.res_blocks = res_blocks
self.starting_nfeatures = starting_nfeatures
self.updown_blocks = updown_blocks
self.is_relu_leaky = is_relu_leaky
self.do_batchnorm = do_batchnorm
self.res_drop_prob = res_drop_prob
self.is_replicatepad = is_replicatepad
self.out_act = out_act
self.forwardV = forwardV
self.upinterp_algo = upinterp_algo
self.post_interp_convtrans = post_interp_convtrans
self.is3D = is3D
super().__init__(**kwargs)
|