from transformers import PreTrainedModel from .configuration_test import TestConfig import torch.nn as nn from transformers import AutoModelForMaskedLM, AutoConfig from transformers import AutoModelForSequenceClassification class TestModel(PreTrainedModel): config_class = TestConfig def __init__(self, config: TestConfig): super().__init__(config) self.input_dim = config.input_dim self.model1 = nn.Linear(config.input_dim, config.output_dim) # self.model2 = AutoModelForMaskedLM.from_config( # AutoConfig.from_pretrained("bert-base-uncased") # ) def forward(self, tensor): return self.model1(tensor)