sanghe0n's picture
Upload folder using huggingface_hub
6a29894 verified
import torch
import torch.nn as nn
class MultiPredictModel(nn.Module):
def __init__(self, encoder, disasterLarge_labels, disasterMedium_labels, urgencyLevel_labels):
super(MultiPredictModel, self).__init__()
self.encoder = encoder
hidden_size = self.encoder.config.hidden_size
# self.dropout = nn.Dropout(p=0.1)
self.classifiers = nn.ModuleDict(modules={
'urgencyLevel' : self._build_classifier(input_size=hidden_size, output_size=len(urgencyLevel_labels)),
'disasterLarge' : self._build_classifier(input_size=hidden_size, output_size=len(disasterLarge_labels)),
'disasterMedium' : self._build_classifier(input_size=hidden_size, output_size=len(disasterMedium_labels)),
})
self.labels = {
'urgencyLevel' : urgencyLevel_labels,
'disasterLarge' : disasterLarge_labels,
'disasterMedium' : disasterMedium_labels
}
def _build_classifier(self, input_size, output_size):
return nn.Sequential(
nn.Linear(in_features=input_size, out_features=(input_size//2)),
nn.BatchNorm1d(num_features=(input_size//2)),
nn.ReLU(),
nn.Dropout(p=0.1),
nn.Linear(in_features=(input_size//2), out_features=(input_size//4)),
nn.BatchNorm1d(num_features=input_size//4),
nn.ReLU(),
nn.Dropout(p=0.1),
nn.Linear(in_features=(input_size//4), out_features=output_size)
)
def forward(self, input_ids, attention_mask, disasterLarge=None, disasterMedium=None, urgencyLevel=None):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
# pooled_output = outputs.last_hidden_state[:, 0, :].float()
pooled_output = outputs.last_hidden_state.mean(dim=1).float()
# pooled_output = self.dropout(pooled_output)
logits = {}
for key, classifier in self.classifiers.items():
logits[key] = classifier(pooled_output)
total_loss = None
CON = (
(urgencyLevel is not None) and
(disasterLarge is not None) and
(disasterMedium is not None)
)
if CON :
loss_fct = nn.CrossEntropyLoss()
loss_ul = loss_fct(logits['urgencyLevel'], urgencyLevel.long())
loss_dl = loss_fct(logits['disasterLarge'], disasterLarge.long())
loss_dm = loss_fct(logits['disasterMedium'], disasterMedium.long())
total_loss = loss_ul + loss_dl + loss_dm
output = {
'loss' : total_loss,
'logits' : {
'logits_ul': logits['urgencyLevel'],
'logits_dl': logits['disasterLarge'],
'logits_dm': logits['disasterMedium']
}
}
return output
def predict(self, texts, tokenizer, device):
self.eval()
if isinstance(texts, str):
texts = [texts]
encodings = tokenizer(
text=texts,
truncation=True,
padding=True,
max_length=128,
return_tensors='pt'
).to(device=device)
with torch.no_grad():
outputs = self.forward(
input_ids=encodings['input_ids'],
attention_mask=encodings['attention_mask']
)
predictions = []
batch_size = encodings['input_ids'].shape[0]
for i in range(batch_size):
pred = {}
total_values = {}
best_values = {}
for key in self.classifiers.keys():
if key == 'urgencyLevel':
logits = outputs['logits']['logits_ul'][i]
elif key == 'disasterLarge':
logits = outputs['logits']['logits_dl'][i]
elif key == 'disasterMedium':
logits = outputs['logits']['logits_dm'][i]
else:
raise KeyError(f"Unknown key: {key}")
probs = torch.nn.functional.softmax(input=logits, dim=0).cpu()
pred_class = self.labels[key][torch.argmax(input=probs).item()]
best_values[key] = pred_class
total_values[f'labels_{key}'] = self.labels[key]
total_values[f'probs_{key}'] = probs.numpy().tolist()
predictions.append({
'bestValues' : best_values,
'totalValues' : total_values
})
return predictions