Spaces:
Running
on
Zero
Running
on
Zero
from typing import Sequence, Optional, Union | |
import math | |
import random | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from ..modules.seanet import SEANetEncoder, SEANetDecoder | |
from ..quantization import ResidualVectorQuantizer | |
class SoundStream(nn.Module): | |
""" SoundStream model or EnCodec model. | |
Args: | |
n_filters (int): n_filters (int): Base width for the model. | |
D (int): Intermediate representation dimension. | |
target_bandwidths (Sequence[int]): Target bandwidths in K-bits/second. | |
ratios (Sequence[int]): downsampling factors, whose multiplication is the hop size. | |
sample_rate (int): wave sampling rate. | |
bins (int): number of code words in a codebook. | |
normalize (bool): audio normalization. | |
""" | |
def __init__( | |
self, | |
n_filters: int = 32, | |
D: int = 512, | |
target_bandwidths: Sequence[Union[int, float]] = [0.5, 1, 1.5, 2, 4, 6], | |
ratios: Sequence[int] = [8, 5, 4, 2], # downsampling by 320 | |
sample_rate: int = 16000, | |
bins: int = 1024, | |
normalize: bool = False, | |
causal: bool = False, | |
): | |
super().__init__() | |
self.hop_length = np.prod(ratios) | |
# total nb of codebooks, e.g., 6Kb/s, sr=16000 and hop_length=320 => nq = 12 | |
n_q = int(1000 * target_bandwidths[-1] // (math.ceil(sample_rate / self.hop_length) * 10)) | |
self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 50 Hz | |
self.bits_per_codebook = int(math.log2(bins)) # 1024 => 10 | |
self.target_bandwidths = target_bandwidths | |
self.n_q = n_q | |
self.sample_rate = sample_rate | |
# Encoder model | |
self.encoder = SEANetEncoder(n_filters=n_filters, dimension=D, ratios=ratios, causal=causal) | |
# RVQ model | |
self.quantizer = ResidualVectorQuantizer(dimension=D, n_q=n_q, bins=bins) | |
# Decoder model | |
self.decoder = SEANetDecoder(n_filters= n_filters, dimension=D, ratios=ratios, causal=causal) | |
def get_last_layer(self): | |
return self.decoder.layers[-1].weight | |
def forward(self, x: torch.Tensor): | |
e = self.encoder(x) | |
# randomly select a band-width during training | |
bw = self.target_bandwidths[random.randint(0, len(self.target_bandwidths) - 1)] # [0, len(target_bandwidths) - 1], both included | |
quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw) | |
# print('quantized ', quantized.shape) | |
# print('codes ', codes.shape) | |
# print('commit_loss ', commit_loss) | |
# print('bandwidth ', bandwidth) | |
# assert 1==2 | |
#quantized = quantized.permute(0,2,1) | |
o = self.decoder(quantized) | |
# print('o ', o.shape) | |
# assert 1==2 | |
return o, commit_loss, None | |
def encode(self, x: torch.Tensor, target_bw: Optional[int] = None) -> torch.Tensor: | |
e = self.encoder(x) | |
if target_bw is None: | |
bw = self.target_bandwidths[-1] | |
else: | |
bw = target_bw | |
codes = self.quantizer.encode(e, self.frame_rate, bw) | |
return codes | |
def decode(self, codes: torch.Tensor) -> torch.Tensor: | |
quantized = self.quantizer.decode(codes) | |
o = self.decoder(quantized) | |
return o | |
# test | |
if __name__ == '__main__': | |
soundstream = SoundStream(n_filters=32, D=256) | |
for i in range(10): | |
print(f"Iter {i}: ") | |
x = torch.rand(1, 1, 16000) | |
o, _, _ = soundstream(x) | |
print('output', o.shape) | |