from transformers import PretrainedConfig from typing import List class WNet3DConfig(PretrainedConfig): model_type = "WNet" def __init__( self, in_ch=1, out_ch=5, init_features=64, **kwargs): self.in_ch = in_ch self.out_ch = out_ch self.init_features = init_features super().__init__(**kwargs) class AttWNet3DConfig(PretrainedConfig): model_type = "AttWNet" def __init__( self, in_ch=1, out_ch=5, init_features=64, **kwargs): self.in_ch = in_ch self.out_ch = out_ch self.init_features = init_features super().__init__(**kwargs) class WNetMSS3DConfig(PretrainedConfig): model_type = "WNetMSS" def __init__( self, in_ch=1, out_ch=5, init_features=64, **kwargs): self.in_ch = in_ch self.out_ch = out_ch self.init_features = init_features super().__init__(**kwargs)