|
import typing as tp |
|
import omegaconf |
|
from torch import nn |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
import os |
|
from omegaconf import OmegaConf, DictConfig |
|
|
|
from .encodec import EncodecModel |
|
from .lm import LMModel |
|
from .seanet import SEANetDecoder |
|
from .codebooks_patterns import DelayedPatternProvider |
|
from .conditioners import ( |
|
ConditioningProvider, |
|
T5Conditioner, |
|
ConditioningAttributes |
|
) |
|
from .vq import ResidualVectorQuantizer |
|
|
|
|
|
|
|
|
|
def _delete_param(cfg: DictConfig, full_name: str): |
|
parts = full_name.split('.') |
|
for part in parts[:-1]: |
|
if part in cfg: |
|
cfg = cfg[part] |
|
else: |
|
return |
|
OmegaConf.set_struct(cfg, False) |
|
if parts[-1] in cfg: |
|
del cfg[parts[-1]] |
|
OmegaConf.set_struct(cfg, True) |
|
|
|
|
|
|
|
def dict_from_config(cfg): |
|
dct = omegaconf.OmegaConf.to_container(cfg, resolve=True) |
|
return dct |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AudioGen(nn.Module): |
|
|
|
|
|
|
|
def __init__(self, |
|
duration=0.024, |
|
device='cpu'): |
|
|
|
super().__init__() |
|
self.device = device |
|
self.load_compression_model() |
|
self.load_lm_model() |
|
self.duration = duration |
|
|
|
@property |
|
def frame_rate(self): |
|
return self.compression_model.frame_rate |
|
|
|
def generate(self, |
|
descriptions): |
|
with torch.no_grad(): |
|
attributes = [ |
|
ConditioningAttributes(text={'description': d}) for d in descriptions] |
|
gen_tokens = self.lm.generate( |
|
conditions=attributes, |
|
max_gen_len=int(self.duration * self.frame_rate)) |
|
x = self.compression_model.decode(gen_tokens, None) |
|
|
|
print('GENAUD 5', x.sum(), x.shape) |
|
|
|
return x / x.abs().max(2, keepdims=True)[0] + 1e-7 |
|
|
|
|
|
def get_quantizer(self, quantizer, cfg, dimension): |
|
klass = { |
|
'no_quant': None, |
|
'rvq': ResidualVectorQuantizer |
|
}[quantizer] |
|
kwargs = dict_from_config(getattr(cfg, quantizer)) |
|
if quantizer != 'no_quant': |
|
kwargs['dimension'] = dimension |
|
return klass(**kwargs) |
|
|
|
|
|
def get_encodec_autoencoder(self, cfg): |
|
kwargs = dict_from_config(getattr(cfg, 'seanet')) |
|
_ = kwargs.pop('encoder') |
|
decoder_override_kwargs = kwargs.pop('decoder') |
|
decoder_kwargs = {**kwargs, **decoder_override_kwargs} |
|
decoder = SEANetDecoder(**decoder_kwargs) |
|
return decoder |
|
|
|
|
|
|
|
def get_compression_model(self, cfg): |
|
"""Instantiate a compression model.""" |
|
if cfg.compression_model == 'encodec': |
|
kwargs = dict_from_config(getattr(cfg, 'encodec')) |
|
quantizer_name = kwargs.pop('quantizer') |
|
decoder = self.get_encodec_autoencoder(cfg) |
|
quantizer = self.get_quantizer(quantizer_name, cfg, 128) |
|
renormalize = kwargs.pop('renormalize', False) |
|
|
|
|
|
kwargs.pop('renorm', None) |
|
|
|
|
|
|
|
|
|
|
|
return EncodecModel(decoder=decoder, |
|
quantizer=quantizer, |
|
frame_rate=50, |
|
renormalize=renormalize, |
|
sample_rate=16000, |
|
channels=1, |
|
causal=False |
|
).to(cfg.device) |
|
else: |
|
raise KeyError(f"Unexpected compression model {cfg.compression_model}") |
|
|
|
|
|
def get_lm_model(self, cfg): |
|
"""Instantiate a transformer LM.""" |
|
if cfg.lm_model in ['transformer_lm', |
|
'transformer_lm_magnet']: |
|
kwargs = dict_from_config(getattr(cfg, 'transformer_lm')) |
|
n_q = kwargs['n_q'] |
|
q_modeling = kwargs.pop('q_modeling', None) |
|
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern') |
|
attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout')) |
|
cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance')) |
|
cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef'] |
|
|
|
condition_provider = self.get_conditioner_provider(kwargs["dim"], cfg |
|
).to(self.device) |
|
|
|
|
|
|
|
kwargs['cross_attention'] = True |
|
if codebooks_pattern_cfg.modeling is None: |
|
print('Q MODELING\n=\n=><') |
|
assert q_modeling is not None, \ |
|
"LM model should either have a codebook pattern defined or transformer_lm.q_modeling" |
|
codebooks_pattern_cfg = omegaconf.OmegaConf.create( |
|
{'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}} |
|
) |
|
|
|
pattern_provider = self.get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg) |
|
return LMModel( |
|
pattern_provider=pattern_provider, |
|
condition_provider=condition_provider, |
|
cfg_dropout=cfg_prob, |
|
cfg_coef=cfg_coef, |
|
attribute_dropout=attribute_dropout, |
|
dtype=getattr(torch, cfg.dtype), |
|
device=self.device, |
|
**kwargs |
|
).to(cfg.device) |
|
else: |
|
raise KeyError(f"Unexpected LM model {cfg.lm_model}") |
|
|
|
|
|
def get_conditioner_provider(self, output_dim, |
|
cfg): |
|
"""Instantiate T5 text""" |
|
cfg = getattr(cfg, 'conditioners') |
|
dict_cfg = {} if cfg is None else dict_from_config(cfg) |
|
conditioners={} |
|
condition_provider_args = dict_cfg.pop('args', {}) |
|
condition_provider_args.pop('merge_text_conditions_p', None) |
|
condition_provider_args.pop('drop_desc_p', None) |
|
|
|
for cond, cond_cfg in dict_cfg.items(): |
|
model_type = cond_cfg['model'] |
|
model_args = cond_cfg[model_type] |
|
if model_type == 't5': |
|
conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, |
|
device=self.device, |
|
**model_args) |
|
else: |
|
raise ValueError(f"Unrecognized conditioning model: {model_type}") |
|
|
|
|
|
return ConditioningProvider(conditioners) |
|
|
|
|
|
|
|
|
|
def get_codebooks_pattern_provider(self, n_q, cfg): |
|
pattern_providers = { |
|
'delay': DelayedPatternProvider, |
|
} |
|
name = cfg.modeling |
|
kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {} |
|
|
|
klass = pattern_providers[name] |
|
return klass(n_q, **kwargs) |
|
|
|
|
|
def load_compression_model(self): |
|
file = hf_hub_download( |
|
repo_id='facebook/audiogen-medium', |
|
filename="compression_state_dict.bin", |
|
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None), |
|
library_name="audiocraft", |
|
library_version= '1.3.0a1') |
|
pkg = torch.load(file, map_location='cpu') |
|
|
|
|
|
|
|
cfg = OmegaConf.create(pkg['xp.cfg']) |
|
cfg.device = 'cpu' |
|
model = self.get_compression_model(cfg) |
|
model.load_state_dict(pkg['best_state'], strict=False) |
|
|
|
self.compression_model = model |
|
|
|
def load_lm_model(self): |
|
file = hf_hub_download( |
|
repo_id='facebook/audiogen-medium', |
|
filename="state_dict.bin", |
|
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None), |
|
library_name="audiocraft", |
|
library_version= '1.3.0a1') |
|
pkg = torch.load(file, |
|
map_location=self.device) |
|
cfg = OmegaConf.create(pkg['xp.cfg']) |
|
|
|
if self.device == 'cpu': |
|
cfg.dtype = 'float32' |
|
else: |
|
cfg.dtype = 'float16' |
|
_delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path') |
|
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p') |
|
_delete_param(cfg, 'conditioners.args.drop_desc_p') |
|
model = self.get_lm_model(cfg) |
|
model.load_state_dict(pkg['best_state']) |
|
model.cfg = cfg |
|
|
|
self.lm = model.to(torch.float) |