|
from typing import List |
|
from typing import Tuple |
|
from typing import Union |
|
|
|
import librosa |
|
import numpy as np |
|
import torch |
|
from torch_complex.tensor import ComplexTensor |
|
|
|
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask |
|
|
|
|
|
class FeatureTransform(torch.nn.Module): |
|
def __init__( |
|
self, |
|
|
|
fs: int = 16000, |
|
n_fft: int = 512, |
|
n_mels: int = 80, |
|
fmin: float = 0.0, |
|
fmax: float = None, |
|
|
|
stats_file: str = None, |
|
apply_uttmvn: bool = True, |
|
uttmvn_norm_means: bool = True, |
|
uttmvn_norm_vars: bool = False, |
|
): |
|
super().__init__() |
|
self.apply_uttmvn = apply_uttmvn |
|
|
|
self.logmel = LogMel(fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) |
|
self.stats_file = stats_file |
|
if stats_file is not None: |
|
self.global_mvn = GlobalMVN(stats_file) |
|
else: |
|
self.global_mvn = None |
|
|
|
if self.apply_uttmvn is not None: |
|
self.uttmvn = UtteranceMVN( |
|
norm_means=uttmvn_norm_means, norm_vars=uttmvn_norm_vars |
|
) |
|
else: |
|
self.uttmvn = None |
|
|
|
def forward( |
|
self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]] |
|
) -> Tuple[torch.Tensor, torch.LongTensor]: |
|
|
|
if x.dim() not in (3, 4): |
|
raise ValueError(f"Input dim must be 3 or 4: {x.dim()}") |
|
if not torch.is_tensor(ilens): |
|
ilens = torch.from_numpy(np.asarray(ilens)).to(x.device) |
|
|
|
if x.dim() == 4: |
|
|
|
if self.training: |
|
|
|
ch = np.random.randint(x.size(2)) |
|
h = x[:, :, ch, :] |
|
else: |
|
|
|
h = x[:, :, 0, :] |
|
else: |
|
h = x |
|
|
|
|
|
h = h.real ** 2 + h.imag ** 2 |
|
|
|
h, _ = self.logmel(h, ilens) |
|
if self.stats_file is not None: |
|
h, _ = self.global_mvn(h, ilens) |
|
if self.apply_uttmvn: |
|
h, _ = self.uttmvn(h, ilens) |
|
|
|
return h, ilens |
|
|
|
|
|
class LogMel(torch.nn.Module): |
|
"""Convert STFT to fbank feats |
|
|
|
The arguments is same as librosa.filters.mel |
|
|
|
Args: |
|
fs: number > 0 [scalar] sampling rate of the incoming signal |
|
n_fft: int > 0 [scalar] number of FFT components |
|
n_mels: int > 0 [scalar] number of Mel bands to generate |
|
fmin: float >= 0 [scalar] lowest frequency (in Hz) |
|
fmax: float >= 0 [scalar] highest frequency (in Hz). |
|
If `None`, use `fmax = fs / 2.0` |
|
htk: use HTK formula instead of Slaney |
|
norm: {None, 1, np.inf} [scalar] |
|
if 1, divide the triangular mel weights by the width of the mel band |
|
(area normalization). Otherwise, leave all the triangles aiming for |
|
a peak value of 1.0 |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
fs: int = 16000, |
|
n_fft: int = 512, |
|
n_mels: int = 80, |
|
fmin: float = 0.0, |
|
fmax: float = None, |
|
htk: bool = False, |
|
norm=1, |
|
): |
|
super().__init__() |
|
|
|
_mel_options = dict( |
|
sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm |
|
) |
|
self.mel_options = _mel_options |
|
|
|
|
|
melmat = librosa.filters.mel(**_mel_options) |
|
|
|
self.register_buffer("melmat", torch.from_numpy(melmat.T).float()) |
|
|
|
def extra_repr(self): |
|
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items()) |
|
|
|
def forward( |
|
self, feat: torch.Tensor, ilens: torch.LongTensor |
|
) -> Tuple[torch.Tensor, torch.LongTensor]: |
|
|
|
mel_feat = torch.matmul(feat, self.melmat) |
|
|
|
logmel_feat = (mel_feat + 1e-20).log() |
|
|
|
logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0) |
|
return logmel_feat, ilens |
|
|
|
|
|
class GlobalMVN(torch.nn.Module): |
|
"""Apply global mean and variance normalization |
|
|
|
Args: |
|
stats_file(str): npy file of 1-dim array or text file. |
|
From the _first element to |
|
the {(len(array) - 1) / 2}th element are treated as |
|
the sum of features, |
|
and the rest excluding the last elements are |
|
treated as the sum of the square value of features, |
|
and the last elements eqauls to the number of samples. |
|
std_floor(float): |
|
""" |
|
|
|
def __init__( |
|
self, |
|
stats_file: str, |
|
norm_means: bool = True, |
|
norm_vars: bool = True, |
|
eps: float = 1.0e-20, |
|
): |
|
super().__init__() |
|
self.norm_means = norm_means |
|
self.norm_vars = norm_vars |
|
|
|
self.stats_file = stats_file |
|
stats = np.load(stats_file) |
|
|
|
stats = stats.astype(float) |
|
assert (len(stats) - 1) % 2 == 0, stats.shape |
|
|
|
count = stats.flatten()[-1] |
|
mean = stats[: (len(stats) - 1) // 2] / count |
|
var = stats[(len(stats) - 1) // 2 : -1] / count - mean * mean |
|
std = np.maximum(np.sqrt(var), eps) |
|
|
|
self.register_buffer("bias", torch.from_numpy(-mean.astype(np.float32))) |
|
self.register_buffer("scale", torch.from_numpy(1 / std.astype(np.float32))) |
|
|
|
def extra_repr(self): |
|
return ( |
|
f"stats_file={self.stats_file}, " |
|
f"norm_means={self.norm_means}, norm_vars={self.norm_vars}" |
|
) |
|
|
|
def forward( |
|
self, x: torch.Tensor, ilens: torch.LongTensor |
|
) -> Tuple[torch.Tensor, torch.LongTensor]: |
|
|
|
if self.norm_means: |
|
x += self.bias.type_as(x) |
|
x.masked_fill(make_pad_mask(ilens, x, 1), 0.0) |
|
|
|
if self.norm_vars: |
|
x *= self.scale.type_as(x) |
|
return x, ilens |
|
|
|
|
|
class UtteranceMVN(torch.nn.Module): |
|
def __init__( |
|
self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20 |
|
): |
|
super().__init__() |
|
self.norm_means = norm_means |
|
self.norm_vars = norm_vars |
|
self.eps = eps |
|
|
|
def extra_repr(self): |
|
return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}" |
|
|
|
def forward( |
|
self, x: torch.Tensor, ilens: torch.LongTensor |
|
) -> Tuple[torch.Tensor, torch.LongTensor]: |
|
return utterance_mvn( |
|
x, ilens, norm_means=self.norm_means, norm_vars=self.norm_vars, eps=self.eps |
|
) |
|
|
|
|
|
def utterance_mvn( |
|
x: torch.Tensor, |
|
ilens: torch.LongTensor, |
|
norm_means: bool = True, |
|
norm_vars: bool = False, |
|
eps: float = 1.0e-20, |
|
) -> Tuple[torch.Tensor, torch.LongTensor]: |
|
"""Apply utterance mean and variance normalization |
|
|
|
Args: |
|
x: (B, T, D), assumed zero padded |
|
ilens: (B, T, D) |
|
norm_means: |
|
norm_vars: |
|
eps: |
|
|
|
""" |
|
ilens_ = ilens.type_as(x) |
|
|
|
mean = x.sum(dim=1) / ilens_[:, None] |
|
|
|
if norm_means: |
|
x -= mean[:, None, :] |
|
x_ = x |
|
else: |
|
x_ = x - mean[:, None, :] |
|
|
|
|
|
x_.masked_fill(make_pad_mask(ilens, x_, 1), 0.0) |
|
if norm_vars: |
|
var = x_.pow(2).sum(dim=1) / ilens_[:, None] |
|
var = torch.clamp(var, min=eps) |
|
x /= var.sqrt()[:, None, :] |
|
x_ = x |
|
return x_, ilens |
|
|
|
|
|
def feature_transform_for(args, n_fft): |
|
return FeatureTransform( |
|
|
|
fs=args.fbank_fs, |
|
n_fft=n_fft, |
|
n_mels=args.n_mels, |
|
fmin=args.fbank_fmin, |
|
fmax=args.fbank_fmax, |
|
|
|
stats_file=args.stats_file, |
|
apply_uttmvn=args.apply_uttmvn, |
|
uttmvn_norm_means=args.uttmvn_norm_means, |
|
uttmvn_norm_vars=args.uttmvn_norm_vars, |
|
) |
|
|