from typing import Sequence, Optional, Union import math import random import torch import torch.nn as nn import numpy as np from ..modules.seanet import SEANetEncoder, SEANetDecoder from ..quantization import ResidualVectorQuantizer class SoundStream(nn.Module): """ SoundStream model or EnCodec model. Args: n_filters (int): n_filters (int): Base width for the model. D (int): Intermediate representation dimension. target_bandwidths (Sequence[int]): Target bandwidths in K-bits/second. ratios (Sequence[int]): downsampling factors, whose multiplication is the hop size. sample_rate (int): wave sampling rate. bins (int): number of code words in a codebook. normalize (bool): audio normalization. """ def __init__( self, n_filters: int = 32, D: int = 512, target_bandwidths: Sequence[Union[int, float]] = [0.5, 1, 1.5, 2, 4, 6], ratios: Sequence[int] = [8, 5, 4, 2], # downsampling by 320 sample_rate: int = 16000, bins: int = 1024, normalize: bool = False, causal: bool = False, ): super().__init__() self.hop_length = np.prod(ratios) # total nb of codebooks, e.g., 6Kb/s, sr=16000 and hop_length=320 => nq = 12 n_q = int(1000 * target_bandwidths[-1] // (math.ceil(sample_rate / self.hop_length) * 10)) self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 50 Hz self.bits_per_codebook = int(math.log2(bins)) # 1024 => 10 self.target_bandwidths = target_bandwidths self.n_q = n_q self.sample_rate = sample_rate # Encoder model self.encoder = SEANetEncoder(n_filters=n_filters, dimension=D, ratios=ratios, causal=causal) # RVQ model self.quantizer = ResidualVectorQuantizer(dimension=D, n_q=n_q, bins=bins) # Decoder model self.decoder = SEANetDecoder(n_filters= n_filters, dimension=D, ratios=ratios, causal=causal) def get_last_layer(self): return self.decoder.layers[-1].weight def forward(self, x: torch.Tensor): e = self.encoder(x) # randomly select a band-width during training bw = self.target_bandwidths[random.randint(0, len(self.target_bandwidths) - 1)] # [0, len(target_bandwidths) - 1], both included quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw) # print('quantized ', quantized.shape) # print('codes ', codes.shape) # print('commit_loss ', commit_loss) # print('bandwidth ', bandwidth) # assert 1==2 #quantized = quantized.permute(0,2,1) o = self.decoder(quantized) # print('o ', o.shape) # assert 1==2 return o, commit_loss, None def encode(self, x: torch.Tensor, target_bw: Optional[int] = None) -> torch.Tensor: e = self.encoder(x) if target_bw is None: bw = self.target_bandwidths[-1] else: bw = target_bw codes = self.quantizer.encode(e, self.frame_rate, bw) return codes def decode(self, codes: torch.Tensor) -> torch.Tensor: quantized = self.quantizer.decode(codes) o = self.decoder(quantized) return o # test if __name__ == '__main__': soundstream = SoundStream(n_filters=32, D=256) for i in range(10): print(f"Iter {i}: ") x = torch.rand(1, 1, 16000) o, _, _ = soundstream(x) print('output', o.shape)