PeechTTSv22050 / models /delightful_univnet.py
nickovchinnikov's picture
Init
9d61c9b
from lightning.pytorch.core import LightningModule
from torch import Tensor
from models.config import AcousticENModelConfig
from models.config import PreprocessingConfigUnivNet as PreprocessingConfig
from models.tts.delightful_tts.delightful_tts import DelightfulTTS
from models.vocoder.univnet import UnivNet
class DelightfulUnivnet(LightningModule):
def __init__(
self,
delightful_checkpoint_path: str,
lang: str = "en",
sampling_rate: int = 22050,
):
super().__init__()
self.sampling_rate = sampling_rate
self.preprocess_config = PreprocessingConfig(
"english_only",
sampling_rate=sampling_rate,
)
self.delightful_tts = DelightfulTTS.load_from_checkpoint(
delightful_checkpoint_path,
strict=False,
# kwargs to be used in the model
preprocess_config=self.preprocess_config,
model_config=AcousticENModelConfig(),
lang=lang,
sampling_rate=sampling_rate,
)
self.delightful_tts.freeze()
# Don't need to use separated checkpoint, prev checkpoint used
self.univnet = UnivNet()
self.univnet.freeze()
def forward(
self,
text: str,
speaker_idx: Tensor,
) -> Tensor:
r"""Performs a forward pass through the AcousticModel.
This code must be run only with the loaded weights from the checkpoint!
Args:
text (str): The input text.
speaker_idx (Tensor): The index of the speaker.
Returns:
Tensor: The generated waveform with hifi-gan.
"""
mel_pred = self.delightful_tts.forward(text, speaker_idx)
wav = self.univnet.forward(mel_pred)
return wav