wavlm-large / s3prl_s3prl_main /s3prl /dataset /pretrain_mockingjay_pipe.py
lmzjms's picture
Upload 1162 files
0b32ad6 verified
from .base import SequentialDataPipe
from .common_pipes import LoadAudio, SetOutputKeys
from .extract_feat_pipes import ExtractKaldiFeat
from .masked_reconstruction_pipes import MaskedReconstruction, PrepareTargetFeat
class PretrainMockingjayPipe(SequentialDataPipe):
"""
each item in the input dataset should have:
wav_path: str
"""
def __init__(
self,
output_keys: dict = None,
position_encoding_size: int = 768,
mask_proportion: float = 0.15,
mask_consecutive_min: int = 7,
mask_consecutive_max: int = 7,
mask_allow_overlap: bool = True,
mask_bucket_ratio: float = 1.5,
mask_frequency: int = 0.2,
kaldi: dict = {
"feat_type": "fbank",
"fbank": {
"frame_length": 25.0,
"frame_shift": 10.0,
"num_mel_bins": 80, # because delta={"order": 2}
"use_log_fbank": True,
},
"mfcc": {"frame_length": 25.0, "frame_shift": 10.0, "num_ceps": 13},
"spectrogram": {"frame_length": 25.0, "frame_shift": 10.0},
},
delta: dict = {"order": 2, "win_length": 5},
cmvn: dict = {"use_cmvn": True},
audio_sample_rate: int = 16000,
audio_channel_reduction: str = "first",
n_jobs: int = 6,
):
"""
Args:
output_keys (dict): args for the output handle
position_encoding_size (int): this should be identical to `hidden_size`
mask_proportion (float): mask this percentage of all spectrogram frames in each sequence at random during MAM training
mask_consecutive_min (int): mask this amount of consecutive frames
mask_consecutive_max (int): mask this amount of consecutive frames
mask_allow_overlap (bool): allow overlap masking
mask_bucket_ratio (float): only used when overlap is not allowed. sample a mask from each bucket in size of [sampled mask_consecutive * mask_bucket_ratio]
mask_frequency (float): mask maximum this percentage of frequency bands, set to 0 for no frequency mask
kaldi (dict): args for the kaldi extracter
delta (dict): args for applying delta on features
cmvn (dict): args for applying cmvn on features
n_mels (int): number of mel features
n_mfcc (int): number of mfcc features
audio_sample_rate (int): audio sample rate
audio_channel_reduction (str): "first" channel
n_jobs (int): number of workers
"""
output_keys = output_keys or dict(
x="masked_feat",
label="target_feat",
label_mask="label_mask",
position_encoding="pos_enc",
attention_mask="attn_mask",
unique_name="id",
)
super().__init__(
LoadAudio(
n_jobs=n_jobs,
audio_sample_rate=audio_sample_rate,
audio_channel_reduction=audio_channel_reduction,
),
ExtractKaldiFeat(
kaldi=kaldi, delta=delta, cmvn=cmvn, feat_name="source_feat"
),
PrepareTargetFeat(
use_copy=True,
source_feat_name="source_feat",
target_feat_name="target_feat",
),
MaskedReconstruction(
position_encoding_size=position_encoding_size,
mask_proportion=mask_proportion,
mask_consecutive_min=mask_consecutive_min,
mask_consecutive_max=mask_consecutive_max,
mask_allow_overlap=mask_allow_overlap,
mask_bucket_ratio=mask_bucket_ratio,
mask_frequency=mask_frequency,
source_feat_name="source_feat",
target_feat_name="target_feat",
masked_feat_name="masked_feat",
pos_enc_name="pos_enc",
attn_mask_name="attn_mask",
label_mask_name="label_mask",
),
SetOutputKeys(output_keys=output_keys),
)