File size: 3,069 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
from typing import Any
from typing import Dict
from typing import Optional
from typing import Tuple
from typing import Union
import humanfriendly
import torch
from typeguard import check_argument_types
from espnet2.layers.log_mel import LogMel
from espnet2.layers.stft import Stft
from espnet2.tts.feats_extract.abs_feats_extract import AbsFeatsExtract
class LogMelFbank(AbsFeatsExtract):
"""Conventional frontend structure for ASR
Stft -> amplitude-spec -> Log-Mel-Fbank
"""
def __init__(
self,
fs: Union[int, str] = 16000,
n_fft: int = 1024,
win_length: int = None,
hop_length: int = 256,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
n_mels: int = 80,
fmin: Optional[int] = 80,
fmax: Optional[int] = 7600,
htk: bool = False,
):
assert check_argument_types()
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
self.fs = fs
self.n_mels = n_mels
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.fmin = fmin
self.fmax = fmax
self.stft = Stft(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
window=window,
center=center,
normalized=normalized,
onesided=onesided,
)
self.logmel = LogMel(
fs=fs,
n_fft=n_fft,
n_mels=n_mels,
fmin=fmin,
fmax=fmax,
htk=htk,
log_base=10.0,
)
def output_size(self) -> int:
return self.n_mels
def get_parameters(self) -> Dict[str, Any]:
"""Return the parameters required by Vocoder"""
return dict(
fs=self.fs,
n_fft=self.n_fft,
n_shift=self.hop_length,
window=self.window,
n_mels=self.n_mels,
win_length=self.win_length,
fmin=self.fmin,
fmax=self.fmax,
)
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Domain-conversion: e.g. Stft: time -> time-freq
input_stft, feats_lens = self.stft(input, input_lengths)
assert input_stft.dim() >= 4, input_stft.shape
# "2" refers to the real/imag parts of Complex
assert input_stft.shape[-1] == 2, input_stft.shape
# NOTE(kamo): We use different definition for log-spec between TTS and ASR
# TTS: log_10(abs(stft))
# ASR: log_e(power(stft))
# input_stft: (..., F, 2) -> (..., F)
input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2
input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10))
input_feats, _ = self.logmel(input_amp, feats_lens)
return input_feats, feats_lens
|