Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
import numpy as np | |
import torch.nn as nn | |
from functools import partial | |
import torch.nn.functional as F | |
from typing import Callable, Dict | |
from funasr_detach.models.emotion2vec.fairseq_modules import ( | |
LayerNorm, | |
SamePad, | |
TransposeLast, | |
ConvFeatureExtractionModel, | |
) | |
from funasr_detach.models.emotion2vec.modules import Modality, BlockEncoder, Decoder1d | |
from funasr_detach.models.emotion2vec.base import ( | |
ModalitySpecificEncoder, | |
get_alibi_bias, | |
) | |
class AudioEncoder(ModalitySpecificEncoder): | |
def __init__( | |
self, | |
modality_cfg, | |
embed_dim: int, | |
make_block: Callable[[float], nn.ModuleList], | |
norm_layer: Callable[[int], nn.LayerNorm], | |
layer_norm_first: bool, | |
alibi_biases: Dict, | |
): | |
self.feature_enc_layers = eval(modality_cfg.feature_encoder_spec) | |
feature_embed_dim = self.feature_enc_layers[-1][0] | |
local_encoder = ConvFeatureExtractionModel( | |
conv_layers=self.feature_enc_layers, | |
dropout=0.0, | |
mode=modality_cfg.extractor_mode, | |
conv_bias=False, | |
) | |
project_features = nn.Sequential( | |
TransposeLast(), | |
nn.LayerNorm(feature_embed_dim), | |
nn.Linear(feature_embed_dim, embed_dim), | |
) | |
num_pos_layers = modality_cfg.conv_pos_depth | |
k = max(3, modality_cfg.conv_pos_width // num_pos_layers) | |
positional_encoder = nn.Sequential( | |
TransposeLast(), | |
*[ | |
nn.Sequential( | |
nn.Conv1d( | |
embed_dim, | |
embed_dim, | |
kernel_size=k, | |
padding=k // 2, | |
groups=modality_cfg.conv_pos_groups, | |
), | |
SamePad(k), | |
TransposeLast(), | |
LayerNorm(embed_dim, elementwise_affine=False), | |
TransposeLast(), | |
nn.GELU(), | |
) | |
for _ in range(num_pos_layers) | |
], | |
TransposeLast(), | |
) | |
if modality_cfg.conv_pos_pre_ln: | |
positional_encoder = nn.Sequential(LayerNorm(embed_dim), positional_encoder) | |
dpr = np.linspace( | |
modality_cfg.start_drop_path_rate, | |
modality_cfg.end_drop_path_rate, | |
modality_cfg.prenet_depth, | |
) | |
context_encoder = BlockEncoder( | |
nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)), | |
norm_layer(embed_dim) if not layer_norm_first else None, | |
layer_norm_first, | |
modality_cfg.prenet_layerdrop, | |
modality_cfg.prenet_dropout, | |
) | |
decoder = ( | |
Decoder1d(modality_cfg.decoder, embed_dim) | |
if modality_cfg.decoder is not None | |
else None | |
) | |
alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases) | |
super().__init__( | |
modality_cfg=modality_cfg, | |
embed_dim=embed_dim, | |
local_encoder=local_encoder, | |
project_features=project_features, | |
fixed_positional_encoder=None, | |
relative_positional_encoder=positional_encoder, | |
context_encoder=context_encoder, | |
decoder=decoder, | |
get_alibi_bias=alibi_bias_fn, | |
) | |
def convert_padding_mask(self, x, padding_mask): | |
def get_feat_extract_output_lengths(input_lengths: torch.LongTensor): | |
""" | |
Computes the output length of the convolutional layers | |
""" | |
def _conv_out_length(input_length, kernel_size, stride): | |
return torch.floor((input_length - kernel_size) / stride + 1) | |
for i in range(len(self.feature_enc_layers)): | |
input_lengths = _conv_out_length( | |
input_lengths, | |
self.feature_enc_layers[i][1], | |
self.feature_enc_layers[i][2], | |
) | |
return input_lengths.to(torch.long) | |
if padding_mask is not None: | |
input_lengths = (1 - padding_mask.long()).sum(-1) | |
# apply conv formula to get real output_lengths | |
output_lengths = get_feat_extract_output_lengths(input_lengths) | |
if padding_mask.any(): | |
padding_mask = torch.zeros(x.shape[:2], dtype=x.dtype, device=x.device) | |
# these two operations makes sure that all values | |
# before the output lengths indices are attended to | |
padding_mask[ | |
( | |
torch.arange(padding_mask.shape[0], device=padding_mask.device), | |
output_lengths - 1, | |
) | |
] = 1 | |
padding_mask = ( | |
1 - padding_mask.flip([-1]).cumsum(-1).flip([-1]) | |
).bool() | |
else: | |
padding_mask = torch.zeros( | |
x.shape[:2], dtype=torch.bool, device=x.device | |
) | |
return padding_mask | |
def reset_parameters(self): | |
super().reset_parameters() | |
for mod in self.project_features.children(): | |
if isinstance(mod, nn.Linear): | |
mod.reset_parameters() | |
if self.decoder is not None: | |
self.decoder.reset_parameters() | |