nickovchinnikov's picture
Init
9d61c9b
raw
history blame
8 kB
from typing import Tuple
import torch
from torch import nn
from .helpers import average_over_durations
from .variance_predictor import VariancePredictor
class EnergyAdaptor(nn.Module):
"""Variance Adaptor with an added 1D conv layer. Used to
get energy embeddings.
Args:
channels_in (int): Number of in channels for conv layers.
channels_out (int): Number of out channels.
kernel_size (int): Size the kernel for the conv layers.
dropout (float): Probability of dropout.
leaky_relu_slope (float): Slope for the leaky relu.
emb_kernel_size (int): Size the kernel for the pitch embedding.
Inputs: inputs, mask
- **inputs** (batch, time1, dim): Tensor containing input vector
- **target** (batch, 1, time2): Tensor containing the energy target
- **dr** (batch, time1): Tensor containing aligner durations vector
- **mask** (batch, time1): Tensor containing indices to be masked
Returns:
- **energy prediction** (batch, 1, time1): Tensor produced by energy predictor
- **energy embedding** (batch, channels, time1): Tensor produced energy adaptor
- **average energy target(train only)** (batch, 1, time1): Tensor produced after averaging over durations
"""
def __init__(
self,
channels_in: int,
channels_hidden: int,
channels_out: int,
kernel_size: int,
dropout: float,
leaky_relu_slope: float,
emb_kernel_size: int,
):
super().__init__()
self.energy_predictor = VariancePredictor(
channels_in=channels_in,
channels=channels_hidden,
channels_out=channels_out,
kernel_size=kernel_size,
p_dropout=dropout,
leaky_relu_slope=leaky_relu_slope,
)
self.energy_emb = nn.Conv1d(
1,
channels_hidden,
kernel_size=emb_kernel_size,
padding=int((emb_kernel_size - 1) / 2),
)
def get_energy_embedding_train(
self,
x: torch.Tensor,
target: torch.Tensor,
dr: torch.Tensor,
mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r"""Function is used during training to get the energy prediction, average energy target, and energy embedding.
Args:
x (torch.Tensor): A 3D tensor of shape [B, T_src, C] where B is the batch size,
T_src is the source sequence length, and C is the number of channels.
target (torch.Tensor): A 3D tensor of shape [B, 1, T_max2] where B is the batch size,
T_max2 is the maximum target sequence length.
dr (torch.Tensor): A 2D tensor of shape [B, T_src] where B is the batch size,
T_src is the source sequence length. The values represent the durations.
mask (torch.Tensor): A 2D tensor of shape [B, T_src] where B is the batch size,
T_src is the source sequence length. The values represent the mask.
Returns:
energy_pred (torch.Tensor): A 3D tensor of shape [B, 1, T_src] where B is the batch size,
T_src is the source sequence length. The values represent the energy prediction.
avg_energy_target (torch.Tensor): A 3D tensor of shape [B, 1, T_src] where B is the batch size,
T_src is the source sequence length. The values represent the average energy target.
energy_emb (torch.Tensor): A 3D tensor of shape [B, C, T_src] where B is the batch size,
C is the number of channels, T_src is the source sequence length. The values represent the energy embedding.
Shapes:
x: :math: `[B, T_src, C]`
target: :math: `[B, 1, T_max2]`
dr: :math: `[B, T_src]`
mask: :math: `[B, T_src]`
"""
energy_pred = self.energy_predictor.forward(x, mask)
energy_pred = energy_pred.unsqueeze(1)
avg_energy_target = average_over_durations(target, dr)
energy_emb = self.energy_emb(avg_energy_target)
return energy_pred, avg_energy_target, energy_emb
def add_energy_embedding_train(
self,
x: torch.Tensor,
target: torch.Tensor,
dr: torch.Tensor,
mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r"""Add energy embedding during training.
This method calculates the energy embedding and adds it to the input tensor 'x'.
It also returns the predicted energy and the average target energy.
Args:
x (torch.Tensor): The input tensor to which the energy embedding will be added.
target (torch.Tensor): The target tensor used in the energy embedding calculation.
dr (torch.Tensor): The duration tensor used in the energy embedding calculation.
mask (torch.Tensor): The mask tensor used in the energy embedding calculation.
Returns:
x (torch.Tensor): The input tensor with added energy embedding.
energy_pred (torch.Tensor): The predicted energy tensor.
avg_energy_target (torch.Tensor): The average target energy tensor.
"""
energy_pred, avg_energy_target, energy_emb = self.get_energy_embedding_train(
x=x,
target=target,
dr=dr,
mask=mask,
)
x_energy = x + energy_emb.transpose(1, 2)
return x_energy, energy_pred, avg_energy_target
def get_energy_embedding(
self,
x: torch.Tensor,
mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Function is used during inference to get the energy embedding and energy prediction.
Args:
x (torch.Tensor): A 3D tensor of shape [B, T_src, C] where B is the batch size,
T_src is the source sequence length, and C is the number of channels.
mask (torch.Tensor): A 2D tensor of shape [B, T_src] where B is the batch size,
T_src is the source sequence length. The values represent the mask.
Returns:
energy_emb_pred (torch.Tensor): A 3D tensor of shape [B, C, T_src] where B is the batch size,
C is the number of channels, T_src is the source sequence length. The values represent the energy embedding.
energy_pred (torch.Tensor): A 3D tensor of shape [B, 1, T_src] where B is the batch size,
T_src is the source sequence length. The values represent the energy prediction.
"""
energy_pred = self.energy_predictor.forward(x, mask)
energy_pred = energy_pred.unsqueeze(1)
energy_emb_pred = self.energy_emb(energy_pred)
return energy_emb_pred, energy_pred
def add_energy_embedding(
self,
x: torch.Tensor,
mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Add energy embedding during inference.
This method calculates the energy embedding and adds it to the input tensor 'x'.
It also returns the predicted energy.
Args:
x (torch.Tensor): The input tensor to which the energy embedding will be added.
mask (torch.Tensor): The mask tensor used in the energy embedding calculation.
energy_transform (Callable): A function to transform the energy prediction.
Returns:
x (torch.Tensor): The input tensor with added energy embedding.
energy_pred (torch.Tensor): The predicted energy tensor.
"""
energy_emb_pred, energy_pred = self.get_energy_embedding(x, mask)
x_energy = x + energy_emb_pred.transpose(1, 2)
return x_energy, energy_pred