from transformers import PreTrainedModel from share import ResidualBlock, ResNet20 from .configuration_resnet20 import ResNetConfig BLOCK_MAPPING = {"basic": ResidualBlock} class ResNet20Model(PreTrainedModel): config_class = ResNetConfig def __init__(self, config): super().__init__(config) self.model = ResNet20( block_type=BLOCK_MAPPING[config.block_type], num_classes=config.num_classes ) def forward(self, x): return self.model(x)