Spaces:
Sleeping
Sleeping
File size: 1,805 Bytes
9d61c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from lightning.pytorch.core import LightningModule
from torch import Tensor
from models.config import AcousticENModelConfig
from models.config import PreprocessingConfigUnivNet as PreprocessingConfig
from models.tts.delightful_tts.delightful_tts import DelightfulTTS
from models.vocoder.univnet import UnivNet
class DelightfulUnivnet(LightningModule):
def __init__(
self,
delightful_checkpoint_path: str,
lang: str = "en",
sampling_rate: int = 22050,
):
super().__init__()
self.sampling_rate = sampling_rate
self.preprocess_config = PreprocessingConfig(
"english_only",
sampling_rate=sampling_rate,
)
self.delightful_tts = DelightfulTTS.load_from_checkpoint(
delightful_checkpoint_path,
strict=False,
# kwargs to be used in the model
preprocess_config=self.preprocess_config,
model_config=AcousticENModelConfig(),
lang=lang,
sampling_rate=sampling_rate,
)
self.delightful_tts.freeze()
# Don't need to use separated checkpoint, prev checkpoint used
self.univnet = UnivNet()
self.univnet.freeze()
def forward(
self,
text: str,
speaker_idx: Tensor,
) -> Tensor:
r"""Performs a forward pass through the AcousticModel.
This code must be run only with the loaded weights from the checkpoint!
Args:
text (str): The input text.
speaker_idx (Tensor): The index of the speaker.
Returns:
Tensor: The generated waveform with hifi-gan.
"""
mel_pred = self.delightful_tts.forward(text, speaker_idx)
wav = self.univnet.forward(mel_pred)
return wav
|