from typing import Any from typing import Dict from typing import Optional from typing import Tuple import torch from typeguard import check_argument_types from espnet2.layers.stft import Stft from espnet2.tts.feats_extract.abs_feats_extract import AbsFeatsExtract class LogSpectrogram(AbsFeatsExtract): """Conventional frontend structure for ASR Stft -> log-amplitude-spec """ def __init__( self, n_fft: int = 1024, win_length: int = None, hop_length: int = 256, window: Optional[str] = "hann", center: bool = True, normalized: bool = False, onesided: bool = True, ): assert check_argument_types() super().__init__() self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length self.window = window self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided, ) self.n_fft = n_fft def output_size(self) -> int: return self.n_fft // 2 + 1 def get_parameters(self) -> Dict[str, Any]: """Return the parameters required by Vocoder""" return dict( n_fft=self.n_fft, n_shift=self.hop_length, win_length=self.win_length, window=self.window, ) def forward( self, input: torch.Tensor, input_lengths: torch.Tensor = None ) -> Tuple[torch.Tensor, torch.Tensor]: # 1. Stft: time -> time-freq input_stft, feats_lens = self.stft(input, input_lengths) assert input_stft.dim() >= 4, input_stft.shape # "2" refers to the real/imag parts of Complex assert input_stft.shape[-1] == 2, input_stft.shape # NOTE(kamo): We use different definition for log-spec between TTS and ASR # TTS: log_10(abs(stft)) # ASR: log_e(power(stft)) # STFT -> Power spectrum # input_stft: (..., F, 2) -> (..., F) input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2 log_amp = 0.5 * torch.log10(torch.clamp(input_power, min=1.0e-10)) return log_amp, feats_lens