ru_patents_ner-tiny / custom_factory.py
Delicalib's picture
Update spaCy pipeline
b832029 verified
from spacy.pipeline.ner import EntityRecognizer
from spacy.language import Language
from thinc.api import Config
from sklearn.metrics import f1_score, precision_recall_fscore_support
import plotly.express as px
import plotly.graph_objects as go
import time
import json
import os
from pathlib import Path
default_model_config = """
[model]
@architectures = "spacy.TransitionBasedParser.v2"
state_type = "ner"
extra_state_tokens = false
hidden_width = 64
maxout_pieces = 2
use_upper = false
nO = null
[model.tok2vec]
@architectures = "spacy-transformers.TransformerListener.v1"
grad_factor = 1.0
pooling = {"@layers":"reduce_mean.v1"}
upstream = "*"
"""
DEFAULT_MODEL = Config().from_str(default_model_config)["model"]
@Language.factory("ner_all_metrics",
default_config={
"model": DEFAULT_MODEL,
"moves": None,
"scorer": {"@scorers": "spacy.ner_scorer.v1"},
"incorrect_spans_key": None,
"update_with_oracle_cut_size": 100,
"eval_frequency": 100,
},
default_score_weights={
"f1_micro": 1.0,
"f1_macro": 1.0,
"f1_weighted": 1.0,
"f1_COMPONENT": 1.0,
"f1_SYSTEM": 1.0,
"f1_ATTRIBUTE": 1.0,
"ents_p": 0.0,
"ents_r": 0.0,
})
def create_ner_all_metrics(
nlp, name,
model, moves,
scorer, incorrect_spans_key,
update_with_oracle_cut_size, eval_frequency
):
return NERWithAllMetrics(
nlp.vocab, model,
name=name, moves=moves,
scorer=scorer, incorrect_spans_key=incorrect_spans_key,
update_with_oracle_cut_size=update_with_oracle_cut_size, eval_frequency=eval_frequency
)
class NERWithAllMetrics(EntityRecognizer):
def __init__(self, *args, eval_frequency=100, **kwargs):
super().__init__(*args, **kwargs)
self.metric_history = []
self.max_f1 = 0
self.max_f1_step = 0
self.eval_frequency = eval_frequency
self.start_learning_time = None
def score(self, examples, **kwargs):
scores = super().score(examples, **kwargs)
scores = dict(list(scores.items()) + list(self.custom_scorer(examples).items()))
tmp_scores = scores.copy()
tmp_scores["step"] = len(self.metric_history) * self.eval_frequency
if tmp_scores["f1_macro"] > self.max_f1:
self.max_f1 = tmp_scores["f1_macro"]
self.max_f1_step = tmp_scores["step"]
self.metric_history.append(tmp_scores)
return scores
def custom_scorer(self, examples):
y_true = []
y_pred = []
for example in examples:
gold = {(ent.start_char, ent.end_char, ent.label_) for ent in example.reference.ents}
pred = {(ent.start_char, ent.end_char, ent.label_) for ent in example.predicted.ents}
all_spans = gold | pred
for span in all_spans:
if span in gold and span in pred:
y_true.append(span[2])
y_pred.append(span[2])
elif span in gold:
y_true.append(span[2])
y_pred.append("O")
elif span in pred:
y_true.append("O")
y_pred.append(span[2])
labels = sorted({label for label in y_true if label != "O"})
precision, recall, f1, support = precision_recall_fscore_support(
y_true, y_pred, labels=labels, zero_division=0, average=None
)
result = {}
for l, p, r, f in zip(labels, precision, recall, f1):
result[f"f1_{l}"] = f
result["f1_micro"] = f1_score(y_true, y_pred, average="micro", labels=labels, zero_division=0)
result["f1_macro"] = f1_score(y_true, y_pred, average="macro", labels=labels, zero_division=0)
result["f1_weighted"] = f1_score(y_true, y_pred, average="weighted", labels=labels, zero_division=0)
return result
def preprocess_metric_history(self):
result = {
"metric_name": [],
"metric_value": [],
"step": []
}
for cur_metrics in self.metric_history:
cur_step = cur_metrics["step"]
for key, value in cur_metrics.items():
if key != "step" and isinstance(value, float):
result["metric_name"].append(key)
result["metric_value"].append(value)
result["step"].append(cur_step)
return result
def save_metrics_history(self, path):
if self.start_learning_time is None:
self.start_learning_time = time.monotonic()
if self.metric_history:
metrics_history_to_save = self.preprocess_metric_history()
fig = px.line(metrics_history_to_save, x="step", y="metric_value", color="metric_name")
for trace in fig.data:
if trace.name in ["f1_micro", "f1_macro", "f1_weighted"]:
trace.line.width = 6
else:
trace.line.width = 1
idx = list(trace.x).index(self.max_f1_step)
highlight_y = list(trace.y)[idx]
line_color = trace.line.color
line_name = trace.name
fig.add_trace(go.Scatter(
x=[self.max_f1_step], y=[highlight_y],
mode='markers+text',
marker=dict(
color=line_color, size=10),
text=[f"{round(highlight_y, 2)}"],
textposition="top center",
name=f"{line_name} best"
))
current_time = time.monotonic()
current_time_of_training = current_time - self.start_learning_time
current_time_of_training_text = f"{int(current_time_of_training // 3600)} hrs {int(current_time_of_training % 3600) // 60} min {round(current_time_of_training % 60)} sec"
fig.update_layout(title = dict(
text="Training statistics",
subtitle=dict(
text=f"Training time amounted to {current_time_of_training_text}",
font=dict(color="gray", size=13),
)
))
output_dir = os.path.join(str(path), "logs")
os.makedirs(output_dir, exist_ok=True)
fig_path = os.path.join(output_dir, "training_metrics.html")
json_path = os.path.join(output_dir, "training_metrics.json")
fig.write_html(fig_path)
with open(json_path, "w", encoding="utf-8") as f:
json.dump({
"data": metrics_history_to_save,
"train_time_s": current_time_of_training
}, f, indent=2, ensure_ascii=False)
def to_disk(self, path, *args, **kwargs):
super().to_disk(path, *args, **kwargs)
output_dir = Path(path)
output_dir_metrics = output_dir.parent.parent
self.save_metrics_history(output_dir_metrics)