import torch.nn as nn from torch import cat from transformers import DistilBertModel class JobFakeModel(nn.Module): def __init__(self, base_model, freeze_base): super(JobFakeModel, self).__init__() self.base_model = base_model self.fc = nn.Sequential( nn.Linear(768*3, 600), nn.ReLU(), nn.Linear(600, 300), nn.ReLU(), nn.Linear(300, 1) ) self.head1, self.head2, self.head3 = self._create_base_model() if freeze_base: for param in self.head1.parameters(): param.requires_grad = False for param in self.head2.parameters(): param.requires_grad = False for param in self.head3.parameters(): param.requires_grad = False def forward(self, x, y , z): x = self.head1(**x).last_hidden_state.mean(dim=1) y = self.head2(**y).last_hidden_state.mean(dim=1) z = self.head3(**z).last_hidden_state.mean(dim=1) output = cat([x, y, z], dim=1) output = self.fc(output) return output def _create_base_model(self): if self.base_model == "distilbert": model1 = DistilBertModel.from_pretrained("distilbert-base-uncased") model2 = DistilBertModel.from_pretrained("distilbert-base-uncased") model3 = DistilBertModel.from_pretrained("distilbert-base-uncased") return model1, model2, model3 else: raise ValueError("Model not supported")