PeechTTSv22050 / models /vocoder /univnet /multi_resolution_discriminator.py
nickovchinnikov's picture
Init
9d61c9b
import torch
from torch import nn
from torch.nn import Module
from models.config import VocoderModelConfig
from .discriminator_r import DiscriminatorR
class MultiResolutionDiscriminator(Module):
r"""Multi-resolution discriminator for the UnivNet vocoder.
This class implements a multi-resolution discriminator that consists of multiple DiscriminatorR instances, each operating at a different resolution.
Args:
model_config (VocoderModelConfig): Model configuration object.
Attributes:
resolutions (list): List of resolutions for each DiscriminatorR instance.
discriminators (nn.ModuleList): List of DiscriminatorR instances.
Methods:
forward(x): Computes the forward pass of the multi-resolution discriminator.
"""
def __init__(
self,
model_config: VocoderModelConfig,
):
super().__init__()
self.resolutions = model_config.mrd.resolutions
self.discriminators = nn.ModuleList(
[
DiscriminatorR(resolution, model_config=model_config)
for resolution in self.resolutions
],
)
def forward(self, x: torch.Tensor) -> list[tuple[torch.Tensor, torch.Tensor]]:
r"""Computes the forward pass of the multi-resolution discriminator.
Args:
x (torch.Tensor): Input tensor of shape [B, C, T].
Returns:
list: List of tuples containing the intermediate feature maps and the output scores for each `DiscriminatorR` instance.
"""
return [disc(x) for disc in self.discriminators] # [(feat, score), (feat, score), (feat, score)]