File size: 2,298 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
from typing import Any
from typing import Dict
from typing import Optional
from typing import Tuple
import torch
from typeguard import check_argument_types
from espnet2.layers.stft import Stft
from espnet2.tts.feats_extract.abs_feats_extract import AbsFeatsExtract
class LogSpectrogram(AbsFeatsExtract):
"""Conventional frontend structure for ASR
Stft -> log-amplitude-spec
"""
def __init__(
self,
n_fft: int = 1024,
win_length: int = None,
hop_length: int = 256,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
):
assert check_argument_types()
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.stft = Stft(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
window=window,
center=center,
normalized=normalized,
onesided=onesided,
)
self.n_fft = n_fft
def output_size(self) -> int:
return self.n_fft // 2 + 1
def get_parameters(self) -> Dict[str, Any]:
"""Return the parameters required by Vocoder"""
return dict(
n_fft=self.n_fft,
n_shift=self.hop_length,
win_length=self.win_length,
window=self.window,
)
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Stft: time -> time-freq
input_stft, feats_lens = self.stft(input, input_lengths)
assert input_stft.dim() >= 4, input_stft.shape
# "2" refers to the real/imag parts of Complex
assert input_stft.shape[-1] == 2, input_stft.shape
# NOTE(kamo): We use different definition for log-spec between TTS and ASR
# TTS: log_10(abs(stft))
# ASR: log_e(power(stft))
# STFT -> Power spectrum
# input_stft: (..., F, 2) -> (..., F)
input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2
log_amp = 0.5 * torch.log10(torch.clamp(input_power, min=1.0e-10))
return log_amp, feats_lens
|