Spaces:
Runtime error
Runtime error
import argparse | |
import glob | |
import logging | |
import os | |
from argparse import Namespace | |
from importlib import import_module | |
import numpy as np | |
import torch | |
from lightning_base import BaseTransformer, add_generic_args, generic_train | |
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score | |
from torch.nn import CrossEntropyLoss | |
from torch.utils.data import DataLoader, TensorDataset | |
from utils_ner import TokenClassificationTask | |
logger = logging.getLogger(__name__) | |
class NERTransformer(BaseTransformer): | |
""" | |
A training module for NER. See BaseTransformer for the core options. | |
""" | |
mode = "token-classification" | |
def __init__(self, hparams): | |
if type(hparams) == dict: | |
hparams = Namespace(**hparams) | |
module = import_module("tasks") | |
try: | |
token_classification_task_clazz = getattr(module, hparams.task_type) | |
self.token_classification_task: TokenClassificationTask = token_classification_task_clazz() | |
except AttributeError: | |
raise ValueError( | |
f"Task {hparams.task_type} needs to be defined as a TokenClassificationTask subclass in {module}. " | |
f"Available tasks classes are: {TokenClassificationTask.__subclasses__()}" | |
) | |
self.labels = self.token_classification_task.get_labels(hparams.labels) | |
self.pad_token_label_id = CrossEntropyLoss().ignore_index | |
super().__init__(hparams, len(self.labels), self.mode) | |
def forward(self, **inputs): | |
return self.model(**inputs) | |
def training_step(self, batch, batch_num): | |
"Compute loss and log." | |
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} | |
if self.config.model_type != "distilbert": | |
inputs["token_type_ids"] = ( | |
batch[2] if self.config.model_type in ["bert", "xlnet"] else None | |
) # XLM and RoBERTa don"t use token_type_ids | |
outputs = self(**inputs) | |
loss = outputs[0] | |
# tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]} | |
return {"loss": loss} | |
def prepare_data(self): | |
"Called to initialize data. Use the call to construct features" | |
args = self.hparams | |
for mode in ["train", "dev", "test"]: | |
cached_features_file = self._feature_file(mode) | |
if os.path.exists(cached_features_file) and not args.overwrite_cache: | |
logger.info("Loading features from cached file %s", cached_features_file) | |
features = torch.load(cached_features_file) | |
else: | |
logger.info("Creating features from dataset file at %s", args.data_dir) | |
examples = self.token_classification_task.read_examples_from_file(args.data_dir, mode) | |
features = self.token_classification_task.convert_examples_to_features( | |
examples, | |
self.labels, | |
args.max_seq_length, | |
self.tokenizer, | |
cls_token_at_end=bool(self.config.model_type in ["xlnet"]), | |
cls_token=self.tokenizer.cls_token, | |
cls_token_segment_id=2 if self.config.model_type in ["xlnet"] else 0, | |
sep_token=self.tokenizer.sep_token, | |
sep_token_extra=False, | |
pad_on_left=bool(self.config.model_type in ["xlnet"]), | |
pad_token=self.tokenizer.pad_token_id, | |
pad_token_segment_id=self.tokenizer.pad_token_type_id, | |
pad_token_label_id=self.pad_token_label_id, | |
) | |
logger.info("Saving features into cached file %s", cached_features_file) | |
torch.save(features, cached_features_file) | |
def get_dataloader(self, mode: int, batch_size: int, shuffle: bool = False) -> DataLoader: | |
"Load datasets. Called after prepare data." | |
cached_features_file = self._feature_file(mode) | |
logger.info("Loading features from cached file %s", cached_features_file) | |
features = torch.load(cached_features_file) | |
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) | |
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) | |
if features[0].token_type_ids is not None: | |
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) | |
else: | |
all_token_type_ids = torch.tensor([0 for f in features], dtype=torch.long) | |
# HACK(we will not use this anymore soon) | |
all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long) | |
return DataLoader( | |
TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_label_ids), batch_size=batch_size | |
) | |
def validation_step(self, batch, batch_nb): | |
"""Compute validation""" "" | |
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} | |
if self.config.model_type != "distilbert": | |
inputs["token_type_ids"] = ( | |
batch[2] if self.config.model_type in ["bert", "xlnet"] else None | |
) # XLM and RoBERTa don"t use token_type_ids | |
outputs = self(**inputs) | |
tmp_eval_loss, logits = outputs[:2] | |
preds = logits.detach().cpu().numpy() | |
out_label_ids = inputs["labels"].detach().cpu().numpy() | |
return {"val_loss": tmp_eval_loss.detach().cpu(), "pred": preds, "target": out_label_ids} | |
def _eval_end(self, outputs): | |
"Evaluation called for both Val and Test" | |
val_loss_mean = torch.stack([x["val_loss"] for x in outputs]).mean() | |
preds = np.concatenate([x["pred"] for x in outputs], axis=0) | |
preds = np.argmax(preds, axis=2) | |
out_label_ids = np.concatenate([x["target"] for x in outputs], axis=0) | |
label_map = dict(enumerate(self.labels)) | |
out_label_list = [[] for _ in range(out_label_ids.shape[0])] | |
preds_list = [[] for _ in range(out_label_ids.shape[0])] | |
for i in range(out_label_ids.shape[0]): | |
for j in range(out_label_ids.shape[1]): | |
if out_label_ids[i, j] != self.pad_token_label_id: | |
out_label_list[i].append(label_map[out_label_ids[i][j]]) | |
preds_list[i].append(label_map[preds[i][j]]) | |
results = { | |
"val_loss": val_loss_mean, | |
"accuracy_score": accuracy_score(out_label_list, preds_list), | |
"precision": precision_score(out_label_list, preds_list), | |
"recall": recall_score(out_label_list, preds_list), | |
"f1": f1_score(out_label_list, preds_list), | |
} | |
ret = dict(results.items()) | |
ret["log"] = results | |
return ret, preds_list, out_label_list | |
def validation_epoch_end(self, outputs): | |
# when stable | |
ret, preds, targets = self._eval_end(outputs) | |
logs = ret["log"] | |
return {"val_loss": logs["val_loss"], "log": logs, "progress_bar": logs} | |
def test_epoch_end(self, outputs): | |
# updating to test_epoch_end instead of deprecated test_end | |
ret, predictions, targets = self._eval_end(outputs) | |
# Converting to the dict required by pl | |
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/\ | |
# pytorch_lightning/trainer/logging.py#L139 | |
logs = ret["log"] | |
# `val_loss` is the key returned by `self._eval_end()` but actually refers to `test_loss` | |
return {"avg_test_loss": logs["val_loss"], "log": logs, "progress_bar": logs} | |
def add_model_specific_args(parser, root_dir): | |
# Add NER specific options | |
BaseTransformer.add_model_specific_args(parser, root_dir) | |
parser.add_argument( | |
"--task_type", default="NER", type=str, help="Task type to fine tune in training (e.g. NER, POS, etc)" | |
) | |
parser.add_argument( | |
"--max_seq_length", | |
default=128, | |
type=int, | |
help=( | |
"The maximum total input sequence length after tokenization. Sequences longer " | |
"than this will be truncated, sequences shorter will be padded." | |
), | |
) | |
parser.add_argument( | |
"--labels", | |
default="", | |
type=str, | |
help="Path to a file containing all labels. If not specified, CoNLL-2003 labels are used.", | |
) | |
parser.add_argument( | |
"--gpus", | |
default=0, | |
type=int, | |
help="The number of GPUs allocated for this, it is by default 0 meaning none", | |
) | |
parser.add_argument( | |
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" | |
) | |
return parser | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
add_generic_args(parser, os.getcwd()) | |
parser = NERTransformer.add_model_specific_args(parser, os.getcwd()) | |
args = parser.parse_args() | |
model = NERTransformer(args) | |
trainer = generic_train(model, args) | |
if args.do_predict: | |
# See https://github.com/huggingface/transformers/issues/3159 | |
# pl use this default format to create a checkpoint: | |
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\ | |
# /pytorch_lightning/callbacks/model_checkpoint.py#L322 | |
checkpoints = sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True)) | |
model = model.load_from_checkpoint(checkpoints[-1]) | |
trainer.test(model) | |