illorca commited on
Commit
c5d83eb
1 Parent(s): 02e3311

Re-upload project

Browse files
Files changed (6) hide show
  1. .gitattributes +33 -0
  2. FairEval.py +1651 -0
  3. README.md +96 -0
  4. app.py +6 -0
  5. fairevaluation.py +237 -0
  6. requirements.txt +3 -0
.gitattributes ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.npy filter=lfs diff=lfs merge=lfs -text
14
+ *.npz filter=lfs diff=lfs merge=lfs -text
15
+ *.onnx filter=lfs diff=lfs merge=lfs -text
16
+ *.ot filter=lfs diff=lfs merge=lfs -text
17
+ *.parquet filter=lfs diff=lfs merge=lfs -text
18
+ *.pb filter=lfs diff=lfs merge=lfs -text
19
+ *.pickle filter=lfs diff=lfs merge=lfs -text
20
+ *.pkl filter=lfs diff=lfs merge=lfs -text
21
+ *.pt filter=lfs diff=lfs merge=lfs -text
22
+ *.pth filter=lfs diff=lfs merge=lfs -text
23
+ *.rar filter=lfs diff=lfs merge=lfs -text
24
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
25
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
27
+ *.tflite filter=lfs diff=lfs merge=lfs -text
28
+ *.tgz filter=lfs diff=lfs merge=lfs -text
29
+ *.wasm filter=lfs diff=lfs merge=lfs -text
30
+ *.xz filter=lfs diff=lfs merge=lfs -text
31
+ *.zip filter=lfs diff=lfs merge=lfs -text
32
+ *.zst filter=lfs diff=lfs merge=lfs -text
33
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
FairEval.py ADDED
@@ -0,0 +1,1651 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ '''
4
+ Created 09/2021
5
+
6
+ @author: Katrin Ortmann
7
+ '''
8
+
9
+ import argparse
10
+ import os
11
+ import sys
12
+ import re
13
+ from typing import Iterable
14
+ from io import TextIOWrapper
15
+ from copy import deepcopy
16
+
17
+ #####################################
18
+
19
+ def precision(evaldict, version="traditional", weights={}):
20
+ """
21
+ Calculate traditional, fair or weighted precision value.
22
+
23
+ Precision is calculated as the number of true positives
24
+ divided by the number of true positives plus false positives
25
+ plus (optionally) additional error types.
26
+
27
+ Input:
28
+ - A dictionary with error types as keys and counts as values, e.g.,
29
+ {"TP" : 10, "FP" : 2, "LE" : 1, ...}
30
+
31
+ For 'traditional' evaluation, true positives (key: TP) and
32
+ false positives (key: FP) are required.
33
+ The 'fair' evaluation is based on true positives (TP),
34
+ false positives (FP), labeling errors (LE), boundary errors (BE)
35
+ and labeling-boundary errors (LBE).
36
+ The 'weighted' evaluation can include any error type
37
+ that is given as key in the weight dictionary.
38
+ For missing keys, the count is set to 0.
39
+
40
+ - The desired evaluation method. Options are 'traditional',
41
+ 'fair', and 'weighted'. If no weight dictionary is specified,
42
+ 'weighted' is identical to 'fair'.
43
+
44
+ - A weight dictionary to specify how much an error type should
45
+ count as one of the traditional error types (or as true positive).
46
+ Per default, every traditional error is counted as one error (or true positive)
47
+ and each error of the additional types is counted as half false positive and half false negative:
48
+
49
+ {"TP" : {"TP" : 1},
50
+ "FP" : {"FP" : 1},
51
+ "FN" : {"FN" : 1},
52
+ "LE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5},
53
+ "BE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5},
54
+ "LBE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}}
55
+
56
+ Other suggested weights to count boundary errors as half true positives:
57
+
58
+ {"TP" : {"TP" : 1},
59
+ "FP" : {"FP" : 1},
60
+ "FN" : {"FN" : 1},
61
+ "LE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5},
62
+ "BE" : {"TP" : 0.5, "FP" : 0.25, "FN" : 0.25},
63
+ "LBE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}}
64
+
65
+ Or to include different types of boundary errors:
66
+
67
+ {"TP" : {"TP" : 1},
68
+ "FP" : {"FP" : 1},
69
+ "FN" : {"FN" : 1},
70
+ "LE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5},
71
+ "LBE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5},
72
+ "BEO" : {"TP" : 0.5, "FP" : 0.25, "FN" : 0.25},
73
+ "BES" : {"TP" : 0.5, "FP" : 0, "FN" : 0.5},
74
+ "BEL" : {"TP" : 0.5, "FP" : 0.5, "FN" : 0}}
75
+
76
+ Output:
77
+ The precision for the given input values.
78
+ In case of a ZeroDivisionError, the precision is set to 0.
79
+
80
+ """
81
+ traditional_weights = {
82
+ "TP" : {"TP" : 1},
83
+ "FP" : {"FP" : 1},
84
+ "FN" : {"FN" : 1}
85
+ }
86
+ default_fair_weights = {
87
+ "TP" : {"TP" : 1},
88
+ "FP" : {"FP" : 1},
89
+ "FN" : {"FN" : 1},
90
+ "LE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5},
91
+ "BE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5},
92
+ "LBE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}
93
+ }
94
+ try:
95
+ tp = 0
96
+ fp = 0
97
+
98
+ #Set default weights for traditional evaluation
99
+ if version == "traditional":
100
+ weights = traditional_weights
101
+
102
+ #Set weights to default
103
+ #for fair evaluation or if no weights are given
104
+ elif version == "fair" or not weights:
105
+ weights = default_fair_weights
106
+
107
+ #Add weighted errors to true positive count
108
+ tp += sum(
109
+ [w.get("TP", 0) * evaldict.get(error, 0) for error, w in weights.items()]
110
+ )
111
+
112
+ #Add weighted errors to false positive count
113
+ fp += sum(
114
+ [w.get("FP", 0) * evaldict.get(error, 0) for error, w in weights.items()]
115
+ )
116
+
117
+ #Calculate precision
118
+ return tp / (tp + fp)
119
+
120
+ #Output 0 if there is neither true nor false positives
121
+ except ZeroDivisionError:
122
+ return 0.0
123
+
124
+ ######################
125
+
126
+ def recall(evaldict, version="traditional", weights={}):
127
+ """
128
+ Calculate traditional, fair or weighted recall value.
129
+
130
+ Recall is calculated as the number of true positives
131
+ divided by the number of true positives plus false negatives
132
+ plus (optionally) additional error types.
133
+
134
+ Input:
135
+ - A dictionary with error types as keys and counts as values, e.g.,
136
+ {"TP" : 10, "FN" : 2, "LE" : 1, ...}
137
+
138
+ For 'traditional' evaluation, true positives (key: TP) and
139
+ false negatives (key: FN) are required.
140
+ The 'fair' evaluation is based on true positives (TP),
141
+ false negatives (FN), labeling errors (LE), boundary errors (BE)
142
+ and labeling-boundary errors (LBE).
143
+ The 'weighted' evaluation can include any error type
144
+ that is given as key in the weight dictionary.
145
+ For missing keys, the count is set to 0.
146
+
147
+ - The desired evaluation method. Options are 'traditional',
148
+ 'fair', and 'weighted'. If no weight dictionary is specified,
149
+ 'weighted' is identical to 'fair'.
150
+
151
+ - A weight dictionary to specify how much an error type should
152
+ count as one of the traditional error types (or as true positive).
153
+ Per default, every traditional error is counted as one error (or true positive)
154
+ and each error of the additional types is counted as half false positive and half false negative:
155
+
156
+ {"TP" : {"TP" : 1},
157
+ "FP" : {"FP" : 1},
158
+ "FN" : {"FN" : 1},
159
+ "LE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5},
160
+ "BE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5},
161
+ "LBE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}}
162
+
163
+ Other suggested weights to count boundary errors as half true positives:
164
+
165
+ {"TP" : {"TP" : 1},
166
+ "FP" : {"FP" : 1},
167
+ "FN" : {"FN" : 1},
168
+ "LE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5},
169
+ "BE" : {"TP" : 0.5, "FP" : 0.25, "FN" : 0.25},
170
+ "LBE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}}
171
+
172
+ Or to include different types of boundary errors:
173
+
174
+ {"TP" : {"TP" : 1},
175
+ "FP" : {"FP" : 1},
176
+ "FN" : {"FN" : 1},
177
+ "LE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5},
178
+ "LBE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5},
179
+ "BEO" : {"TP" : 0.5, "FP" : 0.25, "FN" : 0.25},
180
+ "BES" : {"TP" : 0.5, "FP" : 0, "FN" : 0.5},
181
+ "BEL" : {"TP" : 0.5, "FP" : 0.5, "FN" : 0}}
182
+
183
+ Output:
184
+ The recall for the given input values.
185
+ In case of a ZeroDivisionError, the recall is set to 0.
186
+
187
+ """
188
+ traditional_weights = {
189
+ "TP" : {"TP" : 1},
190
+ "FP" : {"FP" : 1},
191
+ "FN" : {"FN" : 1}
192
+ }
193
+ default_fair_weights = {
194
+ "TP" : {"TP" : 1},
195
+ "FP" : {"FP" : 1},
196
+ "FN" : {"FN" : 1},
197
+ "LE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5},
198
+ "BE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5},
199
+ "LBE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}
200
+ }
201
+ try:
202
+ tp = 0
203
+ fn = 0
204
+
205
+ #Set default weights for traditional evaluation
206
+ if version == "traditional":
207
+ weights = traditional_weights
208
+
209
+ #Set weights to default
210
+ #for fair evaluation or if no weights are given
211
+ elif version == "fair" or not weights:
212
+ weights = default_fair_weights
213
+
214
+ #Add weighted errors to true positive count
215
+ tp += sum(
216
+ [w.get("TP", 0) * evaldict.get(error, 0) for error, w in weights.items()]
217
+ )
218
+
219
+ #Add weighted errors to false negative count
220
+ fn += sum(
221
+ [w.get("FN", 0) * evaldict.get(error, 0) for error, w in weights.items()]
222
+ )
223
+
224
+ #Calculate recall
225
+ return tp / (tp + fn)
226
+
227
+ #Return zero if there are neither true positives nor false negatives
228
+ except ZeroDivisionError:
229
+ return 0.0
230
+
231
+ ######################
232
+
233
+ def fscore(evaldict):
234
+ """
235
+ Calculates F1-Score from given precision and recall values.
236
+
237
+ Input: A dictionary with a precision (key: Prec) and recall (key: Rec) value.
238
+ Output: The F1-Score. In case of a ZeroDivisionError, the F1-Score is set to 0.
239
+ """
240
+ try:
241
+ return 2 * (evaldict.get("Prec", 0) * evaldict.get("Rec", 0)) \
242
+ / (evaldict.get("Prec", 0) + evaldict.get("Rec", 0))
243
+ except ZeroDivisionError:
244
+ return 0.0
245
+
246
+ ######################
247
+
248
+ def overlap_type(span1, span2):
249
+ """
250
+ Determine the error type of two (overlapping) spans.
251
+
252
+ The function checks, if and how span1 and span2 overlap.
253
+ The first span serves as the basis against which the second
254
+ span is evaluated.
255
+
256
+ span1 ---XXXX---
257
+ span2 ---XXXX--- TP (identical)
258
+ span2 ----XXXX-- BEO (overlap)
259
+ span2 --XXXX---- BEO (overlap)
260
+ span2 ----XX---- BES (smaller)
261
+ span2 ---XX----- BES (smaller)
262
+ span2 --XXXXXX-- BEL (larger)
263
+ span2 --XXXXX--- BEL (larger)
264
+ span2 -X-------- False (no overlap)
265
+
266
+ Input:
267
+ Tuples (beginSpan1, endSpan1) and (beginSpan2, endSpan2),
268
+ where begin and end are the indices of the corresponding tokens.
269
+
270
+ Output:
271
+ Either one of the following strings
272
+ - "TP" = span1 and span2 are identical, i.e., actually no error here
273
+ - "BES" = span2 is shorter and contained within span1 (with at most one identical boundary)
274
+ - "BEL" = span2 is longer and contains span1 (with at most one identical boundary)
275
+ - "BEO" = span1 and span2 overlap with no identical boundary
276
+ or False if span1 and span2 do not overlap.
277
+ """
278
+ #Identical spans
279
+ if span1[0] == span2[0] and span1[1] == span2[1]:
280
+ return "TP"
281
+
282
+ #Start of spans is identical
283
+ if span1[0] == span2[0]:
284
+ #End of 2 is within span1
285
+ if span2[1] >= span1[0] and span2[1] < span1[1]:
286
+ return "BES"
287
+ #End of 2 is behind span1
288
+ else:
289
+ return "BEL"
290
+ #Start of 2 is before span1
291
+ elif span2[0] < span1[0]:
292
+ #End is before span 1
293
+ if span2[1] < span1[0]:
294
+ return False
295
+ #End is within span1
296
+ elif span2[1] < span1[1]:
297
+ return "BEO"
298
+ #End is identical or to the right
299
+ else:
300
+ return "BEL"
301
+ #Start of 2 is within span1
302
+ elif span2[0] >= span1[0] and span2[0] <= span1[1]:
303
+ #End of 2 is wihtin span1
304
+ if span2[1] <= span1[1]:
305
+ return "BES"
306
+ #End of 2 is to the right
307
+ else:
308
+ return "BEO"
309
+ #Start of 2 is behind span1
310
+ else:
311
+ return False
312
+
313
+ #####################################
314
+
315
+ def compare_spans(target_spans, system_spans, focus="target"):
316
+ """
317
+ Compare system and target spans to identify correct/incorrect annotations.
318
+
319
+ The function takes a list of target spans and system spans.
320
+ Each span is a 4-tuple of
321
+ - label: the span type as string
322
+ - begin: the index of first token; equals end for spans of length 1
323
+ - end: the index of the last token; equals begin for spans of length 1
324
+ - tokens: a set of token indices included in the span
325
+ (this allows the correct evaluation of
326
+ partially and multiply overlapping spans;
327
+ to allow for changes of the token set,
328
+ the span tuple is actually implemented as a list.)
329
+
330
+ The function first performs traditional evaluation on these spans
331
+ to identify true positives, false positives, and false negatives.
332
+ Then, the additional error types for fair evaluation are determined,
333
+ following steps 1 to 4:
334
+ 1. Count 1:1 mappings (TP, LE)
335
+ 2. Count boundary errors (BE = BES + BEL + BEO)
336
+ 3. Count labeling-boundary errors (LBE)
337
+ 4. Count 1:0 and 0:1 mappings (FN, FP)
338
+
339
+ Input:
340
+ - List of target spans
341
+ - List of system spans
342
+ - Wether to focus on the system or target annotation (default: target)
343
+
344
+ Output: A dictionary containing
345
+ - the counts of TP, FP, and FN according to traditional evaluation
346
+ (per label and overall)
347
+ - the counts of TP, FP, LE, BE, BES, BEL, BEO, and FN
348
+ (per label and overall; BE = BES + BEL + BEO)
349
+ - a confusion matrix {target_label1 : {system_label1 : count,
350
+ system_label2 : count,
351
+ ...},
352
+ target_label2 : ...
353
+ }
354
+ with an underscore '_' representing an empty label (FN/FP)
355
+ """
356
+
357
+ ##################################
358
+
359
+ def _max_sim(t, S):
360
+ """
361
+ Determine the most similar span s from S for span t.
362
+
363
+ Similarity is defined as
364
+ 1. the maximum number of shared tokens between s and t and
365
+ 2. the minimum number of tokens only in t
366
+ If multiple spans are equally similar, the shortest s is chosen.
367
+ If still multiple spans are equally similar, the first one in the list is chosen,
368
+ which corresponds to the left-most one if sentences are read from left to right.
369
+
370
+ Input:
371
+ - Span t as 4-tuple [label, begin, end, token_set]
372
+ - List S containing > 1 spans
373
+
374
+ Output: The most similar s for t.
375
+ """
376
+ S.sort(key=lambda s: (0-len(t[3].intersection(s[3])),
377
+ len(t[3].difference(s[3])),
378
+ len(s[3].difference(t[3])),
379
+ s[2]-s[1]))
380
+ return S[0]
381
+
382
+ ##################################
383
+
384
+ traditional_error_types = ["TP", "FP", "FN"]
385
+ additional_error_types = ["LE", "BE", "BEO", "BES", "BEL", "LBE"]
386
+
387
+ #Initialize empty eval dict
388
+ eval_dict = {"overall" : {"traditional" : {err_type : 0 for err_type
389
+ in traditional_error_types},
390
+ "fair" : {err_type : 0 for err_type
391
+ in traditional_error_types + additional_error_types}},
392
+ "per_label" : {"traditional" : {},
393
+ "fair" : {}},
394
+ "conf" : {}}
395
+
396
+ #Initialize per-label dict
397
+ for s in target_spans + system_spans:
398
+ if not s[0] in eval_dict["per_label"]["traditional"]:
399
+ eval_dict["per_label"]["traditional"][s[0]] = {err_type : 0 for err_type
400
+ in traditional_error_types}
401
+ eval_dict["per_label"]["fair"][s[0]] = {err_type : 0 for err_type
402
+ in traditional_error_types + additional_error_types}
403
+ #Initialize confusion matrix
404
+ if not s[0] in eval_dict["conf"]:
405
+ eval_dict["conf"][s[0]] = {}
406
+ eval_dict["conf"]["_"] = {}
407
+ for lab in list(eval_dict["conf"])+["_"]:
408
+ for lab2 in list(eval_dict["conf"])+["_"]:
409
+ eval_dict["conf"][lab][lab2] = 0
410
+
411
+ ################################################
412
+ ### Traditional evaluation (overall + per label)
413
+
414
+ for t in target_spans:
415
+ #Spans in target and system annotation are true positives
416
+ if t in system_spans:
417
+ eval_dict["overall"]["traditional"]["TP"] += 1
418
+ eval_dict["per_label"]["traditional"][t[0]]["TP"] += 1
419
+ #Spans only in target annotation are false negatives
420
+ else:
421
+ eval_dict["overall"]["traditional"]["FN"] += 1
422
+ eval_dict["per_label"]["traditional"][t[0]]["FN"] += 1
423
+ for s in system_spans:
424
+ #Spans only in system annotation are false positives
425
+ if not s in target_spans:
426
+ eval_dict["overall"]["traditional"]["FP"] += 1
427
+ eval_dict["per_label"]["traditional"][s[0]]["FP"] += 1
428
+
429
+ ###########################################################
430
+ ### Fair evaluation (overall, per label + confusion matrix)
431
+
432
+ ### Identical spans (TP and LE)
433
+
434
+ ### TP
435
+ #Identify true positives (identical spans between target and system)
436
+ tps = [t for t in target_spans if t in system_spans]
437
+ for t in tps:
438
+ s = [s for s in system_spans if s == t]
439
+ if s:
440
+ s = s[0]
441
+ eval_dict["overall"]["fair"]["TP"] += 1
442
+ eval_dict["per_label"]["fair"][t[0]]["TP"] += 1
443
+ #After counting, remove from input lists
444
+ system_spans.remove(s)
445
+ target_spans.remove(t)
446
+
447
+ ### LE
448
+ #Identify labeling error: identical span but different label
449
+ les = [t for t in target_spans
450
+ if any(t[0] != s[0] and t[1:3] == s[1:3] for s in system_spans)]
451
+ for t in les:
452
+ s = [s for s in system_spans if t[0] != s[0] and t[1:3] == s[1:3]]
453
+ if s:
454
+ s = s[0]
455
+ #Overall: count as one LE
456
+ eval_dict["overall"]["fair"]["LE"] += 1
457
+ #Per label: depending on focus count for target label or system label
458
+ if focus == "target":
459
+ eval_dict["per_label"]["fair"][t[0]]["LE"] += 1
460
+ elif focus == "system":
461
+ eval_dict["per_label"]["fair"][s[0]]["LE"] += 1
462
+ #Add to confusion matrix
463
+ eval_dict["conf"][t[0]][s[0]] += 1
464
+ #After counting, remove from input lists
465
+ system_spans.remove(s)
466
+ target_spans.remove(t)
467
+
468
+ ### Boundary errors
469
+
470
+ #Create lists to collect matched spans
471
+ counted_target = list()
472
+ counted_system = list()
473
+
474
+ #Sort lists by span length (shortest to longest)
475
+ target_spans.sort(key=lambda t : t[2] - t[1])
476
+ system_spans.sort(key=lambda s : s[2] - s[1])
477
+
478
+ ### BE
479
+
480
+ ## 1. Compare input lists
481
+ #Identify boundary errors: identical label but different, overlapping span
482
+ i = 0
483
+ while i < len(target_spans):
484
+ t = target_spans[i]
485
+
486
+ #Find possible boundary errors
487
+ be = [s for s in system_spans
488
+ if t[0] == s[0] and t[1:3] != s[1:3]
489
+ and overlap_type((t[1], t[2]), (s[1], s[2])) in ("BES", "BEL", "BEO")]
490
+ if not be:
491
+ i += 1
492
+ continue
493
+
494
+ #If there is more than one possible BE, take most similar one
495
+ if len(be) > 1:
496
+ s = _max_sim(t, be)
497
+ else:
498
+ s = be[0]
499
+
500
+ #Determine overlap type
501
+ be_type = overlap_type((t[1], t[2]), (s[1], s[2]))
502
+
503
+ #Overall: Count as BE and more fine-grained BE type
504
+ eval_dict["overall"]["fair"]["BE"] += 1
505
+ eval_dict["overall"]["fair"][be_type] += 1
506
+
507
+ #Per-label: count as general BE and specific BE type
508
+ eval_dict["per_label"]["fair"][t[0]]["BE"] += 1
509
+ eval_dict["per_label"]["fair"][t[0]][be_type] += 1
510
+
511
+ #Add to confusion matrix
512
+ eval_dict["conf"][t[0]][s[0]] += 1
513
+
514
+ #Remove matched spans from input list
515
+ system_spans.remove(s)
516
+ target_spans.remove(t)
517
+
518
+ #Remove matched tokens from spans
519
+ matching_tokens = t[3].intersection(s[3])
520
+ s[3] = s[3].difference(matching_tokens)
521
+ t[3] = t[3].difference(matching_tokens)
522
+
523
+ #Move matched spans to counted list
524
+ counted_system.append(s)
525
+ counted_target.append(t)
526
+
527
+ ## 2. Compare input target list with matched system list
528
+ i = 0
529
+ while i < len(target_spans):
530
+ t = target_spans[i]
531
+
532
+ #Find possible boundary errors in already matched spans
533
+ #that still share unmatched tokens
534
+ be = [s for s in counted_system
535
+ if t[0] == s[0] and t[1:3] != s[1:3]
536
+ and overlap_type((t[1], t[2]), (s[1], s[2])) in ("BES", "BEL", "BEO")
537
+ and t[3].intersection(s[3])]
538
+ if not be:
539
+ i += 1
540
+ continue
541
+
542
+ #If there is more than one possible BE, take most similar one
543
+ if len(be) > 1:
544
+ s = _max_sim(t, be)
545
+ else:
546
+ s = be[0]
547
+
548
+ #Determine overlap type
549
+ be_type = overlap_type((t[1], t[2]), (s[1], s[2]))
550
+
551
+ #Overall: Count as BE and more fine-grained BE type
552
+ eval_dict["overall"]["fair"]["BE"] += 1
553
+ eval_dict["overall"]["fair"][be_type] += 1
554
+
555
+ #Per-label: count as general BE and specific BE type
556
+ eval_dict["per_label"]["fair"][t[0]]["BE"] += 1
557
+ eval_dict["per_label"]["fair"][t[0]][be_type] += 1
558
+
559
+ #Add to confusion matrix
560
+ eval_dict["conf"][t[0]][s[0]] += 1
561
+
562
+ #Remove matched span from input list
563
+ target_spans.remove(t)
564
+
565
+ #Remove matched tokens from spans
566
+ matching_tokens = t[3].intersection(s[3])
567
+ counted_system[counted_system.index(s)][3] = s[3].difference(matching_tokens)
568
+ t[3] = t[3].difference(matching_tokens)
569
+
570
+ #Move target span to counted list
571
+ counted_target.append(t)
572
+
573
+ ## 3. Compare input system list with matched target list
574
+ i = 0
575
+ while i < len(system_spans):
576
+ s = system_spans[i]
577
+
578
+ #Find possible boundary errors in already matched target spans
579
+ be = [t for t in counted_target
580
+ if t[0] == s[0] and t[1:3] != s[1:3]
581
+ and overlap_type((t[1], t[2]), (s[1], s[2])) in ("BES", "BEL", "BEO")
582
+ and t[3].intersection(s[3])]
583
+ if not be:
584
+ i += 1
585
+ continue
586
+
587
+ #If there is more than one possible BE, take most similar one
588
+ if len(be) > 1:
589
+ t = _max_sim(s, be)
590
+ else:
591
+ t = be[0]
592
+
593
+ #Determine overlap type
594
+ be_type = overlap_type((t[1], t[2]), (s[1], s[2]))
595
+
596
+ #Overall: Count as BE and more fine-grained BE type
597
+ eval_dict["overall"]["fair"]["BE"] += 1
598
+ eval_dict["overall"]["fair"][be_type] += 1
599
+
600
+ #Per-label: count as general BE and specific BE type
601
+ eval_dict["per_label"]["fair"][t[0]]["BE"] += 1
602
+ eval_dict["per_label"]["fair"][t[0]][be_type] += 1
603
+
604
+ #Add to confusion matrix
605
+ eval_dict["conf"][t[0]][s[0]] += 1
606
+
607
+ #Remove matched span from input list
608
+ system_spans.remove(s)
609
+
610
+ #Remove matched tokens from spans
611
+ matching_tokens = t[3].intersection(s[3])
612
+ counted_target[counted_target.index(t)][3] = t[3].difference(matching_tokens)
613
+ s[3] = s[3].difference(matching_tokens)
614
+
615
+ #Move system span to counted list
616
+ counted_system.append(s)
617
+
618
+ ### LBE
619
+
620
+ ## 1. Compare input lists
621
+ #Identify labeling-boundary errors: different label but overlapping span
622
+ i = 0
623
+ while i < len(target_spans):
624
+ t = target_spans[i]
625
+
626
+ #Find possible boundary errors
627
+ lbe = [s for s in system_spans
628
+ if t[0] != s[0] and t[1:3] != s[1:3]
629
+ and overlap_type((t[1], t[2]), (s[1], s[2])) in ("BES", "BEL", "BEO")]
630
+ if not lbe:
631
+ i += 1
632
+ continue
633
+
634
+ #If there is more than one possible LBE, take most similar one
635
+ if len(lbe) > 1:
636
+ s = _max_sim(t, lbe)
637
+ else:
638
+ s = lbe[0]
639
+
640
+ #Overall: count as LBE
641
+ eval_dict["overall"]["fair"]["LBE"] += 1
642
+
643
+ #Per label: depending on focus count as LBE for target or system label
644
+ if focus == "target":
645
+ eval_dict["per_label"]["fair"][t[0]]["LBE"] += 1
646
+ elif focus == "system":
647
+ eval_dict["per_label"]["fair"][s[0]]["LBE"] += 1
648
+
649
+ #Add to confusion matrix
650
+ eval_dict["conf"][t[0]][s[0]] += 1
651
+
652
+ #Remove matched spans from input list
653
+ system_spans.remove(s)
654
+ target_spans.remove(t)
655
+
656
+ #Remove matched tokens from spans
657
+ matching_tokens = t[3].intersection(s[3])
658
+ s[3] = s[3].difference(matching_tokens)
659
+ t[3] = t[3].difference(matching_tokens)
660
+
661
+ #Move spans to counted lists
662
+ counted_system.append(s)
663
+ counted_target.append(t)
664
+
665
+ ## 2. Compare input target list with matched system list
666
+ i = 0
667
+ while i < len(target_spans):
668
+ t = target_spans[i]
669
+
670
+ #Find possible labeling-boundary errors in already matched system spans
671
+ lbe = [s for s in counted_system
672
+ if t[0] != s[0] and t[1:3] != s[1:3]
673
+ and overlap_type((t[1], t[2]), (s[1], s[2])) in ("BES", "BEL", "BEO")
674
+ and t[3].intersection(s[3])]
675
+ if not lbe:
676
+ i += 1
677
+ continue
678
+
679
+ #If there is more than one possible LBE, take most similar one
680
+ if len(lbe) > 1:
681
+ s = _max_sim(t, lbe)
682
+ else:
683
+ s = lbe[0]
684
+
685
+ #Overall: count as LBE
686
+ eval_dict["overall"]["fair"]["LBE"] += 1
687
+
688
+ #Per label: depending on focus count as LBE for target or system label
689
+ if focus == "target":
690
+ eval_dict["per_label"]["fair"][t[0]]["LBE"] += 1
691
+ elif focus == "system":
692
+ eval_dict["per_label"]["fair"][s[0]]["LBE"] += 1
693
+
694
+ #Add to confusion matrix
695
+ eval_dict["conf"][t[0]][s[0]] += 1
696
+
697
+ #Remove matched span from input list
698
+ target_spans.remove(t)
699
+
700
+ #Remove matched tokens from spans
701
+ matching_tokens = t[3].intersection(s[3])
702
+ counted_system[counted_system.index(s)][3] = s[3].difference(matching_tokens)
703
+ t[3] = t[3].difference(matching_tokens)
704
+
705
+ #Move target span to counted list
706
+ counted_target.append(t)
707
+
708
+ ## 3. Compare input system list with matched target list
709
+ i = 0
710
+ while i < len(system_spans):
711
+ s = system_spans[i]
712
+
713
+ #Find possible labeling-boundary errors in already matched target spans
714
+ lbe = [t for t in counted_target
715
+ if t[0] != s[0] and t[1:3] != s[1:3]
716
+ and overlap_type((t[1], t[2]), (s[1], s[2])) in ("BES", "BEL", "BEO")
717
+ and t[3].intersection(s[3])]
718
+ if not lbe:
719
+ i += 1
720
+ continue
721
+
722
+ #If there is more than one possible LBE, take most similar one
723
+ if len(lbe) > 1:
724
+ t = _max_sim(s, lbe)
725
+ else:
726
+ t = lbe[0]
727
+
728
+ #Overall: count as LBE
729
+ eval_dict["overall"]["fair"]["LBE"] += 1
730
+
731
+ #Per label: depending on focus count as LBE for target or system label
732
+ if focus == "target":
733
+ eval_dict["per_label"]["fair"][t[0]]["LBE"] += 1
734
+ elif focus == "system":
735
+ eval_dict["per_label"]["fair"][s[0]]["LBE"] += 1
736
+
737
+ #Add to confusion matrix
738
+ eval_dict["conf"][t[0]][s[0]] += 1
739
+
740
+ #Remove matched span from input list
741
+ system_spans.remove(s)
742
+
743
+ #Remove matched tokens from spans
744
+ matching_tokens = t[3].intersection(s[3])
745
+ counted_target[counted_target.index(t)][3] = t[3].difference(matching_tokens)
746
+ s[3] = s[3].difference(matching_tokens)
747
+
748
+ #Move matched system span to counted list
749
+ counted_system.append(s)
750
+
751
+ ### 1:0 and 0:1 mappings
752
+
753
+ #FN: identify false negatives
754
+ for t in target_spans:
755
+ eval_dict["overall"]["fair"]["FN"] += 1
756
+ eval_dict["per_label"]["fair"][t[0]]["FN"] += 1
757
+ eval_dict["conf"][t[0]]["_"] += 1
758
+
759
+ #FP: identify false positives
760
+ for s in system_spans:
761
+ eval_dict["overall"]["fair"]["FP"] += 1
762
+ eval_dict["per_label"]["fair"][s[0]]["FP"] += 1
763
+ eval_dict["conf"]["_"][s[0]] += 1
764
+
765
+ return eval_dict
766
+
767
+ ############################
768
+
769
+ def annotation_stats(target_spans, **config):
770
+ """
771
+ Count the target annotations to display simple statistics.
772
+
773
+ The function takes a list of target spans
774
+ with each span being a 4-tuple [label, begin, end, token_set]
775
+ and adds the included labels to the general data stats dictionary.
776
+
777
+ Input:
778
+ - List of target spans
779
+ - Config dictionary
780
+
781
+ Output: The config dictionary is modified in-place.
782
+ """
783
+ stats_dict = config.get("data_stats", {})
784
+ for span in target_spans:
785
+ if span[0] in stats_dict:
786
+ stats_dict[span[0]] += 1
787
+ else:
788
+ stats_dict[span[0]] = 1
789
+ config["data_stats"] = stats_dict
790
+
791
+ ############################
792
+
793
+ def get_spans(sentence, **config):
794
+ """
795
+ Return spans from CoNLL2000 or span files.
796
+
797
+ The function determines the data format of the input sentence
798
+ and extracts the spans from it accordingly.
799
+
800
+ If desired, punctuation can be ignored (config['ignore_punct'] == True)
801
+ for files in the CoNLL2000 format that include POS information.
802
+ The following list of tags is considered as punctuation:
803
+ ['$.', '$,', '$(', #STTS
804
+ 'PUNCT', #UPOS
805
+ 'PUNKT', 'KOMMA', 'COMMA', 'KLAMMER', #custom
806
+ '.', ',', ':', '(', ')', '"', '‘', '“', '’', '”' #PTB
807
+ ]
808
+
809
+ Labels that should be ignored (included in config['exclude']
810
+ or not included in config['labels'] if config['labels'] != 'all')
811
+ are also removed from the resulting list.
812
+
813
+ Input:
814
+ - List of lines for a given sentence
815
+ - Config dictionary
816
+
817
+ Output: List of spans that are included in the sentence.
818
+ """
819
+
820
+ ################
821
+
822
+ def spans_from_conll(sentence):
823
+ """
824
+ Read annotation spans from a CoNLL2000 file.
825
+
826
+ The function takes a list of lines (belonging to one sentence)
827
+ and extracts the annotated spans. The lines are expected to
828
+ contain three space-separated columns:
829
+
830
+ Form XPOS Annotation
831
+
832
+ Form: Word form
833
+ XPOS: POS tag of the word (ideally STTS, UPOS, or PTB)
834
+ Annotation: Span annotation in BIO format (see below);
835
+ multiple spans are separated with the pipe symbol '|'
836
+
837
+ BIO tags consist of the token's position in the span
838
+ (begin 'B', inside 'I', outside 'O'), a dash '-' and the span label,
839
+ e.g., B-NP, I-AC, or in the case of stacked annotations I-RELC|B-NP.
840
+
841
+ The function accepts 'O', '_' and '' as annotations outside of spans.
842
+
843
+ Input: List of lines belonging to one sentence.
844
+ Output: List of spans as 4-tuples [label, begin, end, token_set]
845
+ """
846
+ spans = []
847
+ span_stack = []
848
+
849
+ #For each token
850
+ for t, tok in enumerate(sentence):
851
+
852
+ #Token is [Form, XPOS, Annotation]
853
+ tok = tok.split()
854
+
855
+ #Token is not annotated
856
+ if tok[-1] in ["O", "_", ""]:
857
+ #Add previous stack to span list
858
+ #(sorted from left to right)
859
+ while span_stack:
860
+ spans.append(span_stack.pop(0))
861
+ span_stack = []
862
+ continue
863
+
864
+ #Token is annotated
865
+ #Split stacked annotations at pipe
866
+ annotations = tok[-1].strip().split("|")
867
+
868
+ #While there are more annotation levels on
869
+ #the stack than at the current token,
870
+ #close annotations on the stack (i.e., move
871
+ #them to result list)
872
+ while len(span_stack) > len(annotations):
873
+ spans.append(span_stack.pop())
874
+
875
+ #For each annotation of the current token
876
+ for i, annotation in enumerate(annotations):
877
+
878
+ #New span
879
+ if annotation.startswith("B-"):
880
+
881
+ #If it's the first annotation level and there is
882
+ #something on the stack, move it to result list
883
+ if i == 0 and span_stack:
884
+ while span_stack:
885
+ spans.append(span_stack.pop(0))
886
+ #Otherwise, end same-level annotation on the
887
+ #stack (because a new span begins here) and
888
+ #move it to the result list
889
+ else:
890
+ while len(span_stack) > i:
891
+ spans.append(span_stack.pop())
892
+
893
+ #Last part of BIO tag is the label
894
+ label = annotation.split("-")[1]
895
+
896
+ #Create a new span with this token's
897
+ #index as start and end (incremendet by one).
898
+ s = [label, t+1, t+1, {t+1}]
899
+
900
+ #Add on top of stack
901
+ span_stack.append(s)
902
+
903
+ #Span continues
904
+ elif annotation.startswith("I-"):
905
+ #Increment the end index of the span
906
+ #at the level of this annotation on the stack
907
+ span_stack[i][2] = t+1
908
+ #Also, add the index to the token set
909
+ span_stack[i][-1].add(t+1)
910
+
911
+ #Add sentence final span(s)
912
+ while span_stack:
913
+ spans.append(span_stack.pop(0))
914
+
915
+ return spans
916
+
917
+ ################
918
+
919
+ def spans_from_lines(sentence):
920
+ """
921
+ Read annotation spans from a span file.
922
+
923
+ The function takes a list of lines (belonging to one sentence)
924
+ and extracts the annotated spans. The lines are expected to
925
+ contain four tab-separated columns:
926
+
927
+ Label Begin End Tokens
928
+
929
+ Label: Span label
930
+ Begin: Index of the first included token (must be convertible to int)
931
+ End: Index of the last included token (must be convertible to int
932
+ and equal or greater than begin)
933
+ Tokens: Comma-separated list of indices of the tokens in the span
934
+ (must be convertible to int with begin <= i <= end);
935
+ if no (valid) indices are given, the range begin:end is used
936
+
937
+ Input: List of lines belonging to one sentence.
938
+ Output: List of spans as 4-tuples [label, begin, end, token_set]
939
+ """
940
+ spans = []
941
+ for line in sentence:
942
+ vals = line.split("\t")
943
+ label = vals[0]
944
+ if not label:
945
+ print("ERROR: Missing label in input.")
946
+ return []
947
+ try:
948
+ begin = int(vals[1])
949
+ if begin < 1: raise ValueError
950
+ except ValueError:
951
+ print("ERROR: Begin {0} is not a legal index.".format(vals[1]))
952
+ return []
953
+ try:
954
+ end = int(vals[2])
955
+ if end < 1: raise ValueError
956
+ if end < begin: begin, end = end, begin
957
+ except ValueError:
958
+ print("ERROR: End {0} is not a legal index.".format(vals[2]))
959
+ return []
960
+ try:
961
+ toks = [int(v.strip()) for v in vals[-1].split(",")
962
+ if int(v.strip()) >= begin and int(v.strip()) <= end]
963
+ toks = set(toks)
964
+ except ValueError:
965
+ toks = []
966
+ if not toks:
967
+ toks = [i for i in range(begin, end+1)]
968
+ spans.append([label, begin, end, toks])
969
+ return spans
970
+
971
+ ################
972
+
973
+ #Determine data format
974
+
975
+ #Span files contain 4 tab-separated columns
976
+ if len(sentence[0].split("\t")) == 4:
977
+ format = "spans"
978
+ spans = spans_from_lines(sentence)
979
+
980
+ #CoNLL2000 files contain 3 space-separated columns
981
+ elif len(sentence[0].split(" ")) == 3:
982
+ format = "conll2000"
983
+ spans = spans_from_conll(sentence)
984
+ else:
985
+ print("ERROR: Unknown input format")
986
+ return []
987
+
988
+ #Exclude punctuation from CoNLL2000, if desired
989
+ if format == "conll2000" \
990
+ and config.get("ignore_punct") == True:
991
+
992
+ #For each punctuation tok
993
+ for i, line in enumerate(sentence):
994
+ if line.split(" ")[1] in ["$.", "$,", "$(", #STTS
995
+ "PUNCT", #UPOS
996
+ "PUNKT", "KOMMA", "COMMA", "KLAMMER", #custom
997
+ ".", ",", ":", "(", ")", "\"", "‘", "“", "’", "”" #PTB
998
+ ]:
999
+
1000
+ for s in range(len(spans)):
1001
+ #Remove punc tok from set
1002
+ spans[s][-1].discard(i+1)
1003
+
1004
+ #If span begins with punc, move begin
1005
+ if spans[s][1] == i+1:
1006
+ if spans[s][2] != None and spans[s][2] > i+1:
1007
+ spans[s][1] = i+2
1008
+ else:
1009
+ spans[s][1] = None
1010
+
1011
+ #If span ends with punc, move end
1012
+ if spans[s][2] == i+1:
1013
+ if spans[s][1] != None and spans[s][1] <= i:
1014
+ spans[s][2] = i
1015
+ else:
1016
+ spans[s][2] = None
1017
+
1018
+ #Remove empty spans
1019
+ spans = [s for s in spans if s[1] != None and s[2] != None and len(s[3]) > 0]
1020
+
1021
+ #Exclude unwanted labels
1022
+ spans = [s for s in spans
1023
+ if not s[0] in config.get("exclude", [])
1024
+ and ("all" in config.get("labels", [])
1025
+ or s[0] in config.get("labels", []))]
1026
+
1027
+ return spans
1028
+
1029
+ ############################
1030
+
1031
+ def get_sentences(filename):
1032
+ """
1033
+ Reads sentences from input files.
1034
+
1035
+ The function iterates through the input file and
1036
+ yields a list of lines that belong to one sentence.
1037
+ Sentences are expected to be separated by an empty line.
1038
+
1039
+ Input: Filename of the input file.
1040
+ Output: Yields a list of lines for each sentence.
1041
+ """
1042
+ file = open(filename, mode="r", encoding="utf-8")
1043
+ sent = []
1044
+
1045
+ for line in file:
1046
+ #New line: yield collected lines
1047
+ if sent and not line.strip():
1048
+ yield sent
1049
+ sent = []
1050
+ #New line but nothing to yield
1051
+ elif not line.strip():
1052
+ continue
1053
+ #Collect line of current sentence
1054
+ else:
1055
+ sent.append(line.strip())
1056
+
1057
+ #Last sentence if file doesn't end with empty line
1058
+ if sent:
1059
+ yield sent
1060
+
1061
+ file.close()
1062
+
1063
+ #############################
1064
+
1065
+ def add_dict(base_dict, dict_to_add):
1066
+ """
1067
+ Take a base dictionary and add the values
1068
+ from another dictionary to it.
1069
+
1070
+ Contrary to standard dict update methods,
1071
+ this function does not overwrite values in the
1072
+ base dictionary. Instead, it is meant to add
1073
+ the values of the second dictionary to the values
1074
+ in the base dictionary. The dictionary is modified in-place.
1075
+
1076
+ For example:
1077
+
1078
+ >> base = {"A" : 1, "B" : {"c" : 2, "d" : 3}, "C" : [1, 2, 3]}
1079
+ >> add = {"A" : 1, "B" : {"c" : 1, "e" : 1}, "C" : [4], "D" : 2}
1080
+ >> add_dict(base, add)
1081
+
1082
+ will create a base dictionary:
1083
+
1084
+ >> base
1085
+ {'A': 2, 'B': {'c': 3, 'd': 3, 'e': 1}, 'C': [1, 2, 3, 4], 'D': 2}
1086
+
1087
+ The function can handle different types of nested structures.
1088
+ - Integers and float values are summed up.
1089
+ - Lists are appended
1090
+ - Sets are added (set union)
1091
+ - Dictionaries are added recursively
1092
+ For other value types, the base dictionary is left unchanged.
1093
+
1094
+ Input: Base dictionary and dictionary to be added.
1095
+ Output: Base dictionary.
1096
+ """
1097
+
1098
+ #For each key in second dict
1099
+ for key, val in dict_to_add.items():
1100
+
1101
+ #It is already in the base dict
1102
+ if key in base_dict:
1103
+
1104
+ #It has an integer or float value
1105
+ if isinstance(val, (int, float)) \
1106
+ and isinstance(base_dict[key], (int, float)):
1107
+
1108
+ #Increment value in base dict
1109
+ base_dict[key] += val
1110
+
1111
+ #It has an iterable as value
1112
+ elif isinstance(val, Iterable) \
1113
+ and isinstance(base_dict[key], Iterable):
1114
+
1115
+ #List
1116
+ if isinstance(val, list) \
1117
+ and isinstance(base_dict[key], list):
1118
+ #Append
1119
+ base_dict[key].extend(val)
1120
+
1121
+ #Set
1122
+ elif isinstance(val, set) \
1123
+ and isinstance(base_dict[key], set):
1124
+ #Set union
1125
+ base_dict[key].update(val)
1126
+
1127
+ #Dict
1128
+ elif isinstance(val, dict) \
1129
+ and isinstance(base_dict[key], dict):
1130
+ #Recursively repeat
1131
+ add_dict(base_dict[key], val)
1132
+
1133
+ #Something else
1134
+ else:
1135
+ #Do nothing
1136
+ pass
1137
+
1138
+ #It has something else as value
1139
+ else:
1140
+ #Do nothing
1141
+ pass
1142
+
1143
+ #It is not in the base dict
1144
+ else:
1145
+ #Insert values from second dict into base
1146
+ base_dict[key] = deepcopy(val)
1147
+
1148
+ return base_dict
1149
+
1150
+ #############################
1151
+
1152
+ def calculate_results(eval_dict, **config):
1153
+ """
1154
+ Calculate overall precision, recall, and F-scores.
1155
+
1156
+ The function takes an evaluation dictionary with error counts
1157
+ and applies the precision, recall and fscore functions.
1158
+
1159
+ It will calculate the traditional metrics
1160
+ and fair and/or weighted metrics, depending on the
1161
+ value of config['eval_method'].
1162
+
1163
+ The results are stored in the eval dict as 'Prec', 'Rec' and 'F1'
1164
+ for overall and per-label counts.
1165
+
1166
+ Input: Evaluation dict and config dict.
1167
+ Output: Evaluation dict with added precision, recall and F1 values.
1168
+ """
1169
+
1170
+ #If weighted evaluation should be performed
1171
+ #copy error counts from fair evaluation
1172
+ if "weighted" in config.get("eval_method", []):
1173
+ eval_dict["overall"]["weighted"] = {}
1174
+ for err_type in eval_dict["overall"]["fair"]:
1175
+ eval_dict["overall"]["weighted"][err_type] = eval_dict["overall"]["fair"][err_type]
1176
+ for label in eval_dict["per_label"]["fair"]:
1177
+ eval_dict["per_label"]["weighted"][label] = {}
1178
+ for err_type in eval_dict["per_label"]["fair"][label]:
1179
+ eval_dict["per_label"]["weighted"][label][err_type] = eval_dict["per_label"]["fair"][label][err_type]
1180
+
1181
+ #For each evaluation method
1182
+ for version in config.get("eval_method", ["traditional", "fair"]):
1183
+
1184
+ #Overall results
1185
+ eval_dict["overall"][version]["Prec"] = precision(eval_dict["overall"][version],
1186
+ version,
1187
+ config.get("weights", {}))
1188
+ eval_dict["overall"][version]["Rec"] = recall(eval_dict["overall"][version],
1189
+ version,
1190
+ config.get("weights", {}))
1191
+ eval_dict["overall"][version]["F1"] = fscore(eval_dict["overall"][version])
1192
+
1193
+ #Per label results
1194
+ for label in eval_dict["per_label"][version]:
1195
+ eval_dict["per_label"][version][label]["Prec"] = precision(eval_dict["per_label"][version][label],
1196
+ version,
1197
+ config.get("weights", {}))
1198
+ eval_dict["per_label"][version][label]["Rec"] = recall(eval_dict["per_label"][version][label],
1199
+ version,
1200
+ config.get("weights", {}))
1201
+ eval_dict["per_label"][version][label]["F1"] = fscore(eval_dict["per_label"][version][label])
1202
+
1203
+ return eval_dict
1204
+
1205
+ #############################
1206
+
1207
+ def output_results(eval_dict, **config):
1208
+ """
1209
+ Write evaluation results to the output (file).
1210
+
1211
+ The function takes an evaluation dict and writes
1212
+ all results to the specified output (file):
1213
+
1214
+ 1. Traditional evaluation results
1215
+ 2. Additional evaluation results (fair and/or weighted)
1216
+ 3. Result comparison for different evaluation methods
1217
+ 4. Confusion matrix
1218
+ 5. Data statistics
1219
+
1220
+ Input: Evaluation dict and config dict.
1221
+ """
1222
+ outfile = config.get("eval_out", sys.stdout)
1223
+
1224
+ ### Output results for each evaluation method
1225
+ for version in config.get("eval_method", ["traditional", "fair"]):
1226
+ print(file=outfile)
1227
+ print("### {0} evaluation:".format(version.title()), file=outfile)
1228
+
1229
+ #Determine error categories to output
1230
+ if version == "traditional":
1231
+ cats = ["TP", "FP", "FN"]
1232
+ elif version == "fair" or not config.get("weights", {}):
1233
+ cats = ["TP", "FP", "LE", "BE", "LBE", "FN"]
1234
+ else:
1235
+ cats = list(config.get("weights").keys())
1236
+
1237
+ #Print header
1238
+ print("Label", "\t".join(cats), "Prec", "Rec", "F1", sep="\t", file=outfile)
1239
+
1240
+ #Output results for each label
1241
+ for label,val in sorted(eval_dict["per_label"][version].items()):
1242
+ print(label,
1243
+ "\t".join([str(val.get(cat, eval_dict["per_label"]["fair"].get(cat, 0)))
1244
+ for cat in cats]),
1245
+ "\t".join(["{:04.2f}".format(val.get(metric, 0)*100)
1246
+ for metric in ["Prec", "Rec", "F1"]]),
1247
+ sep="\t", file=outfile)
1248
+
1249
+ #Output overall results
1250
+ print("overall",
1251
+ "\t".join([str(eval_dict["overall"][version].get(cat, eval_dict["overall"]["fair"].get(cat, 0)))
1252
+ for cat in cats]),
1253
+ "\t".join(["{:04.2f}".format(eval_dict["overall"][version].get(metric, 0)*100)
1254
+ for metric in ["Prec", "Rec", "F1"]]),
1255
+ sep="\t", file=outfile)
1256
+
1257
+ ### Output result comparison
1258
+ print(file=outfile)
1259
+ print("### Comparison:", file=outfile)
1260
+ print("Version", "Prec", "Rec", "F1", sep="\t", file=outfile)
1261
+ for version in config.get("eval_method", ["traditional", "fair"]):
1262
+ print(version.title(),
1263
+ "\t".join(["{:04.2f}".format(eval_dict["overall"][version].get(metric, 0)*100)
1264
+ for metric in ["Prec", "Rec", "F1"]]),
1265
+ sep="\t", file=outfile)
1266
+
1267
+ ### Output confusion matrix
1268
+ print(file=outfile)
1269
+ print("### Confusion matrix:", file=outfile)
1270
+
1271
+ #Get set of target labels
1272
+ labels = {lab for lab in eval_dict["conf"]}
1273
+
1274
+ #Add system labels
1275
+ labels = list(labels.union({syslab
1276
+ for lab in eval_dict["conf"]
1277
+ for syslab in eval_dict["conf"][lab]}))
1278
+
1279
+ #Sort alphabetically for output
1280
+ labels.sort()
1281
+
1282
+ #Print top row with system labels
1283
+ print(r"Target\System", "\t".join(labels), sep="\t", file=outfile)
1284
+
1285
+ #Print rows with target labels and counts
1286
+ for targetlab in labels:
1287
+ print(targetlab,
1288
+ "\t".join([str(eval_dict["conf"][targetlab].get(syslab, 0))
1289
+ for syslab in labels]),
1290
+ sep="\t", file=outfile)
1291
+
1292
+ #Output data statistic
1293
+ print(file=outfile)
1294
+ print("### Target data stats:", file=outfile)
1295
+ print("Label", "Freq", "%", sep="\t", file=outfile)
1296
+ total = sum(config.get("data_stats", {}).values())
1297
+ for lab, freq in config.get("data_stats", {}).items():
1298
+ print(lab, freq, "{:04.2f}".format(freq/total*100), sep="\t", file=outfile)
1299
+
1300
+ #Close output if it is a file
1301
+ if isinstance(config.get("eval_out"), TextIOWrapper):
1302
+ outfile.close()
1303
+
1304
+ #############################
1305
+
1306
+ def read_config(config_file):
1307
+ """
1308
+ Function to set program parameters as specified in the config file.
1309
+
1310
+ The following parameters are handled:
1311
+
1312
+ - target_in: path to the target file(s) with gold standard annotation
1313
+ -> output: 'target_files' : [list of target file paths]
1314
+
1315
+ - system_in: path to the system's output file(s), which are evaluated
1316
+ -> output: 'system_files' : [list of system file paths]
1317
+
1318
+ - eval_out: path or filename, where evaluation results should be stored
1319
+ if value is a path, output file 'path/eval.csv' is created
1320
+ if value is 'cmd' or missing, output is set to sys.stdout
1321
+ -> output: 'eval_out' : output file or sys.stdout
1322
+
1323
+ - labels: comma-separated list of labels to evaluate
1324
+ defaults to 'all'
1325
+ -> output: 'labels' : [list of labels as strings]
1326
+
1327
+ - exclude: comma-separated list of labels to exclude from evaluation
1328
+ always contains 'NONE' and 'EMPTY'
1329
+ -> output: 'exclude' : [list of labels as strings]
1330
+
1331
+ - ignore_punct: wether to ignore punctuation during evaluation (true/false)
1332
+ -> output: 'ignore_punct' : True/False
1333
+
1334
+ - focus: wether to focus the evaluation on 'target' or 'system' annotations
1335
+ defaults to 'target'
1336
+ -> output: 'focus' : 'target' or 'system'
1337
+
1338
+ - weights: weights that should be applied during calculation of precision
1339
+ and recall; at the same time can serve as a list of additional
1340
+ error types to include in the evaluation
1341
+ the weights are parsed from comma-separated input formulas of the form
1342
+
1343
+ error_type = weight * TP + weight2 * FP + weight3 * FN
1344
+
1345
+ -> output: 'weights' : { 'error type' : {
1346
+ 'TP' : weight,
1347
+ 'FP' : weight,
1348
+ 'FN' : weight
1349
+ },
1350
+ 'another error type' : {...}
1351
+ }
1352
+
1353
+ - eval_method: defines which evaluation method(s) to use
1354
+ one or more of: 'traditional', 'fair', 'weighted'
1355
+ if value is 'all' or missing, all available methods are returned
1356
+ -> output: 'eval_method' : [list of eval methods]
1357
+
1358
+ Input: Filename of the config file.
1359
+ Output: Settings dictionary.
1360
+ """
1361
+
1362
+ ############################
1363
+
1364
+ def _parse_config(key, val):
1365
+ """
1366
+ Internal function to set specific values for the given keys.
1367
+ In case of illegal values, prints error message and sets key and/or value to None.
1368
+ Input: Key and value from config file
1369
+ Output: Modified key and value
1370
+ """
1371
+ if key in ["target_in", "system_in"]:
1372
+ if os.path.isdir(val):
1373
+ val = os.path.normpath(val)
1374
+ files = [os.path.join(val, f) for f in os.listdir(val)]
1375
+ elif os.path.isfile(val):
1376
+ files = [os.path.normpath(val)]
1377
+ else:
1378
+ print("Error: '{0} = {1}' is not a file/directory.".format(key, val))
1379
+ return None, None
1380
+ if key == "target_in":
1381
+ return "target_files", files
1382
+ elif key == "system_in":
1383
+ return "system_files", files
1384
+
1385
+ elif key == "eval_out":
1386
+ if os.path.isdir(val):
1387
+ val = os.path.normpath(val)
1388
+ outfile = os.path.join(val, "eval.csv")
1389
+ elif os.path.isfile(val):
1390
+ outfile = os.path.normpath(val)
1391
+ elif val == "cmd":
1392
+ outfile = sys.stdout
1393
+ else:
1394
+ try:
1395
+ p, f = os.path.split(val)
1396
+ if not os.path.isdir(p):
1397
+ os.makedirs(p)
1398
+ outfile = os.path.join(p, f)
1399
+ except:
1400
+ print("Error: '{0} = {1}' is not a file/directory.".format(key, val))
1401
+ return None, None
1402
+ return key, outfile
1403
+
1404
+ elif key in ["labels", "exclude"]:
1405
+ labels = list(set([v.strip() for v in val.split(",") if v.strip()]))
1406
+ if key == "exclude":
1407
+ labels.append("NONE")
1408
+ labels.append("EMPTY")
1409
+ return key, labels
1410
+
1411
+ elif key == "ignore_punct":
1412
+ if val.strip().lower() == "false":
1413
+ return key, False
1414
+ else:
1415
+ return key, True
1416
+
1417
+ elif key == "focus":
1418
+ if val.strip().lower() == "system":
1419
+ return key, "system"
1420
+ else:
1421
+ return key, "target"
1422
+
1423
+ elif key == "weights":
1424
+ if val == "default":
1425
+ return key, {"TP" : {"TP" : 1},
1426
+ "FP" : {"FP" : 1},
1427
+ "FN" : {"FN" : 1},
1428
+ "LE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5},
1429
+ "BE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5},
1430
+ "LBE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}}
1431
+ else:
1432
+ formulas = val.split(",")
1433
+ weights = {}
1434
+
1435
+ #For each given formula, i.e., for each error type
1436
+ for f in formulas:
1437
+
1438
+ #Match error type as string-initial letters before equal sign =
1439
+ error_type = re.match(r"\s*(?P<Error>\w+)\s*=", f)
1440
+ if error_type == None:
1441
+ print("WARNING: No error type found in weight formula '{0}'.".format(f))
1442
+ continue
1443
+ else:
1444
+ error_type = error_type.group("Error")
1445
+
1446
+ weights[error_type] = {}
1447
+
1448
+ #Match weight for TP
1449
+ w_tp = re.search(r"(?P<TP>\d*\.?\d+)\s*\*?\s*TP", f)
1450
+ if w_tp == None:
1451
+ print("WARNING: Missing weight for TP for error type {0}. Set to 0.".format(error_type))
1452
+ weights[error_type]["TP"] = 0
1453
+ else:
1454
+ try:
1455
+ w_tp = w_tp.group("TP")
1456
+ w_tp = float(w_tp)
1457
+ weights[error_type]["TP"] = w_tp
1458
+ except ValueError:
1459
+ print("WARNING: Weight for TP for error type {0} is not a number. Set to 0.".format(error_type))
1460
+ weights[error_type]["TP"] = 0
1461
+
1462
+ #Match weight for FP
1463
+ w_fp = re.search(r"(?P<FP>\d*\.?\d+)\s*\*?\s*FP", f)
1464
+ if w_fp == None:
1465
+ print("WARNING: Missing weight for FP for error type {0}. Set to 0.".format(error_type))
1466
+ weights[error_type]["FP"] = 0
1467
+ else:
1468
+ try:
1469
+ w_fp = w_fp.group("FP")
1470
+ w_fp = float(w_fp)
1471
+ weights[error_type]["FP"] = w_fp
1472
+ except ValueError:
1473
+ print("WARNING: Weight for FP for error type {0} is not a number. Set to 0.".format(error_type))
1474
+ weights[error_type]["FP"] = 0
1475
+
1476
+ #Match weight for FP
1477
+ w_fn = re.search(r"(?P<FN>\d*\.?\d+)\s*\*?\s*FN", f)
1478
+ if w_fn == None:
1479
+ print("WARNING: Missing weight for FN for error type {0}. Set to 0.".format(error_type))
1480
+ weights[error_type]["FN"] = 0
1481
+ else:
1482
+ try:
1483
+ w_fn = w_fn.group("FN")
1484
+ w_fn = float(w_fn)
1485
+ weights[error_type]["FN"] = w_fn
1486
+ except ValueError:
1487
+ print("WARNING: Weight for FN for error type {0} is not a number. Set to 0.".format(error_type))
1488
+ weights[error_type]["FN"] = 0
1489
+ if weights:
1490
+ #Add default weights for traditional categories if needed
1491
+ if not "TP" in weights:
1492
+ weights["TP"] = {"TP" : 1}
1493
+ if not "FP" in weights:
1494
+ weights["FP"] = {"FP" : 1}
1495
+ if not "FN" in weights:
1496
+ weights["FN"] = {"FN" : 1}
1497
+ return key, weights
1498
+ else:
1499
+ print("WARNING: No valid weights found. Using default weights.")
1500
+ return key, {"TP" : {"TP" : 1},
1501
+ "FP" : {"FP" : 1},
1502
+ "FN" : {"FN" : 1},
1503
+ "LE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5},
1504
+ "BE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5},
1505
+ "LBE" : {"TP" : 0, "FP" : 0.5, "FN" : 0.5}}
1506
+
1507
+ elif key == "eval_method":
1508
+ available_methods = ["traditional", "fair", "weighted"]
1509
+ if val == "all":
1510
+ return key, available_methods
1511
+ else:
1512
+ methods = []
1513
+ for m in available_methods:
1514
+ if m in [v.strip() for v in val.split(",")
1515
+ if v.strip() and v.strip().lower() in available_methods]:
1516
+ methods.append(m)
1517
+ if methods:
1518
+ return key, methods
1519
+ else:
1520
+ print("WARNING: No evaluation method specified. Applying all methods.")
1521
+ return key, available_methods
1522
+
1523
+ #############################
1524
+
1525
+ config = dict()
1526
+
1527
+ f = open(config_file, mode="r", encoding="utf-8")
1528
+
1529
+ for line in f:
1530
+
1531
+ line = line.strip()
1532
+
1533
+ #Skip empty lines and comments
1534
+ if not line or line.startswith("#"):
1535
+ continue
1536
+
1537
+ line = line.split("=")
1538
+ key = line[0].strip()
1539
+ val = "=".join(line[1:]).strip()
1540
+
1541
+ #Store original paths of input files
1542
+ if key in ["target_in", "system_in"]:
1543
+ print("{0}: {1}".format(key, val))
1544
+ config[key] = val
1545
+
1546
+ #Parse config
1547
+ key, val = _parse_config(key, val)
1548
+
1549
+ #Skip illegal configs
1550
+ if key is None or val is None:
1551
+ continue
1552
+
1553
+ #Warn before overwriting duplicate config items.
1554
+ if key in config:
1555
+ print("WARNING: duplicate config item '{0}' found.".format(key))
1556
+
1557
+ config[key] = val
1558
+
1559
+ f.close()
1560
+
1561
+ #Stop evaluation if either target or system files are missing
1562
+ if not "target_files" in config or not "system_files" in config:
1563
+ print("ERROR: Cannot evaluate without target AND system file(s). Quitting.")
1564
+ return None
1565
+
1566
+ #Output to sys.stdout if no evaluation file is specified
1567
+ elif config.get("eval_out", None) == None:
1568
+ config["eval_out"] = sys.stdout
1569
+ #Otherwise open eval file
1570
+ else:
1571
+ config["eval_out"] = open(config.get("eval_out"), mode="w", encoding="utf-8")
1572
+
1573
+ #Set labels to 'all' if no specific labels are given
1574
+ if config.get("labels", None) == None:
1575
+ config["labels"] = ["all"]
1576
+
1577
+ if config.get("eval_method", None) == None:
1578
+ config["eval_method"] = ["traditional", "fair", "weighted"]
1579
+ if not config.get("weights", {}) and "weighted" in config.get("eval_method"):
1580
+ if not "fair" in config["eval_method"]:
1581
+ config["eval_method"].append("fair")
1582
+ del config["eval_method"][config["eval_method"].index("weighted")]
1583
+
1584
+ #Output settings at the top of evaluation file
1585
+ print("### Evaluation settings:", file=config.get("eval_out"))
1586
+ for key in sorted(config.keys()):
1587
+ if key in ["target_files", "system_files", "eval_out"]:
1588
+ continue
1589
+ print("{0}: {1}".format(key, config.get(key)), file=config.get("eval_out"))
1590
+ print(file=config.get("eval_out"))
1591
+
1592
+ return config
1593
+
1594
+ ###########################
1595
+
1596
+ if __name__ == '__main__':
1597
+ parser = argparse.ArgumentParser()
1598
+ parser.add_argument('--config', help='Configuration File', required=True)
1599
+
1600
+ args = parser.parse_args()
1601
+
1602
+ #Read config file into dict
1603
+ config = read_config(args.config)
1604
+
1605
+ #Create empty eval dict
1606
+ eval_dict = {"overall" : {"traditional" : {}, "fair" : {}},
1607
+ "per_label" : {"traditional" : {}, "fair" : {}},
1608
+ "conf" : {}}
1609
+ for method in config.get("eval_method", ["traditional", "fair"]):
1610
+ eval_dict["overall"][method] = {}
1611
+ eval_dict["per_label"][method] = {}
1612
+
1613
+ #Create dict to count target annotations
1614
+ config["data_stats"] = {}
1615
+
1616
+ #Get system and target files to compare
1617
+ #The files must have the same name to be compared
1618
+ file_pairs = []
1619
+ for t in config.get("target_files", []):
1620
+ s = [f for f in config.get("system_files", [])
1621
+ if os.path.split(t)[-1] == os.path.split(f)[-1]]
1622
+ if s:
1623
+ file_pairs.append((t, s[0]))
1624
+
1625
+ #Go through target and system files in parallel
1626
+ for target_file, system_file in file_pairs:
1627
+
1628
+ #For each sentence pair
1629
+ for target_sentence, system_sentence in zip(get_sentences(target_file),
1630
+ get_sentences(system_file)):
1631
+
1632
+ #Get spans
1633
+ target_spans = get_spans(target_sentence, **config)
1634
+ system_spans = get_spans(system_sentence, **config)
1635
+
1636
+ #Count target annotations for simple statistics.
1637
+ #Result is stored in data_stats key of config dict.
1638
+ annotation_stats(target_spans, **config)
1639
+
1640
+ #Evaluate spans
1641
+ sent_counts = compare_spans(target_spans, system_spans,
1642
+ config.get("focus", "target"))
1643
+
1644
+ #Add results to eval dict
1645
+ eval_dict = add_dict(eval_dict, sent_counts)
1646
+
1647
+ #Calculate overall results
1648
+ eval_dict = calculate_results(eval_dict, **config)
1649
+
1650
+ #Output results
1651
+ output_results(eval_dict, **config)
README.md ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: FairEvaluation
3
+ tags:
4
+ - evaluate
5
+ - metric
6
+ description: "TODO: add a description here"
7
+ sdk: gradio
8
+ sdk_version: 3.0.2
9
+ app_file: app.py
10
+ pinned: false
11
+ ---
12
+
13
+ # Metric: Fair Evaluation
14
+
15
+ ## Metric Description
16
+ The traditional evaluation of NLP labeled spans with precision, recall, and F1-score leads to double penalties for
17
+ close-to-correct annotations. As Manning (2006) argues in an article about named entity recognition, this can lead to
18
+ undesirable effects when systems are optimized for these traditional metrics.
19
+
20
+ Building on his ideas, Katrin Ortmann (2022) develops FairEval: a new evaluation method that more accurately reflects
21
+ true annotation quality by ensuring that every error is counted only once. In addition to the traditional categories of
22
+ true positives (TP), false positives (FP), and false negatives (FN), the new method takes into account the more
23
+ fine-grained error types suggested by Manning: labeling errors (LE), boundary errors (BE), and labeling-boundary
24
+ errors (LBE). Additionally, the system also distinguishes different types of boundary errors:
25
+ - BES: the system's annotation is smaller than the target span
26
+ - BEL: the system's annotation is larger than the target span
27
+ - BEO: the system span overlaps with the target span
28
+
29
+ For more information on the reasoning and computation of the fair metrics from the redefined error count pleas refer to the [original paper](https://aclanthology.org/2022.lrec-1.150.pdf).
30
+
31
+ ## How to Use
32
+ The current HuggingFace implementation accepts input for the predictions and references as sentences in IOB format.
33
+ The simplest use example would be:
34
+
35
+ ```python
36
+ >>> faireval = evaluate.load("illorca/fairevaluation")
37
+ >>> pred = ['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O']
38
+ >>> ref = ['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O']
39
+ >>> results = faireval.compute(predictions=pred, references=ref)
40
+ ```
41
+
42
+ ### Inputs
43
+ - **predictions** *(list)*: list of predictions to score. Each predicted sentence
44
+ should be a list of IOB-formatted labels corresponding to each sentence token.
45
+ Predicted sentences must have the same number of tokens as the references'.
46
+ - **references** *(list)*: list of reference for each prediction. Each reference sentence
47
+ should be a list of IOB-formatted labels corresponding to each sentence token.
48
+
49
+ ### Output Values
50
+ A dictionary with:
51
+ - TP: count of True Positives
52
+ - FP: count of False Positives
53
+ - FN: count of False Negatives
54
+ - LE: count of Labeling Errors
55
+ - BE: count of Boundary Errors
56
+ - BEO: segment of the BE where the prediction overlaps with the reference
57
+ - BES: segment of the BE where the prediction is smaller than the reference
58
+ - BEL: segment of the BE where the prediction is larger than the reference
59
+ - LBE : count of Label-and-Boundary Errors
60
+ - Prec: fair precision
61
+ - Rec: fair recall
62
+ - F1: fair F1-score
63
+
64
+ #### Values from Popular Papers
65
+ *Examples, preferrably with links to leaderboards or publications, to papers that have reported this metric, along with the values they have reported.*
66
+
67
+ *Under construction*
68
+
69
+ ### Examples
70
+ *Code examples of the metric being used. Try to include examples that clear up any potential ambiguity left from the metric description above. If possible, provide a range of examples that show both typical and atypical results, as well as examples where a variety of input parameters are passed.*
71
+
72
+ *Under construction*
73
+
74
+ ## Limitations and Bias
75
+ *Note any known limitations or biases that the metric has, with links and references if possible.*
76
+
77
+ *Under construction*
78
+
79
+ ## Citation
80
+ Ortmann, Katrin. 2022. Fine-Grained Error Analysis and Fair Evaluation of Labeled Spans. In *Proceedings of the Language Resources and Evaluation Conference (LREC)*, Marseille, France, pages 1400–1407. [PDF](https://aclanthology.org/2022.lrec-1.150.pdf)
81
+
82
+ ```bibtex
83
+ @inproceedings{ortmann2022,
84
+ title = {Fine-Grained Error Analysis and Fair Evaluation of Labeled Spans},
85
+ author = {Katrin Ortmann},
86
+ url = {https://aclanthology.org/2022.lrec-1.150},
87
+ year = {2022},
88
+ date = {2022-06-21},
89
+ booktitle = {Proceedings of the Language Resources and Evaluation Conference (LREC)},
90
+ pages = {1400-1407},
91
+ publisher = {European Language Resources Association},
92
+ address = {Marseille, France},
93
+ pubstate = {published},
94
+ type = {inproceedings}
95
+ }
96
+ ```
app.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import evaluate
2
+ from evaluate.utils import launch_gradio_widget
3
+
4
+
5
+ module = evaluate.load("illorca/fairevaluation")
6
+ launch_gradio_widget(module)
fairevaluation.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # huggingface packages
16
+ import evaluate
17
+ import datasets
18
+
19
+ # faireval functions
20
+ from .FairEval import *
21
+
22
+ # packages to manage input formats
23
+ import importlib
24
+ from typing import List, Optional, Union
25
+ from seqeval.metrics.v1 import check_consistent_length
26
+ from seqeval.scheme import Entities, Token, auto_detect
27
+
28
+ _CITATION = """\
29
+ @inproceedings{ortmann2022,
30
+ title = {Fine-Grained Error Analysis and Fair Evaluation of Labeled Spans},
31
+ author = {Katrin Ortmann},
32
+ url = {https://aclanthology.org/2022.lrec-1.150},
33
+ year = {2022},
34
+ date = {2022-06-21},
35
+ booktitle = {Proceedings of the Language Resources and Evaluation Conference (LREC)},
36
+ pages = {1400-1407},
37
+ publisher = {European Language Resources Association},
38
+ address = {Marseille, France},
39
+ pubstate = {published},
40
+ type = {inproceedings}
41
+ }
42
+ """
43
+
44
+ _DESCRIPTION = """\
45
+ New evaluation method that more accurately reflects true annotation quality by ensuring that every error is counted
46
+ only once - avoiding the penalty to close-to-target annotations happening in traditional evaluation.
47
+ In addition to the traditional categories of true positives (TP), false positives (FP), and false negatives
48
+ (FN), the new method takes into account the more fine-grained error types suggested by Manning: labeling errors (LE),
49
+ boundary errors (BE), and labeling-boundary errors (LBE). Additionally, the system also distinguishes different types
50
+ of boundary errors: BES (the system's annotation is smaller than the target span), BEL (the system's annotation is
51
+ larger than the target span) and BEO (the system span overlaps with the target span)
52
+ """
53
+
54
+ _KWARGS_DESCRIPTION = """
55
+ Counts the number of redefined traditional errors (FP, FN), newly defined errors (BE, LE, LBE) and fine-grained
56
+ boundary errors (BES, BEL, BEO). Then computes the fair Precision, Recall and F1-Score.
57
+ For the computation of the metrics from the error count please refer to: https://aclanthology.org/2022.lrec-1.150.pdf
58
+ Args:
59
+ predictions: list of predictions to score. Each predicted sentence
60
+ should be a list of IOB-formatted labels corresponding to each sentence token.
61
+ Predicted sentences must have the same number of tokens as the references'.
62
+ references: list of reference for each prediction. Each reference sentence
63
+ should be a list of IOB-formatted labels corresponding to each sentence token.
64
+ Returns:
65
+ A dictionary with:
66
+ TP: count of True Positives
67
+ FP: count of False Positives
68
+ FN: count of False Negatives
69
+ LE: count of Labeling Errors
70
+ BE: count of Boundary Errors
71
+ BEO: segment of the BE where the prediction overlaps with the reference
72
+ BES: segment of the BE where the prediction is smaller than the reference
73
+ BEL: segment of the BE where the prediction is larger than the reference
74
+ LBE : count of Label-and-Boundary Errors
75
+ Prec: fair precision
76
+ Rec: fair recall
77
+ F1: fair F1-score
78
+ Examples:
79
+ >>> faireval = evaluate.load("illorca/fairevaluation")
80
+ >>> pred = ['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O']
81
+ >>> ref = ['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O']
82
+ >>> results = faireval.compute(predictions=pred, references=ref)
83
+ >>> print(results)
84
+ {'TP': 1,
85
+ 'FP': 0,
86
+ 'FN': 0,
87
+ 'LE': 0,
88
+ 'BE': 1,
89
+ 'BEO': 0,
90
+ 'BES': 0,
91
+ 'BEL': 1,
92
+ 'LBE': 0,
93
+ 'Prec': 0.6666666666666666,
94
+ 'Rec': 0.6666666666666666,
95
+ 'F1': 0.6666666666666666}
96
+ """
97
+
98
+
99
+ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
100
+ class FairEvaluation(evaluate.Metric):
101
+ """Counts the number of redefined traditional errors (FP, FN), newly defined errors (BE, LE, LBE) and fine-grained
102
+ boundary errors (BES, BEL, BEO). Then computes the fair Precision, Recall and F1-Score. """
103
+
104
+ def _info(self):
105
+ return evaluate.MetricInfo(
106
+ # This is the description that will appear on the modules page.
107
+ module_type="metric",
108
+ description=_DESCRIPTION,
109
+ citation=_CITATION,
110
+ inputs_description=_KWARGS_DESCRIPTION,
111
+ # This defines the format of each prediction and reference
112
+ features=datasets.Features({
113
+ "predictions": datasets.Sequence(datasets.Value("string", id="label"), id="sequence"),
114
+ "references": datasets.Sequence(datasets.Value("string", id="label"), id="sequence"),
115
+ }),
116
+ # Homepage of the module for documentation
117
+ homepage="https://huggingface.co/spaces/illorca/fairevaluation",
118
+ # Additional links to the codebase or references
119
+ codebase_urls=["https://github.com/rubcompling/FairEval#acknowledgement"],
120
+ reference_urls=["https://aclanthology.org/2022.lrec-1.150.pdf"]
121
+ )
122
+
123
+ def _compute(
124
+ self,
125
+ predictions,
126
+ references,
127
+ suffix: bool = False,
128
+ scheme: Optional[str] = None,
129
+ mode: Optional[str] = 'fair',
130
+ error_format: Optional[str] = 'count',
131
+ sample_weight: Optional[List[int]] = None,
132
+ zero_division: Union[str, int] = "warn",
133
+ ):
134
+ """Returns the error counts and fair scores"""
135
+ # (1) SEQEVAL INPUT MANAGEMENT
136
+ if scheme is not None:
137
+ try:
138
+ scheme_module = importlib.import_module("seqeval.scheme")
139
+ scheme = getattr(scheme_module, scheme)
140
+ except AttributeError:
141
+ raise ValueError(f"Scheme should be one of [IOB1, IOB2, IOE1, IOE2, IOBES, BILOU], got {scheme}")
142
+
143
+ y_true = references
144
+ y_pred = predictions
145
+
146
+ check_consistent_length(y_true, y_pred)
147
+
148
+ if scheme is None or not issubclass(scheme, Token):
149
+ scheme = auto_detect(y_true, suffix)
150
+
151
+ true_spans = Entities(y_true, scheme, suffix).entities
152
+ pred_spans = Entities(y_pred, scheme, suffix).entities
153
+
154
+ # (2) TRANSFORM FROM SEQEVAL TO FAIREVAL SPAN FORMAT
155
+ true_spans = seq_to_fair(true_spans)
156
+ pred_spans = seq_to_fair(pred_spans)
157
+
158
+ # (3) COUNT ERRORS AND CALCULATE SCORES
159
+ total_errors = compare_spans([], []) # initialize empty error count dictionary
160
+
161
+ for i in range(len(true_spans)):
162
+ sentence_errors = compare_spans(true_spans[i], pred_spans[i])
163
+ total_errors = add_dict(total_errors, sentence_errors)
164
+
165
+ results = calculate_results(total_errors)
166
+ del results['conf']
167
+
168
+ # (4) SELECT OUTPUT MODE AND REFORMAT AS SEQEVAL HUGGINGFACE OUTPUT
169
+ output = {}
170
+ total_trad_errors = results['overall']['traditional']['FP'] + results['overall']['traditional']['FN']
171
+ total_fair_errors = results['overall']['fair']['FP'] + results['overall']['fair']['FN'] + \
172
+ results['overall']['fair']['LE'] + results['overall']['fair']['BE'] + \
173
+ results['overall']['fair']['LBE']
174
+
175
+ assert mode in ['traditional', 'fair'], 'mode must be \'traditional\' or \'fair\''
176
+ assert error_format in ['count', 'proportion'], 'error_format must be \'count\' or \'proportion\''
177
+
178
+ if mode == 'traditional':
179
+ for k, v in results['per_label'][mode].items():
180
+ if error_format == 'count':
181
+ output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], 'TP': v['TP'],
182
+ 'FP': v['FP'], 'FN': v['FN']}
183
+ elif error_format == 'proportion':
184
+ output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], 'TP': v['TP'],
185
+ 'FP': v['FP'] / total_trad_errors, 'FN': v['FN'] / total_trad_errors}
186
+ elif mode == 'fair':
187
+ for k, v in results['per_label'][mode].items():
188
+ if error_format == 'count':
189
+ output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], 'TP': v['TP'],
190
+ 'FP': v['FP'], 'FN': v['FN'], 'LE': v['LE'], 'BE': v['BE'], 'LBE': v['LBE']}
191
+ elif error_format == 'proportion':
192
+ output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], 'TP': v['TP'],
193
+ 'FP': v['FP'] / total_fair_errors, 'FN': v['FN'] / total_fair_errors,
194
+ 'LE': v['LE'] / total_fair_errors, 'BE': v['BE'] / total_fair_errors,
195
+ 'LBE': v['LBE'] / total_fair_errors}
196
+
197
+ output['overall_precision'] = results['overall'][mode]['Prec']
198
+ output['overall_recall'] = results['overall'][mode]['Rec']
199
+ output['overall_f1'] = results['overall'][mode]['F1']
200
+
201
+ if mode == 'traditional':
202
+ output['TP'] = results['overall'][mode]['TP']
203
+ output['FP'] = results['overall'][mode]['FP']
204
+ output['FN'] = results['overall'][mode]['FN']
205
+ if error_format == 'proportion':
206
+ output['FP'] = output['FP'] / total_trad_errors
207
+ output['FN'] = output['FN'] / total_trad_errors
208
+ elif mode == 'fair':
209
+ output['TP'] = results['overall'][mode]['TP']
210
+ output['FP'] = results['overall'][mode]['FP']
211
+ output['FN'] = results['overall'][mode]['FN']
212
+ output['LE'] = results['overall'][mode]['LE']
213
+ output['BE'] = results['overall'][mode]['BE']
214
+ output['LBE'] = results['overall'][mode]['LBE']
215
+ if error_format == 'proportion':
216
+ output['FP'] = output['FP'] / total_fair_errors
217
+ output['FN'] = output['FN'] / total_fair_errors
218
+ output['LE'] = output['LE'] / total_fair_errors
219
+ output['BE'] = output['BE'] / total_fair_errors
220
+ output['LBE'] = output['LBE'] / total_fair_errors
221
+
222
+ return output
223
+
224
+
225
+ def seq_to_fair(seq_sentences):
226
+ out = []
227
+ for seq_sentence in seq_sentences:
228
+ sentence = []
229
+ for entity in seq_sentence:
230
+ span = str(entity).replace('(', '').replace(')', '').replace(' ', '').split(',')
231
+ span = span[1:]
232
+ span[-1] = int(span[-1]) - 1
233
+ span[1] = int(span[1])
234
+ span.append({i for i in range(span[1], span[2] + 1)})
235
+ sentence.append(span)
236
+ out.append(sentence)
237
+ return out
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ git+https://github.com/huggingface/evaluate@main
2
+
3
+ seqeval