import torch from transformers import PreTrainedModel from .configuration_resnet import ResinConfig class ResinModel(PreTrainedModel): config_class = ResinConfig def __init__(self, config): super().__init__(config) self.model = torch.nn.Linear(5, 10) def forward(self, tensor): return self.model.forward_features(tensor)