Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,529 Bytes
a84a65c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
from typing import Sequence, Optional, Union
import math
import random
import torch
import torch.nn as nn
import numpy as np
from ..modules.seanet import SEANetEncoder, SEANetDecoder
from ..quantization import ResidualVectorQuantizer
class SoundStream(nn.Module):
""" SoundStream model or EnCodec model.
Args:
n_filters (int): n_filters (int): Base width for the model.
D (int): Intermediate representation dimension.
target_bandwidths (Sequence[int]): Target bandwidths in K-bits/second.
ratios (Sequence[int]): downsampling factors, whose multiplication is the hop size.
sample_rate (int): wave sampling rate.
bins (int): number of code words in a codebook.
normalize (bool): audio normalization.
"""
def __init__(
self,
n_filters: int = 32,
D: int = 512,
target_bandwidths: Sequence[Union[int, float]] = [0.5, 1, 1.5, 2, 4, 6],
ratios: Sequence[int] = [8, 5, 4, 2], # downsampling by 320
sample_rate: int = 16000,
bins: int = 1024,
normalize: bool = False,
causal: bool = False,
):
super().__init__()
self.hop_length = np.prod(ratios)
# total nb of codebooks, e.g., 6Kb/s, sr=16000 and hop_length=320 => nq = 12
n_q = int(1000 * target_bandwidths[-1] // (math.ceil(sample_rate / self.hop_length) * 10))
self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 50 Hz
self.bits_per_codebook = int(math.log2(bins)) # 1024 => 10
self.target_bandwidths = target_bandwidths
self.n_q = n_q
self.sample_rate = sample_rate
# Encoder model
self.encoder = SEANetEncoder(n_filters=n_filters, dimension=D, ratios=ratios, causal=causal)
# RVQ model
self.quantizer = ResidualVectorQuantizer(dimension=D, n_q=n_q, bins=bins)
# Decoder model
self.decoder = SEANetDecoder(n_filters= n_filters, dimension=D, ratios=ratios, causal=causal)
def get_last_layer(self):
return self.decoder.layers[-1].weight
def forward(self, x: torch.Tensor):
e = self.encoder(x)
# randomly select a band-width during training
bw = self.target_bandwidths[random.randint(0, len(self.target_bandwidths) - 1)] # [0, len(target_bandwidths) - 1], both included
quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw)
# print('quantized ', quantized.shape)
# print('codes ', codes.shape)
# print('commit_loss ', commit_loss)
# print('bandwidth ', bandwidth)
# assert 1==2
#quantized = quantized.permute(0,2,1)
o = self.decoder(quantized)
# print('o ', o.shape)
# assert 1==2
return o, commit_loss, None
def encode(self, x: torch.Tensor, target_bw: Optional[int] = None) -> torch.Tensor:
e = self.encoder(x)
if target_bw is None:
bw = self.target_bandwidths[-1]
else:
bw = target_bw
codes = self.quantizer.encode(e, self.frame_rate, bw)
return codes
def decode(self, codes: torch.Tensor) -> torch.Tensor:
quantized = self.quantizer.decode(codes)
o = self.decoder(quantized)
return o
# test
if __name__ == '__main__':
soundstream = SoundStream(n_filters=32, D=256)
for i in range(10):
print(f"Iter {i}: ")
x = torch.rand(1, 1, 16000)
o, _, _ = soundstream(x)
print('output', o.shape)
|