Create model.py
Browse files
model.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from imports import *
|
2 |
+
from all_datasets import *
|
3 |
+
|
4 |
+
class PhoBertLstmCrf(RobertaForTokenClassification):
|
5 |
+
def __init__(self, config):
|
6 |
+
super(PhoBertLstmCrf, self).__init__(config=config)
|
7 |
+
self.num_labels = config.num_labels
|
8 |
+
self.lstm = nn.LSTM(input_size=config.hidden_size,
|
9 |
+
hidden_size=config.hidden_size // 2,
|
10 |
+
num_layers=1,
|
11 |
+
batch_first=True,
|
12 |
+
bidirectional=True)
|
13 |
+
self.crf = CRF(config.num_labels, batch_first=True)
|
14 |
+
|
15 |
+
@staticmethod
|
16 |
+
def sort_batch(src_tensor, lengths):
|
17 |
+
"""
|
18 |
+
Sort a minibatch by the length of the sequences with the longest sequences first
|
19 |
+
return the sorted batch targes and sequence lengths.
|
20 |
+
This way the output can be used by pack_padd ed_sequences(...)
|
21 |
+
"""
|
22 |
+
seq_lengths, perm_idx = lengths.sort(0, descending=True)
|
23 |
+
seq_tensor = src_tensor[perm_idx]
|
24 |
+
_, reversed_idx = perm_idx.sort(0, descending=False)
|
25 |
+
return seq_tensor, seq_lengths, reversed_idx
|
26 |
+
|
27 |
+
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, valid_ids=None,
|
28 |
+
label_masks=None):
|
29 |
+
seq_outputs = self.roberta(input_ids=input_ids,
|
30 |
+
token_type_ids=token_type_ids,
|
31 |
+
attention_mask=attention_mask,
|
32 |
+
head_mask=None)[0]
|
33 |
+
|
34 |
+
batch_size, max_len, feat_dim = seq_outputs.shape
|
35 |
+
seq_lens = torch.sum(label_masks, dim=-1)
|
36 |
+
range_vector = torch.arange(0, batch_size, dtype=torch.long, device=seq_outputs.device).unsqueeze(1)
|
37 |
+
seq_outputs = seq_outputs[range_vector, valid_ids]
|
38 |
+
|
39 |
+
sorted_seq_outputs, sorted_seq_lens, reversed_idx = self.sort_batch(src_tensor=seq_outputs,
|
40 |
+
lengths=seq_lens)
|
41 |
+
packed_words = pack_padded_sequence(sorted_seq_outputs, sorted_seq_lens.cpu(), True)
|
42 |
+
lstm_outs, _ = self.lstm(packed_words)
|
43 |
+
lstm_outs, _ = pad_packed_sequence(lstm_outs, batch_first=True, total_length=max_len)
|
44 |
+
seq_outputs = lstm_outs[reversed_idx]
|
45 |
+
|
46 |
+
seq_outputs = self.dropout(seq_outputs)
|
47 |
+
logits = self.classifier(seq_outputs)
|
48 |
+
seq_tags = self.crf.decode(logits, mask=label_masks != 0)
|
49 |
+
|
50 |
+
if labels is not None:
|
51 |
+
log_likelihood = self.crf(logits, labels, mask=label_masks.type(torch.uint8))
|
52 |
+
return NerOutput(loss=-1.0 * log_likelihood, tags=seq_tags, cls_metrics=seq_tags)
|
53 |
+
else:
|
54 |
+
return NerOutput(tags=seq_tags)
|