import torch import torch.nn.functional as F from transformers import AutoModel, PreTrainedModel from .config import AutextificationMTLConfig class AutextificationMTLModel(PreTrainedModel): config_class = AutextificationMTLConfig def __init__(self, config: AutextificationMTLConfig): super().__init__(config) self.encoder = AutoModel.from_pretrained(config.transformer_name) embedding_size = self.encoder.config.hidden_size self.hidden = torch.nn.Linear(embedding_size, config.hidden_nodes) self.out_generated = torch.nn.Linear(config.hidden_nodes, 1) self.out_language = torch.nn.Linear(config.hidden_nodes, 1) self.threshold = config.threshold def forward(self, tensor): output = self.encoder( input_ids=tensor["input_ids"], attention_mask=tensor["attention_mask"], return_dict=True ) pooler_output = output["pooler_output"] out = F.relu(self.hidden(pooler_output)) out_generated = torch.sigmoid(self.out_generated(out)) out_language = torch.sigmoid(self.out_language(out)) out_verdict = out_generated > self.threshold return { "is_bot": out_verdict, "bot_prob": out_generated, "english_prob": out_language, }