BioNExt-Tagger / modeling_bionexttagger.py
T-Almeida's picture
Upload model
82bd2af verified
import os
from typing import Optional, Union
from transformers import AutoModel, PreTrainedModel, AutoConfig, BertModel
from transformers.modeling_outputs import TokenClassifierOutput
from torch import nn
from torch.nn import CrossEntropyLoss
from typing import List, Optional
import torch
from itertools import islice
from .configuration_bionexttager import BioNExtTaggerConfig
NUM_PER_LAYER = 16
class BioNExtTaggerModel(PreTrainedModel):
config_class = BioNExtTaggerConfig
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
#print("LOAD BERT?", flush=True)
self.bert = BertModel(config, add_pooling_layer=False)
#AutoModel.from_pretrained(config._name_or_path,
# config=config.get_backbonemodel_config(),
# add_pooling_layer=False)
#print("LOADED BERT", flush=True)
# self.vocab_size = config.vocab_size
#classifier_dropout = (config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dense_activation = nn.GELU(approximate='none')
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.crf = CRF(num_tags=config.num_labels, batch_first=True)
self.reduction=config.crf_reduction
if self.config.freeze == True:
self.manage_freezing()
#self.bert.init_weights() # load pretrained weights
def manage_freezing(self):
for _, param in self.bert.embeddings.named_parameters():
param.requires_grad = False
num_encoders_to_freeze = self.config.num_frozen_encoder
if num_encoders_to_freeze > 0:
for _, param in islice(self.bert.encoder.named_parameters(), num_encoders_to_freeze*NUM_PER_LAYER):
param.requires_grad = False
def forward(self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None
):
# Default `model.config.use_return_dict´ is `True´
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output) # B S E
dense_output = self.dense(sequence_output)
dense_output = self.dense_activation(dense_output)
logits = self.classifier(dense_output)
#logits = self.classifier(sequence_output)
loss = None
if labels is not None:
# During train/test as we don't pass labels during inference
# loss
return self.crf(logits, labels, reduction=self.reduction), logits
else:
# decoded tags
# NOTE: This gather operation (multiGPU) not work here, bc it uses tensors that are on CPU...
return torch.Tensor(self.crf.decode(logits))
# Taken from https://github.com/kmkurn/pytorch-crf/blob/master/torchcrf/__init__.py and fixed got uint8 warning
LARGE_NEGATIVE_NUMBER = -1e9
class CRF(nn.Module):
"""Conditional random field.
This module implements a conditional random field [LMP01]_. The forward computation
of this class computes the log likelihood of the given sequence of tags and
emission score tensor. This class also has `~CRF.decode` method which finds
the best tag sequence given an emission score tensor using `Viterbi algorithm`_.
Args:
num_tags: Number of tags.
batch_first: Whether the first dimension corresponds to the size of a minibatch.
Attributes:
start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size
``(num_tags,)``.
end_transitions (`~torch.nn.Parameter`): End transition score tensor of size
``(num_tags,)``.
transitions (`~torch.nn.Parameter`): Transition score tensor of size
``(num_tags, num_tags)``.
.. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001).
"Conditional random fields: Probabilistic models for segmenting and
labeling sequence data". *Proc. 18th International Conf. on Machine
Learning*. Morgan Kaufmann. pp. 282–289.
.. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm
"""
def __init__(self, num_tags: int, batch_first: bool = False) -> None:
if num_tags <= 0:
raise ValueError(f'invalid number of tags: {num_tags}')
super().__init__()
self.num_tags = num_tags
self.batch_first = batch_first
self.start_transitions = nn.Parameter(torch.empty(num_tags))
self.end_transitions = nn.Parameter(torch.empty(num_tags))
self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
self.reset_parameters()
self.mask_impossible_transitions()
def reset_parameters(self) -> None:
"""Initialize the transition parameters.
The parameters will be initialized randomly from a uniform distribution
between -0.1 and 0.1.
"""
nn.init.uniform_(self.start_transitions, -0.1, 0.1)
nn.init.uniform_(self.end_transitions, -0.1, 0.1)
nn.init.uniform_(self.transitions, -0.1, 0.1)
def mask_impossible_transitions(self) -> None:
"""Set the value of impossible transitions to LARGE_NEGATIVE_NUMBER
- start transition value of I-X
- transition score of O -> I
"""
with torch.no_grad():
for i in range(6):
self.start_transitions[i*2+2] = LARGE_NEGATIVE_NUMBER
#O to any I
self.transitions[0][i*2+2] = LARGE_NEGATIVE_NUMBER
#I to any other I
for j in range(6):
if j!=i:
self.transitions[i*2+1][j*2+2] = LARGE_NEGATIVE_NUMBER
self.transitions[i*2+2][j*2+2] = LARGE_NEGATIVE_NUMBER
def __repr__(self) -> str:
return f'{self.__class__.__name__}(num_tags={self.num_tags})'
def forward(
self,
emissions: torch.Tensor,
tags: torch.LongTensor,
mask: Optional[torch.ByteTensor] = None,
reduction: str = 'sum',
) -> torch.Tensor:
"""Compute the conditional log likelihood of a sequence of tags given emission scores.
Args:
emissions (`~torch.Tensor`): Emission score tensor of size
``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
``(batch_size, seq_length, num_tags)`` otherwise.
tags (`~torch.LongTensor`): Sequence of tags tensor of size
``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
``(batch_size, seq_length)`` otherwise.
mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
reduction: Specifies the reduction to apply to the output:
``none|sum|mean|token_mean``. ``none``: no reduction will be applied.
``sum``: the output will be summed over batches. ``mean``: the output will be
averaged over batches. ``token_mean``: the output will be averaged over tokens.
Returns:
`~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if
reduction is ``none``, ``()`` otherwise.
"""
#self.mask_impossible_transitions()
self._validate(emissions, tags=tags, mask=mask)
if reduction not in ('none', 'sum', 'mean', 'token_mean'):
raise ValueError(f'invalid reduction: {reduction}')
if mask is None:
mask = torch.ones_like(tags, dtype=torch.uint8)
if self.batch_first:
emissions = emissions.transpose(0, 1)
tags = tags.transpose(0, 1)
mask = mask.transpose(0, 1)
# shape: (batch_size,)
numerator = self._compute_score(emissions, tags, mask)
# shape: (batch_size,)
denominator = self._compute_normalizer(emissions, mask)
# shape: (batch_size,)
llh = numerator - denominator
nllh = -llh
if reduction == 'none':
return nllh
if reduction == 'sum':
return nllh.sum()
if reduction == 'mean':
return nllh.mean()
assert reduction == 'token_mean'
return nllh.sum() / mask.type_as(emissions).sum()
def decode(self, emissions: torch.Tensor,
mask: Optional[torch.ByteTensor] = None) -> List[List[int]]:
"""Find the most likely tag sequence using Viterbi algorithm.
Args:
emissions (`~torch.Tensor`): Emission score tensor of size
``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
``(batch_size, seq_length, num_tags)`` otherwise.
mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
Returns:
List of list containing the best tag sequence for each batch.
"""
self._validate(emissions, mask=mask)
if mask is None:
mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)
if self.batch_first:
emissions = emissions.transpose(0, 1)
mask = mask.transpose(0, 1)
return self._viterbi_decode(emissions, mask)
def _validate(
self,
emissions: torch.Tensor,
tags: Optional[torch.LongTensor] = None,
mask: Optional[torch.ByteTensor] = None) -> None:
if emissions.dim() != 3:
raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}')
if emissions.size(2) != self.num_tags:
raise ValueError(
f'expected last dimension of emissions is {self.num_tags}, '
f'got {emissions.size(2)}')
if tags is not None:
if emissions.shape[:2] != tags.shape:
raise ValueError(
'the first two dimensions of emissions and tags must match, '
f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')
if mask is not None:
if emissions.shape[:2] != mask.shape:
raise ValueError(
'the first two dimensions of emissions and mask must match, '
f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
no_empty_seq = not self.batch_first and mask[0].all()
no_empty_seq_bf = self.batch_first and mask[:, 0].all()
if not no_empty_seq and not no_empty_seq_bf:
raise ValueError('mask of the first timestep must all be on')
def _compute_score(
self, emissions: torch.Tensor, tags: torch.LongTensor,
mask: torch.ByteTensor) -> torch.Tensor:
# emissions: (seq_length, batch_size, num_tags)
# tags: (seq_length, batch_size)
# mask: (seq_length, batch_size)
assert emissions.dim() == 3 and tags.dim() == 2
assert emissions.shape[:2] == tags.shape
assert emissions.size(2) == self.num_tags
assert mask.shape == tags.shape
assert mask[0].all()
seq_length, batch_size = tags.shape
mask = mask.type_as(emissions)
# Start transition score and first emission
# shape: (batch_size,)
score = self.start_transitions[tags[0]]
score += emissions[0, torch.arange(batch_size), tags[0]]
for i in range(1, seq_length):
# Transition score to next tag, only added if next timestep is valid (mask == 1)
# shape: (batch_size,)
score += self.transitions[tags[i - 1], tags[i]] * mask[i]
# Emission score for next tag, only added if next timestep is valid (mask == 1)
# shape: (batch_size,)
score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]
# End transition score
# shape: (batch_size,)
seq_ends = mask.long().sum(dim=0) - 1
# shape: (batch_size,)
last_tags = tags[seq_ends, torch.arange(batch_size)]
# shape: (batch_size,)
score += self.end_transitions[last_tags]
return score
def _compute_normalizer(
self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor:
# emissions: (seq_length, batch_size, num_tags)
# mask: (seq_length, batch_size)
assert emissions.dim() == 3 and mask.dim() == 2
assert emissions.shape[:2] == mask.shape
assert emissions.size(2) == self.num_tags
assert mask[0].all()
seq_length = emissions.size(0)
# Start transition score and first emission; score has size of
# (batch_size, num_tags) where for each batch, the j-th column stores
# the score that the first timestep has tag j
# shape: (batch_size, num_tags)
score = self.start_transitions + emissions[0]
for i in range(1, seq_length):
# Broadcast score for every possible next tag
# shape: (batch_size, num_tags, 1)
broadcast_score = score.unsqueeze(2)
# Broadcast emission score for every possible current tag
# shape: (batch_size, 1, num_tags)
broadcast_emissions = emissions[i].unsqueeze(1)
# Compute the score tensor of size (batch_size, num_tags, num_tags) where
# for each sample, entry at row i and column j stores the sum of scores of all
# possible tag sequences so far that end with transitioning from tag i to tag j
# and emitting
# shape: (batch_size, num_tags, num_tags)
next_score = broadcast_score + self.transitions + broadcast_emissions
# Sum over all possible current tags, but we're in score space, so a sum
# becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
# all possible tag sequences so far, that end in tag i
# shape: (batch_size, num_tags)
next_score = torch.logsumexp(next_score, dim=1)
# Set score to the next score if this timestep is valid (mask == 1)
# shape: (batch_size, num_tags)
score = torch.where(mask[i].unsqueeze(1).bool(), next_score, score)
# End transition score
# shape: (batch_size, num_tags)
score += self.end_transitions
# Sum (log-sum-exp) over all possible tags
# shape: (batch_size,)
return torch.logsumexp(score, dim=1)
def _viterbi_decode(self, emissions: torch.FloatTensor,
mask: torch.ByteTensor) -> List[List[int]]:
# emissions: (seq_length, batch_size, num_tags)
# mask: (seq_length, batch_size)
assert emissions.dim() == 3 and mask.dim() == 2
assert emissions.shape[:2] == mask.shape
assert emissions.size(2) == self.num_tags
assert mask[0].all()
seq_length, batch_size = mask.shape
# Start transition and first emission
# shape: (batch_size, num_tags)
score = self.start_transitions + emissions[0]
history = []
# score is a tensor of size (batch_size, num_tags) where for every batch,
# value at column j stores the score of the best tag sequence so far that ends
# with tag j
# history saves where the best tags candidate transitioned from; this is used
# when we trace back the best tag sequence
# Viterbi algorithm recursive case: we compute the score of the best tag sequence
# for every possible next tag
for i in range(1, seq_length):
# Broadcast viterbi score for every possible next tag
# shape: (batch_size, num_tags, 1)
broadcast_score = score.unsqueeze(2)
# Broadcast emission score for every possible current tag
# shape: (batch_size, 1, num_tags)
broadcast_emission = emissions[i].unsqueeze(1)
# Compute the score tensor of size (batch_size, num_tags, num_tags) where
# for each sample, entry at row i and column j stores the score of the best
# tag sequence so far that ends with transitioning from tag i to tag j and emitting
# shape: (batch_size, num_tags, num_tags)
next_score = broadcast_score + self.transitions + broadcast_emission
# Find the maximum score over all possible current tag
# shape: (batch_size, num_tags)
next_score, indices = next_score.max(dim=1)
# Set score to the next score if this timestep is valid (mask == 1)
# and save the index that produces the next score
# shape: (batch_size, num_tags)
score = torch.where(mask[i].unsqueeze(1).bool(), next_score, score)
history.append(indices)
# End transition score
# shape: (batch_size, num_tags)
score += self.end_transitions
# Now, compute the best path for each sample
# shape: (batch_size,)
seq_ends = mask.long().sum(dim=0) - 1
best_tags_list = []
for idx in range(batch_size):
# Find the tag which maximizes the score at the last timestep; this is our best tag
# for the last timestep
_, best_last_tag = score[idx].max(dim=0)
best_tags = [best_last_tag.item()]
# We trace back where the best last tag comes from, append that to our best tag
# sequence, and trace it back again, and so on
for hist in reversed(history[:seq_ends[idx]]):
best_last_tag = hist[idx][best_tags[-1]]
best_tags.append(best_last_tag.item())
# Reverse the order because we start from the last timestep
best_tags.reverse()
best_tags_list.append(best_tags)
return best_tags_list