Spaces:
Sleeping
Sleeping
File size: 1,657 Bytes
9d61c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
from torch import nn
from torch.nn import Module
from models.config import VocoderModelConfig
from .discriminator_r import DiscriminatorR
class MultiResolutionDiscriminator(Module):
r"""Multi-resolution discriminator for the UnivNet vocoder.
This class implements a multi-resolution discriminator that consists of multiple DiscriminatorR instances, each operating at a different resolution.
Args:
model_config (VocoderModelConfig): Model configuration object.
Attributes:
resolutions (list): List of resolutions for each DiscriminatorR instance.
discriminators (nn.ModuleList): List of DiscriminatorR instances.
Methods:
forward(x): Computes the forward pass of the multi-resolution discriminator.
"""
def __init__(
self,
model_config: VocoderModelConfig,
):
super().__init__()
self.resolutions = model_config.mrd.resolutions
self.discriminators = nn.ModuleList(
[
DiscriminatorR(resolution, model_config=model_config)
for resolution in self.resolutions
],
)
def forward(self, x: torch.Tensor) -> list[tuple[torch.Tensor, torch.Tensor]]:
r"""Computes the forward pass of the multi-resolution discriminator.
Args:
x (torch.Tensor): Input tensor of shape [B, C, T].
Returns:
list: List of tuples containing the intermediate feature maps and the output scores for each `DiscriminatorR` instance.
"""
return [disc(x) for disc in self.discriminators] # [(feat, score), (feat, score), (feat, score)]
|