import torch from transformers import PreTrainedModel from .configuration_resnet import ResnetConfig class ResnetModel(PreTrainedModel): config_class = ResnetConfig def __init__(self, config): super().__init__(config) self.model = torch.nn.Linear(5, 10) def forward(self, tensor): return self.model.forward_features(tensor)