File size: 2,526 Bytes
b208559
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from torch import nn
from torch.utils.data import Dataset

class SentencePairDataset(Dataset):
    def __init__(self, sentence_pairs, labels, tokenizer, max_length):
        self.sentence_pairs = sentence_pairs
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.sentence_pairs)

    def __getitem__(self, idx):
        sentence, context = self.sentence_pairs[idx]
        label = self.labels[idx]
        
        # Calculate max lengths for each part
        max_context_length = int(self.max_length * 0.8)  # 80% for context
        max_claim_length = int(self.max_length * 0.2)    # 20% for claim
        
        # Encode claim and context separately
        encoding = self.tokenizer.encode_plus(
            sentence,
            text_pair=context,
            add_special_tokens=True,
            max_length=self.max_length,
            padding="max_length",
            return_attention_mask=True,
            return_tensors="pt",
            truncation=True,
            truncation_strategy='longest_first'  # Truncate the longer text first
        )
        
        return {
            "input_ids": encoding["input_ids"].flatten(),
            "attention_mask": encoding["attention_mask"].flatten(),
            "label": torch.tensor(label, dtype=torch.long),
        }

class MBERTClassifier(nn.Module):
    def __init__(self, mbert, num_classes):
        super(MBERTClassifier, self).__init__()
        self.mbert = mbert
        self.layer_norm = nn.LayerNorm(self.mbert.config.hidden_size)
        self.dropout = nn.Dropout(0.2)
        self.batch_norm = nn.BatchNorm1d(self.mbert.config.hidden_size)
        self.linear = nn.LazyLinear(num_classes)
        self.activation = nn.ELU()

    def forward(self, input_ids, attention_mask):
        _, pooled_output = self.mbert(input_ids=input_ids, attention_mask=attention_mask, return_dict=False)
        norm_output = self.layer_norm(pooled_output)
        batch_norm_output = self.batch_norm(norm_output)
        logits = self.linear(batch_norm_output)
        activated_output = self.activation(logits)
        dropout_output = self.dropout(activated_output)
        return dropout_output

    def predict_proba(self, input_ids, attention_mask):
        logits = self.forward(input_ids, attention_mask)
        probabilities = torch.softmax(logits, dim=-1)
        return probabilities