SmitaGautam commited on
Commit
0732d74
1 Parent(s): 25b106e

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +322 -233
train.py CHANGED
@@ -1,233 +1,322 @@
1
- from datasets import load_dataset
2
- import numpy as np
3
- from sklearn.svm import SVC
4
- from tqdm.notebook import tqdm
5
- from sklearn.preprocessing import StandardScaler
6
- from sklearn.metrics import classification_report
7
- import nltk
8
- from nltk.corpus import stopwords
9
- from nltk import word_tokenize
10
- from nltk import pos_tag
11
- import pickle
12
- import time
13
- from nltk.corpus import names, gazetteers
14
- from sklearn.model_selection import KFold
15
- from itertools import chain
16
- from sklearn.metrics import precision_score, recall_score, fbeta_score, confusion_matrix
17
- import matplotlib.pyplot as plt
18
- import seaborn as sns
19
-
20
-
21
- nltk.download('stopwords')
22
- stopwords = stopwords.words('english')
23
-
24
- pos_tags = [ 'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NN', 'NNP', 'NNPS',
25
- 'NNS', 'NN|SYM', 'PDT', 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD',
26
- 'VBG', 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP$', 'WRB'
27
- ]
28
-
29
- def feature_vector(word, prev_word_pos_tag, next_word_pos_tag, current_word_pos_tag):
30
- vec = np.zeros(116).astype('float32')
31
- if(word.istitle()):
32
- vec[0] = 1
33
- if word.lower() in stopwords:
34
- vec[1] = 1
35
- if(word.isupper()):
36
- vec[2] = 1
37
- vec[3] = len(word)
38
- vec[4] = word.isdigit()
39
-
40
- if prev_word_pos_tag!=-1:
41
- vec[5+prev_word_pos_tag] = 1
42
-
43
- if next_word_pos_tag!=-1:
44
- vec[42+next_word_pos_tag] = 1
45
-
46
- if current_word_pos_tag!=-1:
47
- vec[79+current_word_pos_tag] = 1
48
-
49
- return vec
50
-
51
-
52
- def feature_vector2(word, prev_word_pos_tag, next_word_pos_tag, current_word_pos_tag):
53
- vec = np.zeros(9).astype('float32')
54
- if(word.istitle()):
55
- vec[0] = 1
56
- if word.lower() in stopwords:
57
- vec[1] = 1
58
- if(word.isupper()):
59
- vec[2] = 1
60
- vec[3] = len(word)
61
- vec[4] = word.isdigit()
62
- # idx : -11, 0...36
63
- # if prev_word_pos_tag!=-11:
64
- # vec[5+prev_word_pos_tag] = 1
65
-
66
- # if next_word_pos_tag!=-11:
67
- # vec[42+next_word_pos_tag] = 1
68
-
69
- # if current_word_pos_tag!=-11:
70
- # vec[79+current_word_pos_tag] = 1
71
-
72
- vec[5] = 1 if word in places else 0
73
- vec[6] = 1 if word in people else 0
74
- vec[7] = 1 if word in countries else 0
75
- vec[8] = 1 if word in nationalities else 0
76
- return vec
77
-
78
-
79
- # This function is used to make dataset with features and target label
80
-
81
- def create_data(data):
82
- x_train = []
83
- y_train = []
84
- for x in data:
85
- for y in range(len(x['tokens'])):
86
- prev_pos = -1 if y==0 or x['pos_tags'][y-1]<10 else x['pos_tags'][y-1]
87
- next_pos = -1 if y==len(x['tokens'])-1 or x['pos_tags'][y+1]<10 else x['pos_tags'][y+1]
88
- current_pos = -1 if x['pos_tags'][y]<10 else x['pos_tags'][y]
89
- wordVec = feature_vector(x['tokens'][y], prev_pos-10, next_pos-10, current_pos-10)
90
- x_train.append(wordVec)
91
- y_train.append(1 if x['ner_tags'][y]!=0 else 0)
92
- return x_train, y_train
93
-
94
- def evaluate_overall_metrics(predictions, folds):
95
- precision, recall, f0_5_score, f1_score, f2_score = 0, 0, 0, 0, 0
96
-
97
- for i, (test_label_flat, y_pred_flat) in enumerate(predictions):
98
- # test_label_flat = list(chain.from_iterable(test_label))
99
- # y_pred_flat = list(chain.from_iterable(y_pred))
100
-
101
- # Calculate scores
102
- f0_5_score += fbeta_score(test_label_flat, y_pred_flat, beta=0.5, average='weighted')
103
- f1_score += fbeta_score(test_label_flat, y_pred_flat, beta=1, average='weighted')
104
- f2_score += fbeta_score(test_label_flat, y_pred_flat, beta=2, average='weighted')
105
- precision += precision_score(test_label_flat, y_pred_flat, average='weighted')
106
- recall += recall_score(test_label_flat, y_pred_flat, average='weighted')
107
-
108
- # Averaging across folds
109
- f0_5_score /= folds
110
- f1_score /= folds
111
- f2_score /= folds
112
- precision /= folds
113
- recall /= folds
114
-
115
- print(f'Overall Metrics:')
116
- print(f'Precision : {precision:.3f}')
117
- print(f'Recall : {recall:.3f}')
118
- print(f'F0.5 Score : {f0_5_score:.3f}')
119
- print(f'F1 Score : {f1_score:.3f}')
120
- print(f'F2 Score : {f2_score:.3f}\n')
121
-
122
- def evaluate_per_pos_metrics(predictions, labels):
123
- combined_true = []
124
- combined_pred = []
125
-
126
- # Flatten the list of lists structure
127
- for test_label, y_pred in predictions:
128
- # for sentence_labels, sentence_preds in zip(test_label, y_pred):
129
- combined_true.extend(test_label)
130
- combined_pred.extend(y_pred)
131
-
132
- for tag in labels:
133
- true_binary = [1 if t == tag else 0 for t in combined_true]
134
- pred_binary = [1 if p == tag else 0 for p in combined_pred]
135
-
136
- # Calculate metrics for the tag
137
- precision = precision_score(true_binary, pred_binary, average='binary', zero_division=0)
138
- recall = recall_score(true_binary, pred_binary, average='binary', zero_division=0)
139
- f1_score = fbeta_score(true_binary, pred_binary, beta=1, average='binary', zero_division=0)
140
-
141
- print(f"Metrics for {tag}:")
142
- print(f'Precision : {precision:.3f}')
143
- print(f'Recall : {recall:.3f}')
144
- print(f'F1 Score : {f1_score:.3f}\n')
145
-
146
- def plot_confusion_matrix(predictions, labels, folds):
147
- matrix = None
148
- for i, (test_label_flat, y_pred_flat) in enumerate(predictions):
149
- # test_label_flat = list(chain.from_iterable(test_label))
150
- # y_pred_flat = list(chain.from_iterable(y_pred))
151
-
152
- # Compute confusion matrix for this fold
153
- cm = confusion_matrix(test_label_flat, y_pred_flat, labels=labels)
154
- if i == 0:
155
- matrix = cm
156
- else:
157
- matrix += cm
158
-
159
- matrix = matrix.astype('float')
160
- matrix = matrix / folds
161
- matrix = matrix / np.sum(matrix, axis=1, keepdims=True) # Normalize
162
-
163
- plt.figure(figsize=(10, 8))
164
- sns.heatmap(matrix, annot=True, fmt=".2f", cmap='Blues', xticklabels=labels, yticklabels=labels)
165
- plt.xlabel('Predicted')
166
- plt.ylabel('Actual')
167
- plt.title('Normalized Confusion Matrix for NER')
168
- plt.show()
169
-
170
- if __name__ == "__main__":
171
- data = load_dataset("conll2003", trust_remote_code=True)
172
- d_train = data['train']
173
- d_validation = data['validation']
174
- d_test = data['test']
175
-
176
- nltk.download('gazetteers')
177
- places=set(gazetteers.words())
178
- people=set(names.words())
179
- countries=set(gazetteers.words('countries.txt'))
180
- nationalities=set(gazetteers.words('nationalities.txt'))
181
- x_train, y_train = create_data(d_train)
182
- x_val, y_val = create_data(d_validation)
183
- x_test, y_test = create_data(d_test)
184
- all_X_train = np.concatenate((x_train, x_val, x_test))
185
- all_y_train = np.concatenate((y_train, y_val, y_test))
186
-
187
- #K-Fold
188
- num_fold = 5
189
- kf = KFold(n_splits=num_fold, random_state=42, shuffle=True)
190
- indices = np.arange(len(all_X_train))
191
-
192
- predictions = []
193
- all_models = []
194
-
195
- for i, (train_index, test_index) in enumerate(kf.split(indices)):
196
- print(f"Fold {i} Train Length: {len(train_index)} Test Length: {len(test_index)}")
197
- # all_folds.append((train_index, test_index))# Standardize the features such that all features contribute equally to the distance metric computation of the SVM
198
- X_train = all_X_train[train_index]
199
- y_train = all_y_train[train_index]
200
-
201
- X_test = all_X_train[test_index]
202
- y_test = all_y_train[test_index]
203
-
204
- # scaler = StandardScaler()
205
- # Fit only on the training data (i.e. compute mean and std)
206
- # X_train = scaler.fit_transform(X_train)
207
-
208
- # Use the train data fit values to scale val and test
209
- # X_train = scaler.transform(X_train)
210
- # X_val = scaler.transform(X_val)
211
- # X_test = scaler.transform(X_test)
212
-
213
- model = SVC(random_state = 42, verbose = True)
214
- model.fit(X_train, y_train)
215
-
216
- y_pred_val = model.predict(X_test)
217
-
218
- print("-------"*6)
219
- print(classification_report(y_true=y_test, y_pred=y_pred_val))
220
- print("-------"*6)
221
-
222
- pickle.dump(model, open(f"ner_svm_{str(i)}.pkl", 'wb'))
223
-
224
- predictions.append((y_test, y_pred_val))
225
- all_models.append(model)
226
- break
227
-
228
-
229
- FOLDS = 5
230
- labels = sorted(model.classes_)
231
- evaluate_overall_metrics(predictions, FOLDS)
232
- evaluate_per_pos_metrics(predictions, labels)
233
- plot_confusion_matrix(predictions, labels, FOLDS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import numpy as np
3
+ from sklearn.svm import SVC
4
+ from tqdm.notebook import tqdm
5
+ from sklearn.preprocessing import StandardScaler
6
+ from sklearn.metrics import classification_report
7
+ import nltk
8
+ from nltk.corpus import stopwords
9
+ from nltk import word_tokenize
10
+ from nltk import pos_tag
11
+ import pickle
12
+ import time
13
+ from nltk.corpus import names, gazetteers
14
+ from sklearn.model_selection import KFold
15
+ from itertools import chain
16
+ from sklearn.metrics import precision_score, recall_score, fbeta_score, confusion_matrix
17
+ import matplotlib.pyplot as plt
18
+ import seaborn as sns
19
+ from string import punctuation
20
+
21
+
22
+ nltk.download('stopwords')
23
+ stopwords = stopwords.words('english')
24
+ PUNCT = list(punctuation)
25
+
26
+ nltk.download('gazetteers')
27
+ nltk.download('names')
28
+ from nltk.corpus import names, gazetteers
29
+
30
+ places=set(gazetteers.words())
31
+ people=set(names.words())
32
+ countries=set(gazetteers.words('countries.txt'))
33
+ nationalities=set(gazetteers.words('nationalities.txt'))
34
+
35
+ pos_tags = [ 'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NN', 'NNP', 'NNPS',
36
+ 'NNS', 'NN|SYM', 'PDT', 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD',
37
+ 'VBG', 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP$', 'WRB'
38
+ ]
39
+
40
+
41
+ def feature_vector(w, scaled_position, pos_tag):
42
+ vec = np.zeros(12).astype(np.float32)
43
+
44
+ #if w[0].isupper():
45
+ #title = 1
46
+ #else:
47
+ #title = 0
48
+
49
+ if w.isupper():
50
+ allcaps = 1
51
+ else:
52
+ allcaps = 0
53
+
54
+ if w in PUNCT:
55
+ punct = 1
56
+ else:
57
+ punct = 0
58
+
59
+ if w.lower() in stopwords:
60
+ sw=1
61
+ else:
62
+ sw=0
63
+
64
+ if w.isdigit():
65
+ is_digit=1
66
+ else:
67
+ is_digit=0
68
+
69
+ if pos_tag in ('VB','VBD','VBG','VBN','VBP','VBZ'):
70
+ is_verb=1
71
+ else:
72
+ is_verb=0
73
+
74
+ #if pos_tag in ('NN','NNP','NNPS','NNS'):
75
+ if pos_tag in ('NNP','NNPS'):
76
+ is_noun=1
77
+ else:
78
+ is_noun=0
79
+
80
+ if w in places:
81
+ is_place=1
82
+ else:
83
+ is_place=0
84
+
85
+ if w in people:
86
+ is_people=1
87
+ else:
88
+ is_people=0
89
+
90
+ if w in countries:
91
+ is_country=1
92
+ else:
93
+ is_country=0
94
+
95
+ if w in nationalities:
96
+ is_nation=1
97
+ else:
98
+ is_nation=0
99
+
100
+ # Build vector
101
+ #vec[0] = title
102
+ vec[0] = allcaps
103
+ vec[1] = len(w)
104
+ vec[2] = punct
105
+ vec[3] = scaled_position
106
+ vec[4] = sw
107
+ vec[5] = is_digit
108
+ vec[6] = is_verb
109
+ vec[7] = is_noun
110
+ vec[8] = is_place
111
+ vec[9] = is_people
112
+ vec[10] = is_country
113
+ vec[11] = is_nation
114
+
115
+ return vec
116
+
117
+
118
+ def feature_vector_d(word, prev_word_pos_tag, next_word_pos_tag, current_word_pos_tag):
119
+ vec = np.zeros(116).astype('float32')
120
+ if(word.istitle()):
121
+ vec[0] = 1
122
+ if word.lower() in stopwords:
123
+ vec[1] = 1
124
+ if(word.isupper()):
125
+ vec[2] = 1
126
+ vec[3] = len(word)
127
+ vec[4] = word.isdigit()
128
+
129
+ if prev_word_pos_tag!=-1:
130
+ vec[5+prev_word_pos_tag] = 1
131
+
132
+ if next_word_pos_tag!=-1:
133
+ vec[42+next_word_pos_tag] = 1
134
+
135
+ if current_word_pos_tag!=-1:
136
+ vec[79+current_word_pos_tag] = 1
137
+
138
+ return vec
139
+
140
+
141
+ def feature_vector2(word, prev_word_pos_tag, next_word_pos_tag, current_word_pos_tag):
142
+ vec = np.zeros(9).astype('float32')
143
+ if(word.istitle()):
144
+ vec[0] = 1
145
+ if word.lower() in stopwords:
146
+ vec[1] = 1
147
+ if(word.isupper()):
148
+ vec[2] = 1
149
+ vec[3] = len(word)
150
+ vec[4] = word.isdigit()
151
+ # idx : -11, 0...36
152
+ # if prev_word_pos_tag!=-11:
153
+ # vec[5+prev_word_pos_tag] = 1
154
+
155
+ # if next_word_pos_tag!=-11:
156
+ # vec[42+next_word_pos_tag] = 1
157
+
158
+ # if current_word_pos_tag!=-11:
159
+ # vec[79+current_word_pos_tag] = 1
160
+
161
+ vec[5] = 1 if word in places else 0
162
+ vec[6] = 1 if word in people else 0
163
+ vec[7] = 1 if word in countries else 0
164
+ vec[8] = 1 if word in nationalities else 0
165
+ return vec
166
+
167
+
168
+ # This function is used to make dataset with features and target label
169
+
170
+ def create_data(data):
171
+ x_train = []
172
+ y_train = []
173
+ for x in data:
174
+ for y in range(len(x['tokens'])):
175
+ prev_pos = -1 if y==0 or x['pos_tags'][y-1]<10 else x['pos_tags'][y-1]
176
+ next_pos = -1 if y==len(x['tokens'])-1 or x['pos_tags'][y+1]<10 else x['pos_tags'][y+1]
177
+ current_pos = -1 if x['pos_tags'][y]<10 else x['pos_tags'][y]
178
+ wordVec = feature_vector(x['tokens'][y], prev_pos-10, next_pos-10, current_pos-10)
179
+ x_train.append(wordVec)
180
+ y_train.append(1 if x['ner_tags'][y]!=0 else 0)
181
+ return x_train, y_train
182
+
183
+ def evaluate_overall_metrics(predictions, folds):
184
+ precision, recall, f0_5_score, f1_score, f2_score = 0, 0, 0, 0, 0
185
+
186
+ for i, (test_label_flat, y_pred_flat) in enumerate(predictions):
187
+ # test_label_flat = list(chain.from_iterable(test_label))
188
+ # y_pred_flat = list(chain.from_iterable(y_pred))
189
+
190
+ # Calculate scores
191
+ f0_5_score += fbeta_score(test_label_flat, y_pred_flat, beta=0.5, average='weighted')
192
+ f1_score += fbeta_score(test_label_flat, y_pred_flat, beta=1, average='weighted')
193
+ f2_score += fbeta_score(test_label_flat, y_pred_flat, beta=2, average='weighted')
194
+ precision += precision_score(test_label_flat, y_pred_flat, average='weighted')
195
+ recall += recall_score(test_label_flat, y_pred_flat, average='weighted')
196
+
197
+ # Averaging across folds
198
+ f0_5_score /= folds
199
+ f1_score /= folds
200
+ f2_score /= folds
201
+ precision /= folds
202
+ recall /= folds
203
+
204
+ print(f'Overall Metrics:')
205
+ print(f'Precision : {precision:.3f}')
206
+ print(f'Recall : {recall:.3f}')
207
+ print(f'F0.5 Score : {f0_5_score:.3f}')
208
+ print(f'F1 Score : {f1_score:.3f}')
209
+ print(f'F2 Score : {f2_score:.3f}\n')
210
+
211
+ def evaluate_per_pos_metrics(predictions, labels):
212
+ combined_true = []
213
+ combined_pred = []
214
+
215
+ # Flatten the list of lists structure
216
+ for test_label, y_pred in predictions:
217
+ # for sentence_labels, sentence_preds in zip(test_label, y_pred):
218
+ combined_true.extend(test_label)
219
+ combined_pred.extend(y_pred)
220
+
221
+ for tag in labels:
222
+ true_binary = [1 if t == tag else 0 for t in combined_true]
223
+ pred_binary = [1 if p == tag else 0 for p in combined_pred]
224
+
225
+ # Calculate metrics for the tag
226
+ precision = precision_score(true_binary, pred_binary, average='binary', zero_division=0)
227
+ recall = recall_score(true_binary, pred_binary, average='binary', zero_division=0)
228
+ f1_score = fbeta_score(true_binary, pred_binary, beta=1, average='binary', zero_division=0)
229
+
230
+ print(f"Metrics for {tag}:")
231
+ print(f'Precision : {precision:.3f}')
232
+ print(f'Recall : {recall:.3f}')
233
+ print(f'F1 Score : {f1_score:.3f}\n')
234
+
235
+ def plot_confusion_matrix(predictions, labels, folds):
236
+ matrix = None
237
+ for i, (test_label_flat, y_pred_flat) in enumerate(predictions):
238
+ # test_label_flat = list(chain.from_iterable(test_label))
239
+ # y_pred_flat = list(chain.from_iterable(y_pred))
240
+
241
+ # Compute confusion matrix for this fold
242
+ cm = confusion_matrix(test_label_flat, y_pred_flat, labels=labels)
243
+ if i == 0:
244
+ matrix = cm
245
+ else:
246
+ matrix += cm
247
+
248
+ matrix = matrix.astype('float')
249
+ matrix = matrix / folds
250
+ matrix = matrix / np.sum(matrix, axis=1, keepdims=True) # Normalize
251
+
252
+ plt.figure(figsize=(10, 8))
253
+ sns.heatmap(matrix, annot=True, fmt=".2f", cmap='Blues', xticklabels=labels, yticklabels=labels)
254
+ plt.xlabel('Predicted')
255
+ plt.ylabel('Actual')
256
+ plt.title('Normalized Confusion Matrix for NER')
257
+ plt.show()
258
+
259
+ if __name__ == "__main__":
260
+ data = load_dataset("conll2003", trust_remote_code=True)
261
+ d_train = data['train']
262
+ d_validation = data['validation']
263
+ d_test = data['test']
264
+
265
+ nltk.download('gazetteers')
266
+ places=set(gazetteers.words())
267
+ people=set(names.words())
268
+ countries=set(gazetteers.words('countries.txt'))
269
+ nationalities=set(gazetteers.words('nationalities.txt'))
270
+ x_train, y_train = create_data(d_train)
271
+ x_val, y_val = create_data(d_validation)
272
+ x_test, y_test = create_data(d_test)
273
+ all_X_train = np.concatenate((x_train, x_val, x_test))
274
+ all_y_train = np.concatenate((y_train, y_val, y_test))
275
+
276
+ #K-Fold
277
+ num_fold = 5
278
+ kf = KFold(n_splits=num_fold, random_state=42, shuffle=True)
279
+ indices = np.arange(len(all_X_train))
280
+
281
+ predictions = []
282
+ all_models = []
283
+
284
+ for i, (train_index, test_index) in enumerate(kf.split(indices)):
285
+ print(f"Fold {i} Train Length: {len(train_index)} Test Length: {len(test_index)}")
286
+ # all_folds.append((train_index, test_index))# Standardize the features such that all features contribute equally to the distance metric computation of the SVM
287
+ X_train = all_X_train[train_index]
288
+ y_train = all_y_train[train_index]
289
+
290
+ X_test = all_X_train[test_index]
291
+ y_test = all_y_train[test_index]
292
+
293
+ # scaler = StandardScaler()
294
+ # Fit only on the training data (i.e. compute mean and std)
295
+ # X_train = scaler.fit_transform(X_train)
296
+
297
+ # Use the train data fit values to scale val and test
298
+ # X_train = scaler.transform(X_train)
299
+ # X_val = scaler.transform(X_val)
300
+ # X_test = scaler.transform(X_test)
301
+
302
+ model = SVC(random_state = 42, verbose = True)
303
+ model.fit(X_train, y_train)
304
+
305
+ y_pred_val = model.predict(X_test)
306
+
307
+ print("-------"*6)
308
+ print(classification_report(y_true=y_test, y_pred=y_pred_val))
309
+ print("-------"*6)
310
+
311
+ pickle.dump(model, open(f"ner_svm_{str(i)}.pkl", 'wb'))
312
+
313
+ predictions.append((y_test, y_pred_val))
314
+ all_models.append(model)
315
+ break
316
+
317
+
318
+ FOLDS = 5
319
+ labels = sorted(model.classes_)
320
+ evaluate_overall_metrics(predictions, FOLDS)
321
+ evaluate_per_pos_metrics(predictions, labels)
322
+ plot_confusion_matrix(predictions, labels, FOLDS)