|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import evaluate |
|
import datasets |
|
|
|
|
|
from .FairEvalUtils import * |
|
|
|
|
|
import importlib |
|
from typing import List, Optional, Union |
|
from seqeval.metrics.v1 import check_consistent_length |
|
from seqeval.scheme import Entities, Token, auto_detect |
|
|
|
_CITATION = """\ |
|
@inproceedings{ortmann2022, |
|
title = {Fine-Grained Error Analysis and Fair Evaluation of Labeled Spans}, |
|
author = {Katrin Ortmann}, |
|
url = {https://aclanthology.org/2022.lrec-1.150}, |
|
year = {2022}, |
|
date = {2022-06-21}, |
|
booktitle = {Proceedings of the Language Resources and Evaluation Conference (LREC)}, |
|
pages = {1400-1407}, |
|
publisher = {European Language Resources Association}, |
|
address = {Marseille, France}, |
|
pubstate = {published}, |
|
type = {inproceedings} |
|
} |
|
""" |
|
|
|
_DESCRIPTION = """\ |
|
New evaluation method that more accurately reflects true annotation quality by ensuring that every error is counted |
|
only once - avoiding the penalty to close-to-target annotations happening in traditional evaluation. |
|
In addition to the traditional categories of true positives (TP), false positives (FP), and false negatives |
|
(FN), the new method takes into account more fine-grained error types: labeling errors (LE), boundary errors (BE), |
|
and labeling-boundary errors (LBE). |
|
""" |
|
|
|
_KWARGS_DESCRIPTION = """ |
|
Outputs the error count (TP, FP, etc.) and resulting scores (Precision, Recall and F1) from a reference list of |
|
spans compared against a predicted one. The user can choose to see traditional or fair error counts and scores by |
|
switching the argument 'mode'. |
|
For the computation of the fair metrics from the error count please refer to: https://aclanthology.org/2022.lrec-1.150.pdf |
|
Args: |
|
predictions: a list of lists of predicted labels, i.e. estimated targets as returned by a tagger. |
|
references: list of ground truth reference labels. Predicted sentences must have the same number of tokens as the references. |
|
mode: 'fair', 'traditional' ot 'weighted. Controls the desired output. The default value is 'fair'. |
|
- 'traditional': equivalent to seqeval's metrics / classic span-based evaluation. |
|
- 'fair': default fair score calculation. It will also show traditional scores for comparison. |
|
- 'weighted': custom score calculation with the weights passed. It will also show traditional scores for comparison. |
|
weights: dictionary with the weight of each error for the custom score calculation. |
|
If none is passed and the mode is set to 'weighted', the following is used: |
|
{"TP": {"TP": 1}, |
|
"FP": {"FP": 1}, |
|
"FN": {"FN": 1}, |
|
"LE": {"TP": 0, "FP": 0.5, "FN": 0.5}, |
|
"BE": {"TP": 0.5, "FP": 0.25, "FN": 0.25}, |
|
"LBE": {"TP": 0, "FP": 0.5, "FN": 0.5}} |
|
error_format: 'count', 'error_ratio' or 'entity_ratio'. Controls the desired output for TP, FP, BE, LE, etc:. Default value is 'count'. |
|
- 'count': absolute count of each parameter. |
|
- 'error_ratio': precentage with respect to the total errors that each parameter represents. |
|
- 'entity_ratio': precentage with respect to the total number of ground truth entites that each parameter represents. |
|
zero_division: which value to substitute as a metric value when encountering zero division. Should be one of [0,1,"warn"]. "warn" acts as 0, but the warning is raised. |
|
suffix: True if the IOB tag is a suffix (after type) instead of a prefix (before type), False otherwise. The default value is False, i.e. the IOB tag is a prefix (before type). |
|
scheme: the target tagging scheme, which can be one of [IOB1, IOB2, IOE1, IOE2, IOBES, BILOU]. The default value is None. |
|
Returns: |
|
A dictionary with: |
|
- Overall error parameter count (or ratio) and resulting scores. |
|
- A nested dictionary per label with its respective error parameter count (or ratio) and resulting scores |
|
|
|
If mode is 'traditional', the error parameters shown are the classical TP, FP and FN. If mode is 'fair' or |
|
'weighted', TP remains the same, FP and FN are shown as per the fair definition and additional errors BE, LE and LBE are shown. |
|
|
|
Examples: |
|
>>> faireval = evaluate.load("hpi-dhc/FairEval") |
|
>>> pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O']] |
|
>>> ref = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O']] |
|
>>> results = faireval.compute(predictions=pred, references=ref, mode='fair', error_format='count') |
|
>>> print(results) |
|
{ |
|
"MISC": { |
|
"precision": 0.0, |
|
"recall": 0.0, |
|
"f1": 0.0, |
|
"trad_prec": 0.0, |
|
"trad_rec": 0.0, |
|
"trad_f1": 0.0, |
|
"TP": 0, |
|
"FP": 0.0, |
|
"FN": 0.0, |
|
"LE": 0.0, |
|
"BE": 1.0, |
|
"LBE": 0.0 |
|
}, |
|
"PER": { |
|
"precision": 1.0, |
|
"recall": 1.0, |
|
"f1": 1.0, |
|
"trad_prec": 1.0, |
|
"trad_rec": 1.0, |
|
"trad_f1": 1.0, |
|
"TP": 1, |
|
"FP": 0.0, |
|
"FN": 0.0, |
|
"LE": 0.0, |
|
"BE": 0.0, |
|
"LBE": 0.0 |
|
}, |
|
"overall_precision": 0.6666666666666666, |
|
"overall_recall": 0.6666666666666666, |
|
"overall_f1": 0.6666666666666666, |
|
"overall_trad_prec": 0.5, |
|
"overall_trad_rec": 0.5, |
|
"overall_trad_f1": 0.5, |
|
"TP": 1, |
|
"FP": 0.0, |
|
"FN": 0.0, |
|
"LE": 0.0, |
|
"BE": 1.0, |
|
"LBE": 0.0 |
|
} |
|
""" |
|
|
|
|
|
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
|
class FairEval(evaluate.Metric): |
|
|
|
def _info(self): |
|
return evaluate.MetricInfo( |
|
|
|
module_type="metric", |
|
description=_DESCRIPTION, |
|
citation=_CITATION, |
|
inputs_description=_KWARGS_DESCRIPTION, |
|
|
|
features=datasets.Features({ |
|
"predictions": datasets.Sequence(datasets.Value("string", id="label"), id="sequence"), |
|
"references": datasets.Sequence(datasets.Value("string", id="label"), id="sequence"), |
|
}), |
|
|
|
homepage="https://huggingface.co/spaces/hpi-dhc/FairEval", |
|
|
|
codebase_urls=["https://github.com/rubcompling/FairEval#acknowledgement"], |
|
reference_urls=["https://aclanthology.org/2022.lrec-1.150.pdf"] |
|
) |
|
|
|
def _compute( |
|
self, |
|
predictions, |
|
references, |
|
suffix: bool = False, |
|
scheme: Optional[str] = None, |
|
mode: Optional[str] = 'fair', |
|
weights: dict = None, |
|
error_format: Optional[str] = 'count', |
|
zero_division: Union[str, int] = "warn", |
|
): |
|
"""Returns the error parameter counts and scores""" |
|
|
|
if scheme is not None: |
|
try: |
|
scheme_module = importlib.import_module("seqeval.scheme") |
|
scheme = getattr(scheme_module, scheme) |
|
except AttributeError: |
|
raise ValueError(f"Scheme should be one of [IOB1, IOB2, IOE1, IOE2, IOBES, BILOU], got {scheme}") |
|
|
|
y_true = references |
|
y_pred = predictions |
|
|
|
check_consistent_length(y_true, y_pred) |
|
|
|
if scheme is None or not issubclass(scheme, Token): |
|
scheme = auto_detect(y_true, suffix) |
|
|
|
true_spans = Entities(y_true, scheme, suffix).entities |
|
pred_spans = Entities(y_pred, scheme, suffix).entities |
|
|
|
|
|
true_spans = seq_to_fair(true_spans) |
|
pred_spans = seq_to_fair(pred_spans) |
|
|
|
|
|
total_errors = compare_spans([], []) |
|
total_ref_entities = 0 |
|
for i in range(len(true_spans)): |
|
total_ref_entities += len(true_spans[i]) |
|
sentence_errors = compare_spans(true_spans[i], pred_spans[i]) |
|
total_errors = add_dict(total_errors, sentence_errors) |
|
|
|
if weights is None and mode == 'weighted': |
|
weights = {"TP": {"TP": 1}, |
|
"FP": {"FP": 1}, |
|
"FN": {"FN": 1}, |
|
"LE": {"TP": 0, "FP": 0.5, "FN": 0.5}, |
|
"BE": {"TP": 0.5, "FP": 0.25, "FN": 0.25}, |
|
"LBE": {"TP": 0, "FP": 0.5, "FN": 0.5}} |
|
print("The chosen mode is \'weighted\', but no weights are given. Setting weights to:") |
|
for k in weights: |
|
print(k, ":", weights[k]) |
|
|
|
config = {"labels": "all", "eval_method": ['traditional', 'fair', 'weighted'], "weights": weights,} |
|
results = calculate_results(total_errors, config) |
|
del results['conf'] |
|
|
|
|
|
|
|
output = {} |
|
|
|
if error_format == 'count': |
|
trad_divider = 1 |
|
fair_divider = 1 |
|
elif error_format == 'entity_ratio': |
|
trad_divider = total_ref_entities |
|
fair_divider = total_ref_entities |
|
elif error_format == 'error_ratio': |
|
trad_divider = results['overall']['traditional']['FP'] + results['overall']['traditional']['FN'] |
|
fair_divider = results['overall']['fair']['FP'] + results['overall']['fair']['FN'] + \ |
|
results['overall']['fair']['LE'] + results['overall']['fair']['BE'] + \ |
|
results['overall']['fair']['LBE'] |
|
|
|
|
|
assert mode in ['traditional', 'fair', 'weighted'], 'mode must be \'traditional\', \'fair\' or \'weighted\'' |
|
assert error_format in ['count', 'error_ratio', 'entity_ratio'], 'error_format must be \'count\', \'error_ratio\' or \'entity_ratio\'' |
|
|
|
|
|
if mode == 'traditional': |
|
for k, v in results['per_label'][mode].items(): |
|
output[k] = { |
|
'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], |
|
|
|
|
|
'TP': v['TP'] / trad_divider if error_format == 'entity_ratio' else v['TP'], |
|
'FP': v['FP'] / trad_divider, 'FN': v['FN'] / trad_divider} |
|
elif mode == 'fair' or mode == 'weighted': |
|
for k, v in results['per_label'][mode].items(): |
|
output[k] = { |
|
'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], |
|
|
|
|
|
'trad_prec': results['per_label']['traditional'][k]['Prec'], |
|
'trad_rec': results['per_label']['traditional'][k]['Rec'], |
|
'trad_f1': results['per_label']['traditional'][k]['F1'], |
|
|
|
|
|
'TP': v['TP'] / fair_divider if error_format == 'entity_ratio' else v['TP'], |
|
'FP': v['FP'] / fair_divider, 'FN': v['FN'] / fair_divider, |
|
'LE': v['LE'] / fair_divider, 'BE': v['BE'] / fair_divider, 'LBE': v['LBE'] / fair_divider} |
|
|
|
|
|
output['overall_precision'] = results['overall'][mode]['Prec'] |
|
output['overall_recall'] = results['overall'][mode]['Rec'] |
|
output['overall_f1'] = results['overall'][mode]['F1'] |
|
|
|
|
|
if mode == 'traditional': |
|
output['TP'] = results['overall'][mode]['TP'] / trad_divider if error_format == 'entity_ratio' else \ |
|
results['overall'][mode]['TP'] |
|
output['FP'] = results['overall'][mode]['FP'] / trad_divider |
|
output['FN'] = results['overall'][mode]['FN'] / trad_divider |
|
elif mode == 'fair' or 'weighted': |
|
output['overall_trad_prec'] = results['overall']['traditional']['Prec'] |
|
output['overall_trad_rec'] = results['overall']['traditional']['Rec'] |
|
output['overall_trad_f1'] = results['overall']['traditional']['F1'] |
|
output['TP'] = results['overall'][mode]['TP'] / fair_divider if error_format == 'entity_ratio' else \ |
|
results['overall'][mode]['TP'] |
|
output['FP'] = results['overall'][mode]['FP'] / fair_divider |
|
output['FN'] = results['overall'][mode]['FN'] / fair_divider |
|
output['LE'] = results['overall'][mode]['LE'] / fair_divider |
|
output['BE'] = results['overall'][mode]['BE'] / fair_divider |
|
output['LBE'] = results['overall'][mode]['LBE'] / fair_divider |
|
|
|
return output |
|
|
|
|
|
def seq_to_fair(seq_sentences): |
|
"Transforms input annotated sentences from seqeval span format to FairEval span format" |
|
out = [] |
|
for seq_sentence in seq_sentences: |
|
sentence = [] |
|
for entity in seq_sentence: |
|
span = str(entity).replace('(', '').replace(')', '').replace(' ', '').split(',') |
|
span = span[1:] |
|
span[-1] = int(span[-1]) - 1 |
|
span[1] = int(span[1]) |
|
span.append({i for i in range(span[1], span[2] + 1)}) |
|
sentence.append(span) |
|
out.append(sentence) |
|
return out |
|
|