Spaces:
Runtime error
Runtime error
# Copyright 2023 (authors: Feiteng Li) | |
# | |
# See ../../../../LICENSE for clarification regarding multiple authors | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from dataclasses import asdict, dataclass | |
from typing import Any, Dict, Optional, Union | |
import numpy as np | |
import torch | |
# from lhotse.features.base import FeatureExtractor | |
# from lhotse.utils import EPSILON, Seconds, compute_num_frames | |
from librosa.filters import mel as librosa_mel_fn | |
class BigVGANFbankConfig: | |
# Spectogram-related part | |
# Note that frame_length and frame_shift will be converted to milliseconds before torchaudio/Kaldi sees them | |
frame_length: Seconds = 1024 / 24000.0 | |
frame_shift: Seconds = 256 / 24000.0 | |
remove_dc_offset: bool = True | |
round_to_power_of_two: bool = True | |
# Fbank-related part | |
low_freq: float = 0.0 | |
high_freq: float = 12000.0 | |
num_mel_bins: int = 100 | |
use_energy: bool = False | |
def to_dict(self) -> Dict[str, Any]: | |
return asdict(self) | |
def from_dict(data: Dict[str, Any]) -> "BigVGANFbankConfig": | |
return BigVGANFbankConfig(**data) | |
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): | |
return torch.log(torch.clamp(x, min=clip_val) * C) | |
def spectral_normalize_torch(magnitudes): | |
output = dynamic_range_compression_torch(magnitudes) | |
return output | |
# https://github.com/NVIDIA/BigVGAN | |
# bigvgan_24khz_100band https://drive.google.com/drive/folders/1EpxX6AsxjCbbk0mmAhE0td6eYiABr8Oz | |
class BigVGANFbank(FeatureExtractor): | |
name = "fbank" | |
config_type = BigVGANFbankConfig | |
def __init__(self, config: Optional[Any] = None): | |
super(BigVGANFbank, self).__init__(config) | |
sampling_rate = 24000 | |
self.mel_basis = torch.from_numpy( | |
librosa_mel_fn( | |
sampling_rate, | |
1024, | |
self.config.num_mel_bins, | |
self.config.low_freq, | |
self.config.high_freq, | |
).astype(np.float32) | |
) | |
self.hann_window = torch.hann_window(1024) | |
def _feature_fn(self, samples, **kwargs): | |
win_length, n_fft = 1024, 1024 | |
hop_size = 256 | |
if True: | |
sampling_rate = 24000 | |
duration = round(samples.shape[-1] / sampling_rate, ndigits=12) | |
expected_num_frames = compute_num_frames( | |
duration=duration, | |
frame_shift=self.frame_shift, | |
sampling_rate=sampling_rate, | |
) | |
pad_size = ( | |
(expected_num_frames - 1) * hop_size | |
+ win_length | |
- samples.shape[-1] | |
) | |
assert pad_size >= 0 | |
y = torch.nn.functional.pad( | |
samples, | |
(0, pad_size), | |
mode="constant", | |
) | |
else: | |
y = torch.nn.functional.pad( | |
samples, | |
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), | |
mode="reflect", | |
) | |
y = y.squeeze(1) | |
# complex tensor as default, then use view_as_real for future pytorch compatibility | |
spec = torch.stft( | |
y, | |
n_fft, | |
hop_length=hop_size, | |
win_length=win_length, | |
window=self.hann_window, | |
center=False, | |
pad_mode="reflect", | |
normalized=False, | |
onesided=True, | |
return_complex=True, | |
) | |
spec = torch.view_as_real(spec) | |
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) | |
spec = torch.matmul(self.mel_basis, spec) | |
spec = spectral_normalize_torch(spec) | |
return spec.transpose(2, 1).squeeze(0) | |
def extract( | |
self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int | |
) -> np.ndarray: | |
assert sampling_rate == 24000 | |
params = asdict(self.config) | |
params.update({"sample_frequency": sampling_rate, "snip_edges": False}) | |
params["frame_shift"] *= 1000.0 | |
params["frame_length"] *= 1000.0 | |
if not isinstance(samples, torch.Tensor): | |
samples = torch.from_numpy(samples) | |
# Torchaudio Kaldi feature extractors expect the channel dimension to be first. | |
if len(samples.shape) == 1: | |
samples = samples.unsqueeze(0) | |
features = self._feature_fn(samples, **params).to(torch.float32) | |
return features.numpy() | |
def frame_shift(self) -> Seconds: | |
return self.config.frame_shift | |
def feature_dim(self, sampling_rate: int) -> int: | |
return self.config.num_mel_bins | |
def mix( | |
features_a: np.ndarray, | |
features_b: np.ndarray, | |
energy_scaling_factor_b: float, | |
) -> np.ndarray: | |
return np.log( | |
np.maximum( | |
# protection against log(0); max with EPSILON is adequate since these are energies (always >= 0) | |
EPSILON, | |
np.exp(features_a) | |
+ energy_scaling_factor_b * np.exp(features_b), | |
) | |
) | |
def compute_energy(features: np.ndarray) -> float: | |
return float(np.sum(np.exp(features))) | |
def get_fbank_extractor() -> BigVGANFbank: | |
return BigVGANFbank(BigVGANFbankConfig()) | |
if __name__ == "__main__": | |
extractor = BigVGANFbank(BigVGANFbankConfig()) | |
samples = torch.from_numpy(np.random.random([1000]).astype(np.float32)) | |
samples = torch.clip(samples, -1.0, 1.0) | |
fbank = extractor.extract(samples, 24000.0) | |
print(f"fbank {fbank.shape}") | |
from scipy.io.wavfile import read | |
MAX_WAV_VALUE = 32768.0 | |
sampling_rate, samples = read( | |
"egs/libritts/prompts/5639_40744_000000_000002.wav" | |
) | |
print(f"samples: [{samples.min()}, {samples.max()}]") | |
fbank = extractor.extract(samples.astype(np.float32) / MAX_WAV_VALUE, 24000) | |
print(f"fbank {fbank.shape}") | |
import matplotlib.pyplot as plt | |
_ = plt.figure(figsize=(18, 10)) | |
plt.imshow( | |
X=fbank.transpose(1, 0), | |
cmap=plt.get_cmap("jet"), | |
aspect="auto", | |
interpolation="nearest", | |
) | |
plt.gca().invert_yaxis() | |
plt.savefig("egs/libritts/prompts/5639_40744_000000_000002.png") | |
plt.close() | |
print("fbank test PASS!") | |