|
import math |
|
from dataclasses import dataclass |
|
from pathlib import Path |
|
from typing import Union |
|
|
|
import numpy as np |
|
import torch |
|
import tqdm |
|
from audiotools import AudioSignal |
|
from torch import nn |
|
|
|
SUPPORTED_VERSIONS = ["1.0.0"] |
|
|
|
|
|
@dataclass |
|
class DACFile: |
|
codes: torch.Tensor |
|
|
|
|
|
chunk_length: int |
|
original_length: int |
|
input_db: float |
|
channels: int |
|
sample_rate: int |
|
padding: bool |
|
dac_version: str |
|
|
|
def save(self, path): |
|
artifacts = { |
|
"codes": self.codes.numpy().astype(np.uint16), |
|
"metadata": { |
|
"input_db": self.input_db.numpy().astype(np.float32), |
|
"original_length": self.original_length, |
|
"sample_rate": self.sample_rate, |
|
"chunk_length": self.chunk_length, |
|
"channels": self.channels, |
|
"padding": self.padding, |
|
"dac_version": SUPPORTED_VERSIONS[-1], |
|
}, |
|
} |
|
path = Path(path).with_suffix(".dac") |
|
with open(path, "wb") as f: |
|
np.save(f, artifacts) |
|
return path |
|
|
|
@classmethod |
|
def load(cls, path): |
|
artifacts = np.load(path, allow_pickle=True)[()] |
|
codes = torch.from_numpy(artifacts["codes"].astype(int)) |
|
if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS: |
|
raise RuntimeError( |
|
f"Given file {path} can't be loaded with this version of descript-audio-codec." |
|
) |
|
return cls(codes=codes, **artifacts["metadata"]) |
|
|
|
|
|
class CodecMixin: |
|
@property |
|
def padding(self): |
|
if not hasattr(self, "_padding"): |
|
self._padding = True |
|
return self._padding |
|
|
|
@padding.setter |
|
def padding(self, value): |
|
assert isinstance(value, bool) |
|
|
|
layers = [ |
|
l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d)) |
|
] |
|
|
|
for layer in layers: |
|
if value: |
|
if hasattr(layer, "original_padding"): |
|
layer.padding = layer.original_padding |
|
else: |
|
layer.original_padding = layer.padding |
|
layer.padding = tuple(0 for _ in range(len(layer.padding))) |
|
|
|
self._padding = value |
|
|
|
def get_delay(self): |
|
|
|
l_out = self.get_output_length(0) |
|
L = l_out |
|
|
|
layers = [] |
|
for layer in self.modules(): |
|
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): |
|
layers.append(layer) |
|
|
|
for layer in reversed(layers): |
|
d = layer.dilation[0] |
|
k = layer.kernel_size[0] |
|
s = layer.stride[0] |
|
|
|
if isinstance(layer, nn.ConvTranspose1d): |
|
L = ((L - d * (k - 1) - 1) / s) + 1 |
|
elif isinstance(layer, nn.Conv1d): |
|
L = (L - 1) * s + d * (k - 1) + 1 |
|
|
|
L = math.ceil(L) |
|
|
|
l_in = L |
|
|
|
return (l_in - l_out) // 2 |
|
|
|
def get_output_length(self, input_length): |
|
L = input_length |
|
|
|
for layer in self.modules(): |
|
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): |
|
d = layer.dilation[0] |
|
k = layer.kernel_size[0] |
|
s = layer.stride[0] |
|
|
|
if isinstance(layer, nn.Conv1d): |
|
L = ((L - d * (k - 1) - 1) / s) + 1 |
|
elif isinstance(layer, nn.ConvTranspose1d): |
|
L = (L - 1) * s + d * (k - 1) + 1 |
|
|
|
L = math.floor(L) |
|
return L |
|
|
|
@torch.no_grad() |
|
def compress( |
|
self, |
|
audio_path_or_signal: Union[str, Path, AudioSignal], |
|
win_duration: float = 1.0, |
|
verbose: bool = False, |
|
normalize_db: float = -16, |
|
n_quantizers: int = None, |
|
) -> DACFile: |
|
"""Processes an audio signal from a file or AudioSignal object into |
|
discrete codes. This function processes the signal in short windows, |
|
using constant GPU memory. |
|
|
|
Parameters |
|
---------- |
|
audio_path_or_signal : Union[str, Path, AudioSignal] |
|
audio signal to reconstruct |
|
win_duration : float, optional |
|
window duration in seconds, by default 5.0 |
|
verbose : bool, optional |
|
by default False |
|
normalize_db : float, optional |
|
normalize db, by default -16 |
|
|
|
Returns |
|
------- |
|
DACFile |
|
Object containing compressed codes and metadata |
|
required for decompression |
|
""" |
|
audio_signal = audio_path_or_signal |
|
if isinstance(audio_signal, (str, Path)): |
|
audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) |
|
|
|
self.eval() |
|
original_padding = self.padding |
|
original_device = audio_signal.device |
|
|
|
audio_signal = audio_signal.clone() |
|
original_sr = audio_signal.sample_rate |
|
|
|
resample_fn = audio_signal.resample |
|
loudness_fn = audio_signal.loudness |
|
|
|
|
|
if audio_signal.signal_duration >= 10 * 60 * 60: |
|
resample_fn = audio_signal.ffmpeg_resample |
|
loudness_fn = audio_signal.ffmpeg_loudness |
|
|
|
original_length = audio_signal.signal_length |
|
resample_fn(self.sample_rate) |
|
input_db = loudness_fn() |
|
|
|
if normalize_db is not None: |
|
audio_signal.normalize(normalize_db) |
|
audio_signal.ensure_max_of_audio() |
|
|
|
nb, nac, nt = audio_signal.audio_data.shape |
|
audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) |
|
win_duration = ( |
|
audio_signal.signal_duration if win_duration is None else win_duration |
|
) |
|
|
|
if audio_signal.signal_duration <= win_duration: |
|
|
|
self.padding = True |
|
n_samples = nt |
|
hop = nt |
|
else: |
|
|
|
self.padding = False |
|
|
|
audio_signal.zero_pad(self.delay, self.delay) |
|
n_samples = int(win_duration * self.sample_rate) |
|
|
|
n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) |
|
hop = self.get_output_length(n_samples) |
|
|
|
codes = [] |
|
range_fn = range if not verbose else tqdm.trange |
|
|
|
for i in range_fn(0, nt, hop): |
|
x = audio_signal[..., i : i + n_samples] |
|
x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) |
|
|
|
audio_data = x.audio_data.to(self.device) |
|
audio_data = self.preprocess(audio_data, self.sample_rate) |
|
_, c, _, _, _ = self.encode(audio_data, n_quantizers) |
|
codes.append(c.to(original_device)) |
|
chunk_length = c.shape[-1] |
|
|
|
codes = torch.cat(codes, dim=-1) |
|
|
|
dac_file = DACFile( |
|
codes=codes, |
|
chunk_length=chunk_length, |
|
original_length=original_length, |
|
input_db=input_db, |
|
channels=nac, |
|
sample_rate=original_sr, |
|
padding=self.padding, |
|
dac_version=SUPPORTED_VERSIONS[-1], |
|
) |
|
|
|
if n_quantizers is not None: |
|
codes = codes[:, :n_quantizers, :] |
|
|
|
self.padding = original_padding |
|
return dac_file |
|
|
|
@torch.no_grad() |
|
def decompress( |
|
self, |
|
obj: Union[str, Path, DACFile], |
|
verbose: bool = False, |
|
) -> AudioSignal: |
|
"""Reconstruct audio from a given .dac file |
|
|
|
Parameters |
|
---------- |
|
obj : Union[str, Path, DACFile] |
|
.dac file location or corresponding DACFile object. |
|
verbose : bool, optional |
|
Prints progress if True, by default False |
|
|
|
Returns |
|
------- |
|
AudioSignal |
|
Object with the reconstructed audio |
|
""" |
|
self.eval() |
|
if isinstance(obj, (str, Path)): |
|
obj = DACFile.load(obj) |
|
|
|
original_padding = self.padding |
|
self.padding = obj.padding |
|
|
|
range_fn = range if not verbose else tqdm.trange |
|
codes = obj.codes |
|
original_device = codes.device |
|
chunk_length = obj.chunk_length |
|
recons = [] |
|
|
|
for i in range_fn(0, codes.shape[-1], chunk_length): |
|
c = codes[..., i : i + chunk_length].to(self.device) |
|
z = self.quantizer.from_codes(c)[0] |
|
r = self.decode(z) |
|
recons.append(r.to(original_device)) |
|
|
|
recons = torch.cat(recons, dim=-1) |
|
recons = AudioSignal(recons, self.sample_rate) |
|
|
|
resample_fn = recons.resample |
|
loudness_fn = recons.loudness |
|
|
|
|
|
if recons.signal_duration >= 10 * 60 * 60: |
|
resample_fn = recons.ffmpeg_resample |
|
loudness_fn = recons.ffmpeg_loudness |
|
|
|
recons.normalize(obj.input_db) |
|
resample_fn(obj.sample_rate) |
|
recons = recons[..., : obj.original_length] |
|
loudness_fn() |
|
recons.audio_data = recons.audio_data.reshape( |
|
-1, obj.channels, obj.original_length |
|
) |
|
|
|
self.padding = original_padding |
|
return recons |
|
|