File size: 4,788 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# Copyright 2020 Nagoya University (Tomoki Hayashi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Energy extractor."""
from typing import Any
from typing import Dict
from typing import Tuple
from typing import Union
import humanfriendly
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
from espnet.nets.pytorch_backend.nets_utils import pad_list
from espnet2.layers.stft import Stft
from espnet2.tts.feats_extract.abs_feats_extract import AbsFeatsExtract
class Energy(AbsFeatsExtract):
"""Energy extractor."""
def __init__(
self,
fs: Union[int, str] = 22050,
n_fft: int = 1024,
win_length: int = None,
hop_length: int = 256,
window: str = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
use_token_averaged_energy: bool = True,
reduction_factor: int = None,
):
assert check_argument_types()
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
self.fs = fs
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.use_token_averaged_energy = use_token_averaged_energy
if use_token_averaged_energy:
assert reduction_factor >= 1
self.reduction_factor = reduction_factor
self.stft = Stft(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
window=window,
center=center,
normalized=normalized,
onesided=onesided,
)
def output_size(self) -> int:
return 1
def get_parameters(self) -> Dict[str, Any]:
return dict(
fs=self.fs,
n_fft=self.n_fft,
hop_length=self.hop_length,
window=self.window,
win_length=self.win_length,
center=self.stft.center,
normalized=self.stft.normalized,
use_token_averaged_energy=self.use_token_averaged_energy,
reduction_factor=self.reduction_factor,
)
def forward(
self,
input: torch.Tensor,
input_lengths: torch.Tensor = None,
feats_lengths: torch.Tensor = None,
durations: torch.Tensor = None,
durations_lengths: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# If not provide, we assume that the inputs have the same length
if input_lengths is None:
input_lengths = (
input.new_ones(input.shape[0], dtype=torch.long) * input.shape[1]
)
# Domain-conversion: e.g. Stft: time -> time-freq
input_stft, energy_lengths = self.stft(input, input_lengths)
assert input_stft.dim() >= 4, input_stft.shape
assert input_stft.shape[-1] == 2, input_stft.shape
# input_stft: (..., F, 2) -> (..., F)
input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2
# sum over frequency (B, N, F) -> (B, N)
energy = torch.sqrt(torch.clamp(input_power.sum(dim=2), min=1.0e-10))
# (Optional): Adjust length to match with the mel-spectrogram
if feats_lengths is not None:
energy = [
self._adjust_num_frames(e[:el].view(-1), fl)
for e, el, fl in zip(energy, energy_lengths, feats_lengths)
]
energy_lengths = feats_lengths
# (Optional): Average by duration to calculate token-wise energy
if self.use_token_averaged_energy:
durations = durations * self.reduction_factor
energy = [
self._average_by_duration(e[:el].view(-1), d)
for e, el, d in zip(energy, energy_lengths, durations)
]
energy_lengths = durations_lengths
# Padding
if isinstance(energy, list):
energy = pad_list(energy, 0.0)
# Return with the shape (B, T, 1)
return energy.unsqueeze(-1), energy_lengths
def _average_by_duration(self, x: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
assert 0 <= len(x) - d.sum() < self.reduction_factor
d_cumsum = F.pad(d.cumsum(dim=0), (1, 0))
x_avg = [
x[start:end].mean() if len(x[start:end]) != 0 else x.new_tensor(0.0)
for start, end in zip(d_cumsum[:-1], d_cumsum[1:])
]
return torch.stack(x_avg)
@staticmethod
def _adjust_num_frames(x: torch.Tensor, num_frames: torch.Tensor) -> torch.Tensor:
if num_frames > len(x):
x = F.pad(x, (0, num_frames - len(x)))
elif num_frames < len(x):
x = x[:num_frames]
return x
|