|
|
|
|
|
|
|
|
|
|
|
import json |
|
import logging |
|
from typing import Dict |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
from fairseq.data.audio.audio_utils import ( |
|
TTSSpectrogram, |
|
get_fourier_basis, |
|
get_mel_filters, |
|
get_window, |
|
) |
|
from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig |
|
from fairseq.models import BaseFairseqModel, register_model |
|
from fairseq.models.text_to_speech.codehifigan import CodeGenerator as CodeHiFiGANModel |
|
from fairseq.models.text_to_speech.hifigan import Generator as HiFiGANModel |
|
from fairseq.models.text_to_speech.hub_interface import VocoderHubInterface |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class PseudoInverseMelScale(torch.nn.Module): |
|
def __init__(self, n_stft, n_mels, sample_rate, f_min, f_max) -> None: |
|
super(PseudoInverseMelScale, self).__init__() |
|
self.n_mels = n_mels |
|
basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max) |
|
basis = torch.pinverse(basis) |
|
self.register_buffer("basis", basis) |
|
|
|
def forward(self, melspec: torch.Tensor) -> torch.Tensor: |
|
|
|
shape = melspec.shape |
|
n_mels, time = shape[-2], shape[-1] |
|
melspec = melspec.view(-1, n_mels, time) |
|
|
|
freq, _ = self.basis.size() |
|
assert self.n_mels == n_mels, (self.n_mels, n_mels) |
|
specgram = self.basis.matmul(melspec).clamp(min=0) |
|
|
|
|
|
specgram = specgram.view(shape[:-2] + (freq, time)) |
|
return specgram |
|
|
|
|
|
class GriffinLim(torch.nn.Module): |
|
def __init__( |
|
self, |
|
n_fft: int, |
|
win_length: int, |
|
hop_length: int, |
|
n_iter: int, |
|
window_fn=torch.hann_window, |
|
): |
|
super(GriffinLim, self).__init__() |
|
self.transform = TTSSpectrogram( |
|
n_fft, win_length, hop_length, return_phase=True |
|
) |
|
|
|
basis = get_fourier_basis(n_fft) |
|
basis = torch.pinverse(n_fft / hop_length * basis).T[:, None, :] |
|
basis *= get_window(window_fn, n_fft, win_length) |
|
self.register_buffer("basis", basis) |
|
|
|
self.n_fft = n_fft |
|
self.win_length = win_length |
|
self.hop_length = hop_length |
|
self.n_iter = n_iter |
|
|
|
self.tiny = 1.1754944e-38 |
|
|
|
@classmethod |
|
def get_window_sum_square( |
|
cls, n_frames, hop_length, win_length, n_fft, window_fn=torch.hann_window |
|
) -> torch.Tensor: |
|
w_sq = get_window(window_fn, n_fft, win_length) ** 2 |
|
n = n_fft + hop_length * (n_frames - 1) |
|
x = torch.zeros(n, dtype=torch.float32) |
|
for i in range(n_frames): |
|
ofst = i * hop_length |
|
x[ofst : min(n, ofst + n_fft)] += w_sq[: max(0, min(n_fft, n - ofst))] |
|
return x |
|
|
|
def inverse(self, magnitude: torch.Tensor, phase) -> torch.Tensor: |
|
x = torch.cat( |
|
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 |
|
) |
|
x = F.conv_transpose1d(x, self.basis, stride=self.hop_length) |
|
win_sum_sq = self.get_window_sum_square( |
|
magnitude.shape[-1], |
|
hop_length=self.hop_length, |
|
win_length=self.win_length, |
|
n_fft=self.n_fft, |
|
).to(magnitude.device) |
|
|
|
approx_nonzero_indices = win_sum_sq > self.tiny |
|
x[:, :, approx_nonzero_indices] /= win_sum_sq[approx_nonzero_indices] |
|
x *= self.n_fft / self.hop_length |
|
x = x[:, :, self.n_fft // 2 :] |
|
x = x[:, :, : -self.n_fft // 2 :] |
|
return x |
|
|
|
def forward(self, specgram: torch.Tensor) -> torch.Tensor: |
|
angles = np.angle(np.exp(2j * np.pi * np.random.rand(*specgram.shape))) |
|
angles = torch.from_numpy(angles).to(specgram) |
|
_specgram = specgram.view(-1, specgram.shape[-2], specgram.shape[-1]) |
|
waveform = self.inverse(_specgram, angles).squeeze(1) |
|
for _ in range(self.n_iter): |
|
_, angles = self.transform(waveform) |
|
waveform = self.inverse(_specgram, angles).squeeze(1) |
|
return waveform.squeeze(0) |
|
|
|
|
|
class GriffinLimVocoder(nn.Module): |
|
def __init__( |
|
self, |
|
sample_rate, |
|
win_size, |
|
hop_size, |
|
n_fft, |
|
n_mels, |
|
f_min, |
|
f_max, |
|
window_fn, |
|
spec_bwd_max_iter=32, |
|
fp16=False, |
|
): |
|
super().__init__() |
|
self.inv_mel_transform = PseudoInverseMelScale( |
|
n_stft=n_fft // 2 + 1, |
|
n_mels=n_mels, |
|
sample_rate=sample_rate, |
|
f_min=f_min, |
|
f_max=f_max, |
|
) |
|
self.gl_transform = GriffinLim( |
|
n_fft=n_fft, |
|
win_length=win_size, |
|
hop_length=hop_size, |
|
window_fn=window_fn, |
|
n_iter=spec_bwd_max_iter, |
|
) |
|
if fp16: |
|
self.half() |
|
self.inv_mel_transform.half() |
|
self.gl_transform.half() |
|
else: |
|
self.float() |
|
self.inv_mel_transform.float() |
|
self.gl_transform.float() |
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
self.eval() |
|
x = x.exp().transpose(-1, -2) |
|
x = self.inv_mel_transform(x) |
|
x = self.gl_transform(x) |
|
return x |
|
|
|
@classmethod |
|
def from_data_cfg(cls, args, data_cfg: S2TDataConfig): |
|
feat_cfg = data_cfg.config["features"] |
|
window_fn = getattr(torch, feat_cfg["window_fn"] + "_window") |
|
return cls( |
|
sample_rate=feat_cfg["sample_rate"], |
|
win_size=int(feat_cfg["win_len_t"] * feat_cfg["sample_rate"]), |
|
hop_size=int(feat_cfg["hop_len_t"] * feat_cfg["sample_rate"]), |
|
n_fft=feat_cfg["n_fft"], |
|
n_mels=feat_cfg["n_mels"], |
|
f_min=feat_cfg["f_min"], |
|
f_max=feat_cfg["f_max"], |
|
window_fn=window_fn, |
|
spec_bwd_max_iter=args.spec_bwd_max_iter, |
|
fp16=args.fp16, |
|
) |
|
|
|
|
|
class HiFiGANVocoder(nn.Module): |
|
def __init__( |
|
self, checkpoint_path: str, model_cfg: Dict[str, str], fp16: bool = False |
|
) -> None: |
|
super().__init__() |
|
self.model = HiFiGANModel(model_cfg) |
|
state_dict = torch.load(checkpoint_path) |
|
self.model.load_state_dict(state_dict["generator"]) |
|
if fp16: |
|
self.model.half() |
|
logger.info(f"loaded HiFiGAN checkpoint from {checkpoint_path}") |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
model = self.model.eval() |
|
if len(x.shape) == 2: |
|
return model(x.unsqueeze(0).transpose(1, 2)).detach().squeeze(0) |
|
else: |
|
return model(x.transpose(-1, -2)).detach() |
|
|
|
@classmethod |
|
def from_data_cfg(cls, args, data_cfg: S2TDataConfig): |
|
vocoder_cfg = data_cfg.vocoder |
|
assert vocoder_cfg.get("type", "griffin_lim") == "hifigan" |
|
with open(vocoder_cfg["config"]) as f: |
|
model_cfg = json.load(f) |
|
return cls(vocoder_cfg["checkpoint"], model_cfg, fp16=args.fp16) |
|
|
|
|
|
@register_model("CodeHiFiGANVocoder") |
|
class CodeHiFiGANVocoder(BaseFairseqModel): |
|
def __init__( |
|
self, checkpoint_path: str, model_cfg: Dict[str, str], fp16: bool = False |
|
) -> None: |
|
super().__init__() |
|
self.model = CodeHiFiGANModel(model_cfg) |
|
if torch.cuda.is_available(): |
|
state_dict = torch.load(checkpoint_path) |
|
else: |
|
state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu")) |
|
self.model.load_state_dict(state_dict["generator"]) |
|
self.model.eval() |
|
if fp16: |
|
self.model.half() |
|
self.model.remove_weight_norm() |
|
logger.info(f"loaded CodeHiFiGAN checkpoint from {checkpoint_path}") |
|
|
|
def forward(self, x: Dict[str, torch.Tensor], dur_prediction=False) -> torch.Tensor: |
|
assert "code" in x |
|
x["dur_prediction"] = dur_prediction |
|
|
|
|
|
mask = x["code"] >= 0 |
|
x["code"] = x["code"][mask].unsqueeze(dim=0) |
|
if "f0" in x: |
|
f0_up_ratio = x["f0"].size(1) // x["code"].size(1) |
|
mask = mask.unsqueeze(2).repeat(1, 1, f0_up_ratio).view(-1, x["f0"].size(1)) |
|
x["f0"] = x["f0"][mask].unsqueeze(dim=0) |
|
|
|
return self.model(**x).detach().squeeze() |
|
|
|
@classmethod |
|
def from_data_cfg(cls, args, data_cfg): |
|
vocoder_cfg = data_cfg.vocoder |
|
assert vocoder_cfg is not None, "vocoder not specified in the data config" |
|
with open(vocoder_cfg["config"]) as f: |
|
model_cfg = json.load(f) |
|
return cls(vocoder_cfg["checkpoint"], model_cfg, fp16=args.fp16) |
|
|
|
@classmethod |
|
def hub_models(cls): |
|
base_url = "http://dl.fbaipublicfiles.com/fairseq/vocoder" |
|
model_ids = [ |
|
"unit_hifigan_mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj_dur", |
|
"unit_hifigan_mhubert_vp_en_es_fr_it3_400k_layer11_km1000_es_css10_dur", |
|
"unit_hifigan_HK_layer12.km2500_frame_TAT-TTS", |
|
] |
|
return {i: f"{base_url}/{i}.tar.gz" for i in model_ids} |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
model_name_or_path, |
|
checkpoint_file="model.pt", |
|
data_name_or_path=".", |
|
config="config.json", |
|
fp16: bool = False, |
|
**kwargs, |
|
): |
|
from fairseq import hub_utils |
|
|
|
x = hub_utils.from_pretrained( |
|
model_name_or_path, |
|
checkpoint_file, |
|
data_name_or_path, |
|
archive_map=cls.hub_models(), |
|
config_yaml=config, |
|
fp16=fp16, |
|
is_vocoder=True, |
|
**kwargs, |
|
) |
|
|
|
with open(f"{x['args']['data']}/{config}") as f: |
|
vocoder_cfg = json.load(f) |
|
assert len(x["args"]["model_path"]) == 1, "Too many vocoder models in the input" |
|
|
|
vocoder = CodeHiFiGANVocoder(x["args"]["model_path"][0], vocoder_cfg) |
|
return VocoderHubInterface(vocoder_cfg, vocoder) |
|
|
|
|
|
def get_vocoder(args, data_cfg: S2TDataConfig): |
|
if args.vocoder == "griffin_lim": |
|
return GriffinLimVocoder.from_data_cfg(args, data_cfg) |
|
elif args.vocoder == "hifigan": |
|
return HiFiGANVocoder.from_data_cfg(args, data_cfg) |
|
elif args.vocoder == "code_hifigan": |
|
return CodeHiFiGANVocoder.from_data_cfg(args, data_cfg) |
|
else: |
|
raise ValueError("Unknown vocoder") |
|
|