File size: 4,618 Bytes
6a29894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

import torch
import torch.nn as nn

class MultiPredictModel(nn.Module):
    def __init__(self, encoder, disasterLarge_labels, disasterMedium_labels, urgencyLevel_labels):
        super(MultiPredictModel, self).__init__()
        self.encoder = encoder
        hidden_size = self.encoder.config.hidden_size
        
        # self.dropout = nn.Dropout(p=0.1)
        self.classifiers = nn.ModuleDict(modules={
            'urgencyLevel'   : self._build_classifier(input_size=hidden_size, output_size=len(urgencyLevel_labels)),
            'disasterLarge'  : self._build_classifier(input_size=hidden_size, output_size=len(disasterLarge_labels)),
            'disasterMedium' : self._build_classifier(input_size=hidden_size, output_size=len(disasterMedium_labels)),
        })
    
        self.labels = {
            'urgencyLevel'   : urgencyLevel_labels,
            'disasterLarge'  : disasterLarge_labels,
            'disasterMedium' : disasterMedium_labels
        }
    
    def _build_classifier(self, input_size, output_size):
        return nn.Sequential(
            nn.Linear(in_features=input_size, out_features=(input_size//2)),
            nn.BatchNorm1d(num_features=(input_size//2)),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(in_features=(input_size//2), out_features=(input_size//4)),
            nn.BatchNorm1d(num_features=input_size//4),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(in_features=(input_size//4), out_features=output_size)
        )
    
    def forward(self, input_ids, attention_mask, disasterLarge=None, disasterMedium=None, urgencyLevel=None):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        # pooled_output = outputs.last_hidden_state[:, 0, :].float()
        pooled_output = outputs.last_hidden_state.mean(dim=1).float()
        # pooled_output = self.dropout(pooled_output)
    
        logits = {}
        for key, classifier in self.classifiers.items():
            logits[key] = classifier(pooled_output)
    
        total_loss = None
        CON = (
            (urgencyLevel is not None) and
            (disasterLarge is not None) and
            (disasterMedium is not None)
        )
        if CON :
            loss_fct = nn.CrossEntropyLoss()
            loss_ul = loss_fct(logits['urgencyLevel'], urgencyLevel.long())
            loss_dl = loss_fct(logits['disasterLarge'], disasterLarge.long())
            loss_dm = loss_fct(logits['disasterMedium'], disasterMedium.long())
            total_loss = loss_ul + loss_dl + loss_dm
        output = {
            'loss'   : total_loss,
            'logits' : {
                'logits_ul': logits['urgencyLevel'],
                'logits_dl': logits['disasterLarge'],
                'logits_dm': logits['disasterMedium']
            }
        }
        return output
    
    def predict(self, texts, tokenizer, device):
        self.eval()
        
        if isinstance(texts, str):
            texts = [texts]
        
        encodings = tokenizer(
            text=texts,
            truncation=True,
            padding=True,
            max_length=128,
            return_tensors='pt'
        ).to(device=device)
        
        with torch.no_grad():
            outputs = self.forward(
                input_ids=encodings['input_ids'],
                attention_mask=encodings['attention_mask']
            )
        
        predictions = []
        batch_size = encodings['input_ids'].shape[0]
        for i in range(batch_size):
            pred = {}
            total_values = {}
            best_values = {}
            for key in self.classifiers.keys():
                if key == 'urgencyLevel':
                    logits = outputs['logits']['logits_ul'][i]
                elif key == 'disasterLarge':
                    logits = outputs['logits']['logits_dl'][i]
                elif key == 'disasterMedium':
                    logits = outputs['logits']['logits_dm'][i]
                else:
                    raise KeyError(f"Unknown key: {key}")
                
                probs = torch.nn.functional.softmax(input=logits, dim=0).cpu()
                pred_class = self.labels[key][torch.argmax(input=probs).item()]
                best_values[key] = pred_class
                total_values[f'labels_{key}'] = self.labels[key]
                total_values[f'probs_{key}'] = probs.numpy().tolist()
            
            predictions.append({
                'bestValues'  : best_values,
                'totalValues' : total_values
            })
        return predictions