from spacy.pipeline.ner import EntityRecognizer from spacy.language import Language from thinc.api import Config from sklearn.metrics import f1_score, precision_recall_fscore_support import plotly.express as px import plotly.graph_objects as go import time import json import os from pathlib import Path default_model_config = """ [model] @architectures = "spacy.TransitionBasedParser.v2" state_type = "ner" extra_state_tokens = false hidden_width = 64 maxout_pieces = 2 use_upper = false nO = null [model.tok2vec] @architectures = "spacy-transformers.TransformerListener.v1" grad_factor = 1.0 pooling = {"@layers":"reduce_mean.v1"} upstream = "*" """ DEFAULT_MODEL = Config().from_str(default_model_config)["model"] @Language.factory("ner_all_metrics", default_config={ "model": DEFAULT_MODEL, "moves": None, "scorer": {"@scorers": "spacy.ner_scorer.v1"}, "incorrect_spans_key": None, "update_with_oracle_cut_size": 100, "eval_frequency": 100, }, default_score_weights={ "f1_micro": 1.0, "f1_macro": 1.0, "f1_weighted": 1.0, "f1_COMPONENT": 1.0, "f1_SYSTEM": 1.0, "f1_ATTRIBUTE": 1.0, "ents_p": 0.0, "ents_r": 0.0, }) def create_ner_all_metrics( nlp, name, model, moves, scorer, incorrect_spans_key, update_with_oracle_cut_size, eval_frequency ): return NERWithAllMetrics( nlp.vocab, model, name=name, moves=moves, scorer=scorer, incorrect_spans_key=incorrect_spans_key, update_with_oracle_cut_size=update_with_oracle_cut_size, eval_frequency=eval_frequency ) class NERWithAllMetrics(EntityRecognizer): def __init__(self, *args, eval_frequency=100, **kwargs): super().__init__(*args, **kwargs) self.metric_history = [] self.max_f1 = 0 self.max_f1_step = 0 self.eval_frequency = eval_frequency self.start_learning_time = None def score(self, examples, **kwargs): scores = super().score(examples, **kwargs) scores = dict(list(scores.items()) + list(self.custom_scorer(examples).items())) tmp_scores = scores.copy() tmp_scores["step"] = len(self.metric_history) * self.eval_frequency if tmp_scores["f1_macro"] > self.max_f1: self.max_f1 = tmp_scores["f1_macro"] self.max_f1_step = tmp_scores["step"] self.metric_history.append(tmp_scores) return scores def custom_scorer(self, examples): y_true = [] y_pred = [] for example in examples: gold = {(ent.start_char, ent.end_char, ent.label_) for ent in example.reference.ents} pred = {(ent.start_char, ent.end_char, ent.label_) for ent in example.predicted.ents} all_spans = gold | pred for span in all_spans: if span in gold and span in pred: y_true.append(span[2]) y_pred.append(span[2]) elif span in gold: y_true.append(span[2]) y_pred.append("O") elif span in pred: y_true.append("O") y_pred.append(span[2]) labels = sorted({label for label in y_true if label != "O"}) precision, recall, f1, support = precision_recall_fscore_support( y_true, y_pred, labels=labels, zero_division=0, average=None ) result = {} for l, p, r, f in zip(labels, precision, recall, f1): result[f"f1_{l}"] = f result["f1_micro"] = f1_score(y_true, y_pred, average="micro", labels=labels, zero_division=0) result["f1_macro"] = f1_score(y_true, y_pred, average="macro", labels=labels, zero_division=0) result["f1_weighted"] = f1_score(y_true, y_pred, average="weighted", labels=labels, zero_division=0) return result def preprocess_metric_history(self): result = { "metric_name": [], "metric_value": [], "step": [] } for cur_metrics in self.metric_history: cur_step = cur_metrics["step"] for key, value in cur_metrics.items(): if key != "step" and isinstance(value, float): result["metric_name"].append(key) result["metric_value"].append(value) result["step"].append(cur_step) return result def save_metrics_history(self, path): if self.start_learning_time is None: self.start_learning_time = time.monotonic() if self.metric_history: metrics_history_to_save = self.preprocess_metric_history() fig = px.line(metrics_history_to_save, x="step", y="metric_value", color="metric_name") for trace in fig.data: if trace.name in ["f1_micro", "f1_macro", "f1_weighted"]: trace.line.width = 6 else: trace.line.width = 1 idx = list(trace.x).index(self.max_f1_step) highlight_y = list(trace.y)[idx] line_color = trace.line.color line_name = trace.name fig.add_trace(go.Scatter( x=[self.max_f1_step], y=[highlight_y], mode='markers+text', marker=dict( color=line_color, size=10), text=[f"{round(highlight_y, 2)}"], textposition="top center", name=f"{line_name} best" )) current_time = time.monotonic() current_time_of_training = current_time - self.start_learning_time current_time_of_training_text = f"{int(current_time_of_training // 3600)} hrs {int(current_time_of_training % 3600) // 60} min {round(current_time_of_training % 60)} sec" fig.update_layout(title = dict( text="Training statistics", subtitle=dict( text=f"Training time amounted to {current_time_of_training_text}", font=dict(color="gray", size=13), ) )) output_dir = os.path.join(str(path), "logs") os.makedirs(output_dir, exist_ok=True) fig_path = os.path.join(output_dir, "training_metrics.html") json_path = os.path.join(output_dir, "training_metrics.json") fig.write_html(fig_path) with open(json_path, "w", encoding="utf-8") as f: json.dump({ "data": metrics_history_to_save, "train_time_s": current_time_of_training }, f, indent=2, ensure_ascii=False) def to_disk(self, path, *args, **kwargs): super().to_disk(path, *args, **kwargs) output_dir = Path(path) output_dir_metrics = output_dir.parent.parent self.save_metrics_history(output_dir_metrics)