Spaces:
Runtime error
Runtime error
from typing import List | |
from typing import Tuple | |
from typing import Union | |
import librosa | |
import numpy as np | |
import torch | |
from torch_complex.tensor import ComplexTensor | |
from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask | |
class FeatureTransform(torch.nn.Module): | |
def __init__( | |
self, | |
# Mel options, | |
fs: int = 16000, | |
n_fft: int = 512, | |
n_mels: int = 80, | |
fmin: float = 0.0, | |
fmax: float = None, | |
# Normalization | |
stats_file: str = None, | |
apply_uttmvn: bool = True, | |
uttmvn_norm_means: bool = True, | |
uttmvn_norm_vars: bool = False, | |
): | |
super().__init__() | |
self.apply_uttmvn = apply_uttmvn | |
self.logmel = LogMel(fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) | |
self.stats_file = stats_file | |
if stats_file is not None: | |
self.global_mvn = GlobalMVN(stats_file) | |
else: | |
self.global_mvn = None | |
if self.apply_uttmvn is not None: | |
self.uttmvn = UtteranceMVN( | |
norm_means=uttmvn_norm_means, norm_vars=uttmvn_norm_vars | |
) | |
else: | |
self.uttmvn = None | |
def forward( | |
self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]] | |
) -> Tuple[torch.Tensor, torch.LongTensor]: | |
# (B, T, F) or (B, T, C, F) | |
if x.dim() not in (3, 4): | |
raise ValueError(f"Input dim must be 3 or 4: {x.dim()}") | |
if not torch.is_tensor(ilens): | |
ilens = torch.from_numpy(np.asarray(ilens)).to(x.device) | |
if x.dim() == 4: | |
# h: (B, T, C, F) -> h: (B, T, F) | |
if self.training: | |
# Select 1ch randomly | |
ch = np.random.randint(x.size(2)) | |
h = x[:, :, ch, :] | |
else: | |
# Use the first channel | |
h = x[:, :, 0, :] | |
else: | |
h = x | |
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F) | |
h = h.real**2 + h.imag**2 | |
h, _ = self.logmel(h, ilens) | |
if self.stats_file is not None: | |
h, _ = self.global_mvn(h, ilens) | |
if self.apply_uttmvn: | |
h, _ = self.uttmvn(h, ilens) | |
return h, ilens | |
class LogMel(torch.nn.Module): | |
"""Convert STFT to fbank feats | |
The arguments is same as librosa.filters.mel | |
Args: | |
fs: number > 0 [scalar] sampling rate of the incoming signal | |
n_fft: int > 0 [scalar] number of FFT components | |
n_mels: int > 0 [scalar] number of Mel bands to generate | |
fmin: float >= 0 [scalar] lowest frequency (in Hz) | |
fmax: float >= 0 [scalar] highest frequency (in Hz). | |
If `None`, use `fmax = fs / 2.0` | |
htk: use HTK formula instead of Slaney | |
norm: {None, 1, np.inf} [scalar] | |
if 1, divide the triangular mel weights by the width of the mel band | |
(area normalization). Otherwise, leave all the triangles aiming for | |
a peak value of 1.0 | |
""" | |
def __init__( | |
self, | |
fs: int = 16000, | |
n_fft: int = 512, | |
n_mels: int = 80, | |
fmin: float = 0.0, | |
fmax: float = None, | |
htk: bool = False, | |
norm=1, | |
): | |
super().__init__() | |
_mel_options = dict( | |
sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm | |
) | |
self.mel_options = _mel_options | |
# Note(kamo): The mel matrix of librosa is different from kaldi. | |
melmat = librosa.filters.mel(**_mel_options) | |
# melmat: (D2, D1) -> (D1, D2) | |
self.register_buffer("melmat", torch.from_numpy(melmat.T).float()) | |
def extra_repr(self): | |
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items()) | |
def forward( | |
self, feat: torch.Tensor, ilens: torch.LongTensor | |
) -> Tuple[torch.Tensor, torch.LongTensor]: | |
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2) | |
mel_feat = torch.matmul(feat, self.melmat) | |
logmel_feat = (mel_feat + 1e-20).log() | |
# Zero padding | |
logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0) | |
return logmel_feat, ilens | |
class GlobalMVN(torch.nn.Module): | |
"""Apply global mean and variance normalization | |
Args: | |
stats_file(str): npy file of 1-dim array or text file. | |
From the _first element to | |
the {(len(array) - 1) / 2}th element are treated as | |
the sum of features, | |
and the rest excluding the last elements are | |
treated as the sum of the square value of features, | |
and the last elements eqauls to the number of samples. | |
std_floor(float): | |
""" | |
def __init__( | |
self, | |
stats_file: str, | |
norm_means: bool = True, | |
norm_vars: bool = True, | |
eps: float = 1.0e-20, | |
): | |
super().__init__() | |
self.norm_means = norm_means | |
self.norm_vars = norm_vars | |
self.stats_file = stats_file | |
stats = np.load(stats_file) | |
stats = stats.astype(float) | |
assert (len(stats) - 1) % 2 == 0, stats.shape | |
count = stats.flatten()[-1] | |
mean = stats[: (len(stats) - 1) // 2] / count | |
var = stats[(len(stats) - 1) // 2 : -1] / count - mean * mean | |
std = np.maximum(np.sqrt(var), eps) | |
self.register_buffer("bias", torch.from_numpy(-mean.astype(np.float32))) | |
self.register_buffer("scale", torch.from_numpy(1 / std.astype(np.float32))) | |
def extra_repr(self): | |
return ( | |
f"stats_file={self.stats_file}, " | |
f"norm_means={self.norm_means}, norm_vars={self.norm_vars}" | |
) | |
def forward( | |
self, x: torch.Tensor, ilens: torch.LongTensor | |
) -> Tuple[torch.Tensor, torch.LongTensor]: | |
# feat: (B, T, D) | |
if self.norm_means: | |
x += self.bias.type_as(x) | |
x.masked_fill(make_pad_mask(ilens, x, 1), 0.0) | |
if self.norm_vars: | |
x *= self.scale.type_as(x) | |
return x, ilens | |
class UtteranceMVN(torch.nn.Module): | |
def __init__( | |
self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20 | |
): | |
super().__init__() | |
self.norm_means = norm_means | |
self.norm_vars = norm_vars | |
self.eps = eps | |
def extra_repr(self): | |
return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}" | |
def forward( | |
self, x: torch.Tensor, ilens: torch.LongTensor | |
) -> Tuple[torch.Tensor, torch.LongTensor]: | |
return utterance_mvn( | |
x, ilens, norm_means=self.norm_means, norm_vars=self.norm_vars, eps=self.eps | |
) | |
def utterance_mvn( | |
x: torch.Tensor, | |
ilens: torch.LongTensor, | |
norm_means: bool = True, | |
norm_vars: bool = False, | |
eps: float = 1.0e-20, | |
) -> Tuple[torch.Tensor, torch.LongTensor]: | |
"""Apply utterance mean and variance normalization | |
Args: | |
x: (B, T, D), assumed zero padded | |
ilens: (B, T, D) | |
norm_means: | |
norm_vars: | |
eps: | |
""" | |
ilens_ = ilens.type_as(x) | |
# mean: (B, D) | |
mean = x.sum(dim=1) / ilens_[:, None] | |
if norm_means: | |
x -= mean[:, None, :] | |
x_ = x | |
else: | |
x_ = x - mean[:, None, :] | |
# Zero padding | |
x_.masked_fill(make_pad_mask(ilens, x_, 1), 0.0) | |
if norm_vars: | |
var = x_.pow(2).sum(dim=1) / ilens_[:, None] | |
var = torch.clamp(var, min=eps) | |
x /= var.sqrt()[:, None, :] | |
x_ = x | |
return x_, ilens | |
def feature_transform_for(args, n_fft): | |
return FeatureTransform( | |
# Mel options, | |
fs=args.fbank_fs, | |
n_fft=n_fft, | |
n_mels=args.n_mels, | |
fmin=args.fbank_fmin, | |
fmax=args.fbank_fmax, | |
# Normalization | |
stats_file=args.stats_file, | |
apply_uttmvn=args.apply_uttmvn, | |
uttmvn_norm_means=args.uttmvn_norm_means, | |
uttmvn_norm_vars=args.uttmvn_norm_vars, | |
) | |