Spaces:
Running
on
L40S
Running
on
L40S
File size: 5,271 Bytes
258fd02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
"""
All the functions to build the relevant models and modules
from the Hydra config.
"""
import typing as tp
import omegaconf
import torch
from codeclm.utils.utils import dict_from_config
from codeclm.modules.pattern import (
CodebooksPatternProvider,
DelayedPatternProvider,
)
from codeclm.modules.conditioners import (
BaseConditioner,
QwTokenizerConditioner,
QwTextConditioner,
PhonemeTokenizerConditioner,
QuantizedEmbeddingConditioner,
ConditionerProvider,
ConditionFuser,
)
def get_audio_tokenizer_model(checkpoint_path: str, cfg: omegaconf.DictConfig):
from codeclm.tokenizer.audio_tokenizer import AudioTokenizer
"""Instantiate a compression model."""
if checkpoint_path is None:
return None
if checkpoint_path.startswith('//pretrained/'):
name = checkpoint_path.split('/', 3)[-1]
return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode)
elif checkpoint_path == "":
return None
else:
name = checkpoint_path
return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode)
def get_lm_model(cfg: omegaconf.DictConfig): #-> LMModel:
"""Instantiate a LM."""
lm_kwargs = dict_from_config(getattr(cfg, 'lm'))
# n_q: number of RVQ
code_depth = lm_kwargs['code_depth']
q_modeling = lm_kwargs.pop('q_modeling', None)
# conditioner
condition_provider = get_conditioner_provider(lm_kwargs["dim"], cfg)
# codebook pattern: delay
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
if codebooks_pattern_cfg.modeling is None:
assert q_modeling is not None, \
"LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
codebooks_pattern_cfg = omegaconf.OmegaConf.create(
{'modeling': q_modeling, 'delay': {'delays': list(range(code_depth))}}
)
pattern_provider = get_codebooks_pattern_provider(code_depth, codebooks_pattern_cfg)
# condition dropout
attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
# condition fuser
fuser = get_condition_fuser(cfg)
lm_type = lm_kwargs['lm_type'] # YCY: For consistency, choose different lm.py based on lm_type
if lm_type == 'Llama':
from .lm_levo import LmModel
return LmModel(
pattern_provider=pattern_provider,
condition_provider=condition_provider,
fuser=fuser,
cfg_dropout=cfg_prob,
cfg_coef=cfg_coef,
attribute_dropout=attribute_dropout,
cfg=cfg,
**lm_kwargs
).to('cpu')
else:
raise KeyError(f"Unexpected LM model {lm_type}")
def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditionerProvider:
"""Instantiate a conditioning model."""
cfg = getattr(cfg, 'conditioners')
dict_cfg = {} if cfg is None else dict_from_config(cfg)
conditioners: tp.Dict[str, BaseConditioner] = {}
condition_provider_args = dict_cfg.pop('args', {})
for cond, cond_cfg in dict_cfg.items():
model_type = cond_cfg['model']
model_args = cond_cfg[model_type]
if model_type == 'QwTokenizer':
conditioners[str(cond)] = QwTokenizerConditioner(
output_dim=output_dim,
**model_args
)
elif model_type == "QwTextTokenizer":
conditioners[str(cond)] = QwTextConditioner(
output_dim=output_dim,
**model_args
)
elif model_type == 'PhonemeTokenizer':
conditioners[str(cond)] = PhonemeTokenizerConditioner(
output_dim=output_dim,
**model_args
)
elif model_type == "qt_embedding":
conditioners[str(cond)] = QuantizedEmbeddingConditioner(
dim=output_dim,
**model_args
)
else:
raise ValueError(f"Unrecognized conditioning model: {model_type}")
conditioner = ConditionerProvider(conditioners, **condition_provider_args)
return conditioner
def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
"""Instantiate a condition fuser object."""
fuser_cfg = getattr(cfg, 'fuser')
fuser_methods = ['sum', 'prepend']
fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
return fuser
def get_codebooks_pattern_provider(code_depth: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
"""Instantiate a codebooks pattern provider object."""
pattern_providers = {
'delay': DelayedPatternProvider,
}
name = cfg.modeling
kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
klass = pattern_providers[name]
return klass(code_depth, **kwargs)
|