Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,453 Bytes
96fe5d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
from typing import List
import torch
import torchaudio
from torch import nn
import math
# from inspiremusic.wavtokenizer.decoder.modules import safe_log
from inspiremusic.wavtokenizer.encoder.modules import SEANetEncoder, SEANetDecoder
from inspiremusic.wavtokenizer.encoder import EncodecModel
from inspiremusic.wavtokenizer.encoder.quantization import ResidualVectorQuantizer
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
"""
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
Args:
x (Tensor): Input tensor.
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
Returns:
Tensor: Element-wise logarithm of the input tensor with clipping applied.
"""
return torch.log(torch.clip(x, min=clip_val))
def symlog(x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * torch.log1p(x.abs())
def symexp(x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * (torch.exp(x.abs()) - 1)
class FeatureExtractor(nn.Module):
"""Base class for feature extractors."""
def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Extract features from the given audio.
Args:
audio (Tensor): Input audio waveform.
Returns:
Tensor: Extracted features of shape (B, C, L), where B is the batch size,
C denotes output features, and L is the sequence length.
"""
raise NotImplementedError("Subclasses must implement the forward method.")
class MelSpectrogramFeatures(FeatureExtractor):
def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, n_mels=100, padding="center"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
n_mels=n_mels,
center=padding == "center",
power=1,
)
def forward(self, audio, **kwargs):
if self.padding == "same":
pad = self.mel_spec.win_length - self.mel_spec.hop_length
audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect")
mel = self.mel_spec(audio)
features = safe_log(mel)
return features
class EncodecFeatures(FeatureExtractor):
def __init__(
self,
encodec_model: str = "encodec_24khz",
bandwidths: List[float] = [1.5, 3.0, 6.0, 12.0],
train_codebooks: bool = False,
num_quantizers: int = 1,
dowmsamples: List[int] = [6, 5, 5, 4],
vq_bins: int = 16384,
vq_kmeans: int = 800,
):
super().__init__()
# breakpoint()
self.frame_rate = 25 # not use
# n_q = int(bandwidths[-1]*1000/(math.log2(2048) * self.frame_rate))
n_q = num_quantizers # important
encoder = SEANetEncoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2,
dimension=512, channels=1, n_filters=32, ratios=dowmsamples, activation='ELU',
kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2,
true_skip=False, compress=2)
decoder = SEANetDecoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2,
dimension=512, channels=1, n_filters=32, ratios=[8, 5, 4, 2], activation='ELU',
kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2,
true_skip=False, compress=2)
quantizer = ResidualVectorQuantizer(dimension=512, n_q=n_q, bins=vq_bins, kmeans_iters=vq_kmeans,
decay=0.99, kmeans_init=True)
# breakpoint()
if encodec_model == "encodec_24khz":
self.encodec = EncodecModel(encoder=encoder, decoder=decoder, quantizer=quantizer,
target_bandwidths=bandwidths, sample_rate=24000, channels=1)
else:
raise ValueError(
f"Unsupported encodec_model: {encodec_model}. Supported options are 'encodec_24khz'."
)
for param in self.encodec.parameters():
param.requires_grad = True
# self.num_q = n_q
# codebook_weights = torch.cat([vq.codebook for vq in self.encodec.quantizer.vq.layers[: self.num_q]], dim=0)
# self.codebook_weights = torch.nn.Parameter(codebook_weights, requires_grad=train_codebooks)
self.bandwidths = bandwidths
# @torch.no_grad()
# def get_encodec_codes(self, audio):
# audio = audio.unsqueeze(1)
# emb = self.encodec.encoder(audio)
# codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth)
# return codes
def forward(self, audio: torch.Tensor, bandwidth_id: torch.Tensor = torch.tensor(0)):
if self.training:
self.encodec.train()
audio = audio.unsqueeze(1) # audio(16,24000)
# breakpoint()
emb = self.encodec.encoder(audio)
q_res = self.encodec.quantizer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id])
quantized = q_res.quantized
codes = q_res.codes
commit_loss = q_res.penalty # codes(8,16,75),features(16,128,75)
return quantized, codes, commit_loss
# codes = self.get_encodec_codes(audio)
# # Instead of summing in the loop, it stores subsequent VQ dictionaries in a single `self.codebook_weights`
# # with offsets given by the number of bins, and finally summed in a vectorized operation.
# offsets = torch.arange(
# 0, self.encodec.quantizer.bins * len(codes), self.encodec.quantizer.bins, device=audio.device
# )
# embeddings_idxs = codes + offsets.view(-1, 1, 1)
# features = torch.nn.functional.embedding(embeddings_idxs, self.codebook_weights).sum(dim=0)
# return features.transpose(1, 2)
def infer(self, audio: torch.Tensor, bandwidth_id: torch.Tensor):
if self.training:
self.encodec.train()
audio = audio.unsqueeze(1) # audio(16,24000)
emb = self.encodec.encoder(audio)
q_res = self.encodec.quantizer.infer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id])
quantized = q_res.quantized
codes = q_res.codes
commit_loss = q_res.penalty # codes(8,16,75),features(16,128,75)
return quantized, codes, commit_loss
def _infer(self, audio: torch.Tensor, bandwidth_id: torch.Tensor = torch.tensor(0)):
if self.training:
self.encodec.train()
audio = audio.unsqueeze(1) # audio(16,24000)
emb = self.encodec.encoder(audio)
q_res = self.encodec.quantizer.infer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id])
quantized = q_res.quantized
codes = q_res.codes
commit_loss = q_res.penalty # codes(8,16,75),features(16,128,75)
return quantized, codes, commit_loss |