|
import torch |
|
import lightning.pytorch as pl |
|
from tqdm import tqdm |
|
from sklearn.metrics import f1_score, accuracy_score |
|
from torch.nn import BCEWithLogitsLoss |
|
from transformers import ( |
|
AutoModelForSequenceClassification, |
|
AutoTokenizer, |
|
get_constant_schedule_with_warmup, |
|
) |
|
|
|
class FinanciaMultilabel(pl.LightningModule): |
|
|
|
def __init__(self, model, num_labels): |
|
super().__init__() |
|
self.model = model |
|
self.num_labels = num_labels |
|
self.loss = BCEWithLogitsLoss() |
|
self.validation_step_outputs = [] |
|
|
|
def forward(self, input_ids, attention_mask, token_type_ids): |
|
return self.model(input_ids, attention_mask, token_type_ids).logits |
|
|
|
def training_step(self, batch, batch_idx): |
|
input_ids = batch["input_ids"] |
|
attention_mask = batch["attention_mask"] |
|
labels = batch["labels"] |
|
token_type_ids = batch["token_type_ids"] |
|
outputs = self(input_ids, attention_mask, token_type_ids) |
|
loss = self.loss(outputs.view(-1,self.num_labels), labels.type_as(outputs).view(-1,self.num_labels)) |
|
self.log('train_loss', loss) |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
input_ids = batch["input_ids"] |
|
attention_mask = batch["attention_mask"] |
|
labels = batch["labels"] |
|
token_type_ids = batch["token_type_ids"] |
|
outputs = self(input_ids, attention_mask, token_type_ids) |
|
loss = self.loss(outputs.view(-1,self.num_labels), labels.type_as(outputs).view(-1,self.num_labels)) |
|
pred_labels = torch.sigmoid(outputs) |
|
info = {'val_loss': loss, 'pred_labels': pred_labels, 'labels': labels} |
|
self.validation_step_outputs.append(info) |
|
return |
|
|
|
def on_validation_epoch_end(self): |
|
outputs = self.validation_step_outputs |
|
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() |
|
pred_labels = torch.cat([x['pred_labels'] for x in outputs]) |
|
labels = torch.cat([x['labels'] for x in outputs]) |
|
threshold = 0.50 |
|
pred_bools = pred_labels > threshold |
|
true_bools = labels == 1 |
|
val_f1_accuracy = f1_score(true_bools.cpu(), pred_bools.cpu(), average='micro')*100 |
|
val_flat_accuracy = accuracy_score(true_bools.cpu(), pred_bools.cpu())*100 |
|
self.log('val_loss', avg_loss) |
|
self.log('val_f1_accuracy', val_f1_accuracy, prog_bar=True) |
|
self.log('val_flat_accuracy', val_flat_accuracy, prog_bar=True) |
|
self.validation_step_outputs.clear() |
|
|
|
def configure_optimizers(self): |
|
optimizer = torch.optim.AdamW(self.parameters(), lr=2e-5) |
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=2, verbose=True, min_lr=1e-6) |
|
return { |
|
'optimizer': optimizer, |
|
'lr_scheduler': { |
|
'scheduler': scheduler, |
|
'monitor': 'val_loss' |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
def load_model(checkpoint_path, model, num_labels, device): |
|
model_hugginface = AutoModelForSequenceClassification.from_pretrained(model, num_labels=num_labels, ignore_mismatched_sizes=True) |
|
model = FinanciaMultilabel.load_from_checkpoint( |
|
checkpoint_path, |
|
model=model_hugginface, |
|
num_labels=num_labels, |
|
map_location=device |
|
) |
|
return model |