|
import torch
|
|
from audiotools import AudioSignal
|
|
from audiotools.ml import BaseModel
|
|
from encodec import EncodecModel
|
|
|
|
|
|
class Encodec(BaseModel):
|
|
def __init__(self, sample_rate: int = 24000, bandwidth: float = 24.0):
|
|
super().__init__()
|
|
|
|
if sample_rate == 24000:
|
|
self.model = EncodecModel.encodec_model_24khz()
|
|
else:
|
|
self.model = EncodecModel.encodec_model_48khz()
|
|
self.model.set_target_bandwidth(bandwidth)
|
|
self.sample_rate = 44100
|
|
|
|
def forward(
|
|
self,
|
|
audio_data: torch.Tensor,
|
|
sample_rate: int = 44100,
|
|
n_quantizers: int = None,
|
|
):
|
|
signal = AudioSignal(audio_data, sample_rate)
|
|
signal.resample(self.model.sample_rate)
|
|
recons = self.model(signal.audio_data)
|
|
recons = AudioSignal(recons, self.model.sample_rate)
|
|
recons.resample(sample_rate)
|
|
return {"audio": recons.audio_data}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import numpy as np
|
|
from functools import partial
|
|
|
|
model = Encodec()
|
|
|
|
for n, m in model.named_modules():
|
|
o = m.extra_repr()
|
|
p = sum([np.prod(p.size()) for p in m.parameters()])
|
|
fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
|
|
setattr(m, "extra_repr", partial(fn, o=o, p=p))
|
|
print(model)
|
|
print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
|
|
|
|
length = 88200 * 2
|
|
x = torch.randn(1, 1, length).to(model.device)
|
|
x.requires_grad_(True)
|
|
x.retain_grad()
|
|
|
|
|
|
out = model(x)["audio"]
|
|
|
|
print(x.shape, out.shape)
|
|
|