import torch import torch.nn as nn class MultiPredictModel(nn.Module): def __init__(self, encoder, disasterLarge_labels, disasterMedium_labels, urgencyLevel_labels): super(MultiPredictModel, self).__init__() self.encoder = encoder hidden_size = self.encoder.config.hidden_size # self.dropout = nn.Dropout(p=0.1) self.classifiers = nn.ModuleDict(modules={ 'urgencyLevel' : self._build_classifier(input_size=hidden_size, output_size=len(urgencyLevel_labels)), 'disasterLarge' : self._build_classifier(input_size=hidden_size, output_size=len(disasterLarge_labels)), 'disasterMedium' : self._build_classifier(input_size=hidden_size, output_size=len(disasterMedium_labels)), }) self.labels = { 'urgencyLevel' : urgencyLevel_labels, 'disasterLarge' : disasterLarge_labels, 'disasterMedium' : disasterMedium_labels } def _build_classifier(self, input_size, output_size): return nn.Sequential( nn.Linear(in_features=input_size, out_features=(input_size//2)), nn.BatchNorm1d(num_features=(input_size//2)), nn.ReLU(), nn.Dropout(p=0.1), nn.Linear(in_features=(input_size//2), out_features=(input_size//4)), nn.BatchNorm1d(num_features=input_size//4), nn.ReLU(), nn.Dropout(p=0.1), nn.Linear(in_features=(input_size//4), out_features=output_size) ) def forward(self, input_ids, attention_mask, disasterLarge=None, disasterMedium=None, urgencyLevel=None): outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) # pooled_output = outputs.last_hidden_state[:, 0, :].float() pooled_output = outputs.last_hidden_state.mean(dim=1).float() # pooled_output = self.dropout(pooled_output) logits = {} for key, classifier in self.classifiers.items(): logits[key] = classifier(pooled_output) total_loss = None CON = ( (urgencyLevel is not None) and (disasterLarge is not None) and (disasterMedium is not None) ) if CON : loss_fct = nn.CrossEntropyLoss() loss_ul = loss_fct(logits['urgencyLevel'], urgencyLevel.long()) loss_dl = loss_fct(logits['disasterLarge'], disasterLarge.long()) loss_dm = loss_fct(logits['disasterMedium'], disasterMedium.long()) total_loss = loss_ul + loss_dl + loss_dm output = { 'loss' : total_loss, 'logits' : { 'logits_ul': logits['urgencyLevel'], 'logits_dl': logits['disasterLarge'], 'logits_dm': logits['disasterMedium'] } } return output def predict(self, texts, tokenizer, device): self.eval() if isinstance(texts, str): texts = [texts] encodings = tokenizer( text=texts, truncation=True, padding=True, max_length=128, return_tensors='pt' ).to(device=device) with torch.no_grad(): outputs = self.forward( input_ids=encodings['input_ids'], attention_mask=encodings['attention_mask'] ) predictions = [] batch_size = encodings['input_ids'].shape[0] for i in range(batch_size): pred = {} total_values = {} best_values = {} for key in self.classifiers.keys(): if key == 'urgencyLevel': logits = outputs['logits']['logits_ul'][i] elif key == 'disasterLarge': logits = outputs['logits']['logits_dl'][i] elif key == 'disasterMedium': logits = outputs['logits']['logits_dm'][i] else: raise KeyError(f"Unknown key: {key}") probs = torch.nn.functional.softmax(input=logits, dim=0).cpu() pred_class = self.labels[key][torch.argmax(input=probs).item()] best_values[key] = pred_class total_values[f'labels_{key}'] = self.labels[key] total_values[f'probs_{key}'] = probs.numpy().tolist() predictions.append({ 'bestValues' : best_values, 'totalValues' : total_values }) return predictions