Spaces:
Running
Running
from typing import List, Tuple | |
import torch | |
from torch.nn import Module | |
from models.config import AcousticModelConfigType | |
from models.helpers import tools | |
from .variance_predictor import VariancePredictor | |
class LengthAdaptor(Module): | |
r"""DEPRECATED: The LengthAdaptor module is used to adjust the duration of phonemes. | |
It contains a dedicated duration predictor and methods to upsample the input features to match predicted durations. | |
Args: | |
model_config (AcousticModelConfigType): The model configuration object containing model parameters. | |
""" | |
def __init__( | |
self, | |
model_config: AcousticModelConfigType, | |
): | |
super().__init__() | |
# Initialize the duration predictor | |
self.duration_predictor = VariancePredictor( | |
channels_in=model_config.encoder.n_hidden, | |
channels=model_config.variance_adaptor.n_hidden, | |
channels_out=1, | |
kernel_size=model_config.variance_adaptor.kernel_size, | |
p_dropout=model_config.variance_adaptor.p_dropout, | |
) | |
def length_regulate( | |
self, | |
x: torch.Tensor, | |
duration: torch.Tensor, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
r"""Regulates the length of the input tensor using the duration tensor. | |
Args: | |
x (torch.Tensor): The input tensor. | |
duration (torch.Tensor): The tensor containing duration for each time step in x. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: The regulated output tensor and the tensor containing the length of each sequence in the batch. | |
""" | |
output = torch.jit.annotate(List[torch.Tensor], []) | |
mel_len = torch.jit.annotate(List[int], []) | |
max_len = 0 | |
for batch, expand_target in zip(x, duration): | |
expanded = self.expand(batch, expand_target) | |
if expanded.shape[0] > max_len: | |
max_len = expanded.shape[0] | |
output.append(expanded) | |
mel_len.append(expanded.shape[0]) | |
output = tools.pad(output, max_len) | |
return output, torch.tensor(mel_len, dtype=torch.int64) | |
def expand(self, batch: torch.Tensor, predicted: torch.Tensor) -> torch.Tensor: | |
r"""Expands the input tensor based on the predicted values. | |
Args: | |
batch (torch.Tensor): The input tensor. | |
predicted (torch.Tensor): The tensor containing predicted expansion factors. | |
Returns: | |
torch.Tensor: The expanded tensor. | |
""" | |
out = torch.jit.annotate(List[torch.Tensor], []) | |
for i, vec in enumerate(batch): | |
expand_size = predicted[i].item() | |
out.append(vec.expand(max(int(expand_size), 0), -1)) | |
return torch.cat(out, 0) | |
def upsample_train( | |
self, | |
x: torch.Tensor, | |
x_res: torch.Tensor, | |
duration_target: torch.Tensor, | |
embeddings: torch.Tensor, | |
src_mask: torch.Tensor, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
r"""Upsamples the input tensor during training using ground truth durations. | |
Args: | |
x (torch.Tensor): The input tensor. | |
x_res (torch.Tensor): Another input tensor for duration prediction. | |
duration_target (torch.Tensor): The ground truth durations tensor. | |
embeddings (torch.Tensor): The tensor containing phoneme embeddings. | |
src_mask (torch.Tensor): The mask tensor indicating valid entries in x and x_res. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The upsampled x, log duration prediction, and upsampled embeddings. | |
""" | |
x_res = x_res.detach() | |
log_duration_prediction = self.duration_predictor( | |
x_res, | |
src_mask, | |
) # type: torch.Tensor | |
x, _ = self.length_regulate(x, duration_target) | |
embeddings, _ = self.length_regulate(embeddings, duration_target) | |
return x, log_duration_prediction, embeddings | |
def upsample( | |
self, | |
x: torch.Tensor, | |
x_res: torch.Tensor, | |
src_mask: torch.Tensor, | |
embeddings: torch.Tensor, | |
control: float, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
r"""Upsamples the input tensor during inference. | |
Args: | |
x (torch.Tensor): The input tensor. | |
x_res (torch.Tensor): Another input tensor for duration prediction. | |
src_mask (torch.Tensor): The mask tensor indicating valid entries in x and x_res. | |
embeddings (torch.Tensor): The tensor containing phoneme embeddings. | |
control (float): A control parameter for pitch regulation. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The upsampled x, approximated duration, and upsampled embeddings. | |
""" | |
log_duration_prediction = self.duration_predictor( | |
x_res, | |
src_mask, | |
) | |
duration_rounded = torch.clamp( | |
(torch.round(torch.exp(log_duration_prediction) - 1) * control), | |
min=0, | |
) | |
x, _ = self.length_regulate(x, duration_rounded) | |
embeddings, _ = self.length_regulate(embeddings, duration_rounded) | |
return x, duration_rounded, embeddings | |