from transformers import PretrainedConfig | |
class ResNetConfig(PretrainedConfig): | |
model_type = 'resnet20' | |
def __init__( | |
self, | |
block_type="basic", | |
num_classes: int=10, | |
**kwargs | |
): | |
if block_type not in ["basic"]: | |
raise ValueError(f"Invalid block type {block_type}") | |
self.block_type = block_type | |
self.num_classes = num_classes | |
super().__init__(**kwargs) |