|
from spacy.pipeline.ner import EntityRecognizer
|
|
from spacy.language import Language
|
|
from thinc.api import Config
|
|
from sklearn.metrics import f1_score, precision_recall_fscore_support
|
|
import plotly.express as px
|
|
import plotly.graph_objects as go
|
|
import time
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
|
|
|
|
default_model_config = """
|
|
[model]
|
|
@architectures = "spacy.TransitionBasedParser.v2"
|
|
state_type = "ner"
|
|
extra_state_tokens = false
|
|
hidden_width = 64
|
|
maxout_pieces = 2
|
|
use_upper = false
|
|
nO = null
|
|
|
|
[model.tok2vec]
|
|
@architectures = "spacy-transformers.TransformerListener.v1"
|
|
grad_factor = 1.0
|
|
pooling = {"@layers":"reduce_mean.v1"}
|
|
upstream = "*"
|
|
"""
|
|
DEFAULT_MODEL = Config().from_str(default_model_config)["model"]
|
|
|
|
|
|
@Language.factory("ner_all_metrics",
|
|
default_config={
|
|
"model": DEFAULT_MODEL,
|
|
"moves": None,
|
|
"scorer": {"@scorers": "spacy.ner_scorer.v1"},
|
|
"incorrect_spans_key": None,
|
|
"update_with_oracle_cut_size": 100,
|
|
"eval_frequency": 100,
|
|
},
|
|
default_score_weights={
|
|
"f1_micro": 1.0,
|
|
"f1_macro": 1.0,
|
|
"f1_weighted": 1.0,
|
|
"f1_COMPONENT": 1.0,
|
|
"f1_SYSTEM": 1.0,
|
|
"f1_ATTRIBUTE": 1.0,
|
|
"ents_p": 0.0,
|
|
"ents_r": 0.0,
|
|
})
|
|
def create_ner_all_metrics(
|
|
nlp, name,
|
|
model, moves,
|
|
scorer, incorrect_spans_key,
|
|
update_with_oracle_cut_size, eval_frequency
|
|
):
|
|
return NERWithAllMetrics(
|
|
nlp.vocab, model,
|
|
name=name, moves=moves,
|
|
scorer=scorer, incorrect_spans_key=incorrect_spans_key,
|
|
update_with_oracle_cut_size=update_with_oracle_cut_size, eval_frequency=eval_frequency
|
|
)
|
|
|
|
|
|
class NERWithAllMetrics(EntityRecognizer):
|
|
|
|
def __init__(self, *args, eval_frequency=100, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.metric_history = []
|
|
self.max_f1 = 0
|
|
self.max_f1_step = 0
|
|
self.eval_frequency = eval_frequency
|
|
self.start_learning_time = None
|
|
|
|
def score(self, examples, **kwargs):
|
|
scores = super().score(examples, **kwargs)
|
|
scores = dict(list(scores.items()) + list(self.custom_scorer(examples).items()))
|
|
tmp_scores = scores.copy()
|
|
tmp_scores["step"] = len(self.metric_history) * self.eval_frequency
|
|
if tmp_scores["f1_macro"] > self.max_f1:
|
|
self.max_f1 = tmp_scores["f1_macro"]
|
|
self.max_f1_step = tmp_scores["step"]
|
|
self.metric_history.append(tmp_scores)
|
|
return scores
|
|
|
|
def custom_scorer(self, examples):
|
|
y_true = []
|
|
y_pred = []
|
|
for example in examples:
|
|
gold = {(ent.start_char, ent.end_char, ent.label_) for ent in example.reference.ents}
|
|
pred = {(ent.start_char, ent.end_char, ent.label_) for ent in example.predicted.ents}
|
|
all_spans = gold | pred
|
|
for span in all_spans:
|
|
if span in gold and span in pred:
|
|
y_true.append(span[2])
|
|
y_pred.append(span[2])
|
|
elif span in gold:
|
|
y_true.append(span[2])
|
|
y_pred.append("O")
|
|
elif span in pred:
|
|
y_true.append("O")
|
|
y_pred.append(span[2])
|
|
|
|
labels = sorted({label for label in y_true if label != "O"})
|
|
|
|
precision, recall, f1, support = precision_recall_fscore_support(
|
|
y_true, y_pred, labels=labels, zero_division=0, average=None
|
|
)
|
|
result = {}
|
|
for l, p, r, f in zip(labels, precision, recall, f1):
|
|
result[f"f1_{l}"] = f
|
|
|
|
result["f1_micro"] = f1_score(y_true, y_pred, average="micro", labels=labels, zero_division=0)
|
|
result["f1_macro"] = f1_score(y_true, y_pred, average="macro", labels=labels, zero_division=0)
|
|
result["f1_weighted"] = f1_score(y_true, y_pred, average="weighted", labels=labels, zero_division=0)
|
|
|
|
return result
|
|
|
|
def preprocess_metric_history(self):
|
|
result = {
|
|
"metric_name": [],
|
|
"metric_value": [],
|
|
"step": []
|
|
}
|
|
for cur_metrics in self.metric_history:
|
|
cur_step = cur_metrics["step"]
|
|
for key, value in cur_metrics.items():
|
|
if key != "step" and isinstance(value, float):
|
|
result["metric_name"].append(key)
|
|
result["metric_value"].append(value)
|
|
result["step"].append(cur_step)
|
|
return result
|
|
|
|
def save_metrics_history(self, path):
|
|
if self.start_learning_time is None:
|
|
self.start_learning_time = time.monotonic()
|
|
|
|
if self.metric_history:
|
|
|
|
metrics_history_to_save = self.preprocess_metric_history()
|
|
fig = px.line(metrics_history_to_save, x="step", y="metric_value", color="metric_name")
|
|
for trace in fig.data:
|
|
if trace.name in ["f1_micro", "f1_macro", "f1_weighted"]:
|
|
trace.line.width = 6
|
|
else:
|
|
trace.line.width = 1
|
|
|
|
idx = list(trace.x).index(self.max_f1_step)
|
|
highlight_y = list(trace.y)[idx]
|
|
line_color = trace.line.color
|
|
line_name = trace.name
|
|
fig.add_trace(go.Scatter(
|
|
x=[self.max_f1_step], y=[highlight_y],
|
|
mode='markers+text',
|
|
marker=dict(
|
|
color=line_color, size=10),
|
|
text=[f"{round(highlight_y, 2)}"],
|
|
textposition="top center",
|
|
name=f"{line_name} best"
|
|
))
|
|
|
|
current_time = time.monotonic()
|
|
current_time_of_training = current_time - self.start_learning_time
|
|
current_time_of_training_text = f"{int(current_time_of_training // 3600)} hrs {int(current_time_of_training % 3600) // 60} min {round(current_time_of_training % 60)} sec"
|
|
|
|
fig.update_layout(title = dict(
|
|
text="Training statistics",
|
|
subtitle=dict(
|
|
text=f"Training time amounted to {current_time_of_training_text}",
|
|
font=dict(color="gray", size=13),
|
|
)
|
|
))
|
|
|
|
output_dir = os.path.join(str(path), "logs")
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
fig_path = os.path.join(output_dir, "training_metrics.html")
|
|
json_path = os.path.join(output_dir, "training_metrics.json")
|
|
fig.write_html(fig_path)
|
|
with open(json_path, "w", encoding="utf-8") as f:
|
|
json.dump({
|
|
"data": metrics_history_to_save,
|
|
"train_time_s": current_time_of_training
|
|
}, f, indent=2, ensure_ascii=False)
|
|
|
|
def to_disk(self, path, *args, **kwargs):
|
|
super().to_disk(path, *args, **kwargs)
|
|
output_dir = Path(path)
|
|
output_dir_metrics = output_dir.parent.parent
|
|
self.save_metrics_history(output_dir_metrics)
|
|
|