File size: 5,271 Bytes
258fd02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
All the functions to build the relevant models and modules
from the Hydra config.
"""

import typing as tp

import omegaconf
import torch
from codeclm.utils.utils import dict_from_config
from codeclm.modules.pattern import (
    CodebooksPatternProvider,
    DelayedPatternProvider,
)
from codeclm.modules.conditioners import (
    BaseConditioner,
    QwTokenizerConditioner,
    QwTextConditioner,
    PhonemeTokenizerConditioner,
    QuantizedEmbeddingConditioner,
    ConditionerProvider,
    ConditionFuser,
)


def get_audio_tokenizer_model(checkpoint_path: str, cfg: omegaconf.DictConfig):
    from codeclm.tokenizer.audio_tokenizer import AudioTokenizer
    """Instantiate a compression model."""
    if checkpoint_path is None:
        return None
    if checkpoint_path.startswith('//pretrained/'):
        name = checkpoint_path.split('/', 3)[-1]
        return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode)
    elif checkpoint_path == "":
        return None
    else:
        name = checkpoint_path
        return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode)
    
def get_lm_model(cfg: omegaconf.DictConfig): #-> LMModel:
    """Instantiate a LM."""    
    lm_kwargs = dict_from_config(getattr(cfg, 'lm'))
    
    # n_q: number of RVQ
    code_depth = lm_kwargs['code_depth']
    q_modeling = lm_kwargs.pop('q_modeling', None)    
        
    # conditioner
    condition_provider = get_conditioner_provider(lm_kwargs["dim"], cfg)
    
    # codebook pattern: delay
    codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
    if codebooks_pattern_cfg.modeling is None:
        assert q_modeling is not None, \
            "LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
        codebooks_pattern_cfg = omegaconf.OmegaConf.create(
            {'modeling': q_modeling, 'delay': {'delays': list(range(code_depth))}}
        )
    pattern_provider = get_codebooks_pattern_provider(code_depth, codebooks_pattern_cfg)
    
    # condition dropout
    attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
    cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
    cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
    
    # condition fuser
    fuser = get_condition_fuser(cfg)    
    lm_type = lm_kwargs['lm_type'] # YCY: For consistency, choose different lm.py based on lm_type
    if lm_type == 'Llama':
        from .lm_levo import LmModel
        return LmModel(
            pattern_provider=pattern_provider,
            condition_provider=condition_provider,
            fuser=fuser,
            cfg_dropout=cfg_prob,
            cfg_coef=cfg_coef,
            attribute_dropout=attribute_dropout,
            cfg=cfg,
            **lm_kwargs
        ).to('cpu')
    else:
        raise KeyError(f"Unexpected LM model {lm_type}")


def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditionerProvider:
    """Instantiate a conditioning model."""    
    cfg = getattr(cfg, 'conditioners')
    dict_cfg = {} if cfg is None else dict_from_config(cfg)
    conditioners: tp.Dict[str, BaseConditioner] = {}
    condition_provider_args = dict_cfg.pop('args', {})

    for cond, cond_cfg in dict_cfg.items():
        model_type = cond_cfg['model']
        model_args = cond_cfg[model_type]
        if model_type == 'QwTokenizer':
            conditioners[str(cond)] = QwTokenizerConditioner(
                output_dim=output_dim,
                **model_args
            )
        elif model_type == "QwTextTokenizer":
            conditioners[str(cond)] = QwTextConditioner(
                output_dim=output_dim,
                **model_args
            )
        elif model_type == 'PhonemeTokenizer':
            conditioners[str(cond)] = PhonemeTokenizerConditioner(                                                 
                output_dim=output_dim,
                **model_args
            )
        elif model_type == "qt_embedding":
            conditioners[str(cond)] = QuantizedEmbeddingConditioner(
                dim=output_dim,
                **model_args
            )
        else:
            raise ValueError(f"Unrecognized conditioning model: {model_type}")
    conditioner = ConditionerProvider(conditioners, **condition_provider_args)
    return conditioner


def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
    """Instantiate a condition fuser object."""
    fuser_cfg = getattr(cfg, 'fuser')
    fuser_methods = ['sum', 'prepend']
    fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
    kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
    fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
    return fuser


def get_codebooks_pattern_provider(code_depth: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
    """Instantiate a codebooks pattern provider object."""
    pattern_providers = {
        'delay': DelayedPatternProvider,
    }
    name = cfg.modeling
    kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
    klass = pattern_providers[name]
    return klass(code_depth, **kwargs)