Spaces:
Runtime error
Runtime error
add evaluate_data.py file
Browse files- evaluate_data.py +125 -0
evaluate_data.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from evaluate_model import compute_metrics
|
2 |
+
from datasets import load_from_disk
|
3 |
+
from transformers import AutoTokenizer
|
4 |
+
import os
|
5 |
+
import pickle
|
6 |
+
from transformers import AutoModelForTokenClassification
|
7 |
+
# from transformers import DataCollatorForTokenClassification
|
8 |
+
from utils import tokenize_and_align_labels
|
9 |
+
from rich import print
|
10 |
+
import huggingface_hub
|
11 |
+
import torch
|
12 |
+
|
13 |
+
|
14 |
+
# _ = load_dotenv(find_dotenv()) # read local .env file
|
15 |
+
hf_token= os.environ['HF_TOKEN']
|
16 |
+
huggingface_hub.login(hf_token)
|
17 |
+
|
18 |
+
checkpoint = 'elshehawy/finer-ord-transformers'
|
19 |
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
20 |
+
|
21 |
+
data_path = './data/merged_dataset/'
|
22 |
+
|
23 |
+
test = load_from_disk(data_path)['test']
|
24 |
+
|
25 |
+
feature_path = './data/ner_feature.pickle'
|
26 |
+
|
27 |
+
with open(feature_path, 'rb') as f:
|
28 |
+
ner_feature = pickle.load(f)
|
29 |
+
|
30 |
+
|
31 |
+
# data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
|
32 |
+
|
33 |
+
ner_model = AutoModelForTokenClassification.from_pretrained(checkpoint)
|
34 |
+
|
35 |
+
|
36 |
+
tokenized_dataset = test.map(
|
37 |
+
tokenize_and_align_labels,
|
38 |
+
batched=True,
|
39 |
+
batch_size=None,
|
40 |
+
remove_columns=test.column_names[2:],
|
41 |
+
fn_kwargs={'tokenizer': tokenizer}
|
42 |
+
)
|
43 |
+
|
44 |
+
# tokenized_dataset.set_format('torch')
|
45 |
+
|
46 |
+
|
47 |
+
def collate_fn(data):
|
48 |
+
input_ids = [(element['input_ids']) for element in data]
|
49 |
+
attention_mask = [element['attention_mask'] for element in data]
|
50 |
+
token_type_ids = [element['token_type_ids'] for element in data]
|
51 |
+
labels = [element['labels'] for element in data]
|
52 |
+
|
53 |
+
return input_ids, token_type_ids, attention_mask, labels
|
54 |
+
|
55 |
+
loader = torch.utils.data.DataLoader(tokenized_test, batch_size=32, collate_fn=collate_fn2)
|
56 |
+
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
57 |
+
|
58 |
+
ner_model = ner_model.eval()
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
def get_metrics_trf()
|
63 |
+
y_true, logits = [], []
|
64 |
+
|
65 |
+
for input_ids, token_type_ids, attention_mask, labels in tqdm(loader):
|
66 |
+
ner_model = ner_model.to(device)
|
67 |
+
with torch.no_grad():
|
68 |
+
logits.extend(
|
69 |
+
ner_model(
|
70 |
+
input_ids=torch.tensor(input_ids).to(device),
|
71 |
+
token_type_ids=torch.tensor(token_type_ids).to(device),
|
72 |
+
attention_mask=torch.tensor(attention_mask).to(device)
|
73 |
+
).logits.cpu().numpy()
|
74 |
+
)
|
75 |
+
|
76 |
+
y_true.extend(labels)
|
77 |
+
|
78 |
+
|
79 |
+
all_metrics = compute_metrics((logits, y_true))
|
80 |
+
return all_metrics
|
81 |
+
|
82 |
+
# with open('./metrics/trf/metrics.json', 'w') as f:
|
83 |
+
# json.dump(all_metrics, f)
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
def find_orgs(tokens, labels):
|
89 |
+
orgs = []
|
90 |
+
prev_tok_id = 0
|
91 |
+
for i, (token, label) in enumerate(zip(tokens, labels)):
|
92 |
+
if label == 'B-ORG':
|
93 |
+
org = []
|
94 |
+
org.append(token)
|
95 |
+
orgs.append(org)
|
96 |
+
prev_tok_id = i
|
97 |
+
|
98 |
+
if label == 'I-ORG' and (i-1) == prev_tok_id:
|
99 |
+
org = orgs[-1]
|
100 |
+
org.append(token)
|
101 |
+
orgs[-1] = org
|
102 |
+
prev_tok_id = i
|
103 |
+
# print(i)
|
104 |
+
|
105 |
+
return [tokenizer.convert_tokens_to_string(org) for org in orgs]
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
def store_sample_data():
|
110 |
+
test_data = []
|
111 |
+
|
112 |
+
for sent in finer['test']:
|
113 |
+
labels = [ner_feature.feature.int2str(l) for l in sent['ner_tags']]
|
114 |
+
# print(labels)
|
115 |
+
sent_orgs = find_orgs(sent['tokens'], labels)
|
116 |
+
|
117 |
+
sent_text = tokenizer.convert_tokens_to_string(sent['tokens'])
|
118 |
+
test_data.append({
|
119 |
+
'id': sent['id'],
|
120 |
+
'text': sent_text,
|
121 |
+
'orgs': sent_orgs
|
122 |
+
})
|
123 |
+
|
124 |
+
with open('data/sample_data.json', 'w') as f:
|
125 |
+
json.dump(test_data, f)
|