|
from typing import Optional |
|
|
|
import torch |
|
from torch import nn |
|
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz |
|
|
|
from vocos.spectral_ops import IMDCT, ISTFT |
|
from vocos.modules import symexp |
|
|
|
|
|
class FourierHead(nn.Module): |
|
"""Base class for inverse fourier modules.""" |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Args: |
|
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, |
|
L is the sequence length, and H denotes the model dimension. |
|
|
|
Returns: |
|
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. |
|
""" |
|
raise NotImplementedError("Subclasses must implement the forward method.") |
|
|
|
|
|
class ISTFTHead(FourierHead): |
|
""" |
|
ISTFT Head module for predicting STFT complex coefficients. |
|
|
|
Args: |
|
dim (int): Hidden dimension of the model. |
|
n_fft (int): Size of Fourier transform. |
|
hop_length (int): The distance between neighboring sliding window frames, which should align with |
|
the resolution of the input features. |
|
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". |
|
""" |
|
|
|
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same", ckpt: Optional[str] = None): |
|
super().__init__() |
|
out_dim = n_fft + 2 |
|
self.out = torch.nn.Linear(dim, out_dim) |
|
self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding) |
|
|
|
if ckpt is not None: |
|
params = torch.load(ckpt, map_location="cpu") |
|
|
|
out_weight = params["head.out.weight"] |
|
out_bias = params["head.out.bias"] |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Forward pass of the ISTFTHead module. |
|
|
|
Args: |
|
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, |
|
L is the sequence length, and H denotes the model dimension. |
|
|
|
Returns: |
|
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. |
|
""" |
|
x = self.out(x).transpose(1, 2) |
|
mag, p = x.chunk(2, dim=1) |
|
mag = torch.exp(mag) |
|
mag = torch.clip(mag, max=1e2) |
|
|
|
x = torch.cos(p) |
|
y = torch.sin(p) |
|
|
|
|
|
|
|
|
|
|
|
S = mag * (x + 1j * y) |
|
audio = self.istft(S) |
|
return audio |
|
|
|
|
|
class IMDCTSymExpHead(FourierHead): |
|
""" |
|
IMDCT Head module for predicting MDCT coefficients with symmetric exponential function |
|
|
|
Args: |
|
dim (int): Hidden dimension of the model. |
|
mdct_frame_len (int): Length of the MDCT frame. |
|
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". |
|
sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized |
|
based on perceptual scaling. Defaults to None. |
|
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
mdct_frame_len: int, |
|
padding: str = "same", |
|
sample_rate: Optional[int] = None, |
|
clip_audio: bool = False, |
|
): |
|
super().__init__() |
|
out_dim = mdct_frame_len // 2 |
|
self.out = nn.Linear(dim, out_dim) |
|
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) |
|
self.clip_audio = clip_audio |
|
|
|
if sample_rate is not None: |
|
|
|
m_max = _hz_to_mel(sample_rate // 2) |
|
m_pts = torch.linspace(0, m_max, out_dim) |
|
f_pts = _mel_to_hz(m_pts) |
|
scale = 1 - (f_pts / f_pts.max()) |
|
|
|
with torch.no_grad(): |
|
self.out.weight.mul_(scale.view(-1, 1)) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Forward pass of the IMDCTSymExpHead module. |
|
|
|
Args: |
|
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, |
|
L is the sequence length, and H denotes the model dimension. |
|
|
|
Returns: |
|
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. |
|
""" |
|
x = self.out(x) |
|
x = symexp(x) |
|
x = torch.clip(x, min=-1e2, max=1e2) |
|
audio = self.imdct(x) |
|
if self.clip_audio: |
|
audio = torch.clip(x, min=-1.0, max=1.0) |
|
|
|
return audio |
|
|
|
|
|
class IMDCTCosHead(FourierHead): |
|
""" |
|
IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p) |
|
|
|
Args: |
|
dim (int): Hidden dimension of the model. |
|
mdct_frame_len (int): Length of the MDCT frame. |
|
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". |
|
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. |
|
""" |
|
|
|
def __init__(self, dim: int, mdct_frame_len: int, padding: str = "same", clip_audio: bool = False): |
|
super().__init__() |
|
self.clip_audio = clip_audio |
|
self.out = nn.Linear(dim, mdct_frame_len) |
|
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Forward pass of the IMDCTCosHead module. |
|
|
|
Args: |
|
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, |
|
L is the sequence length, and H denotes the model dimension. |
|
|
|
Returns: |
|
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. |
|
""" |
|
x = self.out(x) |
|
m, p = x.chunk(2, dim=2) |
|
m = torch.exp(m).clip(max=1e2) |
|
audio = self.imdct(m * torch.cos(p)) |
|
if self.clip_audio: |
|
audio = torch.clip(x, min=-1.0, max=1.0) |
|
return audio |
|
|