|
|
|
|
|
|
|
|
|
|
|
"""Compression models or wrapper around existing models. |
|
Also defines the main interface that a model must follow to be usable as an audio tokenizer. |
|
""" |
|
|
|
from abc import ABC, abstractmethod |
|
import logging |
|
import math |
|
from pathlib import Path |
|
import typing as tp |
|
|
|
from einops import rearrange |
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
from transformers import EncodecModel as HFEncodecModel |
|
|
|
import audiocraft.quantization as qt |
|
|
|
|
|
logger = logging.getLogger() |
|
|
|
|
|
class CompressionModel(ABC, nn.Module): |
|
"""Base API for all compression models that aim at being used as audio tokenizers |
|
with a language model. |
|
""" |
|
|
|
|
|
|
|
@abstractmethod |
|
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): |
|
"""See `EncodecModel.decode`.""" |
|
... |
|
|
|
@abstractmethod |
|
def decode_latent(self, codes: torch.Tensor): |
|
"""Decode from the discrete codes to continuous latent space.""" |
|
... |
|
|
|
@property |
|
@abstractmethod |
|
def channels(self) -> int: |
|
... |
|
|
|
@property |
|
@abstractmethod |
|
def frame_rate(self) -> float: |
|
... |
|
|
|
@property |
|
@abstractmethod |
|
def sample_rate(self) -> int: |
|
... |
|
|
|
@property |
|
@abstractmethod |
|
def cardinality(self) -> int: |
|
... |
|
|
|
@property |
|
@abstractmethod |
|
def num_codebooks(self) -> int: |
|
... |
|
|
|
@property |
|
@abstractmethod |
|
def total_codebooks(self) -> int: |
|
... |
|
|
|
@abstractmethod |
|
def set_num_codebooks(self, n: int): |
|
"""Set the active number of codebooks used by the quantizer.""" |
|
... |
|
|
|
@staticmethod |
|
def get_pretrained( |
|
name: str, device: tp.Union[torch.device, str] = 'cpu' |
|
) -> 'CompressionModel': |
|
"""Instantiate a CompressionModel from a given pretrained model. |
|
|
|
Args: |
|
name (Path or str): name of the pretrained model. See after. |
|
device (torch.device or str): Device on which the model is loaded. |
|
|
|
Pretrained models: |
|
- dac_44khz (https://github.com/descriptinc/descript-audio-codec) |
|
- dac_24khz (same) |
|
- facebook/encodec_24khz (https://huggingface.co/facebook/encodec_24khz) |
|
- facebook/encodec_32khz (https://huggingface.co/facebook/encodec_32khz) |
|
- your own model on Hugging Face. Export instructions to come... |
|
""" |
|
|
|
from . import builders, loaders |
|
model: CompressionModel |
|
if name in ['dac_44khz', 'dac_24khz']: |
|
model_type = name.split('_')[1] |
|
logger.info("Getting pretrained compression model from DAC %s", model_type) |
|
model = DAC(model_type) |
|
elif name in ['debug_compression_model']: |
|
logger.info("Getting pretrained compression model for debug") |
|
model = builders.get_debug_compression_model() |
|
elif Path(name).exists(): |
|
|
|
|
|
model = loaders.load_compression_model(name, device=device) |
|
else: |
|
logger.info("Getting pretrained compression model from HF %s", name) |
|
hf_model = HFEncodecModel.from_pretrained(name) |
|
model = HFEncodecCompressionModel(hf_model).to(device) |
|
return model.to(device).eval() |
|
|
|
|
|
class EncodecModel(CompressionModel): |
|
"""Encodec model operating on the raw waveform. |
|
|
|
Args: |
|
encoder (nn.Module): Encoder network. |
|
decoder (nn.Module): Decoder network. |
|
quantizer (qt.BaseQuantizer): Quantizer network. |
|
frame_rate (int): Frame rate for the latent representation. |
|
sample_rate (int): Audio sample rate. |
|
channels (int): Number of audio channels. |
|
causal (bool): Whether to use a causal version of the model. |
|
renormalize (bool): Whether to renormalize the audio before running the model. |
|
""" |
|
|
|
|
|
frame_rate: float = 0 |
|
sample_rate: int = 0 |
|
channels: int = 0 |
|
|
|
def __init__(self, |
|
decoder=None, |
|
quantizer=None, |
|
frame_rate=None, |
|
sample_rate=None, |
|
channels=None, |
|
causal=False, |
|
renormalize=False): |
|
super().__init__() |
|
|
|
self.decoder = decoder |
|
self.quantizer = quantizer |
|
self.frame_rate = frame_rate |
|
self.sample_rate = sample_rate |
|
self.channels = channels |
|
self.renormalize = renormalize |
|
self.causal = causal |
|
if self.causal: |
|
|
|
|
|
assert not self.renormalize, 'Causal model does not support renormalize' |
|
|
|
@property |
|
def total_codebooks(self): |
|
"""Total number of quantizer codebooks available.""" |
|
return self.quantizer.total_codebooks |
|
|
|
@property |
|
def num_codebooks(self): |
|
"""Active number of codebooks used by the quantizer.""" |
|
return self.quantizer.num_codebooks |
|
|
|
def set_num_codebooks(self, n: int): |
|
"""Set the active number of codebooks used by the quantizer.""" |
|
self.quantizer.set_num_codebooks(n) |
|
|
|
@property |
|
def cardinality(self): |
|
"""Cardinality of each codebook.""" |
|
return self.quantizer.bins |
|
|
|
def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: |
|
scale: tp.Optional[torch.Tensor] |
|
if self.renormalize: |
|
mono = x.mean(dim=1, keepdim=True) |
|
volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() |
|
scale = 1e-8 + volume |
|
x = x / scale |
|
scale = scale.view(-1, 1) |
|
else: |
|
scale = None |
|
return x, scale |
|
|
|
def postprocess(self, |
|
x: torch.Tensor, |
|
scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor: |
|
if scale is not None: |
|
assert self.renormalize |
|
x = x * scale.view(-1, 1, 1) |
|
return x |
|
|
|
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): |
|
"""Decode the given codes to a reconstructed representation, using the scale to perform |
|
audio denormalization if needed. |
|
|
|
Args: |
|
codes (torch.Tensor): Int tensor of shape [B, K, T] |
|
scale (torch.Tensor, optional): Float tensor containing the scale value. |
|
|
|
Returns: |
|
out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio. |
|
""" |
|
emb = self.decode_latent(codes) |
|
out = self.decoder(emb) |
|
out = self.postprocess(out, scale) |
|
|
|
return out |
|
|
|
def decode_latent(self, codes: torch.Tensor): |
|
"""Decode from the discrete codes to continuous latent space.""" |
|
return self.quantizer.decode(codes) |