File size: 1,333 Bytes
56dfd9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn.functional as F
from transformers import AutoModel, PreTrainedModel

from .config import AutextificationMTLConfig


class AutextificationMTLModel(PreTrainedModel):
    config_class = AutextificationMTLConfig

    def __init__(self, config: AutextificationMTLConfig):
        super().__init__(config)

        self.encoder = AutoModel.from_pretrained(config.transformer_name)
        embedding_size = self.encoder.config.hidden_size
        self.hidden = torch.nn.Linear(embedding_size, config.hidden_nodes)
        self.out_generated = torch.nn.Linear(config.hidden_nodes, 1)
        self.out_language = torch.nn.Linear(config.hidden_nodes, 1)

        self.threshold = config.threshold

    def forward(self, tensor):
        output = self.encoder(
            input_ids=tensor["input_ids"],
            attention_mask=tensor["attention_mask"],
            return_dict=True
        )
        pooler_output = output["pooler_output"]

        out = F.relu(self.hidden(pooler_output))
        out_generated = torch.sigmoid(self.out_generated(out))
        out_language = torch.sigmoid(self.out_language(out))
        out_verdict = out_generated > self.threshold

        return {
            "is_bot": out_verdict,
            "bot_prob": out_generated,
            "english_prob": out_language,
        }