Dionyssos's picture
revert pattern preserves 4
dcfe0d4
raw
history blame
9.42 kB
import typing as tp
import omegaconf
from torch import nn
import torch
from huggingface_hub import hf_hub_download
import os
from omegaconf import OmegaConf, DictConfig
from .encodec import EncodecModel
from .lm import LMModel
from .seanet import SEANetDecoder
from .codebooks_patterns import DelayedPatternProvider
from .conditioners import (
ConditioningProvider,
T5Conditioner,
ConditioningAttributes
)
from .vq import ResidualVectorQuantizer
def _delete_param(cfg: DictConfig, full_name: str):
parts = full_name.split('.')
for part in parts[:-1]:
if part in cfg:
cfg = cfg[part]
else:
return
OmegaConf.set_struct(cfg, False)
if parts[-1] in cfg:
del cfg[parts[-1]]
OmegaConf.set_struct(cfg, True)
def dict_from_config(cfg):
dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
return dct
# ============================================== DEFINE AUDIOGEN
class AudioGen(nn.Module):
# https://huggingface.co/facebook/audiogen-medium
def __init__(self,
duration=0.024,
device='cpu'):
super().__init__()
self.device = device # needed for loading & select float16 LM
self.load_compression_model()
self.load_lm_model()
self.duration = duration
@property
def frame_rate(self):
return self.compression_model.frame_rate
def generate(self,
descriptions):
with torch.no_grad():
attributes = [
ConditioningAttributes(text={'description': d}) for d in descriptions]
gen_tokens = self.lm.generate(
conditions=attributes,
max_gen_len=int(self.duration * self.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
print('______________\nGENTOk 5', gen_tokens.shape)
print('GENAUD 5', x.sum())
return x
# == BUILD Fn
def get_quantizer(self, quantizer, cfg, dimension):
klass = {
'no_quant': None,
'rvq': ResidualVectorQuantizer
}[quantizer]
kwargs = dict_from_config(getattr(cfg, quantizer))
if quantizer != 'no_quant':
kwargs['dimension'] = dimension
return klass(**kwargs)
def get_encodec_autoencoder(self, cfg):
kwargs = dict_from_config(getattr(cfg, 'seanet'))
_ = kwargs.pop('encoder')
decoder_override_kwargs = kwargs.pop('decoder')
decoder_kwargs = {**kwargs, **decoder_override_kwargs}
decoder = SEANetDecoder(**decoder_kwargs)
return decoder
def get_compression_model(self, cfg):
"""Instantiate a compression model."""
if cfg.compression_model == 'encodec':
kwargs = dict_from_config(getattr(cfg, 'encodec'))
quantizer_name = kwargs.pop('quantizer')
decoder = self.get_encodec_autoencoder(cfg)
quantizer = self.get_quantizer(quantizer_name, cfg, 128)
renormalize = kwargs.pop('renormalize', False)
# deprecated params
# print(f'{frame_rate=} {encoder.dimension=}') frame_rate=50 encoder.dimension=128
kwargs.pop('renorm', None)
# print('\n______!____________\n', kwargs, '\n______!____________\n')
# ______!____________
# {'autoencoder': 'seanet', 'sample_rate': 16000, 'channels': 1, 'causal': False}
# ______!____________
return EncodecModel(decoder=decoder,
quantizer=quantizer,
frame_rate=50,
renormalize=renormalize,
sample_rate=16000,
channels=1,
causal=False
).to(cfg.device)
else:
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
def get_lm_model(self, cfg):
"""Instantiate a transformer LM."""
if cfg.lm_model in ['transformer_lm',
'transformer_lm_magnet']:
kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
n_q = kwargs['n_q']
q_modeling = kwargs.pop('q_modeling', None)
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
condition_provider = self.get_conditioner_provider(kwargs["dim"], cfg
).to(self.device)
# if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
kwargs['cross_attention'] = True
if codebooks_pattern_cfg.modeling is None:
print('Q MODELING\n=\n=><')
assert q_modeling is not None, \
"LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
codebooks_pattern_cfg = omegaconf.OmegaConf.create(
{'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
)
pattern_provider = self.get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
return LMModel(
pattern_provider=pattern_provider,
condition_provider=condition_provider,
cfg_dropout=cfg_prob,
cfg_coef=cfg_coef,
attribute_dropout=attribute_dropout,
dtype=getattr(torch, cfg.dtype),
device=self.device,
**kwargs
).to(cfg.device)
else:
raise KeyError(f"Unexpected LM model {cfg.lm_model}")
def get_conditioner_provider(self, output_dim,
cfg):
"""Instantiate T5 text"""
cfg = getattr(cfg, 'conditioners')
dict_cfg = {} if cfg is None else dict_from_config(cfg)
conditioners={}
condition_provider_args = dict_cfg.pop('args', {})
condition_provider_args.pop('merge_text_conditions_p', None)
condition_provider_args.pop('drop_desc_p', None)
for cond, cond_cfg in dict_cfg.items():
model_type = cond_cfg['model']
model_args = cond_cfg[model_type]
if model_type == 't5':
conditioners[str(cond)] = T5Conditioner(output_dim=output_dim,
device=self.device,
**model_args)
else:
raise ValueError(f"Unrecognized conditioning model: {model_type}")
# print(f'{condition_provider_args=}')
return ConditioningProvider(conditioners)
def get_codebooks_pattern_provider(self, n_q, cfg):
pattern_providers = {
'delay': DelayedPatternProvider, # THIS
}
name = cfg.modeling
kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
klass = pattern_providers[name]
return klass(n_q, **kwargs)
# ======================
def load_compression_model(self):
file = hf_hub_download(
repo_id='facebook/audiogen-medium',
filename="compression_state_dict.bin",
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
library_name="audiocraft",
library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
pkg = torch.load(file, map_location='cpu')
# if 'pretrained' in pkg:
# print('NO RPtrained\n=\n=\n=\n=\n=')
# return EncodecModel.get_pretrained(pkg['pretrained'], device='cpu')
cfg = OmegaConf.create(pkg['xp.cfg'])
cfg.device = 'cpu'
model = self.get_compression_model(cfg)
model.load_state_dict(pkg['best_state'], strict=False) # ckpt has also unused encoder weights
# return model
self.compression_model = model
def load_lm_model(self):
file = hf_hub_download(
repo_id='facebook/audiogen-medium',
filename="state_dict.bin",
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
library_name="audiocraft",
library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
pkg = torch.load(file,
map_location=self.device) #'cpu')
cfg = OmegaConf.create(pkg['xp.cfg'])
# cfg.device = 'cpu'
if self.device == 'cpu':
cfg.dtype = 'float32'
else:
cfg.dtype = 'float16'
_delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
_delete_param(cfg, 'conditioners.args.drop_desc_p')
model = self.get_lm_model(cfg)
model.load_state_dict(pkg['best_state'])
model.cfg = cfg
# return model
self.lm = model.to(torch.float)