pandrei7's picture
Save the first version of the model
56dfd9c
raw
history blame
1.33 kB
import torch
import torch.nn.functional as F
from transformers import AutoModel, PreTrainedModel
from .config import AutextificationMTLConfig
class AutextificationMTLModel(PreTrainedModel):
config_class = AutextificationMTLConfig
def __init__(self, config: AutextificationMTLConfig):
super().__init__(config)
self.encoder = AutoModel.from_pretrained(config.transformer_name)
embedding_size = self.encoder.config.hidden_size
self.hidden = torch.nn.Linear(embedding_size, config.hidden_nodes)
self.out_generated = torch.nn.Linear(config.hidden_nodes, 1)
self.out_language = torch.nn.Linear(config.hidden_nodes, 1)
self.threshold = config.threshold
def forward(self, tensor):
output = self.encoder(
input_ids=tensor["input_ids"],
attention_mask=tensor["attention_mask"],
return_dict=True
)
pooler_output = output["pooler_output"]
out = F.relu(self.hidden(pooler_output))
out_generated = torch.sigmoid(self.out_generated(out))
out_language = torch.sigmoid(self.out_language(out))
out_verdict = out_generated > self.threshold
return {
"is_bot": out_verdict,
"bot_prob": out_generated,
"english_prob": out_language,
}