File size: 3,655 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from typing import Any, Tuple

import torch
from torch import nn
from torch.nn import Module
import torch.nn.functional as F
from torch.nn.utils.parametrizations import spectral_norm, weight_norm

from models.config import VocoderModelConfig


class DiscriminatorP(Module):
    r"""DiscriminatorP is a class that implements a discriminator network for the UnivNet vocoder.

    Args:
        period (int): The period of the Mel spectrogram.
        model_config (VocoderModelConfig): The configuration object for the UnivNet vocoder model.
    """

    def __init__(
        self,
        period: int,
        model_config: VocoderModelConfig,
    ):
        super().__init__()

        self.LRELU_SLOPE = model_config.mpd.lReLU_slope
        self.period = period

        kernel_size = model_config.mpd.kernel_size
        stride = model_config.mpd.stride

        norm_f: Any = (
            spectral_norm if model_config.mpd.use_spectral_norm else weight_norm
        )

        self.convs = nn.ModuleList(
            [
                norm_f(
                    nn.Conv2d(
                        1,
                        64,
                        (kernel_size, 1),
                        (stride, 1),
                        padding=(kernel_size // 2, 0),
                    ),
                ),
                norm_f(
                    nn.Conv2d(
                        64,
                        128,
                        (kernel_size, 1),
                        (stride, 1),
                        padding=(kernel_size // 2, 0),
                    ),
                ),
                norm_f(
                    nn.Conv2d(
                        128,
                        256,
                        (kernel_size, 1),
                        (stride, 1),
                        padding=(kernel_size // 2, 0),
                    ),
                ),
                norm_f(
                    nn.Conv2d(
                        256,
                        512,
                        (kernel_size, 1),
                        (stride, 1),
                        padding=(kernel_size // 2, 0),
                    ),
                ),
                norm_f(
                    nn.Conv2d(
                        512,
                        1024,
                        (kernel_size, 1),
                        1,
                        padding=(kernel_size // 2, 0),
                    ),
                ),
            ],
        )
        self.conv_post = norm_f(
            nn.Conv2d(
                1024,
                1,
                (3, 1),
                1,
                padding=(1, 0),
            ),
        )

    def forward(self, x: torch.Tensor) -> Tuple[list, torch.Tensor]:
        r"""Forward pass of the discriminator network.

        Args:
            x (torch.Tensor): The input tensor of shape (batch_size, channels, time_steps).

        Returns:
            Tuple[list, torch.Tensor]: A tuple containing a list of feature maps and the output tensor of shape (batch_size, period).
        """
        fmap = []

        # 1d to 2d
        b, c, t = x.shape
        if t % self.period != 0:  # pad first
            n_pad = self.period - (t % self.period)
            x = F.pad(x, (0, n_pad), "reflect")
            t = t + n_pad
        x = x.view(b, c, t // self.period, self.period)

        for layers in self.convs:
            x = layers(x.to(dtype=self.conv_post.weight.dtype))
            x = F.leaky_relu(x, self.LRELU_SLOPE)
            fmap.append(x)
        x = self.conv_post(x)
        fmap.append(x)
        x = torch.flatten(x, 1, -1)

        return fmap, x