nam194 commited on
Commit
737279c
·
1 Parent(s): 8824528

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +54 -0
model.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from imports import *
2
+ from all_datasets import *
3
+
4
+ class PhoBertLstmCrf(RobertaForTokenClassification):
5
+ def __init__(self, config):
6
+ super(PhoBertLstmCrf, self).__init__(config=config)
7
+ self.num_labels = config.num_labels
8
+ self.lstm = nn.LSTM(input_size=config.hidden_size,
9
+ hidden_size=config.hidden_size // 2,
10
+ num_layers=1,
11
+ batch_first=True,
12
+ bidirectional=True)
13
+ self.crf = CRF(config.num_labels, batch_first=True)
14
+
15
+ @staticmethod
16
+ def sort_batch(src_tensor, lengths):
17
+ """
18
+ Sort a minibatch by the length of the sequences with the longest sequences first
19
+ return the sorted batch targes and sequence lengths.
20
+ This way the output can be used by pack_padd ed_sequences(...)
21
+ """
22
+ seq_lengths, perm_idx = lengths.sort(0, descending=True)
23
+ seq_tensor = src_tensor[perm_idx]
24
+ _, reversed_idx = perm_idx.sort(0, descending=False)
25
+ return seq_tensor, seq_lengths, reversed_idx
26
+
27
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, valid_ids=None,
28
+ label_masks=None):
29
+ seq_outputs = self.roberta(input_ids=input_ids,
30
+ token_type_ids=token_type_ids,
31
+ attention_mask=attention_mask,
32
+ head_mask=None)[0]
33
+
34
+ batch_size, max_len, feat_dim = seq_outputs.shape
35
+ seq_lens = torch.sum(label_masks, dim=-1)
36
+ range_vector = torch.arange(0, batch_size, dtype=torch.long, device=seq_outputs.device).unsqueeze(1)
37
+ seq_outputs = seq_outputs[range_vector, valid_ids]
38
+
39
+ sorted_seq_outputs, sorted_seq_lens, reversed_idx = self.sort_batch(src_tensor=seq_outputs,
40
+ lengths=seq_lens)
41
+ packed_words = pack_padded_sequence(sorted_seq_outputs, sorted_seq_lens.cpu(), True)
42
+ lstm_outs, _ = self.lstm(packed_words)
43
+ lstm_outs, _ = pad_packed_sequence(lstm_outs, batch_first=True, total_length=max_len)
44
+ seq_outputs = lstm_outs[reversed_idx]
45
+
46
+ seq_outputs = self.dropout(seq_outputs)
47
+ logits = self.classifier(seq_outputs)
48
+ seq_tags = self.crf.decode(logits, mask=label_masks != 0)
49
+
50
+ if labels is not None:
51
+ log_likelihood = self.crf(logits, labels, mask=label_masks.type(torch.uint8))
52
+ return NerOutput(loss=-1.0 * log_likelihood, tags=seq_tags, cls_metrics=seq_tags)
53
+ else:
54
+ return NerOutput(tags=seq_tags)