|
|
|
|
|
''' |
|
Created 09/2021 |
|
|
|
@author: Katrin Ortmann |
|
''' |
|
|
|
import argparse |
|
import os |
|
import sys |
|
import re |
|
from typing import Iterable |
|
from io import TextIOWrapper |
|
from copy import deepcopy |
|
|
|
|
|
|
|
def precision(evaldict, version="traditional", weights={}): |
|
""" |
|
Calculate traditional, fair or weighted precision value. |
|
|
|
Precision is calculated as the number of true positives |
|
divided by the number of true positives plus false positives |
|
plus (optionally) additional error types. |
|
|
|
Input: |
|
- A dictionary with error types as keys and counts as values, e.g., |
|
{"TP" : 10, "FP" : 2, "LE" : 1, ...} |
|
|
|
For 'traditional' evaluation, true positives (key: TP) and |
|
false positives (key: FP) are required. |
|
The 'fair' evaluation is based on true positives (TP), |
|
false positives (FP), labeling errors (LE), boundary errors (BE) |
|
and labeling-boundary errors (LBE). |
|
The 'weighted' evaluation can include any error type |
|
that is given as key in the weight dictionary. |
|
For missing keys, the count is set to 0. |
|
|
|
- The desired evaluation method. Options are 'traditional', |
|
'fair', and 'weighted'. If no weight dictionary is specified, |
|
'weighted' is identical to 'fair'. |
|
|
|
- A weight dictionary to specify how much an error type should |
|
count as one of the traditional error types (or as true positive). |
|
Per default, every traditional error is counted as one error (or true positive) |
|
and each error of the additional types is counted as half false positive and half false negative: |
|
|
|
{"TP" : {"TP" : 1}, |
|
"FP" : {"FP" : 1}, |
|
"FN" : {"FN" : 1}, |
|
"LE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}, |
|
"BE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}, |
|
"LBE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}} |
|
|
|
Other suggested weights to count boundary errors as half true positives: |
|
|
|
{"TP" : {"TP" : 1}, |
|
"FP" : {"FP" : 1}, |
|
"FN" : {"FN" : 1}, |
|
"LE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}, |
|
"BE" : {"TP" : 0.5, "FP" : 0.25, "FN" : 0.25}, |
|
"LBE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}} |
|
|
|
Or to include different types of boundary errors: |
|
|
|
{"TP" : {"TP" : 1}, |
|
"FP" : {"FP" : 1}, |
|
"FN" : {"FN" : 1}, |
|
"LE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}, |
|
"LBE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}, |
|
"BEO" : {"TP" : 0.5, "FP" : 0.25, "FN" : 0.25}, |
|
"BES" : {"TP" : 0.5, "FP" : 0, "FN" : 0.5}, |
|
"BEL" : {"TP" : 0.5, "FP" : 0.5, "FN" : 0}} |
|
|
|
Output: |
|
The precision for the given input values. |
|
In case of a ZeroDivisionError, the precision is set to 0. |
|
|
|
""" |
|
traditional_weights = { |
|
"TP" : {"TP" : 1}, |
|
"FP" : {"FP" : 1}, |
|
"FN" : {"FN" : 1} |
|
} |
|
default_fair_weights = { |
|
"TP" : {"TP" : 1}, |
|
"FP" : {"FP" : 1}, |
|
"FN" : {"FN" : 1}, |
|
"LE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}, |
|
"BE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}, |
|
"LBE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5} |
|
} |
|
try: |
|
tp = 0 |
|
fp = 0 |
|
|
|
|
|
if version == "traditional": |
|
weights = traditional_weights |
|
|
|
|
|
|
|
elif version == "fair" or not weights: |
|
weights = default_fair_weights |
|
|
|
|
|
tp += sum( |
|
[w.get("TP", 0) * evaldict.get(error, 0) for error, w in weights.items()] |
|
) |
|
|
|
|
|
fp += sum( |
|
[w.get("FP", 0) * evaldict.get(error, 0) for error, w in weights.items()] |
|
) |
|
|
|
|
|
return tp / (tp + fp) |
|
|
|
|
|
except ZeroDivisionError: |
|
return 0.0 |
|
|
|
|
|
|
|
def recall(evaldict, version="traditional", weights={}): |
|
""" |
|
Calculate traditional, fair or weighted recall value. |
|
|
|
Recall is calculated as the number of true positives |
|
divided by the number of true positives plus false negatives |
|
plus (optionally) additional error types. |
|
|
|
Input: |
|
- A dictionary with error types as keys and counts as values, e.g., |
|
{"TP" : 10, "FN" : 2, "LE" : 1, ...} |
|
|
|
For 'traditional' evaluation, true positives (key: TP) and |
|
false negatives (key: FN) are required. |
|
The 'fair' evaluation is based on true positives (TP), |
|
false negatives (FN), labeling errors (LE), boundary errors (BE) |
|
and labeling-boundary errors (LBE). |
|
The 'weighted' evaluation can include any error type |
|
that is given as key in the weight dictionary. |
|
For missing keys, the count is set to 0. |
|
|
|
- The desired evaluation method. Options are 'traditional', |
|
'fair', and 'weighted'. If no weight dictionary is specified, |
|
'weighted' is identical to 'fair'. |
|
|
|
- A weight dictionary to specify how much an error type should |
|
count as one of the traditional error types (or as true positive). |
|
Per default, every traditional error is counted as one error (or true positive) |
|
and each error of the additional types is counted as half false positive and half false negative: |
|
|
|
{"TP" : {"TP" : 1}, |
|
"FP" : {"FP" : 1}, |
|
"FN" : {"FN" : 1}, |
|
"LE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}, |
|
"BE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}, |
|
"LBE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}} |
|
|
|
Other suggested weights to count boundary errors as half true positives: |
|
|
|
{"TP" : {"TP" : 1}, |
|
"FP" : {"FP" : 1}, |
|
"FN" : {"FN" : 1}, |
|
"LE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}, |
|
"BE" : {"TP" : 0.5, "FP" : 0.25, "FN" : 0.25}, |
|
"LBE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}} |
|
|
|
Or to include different types of boundary errors: |
|
|
|
{"TP" : {"TP" : 1}, |
|
"FP" : {"FP" : 1}, |
|
"FN" : {"FN" : 1}, |
|
"LE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}, |
|
"LBE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}, |
|
"BEO" : {"TP" : 0.5, "FP" : 0.25, "FN" : 0.25}, |
|
"BES" : {"TP" : 0.5, "FP" : 0, "FN" : 0.5}, |
|
"BEL" : {"TP" : 0.5, "FP" : 0.5, "FN" : 0}} |
|
|
|
Output: |
|
The recall for the given input values. |
|
In case of a ZeroDivisionError, the recall is set to 0. |
|
|
|
""" |
|
traditional_weights = { |
|
"TP" : {"TP" : 1}, |
|
"FP" : {"FP" : 1}, |
|
"FN" : {"FN" : 1} |
|
} |
|
default_fair_weights = { |
|
"TP" : {"TP" : 1}, |
|
"FP" : {"FP" : 1}, |
|
"FN" : {"FN" : 1}, |
|
"LE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}, |
|
"BE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}, |
|
"LBE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5} |
|
} |
|
try: |
|
tp = 0 |
|
fn = 0 |
|
|
|
|
|
if version == "traditional": |
|
weights = traditional_weights |
|
|
|
|
|
|
|
elif version == "fair" or not weights: |
|
weights = default_fair_weights |
|
|
|
|
|
tp += sum( |
|
[w.get("TP", 0) * evaldict.get(error, 0) for error, w in weights.items()] |
|
) |
|
|
|
|
|
fn += sum( |
|
[w.get("FN", 0) * evaldict.get(error, 0) for error, w in weights.items()] |
|
) |
|
|
|
|
|
return tp / (tp + fn) |
|
|
|
|
|
except ZeroDivisionError: |
|
return 0.0 |
|
|
|
|
|
|
|
def fscore(evaldict): |
|
""" |
|
Calculates F1-Score from given precision and recall values. |
|
|
|
Input: A dictionary with a precision (key: Prec) and recall (key: Rec) value. |
|
Output: The F1-Score. In case of a ZeroDivisionError, the F1-Score is set to 0. |
|
""" |
|
try: |
|
return 2 * (evaldict.get("Prec", 0) * evaldict.get("Rec", 0)) \ |
|
/ (evaldict.get("Prec", 0) + evaldict.get("Rec", 0)) |
|
except ZeroDivisionError: |
|
return 0.0 |
|
|
|
|
|
|
|
def overlap_type(span1, span2): |
|
""" |
|
Determine the error type of two (overlapping) spans. |
|
|
|
The function checks, if and how span1 and span2 overlap. |
|
The first span serves as the basis against which the second |
|
span is evaluated. |
|
|
|
span1 ---XXXX--- |
|
span2 ---XXXX--- TP (identical) |
|
span2 ----XXXX-- BEO (overlap) |
|
span2 --XXXX---- BEO (overlap) |
|
span2 ----XX---- BES (smaller) |
|
span2 ---XX----- BES (smaller) |
|
span2 --XXXXXX-- BEL (larger) |
|
span2 --XXXXX--- BEL (larger) |
|
span2 -X-------- False (no overlap) |
|
|
|
Input: |
|
Tuples (beginSpan1, endSpan1) and (beginSpan2, endSpan2), |
|
where begin and end are the indices of the corresponding tokens. |
|
|
|
Output: |
|
Either one of the following strings |
|
- "TP" = span1 and span2 are identical, i.e., actually no error here |
|
- "BES" = span2 is shorter and contained within span1 (with at most one identical boundary) |
|
- "BEL" = span2 is longer and contains span1 (with at most one identical boundary) |
|
- "BEO" = span1 and span2 overlap with no identical boundary |
|
or False if span1 and span2 do not overlap. |
|
""" |
|
|
|
if span1[0] == span2[0] and span1[1] == span2[1]: |
|
return "TP" |
|
|
|
|
|
if span1[0] == span2[0]: |
|
|
|
if span2[1] >= span1[0] and span2[1] < span1[1]: |
|
return "BES" |
|
|
|
else: |
|
return "BEL" |
|
|
|
elif span2[0] < span1[0]: |
|
|
|
if span2[1] < span1[0]: |
|
return False |
|
|
|
elif span2[1] < span1[1]: |
|
return "BEO" |
|
|
|
else: |
|
return "BEL" |
|
|
|
elif span2[0] >= span1[0] and span2[0] <= span1[1]: |
|
|
|
if span2[1] <= span1[1]: |
|
return "BES" |
|
|
|
else: |
|
return "BEO" |
|
|
|
else: |
|
return False |
|
|
|
|
|
|
|
def compare_spans(target_spans, system_spans, focus="target"): |
|
""" |
|
Compare system and target spans to identify correct/incorrect annotations. |
|
|
|
The function takes a list of target spans and system spans. |
|
Each span is a 4-tuple of |
|
- label: the span type as string |
|
- begin: the index of first token; equals end for spans of length 1 |
|
- end: the index of the last token; equals begin for spans of length 1 |
|
- tokens: a set of token indices included in the span |
|
(this allows the correct evaluation of |
|
partially and multiply overlapping spans; |
|
to allow for changes of the token set, |
|
the span tuple is actually implemented as a list.) |
|
|
|
The function first performs traditional evaluation on these spans |
|
to identify true positives, false positives, and false negatives. |
|
Then, the additional error types for fair evaluation are determined, |
|
following steps 1 to 4: |
|
1. Count 1:1 mappings (TP, LE) |
|
2. Count boundary errors (BE = BES + BEL + BEO) |
|
3. Count labeling-boundary errors (LBE) |
|
4. Count 1:0 and 0:1 mappings (FN, FP) |
|
|
|
Input: |
|
- List of target spans |
|
- List of system spans |
|
- Wether to focus on the system or target annotation (default: target) |
|
|
|
Output: A dictionary containing |
|
- the counts of TP, FP, and FN according to traditional evaluation |
|
(per label and overall) |
|
- the counts of TP, FP, LE, BE, BES, BEL, BEO, and FN |
|
(per label and overall; BE = BES + BEL + BEO) |
|
- a confusion matrix {target_label1 : {system_label1 : count, |
|
system_label2 : count, |
|
...}, |
|
target_label2 : ... |
|
} |
|
with an underscore '_' representing an empty label (FN/FP) |
|
""" |
|
|
|
|
|
|
|
def _max_sim(t, S): |
|
""" |
|
Determine the most similar span s from S for span t. |
|
|
|
Similarity is defined as |
|
1. the maximum number of shared tokens between s and t and |
|
2. the minimum number of tokens only in t |
|
If multiple spans are equally similar, the shortest s is chosen. |
|
If still multiple spans are equally similar, the first one in the list is chosen, |
|
which corresponds to the left-most one if sentences are read from left to right. |
|
|
|
Input: |
|
- Span t as 4-tuple [label, begin, end, token_set] |
|
- List S containing > 1 spans |
|
|
|
Output: The most similar s for t. |
|
""" |
|
S.sort(key=lambda s: (0-len(t[3].intersection(s[3])), |
|
len(t[3].difference(s[3])), |
|
len(s[3].difference(t[3])), |
|
s[2]-s[1])) |
|
return S[0] |
|
|
|
|
|
|
|
traditional_error_types = ["TP", "FP", "FN"] |
|
additional_error_types = ["LE", "BE", "BEO", "BES", "BEL", "LBE"] |
|
|
|
|
|
eval_dict = {"overall" : {"traditional" : {err_type : 0 for err_type |
|
in traditional_error_types}, |
|
"fair" : {err_type : 0 for err_type |
|
in traditional_error_types + additional_error_types}}, |
|
"per_label" : {"traditional" : {}, |
|
"fair" : {}}, |
|
"conf" : {}} |
|
|
|
|
|
for s in target_spans + system_spans: |
|
if not s[0] in eval_dict["per_label"]["traditional"]: |
|
eval_dict["per_label"]["traditional"][s[0]] = {err_type : 0 for err_type |
|
in traditional_error_types} |
|
eval_dict["per_label"]["fair"][s[0]] = {err_type : 0 for err_type |
|
in traditional_error_types + additional_error_types} |
|
|
|
if not s[0] in eval_dict["conf"]: |
|
eval_dict["conf"][s[0]] = {} |
|
eval_dict["conf"]["_"] = {} |
|
for lab in list(eval_dict["conf"])+["_"]: |
|
for lab2 in list(eval_dict["conf"])+["_"]: |
|
eval_dict["conf"][lab][lab2] = 0 |
|
|
|
|
|
|
|
|
|
for t in target_spans: |
|
|
|
if t in system_spans: |
|
eval_dict["overall"]["traditional"]["TP"] += 1 |
|
eval_dict["per_label"]["traditional"][t[0]]["TP"] += 1 |
|
|
|
else: |
|
eval_dict["overall"]["traditional"]["FN"] += 1 |
|
eval_dict["per_label"]["traditional"][t[0]]["FN"] += 1 |
|
for s in system_spans: |
|
|
|
if not s in target_spans: |
|
eval_dict["overall"]["traditional"]["FP"] += 1 |
|
eval_dict["per_label"]["traditional"][s[0]]["FP"] += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tps = [t for t in target_spans if t in system_spans] |
|
for t in tps: |
|
s = [s for s in system_spans if s == t] |
|
if s: |
|
s = s[0] |
|
eval_dict["overall"]["fair"]["TP"] += 1 |
|
eval_dict["per_label"]["fair"][t[0]]["TP"] += 1 |
|
|
|
system_spans.remove(s) |
|
target_spans.remove(t) |
|
|
|
|
|
|
|
les = [t for t in target_spans |
|
if any(t[0] != s[0] and t[1:3] == s[1:3] for s in system_spans)] |
|
for t in les: |
|
s = [s for s in system_spans if t[0] != s[0] and t[1:3] == s[1:3]] |
|
if s: |
|
s = s[0] |
|
|
|
eval_dict["overall"]["fair"]["LE"] += 1 |
|
|
|
if focus == "target": |
|
eval_dict["per_label"]["fair"][t[0]]["LE"] += 1 |
|
elif focus == "system": |
|
eval_dict["per_label"]["fair"][s[0]]["LE"] += 1 |
|
|
|
eval_dict["conf"][t[0]][s[0]] += 1 |
|
|
|
system_spans.remove(s) |
|
target_spans.remove(t) |
|
|
|
|
|
|
|
|
|
counted_target = list() |
|
counted_system = list() |
|
|
|
|
|
target_spans.sort(key=lambda t : t[2] - t[1]) |
|
system_spans.sort(key=lambda s : s[2] - s[1]) |
|
|
|
|
|
|
|
|
|
|
|
i = 0 |
|
while i < len(target_spans): |
|
t = target_spans[i] |
|
|
|
|
|
be = [s for s in system_spans |
|
if t[0] == s[0] and t[1:3] != s[1:3] |
|
and overlap_type((t[1], t[2]), (s[1], s[2])) in ("BES", "BEL", "BEO")] |
|
if not be: |
|
i += 1 |
|
continue |
|
|
|
|
|
if len(be) > 1: |
|
s = _max_sim(t, be) |
|
else: |
|
s = be[0] |
|
|
|
|
|
be_type = overlap_type((t[1], t[2]), (s[1], s[2])) |
|
|
|
|
|
eval_dict["overall"]["fair"]["BE"] += 1 |
|
eval_dict["overall"]["fair"][be_type] += 1 |
|
|
|
|
|
eval_dict["per_label"]["fair"][t[0]]["BE"] += 1 |
|
eval_dict["per_label"]["fair"][t[0]][be_type] += 1 |
|
|
|
|
|
eval_dict["conf"][t[0]][s[0]] += 1 |
|
|
|
|
|
system_spans.remove(s) |
|
target_spans.remove(t) |
|
|
|
|
|
matching_tokens = t[3].intersection(s[3]) |
|
s[3] = s[3].difference(matching_tokens) |
|
t[3] = t[3].difference(matching_tokens) |
|
|
|
|
|
counted_system.append(s) |
|
counted_target.append(t) |
|
|
|
|
|
i = 0 |
|
while i < len(target_spans): |
|
t = target_spans[i] |
|
|
|
|
|
|
|
be = [s for s in counted_system |
|
if t[0] == s[0] and t[1:3] != s[1:3] |
|
and overlap_type((t[1], t[2]), (s[1], s[2])) in ("BES", "BEL", "BEO") |
|
and t[3].intersection(s[3])] |
|
if not be: |
|
i += 1 |
|
continue |
|
|
|
|
|
if len(be) > 1: |
|
s = _max_sim(t, be) |
|
else: |
|
s = be[0] |
|
|
|
|
|
be_type = overlap_type((t[1], t[2]), (s[1], s[2])) |
|
|
|
|
|
eval_dict["overall"]["fair"]["BE"] += 1 |
|
eval_dict["overall"]["fair"][be_type] += 1 |
|
|
|
|
|
eval_dict["per_label"]["fair"][t[0]]["BE"] += 1 |
|
eval_dict["per_label"]["fair"][t[0]][be_type] += 1 |
|
|
|
|
|
eval_dict["conf"][t[0]][s[0]] += 1 |
|
|
|
|
|
target_spans.remove(t) |
|
|
|
|
|
matching_tokens = t[3].intersection(s[3]) |
|
counted_system[counted_system.index(s)][3] = s[3].difference(matching_tokens) |
|
t[3] = t[3].difference(matching_tokens) |
|
|
|
|
|
counted_target.append(t) |
|
|
|
|
|
i = 0 |
|
while i < len(system_spans): |
|
s = system_spans[i] |
|
|
|
|
|
be = [t for t in counted_target |
|
if t[0] == s[0] and t[1:3] != s[1:3] |
|
and overlap_type((t[1], t[2]), (s[1], s[2])) in ("BES", "BEL", "BEO") |
|
and t[3].intersection(s[3])] |
|
if not be: |
|
i += 1 |
|
continue |
|
|
|
|
|
if len(be) > 1: |
|
t = _max_sim(s, be) |
|
else: |
|
t = be[0] |
|
|
|
|
|
be_type = overlap_type((t[1], t[2]), (s[1], s[2])) |
|
|
|
|
|
eval_dict["overall"]["fair"]["BE"] += 1 |
|
eval_dict["overall"]["fair"][be_type] += 1 |
|
|
|
|
|
eval_dict["per_label"]["fair"][t[0]]["BE"] += 1 |
|
eval_dict["per_label"]["fair"][t[0]][be_type] += 1 |
|
|
|
|
|
eval_dict["conf"][t[0]][s[0]] += 1 |
|
|
|
|
|
system_spans.remove(s) |
|
|
|
|
|
matching_tokens = t[3].intersection(s[3]) |
|
counted_target[counted_target.index(t)][3] = t[3].difference(matching_tokens) |
|
s[3] = s[3].difference(matching_tokens) |
|
|
|
|
|
counted_system.append(s) |
|
|
|
|
|
|
|
|
|
|
|
i = 0 |
|
while i < len(target_spans): |
|
t = target_spans[i] |
|
|
|
|
|
lbe = [s for s in system_spans |
|
if t[0] != s[0] and t[1:3] != s[1:3] |
|
and overlap_type((t[1], t[2]), (s[1], s[2])) in ("BES", "BEL", "BEO")] |
|
if not lbe: |
|
i += 1 |
|
continue |
|
|
|
|
|
if len(lbe) > 1: |
|
s = _max_sim(t, lbe) |
|
else: |
|
s = lbe[0] |
|
|
|
|
|
eval_dict["overall"]["fair"]["LBE"] += 1 |
|
|
|
|
|
if focus == "target": |
|
eval_dict["per_label"]["fair"][t[0]]["LBE"] += 1 |
|
elif focus == "system": |
|
eval_dict["per_label"]["fair"][s[0]]["LBE"] += 1 |
|
|
|
|
|
eval_dict["conf"][t[0]][s[0]] += 1 |
|
|
|
|
|
system_spans.remove(s) |
|
target_spans.remove(t) |
|
|
|
|
|
matching_tokens = t[3].intersection(s[3]) |
|
s[3] = s[3].difference(matching_tokens) |
|
t[3] = t[3].difference(matching_tokens) |
|
|
|
|
|
counted_system.append(s) |
|
counted_target.append(t) |
|
|
|
|
|
i = 0 |
|
while i < len(target_spans): |
|
t = target_spans[i] |
|
|
|
|
|
lbe = [s for s in counted_system |
|
if t[0] != s[0] and t[1:3] != s[1:3] |
|
and overlap_type((t[1], t[2]), (s[1], s[2])) in ("BES", "BEL", "BEO") |
|
and t[3].intersection(s[3])] |
|
if not lbe: |
|
i += 1 |
|
continue |
|
|
|
|
|
if len(lbe) > 1: |
|
s = _max_sim(t, lbe) |
|
else: |
|
s = lbe[0] |
|
|
|
|
|
eval_dict["overall"]["fair"]["LBE"] += 1 |
|
|
|
|
|
if focus == "target": |
|
eval_dict["per_label"]["fair"][t[0]]["LBE"] += 1 |
|
elif focus == "system": |
|
eval_dict["per_label"]["fair"][s[0]]["LBE"] += 1 |
|
|
|
|
|
eval_dict["conf"][t[0]][s[0]] += 1 |
|
|
|
|
|
target_spans.remove(t) |
|
|
|
|
|
matching_tokens = t[3].intersection(s[3]) |
|
counted_system[counted_system.index(s)][3] = s[3].difference(matching_tokens) |
|
t[3] = t[3].difference(matching_tokens) |
|
|
|
|
|
counted_target.append(t) |
|
|
|
|
|
i = 0 |
|
while i < len(system_spans): |
|
s = system_spans[i] |
|
|
|
|
|
lbe = [t for t in counted_target |
|
if t[0] != s[0] and t[1:3] != s[1:3] |
|
and overlap_type((t[1], t[2]), (s[1], s[2])) in ("BES", "BEL", "BEO") |
|
and t[3].intersection(s[3])] |
|
if not lbe: |
|
i += 1 |
|
continue |
|
|
|
|
|
if len(lbe) > 1: |
|
t = _max_sim(s, lbe) |
|
else: |
|
t = lbe[0] |
|
|
|
|
|
eval_dict["overall"]["fair"]["LBE"] += 1 |
|
|
|
|
|
if focus == "target": |
|
eval_dict["per_label"]["fair"][t[0]]["LBE"] += 1 |
|
elif focus == "system": |
|
eval_dict["per_label"]["fair"][s[0]]["LBE"] += 1 |
|
|
|
|
|
eval_dict["conf"][t[0]][s[0]] += 1 |
|
|
|
|
|
system_spans.remove(s) |
|
|
|
|
|
matching_tokens = t[3].intersection(s[3]) |
|
counted_target[counted_target.index(t)][3] = t[3].difference(matching_tokens) |
|
s[3] = s[3].difference(matching_tokens) |
|
|
|
|
|
counted_system.append(s) |
|
|
|
|
|
|
|
|
|
for t in target_spans: |
|
eval_dict["overall"]["fair"]["FN"] += 1 |
|
eval_dict["per_label"]["fair"][t[0]]["FN"] += 1 |
|
eval_dict["conf"][t[0]]["_"] += 1 |
|
|
|
|
|
for s in system_spans: |
|
eval_dict["overall"]["fair"]["FP"] += 1 |
|
eval_dict["per_label"]["fair"][s[0]]["FP"] += 1 |
|
eval_dict["conf"]["_"][s[0]] += 1 |
|
|
|
return eval_dict |
|
|
|
|
|
|
|
def annotation_stats(target_spans, **config): |
|
""" |
|
Count the target annotations to display simple statistics. |
|
|
|
The function takes a list of target spans |
|
with each span being a 4-tuple [label, begin, end, token_set] |
|
and adds the included labels to the general data stats dictionary. |
|
|
|
Input: |
|
- List of target spans |
|
- Config dictionary |
|
|
|
Output: The config dictionary is modified in-place. |
|
""" |
|
stats_dict = config.get("data_stats", {}) |
|
for span in target_spans: |
|
if span[0] in stats_dict: |
|
stats_dict[span[0]] += 1 |
|
else: |
|
stats_dict[span[0]] = 1 |
|
config["data_stats"] = stats_dict |
|
|
|
|
|
|
|
def get_spans(sentence, **config): |
|
""" |
|
Return spans from CoNLL2000 or span files. |
|
|
|
The function determines the data format of the input sentence |
|
and extracts the spans from it accordingly. |
|
|
|
If desired, punctuation can be ignored (config['ignore_punct'] == True) |
|
for files in the CoNLL2000 format that include POS information. |
|
The following list of tags is considered as punctuation: |
|
['$.', '$,', '$(', #STTS |
|
'PUNCT', #UPOS |
|
'PUNKT', 'KOMMA', 'COMMA', 'KLAMMER', #custom |
|
'.', ',', ':', '(', ')', '"', '‘', '“', '’', '”' #PTB |
|
] |
|
|
|
Labels that should be ignored (included in config['exclude'] |
|
or not included in config['labels'] if config['labels'] != 'all') |
|
are also removed from the resulting list. |
|
|
|
Input: |
|
- List of lines for a given sentence |
|
- Config dictionary |
|
|
|
Output: List of spans that are included in the sentence. |
|
""" |
|
|
|
|
|
|
|
def spans_from_conll(sentence): |
|
""" |
|
Read annotation spans from a CoNLL2000 file. |
|
|
|
The function takes a list of lines (belonging to one sentence) |
|
and extracts the annotated spans. The lines are expected to |
|
contain three space-separated columns: |
|
|
|
Form XPOS Annotation |
|
|
|
Form: Word form |
|
XPOS: POS tag of the word (ideally STTS, UPOS, or PTB) |
|
Annotation: Span annotation in BIO format (see below); |
|
multiple spans are separated with the pipe symbol '|' |
|
|
|
BIO tags consist of the token's position in the span |
|
(begin 'B', inside 'I', outside 'O'), a dash '-' and the span label, |
|
e.g., B-NP, I-AC, or in the case of stacked annotations I-RELC|B-NP. |
|
|
|
The function accepts 'O', '_' and '' as annotations outside of spans. |
|
|
|
Input: List of lines belonging to one sentence. |
|
Output: List of spans as 4-tuples [label, begin, end, token_set] |
|
""" |
|
spans = [] |
|
span_stack = [] |
|
|
|
|
|
for t, tok in enumerate(sentence): |
|
|
|
|
|
tok = tok.split() |
|
|
|
|
|
if tok[-1] in ["O", "_", ""]: |
|
|
|
|
|
while span_stack: |
|
spans.append(span_stack.pop(0)) |
|
span_stack = [] |
|
continue |
|
|
|
|
|
|
|
annotations = tok[-1].strip().split("|") |
|
|
|
|
|
|
|
|
|
|
|
while len(span_stack) > len(annotations): |
|
spans.append(span_stack.pop()) |
|
|
|
|
|
for i, annotation in enumerate(annotations): |
|
|
|
|
|
if annotation.startswith("B-"): |
|
|
|
|
|
|
|
if i == 0 and span_stack: |
|
while span_stack: |
|
spans.append(span_stack.pop(0)) |
|
|
|
|
|
|
|
else: |
|
while len(span_stack) > i: |
|
spans.append(span_stack.pop()) |
|
|
|
|
|
label = annotation.split("-")[1] |
|
|
|
|
|
|
|
s = [label, t+1, t+1, {t+1}] |
|
|
|
|
|
span_stack.append(s) |
|
|
|
|
|
elif annotation.startswith("I-"): |
|
|
|
|
|
span_stack[i][2] = t+1 |
|
|
|
span_stack[i][-1].add(t+1) |
|
|
|
|
|
while span_stack: |
|
spans.append(span_stack.pop(0)) |
|
|
|
return spans |
|
|
|
|
|
|
|
def spans_from_lines(sentence): |
|
""" |
|
Read annotation spans from a span file. |
|
|
|
The function takes a list of lines (belonging to one sentence) |
|
and extracts the annotated spans. The lines are expected to |
|
contain four tab-separated columns: |
|
|
|
Label Begin End Tokens |
|
|
|
Label: Span label |
|
Begin: Index of the first included token (must be convertible to int) |
|
End: Index of the last included token (must be convertible to int |
|
and equal or greater than begin) |
|
Tokens: Comma-separated list of indices of the tokens in the span |
|
(must be convertible to int with begin <= i <= end); |
|
if no (valid) indices are given, the range begin:end is used |
|
|
|
Input: List of lines belonging to one sentence. |
|
Output: List of spans as 4-tuples [label, begin, end, token_set] |
|
""" |
|
spans = [] |
|
for line in sentence: |
|
vals = line.split("\t") |
|
label = vals[0] |
|
if not label: |
|
print("ERROR: Missing label in input.") |
|
return [] |
|
try: |
|
begin = int(vals[1]) |
|
if begin < 1: raise ValueError |
|
except ValueError: |
|
print("ERROR: Begin {0} is not a legal index.".format(vals[1])) |
|
return [] |
|
try: |
|
end = int(vals[2]) |
|
if end < 1: raise ValueError |
|
if end < begin: begin, end = end, begin |
|
except ValueError: |
|
print("ERROR: End {0} is not a legal index.".format(vals[2])) |
|
return [] |
|
try: |
|
toks = [int(v.strip()) for v in vals[-1].split(",") |
|
if int(v.strip()) >= begin and int(v.strip()) <= end] |
|
toks = set(toks) |
|
except ValueError: |
|
toks = [] |
|
if not toks: |
|
toks = [i for i in range(begin, end+1)] |
|
spans.append([label, begin, end, toks]) |
|
return spans |
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(sentence[0].split("\t")) == 4: |
|
format = "spans" |
|
spans = spans_from_lines(sentence) |
|
|
|
|
|
elif len(sentence[0].split(" ")) == 3: |
|
format = "conll2000" |
|
spans = spans_from_conll(sentence) |
|
else: |
|
print("ERROR: Unknown input format") |
|
return [] |
|
|
|
|
|
if format == "conll2000" \ |
|
and config.get("ignore_punct") == True: |
|
|
|
|
|
for i, line in enumerate(sentence): |
|
if line.split(" ")[1] in ["$.", "$,", "$(", |
|
"PUNCT", |
|
"PUNKT", "KOMMA", "COMMA", "KLAMMER", |
|
".", ",", ":", "(", ")", "\"", "‘", "“", "’", "”" |
|
]: |
|
|
|
for s in range(len(spans)): |
|
|
|
spans[s][-1].discard(i+1) |
|
|
|
|
|
if spans[s][1] == i+1: |
|
if spans[s][2] != None and spans[s][2] > i+1: |
|
spans[s][1] = i+2 |
|
else: |
|
spans[s][1] = None |
|
|
|
|
|
if spans[s][2] == i+1: |
|
if spans[s][1] != None and spans[s][1] <= i: |
|
spans[s][2] = i |
|
else: |
|
spans[s][2] = None |
|
|
|
|
|
spans = [s for s in spans if s[1] != None and s[2] != None and len(s[3]) > 0] |
|
|
|
|
|
spans = [s for s in spans |
|
if not s[0] in config.get("exclude", []) |
|
and ("all" in config.get("labels", []) |
|
or s[0] in config.get("labels", []))] |
|
|
|
return spans |
|
|
|
|
|
|
|
def get_sentences(filename): |
|
""" |
|
Reads sentences from input files. |
|
|
|
The function iterates through the input file and |
|
yields a list of lines that belong to one sentence. |
|
Sentences are expected to be separated by an empty line. |
|
|
|
Input: Filename of the input file. |
|
Output: Yields a list of lines for each sentence. |
|
""" |
|
file = open(filename, mode="r", encoding="utf-8") |
|
sent = [] |
|
|
|
for line in file: |
|
|
|
if sent and not line.strip(): |
|
yield sent |
|
sent = [] |
|
|
|
elif not line.strip(): |
|
continue |
|
|
|
else: |
|
sent.append(line.strip()) |
|
|
|
|
|
if sent: |
|
yield sent |
|
|
|
file.close() |
|
|
|
|
|
|
|
def add_dict(base_dict, dict_to_add): |
|
""" |
|
Take a base dictionary and add the values |
|
from another dictionary to it. |
|
|
|
Contrary to standard dict update methods, |
|
this function does not overwrite values in the |
|
base dictionary. Instead, it is meant to add |
|
the values of the second dictionary to the values |
|
in the base dictionary. The dictionary is modified in-place. |
|
|
|
For example: |
|
|
|
>> base = {"A" : 1, "B" : {"c" : 2, "d" : 3}, "C" : [1, 2, 3]} |
|
>> add = {"A" : 1, "B" : {"c" : 1, "e" : 1}, "C" : [4], "D" : 2} |
|
>> add_dict(base, add) |
|
|
|
will create a base dictionary: |
|
|
|
>> base |
|
{'A': 2, 'B': {'c': 3, 'd': 3, 'e': 1}, 'C': [1, 2, 3, 4], 'D': 2} |
|
|
|
The function can handle different types of nested structures. |
|
- Integers and float values are summed up. |
|
- Lists are appended |
|
- Sets are added (set union) |
|
- Dictionaries are added recursively |
|
For other value types, the base dictionary is left unchanged. |
|
|
|
Input: Base dictionary and dictionary to be added. |
|
Output: Base dictionary. |
|
""" |
|
|
|
|
|
for key, val in dict_to_add.items(): |
|
|
|
|
|
if key in base_dict: |
|
|
|
|
|
if isinstance(val, (int, float)) \ |
|
and isinstance(base_dict[key], (int, float)): |
|
|
|
|
|
base_dict[key] += val |
|
|
|
|
|
elif isinstance(val, Iterable) \ |
|
and isinstance(base_dict[key], Iterable): |
|
|
|
|
|
if isinstance(val, list) \ |
|
and isinstance(base_dict[key], list): |
|
|
|
base_dict[key].extend(val) |
|
|
|
|
|
elif isinstance(val, set) \ |
|
and isinstance(base_dict[key], set): |
|
|
|
base_dict[key].update(val) |
|
|
|
|
|
elif isinstance(val, dict) \ |
|
and isinstance(base_dict[key], dict): |
|
|
|
add_dict(base_dict[key], val) |
|
|
|
|
|
else: |
|
|
|
pass |
|
|
|
|
|
else: |
|
|
|
pass |
|
|
|
|
|
else: |
|
|
|
base_dict[key] = deepcopy(val) |
|
|
|
return base_dict |
|
|
|
|
|
|
|
def calculate_results(eval_dict, **config): |
|
""" |
|
Calculate overall precision, recall, and F-scores. |
|
|
|
The function takes an evaluation dictionary with error counts |
|
and applies the precision, recall and fscore functions. |
|
|
|
It will calculate the traditional metrics |
|
and fair and/or weighted metrics, depending on the |
|
value of config['eval_method']. |
|
|
|
The results are stored in the eval dict as 'Prec', 'Rec' and 'F1' |
|
for overall and per-label counts. |
|
|
|
Input: Evaluation dict and config dict. |
|
Output: Evaluation dict with added precision, recall and F1 values. |
|
""" |
|
|
|
|
|
|
|
if "weighted" in config.get("eval_method", []): |
|
eval_dict["overall"]["weighted"] = {} |
|
for err_type in eval_dict["overall"]["fair"]: |
|
eval_dict["overall"]["weighted"][err_type] = eval_dict["overall"]["fair"][err_type] |
|
for label in eval_dict["per_label"]["fair"]: |
|
eval_dict["per_label"]["weighted"][label] = {} |
|
for err_type in eval_dict["per_label"]["fair"][label]: |
|
eval_dict["per_label"]["weighted"][label][err_type] = eval_dict["per_label"]["fair"][label][err_type] |
|
|
|
|
|
for version in config.get("eval_method", ["traditional", "fair"]): |
|
|
|
|
|
eval_dict["overall"][version]["Prec"] = precision(eval_dict["overall"][version], |
|
version, |
|
config.get("weights", {})) |
|
eval_dict["overall"][version]["Rec"] = recall(eval_dict["overall"][version], |
|
version, |
|
config.get("weights", {})) |
|
eval_dict["overall"][version]["F1"] = fscore(eval_dict["overall"][version]) |
|
|
|
|
|
for label in eval_dict["per_label"][version]: |
|
eval_dict["per_label"][version][label]["Prec"] = precision(eval_dict["per_label"][version][label], |
|
version, |
|
config.get("weights", {})) |
|
eval_dict["per_label"][version][label]["Rec"] = recall(eval_dict["per_label"][version][label], |
|
version, |
|
config.get("weights", {})) |
|
eval_dict["per_label"][version][label]["F1"] = fscore(eval_dict["per_label"][version][label]) |
|
|
|
return eval_dict |
|
|
|
|
|
|
|
def output_results(eval_dict, **config): |
|
""" |
|
Write evaluation results to the output (file). |
|
|
|
The function takes an evaluation dict and writes |
|
all results to the specified output (file): |
|
|
|
1. Traditional evaluation results |
|
2. Additional evaluation results (fair and/or weighted) |
|
3. Result comparison for different evaluation methods |
|
4. Confusion matrix |
|
5. Data statistics |
|
|
|
Input: Evaluation dict and config dict. |
|
""" |
|
outfile = config.get("eval_out", sys.stdout) |
|
|
|
|
|
for version in config.get("eval_method", ["traditional", "fair"]): |
|
print(file=outfile) |
|
print("### {0} evaluation:".format(version.title()), file=outfile) |
|
|
|
|
|
if version == "traditional": |
|
cats = ["TP", "FP", "FN"] |
|
elif version == "fair" or not config.get("weights", {}): |
|
cats = ["TP", "FP", "LE", "BE", "LBE", "FN"] |
|
else: |
|
cats = list(config.get("weights").keys()) |
|
|
|
|
|
print("Label", "\t".join(cats), "Prec", "Rec", "F1", sep="\t", file=outfile) |
|
|
|
|
|
for label,val in sorted(eval_dict["per_label"][version].items()): |
|
print(label, |
|
"\t".join([str(val.get(cat, eval_dict["per_label"]["fair"].get(cat, 0))) |
|
for cat in cats]), |
|
"\t".join(["{:04.2f}".format(val.get(metric, 0)*100) |
|
for metric in ["Prec", "Rec", "F1"]]), |
|
sep="\t", file=outfile) |
|
|
|
|
|
print("overall", |
|
"\t".join([str(eval_dict["overall"][version].get(cat, eval_dict["overall"]["fair"].get(cat, 0))) |
|
for cat in cats]), |
|
"\t".join(["{:04.2f}".format(eval_dict["overall"][version].get(metric, 0)*100) |
|
for metric in ["Prec", "Rec", "F1"]]), |
|
sep="\t", file=outfile) |
|
|
|
|
|
print(file=outfile) |
|
print("### Comparison:", file=outfile) |
|
print("Version", "Prec", "Rec", "F1", sep="\t", file=outfile) |
|
for version in config.get("eval_method", ["traditional", "fair"]): |
|
print(version.title(), |
|
"\t".join(["{:04.2f}".format(eval_dict["overall"][version].get(metric, 0)*100) |
|
for metric in ["Prec", "Rec", "F1"]]), |
|
sep="\t", file=outfile) |
|
|
|
|
|
print(file=outfile) |
|
print("### Confusion matrix:", file=outfile) |
|
|
|
|
|
labels = {lab for lab in eval_dict["conf"]} |
|
|
|
|
|
labels = list(labels.union({syslab |
|
for lab in eval_dict["conf"] |
|
for syslab in eval_dict["conf"][lab]})) |
|
|
|
|
|
labels.sort() |
|
|
|
|
|
print(r"Target\System", "\t".join(labels), sep="\t", file=outfile) |
|
|
|
|
|
for targetlab in labels: |
|
print(targetlab, |
|
"\t".join([str(eval_dict["conf"][targetlab].get(syslab, 0)) |
|
for syslab in labels]), |
|
sep="\t", file=outfile) |
|
|
|
|
|
print(file=outfile) |
|
print("### Target data stats:", file=outfile) |
|
print("Label", "Freq", "%", sep="\t", file=outfile) |
|
total = sum(config.get("data_stats", {}).values()) |
|
for lab, freq in config.get("data_stats", {}).items(): |
|
print(lab, freq, "{:04.2f}".format(freq/total*100), sep="\t", file=outfile) |
|
|
|
|
|
if isinstance(config.get("eval_out"), TextIOWrapper): |
|
outfile.close() |
|
|
|
|
|
|
|
def read_config(config_file): |
|
""" |
|
Function to set program parameters as specified in the config file. |
|
|
|
The following parameters are handled: |
|
|
|
- target_in: path to the target file(s) with gold standard annotation |
|
-> output: 'target_files' : [list of target file paths] |
|
|
|
- system_in: path to the system's output file(s), which are evaluated |
|
-> output: 'system_files' : [list of system file paths] |
|
|
|
- eval_out: path or filename, where evaluation results should be stored |
|
if value is a path, output file 'path/eval.csv' is created |
|
if value is 'cmd' or missing, output is set to sys.stdout |
|
-> output: 'eval_out' : output file or sys.stdout |
|
|
|
- labels: comma-separated list of labels to evaluate |
|
defaults to 'all' |
|
-> output: 'labels' : [list of labels as strings] |
|
|
|
- exclude: comma-separated list of labels to exclude from evaluation |
|
always contains 'NONE' and 'EMPTY' |
|
-> output: 'exclude' : [list of labels as strings] |
|
|
|
- ignore_punct: wether to ignore punctuation during evaluation (true/false) |
|
-> output: 'ignore_punct' : True/False |
|
|
|
- focus: wether to focus the evaluation on 'target' or 'system' annotations |
|
defaults to 'target' |
|
-> output: 'focus' : 'target' or 'system' |
|
|
|
- weights: weights that should be applied during calculation of precision |
|
and recall; at the same time can serve as a list of additional |
|
error types to include in the evaluation |
|
the weights are parsed from comma-separated input formulas of the form |
|
|
|
error_type = weight * TP + weight2 * FP + weight3 * FN |
|
|
|
-> output: 'weights' : { 'error type' : { |
|
'TP' : weight, |
|
'FP' : weight, |
|
'FN' : weight |
|
}, |
|
'another error type' : {...} |
|
} |
|
|
|
- eval_method: defines which evaluation method(s) to use |
|
one or more of: 'traditional', 'fair', 'weighted' |
|
if value is 'all' or missing, all available methods are returned |
|
-> output: 'eval_method' : [list of eval methods] |
|
|
|
Input: Filename of the config file. |
|
Output: Settings dictionary. |
|
""" |
|
|
|
|
|
|
|
def _parse_config(key, val): |
|
""" |
|
Internal function to set specific values for the given keys. |
|
In case of illegal values, prints error message and sets key and/or value to None. |
|
Input: Key and value from config file |
|
Output: Modified key and value |
|
""" |
|
if key in ["target_in", "system_in"]: |
|
if os.path.isdir(val): |
|
val = os.path.normpath(val) |
|
files = [os.path.join(val, f) for f in os.listdir(val)] |
|
elif os.path.isfile(val): |
|
files = [os.path.normpath(val)] |
|
else: |
|
print("Error: '{0} = {1}' is not a file/directory.".format(key, val)) |
|
return None, None |
|
if key == "target_in": |
|
return "target_files", files |
|
elif key == "system_in": |
|
return "system_files", files |
|
|
|
elif key == "eval_out": |
|
if os.path.isdir(val): |
|
val = os.path.normpath(val) |
|
outfile = os.path.join(val, "eval.csv") |
|
elif os.path.isfile(val): |
|
outfile = os.path.normpath(val) |
|
elif val == "cmd": |
|
outfile = sys.stdout |
|
else: |
|
try: |
|
p, f = os.path.split(val) |
|
if not os.path.isdir(p): |
|
os.makedirs(p) |
|
outfile = os.path.join(p, f) |
|
except: |
|
print("Error: '{0} = {1}' is not a file/directory.".format(key, val)) |
|
return None, None |
|
return key, outfile |
|
|
|
elif key in ["labels", "exclude"]: |
|
labels = list(set([v.strip() for v in val.split(",") if v.strip()])) |
|
if key == "exclude": |
|
labels.append("NONE") |
|
labels.append("EMPTY") |
|
return key, labels |
|
|
|
elif key == "ignore_punct": |
|
if val.strip().lower() == "false": |
|
return key, False |
|
else: |
|
return key, True |
|
|
|
elif key == "focus": |
|
if val.strip().lower() == "system": |
|
return key, "system" |
|
else: |
|
return key, "target" |
|
|
|
elif key == "weights": |
|
if val == "default": |
|
return key, {"TP" : {"TP" : 1}, |
|
"FP" : {"FP" : 1}, |
|
"FN" : {"FN" : 1}, |
|
"LE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}, |
|
"BE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}, |
|
"LBE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}} |
|
else: |
|
formulas = val.split(",") |
|
weights = {} |
|
|
|
|
|
for f in formulas: |
|
|
|
|
|
error_type = re.match(r"\s*(?P<Error>\w+)\s*=", f) |
|
if error_type == None: |
|
print("WARNING: No error type found in weight formula '{0}'.".format(f)) |
|
continue |
|
else: |
|
error_type = error_type.group("Error") |
|
|
|
weights[error_type] = {} |
|
|
|
|
|
w_tp = re.search(r"(?P<TP>\d*\.?\d+)\s*\*?\s*TP", f) |
|
if w_tp == None: |
|
print("WARNING: Missing weight for TP for error type {0}. Set to 0.".format(error_type)) |
|
weights[error_type]["TP"] = 0 |
|
else: |
|
try: |
|
w_tp = w_tp.group("TP") |
|
w_tp = float(w_tp) |
|
weights[error_type]["TP"] = w_tp |
|
except ValueError: |
|
print("WARNING: Weight for TP for error type {0} is not a number. Set to 0.".format(error_type)) |
|
weights[error_type]["TP"] = 0 |
|
|
|
|
|
w_fp = re.search(r"(?P<FP>\d*\.?\d+)\s*\*?\s*FP", f) |
|
if w_fp == None: |
|
print("WARNING: Missing weight for FP for error type {0}. Set to 0.".format(error_type)) |
|
weights[error_type]["FP"] = 0 |
|
else: |
|
try: |
|
w_fp = w_fp.group("FP") |
|
w_fp = float(w_fp) |
|
weights[error_type]["FP"] = w_fp |
|
except ValueError: |
|
print("WARNING: Weight for FP for error type {0} is not a number. Set to 0.".format(error_type)) |
|
weights[error_type]["FP"] = 0 |
|
|
|
|
|
w_fn = re.search(r"(?P<FN>\d*\.?\d+)\s*\*?\s*FN", f) |
|
if w_fn == None: |
|
print("WARNING: Missing weight for FN for error type {0}. Set to 0.".format(error_type)) |
|
weights[error_type]["FN"] = 0 |
|
else: |
|
try: |
|
w_fn = w_fn.group("FN") |
|
w_fn = float(w_fn) |
|
weights[error_type]["FN"] = w_fn |
|
except ValueError: |
|
print("WARNING: Weight for FN for error type {0} is not a number. Set to 0.".format(error_type)) |
|
weights[error_type]["FN"] = 0 |
|
if weights: |
|
|
|
if not "TP" in weights: |
|
weights["TP"] = {"TP" : 1} |
|
if not "FP" in weights: |
|
weights["FP"] = {"FP" : 1} |
|
if not "FN" in weights: |
|
weights["FN"] = {"FN" : 1} |
|
return key, weights |
|
else: |
|
print("WARNING: No valid weights found. Using default weights.") |
|
return key, {"TP" : {"TP" : 1}, |
|
"FP" : {"FP" : 1}, |
|
"FN" : {"FN" : 1}, |
|
"LE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}, |
|
"BE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}, |
|
"LBE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}} |
|
|
|
elif key == "eval_method": |
|
available_methods = ["traditional", "fair", "weighted"] |
|
if val == "all": |
|
return key, available_methods |
|
else: |
|
methods = [] |
|
for m in available_methods: |
|
if m in [v.strip() for v in val.split(",") |
|
if v.strip() and v.strip().lower() in available_methods]: |
|
methods.append(m) |
|
if methods: |
|
return key, methods |
|
else: |
|
print("WARNING: No evaluation method specified. Applying all methods.") |
|
return key, available_methods |
|
|
|
|
|
|
|
config = dict() |
|
|
|
f = open(config_file, mode="r", encoding="utf-8") |
|
|
|
for line in f: |
|
|
|
line = line.strip() |
|
|
|
|
|
if not line or line.startswith("#"): |
|
continue |
|
|
|
line = line.split("=") |
|
key = line[0].strip() |
|
val = "=".join(line[1:]).strip() |
|
|
|
|
|
if key in ["target_in", "system_in"]: |
|
print("{0}: {1}".format(key, val)) |
|
config[key] = val |
|
|
|
|
|
key, val = _parse_config(key, val) |
|
|
|
|
|
if key is None or val is None: |
|
continue |
|
|
|
|
|
if key in config: |
|
print("WARNING: duplicate config item '{0}' found.".format(key)) |
|
|
|
config[key] = val |
|
|
|
f.close() |
|
|
|
|
|
if not "target_files" in config or not "system_files" in config: |
|
print("ERROR: Cannot evaluate without target AND system file(s). Quitting.") |
|
return None |
|
|
|
|
|
elif config.get("eval_out", None) == None: |
|
config["eval_out"] = sys.stdout |
|
|
|
else: |
|
config["eval_out"] = open(config.get("eval_out"), mode="w", encoding="utf-8") |
|
|
|
|
|
if config.get("labels", None) == None: |
|
config["labels"] = ["all"] |
|
|
|
if config.get("eval_method", None) == None: |
|
config["eval_method"] = ["traditional", "fair", "weighted"] |
|
if not config.get("weights", {}) and "weighted" in config.get("eval_method"): |
|
if not "fair" in config["eval_method"]: |
|
config["eval_method"].append("fair") |
|
del config["eval_method"][config["eval_method"].index("weighted")] |
|
|
|
|
|
print("### Evaluation settings:", file=config.get("eval_out")) |
|
for key in sorted(config.keys()): |
|
if key in ["target_files", "system_files", "eval_out"]: |
|
continue |
|
print("{0}: {1}".format(key, config.get(key)), file=config.get("eval_out")) |
|
print(file=config.get("eval_out")) |
|
|
|
return config |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--config', help='Configuration File', required=True) |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
config = read_config(args.config) |
|
|
|
|
|
eval_dict = {"overall" : {"traditional" : {}, "fair" : {}}, |
|
"per_label" : {"traditional" : {}, "fair" : {}}, |
|
"conf" : {}} |
|
for method in config.get("eval_method", ["traditional", "fair"]): |
|
eval_dict["overall"][method] = {} |
|
eval_dict["per_label"][method] = {} |
|
|
|
|
|
config["data_stats"] = {} |
|
|
|
|
|
|
|
file_pairs = [] |
|
for t in config.get("target_files", []): |
|
s = [f for f in config.get("system_files", []) |
|
if os.path.split(t)[-1] == os.path.split(f)[-1]] |
|
if s: |
|
file_pairs.append((t, s[0])) |
|
|
|
|
|
for target_file, system_file in file_pairs: |
|
|
|
|
|
for target_sentence, system_sentence in zip(get_sentences(target_file), |
|
get_sentences(system_file)): |
|
|
|
|
|
target_spans = get_spans(target_sentence, **config) |
|
system_spans = get_spans(system_sentence, **config) |
|
|
|
|
|
|
|
annotation_stats(target_spans, **config) |
|
|
|
|
|
sent_counts = compare_spans(target_spans, system_spans, |
|
config.get("focus", "target")) |
|
|
|
|
|
eval_dict = add_dict(eval_dict, sent_counts) |
|
|
|
|
|
eval_dict = calculate_results(eval_dict, **config) |
|
|
|
|
|
output_results(eval_dict, **config) |