File size: 508 Bytes
137840f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from transformers import PreTrainedModel

from share import ResidualBlock, ResNet20
from .configuration_resnet20 import ResNetConfig

BLOCK_MAPPING = {"basic": ResidualBlock}

class ResNet20Model(PreTrainedModel):
    config_class = ResNetConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = ResNet20(
            block_type=BLOCK_MAPPING[config.block_type],
            num_classes=config.num_classes
        )

    def forward(self, x):
        return self.model(x)