wavlm-base-spkreg / modeling_wavlm_spkreg.py
yangwang825's picture
Update modeling_wavlm_spkreg.py
5b7e94f verified
import math
import warnings
from typing import Union, Tuple, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput, Wav2Vec2BaseModelOutput
from transformers.models.wavlm.modeling_wavlm import (
WavLMGumbelVectorQuantizer,
WavLMPositionalConvEmbedding,
WavLMFeatureProjection,
WavLMFeatureEncoder,
WavLMEncoderStableLayerNorm,
WavLMEncoder,
WavLMAdapter,
_HIDDEN_STATES_START_POSITION
)
from .configuration_wavlm_spkreg import WavLMSpkRegConfig
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.sum(-1).detach().tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []
max_num_masked_span = compute_num_masked_span(sequence_length)
if max_num_masked_span == 0:
return spec_aug_mask
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
class WavLMSpkRegPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = WavLMSpkRegConfig
base_model_prefix = "wavlm"
main_input_name = "input_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
# gumbel softmax requires special init
if isinstance(module, WavLMGumbelVectorQuantizer):
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
module.weight_proj.bias.data.zero_()
nn.init.uniform_(module.codevectors)
elif isinstance(module, WavLMPositionalConvEmbedding):
nn.init.normal_(
module.conv.weight,
mean=0,
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
)
nn.init.constant_(module.conv.bias, 0)
elif isinstance(module, WavLMFeatureProjection):
k = math.sqrt(1 / module.projection.in_features)
nn.init.uniform_(module.projection.weight, a=-k, b=k)
nn.init.uniform_(module.projection.bias, a=-k, b=k)
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(module.weight)
if module.bias is not None:
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
nn.init.uniform_(module.bias, a=-k, b=k)
def _get_feat_extract_output_lengths(
self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
):
"""
Computes the output length of the convolutional layers
"""
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
if add_adapter:
for _ in range(self.config.num_adapter_layers):
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
return input_lengths
def _get_feature_vector_attention_mask(
self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
):
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
output_lengths = output_lengths.to(torch.long)
batch_size = attention_mask.shape[0]
attention_mask = torch.zeros(
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
)
# these two operations makes sure that all values before the output lengths idxs are attended to
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask
class WavLMSpkRegModel(WavLMSpkRegPreTrainedModel):
def __init__(self, config: WavLMSpkRegConfig):
super().__init__(config)
self.config = config
self.feature_extractor = WavLMFeatureEncoder(config)
self.feature_projection = WavLMFeatureProjection(config)
# model only needs masking vector if mask prob is > 0.0
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
if config.do_stable_layer_norm:
self.encoder = WavLMEncoderStableLayerNorm(config)
else:
self.encoder = WavLMEncoder(config)
self.adapter = WavLMAdapter(config) if config.add_adapter else None
# Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
not be updated during training.
"""
warnings.warn(
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
"Please use the equivalent `freeze_feature_encoder` method instead.",
FutureWarning,
)
self.freeze_feature_encoder()
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
not be updated during training.
"""
self.feature_extractor._freeze_parameters()
def _mask_hidden_states(
self,
hidden_states: torch.FloatTensor,
mask_time_indices: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
"""
Masks extracted features along time axis and/or along feature axis according to
[SpecAugment](https://arxiv.org/abs/1904.08779).
"""
# `config.apply_spec_augment` can set masking to False
if not getattr(self.config, "apply_spec_augment", True):
return hidden_states
# generate indices & apply SpecAugment along time axis
batch_size, sequence_length, hidden_size = hidden_states.size()
if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
elif self.config.mask_time_prob > 0 and self.training:
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length,
attention_mask=attention_mask,
min_masks=self.config.mask_time_min_masks,
)
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
min_masks=self.config.mask_feature_min_masks,
)
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
hidden_states[mask_feature_indices] = 0
return hidden_states
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
mask_time_indices: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
extract_features = self.feature_extractor(input_values)
extract_features = extract_features.transpose(1, 2)
if attention_mask is not None:
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(
extract_features.shape[1], attention_mask, add_adapter=False
)
hidden_states, extract_features = self.feature_projection(extract_features)
hidden_states = self._mask_hidden_states(
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if self.adapter is not None:
hidden_states = self.adapter(hidden_states)
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
return Wav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class AngularLinear(nn.Module):
def __init__(self, in_features: int, out_features: int):
super(AngularLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = torch.nn.Parameter(
torch.FloatTensor(out_features, in_features), requires_grad=True
)
nn.init.xavier_normal_(self.weight, gain=1)
def forward(
self,
inputs: torch.Tensor,
):
# Calculation of cos(theta)
cosine = F.linear(F.normalize(inputs), F.normalize(self.weight))
return cosine
def extra_repr(self) -> str:
return 'in_features={}, out_features={}'.format(
self.in_features, self.out_features
)
class AMSoftmaxLoss(nn.Module):
"""Additive Margin Softmax (CosFace).
Paper: Wang, Feng, et al. "Additive margin softmax for face verification."
IEEE Signal Processing Letters 25.7 (2018): 926-930.
"""
def __init__(
self,
scale: float = 30.0,
margin: float = 0.35,
label_smoothing: float = 0.0,
reduction: str = "mean"
):
"""
Args:
num_classes: Number of classes (output dimension)
scale: Scaling factor for logits (default: 30.0)
margin: Angular margin (default: 0.35)
"""
super(AMSoftmaxLoss, self).__init__()
self.scale = scale
self.margin = margin
self.label_smoothing = label_smoothing
self.reduction = reduction
def forward(
self,
inputs: torch.Tensor,
targets: torch.Tensor,
):
"""
Args:
inputs: Input features of shape (batch_size, num_labels)
targets: Ground truth labels of shape (batch_size)
label_smoothing: Label smoothing factor (default: 0.0)
reduction: Reduction method (default: "mean")
Returns:
Loss value
"""
_, num_labels = inputs.shape
# `inputs` are the outputs from AngularLinear()
cos_theta = torch.clamp(inputs, -1.0 + 1e-7, 1.0 - 1e-7)
psi = cos_theta - self.margin
one_hot = nn.functional.one_hot(targets, num_labels)
outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
loss = F.cross_entropy(
outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction
)
return loss
class AAMSoftmaxLoss(nn.Module):
"""Additive Angular Margin Softmax (ArcFace).
Paper: Deng, Jiankang, et al. "Arcface: Additive angular margin loss for deep face recognition."
Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019.
"""
def __init__(
self,
scale: float = 30.0,
margin: float = 0.2,
easy_margin: bool = False,
label_smoothing: float = 0.0,
reduction: str = "mean"
):
"""
Args:
num_classes: Number of classes (output dimension)
scale: Scaling factor for logits (default: 30.0)
margin: Angular margin (default: 0.35)
easy_margin: Use the easy margin loss (default: False)
"""
super(AAMSoftmaxLoss, self).__init__()
self.scale = scale
self.margin = margin
self.easy_margin = easy_margin
self.label_smoothing = label_smoothing
self.reduction = reduction
def forward(
self,
inputs: torch.Tensor,
targets: torch.Tensor,
):
"""
Args:
inputs: Input features of shape (batch_size, num_labels)
targets: Ground truth labels of shape (batch_size)
Returns:
Loss value
"""
_, num_labels = inputs.shape
# `inputs` are the outputs from AngularLinear()
epsilon = 1e-6
# theta = torch.acos(cos_theta)
# psi = torch.cos(theta + self.margin)
cos_theta = torch.clamp(inputs, -1.0 + epsilon, 1.0 - epsilon)
sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2))
sin_theta = torch.clamp(sin_theta, 0.0 + epsilon, 1.0 - epsilon)
cos_m = math.cos(self.margin)
sin_m = math.sin(self.margin)
psi = cos_theta * cos_m - sin_theta * sin_m # cos(theta + m)
if self.easy_margin:
psi = torch.where(cos_theta > 0, psi, cos_theta)
else:
# Make the function cos(theta+m) monotonic decreasing while theta in [0°, 180°]
psi = torch.where((cos_theta - math.cos(math.pi - self.margin)) > 0, psi, cos_theta - self.margin)
one_hot = nn.functional.one_hot(targets, num_labels)
outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
loss = F.cross_entropy(
outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction
)
return loss
class WavLMSpkRegForSequenceClassification(WavLMSpkRegPreTrainedModel):
def __init__(self, config):
super().__init__(config)
if hasattr(config, "add_adapter") and config.add_adapter:
raise ValueError(
"Sequence classification does not support the use of WavLM adapters (config.add_adapter=True)"
)
self.wavlm = WavLMSpkRegModel(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
if self.config.loss_fct == 'cross_entropy':
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
elif self.config.loss_fct == 'additive_margin':
self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels)
elif self.config.loss_fct == 'additive_angular_margin':
self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels)
else:
raise ValueError(f"Unsupported loss function: {self.config.loss_fct}")
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
not be updated during training.
"""
warnings.warn(
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
"Please use the equivalent `freeze_feature_encoder` method instead.",
FutureWarning,
)
self.freeze_feature_encoder()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wavlm
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
not be updated during training.
"""
self.wavlm.feature_extractor._freeze_parameters()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->wavlm
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.wavlm.parameters():
param.requires_grad = False
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->WavLM, wav2vec2->wavlm
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
) -> Union[Tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
outputs = self.wavlm(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.use_weighted_layer_sum:
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = outputs[0]
hidden_states = self.projector(hidden_states)
if attention_mask is None:
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.config.loss_fct == 'cross_entropy':
loss_fct = nn.CrossEntropyLoss(
label_smoothing=self.config.label_smoothing,
reduction=self.config.reduction
)
elif self.config.loss_fct == 'additive_margin':
loss_fct = AMSoftmaxLoss(
scale=self.config.scale,
margin=self.config.margin,
label_smoothing=self.config.label_smoothing,
reduction=self.config.reduction
)
elif self.config.loss_fct == 'additive_angular_margin':
loss_fct = AAMSoftmaxLoss(
scale=self.config.scale,
margin=self.config.margin,
easy_margin=self.config.easy_margin,
label_smoothing=self.config.label_smoothing,
reduction=self.config.reduction
)
loss = loss_fct(
logits.view(-1, self.config.num_labels),
labels.view(-1),
)
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)