wavlm-large / s3prl_s3prl_main /s3prl /dataset /extract_feat_pipes.py
lmzjms's picture
Upload 1162 files
0b32ad6 verified
from dataclasses import dataclass
from s3prl.upstream.apc.audio import create_transform as apc_feat_extracter
from s3prl.upstream.baseline.extracter import get_extracter as kaldi_feat_extracter
from s3prl.upstream.baseline.preprocessor import (
get_preprocessor as online_feat_extracter,
)
from s3prl.upstream.npc.audio import create_transform as npc_feat_extracter
from .base import AugmentedDynamicItemDataset, DataPipe
@dataclass
class ExtractKaldiFeat(DataPipe):
kaldi: dict = None
delta: dict = None
cmvn: dict = None
wav_name: str = "wav"
feat_name: str = "feat"
"""
Args:
kaldi (dict): args for the kaldi extracter
delta (dict): args for applying delta on features
cmvn (dict): args for applying cmvn on features
wav_name (str): handle for the `takes` (input)
feat_name (str): handle for the `provides` (output)
"""
def extract_feat(self, extracter, wav):
"""
(wav_seq_len, 1) -> extracter -> (feat_seq_len, feat_dim)
"""
feat = extracter(wav)
return feat
def __call__(self, dataset: AugmentedDynamicItemDataset):
_audio_config = {
"kaldi": self.kaldi,
"delta": self.delta,
"cmvn": self.cmvn,
}
extracter, feat_dim, frame_shift = kaldi_feat_extracter(_audio_config)
dataset.add_tool("extracter", extracter)
dataset.add_tool("feat_dim", feat_dim)
dataset.add_tool("frame_shift", frame_shift)
dataset.add_dynamic_item(
self.extract_feat,
takes=["extracter", self.wav_name],
provides=self.feat_name,
)
return dataset
@dataclass
class ExtractOnlineFeat(DataPipe):
win_ms: int = 25
hop_ms: int = 10
n_freq: int = 201
n_mels: int = 80
n_mfcc: int = 13
input: dict = None
target: dict = None
wav_name: str = "wav"
feat_name: str = "feat"
"""
Args:
win_ms (int): window size in ms
hop_ms (int): hop size in ms
n_freq (int): number of frequency bins
n_mels (int): number of mel features
n_mfcc (int): number of mfcc features
input (dict): args for the input feat, example - {"channel": 0, "cmvn": True, "delta": 0, "feat_type": "mel", "log": True,}
target (dict): args for the output feat, example - {"channel": 1, "cmvn": True, "delta": 0, "feat_type": "mel", "log": True,}
wav_name (str): handle for the `takes` (input)
feat_name (str): handle for the `provides` (output)
"""
def extract_feat(self, extracter, wav):
"""
(wav_seq_len, 1) -> permute + unsqueeze ->
(1, 1, wav_seq_len) -> extracter -> (feat_seq_len, feat_dim)
"""
wav = wav.permute(1, 0).unsqueeze(0)
feat = extracter(wav)[0][0]
return feat
def __call__(self, dataset: AugmentedDynamicItemDataset):
_audio_config = {
"win_ms": self.win_ms,
"hop_ms": self.hop_ms,
"n_freq": self.n_freq,
"n_mels": self.n_mels,
"n_mfcc": self.n_mfcc,
"input": self.input,
"target": self.target,
}
extracter, feat_dim, _ = online_feat_extracter(_audio_config)
dataset.add_tool("extracter", extracter)
dataset.add_tool("feat_dim", feat_dim)
dataset.add_dynamic_item(
self.extract_feat,
takes=["extracter", self.wav_name],
provides=self.feat_name,
)
return dataset
@dataclass
class ExtractApcFeat(DataPipe):
feat_type: str = "fbank"
feat_dim: int = 80
frame_length: int = 25
frame_shift: int = 10
decode_wav: bool = False
cmvn: bool = True
wav_name: str = "wav"
feat_name: str = "feat"
"""
Args:
feat_type (str): feature type
feat_dim (int): feature dimension
frame_length (int): window size in ms
frame_shift (int): hop size in ms
decode_wav (bool): whether to decode wav
cmvn (bool): whether to apply uttr.-wised CMVN on feature
wav_name (str): handle for the `takes` (input)
feat_name (str): handle for the `provides` (output)
"""
def extract_feat(self, extracter, wav):
"""
(wav_seq_len, 1) -> permute ->
(1, wav_seq_len) -> extracter -> (feat_seq_len, feat_dim)
"""
feat = extracter(wav.permute(1, 0))
return feat
def __call__(self, dataset: AugmentedDynamicItemDataset):
_audio_config = {
"feat_type": self.feat_type,
"feat_dim": self.feat_dim,
"frame_length": self.frame_length,
"frame_shift": self.frame_shift,
"decode_wav": self.decode_wav,
"cmvn": self.cmvn,
}
extracter, feat_dim = apc_feat_extracter(_audio_config)
dataset.add_tool("extracter", extracter)
dataset.add_tool("feat_dim", feat_dim)
dataset.add_dynamic_item(
self.extract_feat,
takes=["extracter", self.wav_name],
provides=self.feat_name,
)
return dataset
@dataclass
class ExtractNpcFeat(DataPipe):
feat_type: str = "fbank"
feat_dim: int = 80
frame_length: int = 25
frame_shift: int = 10
decode_wav: bool = False
cmvn: bool = True
wav_name: str = "wav"
feat_name: str = "feat"
"""
Args:
feat_type (str): feature type
feat_dim (int): feature dimension
frame_length (int): window size in ms
frame_shift (int): hop size in ms
decode_wav (bool): whether to decode wav
cmvn (bool): whether to apply uttr.-wised CMVN on feature
wav_name (str): handle for the `takes` (input)
feat_name (str): handle for the `provides` (output)
"""
def extract_feat(self, extracter, wav):
"""
(wav_seq_len, 1) -> permute ->
(1, wav_seq_len) -> extracter -> (feat_seq_len, feat_dim)
"""
feat = extracter(wav.permute(1, 0))
return feat
def __call__(self, dataset: AugmentedDynamicItemDataset):
_audio_config = {
"feat_type": self.feat_type,
"feat_dim": self.feat_dim,
"frame_length": self.frame_length,
"frame_shift": self.frame_shift,
"decode_wav": self.decode_wav,
"cmvn": self.cmvn,
}
extracter, feat_dim = npc_feat_extracter(_audio_config)
dataset.add_tool("extracter", extracter)
dataset.add_tool("feat_dim", feat_dim)
dataset.add_dynamic_item(
self.extract_feat,
takes=["extracter", self.wav_name],
provides=self.feat_name,
)
return dataset