sguarnaccio commited on
Commit
0d39895
·
1 Parent(s): 647b45a

Upload clf_ner.py

Browse files
Files changed (1) hide show
  1. clf_ner.py +84 -0
clf_ner.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertPreTrainedModel
2
+ from typing import List
3
+ import torch
4
+ from torch import nn
5
+ import numpy as np
6
+ from transformers import AutoTokenizer,BertModel
7
+
8
+ class ClassifierNER(BertPreTrainedModel):
9
+ def __init__(self,config):
10
+ super(ClassifierNER,self).__init__(config)
11
+ self.bert = BertModel(config, add_pooling_layer=True)
12
+ self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer)
13
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
14
+ self.loss_fct = nn.CrossEntropyLoss()
15
+ # set classifier layer
16
+ self.clf_labels= config.clf_labels
17
+ self.clf_classes = len(self.clf_labels)
18
+ self.clf_linear = nn.Linear(config.hidden_size,self.clf_classes)
19
+ #set ner layer
20
+ self.ner_labels = config.ner_labels
21
+ self.ner_classes = len(self.ner_labels)
22
+ self.ner_linear = nn.Linear(config.hidden_size,self.ner_classes)
23
+ self.ner_lstm = nn.LSTM(config.hidden_size,config.hidden_size//2,dropout=config.hidden_dropout_prob,batch_first=True,bidirectional=True)
24
+ def forward(self,input_ids,token_type_ids,attention_mask,clf_labels=None,ner_labels=None,**kwargs):
25
+ outputs = self.bert(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,**kwargs)
26
+ clf_output = outputs[1]
27
+ clf_output = self.dropout(clf_output)
28
+ clf_logits = self.clf_linear(clf_output)
29
+ clf_loss = 0
30
+ if clf_labels is not None:
31
+ clf_labels_tensor = torch.tensor(clf_labels, dtype=torch.long)
32
+ clf_loss = self.loss_fct(clf_logits.view(-1, self.clf_classes), clf_labels_tensor.view(-1))
33
+ ner_output = outputs[0]
34
+ ner_output = self.dropout(ner_output)
35
+ lstm_output,hc = self.ner_lstm(ner_output)
36
+ ner_logits = self.ner_linear(lstm_output)
37
+ ner_loss = 0
38
+ if ner_labels is not None:
39
+ ner_loss = self.loss_fct(ner_logits.view(-1,self.ner_classes),ner_labels.view(-1))
40
+ if clf_labels is not None or ner_labels is not None:
41
+ loss = clf_loss + ner_loss
42
+ return loss, clf_logits, ner_logits
43
+ else:
44
+ return clf_logits,ner_logits
45
+ def predict(self,text):
46
+ with torch.no_grad():
47
+ tokenized = self.tokenizer.encode_plus(text,truncation=True,max_length=512,return_tensors="pt",return_offsets_mapping=True)
48
+ clf_prediction,ner_prediction = self(tokenized['input_ids'],tokenized['token_type_ids'],tokenized['attention_mask'])
49
+ clf_prediction = self.clf_labels[str(torch.argmax(clf_prediction,dim=-1).item())]
50
+ ner_prediction = self.align_predictions(text,ner_prediction,tokenized['offset_mapping'])
51
+ return {"classification":clf_prediction,"entities":ner_prediction}
52
+ def align_predictions(self,text,predictions,offsets):
53
+ results = []
54
+ predictions = torch.argmax(predictions,dim=-1)[0].tolist()
55
+ offsets = offsets[0].tolist()
56
+ idx = 0
57
+ while idx < len(predictions):
58
+ pred = predictions[idx]
59
+ label = self.ner_labels[str(pred)]
60
+ if label != "O":
61
+ # Remove the B- or I-
62
+ label = label[2:]
63
+ start, end = offsets[idx]
64
+ # Grab all the tokens labeled with I-label
65
+ idx += 1
66
+ while (
67
+ idx < len(predictions)
68
+ and self.ner_labels[str(predictions[idx])] == f"I-{label}"
69
+ ):
70
+ _, end = offsets[idx]
71
+ idx += 1
72
+
73
+ # The score is the mean of all the scores of the tokens in that grouped entity
74
+ word = text[start:end]
75
+ results.append(
76
+ {
77
+ "label": label,
78
+ "entity": word,
79
+ "start": start,
80
+ "end": end,
81
+ }
82
+ )
83
+ idx += 1
84
+ return results