from lightning.pytorch.core import LightningModule from torch import Tensor from models.config import AcousticENModelConfig from models.config import PreprocessingConfigUnivNet as PreprocessingConfig from models.tts.delightful_tts.delightful_tts import DelightfulTTS from models.vocoder.univnet import UnivNet class DelightfulUnivnet(LightningModule): def __init__( self, delightful_checkpoint_path: str, lang: str = "en", sampling_rate: int = 22050, ): super().__init__() self.sampling_rate = sampling_rate self.preprocess_config = PreprocessingConfig( "english_only", sampling_rate=sampling_rate, ) self.delightful_tts = DelightfulTTS.load_from_checkpoint( delightful_checkpoint_path, strict=False, # kwargs to be used in the model preprocess_config=self.preprocess_config, model_config=AcousticENModelConfig(), lang=lang, sampling_rate=sampling_rate, ) self.delightful_tts.freeze() # Don't need to use separated checkpoint, prev checkpoint used self.univnet = UnivNet() self.univnet.freeze() def forward( self, text: str, speaker_idx: Tensor, ) -> Tensor: r"""Performs a forward pass through the AcousticModel. This code must be run only with the loaded weights from the checkpoint! Args: text (str): The input text. speaker_idx (Tensor): The index of the speaker. Returns: Tensor: The generated waveform with hifi-gan. """ mel_pred = self.delightful_tts.forward(text, speaker_idx) wav = self.univnet.forward(mel_pred) return wav