fakeJobPredictor / model.py
sebastiansarasti's picture
adding the files for the app
4736ae1 verified
import torch.nn as nn
from torch import cat
from transformers import DistilBertModel
class JobFakeModel(nn.Module):
def __init__(self, base_model, freeze_base):
super(JobFakeModel, self).__init__()
self.base_model = base_model
self.fc = nn.Sequential(
nn.Linear(768*3, 600),
nn.ReLU(),
nn.Linear(600, 300),
nn.ReLU(),
nn.Linear(300, 1)
)
self.head1, self.head2, self.head3 = self._create_base_model()
if freeze_base:
for param in self.head1.parameters():
param.requires_grad = False
for param in self.head2.parameters():
param.requires_grad = False
for param in self.head3.parameters():
param.requires_grad = False
def forward(self, x, y , z):
x = self.head1(**x).last_hidden_state.mean(dim=1)
y = self.head2(**y).last_hidden_state.mean(dim=1)
z = self.head3(**z).last_hidden_state.mean(dim=1)
output = cat([x, y, z], dim=1)
output = self.fc(output)
return output
def _create_base_model(self):
if self.base_model == "distilbert":
model1 = DistilBertModel.from_pretrained("distilbert-base-uncased")
model2 = DistilBertModel.from_pretrained("distilbert-base-uncased")
model3 = DistilBertModel.from_pretrained("distilbert-base-uncased")
return model1, model2, model3
else:
raise ValueError("Model not supported")