|
from transformers import PretrainedConfig |
|
from typing import List |
|
|
|
class ReconResNetConfig(PretrainedConfig): |
|
model_type = "ReconResNet" |
|
def __init__( |
|
self, |
|
in_channels=1, |
|
out_channels=1, |
|
res_blocks=14, |
|
starting_nfeatures=64, |
|
updown_blocks=2, |
|
is_relu_leaky=True, |
|
do_batchnorm=False, |
|
res_drop_prob=0.2, |
|
is_replicatepad=0, |
|
out_act="sigmoid", |
|
forwardV=0, |
|
upinterp_algo='convtrans', |
|
post_interp_convtrans=False, |
|
is3D=False, |
|
**kwargs): |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.res_blocks = res_blocks |
|
self.starting_nfeatures = starting_nfeatures |
|
self.updown_blocks = updown_blocks |
|
self.is_relu_leaky = is_relu_leaky |
|
self.do_batchnorm = do_batchnorm |
|
self.res_drop_prob = res_drop_prob |
|
self.is_replicatepad = is_replicatepad |
|
self.out_act = out_act |
|
self.forwardV = forwardV |
|
self.upinterp_algo = upinterp_algo |
|
self.post_interp_convtrans = post_interp_convtrans |
|
self.is3D = is3D |
|
super().__init__(**kwargs) |
|
|