custom-resnet20 / configuration_resnet20.py
Junesnow's picture
Upload model
137840f
raw
history blame contribute delete
448 Bytes
from transformers import PretrainedConfig
class ResNetConfig(PretrainedConfig):
model_type = 'resnet20'
def __init__(
self,
block_type="basic",
num_classes: int=10,
**kwargs
):
if block_type not in ["basic"]:
raise ValueError(f"Invalid block type {block_type}")
self.block_type = block_type
self.num_classes = num_classes
super().__init__(**kwargs)