File size: 8,000 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
from typing import Tuple

import torch
from torch import nn

from .helpers import average_over_durations
from .variance_predictor import VariancePredictor


class EnergyAdaptor(nn.Module):
    """Variance Adaptor with an added 1D conv layer. Used to
    get energy embeddings.

    Args:
        channels_in (int): Number of in channels for conv layers.
        channels_out (int): Number of out channels.
        kernel_size (int): Size the kernel for the conv layers.
        dropout (float): Probability of dropout.
        leaky_relu_slope (float): Slope for the leaky relu.
        emb_kernel_size (int): Size the kernel for the pitch embedding.

    Inputs: inputs, mask
        - **inputs** (batch, time1, dim): Tensor containing input vector
        - **target** (batch, 1, time2): Tensor containing the energy target
        - **dr** (batch, time1): Tensor containing aligner durations vector
        - **mask** (batch, time1): Tensor containing indices to be masked
    Returns:
        - **energy prediction** (batch, 1, time1): Tensor produced by energy predictor
        - **energy embedding** (batch, channels, time1): Tensor produced energy adaptor
        - **average energy target(train only)** (batch, 1, time1): Tensor produced after averaging over durations

    """

    def __init__(
        self,
        channels_in: int,
        channels_hidden: int,
        channels_out: int,
        kernel_size: int,
        dropout: float,
        leaky_relu_slope: float,
        emb_kernel_size: int,
    ):
        super().__init__()
        self.energy_predictor = VariancePredictor(
            channels_in=channels_in,
            channels=channels_hidden,
            channels_out=channels_out,
            kernel_size=kernel_size,
            p_dropout=dropout,
            leaky_relu_slope=leaky_relu_slope,
        )
        self.energy_emb = nn.Conv1d(
            1,
            channels_hidden,
            kernel_size=emb_kernel_size,
            padding=int((emb_kernel_size - 1) / 2),
        )

    def get_energy_embedding_train(
        self,
        x: torch.Tensor,
        target: torch.Tensor,
        dr: torch.Tensor,
        mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        r"""Function is used during training to get the energy prediction, average energy target, and energy embedding.

        Args:
            x (torch.Tensor): A 3D tensor of shape [B, T_src, C] where B is the batch size,
                            T_src is the source sequence length, and C is the number of channels.
            target (torch.Tensor): A 3D tensor of shape [B, 1, T_max2] where B is the batch size,
                                T_max2 is the maximum target sequence length.
            dr (torch.Tensor): A 2D tensor of shape [B, T_src] where B is the batch size,
                                T_src is the source sequence length. The values represent the durations.
            mask (torch.Tensor): A 2D tensor of shape [B, T_src] where B is the batch size,
                                T_src is the source sequence length. The values represent the mask.

        Returns:
            energy_pred (torch.Tensor): A 3D tensor of shape [B, 1, T_src] where B is the batch size,
                                        T_src is the source sequence length. The values represent the energy prediction.
            avg_energy_target (torch.Tensor): A 3D tensor of shape [B, 1, T_src] where B is the batch size,
                                            T_src is the source sequence length. The values represent the average energy target.
            energy_emb (torch.Tensor): A 3D tensor of shape [B, C, T_src] where B is the batch size,
                                    C is the number of channels, T_src is the source sequence length. The values represent the energy embedding.
        Shapes:
            x: :math: `[B, T_src, C]`
            target: :math: `[B, 1, T_max2]`
            dr: :math: `[B, T_src]`
            mask: :math: `[B, T_src]`
        """
        energy_pred = self.energy_predictor.forward(x, mask)
        energy_pred = energy_pred.unsqueeze(1)

        avg_energy_target = average_over_durations(target, dr)
        energy_emb = self.energy_emb(avg_energy_target)

        return energy_pred, avg_energy_target, energy_emb

    def add_energy_embedding_train(
        self,
        x: torch.Tensor,
        target: torch.Tensor,
        dr: torch.Tensor,
        mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        r"""Add energy embedding during training.

        This method calculates the energy embedding and adds it to the input tensor 'x'.
        It also returns the predicted energy and the average target energy.

        Args:
            x (torch.Tensor): The input tensor to which the energy embedding will be added.
            target (torch.Tensor): The target tensor used in the energy embedding calculation.
            dr (torch.Tensor): The duration tensor used in the energy embedding calculation.
            mask (torch.Tensor): The mask tensor used in the energy embedding calculation.

        Returns:
            x (torch.Tensor): The input tensor with added energy embedding.
            energy_pred (torch.Tensor): The predicted energy tensor.
            avg_energy_target (torch.Tensor): The average target energy tensor.
        """
        energy_pred, avg_energy_target, energy_emb = self.get_energy_embedding_train(
            x=x,
            target=target,
            dr=dr,
            mask=mask,
        )
        x_energy = x + energy_emb.transpose(1, 2)
        return x_energy, energy_pred, avg_energy_target

    def get_energy_embedding(
        self,
        x: torch.Tensor,
        mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Function is used during inference to get the energy embedding and energy prediction.

        Args:
            x (torch.Tensor): A 3D tensor of shape [B, T_src, C] where B is the batch size,
                            T_src is the source sequence length, and C is the number of channels.
            mask (torch.Tensor): A 2D tensor of shape [B, T_src] where B is the batch size,
                                T_src is the source sequence length. The values represent the mask.

        Returns:
            energy_emb_pred (torch.Tensor): A 3D tensor of shape [B, C, T_src] where B is the batch size,
                                            C is the number of channels, T_src is the source sequence length. The values represent the energy embedding.
            energy_pred (torch.Tensor): A 3D tensor of shape [B, 1, T_src] where B is the batch size,
                                        T_src is the source sequence length. The values represent the energy prediction.
        """
        energy_pred = self.energy_predictor.forward(x, mask)
        energy_pred = energy_pred.unsqueeze(1)

        energy_emb_pred = self.energy_emb(energy_pred)
        return energy_emb_pred, energy_pred

    def add_energy_embedding(
        self,
        x: torch.Tensor,
        mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Add energy embedding during inference.

        This method calculates the energy embedding and adds it to the input tensor 'x'.
        It also returns the predicted energy.

        Args:
            x (torch.Tensor): The input tensor to which the energy embedding will be added.
            mask (torch.Tensor): The mask tensor used in the energy embedding calculation.
            energy_transform (Callable): A function to transform the energy prediction.

        Returns:
            x (torch.Tensor): The input tensor with added energy embedding.
            energy_pred (torch.Tensor): The predicted energy tensor.
        """
        energy_emb_pred, energy_pred = self.get_energy_embedding(x, mask)
        x_energy = x + energy_emb_pred.transpose(1, 2)
        return x_energy, energy_pred