|
import torch |
|
import torch.nn.functional as F |
|
from transformers import AutoModel, PreTrainedModel |
|
|
|
from .config import AutextificationMTLConfig |
|
|
|
|
|
class AutextificationMTLModel(PreTrainedModel): |
|
config_class = AutextificationMTLConfig |
|
|
|
def __init__(self, config: AutextificationMTLConfig): |
|
super().__init__(config) |
|
|
|
self.encoder = AutoModel.from_pretrained(config.transformer_name) |
|
embedding_size = self.encoder.config.hidden_size |
|
self.hidden = torch.nn.Linear(embedding_size, config.hidden_nodes) |
|
self.out_generated = torch.nn.Linear(config.hidden_nodes, 1) |
|
self.out_language = torch.nn.Linear(config.hidden_nodes, 1) |
|
|
|
self.threshold = config.threshold |
|
|
|
def forward(self, tensor): |
|
output = self.encoder( |
|
input_ids=tensor["input_ids"], |
|
attention_mask=tensor["attention_mask"], |
|
return_dict=True |
|
) |
|
pooler_output = output["pooler_output"] |
|
|
|
out = F.relu(self.hidden(pooler_output)) |
|
out_generated = torch.sigmoid(self.out_generated(out)) |
|
out_language = torch.sigmoid(self.out_language(out)) |
|
out_verdict = out_generated > self.threshold |
|
|
|
return { |
|
"is_bot": out_verdict, |
|
"bot_prob": out_generated, |
|
"english_prob": out_language, |
|
} |
|
|