SongGeneration / codeclm /tokenizer /audio_tokenizer.py
waytan22's picture
add auto prompt and interface
d658154
raw
history blame
7.54 kB
"""
Tokenizer or wrapper around existing models.
Also defines the main interface that a model must follow to be usable as an audio tokenizer.
"""
from abc import ABC, abstractmethod
import logging
import typing as tp
import torch
from torch import nn
logger = logging.getLogger()
class AudioTokenizer(ABC, nn.Module):
"""Base API for all compression model that aim at being used as audio tokenizers
with a language model.
"""
@abstractmethod
def forward(self, x: torch.Tensor) :
...
@abstractmethod
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
"""See `EncodecModel.encode`."""
...
@abstractmethod
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
"""See `EncodecModel.decode`."""
...
@abstractmethod
def decode_latent(self, codes: torch.Tensor):
"""Decode from the discrete codes to continuous latent space."""
...
@property
@abstractmethod
def channels(self) -> int:
...
@property
@abstractmethod
def frame_rate(self) -> float:
...
@property
@abstractmethod
def sample_rate(self) -> int:
...
@property
@abstractmethod
def cardinality(self) -> int:
...
@property
@abstractmethod
def num_codebooks(self) -> int:
...
@property
@abstractmethod
def total_codebooks(self) -> int:
...
@abstractmethod
def set_num_codebooks(self, n: int):
"""Set the active number of codebooks used by the quantizer."""
...
@staticmethod
def get_pretrained(
name: str,
vae_config: str,
vae_model: str,
device: tp.Union[torch.device, str] = 'cpu',
mode='extract'
) -> 'AudioTokenizer':
"""Instantiate a AudioTokenizer model from a given pretrained model.
Args:
name (Path or str): name of the pretrained model. See after.
device (torch.device or str): Device on which the model is loaded.
"""
model: AudioTokenizer
if name.split('_')[0] == 'Flow1dVAESeparate':
model_type = name.split('_', 1)[1]
logger.info("Getting pretrained compression model from semantic model %s", model_type)
model = Flow1dVAESeparate(model_type, vae_config, vae_model)
elif name.split('_')[0] == 'Flow1dVAE1rvq':
model_type = name.split('_', 1)[1]
logger.info("Getting pretrained compression model from semantic model %s", model_type)
model = Flow1dVAE1rvq(model_type, vae_config, vae_model)
else:
raise NotImplementedError("{} is not implemented in models/audio_tokenizer.py".format(
name))
return model.to(device).eval()
class Flow1dVAE1rvq(AudioTokenizer):
def __init__(
self,
model_type: str = "model_2_fixed.safetensors",
vae_config: str = "",
vae_model: str = "",
):
super().__init__()
from codeclm.tokenizer.Flow1dVAE.generate_1rvq import Tango
model_path = model_type
self.model = Tango(model_path=model_path, vae_config=vae_config, vae_model=vae_model, device='cuda')
print ("Successfully loaded checkpoint from:", model_path)
self.n_quantizers = 1
def forward(self, x: torch.Tensor) :
# We don't support training with this.
raise NotImplementedError("Forward and training with DAC not supported.")
@torch.no_grad()
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
if x.ndim == 2:
x = x.unsqueeze(1)
codes = self.model.sound2code(x) # [B T] -> [B N T]
return codes, None
@torch.no_grad()
def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9):
wav = self.model.code2sound(codes, prompt=prompt, guidance_scale=1.5,
num_steps=50, disable_progress=False) # [B,N,T] -> [B,T]
return wav[None]
@torch.no_grad()
def decode_latent(self, codes: torch.Tensor):
"""Decode from the discrete codes to continuous latent space."""
# import pdb; pdb.set_trace()
return self.model.quantizer.from_codes(codes.transpose(1,2))[0]
@property
def channels(self) -> int:
return 2
@property
def frame_rate(self) -> float:
return 25
@property
def sample_rate(self) -> int:
return self.samplerate
@property
def cardinality(self) -> int:
return 10000
@property
def num_codebooks(self) -> int:
return self.n_quantizers
@property
def total_codebooks(self) -> int:
# return self.model.RVQ
return 1
def set_num_codebooks(self, n: int):
"""Set the active number of codebooks used by the quantizer.
"""
assert n >= 1
assert n <= self.total_codebooks
self.n_quantizers = n
class Flow1dVAESeparate(AudioTokenizer):
def __init__(
self,
model_type: str = "model_2.safetensors",
vae_config: str = "",
vae_model: str = "",
):
super().__init__()
from codeclm.tokenizer.Flow1dVAE.generate_septoken import Tango
model_path = model_type
self.model = Tango(model_path=model_path, vae_config=vae_config, vae_model=vae_model, device='cuda')
print ("Successfully loaded checkpoint from:", model_path)
self.n_quantizers = 1
def forward(self, x: torch.Tensor) :
# We don't support training with this.
raise NotImplementedError("Forward and training with DAC not supported.")
@torch.no_grad()
def encode(self, x_vocal: torch.Tensor, x_bgm: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
if x_vocal.ndim == 2:
x_vocal = x_vocal.unsqueeze(1)
if x_bgm.ndim == 2:
x_bgm = x_bgm.unsqueeze(1)
codes_vocal, codes_bgm = self.model.sound2code(x_vocal, x_bgm)
return codes_vocal, codes_bgm
@torch.no_grad()
def decode(self, codes: torch.Tensor, prompt_vocal = None, prompt_bgm = None):
wav = self.model.code2sound(codes, prompt_vocal=prompt_vocal, prompt_bgm=prompt_bgm, guidance_scale=1.5,
num_steps=50, disable_progress=False) # [B,N,T] -> [B,T]
return wav[None]
@torch.no_grad()
def decode_latent(self, codes: torch.Tensor):
"""Decode from the discrete codes to continuous latent space."""
# import pdb; pdb.set_trace()
return self.model.quantizer.from_codes(codes.transpose(1,2))[0]
@property
def channels(self) -> int:
return 2
@property
def frame_rate(self) -> float:
return 25
@property
def sample_rate(self) -> int:
return self.samplerate
@property
def cardinality(self) -> int:
return 10000
@property
def num_codebooks(self) -> int:
return self.n_quantizers
@property
def total_codebooks(self) -> int:
# return self.model.RVQ
return 1
def set_num_codebooks(self, n: int):
"""Set the active number of codebooks used by the quantizer.
"""
assert n >= 1
assert n <= self.total_codebooks
self.n_quantizers = n