Spaces:
Build error
Build error
Commit
·
bbf73b7
1
Parent(s):
17654dc
Delete validate.py
Browse files- validate.py +0 -168
validate.py
DELETED
@@ -1,168 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import re
|
3 |
-
|
4 |
-
import torch
|
5 |
-
from tqdm.auto import tqdm
|
6 |
-
|
7 |
-
from network import EntNet
|
8 |
-
from utils import read_conll_ner, split_conll_docs, create_context_data, extract_spans
|
9 |
-
|
10 |
-
use_cuda = torch.cuda.is_available()
|
11 |
-
device = torch.device("cuda" if use_cuda else "cpu")
|
12 |
-
|
13 |
-
|
14 |
-
def classify(model, sents, pos, batch_size):
|
15 |
-
model.eval()
|
16 |
-
result = []
|
17 |
-
for i in tqdm(range(0, len(sents), batch_size), desc='classifying... '):
|
18 |
-
tag_seqs = model(sentences=sents[i:i + batch_size],
|
19 |
-
pos=pos[i:i + batch_size])
|
20 |
-
result.extend(tag_seqs['pred_tags'])
|
21 |
-
# f1, p, r
|
22 |
-
return [[[w, t] for w, t in zip(s, r)] for s, r in zip(sents, result)]
|
23 |
-
|
24 |
-
|
25 |
-
def entities_from_token_classes(tokens):
|
26 |
-
ENTITY_BEGIN_REGEX = r"^B" # -(\w+)"
|
27 |
-
ENTITY_MIDDLE_REGEX = r"^I" # -(\w+)"
|
28 |
-
|
29 |
-
entities = []
|
30 |
-
current_entity = None
|
31 |
-
start_index_of_current_entity = 0
|
32 |
-
end_index_of_current_entity = 0
|
33 |
-
for i, kls in enumerate(tokens):
|
34 |
-
m = re.match(ENTITY_BEGIN_REGEX, kls)
|
35 |
-
if m is not None:
|
36 |
-
if current_entity is not None:
|
37 |
-
entities.append({
|
38 |
-
"type": current_entity,
|
39 |
-
"index": [start_index_of_current_entity,
|
40 |
-
end_index_of_current_entity]
|
41 |
-
})
|
42 |
-
# start of entity
|
43 |
-
current_entity = m.string.split('-')[1] if '-' in m.string else ''
|
44 |
-
start_index_of_current_entity = i
|
45 |
-
end_index_of_current_entity = i
|
46 |
-
continue
|
47 |
-
|
48 |
-
m = re.match(ENTITY_MIDDLE_REGEX, kls)
|
49 |
-
if current_entity is not None:
|
50 |
-
if m is None:
|
51 |
-
# after the end of this entity
|
52 |
-
entities.append({
|
53 |
-
"type": current_entity,
|
54 |
-
"index": [start_index_of_current_entity,
|
55 |
-
end_index_of_current_entity]
|
56 |
-
})
|
57 |
-
current_entity = None
|
58 |
-
continue
|
59 |
-
# in the middle of this entity
|
60 |
-
end_index_of_current_entity = i
|
61 |
-
|
62 |
-
# Add any remaining entity
|
63 |
-
if current_entity is not None:
|
64 |
-
entities.append({
|
65 |
-
"type": current_entity,
|
66 |
-
"index": [start_index_of_current_entity,
|
67 |
-
end_index_of_current_entity]
|
68 |
-
})
|
69 |
-
|
70 |
-
return entities
|
71 |
-
|
72 |
-
|
73 |
-
def calc_f1(targs, preds):
|
74 |
-
stat_dict = {
|
75 |
-
'overall': {'unl_tp': 0, 'lab_tp': 0, 'targs': 0, 'preds': 0}
|
76 |
-
}
|
77 |
-
|
78 |
-
for sent_targs, sent_preds in zip(targs, preds):
|
79 |
-
stat_dict['overall']['targs'] += len(sent_targs)
|
80 |
-
stat_dict['overall']['preds'] += len(sent_preds)
|
81 |
-
|
82 |
-
for pred in sent_preds:
|
83 |
-
if pred['type'] not in stat_dict.keys():
|
84 |
-
stat_dict[pred['type']] = {'lab_tp': 0, 'targs': 0, 'preds': 0}
|
85 |
-
stat_dict[pred['type']]['preds'] += 1
|
86 |
-
|
87 |
-
for targ in sent_targs:
|
88 |
-
if targ['type'] not in stat_dict.keys():
|
89 |
-
stat_dict[targ['type']] = {'lab_tp': 0, 'targs': 0, 'preds': 0}
|
90 |
-
stat_dict[targ['type']]['targs'] += 1
|
91 |
-
# is there a span that matches exactly?
|
92 |
-
for pred in sent_preds:
|
93 |
-
if pred['index'][0] == targ['index'][0] and pred['index'][1] == targ['index'][1]:
|
94 |
-
stat_dict['overall']['unl_tp'] += 1
|
95 |
-
# if so do the tags match exactly?
|
96 |
-
if pred['type'] == targ['type']:
|
97 |
-
stat_dict['overall']['lab_tp'] += 1
|
98 |
-
stat_dict[targ['type']]['lab_tp'] += 1
|
99 |
-
|
100 |
-
for k in stat_dict.keys():
|
101 |
-
if k == 'overall':
|
102 |
-
stat_dict[k]['unl_p'] = stat_dict[k]['unl_tp'] / stat_dict[k]['preds'] if stat_dict[k]['preds'] else 0
|
103 |
-
stat_dict[k]['unl_r'] = stat_dict[k]['unl_tp'] / stat_dict[k]['targs'] if stat_dict[k]['targs'] else 0
|
104 |
-
stat_dict[k]['unl_f1'] = 2 * stat_dict[k]['unl_p'] * stat_dict[k]['unl_r'] / (
|
105 |
-
stat_dict[k]['unl_p'] + stat_dict[k]['unl_r']) if (
|
106 |
-
stat_dict[k]['unl_p'] + stat_dict[k]['unl_r']) else 0
|
107 |
-
stat_dict[k]['lab_p'] = stat_dict[k]['lab_tp'] / stat_dict[k]['preds'] if stat_dict[k]['preds'] else 0
|
108 |
-
stat_dict[k]['lab_r'] = stat_dict[k]['lab_tp'] / stat_dict[k]['targs'] if stat_dict[k]['targs'] else 0
|
109 |
-
stat_dict[k]['lab_f1'] = 2 * stat_dict[k]['lab_p'] * stat_dict[k]['lab_r'] / (
|
110 |
-
stat_dict[k]['lab_p'] + stat_dict[k]['lab_r']) if (stat_dict[k]['lab_p'] + stat_dict[k]['lab_r']) else 0
|
111 |
-
class_f1s = [v['lab_f1'] for k, v in stat_dict.items() if k != 'overall']
|
112 |
-
stat_dict['overall']['macro_lab_f1'] = sum(class_f1s) / len(class_f1s)
|
113 |
-
return stat_dict
|
114 |
-
|
115 |
-
|
116 |
-
def main(args):
|
117 |
-
global device
|
118 |
-
device = torch.device('cuda' if use_cuda else 'cpu')
|
119 |
-
|
120 |
-
test_columns = read_conll_ner(args.test_path)
|
121 |
-
test_docs = split_conll_docs(test_columns[0])
|
122 |
-
test_data = create_context_data(test_docs, args.context_size)
|
123 |
-
|
124 |
-
sents = [td[0] for td in test_data]
|
125 |
-
pos = [td[1] for td in test_data]
|
126 |
-
|
127 |
-
if len(args.model_path) > 1 or args.span_model_path is not None:
|
128 |
-
model = StagedEnsemble(model_paths=args.model_path, span_model_paths=args.span_model_path, device=device)
|
129 |
-
else:
|
130 |
-
model = EntNet.load_model(args.model_path[0], device=device)
|
131 |
-
model.to(device)
|
132 |
-
|
133 |
-
BATCH_SIZE = args.batch_size
|
134 |
-
res = classify(model, sents, pos, BATCH_SIZE)
|
135 |
-
targets = [td[2] for td in test_data]
|
136 |
-
|
137 |
-
targ_tags = [entities_from_token_classes(td[2]) for td in test_data]
|
138 |
-
pred_tags = [entities_from_token_classes([t[1] for t in r]) for r in res]
|
139 |
-
result = calc_f1(targ_tags, pred_tags)
|
140 |
-
|
141 |
-
print(f'Overall unlabelled - F1:{result["overall"]["unl_f1"]}, '
|
142 |
-
f'P:{result["overall"]["unl_p"]}, '
|
143 |
-
f'R:{result["overall"]["unl_r"]}')
|
144 |
-
print(f'Overall labelled - Micro F1:{result["overall"]["lab_f1"]}, '
|
145 |
-
f'P:{result["overall"]["lab_p"]}, '
|
146 |
-
f'R:{result["overall"]["lab_r"]}')
|
147 |
-
print(f'Overall labelled - Macro F1:{result["overall"]["macro_lab_f1"]}')
|
148 |
-
for k, v in result.items():
|
149 |
-
if k == 'overall':
|
150 |
-
continue
|
151 |
-
print(f'{k} - F1:{v["lab_f1"]}, P:{v["lab_p"]}, R:{v["lab_r"]}')
|
152 |
-
|
153 |
-
|
154 |
-
if __name__ == "__main__":
|
155 |
-
parser = argparse.ArgumentParser()
|
156 |
-
parser.add_argument('--model_path', type=str, nargs='+', default=None, required=True, help='')
|
157 |
-
parser.add_argument('--span_model_path', type=str, nargs='*', default=None, help='')
|
158 |
-
# parser.add_argument('--network_type', type=str,
|
159 |
-
# choices=['span', 'entity', 'joint'], required=True,
|
160 |
-
# default=None, help='If entity is chosen, a path to a '
|
161 |
-
# 'span model is required also')
|
162 |
-
parser.add_argument('--test_path', type=str, default=None, help='')
|
163 |
-
parser.add_argument('--context_size', type=int, default=1, help='')
|
164 |
-
parser.add_argument('--batch_size', type=int, default=8, help='')
|
165 |
-
# parser.add_argument('--cuda_id', type=int, default=0, help='')
|
166 |
-
|
167 |
-
args = parser.parse_args()
|
168 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|