Spaces:
Sleeping
Sleeping
import torch | |
from torch.nn import Module | |
from models.helpers.tools import get_mask_from_lengths | |
from .generator import Generator | |
class TracedGenerator(Module): | |
def __init__( | |
self, | |
generator: Generator, | |
# example_inputs: Tuple[Any], example_inputs (Tuple[Any]): Example inputs to use for tracing. | |
): | |
r"""A traced version of the UnivNet class that can be used for faster inference. | |
Args: | |
generator (UnivNet): The UnivNet instance to trace. | |
""" | |
super().__init__() | |
self.mel_mask_value: float = generator.mel_mask_value | |
self.hop_length: int = generator.hop_length | |
# TODO: https://github.com/Lightning-AI/lightning/issues/14036 | |
# Disable trace since model is non-deterministic | |
# self.generator = torch.jit.trace(generator, example_inputs, check_trace=False) | |
# self.generator = generator.to_torchscript( | |
# method="trace", example_inputs=example_inputs, check_trace=False | |
# ) | |
self.generator = generator | |
def forward(self, c: torch.Tensor, mel_lens: torch.Tensor) -> torch.Tensor: | |
r"""Forward pass of the traced UnivNet. | |
Args: | |
c (torch.Tensor): The input mel-spectrogram tensor. | |
mel_lens (torch.Tensor): The lengths of the input mel-spectrograms. | |
Returns: | |
torch.Tensor: The generated audio tensor. | |
""" | |
mel_mask = get_mask_from_lengths(mel_lens).unsqueeze(1).to(c.device) | |
c = c.masked_fill(mel_mask, self.mel_mask_value) | |
zero = torch.full( | |
(c.shape[0], c.shape[1], 10), self.mel_mask_value, device=c.device, | |
) | |
mel = torch.cat((c, zero), dim=2) | |
audio = self.generator(mel) | |
audio = audio[:, :, : -(self.hop_length * 10)] | |
audio_mask = get_mask_from_lengths(mel_lens * 256).unsqueeze(1) | |
return audio.masked_fill(audio_mask, 0.0) | |