Spaces:
Sleeping
Sleeping
from lightning.pytorch.core import LightningModule | |
from torch import Tensor | |
from models.config import PreprocessingConfigHifiGAN as PreprocessingConfig | |
from models.tts.delightful_tts.delightful_tts import DelightfulTTS | |
from models.vocoder.hifigan import HifiGan | |
class DelightfulHiFi(LightningModule): | |
def __init__( | |
self, | |
delightful_checkpoint_path: str, | |
hifi_checkpoint_path: str, | |
lang: str = "en", | |
sampling_rate: int = 44100, | |
): | |
super().__init__() | |
self.sampling_rate = sampling_rate | |
self.preprocess_config = PreprocessingConfig( | |
"multilingual", | |
sampling_rate=sampling_rate, | |
) | |
self.delightful_tts = DelightfulTTS.load_from_checkpoint( | |
delightful_checkpoint_path, | |
# kwargs to be used in the model | |
lang=lang, | |
sampling_rate=sampling_rate, | |
preprocess_config=self.preprocess_config, | |
) | |
self.delightful_tts.freeze() | |
self.hifi_gan = HifiGan.load_from_checkpoint( | |
hifi_checkpoint_path, | |
) | |
self.hifi_gan.freeze() | |
def forward( | |
self, | |
text: str, | |
speaker_idx: Tensor, | |
) -> Tensor: | |
r"""Performs a forward pass through the AcousticModel. | |
This code must be run only with the loaded weights from the checkpoint! | |
Args: | |
text (str): The input text. | |
speaker_idx (Tensor): The index of the speaker. | |
Returns: | |
Tensor: The generated waveform with hifi-gan. | |
""" | |
mel_pred = self.delightful_tts.forward(text, speaker_idx) | |
wav = self.hifi_gan.generator.forward(mel_pred) | |
return wav | |