from transformers import PretrainedConfig class ResNetConfig(PretrainedConfig): model_type = 'resnet20' def __init__( self, block_type="basic", num_classes: int=10, **kwargs ): if block_type not in ["basic"]: raise ValueError(f"Invalid block type {block_type}") self.block_type = block_type self.num_classes = num_classes super().__init__(**kwargs)