|
from transformers import PreTrainedModel
|
|
from .configuration_test import TestConfig
|
|
import torch.nn as nn
|
|
from transformers import AutoModelForMaskedLM, AutoConfig
|
|
from transformers import AutoModelForSequenceClassification
|
|
|
|
|
|
class TestModel(PreTrainedModel):
|
|
config_class = TestConfig
|
|
|
|
def __init__(self, config: TestConfig):
|
|
super().__init__(config)
|
|
self.input_dim = config.input_dim
|
|
self.model1 = nn.Linear(config.input_dim, config.output_dim)
|
|
self.model2 = AutoModelForMaskedLM.from_config(
|
|
AutoConfig.from_pretrained("albert/albert-base-v2")
|
|
)
|
|
|
|
def forward(self, tensor):
|
|
return self.model1(tensor)
|
|
|