|
from typing import List |
|
|
|
import torch |
|
import torchaudio |
|
|
|
from torch import nn |
|
from omegaconf import OmegaConf |
|
from vocos.modules import safe_log |
|
|
|
|
|
class FeatureExtractor(nn.Module): |
|
"""Base class for feature extractors.""" |
|
|
|
def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor: |
|
""" |
|
Extract features from the given audio. |
|
|
|
Args: |
|
audio (Tensor): Input audio waveform. |
|
|
|
Returns: |
|
Tensor: Extracted features of shape (B, C, L), where B is the batch size, |
|
C denotes output features, and L is the sequence length. |
|
""" |
|
raise NotImplementedError("Subclasses must implement the forward method.") |
|
|
|
|
|
class MelSpectrogramFeatures(FeatureExtractor): |
|
def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, n_mels=100, padding="center"): |
|
super().__init__() |
|
if padding not in ["center", "same"]: |
|
raise ValueError("Padding must be 'center' or 'same'.") |
|
self.padding = padding |
|
self.mel_spec = torchaudio.transforms.MelSpectrogram( |
|
sample_rate=sample_rate, |
|
n_fft=n_fft, |
|
hop_length=hop_length, |
|
n_mels=n_mels, |
|
center=padding == "center", |
|
power=1, |
|
) |
|
|
|
def forward(self, audio, **kwargs): |
|
if self.padding == "same": |
|
pad = self.mel_spec.win_length - self.mel_spec.hop_length |
|
audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect") |
|
mel = self.mel_spec(audio) |
|
features = safe_log(mel) |
|
return features |
|
|
|
|
|
class EncodecFeatures(FeatureExtractor): |
|
def __init__( |
|
self, |
|
encodec_model: str = "encodec_24khz", |
|
bandwidths: List[float] = [1.5, 3.0, 6.0, 12.0], |
|
train_codebooks: bool = False, |
|
): |
|
super().__init__() |
|
if encodec_model == "encodec_24khz": |
|
encodec = EncodecModel.encodec_model_24khz |
|
elif encodec_model == "encodec_48khz": |
|
encodec = EncodecModel.encodec_model_48khz |
|
else: |
|
raise ValueError( |
|
f"Unsupported encodec_model: {encodec_model}. Supported options are 'encodec_24khz' and 'encodec_48khz'." |
|
) |
|
self.encodec = encodec(pretrained=True) |
|
for param in self.encodec.parameters(): |
|
param.requires_grad = False |
|
self.num_q = self.encodec.quantizer.get_num_quantizers_for_bandwidth( |
|
self.encodec.frame_rate, bandwidth=max(bandwidths) |
|
) |
|
codebook_weights = torch.cat([vq.codebook for vq in self.encodec.quantizer.vq.layers[: self.num_q]], dim=0) |
|
self.codebook_weights = torch.nn.Parameter(codebook_weights, requires_grad=train_codebooks) |
|
self.bandwidths = bandwidths |
|
|
|
@torch.no_grad() |
|
def get_encodec_codes(self, audio): |
|
audio = audio.unsqueeze(1) |
|
emb = self.encodec.encoder(audio) |
|
codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth) |
|
return codes |
|
|
|
def forward(self, audio: torch.Tensor, **kwargs): |
|
bandwidth_id = kwargs.get("bandwidth_id") |
|
if bandwidth_id is None: |
|
raise ValueError("The 'bandwidth_id' argument is required") |
|
self.encodec.eval() |
|
self.encodec.set_target_bandwidth(self.bandwidths[bandwidth_id]) |
|
codes = self.get_encodec_codes(audio) |
|
|
|
|
|
offsets = torch.arange( |
|
0, self.encodec.quantizer.bins * len(codes), self.encodec.quantizer.bins, device=audio.device |
|
) |
|
embeddings_idxs = codes + offsets.view(-1, 1, 1) |
|
features = torch.nn.functional.embedding(embeddings_idxs, self.codebook_weights).sum(dim=0) |
|
return features.transpose(1, 2) |
|
|
|
class xCodecFeatures(FeatureExtractor): |
|
def __init__( |
|
self, |
|
config: str, |
|
ckpt: str, |
|
): |
|
super().__init__() |
|
self.config = OmegaConf.load(config) |
|
self.model = eval(self.config.generator.name)(**self.config.generator.config) |
|
parameter_dict = torch.load(ckpt, map_location='cpu') |
|
self.model.load_state_dict(parameter_dict['codec_model']) |
|
self.resampler = torchaudio.transforms.Resample(orig_freq=44100, new_freq=16000).to('cuda') |
|
self.model.eval() |
|
|
|
def forward(self, audio: torch.Tensor): |
|
|
|
audio = self.resampler(audio) |
|
with torch.no_grad(): |
|
codes = self.model.encode(audio, target_bw=6) |
|
return codes |