nehalelkaref commited on
Commit
bbf73b7
·
1 Parent(s): 17654dc

Delete validate.py

Browse files
Files changed (1) hide show
  1. validate.py +0 -168
validate.py DELETED
@@ -1,168 +0,0 @@
1
- import argparse
2
- import re
3
-
4
- import torch
5
- from tqdm.auto import tqdm
6
-
7
- from network import EntNet
8
- from utils import read_conll_ner, split_conll_docs, create_context_data, extract_spans
9
-
10
- use_cuda = torch.cuda.is_available()
11
- device = torch.device("cuda" if use_cuda else "cpu")
12
-
13
-
14
- def classify(model, sents, pos, batch_size):
15
- model.eval()
16
- result = []
17
- for i in tqdm(range(0, len(sents), batch_size), desc='classifying... '):
18
- tag_seqs = model(sentences=sents[i:i + batch_size],
19
- pos=pos[i:i + batch_size])
20
- result.extend(tag_seqs['pred_tags'])
21
- # f1, p, r
22
- return [[[w, t] for w, t in zip(s, r)] for s, r in zip(sents, result)]
23
-
24
-
25
- def entities_from_token_classes(tokens):
26
- ENTITY_BEGIN_REGEX = r"^B" # -(\w+)"
27
- ENTITY_MIDDLE_REGEX = r"^I" # -(\w+)"
28
-
29
- entities = []
30
- current_entity = None
31
- start_index_of_current_entity = 0
32
- end_index_of_current_entity = 0
33
- for i, kls in enumerate(tokens):
34
- m = re.match(ENTITY_BEGIN_REGEX, kls)
35
- if m is not None:
36
- if current_entity is not None:
37
- entities.append({
38
- "type": current_entity,
39
- "index": [start_index_of_current_entity,
40
- end_index_of_current_entity]
41
- })
42
- # start of entity
43
- current_entity = m.string.split('-')[1] if '-' in m.string else ''
44
- start_index_of_current_entity = i
45
- end_index_of_current_entity = i
46
- continue
47
-
48
- m = re.match(ENTITY_MIDDLE_REGEX, kls)
49
- if current_entity is not None:
50
- if m is None:
51
- # after the end of this entity
52
- entities.append({
53
- "type": current_entity,
54
- "index": [start_index_of_current_entity,
55
- end_index_of_current_entity]
56
- })
57
- current_entity = None
58
- continue
59
- # in the middle of this entity
60
- end_index_of_current_entity = i
61
-
62
- # Add any remaining entity
63
- if current_entity is not None:
64
- entities.append({
65
- "type": current_entity,
66
- "index": [start_index_of_current_entity,
67
- end_index_of_current_entity]
68
- })
69
-
70
- return entities
71
-
72
-
73
- def calc_f1(targs, preds):
74
- stat_dict = {
75
- 'overall': {'unl_tp': 0, 'lab_tp': 0, 'targs': 0, 'preds': 0}
76
- }
77
-
78
- for sent_targs, sent_preds in zip(targs, preds):
79
- stat_dict['overall']['targs'] += len(sent_targs)
80
- stat_dict['overall']['preds'] += len(sent_preds)
81
-
82
- for pred in sent_preds:
83
- if pred['type'] not in stat_dict.keys():
84
- stat_dict[pred['type']] = {'lab_tp': 0, 'targs': 0, 'preds': 0}
85
- stat_dict[pred['type']]['preds'] += 1
86
-
87
- for targ in sent_targs:
88
- if targ['type'] not in stat_dict.keys():
89
- stat_dict[targ['type']] = {'lab_tp': 0, 'targs': 0, 'preds': 0}
90
- stat_dict[targ['type']]['targs'] += 1
91
- # is there a span that matches exactly?
92
- for pred in sent_preds:
93
- if pred['index'][0] == targ['index'][0] and pred['index'][1] == targ['index'][1]:
94
- stat_dict['overall']['unl_tp'] += 1
95
- # if so do the tags match exactly?
96
- if pred['type'] == targ['type']:
97
- stat_dict['overall']['lab_tp'] += 1
98
- stat_dict[targ['type']]['lab_tp'] += 1
99
-
100
- for k in stat_dict.keys():
101
- if k == 'overall':
102
- stat_dict[k]['unl_p'] = stat_dict[k]['unl_tp'] / stat_dict[k]['preds'] if stat_dict[k]['preds'] else 0
103
- stat_dict[k]['unl_r'] = stat_dict[k]['unl_tp'] / stat_dict[k]['targs'] if stat_dict[k]['targs'] else 0
104
- stat_dict[k]['unl_f1'] = 2 * stat_dict[k]['unl_p'] * stat_dict[k]['unl_r'] / (
105
- stat_dict[k]['unl_p'] + stat_dict[k]['unl_r']) if (
106
- stat_dict[k]['unl_p'] + stat_dict[k]['unl_r']) else 0
107
- stat_dict[k]['lab_p'] = stat_dict[k]['lab_tp'] / stat_dict[k]['preds'] if stat_dict[k]['preds'] else 0
108
- stat_dict[k]['lab_r'] = stat_dict[k]['lab_tp'] / stat_dict[k]['targs'] if stat_dict[k]['targs'] else 0
109
- stat_dict[k]['lab_f1'] = 2 * stat_dict[k]['lab_p'] * stat_dict[k]['lab_r'] / (
110
- stat_dict[k]['lab_p'] + stat_dict[k]['lab_r']) if (stat_dict[k]['lab_p'] + stat_dict[k]['lab_r']) else 0
111
- class_f1s = [v['lab_f1'] for k, v in stat_dict.items() if k != 'overall']
112
- stat_dict['overall']['macro_lab_f1'] = sum(class_f1s) / len(class_f1s)
113
- return stat_dict
114
-
115
-
116
- def main(args):
117
- global device
118
- device = torch.device('cuda' if use_cuda else 'cpu')
119
-
120
- test_columns = read_conll_ner(args.test_path)
121
- test_docs = split_conll_docs(test_columns[0])
122
- test_data = create_context_data(test_docs, args.context_size)
123
-
124
- sents = [td[0] for td in test_data]
125
- pos = [td[1] for td in test_data]
126
-
127
- if len(args.model_path) > 1 or args.span_model_path is not None:
128
- model = StagedEnsemble(model_paths=args.model_path, span_model_paths=args.span_model_path, device=device)
129
- else:
130
- model = EntNet.load_model(args.model_path[0], device=device)
131
- model.to(device)
132
-
133
- BATCH_SIZE = args.batch_size
134
- res = classify(model, sents, pos, BATCH_SIZE)
135
- targets = [td[2] for td in test_data]
136
-
137
- targ_tags = [entities_from_token_classes(td[2]) for td in test_data]
138
- pred_tags = [entities_from_token_classes([t[1] for t in r]) for r in res]
139
- result = calc_f1(targ_tags, pred_tags)
140
-
141
- print(f'Overall unlabelled - F1:{result["overall"]["unl_f1"]}, '
142
- f'P:{result["overall"]["unl_p"]}, '
143
- f'R:{result["overall"]["unl_r"]}')
144
- print(f'Overall labelled - Micro F1:{result["overall"]["lab_f1"]}, '
145
- f'P:{result["overall"]["lab_p"]}, '
146
- f'R:{result["overall"]["lab_r"]}')
147
- print(f'Overall labelled - Macro F1:{result["overall"]["macro_lab_f1"]}')
148
- for k, v in result.items():
149
- if k == 'overall':
150
- continue
151
- print(f'{k} - F1:{v["lab_f1"]}, P:{v["lab_p"]}, R:{v["lab_r"]}')
152
-
153
-
154
- if __name__ == "__main__":
155
- parser = argparse.ArgumentParser()
156
- parser.add_argument('--model_path', type=str, nargs='+', default=None, required=True, help='')
157
- parser.add_argument('--span_model_path', type=str, nargs='*', default=None, help='')
158
- # parser.add_argument('--network_type', type=str,
159
- # choices=['span', 'entity', 'joint'], required=True,
160
- # default=None, help='If entity is chosen, a path to a '
161
- # 'span model is required also')
162
- parser.add_argument('--test_path', type=str, default=None, help='')
163
- parser.add_argument('--context_size', type=int, default=1, help='')
164
- parser.add_argument('--batch_size', type=int, default=8, help='')
165
- # parser.add_argument('--cuda_id', type=int, default=0, help='')
166
-
167
- args = parser.parse_args()
168
- main(args)