Spaces:
Sleeping
Sleeping
File size: 1,151 Bytes
9d61c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import torch
from torch import nn
from torch.nn import Module
from models.config import VocoderModelConfig
from .discriminator_p import DiscriminatorP
class MultiPeriodDiscriminator(Module):
r"""MultiPeriodDiscriminator is a class that implements a multi-period discriminator network for the UnivNet vocoder.
Args:
model_config (VocoderModelConfig): The configuration object for the UnivNet vocoder model.
"""
def __init__(
self,
model_config: VocoderModelConfig,
):
super().__init__()
self.discriminators = nn.ModuleList(
[
DiscriminatorP(period, model_config=model_config)
for period in model_config.mpd.periods
],
)
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
r"""Forward pass of the multi-period discriminator network.
Args:
x (torch.Tensor): The input tensor of shape (batch_size, channels, time_steps).
Returns:
list: A list of output tensors from each discriminator network.
"""
return [disc(x) for disc in self.discriminators]
|