BecomeAllan
commited on
Commit
·
8bf76cf
1
Parent(s):
9c0c4aa
update funs
Browse files- .vscode/settings.json +7 -0
- ML_SLRC.py +382 -44
- Util_funs.py +305 -418
.vscode/settings.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"workbench.colorCustomizations": {
|
3 |
+
"activityBar.background": "#093518",
|
4 |
+
"titleBar.activeBackground": "#0D4A21",
|
5 |
+
"titleBar.activeForeground": "#F3FDF6"
|
6 |
+
}
|
7 |
+
}
|
ML_SLRC.py
CHANGED
@@ -1,33 +1,18 @@
|
|
1 |
-
|
2 |
-
import torch.nn as nn
|
3 |
-
import math
|
4 |
import torch
|
5 |
import numpy as np
|
6 |
-
|
7 |
-
import time
|
8 |
-
import transformers
|
9 |
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
10 |
-
from sklearn.manifold import TSNE
|
11 |
-
from copy import deepcopy, copy
|
12 |
-
import seaborn as sns
|
13 |
-
import matplotlib.pylab as plt
|
14 |
-
from pprint import pprint
|
15 |
-
import shutil
|
16 |
-
import datetime
|
17 |
import re
|
18 |
-
import json
|
19 |
-
from pathlib import Path
|
20 |
-
|
21 |
-
import torch
|
22 |
-
import torch.nn as nn
|
23 |
-
from torch.utils.data import Dataset, DataLoader
|
24 |
import unicodedata
|
25 |
-
import
|
26 |
-
|
|
|
|
|
|
|
27 |
import torch
|
28 |
-
import
|
29 |
-
from
|
30 |
-
|
31 |
|
32 |
|
33 |
# Pre-trained model
|
@@ -117,7 +102,6 @@ class SLR_Classifier(nn.Module):
|
|
117 |
|
118 |
return [loss, [feature, logit], predict]
|
119 |
|
120 |
-
|
121 |
# Undesirable patterns within texts
|
122 |
patterns = {
|
123 |
'CONCLUSIONS AND IMPLICATIONS':'',
|
@@ -157,27 +141,50 @@ patterns = {
|
|
157 |
'</p>':'',
|
158 |
'<<ETX>>':'',
|
159 |
'+/-':'',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
}
|
161 |
|
162 |
patterns = {x.lower():y for x,y in patterns.items()}
|
163 |
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
class SLR_DataSet(Dataset):
|
167 |
-
def __init__(self, **args):
|
168 |
self.tokenizer = args.get('tokenizer')
|
169 |
self.data = args.get('data')
|
170 |
self.max_seq_length = args.get("max_seq_length", 512)
|
171 |
self.INPUT_NAME = args.get("input", 'x')
|
172 |
self.LABEL_NAME = args.get("output", 'y')
|
|
|
173 |
|
174 |
# Tokenizing and processing text
|
175 |
def encode_text(self, example):
|
176 |
comment_text = example[self.INPUT_NAME]
|
177 |
-
|
|
|
178 |
|
179 |
try:
|
180 |
-
labels = LABEL_MAP[example[self.LABEL_NAME]]
|
181 |
except:
|
182 |
labels = -1
|
183 |
|
@@ -200,15 +207,6 @@ class SLR_DataSet(Dataset):
|
|
200 |
torch.tensor([torch.tensor(labels).to(int)])
|
201 |
))
|
202 |
|
203 |
-
# Text processing function
|
204 |
-
def treat_text(self, text):
|
205 |
-
text = unicodedata.normalize("NFKD",str(text))
|
206 |
-
text = multiple_replace(patterns,text.lower())
|
207 |
-
text = re.sub('(\(.+\))|(\[.+\])|( \d )|(<)|(>)|(- )','', text)
|
208 |
-
text = re.sub('( +)',' ', text)
|
209 |
-
text = re.sub('(, ,)|(,,)',',', text)
|
210 |
-
text = re.sub('(%)|(per cent)',' percent', text)
|
211 |
-
return text
|
212 |
|
213 |
def __len__(self):
|
214 |
return len(self.data)
|
@@ -221,6 +219,350 @@ class SLR_DataSet(Dataset):
|
|
221 |
return temp_data
|
222 |
|
223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
|
225 |
# Regex multiple replace function
|
226 |
def multiple_replace(dict, text):
|
@@ -229,8 +571,4 @@ def multiple_replace(dict, text):
|
|
229 |
regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys())))
|
230 |
|
231 |
# Substitution
|
232 |
-
return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text)
|
233 |
-
|
234 |
-
# Undesirable patterns within texts
|
235 |
-
|
236 |
-
|
|
|
1 |
+
from torch import nn
|
|
|
|
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
+
from copy import deepcopy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import unicodedata
|
7 |
+
from torch.utils.data import Dataset, DataLoader,TensorDataset, RandomSampler
|
8 |
+
from sklearn.model_selection import train_test_split
|
9 |
+
from torch.optim import Adam
|
10 |
+
from copy import deepcopy
|
11 |
+
import gc
|
12 |
import torch
|
13 |
+
import numpy as np
|
14 |
+
from torchmetrics import functional as fn
|
15 |
+
import random
|
16 |
|
17 |
|
18 |
# Pre-trained model
|
|
|
102 |
|
103 |
return [loss, [feature, logit], predict]
|
104 |
|
|
|
105 |
# Undesirable patterns within texts
|
106 |
patterns = {
|
107 |
'CONCLUSIONS AND IMPLICATIONS':'',
|
|
|
141 |
'</p>':'',
|
142 |
'<<ETX>>':'',
|
143 |
'+/-':'',
|
144 |
+
'\(.+\)':'',
|
145 |
+
'\[.+\]':'',
|
146 |
+
' \d ':'',
|
147 |
+
'<':'',
|
148 |
+
'>':'',
|
149 |
+
'- ':'',
|
150 |
+
' +':' ',
|
151 |
+
', ,':',',
|
152 |
+
',,':',',
|
153 |
+
'%':' percent',
|
154 |
+
'per cent':' percent'
|
155 |
}
|
156 |
|
157 |
patterns = {x.lower():y for x,y in patterns.items()}
|
158 |
|
159 |
+
|
160 |
+
LABEL_MAP = {'negative': 0,
|
161 |
+
'not included':0,
|
162 |
+
'0':0,
|
163 |
+
0:0,
|
164 |
+
'excluded':0,
|
165 |
+
'positive': 1,
|
166 |
+
'included':1,
|
167 |
+
'1':1,
|
168 |
+
1:1,
|
169 |
+
}
|
170 |
|
171 |
class SLR_DataSet(Dataset):
|
172 |
+
def __init__(self,treat_text =None, **args):
|
173 |
self.tokenizer = args.get('tokenizer')
|
174 |
self.data = args.get('data')
|
175 |
self.max_seq_length = args.get("max_seq_length", 512)
|
176 |
self.INPUT_NAME = args.get("input", 'x')
|
177 |
self.LABEL_NAME = args.get("output", 'y')
|
178 |
+
self.treat_text = treat_text
|
179 |
|
180 |
# Tokenizing and processing text
|
181 |
def encode_text(self, example):
|
182 |
comment_text = example[self.INPUT_NAME]
|
183 |
+
if self.treat_text:
|
184 |
+
comment_text = self.treat_text(comment_text)
|
185 |
|
186 |
try:
|
187 |
+
labels = LABEL_MAP[example[self.LABEL_NAME].lower()]
|
188 |
except:
|
189 |
labels = -1
|
190 |
|
|
|
207 |
torch.tensor([torch.tensor(labels).to(int)])
|
208 |
))
|
209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
def __len__(self):
|
212 |
return len(self.data)
|
|
|
219 |
return temp_data
|
220 |
|
221 |
|
222 |
+
class Learner(nn.Module):
|
223 |
+
|
224 |
+
def __init__(self, **args):
|
225 |
+
"""
|
226 |
+
:param args:
|
227 |
+
"""
|
228 |
+
super(Learner, self).__init__()
|
229 |
+
|
230 |
+
self.inner_print = args.get('inner_print')
|
231 |
+
self.inner_batch_size = args.get('inner_batch_size')
|
232 |
+
self.outer_update_lr = args.get('outer_update_lr')
|
233 |
+
self.inner_update_lr = args.get('inner_update_lr')
|
234 |
+
self.inner_update_step = args.get('inner_update_step')
|
235 |
+
self.inner_update_step_eval = args.get('inner_update_step_eval')
|
236 |
+
self.model = args.get('model')
|
237 |
+
self.device = args.get('device')
|
238 |
+
|
239 |
+
# Outer optimizer
|
240 |
+
self.outer_optimizer = Adam(self.model.parameters(), lr=self.outer_update_lr)
|
241 |
+
self.model.train()
|
242 |
+
|
243 |
+
def forward(self, batch_tasks, training = True, valid_train = True):
|
244 |
+
"""
|
245 |
+
batch = [(support TensorDataset, query TensorDataset),
|
246 |
+
(support TensorDataset, query TensorDataset),
|
247 |
+
(support TensorDataset, query TensorDataset),
|
248 |
+
(support TensorDataset, query TensorDataset)]
|
249 |
+
|
250 |
+
# support = TensorDataset(all_input_ids, all_attention_mask, all_segment_ids, all_label_ids)
|
251 |
+
"""
|
252 |
+
task_accs = []
|
253 |
+
task_f1 = []
|
254 |
+
task_recall = []
|
255 |
+
sum_gradients = []
|
256 |
+
num_task = len(batch_tasks)
|
257 |
+
num_inner_update_step = self.inner_update_step if training else self.inner_update_step_eval
|
258 |
+
|
259 |
+
# Outer loop tasks
|
260 |
+
for task_id, task in enumerate(batch_tasks):
|
261 |
+
support = task[0]
|
262 |
+
query = task[1]
|
263 |
+
name = task[2]
|
264 |
+
|
265 |
+
# Copying model
|
266 |
+
fast_model = deepcopy(self.model)
|
267 |
+
fast_model.to(self.device)
|
268 |
+
|
269 |
+
# Inner trainer optimizer
|
270 |
+
inner_optimizer = Adam(fast_model.parameters(), lr=self.inner_update_lr)
|
271 |
+
|
272 |
+
# Creating training data loaders
|
273 |
+
if len(support) % self.inner_batch_size == 1 :
|
274 |
+
support_dataloader = DataLoader(support, sampler=RandomSampler(support),
|
275 |
+
batch_size=self.inner_batch_size,
|
276 |
+
drop_last=True)
|
277 |
+
else:
|
278 |
+
support_dataloader = DataLoader(support, sampler=RandomSampler(support),
|
279 |
+
batch_size=self.inner_batch_size,
|
280 |
+
drop_last=False)
|
281 |
+
|
282 |
+
# steps_per_epoch=len(support) // self.inner_batch_size
|
283 |
+
# total_training_steps = steps_per_epoch * 5
|
284 |
+
# warmup_steps = total_training_steps // 3
|
285 |
+
#
|
286 |
+
|
287 |
+
# scheduler = get_linear_schedule_with_warmup(
|
288 |
+
# inner_optimizer,
|
289 |
+
# num_warmup_steps=warmup_steps,
|
290 |
+
# num_training_steps=total_training_steps
|
291 |
+
# )
|
292 |
+
|
293 |
+
fast_model.train()
|
294 |
+
|
295 |
+
# Inner loop training epoch (support set)
|
296 |
+
if valid_train:
|
297 |
+
print('----Task',task_id,":", name, '----')
|
298 |
+
|
299 |
+
for i in range(0, num_inner_update_step):
|
300 |
+
all_loss = []
|
301 |
+
|
302 |
+
# Inner loop training batch (support set)
|
303 |
+
for inner_step, batch in enumerate(support_dataloader):
|
304 |
+
batch = tuple(t.to(self.device) for t in batch)
|
305 |
+
input_ids, attention_mask, token_type_ids, label_id = batch
|
306 |
+
|
307 |
+
# Feed Foward
|
308 |
+
loss, _, _ = fast_model(input_ids, attention_mask, token_type_ids=token_type_ids, labels = label_id)
|
309 |
+
|
310 |
+
# Computing gradients
|
311 |
+
loss.backward()
|
312 |
+
# torch.nn.utils.clip_grad_norm_(fast_model.parameters(), max_norm=1)
|
313 |
+
|
314 |
+
# Updating inner training parameters
|
315 |
+
inner_optimizer.step()
|
316 |
+
inner_optimizer.zero_grad()
|
317 |
+
|
318 |
+
# Appending losses
|
319 |
+
all_loss.append(loss.item())
|
320 |
+
|
321 |
+
del batch, input_ids, attention_mask, label_id
|
322 |
+
torch.cuda.empty_cache()
|
323 |
+
|
324 |
+
if valid_train:
|
325 |
+
if (i+1) % self.inner_print == 0:
|
326 |
+
print("Inner Loss: ", np.mean(all_loss))
|
327 |
+
|
328 |
+
fast_model.to(torch.device('cpu'))
|
329 |
+
|
330 |
+
# Inner training phase weights
|
331 |
+
if training:
|
332 |
+
meta_weights = list(self.model.parameters())
|
333 |
+
fast_weights = list(fast_model.parameters())
|
334 |
+
|
335 |
+
# Appending gradients
|
336 |
+
gradients = []
|
337 |
+
for i, (meta_params, fast_params) in enumerate(zip(meta_weights, fast_weights)):
|
338 |
+
gradient = meta_params - fast_params
|
339 |
+
if task_id == 0:
|
340 |
+
sum_gradients.append(gradient)
|
341 |
+
else:
|
342 |
+
sum_gradients[i] += gradient
|
343 |
+
|
344 |
+
|
345 |
+
# Inner test (query set)
|
346 |
+
fast_model.to(self.device)
|
347 |
+
fast_model.eval()
|
348 |
+
|
349 |
+
if valid_train:
|
350 |
+
# Inner test (query set)
|
351 |
+
fast_model.to(self.device)
|
352 |
+
fast_model.eval()
|
353 |
+
|
354 |
+
with torch.no_grad():
|
355 |
+
# Data loader
|
356 |
+
query_dataloader = DataLoader(query, sampler=None, batch_size=len(query))
|
357 |
+
query_batch = iter(query_dataloader).next()
|
358 |
+
query_batch = tuple(t.to(self.device) for t in query_batch)
|
359 |
+
q_input_ids, q_attention_mask, q_token_type_ids, q_label_id = query_batch
|
360 |
+
|
361 |
+
# Feedfoward
|
362 |
+
_, _, pre_label_id = fast_model(q_input_ids, q_attention_mask, q_token_type_ids, labels = q_label_id)
|
363 |
+
|
364 |
+
# Predictions
|
365 |
+
pre_label_id = pre_label_id.detach().cpu().squeeze()
|
366 |
+
# Labels
|
367 |
+
q_label_id = q_label_id.detach().cpu()
|
368 |
+
|
369 |
+
# Calculating metrics
|
370 |
+
acc = fn.accuracy(pre_label_id, q_label_id).item()
|
371 |
+
recall = fn.recall(pre_label_id, q_label_id).item(),
|
372 |
+
f1 = fn.f1_score(pre_label_id, q_label_id).item()
|
373 |
+
|
374 |
+
# appending metrics
|
375 |
+
task_accs.append(acc)
|
376 |
+
task_f1.append(f1)
|
377 |
+
task_recall.append(recall)
|
378 |
+
|
379 |
+
fast_model.to(torch.device('cpu'))
|
380 |
+
|
381 |
+
del fast_model, inner_optimizer
|
382 |
+
torch.cuda.empty_cache()
|
383 |
+
|
384 |
+
print("\n")
|
385 |
+
print("f1:",np.mean(task_f1))
|
386 |
+
print("recall:",np.mean(task_recall))
|
387 |
+
|
388 |
+
# Updating outer training parameters
|
389 |
+
if training:
|
390 |
+
# Mean of gradients
|
391 |
+
for i in range(0,len(sum_gradients)):
|
392 |
+
sum_gradients[i] = sum_gradients[i] / float(num_task)
|
393 |
+
|
394 |
+
# Indexing parameters to model
|
395 |
+
for i, params in enumerate(self.model.parameters()):
|
396 |
+
params.grad = sum_gradients[i]
|
397 |
+
|
398 |
+
# Updating parameters
|
399 |
+
self.outer_optimizer.step()
|
400 |
+
self.outer_optimizer.zero_grad()
|
401 |
+
|
402 |
+
del sum_gradients
|
403 |
+
gc.collect()
|
404 |
+
torch.cuda.empty_cache()
|
405 |
+
|
406 |
+
if valid_train:
|
407 |
+
return np.mean(task_accs)
|
408 |
+
else:
|
409 |
+
return np.array(0)
|
410 |
+
|
411 |
+
|
412 |
+
|
413 |
+
# Creating Meta Tasks
|
414 |
+
class MetaTask(Dataset):
|
415 |
+
def __init__(self, examples, num_task, k_support, k_query,
|
416 |
+
tokenizer, training=True, max_seq_length=512,
|
417 |
+
treat_text =None, **args):
|
418 |
+
"""
|
419 |
+
:param samples: list of samples
|
420 |
+
:param num_task: number of training tasks.
|
421 |
+
:param k_support: number of classes support samples per task
|
422 |
+
:param k_query: number of classes query sample per task
|
423 |
+
"""
|
424 |
+
self.examples = examples
|
425 |
+
|
426 |
+
self.num_task = num_task
|
427 |
+
self.k_support = k_support
|
428 |
+
self.k_query = k_query
|
429 |
+
self.tokenizer = tokenizer
|
430 |
+
self.max_seq_length = max_seq_length
|
431 |
+
self.treat_text = treat_text
|
432 |
+
|
433 |
+
# Randomly generating tasks
|
434 |
+
self.create_batch(self.num_task, training)
|
435 |
+
|
436 |
+
# Creating batch
|
437 |
+
def create_batch(self, num_task, training):
|
438 |
+
self.supports = [] # support set
|
439 |
+
self.queries = [] # query set
|
440 |
+
self.task_names = [] # Name of task
|
441 |
+
self.supports_indexs = [] # index of supports
|
442 |
+
self.queries_indexs = [] # index of queries
|
443 |
+
self.num_task=num_task
|
444 |
+
|
445 |
+
# Available tasks
|
446 |
+
domains = self.examples['domain'].unique()
|
447 |
+
|
448 |
+
# If not training, create all tasks
|
449 |
+
if not(training):
|
450 |
+
self.task_names = domains
|
451 |
+
num_task = len(self.task_names)
|
452 |
+
self.num_task=num_task
|
453 |
+
|
454 |
+
|
455 |
+
for b in range(num_task): # For each task,
|
456 |
+
total_per_class = self.k_support + self.k_query
|
457 |
+
task_size = 2*self.k_support + 2*self.k_query
|
458 |
+
|
459 |
+
# Select a task at random
|
460 |
+
if training:
|
461 |
+
domain = random.choice(domains)
|
462 |
+
self.task_names.append(domain)
|
463 |
+
else:
|
464 |
+
domain = self.task_names[b]
|
465 |
+
|
466 |
+
# Task data
|
467 |
+
domainExamples = self.examples[self.examples['domain'] == domain]
|
468 |
+
|
469 |
+
# Minimal label quantity
|
470 |
+
min_per_class = min(domainExamples['label'].value_counts())
|
471 |
+
|
472 |
+
if total_per_class > min_per_class:
|
473 |
+
total_per_class = min_per_class
|
474 |
+
|
475 |
+
# Select k_support + k_query task examples
|
476 |
+
# Sample (n) from each label(class)
|
477 |
+
selected_examples = domainExamples.groupby("label").sample(total_per_class, replace = False)
|
478 |
+
|
479 |
+
# Split data into support (training) and query (testing) sets
|
480 |
+
s, q = train_test_split(selected_examples,
|
481 |
+
stratify= selected_examples["label"],
|
482 |
+
test_size= 2*self.k_query/task_size,
|
483 |
+
shuffle=True)
|
484 |
+
|
485 |
+
# Permutating data
|
486 |
+
s = s.sample(frac=1)
|
487 |
+
q = q.sample(frac=1)
|
488 |
+
|
489 |
+
# Appending indexes
|
490 |
+
if not(training):
|
491 |
+
self.supports_indexs.append(s.index)
|
492 |
+
self.queries_indexs.append(q.index)
|
493 |
+
|
494 |
+
# Creating list of support (training) and query (testing) tasks
|
495 |
+
self.supports.append(s.to_dict('records'))
|
496 |
+
self.queries.append(q.to_dict('records'))
|
497 |
+
|
498 |
+
# Creating task tensors
|
499 |
+
def create_feature_set(self, examples):
|
500 |
+
all_input_ids = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
|
501 |
+
all_attention_mask = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
|
502 |
+
all_token_type_ids = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
|
503 |
+
all_label_ids = torch.empty(len(examples), dtype = torch.long)
|
504 |
+
|
505 |
+
for _id, e in enumerate(examples):
|
506 |
+
all_input_ids[_id], all_attention_mask[_id], all_token_type_ids[_id], all_label_ids[_id] = self.encode_text(e)
|
507 |
+
|
508 |
+
return TensorDataset(
|
509 |
+
all_input_ids,
|
510 |
+
all_attention_mask,
|
511 |
+
all_token_type_ids,
|
512 |
+
all_label_ids
|
513 |
+
)
|
514 |
+
|
515 |
+
# Data encoding
|
516 |
+
def encode_text(self, example):
|
517 |
+
comment_text = example["text"]
|
518 |
+
|
519 |
+
if self.treat_text:
|
520 |
+
comment_text = self.treat_text(comment_text)
|
521 |
+
|
522 |
+
labels = LABEL_MAP[example["label"]]
|
523 |
+
|
524 |
+
encoding = self.tokenizer.encode_plus(
|
525 |
+
(comment_text, "It is a great text."),
|
526 |
+
add_special_tokens=True,
|
527 |
+
max_length=self.max_seq_length,
|
528 |
+
return_token_type_ids=True,
|
529 |
+
padding="max_length",
|
530 |
+
truncation=True,
|
531 |
+
return_attention_mask=True,
|
532 |
+
return_tensors='pt',
|
533 |
+
)
|
534 |
+
|
535 |
+
return tuple((
|
536 |
+
encoding["input_ids"].flatten(),
|
537 |
+
encoding["attention_mask"].flatten(),
|
538 |
+
encoding["token_type_ids"].flatten(),
|
539 |
+
torch.tensor([torch.tensor(labels).to(int)])
|
540 |
+
))
|
541 |
+
|
542 |
+
# Returns data upon calling
|
543 |
+
def __getitem__(self, index):
|
544 |
+
support_set = self.create_feature_set(self.supports[index])
|
545 |
+
query_set = self.create_feature_set(self.queries[index])
|
546 |
+
name = self.task_names[index]
|
547 |
+
return support_set, query_set, name
|
548 |
+
|
549 |
+
def __len__(self):
|
550 |
+
return self.num_task
|
551 |
+
|
552 |
+
|
553 |
+
class treat_text:
|
554 |
+
def __init__(self, patterns):
|
555 |
+
self.patterns = patterns
|
556 |
+
|
557 |
+
def __call__(self,text):
|
558 |
+
text = unicodedata.normalize("NFKD",str(text))
|
559 |
+
text = multiple_replace(self.patterns,text.lower())
|
560 |
+
text = re.sub('(\(.+\))|(\[.+\])|( \d )|(<)|(>)|(- )','', text)
|
561 |
+
text = re.sub('( +)',' ', text)
|
562 |
+
text = re.sub('(, ,)|(,,)',',', text)
|
563 |
+
text = re.sub('(%)|(per cent)',' percent', text)
|
564 |
+
return text
|
565 |
+
|
566 |
|
567 |
# Regex multiple replace function
|
568 |
def multiple_replace(dict, text):
|
|
|
571 |
regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys())))
|
572 |
|
573 |
# Substitution
|
574 |
+
return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text)
|
|
|
|
|
|
|
|
Util_funs.py
CHANGED
@@ -1,49 +1,49 @@
|
|
|
|
|
|
1 |
import os
|
2 |
-
import torch
|
3 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import random
|
5 |
-
import json, pickle
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
import torch
|
11 |
-
|
12 |
-
import pandas as pd
|
13 |
import time
|
14 |
-
import transformers
|
15 |
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
16 |
from sklearn.manifold import TSNE
|
17 |
-
from copy import deepcopy
|
18 |
import seaborn as sns
|
19 |
import matplotlib.pylab as plt
|
20 |
-
from pprint import pprint
|
21 |
-
import shutil
|
22 |
-
import datetime
|
23 |
-
import re
|
24 |
import json
|
25 |
from pathlib import Path
|
26 |
-
|
27 |
-
import
|
28 |
-
from
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
from transformers import BertForSequenceClassification
|
35 |
-
from copy import deepcopy
|
36 |
-
import gc
|
37 |
-
from sklearn.metrics import accuracy_score
|
38 |
-
import torch
|
39 |
-
import numpy as np
|
40 |
-
import torchmetrics
|
41 |
-
from torchmetrics import functional as fn
|
42 |
|
43 |
|
44 |
-
SEED = 2222
|
45 |
|
46 |
-
gen_seed = torch.Generator().manual_seed(SEED)
|
47 |
|
48 |
|
49 |
# Random seed function
|
@@ -54,7 +54,7 @@ def random_seed(value):
|
|
54 |
np.random.seed(value)
|
55 |
random.seed(value)
|
56 |
|
57 |
-
#
|
58 |
def create_batch_of_tasks(taskset, is_shuffle = True, batch_size = 4):
|
59 |
idxs = list(range(0,len(taskset)))
|
60 |
if is_shuffle:
|
@@ -63,48 +63,51 @@ def create_batch_of_tasks(taskset, is_shuffle = True, batch_size = 4):
|
|
63 |
yield [taskset[idxs[i]] for i in range(i, min(i + batch_size,len(taskset)))]
|
64 |
|
65 |
|
66 |
-
|
67 |
-
def prepare_data(data, batch_size,tokenizer,max_seq_length,
|
68 |
input = 'text', output = 'label',
|
69 |
-
train_size_per_class = 5
|
|
|
70 |
data = data.reset_index().drop("index", axis=1)
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
data_train = labaled_data.groupby('label').sample(train_size_per_class)
|
75 |
|
76 |
-
|
77 |
-
|
|
|
78 |
|
79 |
-
|
|
|
80 |
|
81 |
|
82 |
-
#
|
83 |
-
##
|
84 |
dataset_train = SLR_DataSet(
|
85 |
data = data_train.sample(frac=1),
|
86 |
input = input,
|
87 |
output = output,
|
88 |
tokenizer=tokenizer,
|
89 |
-
max_seq_length =max_seq_length
|
|
|
90 |
|
91 |
-
|
92 |
-
# Dataloaders
|
93 |
-
## Transforma em dataset
|
94 |
dataset_test = SLR_DataSet(
|
95 |
data = data_test,
|
96 |
input = input,
|
97 |
output = output,
|
98 |
tokenizer=tokenizer,
|
99 |
-
max_seq_length =max_seq_length
|
|
|
100 |
|
101 |
# Dataloaders
|
102 |
-
##
|
103 |
data_train_loader = DataLoader(dataset_train,
|
104 |
shuffle=True,
|
105 |
batch_size=batch_size['train']
|
106 |
)
|
107 |
|
|
|
108 |
if len(dataset_test) % batch_size['test'] == 1 :
|
109 |
data_test_loader = DataLoader(dataset_test,
|
110 |
batch_size=batch_size['test'],
|
@@ -117,50 +120,54 @@ def prepare_data(data, batch_size,tokenizer,max_seq_length,
|
|
117 |
return data_train_loader, data_test_loader, data_train, data_test
|
118 |
|
119 |
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
-
|
122 |
-
|
123 |
-
from tqdm import tqdm
|
124 |
-
|
125 |
-
def meta_train(data, model, device, Info, print_epoch =True, size_layer=0, Test_resource =None):
|
126 |
-
|
127 |
learner = Learner(model = model, device = device, **Info)
|
128 |
|
129 |
# Testing tasks
|
130 |
if isinstance(Test_resource, pd.DataFrame):
|
131 |
test = MetaTask(Test_resource, num_task = 0, k_support=10, k_query=10,
|
132 |
-
training=False, **Info)
|
133 |
|
134 |
|
135 |
torch.clear_autocast_cache()
|
136 |
gc.collect()
|
137 |
torch.cuda.empty_cache()
|
138 |
|
139 |
-
# Meta
|
140 |
for epoch in tqdm(range(Info['meta_epoch']), desc= "Meta epoch ", ncols=80):
|
141 |
-
# print("Meta Epoca:", epoch)
|
142 |
|
143 |
-
#
|
144 |
train = MetaTask(data,
|
145 |
num_task = Info['num_task_train'],
|
146 |
k_support=Info['k_qry'],
|
147 |
-
k_query=Info['k_spt'],
|
|
|
148 |
|
149 |
-
#
|
150 |
db = create_batch_of_tasks(train, is_shuffle = True, batch_size = Info["outer_batch_size"])
|
151 |
|
152 |
if print_epoch:
|
153 |
# Outer loop bach training
|
154 |
for step, task_batch in enumerate(db):
|
155 |
print("\n-----------------Training Mode","Meta_epoch:", epoch ,"-----------------\n")
|
156 |
-
|
|
|
157 |
acc = learner(task_batch, valid_train= print_epoch)
|
158 |
print('Step:', step, '\ttraining Acc:', acc)
|
|
|
159 |
if isinstance(Test_resource, pd.DataFrame):
|
160 |
-
# Validating Model
|
161 |
if ((epoch+1) % 4) + step == 0:
|
162 |
random_seed(123)
|
163 |
print("\n-----------------Testing Mode-----------------\n")
|
|
|
|
|
164 |
db_test = create_batch_of_tasks(test, is_shuffle = False, batch_size = 1)
|
165 |
acc_all_test = []
|
166 |
|
@@ -174,10 +181,10 @@ def meta_train(data, model, device, Info, print_epoch =True, size_layer=0, Test_
|
|
174 |
|
175 |
# Restarting training randomly
|
176 |
random_seed(int(time.time() % 10))
|
177 |
-
|
178 |
-
|
179 |
else:
|
180 |
for step, task_batch in enumerate(db):
|
|
|
181 |
acc = learner(task_batch, print_epoch, valid_train= print_epoch)
|
182 |
|
183 |
torch.clear_autocast_cache()
|
@@ -187,14 +194,14 @@ def meta_train(data, model, device, Info, print_epoch =True, size_layer=0, Test_
|
|
187 |
|
188 |
|
189 |
def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr = 1, print_info = True, name = 'name'):
|
190 |
-
#
|
191 |
model_meta = deepcopy(model)
|
192 |
optimizer = Adam(model_meta.parameters(), lr=lr)
|
193 |
|
194 |
model_meta.to(device)
|
195 |
model_meta.train()
|
196 |
|
197 |
-
#
|
198 |
for i in range(0, epoch):
|
199 |
all_loss = []
|
200 |
|
@@ -203,13 +210,13 @@ def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr
|
|
203 |
batch = tuple(t.to(device) for t in batch)
|
204 |
input_ids, attention_mask,q_token_type_ids, label_id = batch
|
205 |
|
206 |
-
# Feedfoward
|
207 |
loss, _, _ = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze())
|
208 |
|
209 |
-
#
|
210 |
loss.backward()
|
211 |
|
212 |
-
#
|
213 |
optimizer.step()
|
214 |
optimizer.zero_grad()
|
215 |
|
@@ -220,39 +227,43 @@ def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr
|
|
220 |
print("Loss: ", np.mean(all_loss))
|
221 |
|
222 |
|
223 |
-
#
|
224 |
model_meta.eval()
|
225 |
all_loss = []
|
226 |
-
|
227 |
features = []
|
228 |
labels = []
|
229 |
predi_logit = []
|
230 |
|
231 |
with torch.no_grad():
|
|
|
232 |
for inner_step, batch in enumerate(tqdm(data_test_loader,
|
233 |
desc="Test validation | " + name,
|
234 |
ncols=80)) :
|
235 |
batch = tuple(t.to(device) for t in batch)
|
236 |
input_ids, attention_mask,q_token_type_ids, label_id = batch
|
237 |
|
238 |
-
#
|
239 |
_, feature, prediction = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze())
|
240 |
|
|
|
241 |
prediction = prediction.detach().cpu().squeeze()
|
242 |
label_id = label_id.detach().cpu()
|
|
|
|
|
243 |
logit = feature[1].detach().cpu()
|
244 |
-
|
245 |
|
246 |
-
|
247 |
features.append(feature_lat.numpy())
|
248 |
-
predi_logit.append(logit.numpy())
|
249 |
|
250 |
-
#
|
251 |
-
|
|
|
252 |
del input_ids, attention_mask, label_id, batch
|
253 |
|
254 |
-
|
255 |
-
|
256 |
|
257 |
model_meta.to('cpu')
|
258 |
gc.collect()
|
@@ -260,26 +271,32 @@ def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr
|
|
260 |
|
261 |
del model_meta, optimizer
|
262 |
|
|
|
263 |
|
|
|
|
|
|
|
264 |
features = np.concatenate(np.array(features,dtype=object))
|
265 |
-
labels = np.concatenate(np.array(labels,dtype=object))
|
266 |
-
logits = np.concatenate(np.array(predi_logit,dtype=object))
|
267 |
-
|
268 |
features = torch.tensor(features.astype(np.float32)).detach().clone()
|
|
|
|
|
269 |
labels = torch.tensor(labels.astype(int)).detach().clone()
|
|
|
|
|
270 |
logits = torch.tensor(logits.astype(np.float32)).detach().clone()
|
271 |
|
272 |
-
#
|
273 |
X_embedded = TSNE(n_components=2, learning_rate='auto',
|
274 |
init='random').fit_transform(features.detach().clone())
|
275 |
|
276 |
return logits.detach().clone(), X_embedded, labels.detach().clone(), features.detach().clone()
|
277 |
-
|
278 |
-
|
279 |
def wss_calc(logit, labels, trsh = 0.5):
|
280 |
|
281 |
-
#
|
282 |
predict_trash = torch.sigmoid(logit).squeeze() >= trsh
|
|
|
|
|
283 |
CM = confusion_matrix(labels, predict_trash.to(int) )
|
284 |
tn, fp, fne, tp = CM.ravel()
|
285 |
|
@@ -287,36 +304,22 @@ def wss_calc(logit, labels, trsh = 0.5):
|
|
287 |
N = (tn + fp)
|
288 |
recall = tp/(tp+fne)
|
289 |
|
290 |
-
#
|
291 |
-
|
292 |
|
293 |
-
#
|
294 |
-
|
295 |
|
296 |
return {
|
297 |
-
"wss": round(
|
298 |
-
"awss": round(
|
299 |
"R": round(recall,4),
|
300 |
"CM": CM
|
301 |
}
|
302 |
|
303 |
|
304 |
-
|
305 |
-
|
306 |
-
from sklearn.metrics import confusion_matrix
|
307 |
-
from torchmetrics import functional as fn
|
308 |
-
import matplotlib.pyplot as plt
|
309 |
-
from sklearn.metrics import roc_curve, auc
|
310 |
-
from sklearn.metrics import roc_auc_score
|
311 |
-
import ipywidgets as widgets
|
312 |
-
from IPython.display import HTML, display, clear_output
|
313 |
-
import matplotlib.pyplot as plt
|
314 |
-
import seaborn as sns
|
315 |
-
import warnings
|
316 |
-
|
317 |
-
warnings.simplefilter(action='ignore', category=FutureWarning)
|
318 |
-
|
319 |
-
def plot(logits, X_embedded, labels, tresh, show = True,
|
320 |
namefig = "plot", make_plot = True, print_stats = True, save = True):
|
321 |
col = pd.MultiIndex.from_tuples([
|
322 |
("Predict", "0"),
|
@@ -329,30 +332,27 @@ def plot(logits, X_embedded, labels, tresh, show = True,
|
|
329 |
|
330 |
predict = torch.sigmoid(logits).detach().clone()
|
331 |
|
332 |
-
|
333 |
-
|
334 |
fpr, tpr, thresholds = roc_curve(labels, predict.squeeze())
|
335 |
|
336 |
-
#
|
337 |
-
|
338 |
-
|
339 |
idx_wss95 = sum(tpr < 0.95)
|
|
|
340 |
thresholds95 = thresholds[idx_wss95]
|
341 |
|
|
|
342 |
wss95_info = wss_calc(logits,labels, thresholds95 )
|
343 |
acc_wss95 = fn.accuracy(predict, labels, threshold=thresholds95)
|
344 |
f1_wss95 = fn.f1_score(predict, labels, threshold=thresholds95)
|
345 |
|
346 |
|
347 |
-
#
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
wss_info = wss_calc(logits,labels, tresh )
|
353 |
-
# Accuraci
|
354 |
-
acc_wssR = fn.accuracy(predict, labels, threshold=tresh)
|
355 |
-
f1_wssR = fn.f1_score(predict, labels, threshold=tresh)
|
356 |
|
357 |
|
358 |
metrics= {
|
@@ -370,12 +370,11 @@ def plot(logits, X_embedded, labels, tresh, show = True,
|
|
370 |
# f1
|
371 |
"f1@95": f1_wss95.item(),
|
372 |
"f1@R": f1_wssR.item(),
|
373 |
-
#
|
374 |
-
"
|
375 |
}
|
376 |
|
377 |
-
#
|
378 |
-
|
379 |
if print_stats:
|
380 |
wss95= f"WSS@95:{wss95_info['wss']}, R: {wss95_info['R']}"
|
381 |
wss95_adj= f"ASSWSS@95:{wss95_info['awss']}"
|
@@ -383,14 +382,14 @@ def plot(logits, X_embedded, labels, tresh, show = True,
|
|
383 |
print(wss95_adj)
|
384 |
print('Acc.:', round(acc_wss95.item(), 4))
|
385 |
print('F1-score:', round(f1_wss95.item(), 4))
|
386 |
-
print(f"
|
387 |
cm = pd.DataFrame(wss95_info['CM'],
|
388 |
index=index,
|
389 |
columns=col)
|
390 |
|
391 |
print("\nConfusion matrix:")
|
392 |
print(cm)
|
393 |
-
print("\n---Metrics with threshold:",
|
394 |
wss= f"WSS@R:{wss_info['wss']}, R: {wss_info['R']}"
|
395 |
print(wss)
|
396 |
wss_adj= f"AWSS@R:{wss_info['awss']}"
|
@@ -405,51 +404,53 @@ def plot(logits, X_embedded, labels, tresh, show = True,
|
|
405 |
print(cm)
|
406 |
|
407 |
|
408 |
-
#
|
409 |
|
410 |
if make_plot:
|
411 |
|
412 |
fig, axes = plt.subplots(1, 4, figsize=(25,10))
|
413 |
alpha = torch.squeeze(predict).numpy()
|
414 |
|
415 |
-
#
|
416 |
-
|
417 |
p1 = sns.scatterplot(x=X_embedded[:, 0],
|
418 |
y=X_embedded[:, 1],
|
419 |
hue=labels,
|
420 |
-
alpha=alpha, ax = axes[0]).set_title('Predictions-TSNE')
|
421 |
|
|
|
|
|
422 |
t_wss = predict >= thresholds95
|
423 |
t_wss = t_wss.squeeze().numpy()
|
424 |
-
|
425 |
p2 = sns.scatterplot(x=X_embedded[t_wss, 0],
|
426 |
y=X_embedded[t_wss, 1],
|
427 |
hue=labels[t_wss],
|
428 |
-
alpha=alpha[t_wss], ax = axes[1]).set_title('WSS@95')
|
429 |
|
430 |
-
|
|
|
431 |
t = t.squeeze().numpy()
|
432 |
-
|
433 |
p3 = sns.scatterplot(x=X_embedded[t, 0],
|
434 |
y=X_embedded[t, 1],
|
435 |
hue=labels[t],
|
436 |
-
alpha=alpha[t], ax = axes[2]).set_title(f'Predictions-
|
437 |
-
|
438 |
|
|
|
439 |
roc_auc = auc(fpr, tpr)
|
440 |
lw = 2
|
441 |
-
|
442 |
axes[3].plot(
|
443 |
fpr,
|
444 |
tpr,
|
445 |
color="darkorange",
|
446 |
lw=lw,
|
447 |
label="ROC curve (area = %0.2f)" % roc_auc)
|
448 |
-
|
449 |
axes[3].plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
|
450 |
axes[3].axhline(y=0.95, color='r', linestyle='-')
|
451 |
-
axes[3].set(xlabel="False Positive Rate", ylabel="True Positive Rate"
|
452 |
axes[3].legend(loc="lower right")
|
|
|
|
|
|
|
|
|
453 |
|
454 |
if show:
|
455 |
plt.show()
|
@@ -459,6 +460,7 @@ def plot(logits, X_embedded, labels, tresh, show = True,
|
|
459 |
|
460 |
return metrics
|
461 |
|
|
|
462 |
def auc_plot(logits,labels, color = "darkorange", label = "test"):
|
463 |
predict = torch.sigmoid(logits).detach().clone()
|
464 |
fpr, tpr, thresholds = roc_curve(labels, predict.squeeze())
|
@@ -478,45 +480,40 @@ def auc_plot(logits,labels, color = "darkorange", label = "test"):
|
|
478 |
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
|
479 |
plt.axhline(y=0.95, color='r', linestyle='-')
|
480 |
|
481 |
-
|
482 |
-
from sklearn.metrics import confusion_matrix
|
483 |
-
from torchmetrics import functional as fn
|
484 |
-
import matplotlib.pyplot as plt
|
485 |
-
from sklearn.metrics import roc_curve, auc
|
486 |
-
from sklearn.metrics import roc_auc_score
|
487 |
-
import ipywidgets as widgets
|
488 |
-
from IPython.display import HTML, display, clear_output
|
489 |
-
import matplotlib.pyplot as plt
|
490 |
-
import seaborn as sns
|
491 |
-
import warnings
|
492 |
-
|
493 |
-
|
494 |
class diagnosis():
|
495 |
-
def __init__(self, names, Valid_resource, batch_size_test,
|
|
|
496 |
self.names=names
|
497 |
self.Valid_resource=Valid_resource
|
498 |
self.batch_size_test=batch_size_test
|
499 |
self.model=model
|
500 |
-
self.start=start
|
|
|
|
|
|
|
|
|
501 |
|
|
|
502 |
self.value_trash = widgets.FloatText(
|
503 |
value=0.95,
|
504 |
-
description='
|
505 |
disabled=False
|
506 |
)
|
507 |
-
|
508 |
self.valueb = widgets.IntText(
|
509 |
value=10,
|
510 |
description='size',
|
511 |
disabled=False
|
512 |
)
|
513 |
|
|
|
514 |
self.train_b = widgets.Button(description="Train")
|
515 |
self.next_b = widgets.Button(description="Next")
|
516 |
self.eval_b = widgets.Button(description="Evaluation")
|
517 |
|
518 |
self.hbox = widgets.HBox([self.train_b, self.valueb])
|
519 |
|
|
|
520 |
self.next_b.on_click(self.Next_button)
|
521 |
self.train_b.on_click(self.Train_button)
|
522 |
self.eval_b.on_click(self.Evaluation_button)
|
@@ -527,36 +524,37 @@ class diagnosis():
|
|
527 |
clear_output()
|
528 |
self.i=self.i+1
|
529 |
|
530 |
-
#
|
531 |
-
self.domain = names[self.i]
|
532 |
-
print("Name:", self.domain)
|
533 |
-
|
534 |
-
# global data
|
535 |
self.data = self.Valid_resource[self.Valid_resource['domain'] == self.domain]
|
|
|
|
|
536 |
print(self.data['label'].value_counts())
|
537 |
-
|
538 |
display(self.hbox)
|
539 |
display(self.next_b)
|
540 |
|
|
|
541 |
# Train button
|
542 |
def Train_button(self, y):
|
543 |
clear_output()
|
544 |
print(self.domain)
|
545 |
|
546 |
-
#
|
547 |
self.data_train_loader, self.data_test_loader, self.data_train, self.data_test = prepare_data(self.data,
|
548 |
train_size_per_class = self.valueb.value,
|
549 |
-
batch_size = {'train': Info['inner_batch_size'],
|
550 |
-
'test': batch_size_test},
|
551 |
-
max_seq_length = Info['max_seq_length'],
|
552 |
-
tokenizer = Info['tokenizer'],
|
553 |
input = "text",
|
554 |
-
output = "label"
|
|
|
555 |
|
|
|
556 |
self.logits, self.X_embedded, self.labels, self.features = train_loop(self.data_train_loader, self.data_test_loader,
|
557 |
-
model, device,
|
558 |
-
epoch = Info['inner_update_step'],
|
559 |
-
lr=Info['inner_update_lr'],
|
560 |
print_info=True,
|
561 |
name = self.domain)
|
562 |
|
@@ -565,6 +563,7 @@ class diagnosis():
|
|
565 |
display(tresh_box)
|
566 |
display(self.next_b)
|
567 |
|
|
|
568 |
# Evaluation button
|
569 |
def Evaluation_button(self, te):
|
570 |
clear_output()
|
@@ -573,19 +572,18 @@ class diagnosis():
|
|
573 |
print(self.domain)
|
574 |
# print("\n")
|
575 |
print("-------Train data-------")
|
576 |
-
print(
|
577 |
print("-------Test data-------")
|
578 |
-
print(
|
579 |
# print("\n")
|
580 |
|
581 |
display(self.next_b)
|
582 |
display(tresh_box)
|
583 |
display(self.hbox)
|
584 |
|
585 |
-
|
586 |
metrics = plot(self.logits, self.X_embedded, self.labels,
|
587 |
-
|
588 |
-
# namefig= "./"+base_path +"/"+"Results/size_layer/"+ name_domain+'/' +str(n_layers) + '/img/' + str(attempt) + 'plots',
|
589 |
namefig= 'test',
|
590 |
make_plot = True,
|
591 |
print_stats = True,
|
@@ -593,261 +591,150 @@ class diagnosis():
|
|
593 |
|
594 |
def __call__(self):
|
595 |
self.i= self.start-1
|
596 |
-
|
597 |
clear_output()
|
598 |
display(self.next_b)
|
599 |
|
600 |
|
601 |
|
602 |
|
|
|
|
|
|
|
|
|
|
|
|
|
603 |
|
|
|
|
|
|
|
|
|
604 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
605 |
|
606 |
|
607 |
-
|
608 |
-
|
609 |
-
import torch.nn.functional as F
|
610 |
-
import torch.nn as nn
|
611 |
-
import math
|
612 |
-
import torch
|
613 |
-
import numpy as np
|
614 |
-
import pandas as pd
|
615 |
-
import time
|
616 |
-
import transformers
|
617 |
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
618 |
-
from sklearn.manifold import TSNE
|
619 |
-
from copy import deepcopy, copy
|
620 |
-
import seaborn as sns
|
621 |
-
import matplotlib.pylab as plt
|
622 |
-
from pprint import pprint
|
623 |
-
import shutil
|
624 |
-
import datetime
|
625 |
-
import re
|
626 |
-
import json
|
627 |
-
from pathlib import Path
|
628 |
-
|
629 |
-
import torch
|
630 |
-
import torch.nn as nn
|
631 |
-
from torch.utils.data import Dataset, DataLoader
|
632 |
-
import unicodedata
|
633 |
-
import re
|
634 |
-
|
635 |
-
import torch
|
636 |
-
import torch.nn as nn
|
637 |
-
from torch.utils.data import Dataset, DataLoader
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
# Pre-trained model
|
642 |
-
class Encoder(nn.Module):
|
643 |
-
def __init__(self, layers, freeze_bert, model):
|
644 |
-
super(Encoder, self).__init__()
|
645 |
-
|
646 |
-
# Dummy Parameter
|
647 |
-
self.dummy_param = nn.Parameter(torch.empty(0))
|
648 |
-
|
649 |
-
# Pre-trained model
|
650 |
-
self.model = deepcopy(model)
|
651 |
-
|
652 |
-
# Freezing bert parameters
|
653 |
-
if freeze_bert:
|
654 |
-
for param in self.model.parameters():
|
655 |
-
param.requires_grad = freeze_bert
|
656 |
-
|
657 |
-
# Selecting hidden layers of the pre-trained model
|
658 |
-
old_model_encoder = self.model.encoder.layer
|
659 |
-
new_model_encoder = nn.ModuleList()
|
660 |
-
|
661 |
-
for i in layers:
|
662 |
-
new_model_encoder.append(old_model_encoder[i])
|
663 |
-
|
664 |
-
self.model.encoder.layer = new_model_encoder
|
665 |
|
666 |
-
# Feed forward
|
667 |
-
def forward(self, **x):
|
668 |
-
return self.model(**x)['pooler_output']
|
669 |
-
|
670 |
-
# Complete model
|
671 |
-
class SLR_Classifier(nn.Module):
|
672 |
-
def __init__(self, **data):
|
673 |
-
super(SLR_Classifier, self).__init__()
|
674 |
-
|
675 |
-
# Dummy Parameter
|
676 |
-
self.dummy_param = nn.Parameter(torch.empty(0))
|
677 |
-
|
678 |
-
# Loss function
|
679 |
-
# Binary Cross Entropy with logits reduced to mean
|
680 |
-
self.loss_fn = nn.BCEWithLogitsLoss(reduction = 'mean',
|
681 |
-
pos_weight=torch.FloatTensor([data.get("pos_weight", 2.5)]))
|
682 |
-
|
683 |
-
# Pre-trained model
|
684 |
-
self.Encoder = Encoder(layers = data.get("bert_layers", range(12)),
|
685 |
-
freeze_bert = data.get("freeze_bert", False),
|
686 |
-
model = data.get("model"),
|
687 |
-
)
|
688 |
-
|
689 |
-
# Feature Map Layer
|
690 |
-
self.feature_map = nn.Sequential(
|
691 |
-
# nn.LayerNorm(self.Encoder.model.config.hidden_size),
|
692 |
-
nn.BatchNorm1d(self.Encoder.model.config.hidden_size),
|
693 |
-
# nn.Dropout(data.get("drop", 0.5)),
|
694 |
-
nn.Linear(self.Encoder.model.config.hidden_size, 200),
|
695 |
-
nn.Dropout(data.get("drop", 0.5)),
|
696 |
-
)
|
697 |
-
|
698 |
-
# Classifier Layer
|
699 |
-
self.classifier = nn.Sequential(
|
700 |
-
# nn.LayerNorm(self.Encoder.model.config.hidden_size),
|
701 |
-
# nn.Dropout(data.get("drop", 0.5)),
|
702 |
-
# nn.BatchNorm1d(self.Encoder.model.config.hidden_size),
|
703 |
-
# nn.Dropout(data.get("drop", 0.5)),
|
704 |
-
nn.Tanh(),
|
705 |
-
nn.Linear(200, 1)
|
706 |
-
)
|
707 |
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
# Feed forward
|
713 |
-
def forward(self, input_ids, attention_mask, token_type_ids, labels):
|
714 |
-
|
715 |
-
predict = self.Encoder(**{"input_ids":input_ids,
|
716 |
-
"attention_mask":attention_mask,
|
717 |
-
"token_type_ids":token_type_ids})
|
718 |
-
feature = self.feature_map(predict)
|
719 |
-
logit = self.classifier(feature)
|
720 |
-
|
721 |
-
predict = torch.sigmoid(logit)
|
722 |
|
723 |
-
#
|
724 |
-
|
725 |
-
|
726 |
-
return [loss, [feature, logit], predict]
|
727 |
-
|
728 |
-
|
729 |
-
# Undesirable patterns within texts
|
730 |
-
patterns = {
|
731 |
-
'CONCLUSIONS AND IMPLICATIONS':'',
|
732 |
-
'BACKGROUND AND PURPOSE':'',
|
733 |
-
'EXPERIMENTAL APPROACH':'',
|
734 |
-
'KEY RESULTS AEA':'',
|
735 |
-
'©':'',
|
736 |
-
'®':'',
|
737 |
-
'μ':'',
|
738 |
-
'(C)':'',
|
739 |
-
'OBJECTIVE:':'',
|
740 |
-
'MATERIALS AND METHODS:':'',
|
741 |
-
'SIGNIFICANCE:':'',
|
742 |
-
'BACKGROUND:':'',
|
743 |
-
'RESULTS:':'',
|
744 |
-
'METHODS:':'',
|
745 |
-
'CONCLUSIONS:':'',
|
746 |
-
'AIM:':'',
|
747 |
-
'STUDY DESIGN:':'',
|
748 |
-
'CLINICAL RELEVANCE:':'',
|
749 |
-
'CONCLUSION:':'',
|
750 |
-
'HYPOTHESIS:':'',
|
751 |
-
'CLINICAL RELEVANCE:':'',
|
752 |
-
'Questions/Purposes:':'',
|
753 |
-
'Introduction:':'',
|
754 |
-
'PURPOSE:':'',
|
755 |
-
'PATIENTS AND METHODS:':'',
|
756 |
-
'FINDINGS:':'',
|
757 |
-
'INTERPRETATIONS:':'',
|
758 |
-
'FUNDING:':'',
|
759 |
-
'PROGRESS:':'',
|
760 |
-
'CONTEXT:':'',
|
761 |
-
'MEASURES:':'',
|
762 |
-
'DESIGN:':'',
|
763 |
-
'BACKGROUND AND OBJECTIVES:':'',
|
764 |
-
'<p>':'',
|
765 |
-
'</p>':'',
|
766 |
-
'<<ETX>>':'',
|
767 |
-
'+/-':'',
|
768 |
-
}
|
769 |
-
|
770 |
-
patterns = {x.lower():y for x,y in patterns.items()}
|
771 |
-
|
772 |
-
LABEL_MAP = {'negative': 0,
|
773 |
-
'not included':0,
|
774 |
-
'0':0,
|
775 |
-
0:0,
|
776 |
-
'excluded':0,
|
777 |
-
'positive': 1,
|
778 |
-
'included':1,
|
779 |
-
'1':1,
|
780 |
-
1:1,
|
781 |
-
}
|
782 |
-
|
783 |
-
class SLR_DataSet(Dataset):
|
784 |
-
def __init__(self, **args):
|
785 |
-
self.tokenizer = args.get('tokenizer')
|
786 |
-
self.data = args.get('data')
|
787 |
-
self.max_seq_length = args.get("max_seq_length", 512)
|
788 |
-
self.INPUT_NAME = args.get("input", 'x')
|
789 |
-
self.LABEL_NAME = args.get("output", 'y')
|
790 |
-
|
791 |
-
# Tokenizing and processing text
|
792 |
-
def encode_text(self, example):
|
793 |
-
comment_text = example[self.INPUT_NAME]
|
794 |
-
comment_text = self.treat_text(comment_text)
|
795 |
-
|
796 |
-
try:
|
797 |
-
labels = LABEL_MAP[example[self.LABEL_NAME].lower()]
|
798 |
-
except:
|
799 |
-
labels = -1
|
800 |
-
|
801 |
-
encoding = self.tokenizer.encode_plus(
|
802 |
-
(comment_text, "It is great text"),
|
803 |
-
add_special_tokens=True,
|
804 |
-
max_length=self.max_seq_length,
|
805 |
-
return_token_type_ids=True,
|
806 |
-
padding="max_length",
|
807 |
-
truncation=True,
|
808 |
-
return_attention_mask=True,
|
809 |
-
return_tensors='pt',
|
810 |
-
)
|
811 |
-
|
812 |
-
|
813 |
-
return tuple((
|
814 |
-
encoding["input_ids"].flatten(),
|
815 |
-
encoding["attention_mask"].flatten(),
|
816 |
-
encoding["token_type_ids"].flatten(),
|
817 |
-
torch.tensor([torch.tensor(labels).to(int)])
|
818 |
-
))
|
819 |
-
|
820 |
-
# Text processing function
|
821 |
-
def treat_text(self, text):
|
822 |
-
text = unicodedata.normalize("NFKD",str(text))
|
823 |
-
text = multiple_replace(patterns,text.lower())
|
824 |
-
text = re.sub('(\(.+\))|(\[.+\])|( \d )|(<)|(>)|(- )','', text)
|
825 |
-
text = re.sub('( +)',' ', text)
|
826 |
-
text = re.sub('(, ,)|(,,)',',', text)
|
827 |
-
text = re.sub('(%)|(per cent)',' percent', text)
|
828 |
-
return text
|
829 |
-
|
830 |
-
def __len__(self):
|
831 |
-
return len(self.data)
|
832 |
-
|
833 |
-
# Returning data
|
834 |
-
def __getitem__(self, index: int):
|
835 |
-
# print(index)
|
836 |
-
data_row = self.data.reset_index().iloc[index]
|
837 |
-
temp_data = self.encode_text(data_row)
|
838 |
-
return temp_data
|
839 |
-
|
840 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
841 |
|
842 |
-
#
|
843 |
-
|
|
|
|
|
|
|
844 |
|
845 |
-
|
846 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
847 |
|
848 |
-
|
849 |
-
|
|
|
|
|
|
|
850 |
|
851 |
-
# Undesirable patterns within texts
|
852 |
|
853 |
|
|
|
1 |
+
from ML_SLRC import *
|
2 |
+
|
3 |
import os
|
|
|
4 |
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from torch.optim import Adam
|
10 |
+
|
11 |
+
import gc
|
12 |
+
from torchmetrics import functional as fn
|
13 |
+
|
14 |
import random
|
|
|
15 |
|
16 |
+
|
17 |
+
warnings.simplefilter(action='ignore', category=FutureWarning)
|
18 |
+
|
19 |
+
from tqdm import tqdm
|
20 |
+
|
21 |
+
from sklearn.metrics import confusion_matrix
|
22 |
+
from sklearn.metrics import roc_curve, auc
|
23 |
+
import ipywidgets as widgets
|
24 |
+
from IPython.display import display, clear_output
|
25 |
+
import matplotlib.pyplot as plt
|
26 |
+
import warnings
|
27 |
import torch
|
28 |
+
|
|
|
29 |
import time
|
|
|
|
|
30 |
from sklearn.manifold import TSNE
|
31 |
+
from copy import deepcopy
|
32 |
import seaborn as sns
|
33 |
import matplotlib.pylab as plt
|
|
|
|
|
|
|
|
|
34 |
import json
|
35 |
from pathlib import Path
|
36 |
+
|
37 |
+
import re
|
38 |
+
from collections import defaultdict
|
39 |
+
|
40 |
+
# SEED = 2222
|
41 |
+
|
42 |
+
# gen_seed = torch.Generator().manual_seed(SEED)
|
43 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
|
|
|
46 |
|
|
|
47 |
|
48 |
|
49 |
# Random seed function
|
|
|
54 |
np.random.seed(value)
|
55 |
random.seed(value)
|
56 |
|
57 |
+
# Tasks for meta-learner
|
58 |
def create_batch_of_tasks(taskset, is_shuffle = True, batch_size = 4):
|
59 |
idxs = list(range(0,len(taskset)))
|
60 |
if is_shuffle:
|
|
|
63 |
yield [taskset[idxs[i]] for i in range(i, min(i + batch_size,len(taskset)))]
|
64 |
|
65 |
|
66 |
+
# Prepare data to process by Domain-learner
|
67 |
+
def prepare_data(data, batch_size, tokenizer,max_seq_length,
|
68 |
input = 'text', output = 'label',
|
69 |
+
train_size_per_class = 5, global_datasets = False,
|
70 |
+
treat_text_fun =None):
|
71 |
data = data.reset_index().drop("index", axis=1)
|
72 |
|
73 |
+
if global_datasets:
|
74 |
+
global data_train, data_test
|
|
|
75 |
|
76 |
+
# Sample task for training
|
77 |
+
data_train = data.groupby('label').sample(train_size_per_class, replace=False)
|
78 |
+
idex = data.index.isin(data_train.index)
|
79 |
|
80 |
+
# The Test set to label by the model
|
81 |
+
data_test = data[~idex].reset_index()
|
82 |
|
83 |
|
84 |
+
# Transform in dataset to model
|
85 |
+
## Train
|
86 |
dataset_train = SLR_DataSet(
|
87 |
data = data_train.sample(frac=1),
|
88 |
input = input,
|
89 |
output = output,
|
90 |
tokenizer=tokenizer,
|
91 |
+
max_seq_length =max_seq_length,
|
92 |
+
treat_text =treat_text_fun)
|
93 |
|
94 |
+
## Test
|
|
|
|
|
95 |
dataset_test = SLR_DataSet(
|
96 |
data = data_test,
|
97 |
input = input,
|
98 |
output = output,
|
99 |
tokenizer=tokenizer,
|
100 |
+
max_seq_length =max_seq_length,
|
101 |
+
treat_text =treat_text_fun)
|
102 |
|
103 |
# Dataloaders
|
104 |
+
## Train
|
105 |
data_train_loader = DataLoader(dataset_train,
|
106 |
shuffle=True,
|
107 |
batch_size=batch_size['train']
|
108 |
)
|
109 |
|
110 |
+
## Test
|
111 |
if len(dataset_test) % batch_size['test'] == 1 :
|
112 |
data_test_loader = DataLoader(dataset_test,
|
113 |
batch_size=batch_size['test'],
|
|
|
120 |
return data_train_loader, data_test_loader, data_train, data_test
|
121 |
|
122 |
|
123 |
+
# Meta trainer
|
124 |
+
def meta_train(data, model, device, Info,
|
125 |
+
print_epoch =True,
|
126 |
+
Test_resource =None,
|
127 |
+
treat_text_fun =None):
|
128 |
|
129 |
+
# Meta-learner model
|
|
|
|
|
|
|
|
|
|
|
130 |
learner = Learner(model = model, device = device, **Info)
|
131 |
|
132 |
# Testing tasks
|
133 |
if isinstance(Test_resource, pd.DataFrame):
|
134 |
test = MetaTask(Test_resource, num_task = 0, k_support=10, k_query=10,
|
135 |
+
training=False,treat_text =treat_text_fun, **Info)
|
136 |
|
137 |
|
138 |
torch.clear_autocast_cache()
|
139 |
gc.collect()
|
140 |
torch.cuda.empty_cache()
|
141 |
|
142 |
+
# Meta epoch (Outer epoch)
|
143 |
for epoch in tqdm(range(Info['meta_epoch']), desc= "Meta epoch ", ncols=80):
|
|
|
144 |
|
145 |
+
# Train tasks
|
146 |
train = MetaTask(data,
|
147 |
num_task = Info['num_task_train'],
|
148 |
k_support=Info['k_qry'],
|
149 |
+
k_query=Info['k_spt'],
|
150 |
+
treat_text =treat_text_fun, **Info)
|
151 |
|
152 |
+
# Batch of train tasks
|
153 |
db = create_batch_of_tasks(train, is_shuffle = True, batch_size = Info["outer_batch_size"])
|
154 |
|
155 |
if print_epoch:
|
156 |
# Outer loop bach training
|
157 |
for step, task_batch in enumerate(db):
|
158 |
print("\n-----------------Training Mode","Meta_epoch:", epoch ,"-----------------\n")
|
159 |
+
|
160 |
+
# meta-feedfoward (outer-feedfoward)
|
161 |
acc = learner(task_batch, valid_train= print_epoch)
|
162 |
print('Step:', step, '\ttraining Acc:', acc)
|
163 |
+
|
164 |
if isinstance(Test_resource, pd.DataFrame):
|
165 |
+
# Validating Model
|
166 |
if ((epoch+1) % 4) + step == 0:
|
167 |
random_seed(123)
|
168 |
print("\n-----------------Testing Mode-----------------\n")
|
169 |
+
|
170 |
+
# Batch of test tasks
|
171 |
db_test = create_batch_of_tasks(test, is_shuffle = False, batch_size = 1)
|
172 |
acc_all_test = []
|
173 |
|
|
|
181 |
|
182 |
# Restarting training randomly
|
183 |
random_seed(int(time.time() % 10))
|
184 |
+
|
|
|
185 |
else:
|
186 |
for step, task_batch in enumerate(db):
|
187 |
+
# meta-feedfoward (outer-feedfoward)
|
188 |
acc = learner(task_batch, print_epoch, valid_train= print_epoch)
|
189 |
|
190 |
torch.clear_autocast_cache()
|
|
|
194 |
|
195 |
|
196 |
def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr = 1, print_info = True, name = 'name'):
|
197 |
+
# Start the model's parameters
|
198 |
model_meta = deepcopy(model)
|
199 |
optimizer = Adam(model_meta.parameters(), lr=lr)
|
200 |
|
201 |
model_meta.to(device)
|
202 |
model_meta.train()
|
203 |
|
204 |
+
# Task epoch (Inner epoch)
|
205 |
for i in range(0, epoch):
|
206 |
all_loss = []
|
207 |
|
|
|
210 |
batch = tuple(t.to(device) for t in batch)
|
211 |
input_ids, attention_mask,q_token_type_ids, label_id = batch
|
212 |
|
213 |
+
# Inner Feedfoward
|
214 |
loss, _, _ = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze())
|
215 |
|
216 |
+
# compute grads
|
217 |
loss.backward()
|
218 |
|
219 |
+
# update parameters
|
220 |
optimizer.step()
|
221 |
optimizer.zero_grad()
|
222 |
|
|
|
227 |
print("Loss: ", np.mean(all_loss))
|
228 |
|
229 |
|
230 |
+
# Test evaluation
|
231 |
model_meta.eval()
|
232 |
all_loss = []
|
233 |
+
all_acc = []
|
234 |
features = []
|
235 |
labels = []
|
236 |
predi_logit = []
|
237 |
|
238 |
with torch.no_grad():
|
239 |
+
# Test's Batch loop
|
240 |
for inner_step, batch in enumerate(tqdm(data_test_loader,
|
241 |
desc="Test validation | " + name,
|
242 |
ncols=80)) :
|
243 |
batch = tuple(t.to(device) for t in batch)
|
244 |
input_ids, attention_mask,q_token_type_ids, label_id = batch
|
245 |
|
246 |
+
# Predictions
|
247 |
_, feature, prediction = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze())
|
248 |
|
249 |
+
# Save batch's predictions
|
250 |
prediction = prediction.detach().cpu().squeeze()
|
251 |
label_id = label_id.detach().cpu()
|
252 |
+
labels.append(label_id.numpy().squeeze())
|
253 |
+
|
254 |
logit = feature[1].detach().cpu()
|
255 |
+
predi_logit.append(logit.numpy())
|
256 |
|
257 |
+
feature_lat = feature[0].detach().cpu()
|
258 |
features.append(feature_lat.numpy())
|
|
|
259 |
|
260 |
+
# Accuracy over the test's bach
|
261 |
+
acc = fn.accuracy(prediction, label_id).item()
|
262 |
+
all_acc.append(acc)
|
263 |
del input_ids, attention_mask, label_id, batch
|
264 |
|
265 |
+
if print_info:
|
266 |
+
print("acc:", np.mean(all_acc))
|
267 |
|
268 |
model_meta.to('cpu')
|
269 |
gc.collect()
|
|
|
271 |
|
272 |
del model_meta, optimizer
|
273 |
|
274 |
+
return map_feature_tsne(features, labels, predi_logit)
|
275 |
|
276 |
+
# Process predictions and map the feature_map in tsne
|
277 |
+
def map_feature_tsne(features, labels, predi_logit):
|
278 |
+
|
279 |
features = np.concatenate(np.array(features,dtype=object))
|
|
|
|
|
|
|
280 |
features = torch.tensor(features.astype(np.float32)).detach().clone()
|
281 |
+
|
282 |
+
labels = np.concatenate(np.array(labels,dtype=object))
|
283 |
labels = torch.tensor(labels.astype(int)).detach().clone()
|
284 |
+
|
285 |
+
logits = np.concatenate(np.array(predi_logit,dtype=object))
|
286 |
logits = torch.tensor(logits.astype(np.float32)).detach().clone()
|
287 |
|
288 |
+
# Dimention reduction
|
289 |
X_embedded = TSNE(n_components=2, learning_rate='auto',
|
290 |
init='random').fit_transform(features.detach().clone())
|
291 |
|
292 |
return logits.detach().clone(), X_embedded, labels.detach().clone(), features.detach().clone()
|
293 |
+
|
|
|
294 |
def wss_calc(logit, labels, trsh = 0.5):
|
295 |
|
296 |
+
# Prediction label given the threshold
|
297 |
predict_trash = torch.sigmoid(logit).squeeze() >= trsh
|
298 |
+
|
299 |
+
# Compute confusion matrix values
|
300 |
CM = confusion_matrix(labels, predict_trash.to(int) )
|
301 |
tn, fp, fne, tp = CM.ravel()
|
302 |
|
|
|
304 |
N = (tn + fp)
|
305 |
recall = tp/(tp+fne)
|
306 |
|
307 |
+
# WSS
|
308 |
+
wss = (tn + fne)/len(labels) -(1- recall)
|
309 |
|
310 |
+
# AWSS
|
311 |
+
awss = (tn/N - fne/P)
|
312 |
|
313 |
return {
|
314 |
+
"wss": round(wss,4),
|
315 |
+
"awss": round(awss,4),
|
316 |
"R": round(recall,4),
|
317 |
"CM": CM
|
318 |
}
|
319 |
|
320 |
|
321 |
+
# Compute the metrics
|
322 |
+
def plot(logits, X_embedded, labels, threshold, show = True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
namefig = "plot", make_plot = True, print_stats = True, save = True):
|
324 |
col = pd.MultiIndex.from_tuples([
|
325 |
("Predict", "0"),
|
|
|
332 |
|
333 |
predict = torch.sigmoid(logits).detach().clone()
|
334 |
|
335 |
+
# Roc curve
|
|
|
336 |
fpr, tpr, thresholds = roc_curve(labels, predict.squeeze())
|
337 |
|
338 |
+
# Given by a Recall of 95% (threshold avaliation)
|
339 |
+
## WSS
|
340 |
+
### Index to recall
|
341 |
idx_wss95 = sum(tpr < 0.95)
|
342 |
+
### threshold
|
343 |
thresholds95 = thresholds[idx_wss95]
|
344 |
|
345 |
+
### Compute the metrics
|
346 |
wss95_info = wss_calc(logits,labels, thresholds95 )
|
347 |
acc_wss95 = fn.accuracy(predict, labels, threshold=thresholds95)
|
348 |
f1_wss95 = fn.f1_score(predict, labels, threshold=thresholds95)
|
349 |
|
350 |
|
351 |
+
# Given by a threshold (recall avaliation)
|
352 |
+
### Compute the metrics
|
353 |
+
wss_info = wss_calc(logits,labels, threshold )
|
354 |
+
acc_wssR = fn.accuracy(predict, labels, threshold=threshold)
|
355 |
+
f1_wssR = fn.f1_score(predict, labels, threshold=threshold)
|
|
|
|
|
|
|
|
|
356 |
|
357 |
|
358 |
metrics= {
|
|
|
370 |
# f1
|
371 |
"f1@95": f1_wss95.item(),
|
372 |
"f1@R": f1_wssR.item(),
|
373 |
+
# threshold 95
|
374 |
+
"threshold@95": thresholds95
|
375 |
}
|
376 |
|
377 |
+
# Print stats
|
|
|
378 |
if print_stats:
|
379 |
wss95= f"WSS@95:{wss95_info['wss']}, R: {wss95_info['R']}"
|
380 |
wss95_adj= f"ASSWSS@95:{wss95_info['awss']}"
|
|
|
382 |
print(wss95_adj)
|
383 |
print('Acc.:', round(acc_wss95.item(), 4))
|
384 |
print('F1-score:', round(f1_wss95.item(), 4))
|
385 |
+
print(f"threshold to wss95: {round(thresholds95, 4)}")
|
386 |
cm = pd.DataFrame(wss95_info['CM'],
|
387 |
index=index,
|
388 |
columns=col)
|
389 |
|
390 |
print("\nConfusion matrix:")
|
391 |
print(cm)
|
392 |
+
print("\n---Metrics with threshold:", threshold, "----\n")
|
393 |
wss= f"WSS@R:{wss_info['wss']}, R: {wss_info['R']}"
|
394 |
print(wss)
|
395 |
wss_adj= f"AWSS@R:{wss_info['awss']}"
|
|
|
404 |
print(cm)
|
405 |
|
406 |
|
407 |
+
# Plots
|
408 |
|
409 |
if make_plot:
|
410 |
|
411 |
fig, axes = plt.subplots(1, 4, figsize=(25,10))
|
412 |
alpha = torch.squeeze(predict).numpy()
|
413 |
|
414 |
+
# TSNE
|
|
|
415 |
p1 = sns.scatterplot(x=X_embedded[:, 0],
|
416 |
y=X_embedded[:, 1],
|
417 |
hue=labels,
|
418 |
+
alpha=alpha, ax = axes[0]).set_title('Predictions-TSNE', size=20)
|
419 |
|
420 |
+
|
421 |
+
# WSS@95
|
422 |
t_wss = predict >= thresholds95
|
423 |
t_wss = t_wss.squeeze().numpy()
|
|
|
424 |
p2 = sns.scatterplot(x=X_embedded[t_wss, 0],
|
425 |
y=X_embedded[t_wss, 1],
|
426 |
hue=labels[t_wss],
|
427 |
+
alpha=alpha[t_wss], ax = axes[1]).set_title('WSS@95', size=20)
|
428 |
|
429 |
+
# WSS@R
|
430 |
+
t = predict >= threshold
|
431 |
t = t.squeeze().numpy()
|
|
|
432 |
p3 = sns.scatterplot(x=X_embedded[t, 0],
|
433 |
y=X_embedded[t, 1],
|
434 |
hue=labels[t],
|
435 |
+
alpha=alpha[t], ax = axes[2]).set_title(f'Predictions-threshold {threshold}', size=20)
|
|
|
436 |
|
437 |
+
# ROC-Curve
|
438 |
roc_auc = auc(fpr, tpr)
|
439 |
lw = 2
|
|
|
440 |
axes[3].plot(
|
441 |
fpr,
|
442 |
tpr,
|
443 |
color="darkorange",
|
444 |
lw=lw,
|
445 |
label="ROC curve (area = %0.2f)" % roc_auc)
|
|
|
446 |
axes[3].plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
|
447 |
axes[3].axhline(y=0.95, color='r', linestyle='-')
|
448 |
+
# axes[3].set(xlabel="False Positive Rate", ylabel="True Positive Rate")
|
449 |
axes[3].legend(loc="lower right")
|
450 |
+
axes[3].set_title(label= "ROC", size = 20)
|
451 |
+
axes[3].set_ylabel("True Positive Rate", fontsize = 15)
|
452 |
+
axes[3].set_xlabel("False Positive Rate", fontsize = 15)
|
453 |
+
|
454 |
|
455 |
if show:
|
456 |
plt.show()
|
|
|
460 |
|
461 |
return metrics
|
462 |
|
463 |
+
|
464 |
def auc_plot(logits,labels, color = "darkorange", label = "test"):
|
465 |
predict = torch.sigmoid(logits).detach().clone()
|
466 |
fpr, tpr, thresholds = roc_curve(labels, predict.squeeze())
|
|
|
480 |
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
|
481 |
plt.axhline(y=0.95, color='r', linestyle='-')
|
482 |
|
483 |
+
# Interface to evaluation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
484 |
class diagnosis():
|
485 |
+
def __init__(self, names, Valid_resource, batch_size_test,
|
486 |
+
model,Info, device,treat_text_fun=None,start = 0):
|
487 |
self.names=names
|
488 |
self.Valid_resource=Valid_resource
|
489 |
self.batch_size_test=batch_size_test
|
490 |
self.model=model
|
491 |
+
self.start=start
|
492 |
+
self.Info = Info
|
493 |
+
self.device = device
|
494 |
+
self.treat_text_fun = treat_text_fun
|
495 |
+
|
496 |
|
497 |
+
# BOX INPUT
|
498 |
self.value_trash = widgets.FloatText(
|
499 |
value=0.95,
|
500 |
+
description='threshold',
|
501 |
disabled=False
|
502 |
)
|
|
|
503 |
self.valueb = widgets.IntText(
|
504 |
value=10,
|
505 |
description='size',
|
506 |
disabled=False
|
507 |
)
|
508 |
|
509 |
+
# Buttons
|
510 |
self.train_b = widgets.Button(description="Train")
|
511 |
self.next_b = widgets.Button(description="Next")
|
512 |
self.eval_b = widgets.Button(description="Evaluation")
|
513 |
|
514 |
self.hbox = widgets.HBox([self.train_b, self.valueb])
|
515 |
|
516 |
+
# Click buttons functions
|
517 |
self.next_b.on_click(self.Next_button)
|
518 |
self.train_b.on_click(self.Train_button)
|
519 |
self.eval_b.on_click(self.Evaluation_button)
|
|
|
524 |
clear_output()
|
525 |
self.i=self.i+1
|
526 |
|
527 |
+
# Select the domain data
|
528 |
+
self.domain = self.names[self.i]
|
|
|
|
|
|
|
529 |
self.data = self.Valid_resource[self.Valid_resource['domain'] == self.domain]
|
530 |
+
|
531 |
+
print("Name:", self.domain)
|
532 |
print(self.data['label'].value_counts())
|
|
|
533 |
display(self.hbox)
|
534 |
display(self.next_b)
|
535 |
|
536 |
+
|
537 |
# Train button
|
538 |
def Train_button(self, y):
|
539 |
clear_output()
|
540 |
print(self.domain)
|
541 |
|
542 |
+
# Prepare data for training (domain-learner)
|
543 |
self.data_train_loader, self.data_test_loader, self.data_train, self.data_test = prepare_data(self.data,
|
544 |
train_size_per_class = self.valueb.value,
|
545 |
+
batch_size = {'train': self.Info['inner_batch_size'],
|
546 |
+
'test': self.batch_size_test},
|
547 |
+
max_seq_length = self.Info['max_seq_length'],
|
548 |
+
tokenizer = self.Info['tokenizer'],
|
549 |
input = "text",
|
550 |
+
output = "label",
|
551 |
+
treat_text_fun=self.treat_text_fun)
|
552 |
|
553 |
+
# Train the model and predict in the test set
|
554 |
self.logits, self.X_embedded, self.labels, self.features = train_loop(self.data_train_loader, self.data_test_loader,
|
555 |
+
self.model, self.device,
|
556 |
+
epoch = self.Info['inner_update_step'],
|
557 |
+
lr=self.Info['inner_update_lr'],
|
558 |
print_info=True,
|
559 |
name = self.domain)
|
560 |
|
|
|
563 |
display(tresh_box)
|
564 |
display(self.next_b)
|
565 |
|
566 |
+
|
567 |
# Evaluation button
|
568 |
def Evaluation_button(self, te):
|
569 |
clear_output()
|
|
|
572 |
print(self.domain)
|
573 |
# print("\n")
|
574 |
print("-------Train data-------")
|
575 |
+
print(data_train['label'].value_counts())
|
576 |
print("-------Test data-------")
|
577 |
+
print(data_test['label'].value_counts())
|
578 |
# print("\n")
|
579 |
|
580 |
display(self.next_b)
|
581 |
display(tresh_box)
|
582 |
display(self.hbox)
|
583 |
|
584 |
+
# Compute metrics
|
585 |
metrics = plot(self.logits, self.X_embedded, self.labels,
|
586 |
+
threshold=self.Info['threshold'], show = True,
|
|
|
587 |
namefig= 'test',
|
588 |
make_plot = True,
|
589 |
print_stats = True,
|
|
|
591 |
|
592 |
def __call__(self):
|
593 |
self.i= self.start-1
|
|
|
594 |
clear_output()
|
595 |
display(self.next_b)
|
596 |
|
597 |
|
598 |
|
599 |
|
600 |
+
# Simulation attemps of domain learner
|
601 |
+
def pipeline_simulation(Valid_resource, names_to_valid, path_save,
|
602 |
+
model, Info, device, initializer_model,
|
603 |
+
treat_text_fun=None):
|
604 |
+
n_attempt = 5
|
605 |
+
batch_test = 100
|
606 |
|
607 |
+
# Create a directory to save informations
|
608 |
+
for name in names_to_valid:
|
609 |
+
name = re.sub("\.csv", "",name)
|
610 |
+
Path(path_save + name + "/img").mkdir(parents=True, exist_ok=True)
|
611 |
|
612 |
+
# Dict to sabe roc curves
|
613 |
+
roc_stats = defaultdict(lambda: defaultdict(
|
614 |
+
lambda: defaultdict(
|
615 |
+
list
|
616 |
+
)
|
617 |
+
)
|
618 |
+
)
|
619 |
|
620 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
621 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
622 |
|
623 |
+
all_metrics = []
|
624 |
+
# Loop over a list of domains
|
625 |
+
for name in names_to_valid:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
626 |
|
627 |
+
# Select a domain dataset
|
628 |
+
data = Valid_resource[Valid_resource['domain'] == name].reset_index().drop("index", axis=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
629 |
|
630 |
+
# Attempts simulation
|
631 |
+
for attempt in range(n_attempt):
|
632 |
+
print("---"*4,"attempt", attempt, "---"*4)
|
633 |
+
|
634 |
+
# Prepare data to pass to the model
|
635 |
+
data_train_loader, data_test_loader, _ , _ = prepare_data(data,
|
636 |
+
train_size_per_class = Info['k_spt'],
|
637 |
+
batch_size = {'train': Info['inner_batch_size'],
|
638 |
+
'test': batch_test},
|
639 |
+
max_seq_length = Info['max_seq_length'],
|
640 |
+
tokenizer = Info['tokenizer'],
|
641 |
+
input = "text",
|
642 |
+
output = "label",
|
643 |
+
treat_text_fun=treat_text_fun)
|
644 |
+
|
645 |
+
# Train the model and evaluate on the test set of the domain
|
646 |
+
logits, X_embedded, labels, features = train_loop(data_train_loader, data_test_loader,
|
647 |
+
model, device,
|
648 |
+
epoch = Info['inner_update_step'],
|
649 |
+
lr=Info['inner_update_lr'],
|
650 |
+
print_info=False,
|
651 |
+
name = name)
|
652 |
+
|
653 |
+
|
654 |
+
name_domain = re.sub("\.csv", "",name)
|
655 |
|
656 |
+
# Compute the metrics
|
657 |
+
metrics = plot(logits, X_embedded, labels,
|
658 |
+
threshold=Info['threshold'], show = False,
|
659 |
+
namefig= path_save + name_domain + "/img/" + str(attempt) + 'plots',
|
660 |
+
make_plot = True, print_stats = False, save = True)
|
661 |
|
662 |
+
# Compute the roc-curve
|
663 |
+
fpr, tpr, _ = roc_curve(labels, torch.sigmoid(logits).squeeze())
|
664 |
+
|
665 |
+
# Save the correspoud information of the domain
|
666 |
+
metrics['name'] = name_domain
|
667 |
+
metrics['layer_size'] = Info['bert_layers']
|
668 |
+
metrics['attempt'] = attempt
|
669 |
+
roc_stats[name_domain][str(Info['bert_layers'])]['fpr'].append(fpr.tolist())
|
670 |
+
roc_stats[name_domain][str(Info['bert_layers'])]['tpr'].append(tpr.tolist())
|
671 |
+
all_metrics.append(metrics)
|
672 |
+
|
673 |
+
# Save the metrics and the roc curve of the attemp
|
674 |
+
pd.DataFrame(all_metrics).to_csv(path_save+ "metrics.csv")
|
675 |
+
roc_path = path_save + "roc_stats.json"
|
676 |
+
with open(roc_path, 'w') as fp:
|
677 |
+
json.dump(roc_stats, fp)
|
678 |
+
|
679 |
+
|
680 |
+
del fpr, tpr, logits, X_embedded, labels
|
681 |
+
del features, metrics, _
|
682 |
+
|
683 |
+
|
684 |
+
# Save the information used to evaluate the validation resource
|
685 |
+
save_info = Info.copy()
|
686 |
+
save_info['model'] = initializer_model.tokenizer.name_or_path
|
687 |
+
save_info.pop("tokenizer")
|
688 |
+
save_info.pop("bert_layers")
|
689 |
+
|
690 |
+
info_path = path_save+"info.json"
|
691 |
+
with open(info_path, 'w') as fp:
|
692 |
+
json.dump(save_info, fp)
|
693 |
+
|
694 |
+
|
695 |
+
# Loading dataset statistics
|
696 |
+
def load_data_statistics(paths, names):
|
697 |
+
size = []
|
698 |
+
pos = []
|
699 |
+
neg = []
|
700 |
+
for p in paths:
|
701 |
+
data = pd.read_csv(p)
|
702 |
+
data = data.dropna()
|
703 |
+
# Dataset size
|
704 |
+
size.append(len(data))
|
705 |
+
# Number of positive labels
|
706 |
+
pos.append(data['labels'].value_counts()[1])
|
707 |
+
# Number of negative labels
|
708 |
+
neg.append(data['labels'].value_counts()[0])
|
709 |
+
del data
|
710 |
+
|
711 |
+
info_load = pd.DataFrame({
|
712 |
+
"size":size,
|
713 |
+
"pos":pos,
|
714 |
+
"neg":neg,
|
715 |
+
"names":names,
|
716 |
+
"paths": paths })
|
717 |
+
return info_load
|
718 |
+
|
719 |
+
# Loading the datasets
|
720 |
+
def load_data(train_info_load):
|
721 |
+
|
722 |
+
col = ['abstract','title', 'labels', 'domain']
|
723 |
+
|
724 |
+
data_train = pd.DataFrame(columns=col)
|
725 |
+
for p in train_info_load['paths']:
|
726 |
+
data_temp = pd.read_csv(p).loc[:, ['labels', 'title', 'abstract']]
|
727 |
+
data_temp = pd.read_csv(p).loc[:, ['labels', 'title', 'abstract']]
|
728 |
+
data_temp['domain'] = os.path.basename(p)
|
729 |
+
data_train = pd.concat([data_train, data_temp])
|
730 |
+
|
731 |
+
data_train['text'] = data_train['title'] + data_train['abstract'].replace(np.nan, '')
|
732 |
|
733 |
+
return( data_train \
|
734 |
+
.replace({"labels":{0:"negative", 1:'positive'}})\
|
735 |
+
.rename({"labels":"label"} , axis=1)\
|
736 |
+
.loc[ :,("text","domain","label")]
|
737 |
+
)
|
738 |
|
|
|
739 |
|
740 |
|