wavlm-large / s3prl_s3prl_main /s3prl /dataset /pretrain_tera_pipe.py
lmzjms's picture
Upload 1162 files
0b32ad6 verified
from .base import SequentialDataPipe
from .common_pipes import LoadAudio, SetOutputKeys
from .extract_feat_pipes import ExtractOnlineFeat
from .masked_reconstruction_pipes import MaskedReconstruction, PrepareTargetFeat
from .noise_augmentation_pipes import NoiseAugmentation
from .norm_wav_pipes import NormWavDecibel
class PretrainTeraPipe(SequentialDataPipe):
"""
each item in the input dataset should have:
wav_path: str
"""
def __init__(
self,
output_keys: dict = None,
position_encoding_size: int = 768,
mask_proportion: float = 0.15,
mask_consecutive_min: int = 7,
mask_consecutive_max: int = 7,
mask_allow_overlap: bool = True,
mask_bucket_ratio: float = 1.5,
mask_frequency: int = 0.2,
noise_proportion: float = 0.0,
win_ms: int = 25,
hop_ms: int = 10,
n_freq: int = 201,
n_mels: int = 80,
n_mfcc: int = 13,
input: dict = {
"channel": 0,
"cmvn": True,
"delta": 0,
"feat_type": "mel",
"log": True,
},
target: dict = {
"channel": 1,
"cmvn": True,
"delta": 0,
"feat_type": "mel",
"log": True,
},
target_level: int = -25,
audio_sample_rate: int = 16000,
audio_channel_reduction: str = "first",
n_jobs: int = 6,
):
"""
Args:
output_keys (dict): args for the output handle
position_encoding_size (int): this should be identical to `hidden_size`
mask_proportion (float): mask this percentage of all spectrogram frames in each sequence at random during MAM training
mask_consecutive_min (int): mask this amount of consecutive frames
mask_consecutive_max (int): mask this amount of consecutive frames
mask_allow_overlap (bool): allow overlap masking
mask_bucket_ratio (float): only used when overlap is not allowed. sample a mask from each bucket in size of [sampled mask_consecutive * mask_bucket_ratio]
mask_frequency (float): mask maximum this percentage of frequency bands, set to 0 for no frequency mask
noise_proportion (float): for this percentage of the time, Gaussian noise will be applied on all frames during MAM training, set to 0 for no noise
win_ms (int): window size in ms
hop_ms (int): hop size in ms
n_freq (int): number of frequency bins
n_mels (int): number of mel features
n_mfcc (int): number of mfcc features
input (dict): args for the input feat, example - {"channel": 0, "cmvn": True, "delta": 0, "feat_type": "mel", "log": True,}
target (dict): args for the output feat, example - {"channel": 1, "cmvn": True, "delta": 0, "feat_type": "mel", "log": True,}
target_level (int): normalize the wav decibel level to the target value
audio_sample_rate (int): audio sample rate
audio_channel_reduction (str): "first" channel
n_jobs (int): number of workers
"""
output_keys = output_keys or dict(
x="masked_feat",
label="target_feat",
label_mask="label_mask",
position_encoding="pos_enc",
attention_mask="attn_mask",
unique_name="id",
)
super().__init__(
LoadAudio(
n_jobs=n_jobs,
audio_sample_rate=audio_sample_rate,
audio_channel_reduction=audio_channel_reduction,
),
NormWavDecibel(
target_level=target_level,
),
ExtractOnlineFeat(
win_ms=win_ms,
hop_ms=hop_ms,
n_freq=n_freq,
n_mels=n_mels,
n_mfcc=n_mfcc,
input=input,
target=target,
feat_name="source_feat",
),
PrepareTargetFeat(
use_copy=True,
source_feat_name="source_feat",
target_feat_name="target_feat",
),
NoiseAugmentation(
noise_proportion=noise_proportion,
input_feat_name="source_feat",
output_feat_name="noised_feat",
),
MaskedReconstruction(
position_encoding_size=position_encoding_size,
mask_proportion=mask_proportion,
mask_consecutive_min=mask_consecutive_min,
mask_consecutive_max=mask_consecutive_max,
mask_allow_overlap=mask_allow_overlap,
mask_bucket_ratio=mask_bucket_ratio,
mask_frequency=mask_frequency,
source_feat_name="noised_feat",
target_feat_name="target_feat",
masked_feat_name="masked_feat",
pos_enc_name="pos_enc",
attn_mask_name="attn_mask",
label_mask_name="label_mask",
),
SetOutputKeys(output_keys=output_keys),
)