PeechTTSv22050 / models /vocoder /univnet /multi_period_discriminator.py
nickovchinnikov's picture
Init
9d61c9b
import torch
from torch import nn
from torch.nn import Module
from models.config import VocoderModelConfig
from .discriminator_p import DiscriminatorP
class MultiPeriodDiscriminator(Module):
r"""MultiPeriodDiscriminator is a class that implements a multi-period discriminator network for the UnivNet vocoder.
Args:
model_config (VocoderModelConfig): The configuration object for the UnivNet vocoder model.
"""
def __init__(
self,
model_config: VocoderModelConfig,
):
super().__init__()
self.discriminators = nn.ModuleList(
[
DiscriminatorP(period, model_config=model_config)
for period in model_config.mpd.periods
],
)
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
r"""Forward pass of the multi-period discriminator network.
Args:
x (torch.Tensor): The input tensor of shape (batch_size, channels, time_steps).
Returns:
list: A list of output tensors from each discriminator network.
"""
return [disc(x) for disc in self.discriminators]