from transformers import PreTrainedModel | |
from .configuration_test import TestConfig | |
import torch.nn as nn | |
from transformers import AutoModelForMaskedLM, AutoConfig | |
from transformers import AutoModelForSequenceClassification | |
class TestModel(PreTrainedModel): | |
config_class = TestConfig | |
def __init__(self, config: TestConfig): | |
super().__init__(config) | |
self.input_dim = config.input_dim | |
self.model1 = nn.Linear(config.input_dim, config.output_dim) | |
# self.model2 = AutoModelForMaskedLM.from_config( | |
# AutoConfig.from_pretrained("bert-base-uncased") | |
# ) | |
def forward(self, tensor): | |
return self.model1(tensor) | |