from transformers import PreTrainedModel from .configuration_test import TestConfig import torch.nn as nn from transformers import AutoModelForMaskedLM class TestModel(PreTrainedModel): config_class = TestConfig def __init__(self, config: TestConfig): super().__init__(config) self.model1 = nn.Linear(config.input_dim, config.output_dim) self.model2 = AutoModelForMaskedLM.from_pretrained("bert-base-uncased") self.model3 = None def forward(self, tensor): return self.model1(tensor)