|
|
|
import os |
|
from typing import Optional, Union |
|
from transformers import AutoModel, PreTrainedModel, AutoConfig, BertModel |
|
from transformers.modeling_outputs import TokenClassifierOutput |
|
from torch import nn |
|
from torch.nn import CrossEntropyLoss |
|
|
|
from typing import List, Optional |
|
|
|
import torch |
|
from itertools import islice |
|
from .configuration_bionexttager import BioNExtTaggerConfig |
|
|
|
|
|
NUM_PER_LAYER = 16 |
|
|
|
class BioNExtTaggerModel(PreTrainedModel): |
|
config_class = BioNExtTaggerConfig |
|
_keys_to_ignore_on_load_unexpected = [r"pooler"] |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
self.bert = BertModel(config, add_pooling_layer=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.dense_activation = nn.GELU(approximate='none') |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
self.crf = CRF(num_tags=config.num_labels, batch_first=True) |
|
self.reduction=config.crf_reduction |
|
|
|
if self.config.freeze == True: |
|
self.manage_freezing() |
|
|
|
|
|
|
|
def manage_freezing(self): |
|
for _, param in self.bert.embeddings.named_parameters(): |
|
param.requires_grad = False |
|
|
|
num_encoders_to_freeze = self.config.num_frozen_encoder |
|
if num_encoders_to_freeze > 0: |
|
for _, param in islice(self.bert.encoder.named_parameters(), num_encoders_to_freeze*NUM_PER_LAYER): |
|
param.requires_grad = False |
|
|
|
|
|
def forward(self, |
|
input_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
labels=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None |
|
): |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.bert(input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict) |
|
|
|
sequence_output = outputs[0] |
|
sequence_output = self.dropout(sequence_output) |
|
dense_output = self.dense(sequence_output) |
|
dense_output = self.dense_activation(dense_output) |
|
logits = self.classifier(dense_output) |
|
|
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
|
|
|
|
return self.crf(logits, labels, reduction=self.reduction), logits |
|
else: |
|
|
|
|
|
return torch.Tensor(self.crf.decode(logits)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LARGE_NEGATIVE_NUMBER = -1e9 |
|
|
|
class CRF(nn.Module): |
|
"""Conditional random field. |
|
This module implements a conditional random field [LMP01]_. The forward computation |
|
of this class computes the log likelihood of the given sequence of tags and |
|
emission score tensor. This class also has `~CRF.decode` method which finds |
|
the best tag sequence given an emission score tensor using `Viterbi algorithm`_. |
|
Args: |
|
num_tags: Number of tags. |
|
batch_first: Whether the first dimension corresponds to the size of a minibatch. |
|
Attributes: |
|
start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size |
|
``(num_tags,)``. |
|
end_transitions (`~torch.nn.Parameter`): End transition score tensor of size |
|
``(num_tags,)``. |
|
transitions (`~torch.nn.Parameter`): Transition score tensor of size |
|
``(num_tags, num_tags)``. |
|
.. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001). |
|
"Conditional random fields: Probabilistic models for segmenting and |
|
labeling sequence data". *Proc. 18th International Conf. on Machine |
|
Learning*. Morgan Kaufmann. pp. 282–289. |
|
.. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm |
|
""" |
|
|
|
def __init__(self, num_tags: int, batch_first: bool = False) -> None: |
|
if num_tags <= 0: |
|
raise ValueError(f'invalid number of tags: {num_tags}') |
|
super().__init__() |
|
self.num_tags = num_tags |
|
self.batch_first = batch_first |
|
self.start_transitions = nn.Parameter(torch.empty(num_tags)) |
|
self.end_transitions = nn.Parameter(torch.empty(num_tags)) |
|
self.transitions = nn.Parameter(torch.empty(num_tags, num_tags)) |
|
|
|
self.reset_parameters() |
|
self.mask_impossible_transitions() |
|
|
|
def reset_parameters(self) -> None: |
|
"""Initialize the transition parameters. |
|
The parameters will be initialized randomly from a uniform distribution |
|
between -0.1 and 0.1. |
|
""" |
|
nn.init.uniform_(self.start_transitions, -0.1, 0.1) |
|
nn.init.uniform_(self.end_transitions, -0.1, 0.1) |
|
nn.init.uniform_(self.transitions, -0.1, 0.1) |
|
|
|
def mask_impossible_transitions(self) -> None: |
|
"""Set the value of impossible transitions to LARGE_NEGATIVE_NUMBER |
|
- start transition value of I-X |
|
- transition score of O -> I |
|
""" |
|
with torch.no_grad(): |
|
|
|
for i in range(6): |
|
self.start_transitions[i*2+2] = LARGE_NEGATIVE_NUMBER |
|
|
|
self.transitions[0][i*2+2] = LARGE_NEGATIVE_NUMBER |
|
|
|
|
|
for j in range(6): |
|
if j!=i: |
|
self.transitions[i*2+1][j*2+2] = LARGE_NEGATIVE_NUMBER |
|
self.transitions[i*2+2][j*2+2] = LARGE_NEGATIVE_NUMBER |
|
|
|
|
|
|
|
def __repr__(self) -> str: |
|
return f'{self.__class__.__name__}(num_tags={self.num_tags})' |
|
|
|
def forward( |
|
self, |
|
emissions: torch.Tensor, |
|
tags: torch.LongTensor, |
|
mask: Optional[torch.ByteTensor] = None, |
|
reduction: str = 'sum', |
|
) -> torch.Tensor: |
|
"""Compute the conditional log likelihood of a sequence of tags given emission scores. |
|
Args: |
|
emissions (`~torch.Tensor`): Emission score tensor of size |
|
``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, |
|
``(batch_size, seq_length, num_tags)`` otherwise. |
|
tags (`~torch.LongTensor`): Sequence of tags tensor of size |
|
``(seq_length, batch_size)`` if ``batch_first`` is ``False``, |
|
``(batch_size, seq_length)`` otherwise. |
|
mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` |
|
if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. |
|
reduction: Specifies the reduction to apply to the output: |
|
``none|sum|mean|token_mean``. ``none``: no reduction will be applied. |
|
``sum``: the output will be summed over batches. ``mean``: the output will be |
|
averaged over batches. ``token_mean``: the output will be averaged over tokens. |
|
Returns: |
|
`~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if |
|
reduction is ``none``, ``()`` otherwise. |
|
""" |
|
|
|
self._validate(emissions, tags=tags, mask=mask) |
|
if reduction not in ('none', 'sum', 'mean', 'token_mean'): |
|
raise ValueError(f'invalid reduction: {reduction}') |
|
if mask is None: |
|
mask = torch.ones_like(tags, dtype=torch.uint8) |
|
|
|
if self.batch_first: |
|
emissions = emissions.transpose(0, 1) |
|
tags = tags.transpose(0, 1) |
|
mask = mask.transpose(0, 1) |
|
|
|
|
|
numerator = self._compute_score(emissions, tags, mask) |
|
|
|
denominator = self._compute_normalizer(emissions, mask) |
|
|
|
llh = numerator - denominator |
|
nllh = -llh |
|
|
|
if reduction == 'none': |
|
return nllh |
|
if reduction == 'sum': |
|
return nllh.sum() |
|
if reduction == 'mean': |
|
return nllh.mean() |
|
assert reduction == 'token_mean' |
|
return nllh.sum() / mask.type_as(emissions).sum() |
|
|
|
def decode(self, emissions: torch.Tensor, |
|
mask: Optional[torch.ByteTensor] = None) -> List[List[int]]: |
|
"""Find the most likely tag sequence using Viterbi algorithm. |
|
Args: |
|
emissions (`~torch.Tensor`): Emission score tensor of size |
|
``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, |
|
``(batch_size, seq_length, num_tags)`` otherwise. |
|
mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` |
|
if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. |
|
Returns: |
|
List of list containing the best tag sequence for each batch. |
|
""" |
|
self._validate(emissions, mask=mask) |
|
if mask is None: |
|
mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8) |
|
|
|
if self.batch_first: |
|
emissions = emissions.transpose(0, 1) |
|
mask = mask.transpose(0, 1) |
|
|
|
return self._viterbi_decode(emissions, mask) |
|
|
|
def _validate( |
|
self, |
|
emissions: torch.Tensor, |
|
tags: Optional[torch.LongTensor] = None, |
|
mask: Optional[torch.ByteTensor] = None) -> None: |
|
if emissions.dim() != 3: |
|
raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}') |
|
if emissions.size(2) != self.num_tags: |
|
raise ValueError( |
|
f'expected last dimension of emissions is {self.num_tags}, ' |
|
f'got {emissions.size(2)}') |
|
|
|
if tags is not None: |
|
if emissions.shape[:2] != tags.shape: |
|
raise ValueError( |
|
'the first two dimensions of emissions and tags must match, ' |
|
f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}') |
|
|
|
if mask is not None: |
|
if emissions.shape[:2] != mask.shape: |
|
raise ValueError( |
|
'the first two dimensions of emissions and mask must match, ' |
|
f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}') |
|
no_empty_seq = not self.batch_first and mask[0].all() |
|
no_empty_seq_bf = self.batch_first and mask[:, 0].all() |
|
if not no_empty_seq and not no_empty_seq_bf: |
|
raise ValueError('mask of the first timestep must all be on') |
|
|
|
def _compute_score( |
|
self, emissions: torch.Tensor, tags: torch.LongTensor, |
|
mask: torch.ByteTensor) -> torch.Tensor: |
|
|
|
|
|
|
|
assert emissions.dim() == 3 and tags.dim() == 2 |
|
assert emissions.shape[:2] == tags.shape |
|
assert emissions.size(2) == self.num_tags |
|
assert mask.shape == tags.shape |
|
assert mask[0].all() |
|
|
|
seq_length, batch_size = tags.shape |
|
mask = mask.type_as(emissions) |
|
|
|
|
|
|
|
score = self.start_transitions[tags[0]] |
|
score += emissions[0, torch.arange(batch_size), tags[0]] |
|
|
|
for i in range(1, seq_length): |
|
|
|
|
|
score += self.transitions[tags[i - 1], tags[i]] * mask[i] |
|
|
|
|
|
|
|
score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i] |
|
|
|
|
|
|
|
seq_ends = mask.long().sum(dim=0) - 1 |
|
|
|
last_tags = tags[seq_ends, torch.arange(batch_size)] |
|
|
|
score += self.end_transitions[last_tags] |
|
|
|
return score |
|
|
|
def _compute_normalizer( |
|
self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor: |
|
|
|
|
|
assert emissions.dim() == 3 and mask.dim() == 2 |
|
assert emissions.shape[:2] == mask.shape |
|
assert emissions.size(2) == self.num_tags |
|
assert mask[0].all() |
|
|
|
seq_length = emissions.size(0) |
|
|
|
|
|
|
|
|
|
|
|
score = self.start_transitions + emissions[0] |
|
|
|
for i in range(1, seq_length): |
|
|
|
|
|
broadcast_score = score.unsqueeze(2) |
|
|
|
|
|
|
|
broadcast_emissions = emissions[i].unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
next_score = broadcast_score + self.transitions + broadcast_emissions |
|
|
|
|
|
|
|
|
|
|
|
next_score = torch.logsumexp(next_score, dim=1) |
|
|
|
|
|
|
|
score = torch.where(mask[i].unsqueeze(1).bool(), next_score, score) |
|
|
|
|
|
|
|
score += self.end_transitions |
|
|
|
|
|
|
|
return torch.logsumexp(score, dim=1) |
|
|
|
def _viterbi_decode(self, emissions: torch.FloatTensor, |
|
mask: torch.ByteTensor) -> List[List[int]]: |
|
|
|
|
|
assert emissions.dim() == 3 and mask.dim() == 2 |
|
assert emissions.shape[:2] == mask.shape |
|
assert emissions.size(2) == self.num_tags |
|
assert mask[0].all() |
|
|
|
seq_length, batch_size = mask.shape |
|
|
|
|
|
|
|
score = self.start_transitions + emissions[0] |
|
history = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(1, seq_length): |
|
|
|
|
|
broadcast_score = score.unsqueeze(2) |
|
|
|
|
|
|
|
broadcast_emission = emissions[i].unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
next_score = broadcast_score + self.transitions + broadcast_emission |
|
|
|
|
|
|
|
next_score, indices = next_score.max(dim=1) |
|
|
|
|
|
|
|
|
|
score = torch.where(mask[i].unsqueeze(1).bool(), next_score, score) |
|
history.append(indices) |
|
|
|
|
|
|
|
score += self.end_transitions |
|
|
|
|
|
|
|
|
|
seq_ends = mask.long().sum(dim=0) - 1 |
|
best_tags_list = [] |
|
|
|
for idx in range(batch_size): |
|
|
|
|
|
_, best_last_tag = score[idx].max(dim=0) |
|
best_tags = [best_last_tag.item()] |
|
|
|
|
|
|
|
for hist in reversed(history[:seq_ends[idx]]): |
|
best_last_tag = hist[idx][best_tags[-1]] |
|
best_tags.append(best_last_tag.item()) |
|
|
|
|
|
best_tags.reverse() |
|
best_tags_list.append(best_tags) |
|
|
|
return best_tags_list |
|
|
|
|
|
|