import math import warnings from typing import Union, Tuple, Optional import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from transformers.modeling_utils import PreTrainedModel from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutput from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.fsdp import is_fsdp_managed_module from transformers.models.hubert.modeling_hubert import ( HubertFeatureEncoder, HubertFeatureProjection, HubertEncoderStableLayerNorm, HubertEncoder, _HIDDEN_STATES_START_POSITION ) from .configuration_hubert_spkreg import HubertSpkRegConfig class HubertSpkRegPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = HubertSpkRegConfig base_model_prefix = "hubert" main_input_name = "input_values" supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, nn.Conv1d): if is_deepspeed_zero3_enabled(): import deepspeed if hasattr(module, "weight_v") and hasattr(module, "weight_g"): with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): nn.init.kaiming_normal_(module.weight.data) else: with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): nn.init.kaiming_normal_(module.weight.data) else: nn.init.kaiming_normal_(module.weight.data) if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ Computes the output length of the convolutional layers """ def _conv_out_length(input_length, kernel_size, stride): # 1D convolutional layer output length formula taken # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): input_lengths = _conv_out_length(input_lengths, kernel_size, stride) return input_lengths def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) batch_size = attention_mask.shape[0] attention_mask = torch.zeros( (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device ) # these two operations makes sure that all values before the output lengths idxs are attended to attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices def _compute_mask_indices( shape: Tuple[int, int], mask_prob: float, mask_length: int, attention_mask: Optional[torch.LongTensor] = None, min_masks: int = 0, ) -> np.ndarray: """ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on CPU as part of the preprocessing during training. Args: shape: The shape for which to compute masks. This should be of a tuple of size 2 where the first element is the batch size and the second element is the length of the axis to span. mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of independently generated mask spans of length `mask_length` is computed by `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the actual percentage will be smaller. mask_length: size of the mask min_masks: minimum number of masked spans attention_mask: A (right-padded) attention mask which independently shortens the feature axis of each batch dimension. """ batch_size, sequence_length = shape if mask_length < 1: raise ValueError("`mask_length` has to be bigger than 0.") if mask_length > sequence_length: raise ValueError( f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" f" and `sequence_length`: {sequence_length}`" ) # epsilon is used for probabilistic rounding epsilon = np.random.rand(1).item() def compute_num_masked_span(input_length): """Given input length, compute how many spans should be masked""" num_masked_span = int(mask_prob * input_length / mask_length + epsilon) num_masked_span = max(num_masked_span, min_masks) # make sure num masked span <= sequence_length if num_masked_span * mask_length > sequence_length: num_masked_span = sequence_length // mask_length # make sure num_masked span is also <= input_length - (mask_length - 1) if input_length - (mask_length - 1) < num_masked_span: num_masked_span = max(input_length - (mask_length - 1), 0) return num_masked_span # compute number of masked spans in batch input_lengths = ( attention_mask.sum(-1).detach().tolist() if attention_mask is not None else [sequence_length for _ in range(batch_size)] ) # SpecAugment mask to fill spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) spec_aug_mask_idxs = [] max_num_masked_span = compute_num_masked_span(sequence_length) if max_num_masked_span == 0: return spec_aug_mask for input_length in input_lengths: # compute num of masked spans for this input num_masked_span = compute_num_masked_span(input_length) # get random indices to mask spec_aug_mask_idx = np.random.choice( np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False ) # pick first sampled index that will serve as a dummy index to pad vector # to ensure same dimension for all batches due to probabilistic rounding # Picking first sample just pads those vectors twice. if len(spec_aug_mask_idx) == 0: # this case can only happen if `input_length` is strictly smaller then # `sequence_length` in which case the last token has to be a padding # token which we can use as a dummy mask id dummy_mask_idx = sequence_length - 1 else: dummy_mask_idx = spec_aug_mask_idx[0] spec_aug_mask_idx = np.concatenate( [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] ) spec_aug_mask_idxs.append(spec_aug_mask_idx) spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) # expand masked indices to masked spans spec_aug_mask_idxs = np.broadcast_to( spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) ) spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) # add offset to the starting indexes so that indexes now create a span offsets = np.arange(mask_length)[None, None, :] offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( batch_size, max_num_masked_span * mask_length ) spec_aug_mask_idxs = spec_aug_mask_idxs + offsets # ensure that we cannot have indices larger than sequence_length if spec_aug_mask_idxs.max() > sequence_length - 1: spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 # scatter indices to mask np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) return spec_aug_mask class HubertSpkRegModel(HubertSpkRegPreTrainedModel): def __init__(self, config: HubertSpkRegConfig): super().__init__(config) self.config = config self.feature_extractor = HubertFeatureEncoder(config) self.feature_projection = HubertFeatureProjection(config) if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) if config.do_stable_layer_norm: self.encoder = HubertEncoderStableLayerNorm(config) else: self.encoder = HubertEncoder(config) # Initialize weights and apply final processing self.post_init() # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states def _mask_hidden_states( self, hidden_states: torch.FloatTensor, mask_time_indices: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): """ Masks extracted features along time axis and/or along feature axis according to [SpecAugment](https://arxiv.org/abs/1904.08779). """ # `config.apply_spec_augment` can set masking to False if not getattr(self.config, "apply_spec_augment", True): return hidden_states # generate indices & apply SpecAugment along time axis batch_size, sequence_length, hidden_size = hidden_states.size() if mask_time_indices is not None: # apply SpecAugment along time axis with given mask_time_indices hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) elif self.config.mask_time_prob > 0 and self.training: mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), mask_prob=self.config.mask_time_prob, mask_length=self.config.mask_time_length, attention_mask=attention_mask, min_masks=self.config.mask_time_min_masks, ) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) if self.config.mask_feature_prob > 0 and self.training: # generate indices & apply SpecAugment along feature axis mask_feature_indices = _compute_mask_indices( (batch_size, hidden_size), mask_prob=self.config.mask_feature_prob, mask_length=self.config.mask_feature_length, min_masks=self.config.mask_feature_min_masks, ) mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) hidden_states[mask_feature_indices] = 0 return hidden_states def forward( self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, mask_time_indices: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: """ Returns: Example: ```python >>> from transformers import AutoProcessor, HubertModel >>> from datasets import load_dataset >>> import soundfile as sf >>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft") >>> model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft") >>> def map_to_array(batch): ... speech, _ = sf.read(batch["file"]) ... batch["speech"] = speech ... return batch >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> ds = ds.map(map_to_array) >>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1 >>> hidden_states = model(input_values).last_hidden_state ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict extract_features = self.feature_extractor(input_values) extract_features = extract_features.transpose(1, 2) if attention_mask is not None: # compute reduced attention_mask corresponding to feature vectors attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) hidden_states = self.feature_projection(extract_features) hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices) encoder_outputs = self.encoder( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = encoder_outputs[0] if not return_dict: return (hidden_states,) + encoder_outputs[1:] return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) class AngularLinear(nn.Module): def __init__(self, in_features: int, out_features: int): super(AngularLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = torch.nn.Parameter( torch.FloatTensor(out_features, in_features), requires_grad=True ) nn.init.xavier_normal_(self.weight, gain=1) def forward( self, inputs: torch.Tensor, ): # Calculation of cos(theta) cosine = F.linear(F.normalize(inputs), F.normalize(self.weight)) return cosine def extra_repr(self) -> str: return 'in_features={}, out_features={}'.format( self.in_features, self.out_features ) class AMSoftmaxLoss(nn.Module): """Additive Margin Softmax (CosFace). Paper: Wang, Feng, et al. "Additive margin softmax for face verification." IEEE Signal Processing Letters 25.7 (2018): 926-930. """ def __init__( self, scale: float = 30.0, margin: float = 0.35, label_smoothing: float = 0.0, reduction: str = "mean" ): """ Args: num_classes: Number of classes (output dimension) scale: Scaling factor for logits (default: 30.0) margin: Angular margin (default: 0.35) """ super(AMSoftmaxLoss, self).__init__() self.scale = scale self.margin = margin self.label_smoothing = label_smoothing self.reduction = reduction def forward( self, inputs: torch.Tensor, targets: torch.Tensor, ): """ Args: inputs: Input features of shape (batch_size, num_labels) targets: Ground truth labels of shape (batch_size) label_smoothing: Label smoothing factor (default: 0.0) reduction: Reduction method (default: "mean") Returns: Loss value """ _, num_labels = inputs.shape # `inputs` are the outputs from AngularLinear() cos_theta = torch.clamp(inputs, -1.0 + 1e-7, 1.0 - 1e-7) psi = cos_theta - self.margin one_hot = nn.functional.one_hot(targets, num_labels) outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta) loss = F.cross_entropy( outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction ) return loss class AAMSoftmaxLoss(nn.Module): """Additive Angular Margin Softmax (ArcFace). Paper: Deng, Jiankang, et al. "Arcface: Additive angular margin loss for deep face recognition." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019. """ def __init__( self, scale: float = 30.0, margin: float = 0.2, easy_margin: bool = False, label_smoothing: float = 0.0, reduction: str = "mean" ): """ Args: num_classes: Number of classes (output dimension) scale: Scaling factor for logits (default: 30.0) margin: Angular margin (default: 0.35) easy_margin: Use the easy margin loss (default: False) """ super(AAMSoftmaxLoss, self).__init__() self.scale = scale self.margin = margin self.easy_margin = easy_margin self.label_smoothing = label_smoothing self.reduction = reduction def forward( self, inputs: torch.Tensor, targets: torch.Tensor, ): """ Args: inputs: Input features of shape (batch_size, num_labels) targets: Ground truth labels of shape (batch_size) Returns: Loss value """ _, num_labels = inputs.shape # `inputs` are the outputs from AngularLinear() epsilon = 1e-6 # theta = torch.acos(cos_theta) # psi = torch.cos(theta + self.margin) cos_theta = torch.clamp(inputs, -1.0 + epsilon, 1.0 - epsilon) sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2)) sin_theta = torch.clamp(sin_theta, 0.0 + epsilon, 1.0 - epsilon) cos_m = math.cos(self.margin) sin_m = math.sin(self.margin) psi = cos_theta * cos_m - sin_theta * sin_m # cos(theta + m) if self.easy_margin: psi = torch.where(cos_theta > 0, psi, cos_theta) else: # Make the function cos(theta+m) monotonic decreasing while theta in [0°, 180°] psi = torch.where((cos_theta - math.cos(math.pi - self.margin)) > 0, psi, cos_theta - self.margin) one_hot = nn.functional.one_hot(targets, num_labels) outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta) loss = F.cross_entropy( outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction ) return loss class HubertSpkRegForSequenceClassification(HubertSpkRegPreTrainedModel): def __init__(self, config): super().__init__(config) if hasattr(config, "add_adapter") and config.add_adapter: raise ValueError( "Sequence classification does not support the use of Hubert adapters (config.add_adapter=True)" ) self.hubert = HubertSpkRegModel(config) num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings if config.use_weighted_layer_sum: self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) if self.config.loss_fct == 'cross_entropy': self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) elif self.config.loss_fct == 'additive_margin': self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels) elif self.config.loss_fct == 'additive_angular_margin': self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels) else: raise ValueError(f"Unsupported loss function: {self.config.loss_fct}") # Initialize weights and apply final processing self.post_init() def freeze_feature_extractor(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameters will not be updated during training. """ warnings.warn( "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " "Please use the equivalent `freeze_feature_encoder` method instead.", FutureWarning, ) self.freeze_feature_encoder() def freeze_feature_encoder(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameter will not be updated during training. """ self.hubert.feature_extractor._freeze_parameters() def freeze_base_model(self): """ Calling this function will disable the gradient computation for the base model so that its parameters will not be updated during training. Only the classification head will be updated. """ for param in self.hubert.parameters(): param.requires_grad = False def forward( self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.Tensor] = None, ) -> Union[Tuple, SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states outputs = self.hubert( input_values, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if self.config.use_weighted_layer_sum: hidden_states = outputs[_HIDDEN_STATES_START_POSITION] hidden_states = torch.stack(hidden_states, dim=1) norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) else: hidden_states = outputs[0] hidden_states = self.projector(hidden_states) if attention_mask is None: pooled_output = hidden_states.mean(dim=1) else: padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) hidden_states[~padding_mask] = 0.0 pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) logits = self.classifier(pooled_output) loss = None if labels is not None: if self.config.loss_fct == 'cross_entropy': loss_fct = nn.CrossEntropyLoss( label_smoothing=self.config.label_smoothing, reduction=self.config.reduction ) elif self.config.loss_fct == 'additive_margin': loss_fct = AMSoftmaxLoss( scale=self.config.scale, margin=self.config.margin, label_smoothing=self.config.label_smoothing, reduction=self.config.reduction ) elif self.config.loss_fct == 'additive_angular_margin': loss_fct = AAMSoftmaxLoss( scale=self.config.scale, margin=self.config.margin, easy_margin=self.config.easy_margin, label_smoothing=self.config.label_smoothing, reduction=self.config.reduction ) loss = loss_fct( logits.view(-1, self.config.num_labels), labels.view(-1), ) if not return_dict: output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )