|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
class MultiPredictModel(nn.Module): |
|
def __init__(self, encoder, disasterLarge_labels, disasterMedium_labels, urgencyLevel_labels): |
|
super(MultiPredictModel, self).__init__() |
|
self.encoder = encoder |
|
hidden_size = self.encoder.config.hidden_size |
|
|
|
|
|
self.classifiers = nn.ModuleDict(modules={ |
|
'urgencyLevel' : self._build_classifier(input_size=hidden_size, output_size=len(urgencyLevel_labels)), |
|
'disasterLarge' : self._build_classifier(input_size=hidden_size, output_size=len(disasterLarge_labels)), |
|
'disasterMedium' : self._build_classifier(input_size=hidden_size, output_size=len(disasterMedium_labels)), |
|
}) |
|
|
|
self.labels = { |
|
'urgencyLevel' : urgencyLevel_labels, |
|
'disasterLarge' : disasterLarge_labels, |
|
'disasterMedium' : disasterMedium_labels |
|
} |
|
|
|
def _build_classifier(self, input_size, output_size): |
|
return nn.Sequential( |
|
nn.Linear(in_features=input_size, out_features=(input_size//2)), |
|
nn.BatchNorm1d(num_features=(input_size//2)), |
|
nn.ReLU(), |
|
nn.Dropout(p=0.1), |
|
nn.Linear(in_features=(input_size//2), out_features=(input_size//4)), |
|
nn.BatchNorm1d(num_features=input_size//4), |
|
nn.ReLU(), |
|
nn.Dropout(p=0.1), |
|
nn.Linear(in_features=(input_size//4), out_features=output_size) |
|
) |
|
|
|
def forward(self, input_ids, attention_mask, disasterLarge=None, disasterMedium=None, urgencyLevel=None): |
|
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
|
|
|
pooled_output = outputs.last_hidden_state.mean(dim=1).float() |
|
|
|
|
|
logits = {} |
|
for key, classifier in self.classifiers.items(): |
|
logits[key] = classifier(pooled_output) |
|
|
|
total_loss = None |
|
CON = ( |
|
(urgencyLevel is not None) and |
|
(disasterLarge is not None) and |
|
(disasterMedium is not None) |
|
) |
|
if CON : |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss_ul = loss_fct(logits['urgencyLevel'], urgencyLevel.long()) |
|
loss_dl = loss_fct(logits['disasterLarge'], disasterLarge.long()) |
|
loss_dm = loss_fct(logits['disasterMedium'], disasterMedium.long()) |
|
total_loss = loss_ul + loss_dl + loss_dm |
|
output = { |
|
'loss' : total_loss, |
|
'logits' : { |
|
'logits_ul': logits['urgencyLevel'], |
|
'logits_dl': logits['disasterLarge'], |
|
'logits_dm': logits['disasterMedium'] |
|
} |
|
} |
|
return output |
|
|
|
def predict(self, texts, tokenizer, device): |
|
self.eval() |
|
|
|
if isinstance(texts, str): |
|
texts = [texts] |
|
|
|
encodings = tokenizer( |
|
text=texts, |
|
truncation=True, |
|
padding=True, |
|
max_length=128, |
|
return_tensors='pt' |
|
).to(device=device) |
|
|
|
with torch.no_grad(): |
|
outputs = self.forward( |
|
input_ids=encodings['input_ids'], |
|
attention_mask=encodings['attention_mask'] |
|
) |
|
|
|
predictions = [] |
|
batch_size = encodings['input_ids'].shape[0] |
|
for i in range(batch_size): |
|
pred = {} |
|
total_values = {} |
|
best_values = {} |
|
for key in self.classifiers.keys(): |
|
if key == 'urgencyLevel': |
|
logits = outputs['logits']['logits_ul'][i] |
|
elif key == 'disasterLarge': |
|
logits = outputs['logits']['logits_dl'][i] |
|
elif key == 'disasterMedium': |
|
logits = outputs['logits']['logits_dm'][i] |
|
else: |
|
raise KeyError(f"Unknown key: {key}") |
|
|
|
probs = torch.nn.functional.softmax(input=logits, dim=0).cpu() |
|
pred_class = self.labels[key][torch.argmax(input=probs).item()] |
|
best_values[key] = pred_class |
|
total_values[f'labels_{key}'] = self.labels[key] |
|
total_values[f'probs_{key}'] = probs.numpy().tolist() |
|
|
|
predictions.append({ |
|
'bestValues' : best_values, |
|
'totalValues' : total_values |
|
}) |
|
return predictions |
|
|