from transformers import PreTrainedModel | |
from share import ResidualBlock, ResNet20 | |
from .configuration_resnet20 import ResNetConfig | |
BLOCK_MAPPING = {"basic": ResidualBlock} | |
class ResNet20Model(PreTrainedModel): | |
config_class = ResNetConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = ResNet20( | |
block_type=BLOCK_MAPPING[config.block_type], | |
num_classes=config.num_classes | |
) | |
def forward(self, x): | |
return self.model(x) |