File size: 448 Bytes
137840f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from transformers import PretrainedConfig

class ResNetConfig(PretrainedConfig):
    model_type = 'resnet20'

    def __init__(
        self, 
        block_type="basic",
        num_classes: int=10,
        **kwargs
    ):
        if block_type not in ["basic"]:
            raise ValueError(f"Invalid block type {block_type}")
        
        self.block_type = block_type
        self.num_classes = num_classes
        super().__init__(**kwargs)