|
import json
|
|
|
|
def create_model_from_config(model_config):
|
|
model_type = model_config.get('model_type', None)
|
|
|
|
assert model_type is not None, 'model_type must be specified in model config'
|
|
|
|
if model_type == 'autoencoder':
|
|
from .autoencoders import create_autoencoder_from_config
|
|
return create_autoencoder_from_config(model_config)
|
|
elif model_type == 'diffusion_uncond':
|
|
from .diffusion import create_diffusion_uncond_from_config
|
|
return create_diffusion_uncond_from_config(model_config)
|
|
elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior":
|
|
from .diffusion import create_diffusion_cond_from_config
|
|
return create_diffusion_cond_from_config(model_config)
|
|
elif model_type == 'diffusion_autoencoder':
|
|
from .autoencoders import create_diffAE_from_config
|
|
return create_diffAE_from_config(model_config)
|
|
elif model_type == 'lm':
|
|
from .lm import create_audio_lm_from_config
|
|
return create_audio_lm_from_config(model_config)
|
|
else:
|
|
raise NotImplementedError(f'Unknown model type: {model_type}')
|
|
|
|
def create_model_from_config_path(model_config_path):
|
|
with open(model_config_path) as f:
|
|
model_config = json.load(f)
|
|
|
|
return create_model_from_config(model_config)
|
|
|
|
def create_pretransform_from_config(pretransform_config, sample_rate):
|
|
pretransform_type = pretransform_config.get('type', None)
|
|
|
|
assert pretransform_type is not None, 'type must be specified in pretransform config'
|
|
|
|
if pretransform_type == 'autoencoder':
|
|
from .autoencoders import create_autoencoder_from_config
|
|
from .pretransforms import AutoencoderPretransform
|
|
|
|
|
|
|
|
autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]}
|
|
autoencoder = create_autoencoder_from_config(autoencoder_config)
|
|
|
|
scale = pretransform_config.get("scale", 1.0)
|
|
model_half = pretransform_config.get("model_half", False)
|
|
iterate_batch = pretransform_config.get("iterate_batch", False)
|
|
chunked = pretransform_config.get("chunked", False)
|
|
|
|
pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked)
|
|
elif pretransform_type == 'wavelet':
|
|
from .pretransforms import WaveletPretransform
|
|
|
|
wavelet_config = pretransform_config["config"]
|
|
channels = wavelet_config["channels"]
|
|
levels = wavelet_config["levels"]
|
|
wavelet = wavelet_config["wavelet"]
|
|
|
|
pretransform = WaveletPretransform(channels, levels, wavelet)
|
|
elif pretransform_type == 'pqmf':
|
|
from .pretransforms import PQMFPretransform
|
|
pqmf_config = pretransform_config["config"]
|
|
pretransform = PQMFPretransform(**pqmf_config)
|
|
elif pretransform_type == 'dac_pretrained':
|
|
from .pretransforms import PretrainedDACPretransform
|
|
pretrained_dac_config = pretransform_config["config"]
|
|
pretransform = PretrainedDACPretransform(**pretrained_dac_config)
|
|
elif pretransform_type == "audiocraft_pretrained":
|
|
from .pretransforms import AudiocraftCompressionPretransform
|
|
|
|
audiocraft_config = pretransform_config["config"]
|
|
pretransform = AudiocraftCompressionPretransform(**audiocraft_config)
|
|
else:
|
|
raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}')
|
|
|
|
enable_grad = pretransform_config.get('enable_grad', False)
|
|
pretransform.enable_grad = enable_grad
|
|
|
|
pretransform.eval().requires_grad_(pretransform.enable_grad)
|
|
|
|
return pretransform
|
|
|
|
def create_bottleneck_from_config(bottleneck_config):
|
|
bottleneck_type = bottleneck_config.get('type', None)
|
|
|
|
assert bottleneck_type is not None, 'type must be specified in bottleneck config'
|
|
|
|
if bottleneck_type == 'tanh':
|
|
from .bottleneck import TanhBottleneck
|
|
bottleneck = TanhBottleneck()
|
|
elif bottleneck_type == 'vae':
|
|
from .bottleneck import VAEBottleneck
|
|
bottleneck = VAEBottleneck()
|
|
elif bottleneck_type == 'rvq':
|
|
from .bottleneck import RVQBottleneck
|
|
|
|
quantizer_params = {
|
|
"dim": 128,
|
|
"codebook_size": 1024,
|
|
"num_quantizers": 8,
|
|
"decay": 0.99,
|
|
"kmeans_init": True,
|
|
"kmeans_iters": 50,
|
|
"threshold_ema_dead_code": 2,
|
|
}
|
|
|
|
quantizer_params.update(bottleneck_config["config"])
|
|
|
|
bottleneck = RVQBottleneck(**quantizer_params)
|
|
elif bottleneck_type == "dac_rvq":
|
|
from .bottleneck import DACRVQBottleneck
|
|
|
|
bottleneck = DACRVQBottleneck(**bottleneck_config["config"])
|
|
|
|
elif bottleneck_type == 'rvq_vae':
|
|
from .bottleneck import RVQVAEBottleneck
|
|
|
|
quantizer_params = {
|
|
"dim": 128,
|
|
"codebook_size": 1024,
|
|
"num_quantizers": 8,
|
|
"decay": 0.99,
|
|
"kmeans_init": True,
|
|
"kmeans_iters": 50,
|
|
"threshold_ema_dead_code": 2,
|
|
}
|
|
|
|
quantizer_params.update(bottleneck_config["config"])
|
|
|
|
bottleneck = RVQVAEBottleneck(**quantizer_params)
|
|
|
|
elif bottleneck_type == 'dac_rvq_vae':
|
|
from .bottleneck import DACRVQVAEBottleneck
|
|
bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"])
|
|
elif bottleneck_type == 'l2_norm':
|
|
from .bottleneck import L2Bottleneck
|
|
bottleneck = L2Bottleneck()
|
|
elif bottleneck_type == "wasserstein":
|
|
from .bottleneck import WassersteinBottleneck
|
|
bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {}))
|
|
elif bottleneck_type == "fsq":
|
|
from .bottleneck import FSQBottleneck
|
|
bottleneck = FSQBottleneck(**bottleneck_config["config"])
|
|
else:
|
|
raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}')
|
|
|
|
requires_grad = bottleneck_config.get('requires_grad', True)
|
|
if not requires_grad:
|
|
for param in bottleneck.parameters():
|
|
param.requires_grad = False
|
|
|
|
return bottleneck
|
|
|