|
import torch.nn as nn |
|
from torch import cat |
|
from transformers import DistilBertModel |
|
|
|
class JobFakeModel(nn.Module): |
|
def __init__(self, base_model, freeze_base): |
|
super(JobFakeModel, self).__init__() |
|
self.base_model = base_model |
|
self.fc = nn.Sequential( |
|
nn.Linear(768*3, 600), |
|
nn.ReLU(), |
|
nn.Linear(600, 300), |
|
nn.ReLU(), |
|
nn.Linear(300, 1) |
|
) |
|
self.head1, self.head2, self.head3 = self._create_base_model() |
|
|
|
if freeze_base: |
|
for param in self.head1.parameters(): |
|
param.requires_grad = False |
|
for param in self.head2.parameters(): |
|
param.requires_grad = False |
|
for param in self.head3.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, x, y , z): |
|
x = self.head1(**x).last_hidden_state.mean(dim=1) |
|
y = self.head2(**y).last_hidden_state.mean(dim=1) |
|
z = self.head3(**z).last_hidden_state.mean(dim=1) |
|
output = cat([x, y, z], dim=1) |
|
output = self.fc(output) |
|
return output |
|
|
|
def _create_base_model(self): |
|
if self.base_model == "distilbert": |
|
model1 = DistilBertModel.from_pretrained("distilbert-base-uncased") |
|
model2 = DistilBertModel.from_pretrained("distilbert-base-uncased") |
|
model3 = DistilBertModel.from_pretrained("distilbert-base-uncased") |
|
return model1, model2, model3 |
|
else: |
|
raise ValueError("Model not supported") |
|
|