|
from imports import * |
|
from all_datasets import * |
|
|
|
class PhoBertLstmCrf(RobertaForTokenClassification): |
|
def __init__(self, config): |
|
super(PhoBertLstmCrf, self).__init__(config=config) |
|
self.num_labels = config.num_labels |
|
self.lstm = nn.LSTM(input_size=config.hidden_size, |
|
hidden_size=config.hidden_size // 2, |
|
num_layers=1, |
|
batch_first=True, |
|
bidirectional=True) |
|
self.crf = CRF(config.num_labels, batch_first=True) |
|
|
|
@staticmethod |
|
def sort_batch(src_tensor, lengths): |
|
""" |
|
Sort a minibatch by the length of the sequences with the longest sequences first |
|
return the sorted batch targes and sequence lengths. |
|
This way the output can be used by pack_padd ed_sequences(...) |
|
""" |
|
seq_lengths, perm_idx = lengths.sort(0, descending=True) |
|
seq_tensor = src_tensor[perm_idx] |
|
_, reversed_idx = perm_idx.sort(0, descending=False) |
|
return seq_tensor, seq_lengths, reversed_idx |
|
|
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, valid_ids=None, |
|
label_masks=None): |
|
seq_outputs = self.roberta(input_ids=input_ids, |
|
token_type_ids=token_type_ids, |
|
attention_mask=attention_mask, |
|
head_mask=None)[0] |
|
|
|
batch_size, max_len, feat_dim = seq_outputs.shape |
|
seq_lens = torch.sum(label_masks, dim=-1) |
|
range_vector = torch.arange(0, batch_size, dtype=torch.long, device=seq_outputs.device).unsqueeze(1) |
|
seq_outputs = seq_outputs[range_vector, valid_ids] |
|
|
|
sorted_seq_outputs, sorted_seq_lens, reversed_idx = self.sort_batch(src_tensor=seq_outputs, |
|
lengths=seq_lens) |
|
packed_words = pack_padded_sequence(sorted_seq_outputs, sorted_seq_lens.cpu(), True) |
|
lstm_outs, _ = self.lstm(packed_words) |
|
lstm_outs, _ = pad_packed_sequence(lstm_outs, batch_first=True, total_length=max_len) |
|
seq_outputs = lstm_outs[reversed_idx] |
|
|
|
seq_outputs = self.dropout(seq_outputs) |
|
logits = self.classifier(seq_outputs) |
|
seq_tags = self.crf.decode(logits, mask=label_masks != 0) |
|
|
|
if labels is not None: |
|
log_likelihood = self.crf(logits, labels, mask=label_masks.type(torch.uint8)) |
|
return NerOutput(loss=-1.0 * log_likelihood, tags=seq_tags, cls_metrics=seq_tags) |
|
else: |
|
return NerOutput(tags=seq_tags) |
|
|