File size: 7,975 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 |
from typing import List
from typing import Tuple
from typing import Union
import librosa
import numpy as np
import torch
from torch_complex.tensor import ComplexTensor
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
class FeatureTransform(torch.nn.Module):
def __init__(
self,
# Mel options,
fs: int = 16000,
n_fft: int = 512,
n_mels: int = 80,
fmin: float = 0.0,
fmax: float = None,
# Normalization
stats_file: str = None,
apply_uttmvn: bool = True,
uttmvn_norm_means: bool = True,
uttmvn_norm_vars: bool = False,
):
super().__init__()
self.apply_uttmvn = apply_uttmvn
self.logmel = LogMel(fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
self.stats_file = stats_file
if stats_file is not None:
self.global_mvn = GlobalMVN(stats_file)
else:
self.global_mvn = None
if self.apply_uttmvn is not None:
self.uttmvn = UtteranceMVN(
norm_means=uttmvn_norm_means, norm_vars=uttmvn_norm_vars
)
else:
self.uttmvn = None
def forward(
self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]]
) -> Tuple[torch.Tensor, torch.LongTensor]:
# (B, T, F) or (B, T, C, F)
if x.dim() not in (3, 4):
raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
if not torch.is_tensor(ilens):
ilens = torch.from_numpy(np.asarray(ilens)).to(x.device)
if x.dim() == 4:
# h: (B, T, C, F) -> h: (B, T, F)
if self.training:
# Select 1ch randomly
ch = np.random.randint(x.size(2))
h = x[:, :, ch, :]
else:
# Use the first channel
h = x[:, :, 0, :]
else:
h = x
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
h = h.real ** 2 + h.imag ** 2
h, _ = self.logmel(h, ilens)
if self.stats_file is not None:
h, _ = self.global_mvn(h, ilens)
if self.apply_uttmvn:
h, _ = self.uttmvn(h, ilens)
return h, ilens
class LogMel(torch.nn.Module):
"""Convert STFT to fbank feats
The arguments is same as librosa.filters.mel
Args:
fs: number > 0 [scalar] sampling rate of the incoming signal
n_fft: int > 0 [scalar] number of FFT components
n_mels: int > 0 [scalar] number of Mel bands to generate
fmin: float >= 0 [scalar] lowest frequency (in Hz)
fmax: float >= 0 [scalar] highest frequency (in Hz).
If `None`, use `fmax = fs / 2.0`
htk: use HTK formula instead of Slaney
norm: {None, 1, np.inf} [scalar]
if 1, divide the triangular mel weights by the width of the mel band
(area normalization). Otherwise, leave all the triangles aiming for
a peak value of 1.0
"""
def __init__(
self,
fs: int = 16000,
n_fft: int = 512,
n_mels: int = 80,
fmin: float = 0.0,
fmax: float = None,
htk: bool = False,
norm=1,
):
super().__init__()
_mel_options = dict(
sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm
)
self.mel_options = _mel_options
# Note(kamo): The mel matrix of librosa is different from kaldi.
melmat = librosa.filters.mel(**_mel_options)
# melmat: (D2, D1) -> (D1, D2)
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
def extra_repr(self):
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
def forward(
self, feat: torch.Tensor, ilens: torch.LongTensor
) -> Tuple[torch.Tensor, torch.LongTensor]:
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
mel_feat = torch.matmul(feat, self.melmat)
logmel_feat = (mel_feat + 1e-20).log()
# Zero padding
logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0)
return logmel_feat, ilens
class GlobalMVN(torch.nn.Module):
"""Apply global mean and variance normalization
Args:
stats_file(str): npy file of 1-dim array or text file.
From the _first element to
the {(len(array) - 1) / 2}th element are treated as
the sum of features,
and the rest excluding the last elements are
treated as the sum of the square value of features,
and the last elements eqauls to the number of samples.
std_floor(float):
"""
def __init__(
self,
stats_file: str,
norm_means: bool = True,
norm_vars: bool = True,
eps: float = 1.0e-20,
):
super().__init__()
self.norm_means = norm_means
self.norm_vars = norm_vars
self.stats_file = stats_file
stats = np.load(stats_file)
stats = stats.astype(float)
assert (len(stats) - 1) % 2 == 0, stats.shape
count = stats.flatten()[-1]
mean = stats[: (len(stats) - 1) // 2] / count
var = stats[(len(stats) - 1) // 2 : -1] / count - mean * mean
std = np.maximum(np.sqrt(var), eps)
self.register_buffer("bias", torch.from_numpy(-mean.astype(np.float32)))
self.register_buffer("scale", torch.from_numpy(1 / std.astype(np.float32)))
def extra_repr(self):
return (
f"stats_file={self.stats_file}, "
f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
)
def forward(
self, x: torch.Tensor, ilens: torch.LongTensor
) -> Tuple[torch.Tensor, torch.LongTensor]:
# feat: (B, T, D)
if self.norm_means:
x += self.bias.type_as(x)
x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
if self.norm_vars:
x *= self.scale.type_as(x)
return x, ilens
class UtteranceMVN(torch.nn.Module):
def __init__(
self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20
):
super().__init__()
self.norm_means = norm_means
self.norm_vars = norm_vars
self.eps = eps
def extra_repr(self):
return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
def forward(
self, x: torch.Tensor, ilens: torch.LongTensor
) -> Tuple[torch.Tensor, torch.LongTensor]:
return utterance_mvn(
x, ilens, norm_means=self.norm_means, norm_vars=self.norm_vars, eps=self.eps
)
def utterance_mvn(
x: torch.Tensor,
ilens: torch.LongTensor,
norm_means: bool = True,
norm_vars: bool = False,
eps: float = 1.0e-20,
) -> Tuple[torch.Tensor, torch.LongTensor]:
"""Apply utterance mean and variance normalization
Args:
x: (B, T, D), assumed zero padded
ilens: (B, T, D)
norm_means:
norm_vars:
eps:
"""
ilens_ = ilens.type_as(x)
# mean: (B, D)
mean = x.sum(dim=1) / ilens_[:, None]
if norm_means:
x -= mean[:, None, :]
x_ = x
else:
x_ = x - mean[:, None, :]
# Zero padding
x_.masked_fill(make_pad_mask(ilens, x_, 1), 0.0)
if norm_vars:
var = x_.pow(2).sum(dim=1) / ilens_[:, None]
var = torch.clamp(var, min=eps)
x /= var.sqrt()[:, None, :]
x_ = x
return x_, ilens
def feature_transform_for(args, n_fft):
return FeatureTransform(
# Mel options,
fs=args.fbank_fs,
n_fft=n_fft,
n_mels=args.n_mels,
fmin=args.fbank_fmin,
fmax=args.fbank_fmax,
# Normalization
stats_file=args.stats_file,
apply_uttmvn=args.apply_uttmvn,
uttmvn_norm_means=args.uttmvn_norm_means,
uttmvn_norm_vars=args.uttmvn_norm_vars,
)
|