elshehawy commited on
Commit
fffc1e9
Β·
1 Parent(s): 894b24d

add evaluate_data.py file

Browse files
Files changed (1) hide show
  1. evaluate_data.py +125 -0
evaluate_data.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from evaluate_model import compute_metrics
2
+ from datasets import load_from_disk
3
+ from transformers import AutoTokenizer
4
+ import os
5
+ import pickle
6
+ from transformers import AutoModelForTokenClassification
7
+ # from transformers import DataCollatorForTokenClassification
8
+ from utils import tokenize_and_align_labels
9
+ from rich import print
10
+ import huggingface_hub
11
+ import torch
12
+
13
+
14
+ # _ = load_dotenv(find_dotenv()) # read local .env file
15
+ hf_token= os.environ['HF_TOKEN']
16
+ huggingface_hub.login(hf_token)
17
+
18
+ checkpoint = 'elshehawy/finer-ord-transformers'
19
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
20
+
21
+ data_path = './data/merged_dataset/'
22
+
23
+ test = load_from_disk(data_path)['test']
24
+
25
+ feature_path = './data/ner_feature.pickle'
26
+
27
+ with open(feature_path, 'rb') as f:
28
+ ner_feature = pickle.load(f)
29
+
30
+
31
+ # data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
32
+
33
+ ner_model = AutoModelForTokenClassification.from_pretrained(checkpoint)
34
+
35
+
36
+ tokenized_dataset = test.map(
37
+ tokenize_and_align_labels,
38
+ batched=True,
39
+ batch_size=None,
40
+ remove_columns=test.column_names[2:],
41
+ fn_kwargs={'tokenizer': tokenizer}
42
+ )
43
+
44
+ # tokenized_dataset.set_format('torch')
45
+
46
+
47
+ def collate_fn(data):
48
+ input_ids = [(element['input_ids']) for element in data]
49
+ attention_mask = [element['attention_mask'] for element in data]
50
+ token_type_ids = [element['token_type_ids'] for element in data]
51
+ labels = [element['labels'] for element in data]
52
+
53
+ return input_ids, token_type_ids, attention_mask, labels
54
+
55
+ loader = torch.utils.data.DataLoader(tokenized_test, batch_size=32, collate_fn=collate_fn2)
56
+ device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
57
+
58
+ ner_model = ner_model.eval()
59
+
60
+
61
+
62
+ def get_metrics_trf()
63
+ y_true, logits = [], []
64
+
65
+ for input_ids, token_type_ids, attention_mask, labels in tqdm(loader):
66
+ ner_model = ner_model.to(device)
67
+ with torch.no_grad():
68
+ logits.extend(
69
+ ner_model(
70
+ input_ids=torch.tensor(input_ids).to(device),
71
+ token_type_ids=torch.tensor(token_type_ids).to(device),
72
+ attention_mask=torch.tensor(attention_mask).to(device)
73
+ ).logits.cpu().numpy()
74
+ )
75
+
76
+ y_true.extend(labels)
77
+
78
+
79
+ all_metrics = compute_metrics((logits, y_true))
80
+ return all_metrics
81
+
82
+ # with open('./metrics/trf/metrics.json', 'w') as f:
83
+ # json.dump(all_metrics, f)
84
+
85
+
86
+
87
+
88
+ def find_orgs(tokens, labels):
89
+ orgs = []
90
+ prev_tok_id = 0
91
+ for i, (token, label) in enumerate(zip(tokens, labels)):
92
+ if label == 'B-ORG':
93
+ org = []
94
+ org.append(token)
95
+ orgs.append(org)
96
+ prev_tok_id = i
97
+
98
+ if label == 'I-ORG' and (i-1) == prev_tok_id:
99
+ org = orgs[-1]
100
+ org.append(token)
101
+ orgs[-1] = org
102
+ prev_tok_id = i
103
+ # print(i)
104
+
105
+ return [tokenizer.convert_tokens_to_string(org) for org in orgs]
106
+
107
+
108
+
109
+ def store_sample_data():
110
+ test_data = []
111
+
112
+ for sent in finer['test']:
113
+ labels = [ner_feature.feature.int2str(l) for l in sent['ner_tags']]
114
+ # print(labels)
115
+ sent_orgs = find_orgs(sent['tokens'], labels)
116
+
117
+ sent_text = tokenizer.convert_tokens_to_string(sent['tokens'])
118
+ test_data.append({
119
+ 'id': sent['id'],
120
+ 'text': sent_text,
121
+ 'orgs': sent_orgs
122
+ })
123
+
124
+ with open('data/sample_data.json', 'w') as f:
125
+ json.dump(test_data, f)