File size: 448 Bytes
137840f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
from transformers import PretrainedConfig
class ResNetConfig(PretrainedConfig):
model_type = 'resnet20'
def __init__(
self,
block_type="basic",
num_classes: int=10,
**kwargs
):
if block_type not in ["basic"]:
raise ValueError(f"Invalid block type {block_type}")
self.block_type = block_type
self.num_classes = num_classes
super().__init__(**kwargs) |