Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Main model for using AudioGen. This will combine all the required components | |
and provide easy access to the generation API. | |
""" | |
import typing as tp | |
import torch | |
from .encodec import CompressionModel | |
from .genmodel import BaseGenModel | |
from .lm import LMModel | |
from .builders import get_debug_compression_model, get_debug_lm_model | |
from .loaders import load_compression_model, load_lm_model | |
class AudioGen(BaseGenModel): | |
"""AudioGen main model with convenient generation API. | |
Args: | |
name (str): name of the model. | |
compression_model (CompressionModel): Compression model | |
used to map audio to invertible discrete representations. | |
lm (LMModel): Language model over discrete representations. | |
max_duration (float, optional): maximum duration the model can produce, | |
otherwise, inferred from the training params. | |
""" | |
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, | |
max_duration: tp.Optional[float] = None): | |
super().__init__(name, compression_model, lm, max_duration) | |
self.set_generation_params(duration=5) # default duration | |
def get_pretrained(name: str = 'facebook/audiogen-medium', device=None): | |
"""Return pretrained model, we provide a single model for now: | |
- facebook/audiogen-medium (1.5B), text to sound, | |
# see: https://huggingface.co/facebook/audiogen-medium | |
""" | |
if device is None: | |
if torch.cuda.device_count(): | |
device = 'cuda' | |
else: | |
device = 'cpu' | |
if name == 'debug': | |
# used only for unit tests | |
compression_model = get_debug_compression_model(device, sample_rate=16000) | |
lm = get_debug_lm_model(device) | |
return AudioGen(name, compression_model, lm, max_duration=10) | |
compression_model = load_compression_model(name, device=device) | |
lm = load_lm_model(name, device=device) | |
assert 'self_wav' not in lm.condition_provider.conditioners, \ | |
"AudioGen do not support waveform conditioning for now" | |
return AudioGen(name, compression_model, lm) | |
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250, | |
top_p: float = 0.0, temperature: float = 1.0, | |
duration: float = 10.0, cfg_coef: float = 3.0, | |
two_step_cfg: bool = False, extend_stride: float = 2): | |
"""Set the generation parameters for AudioGen. | |
Args: | |
use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True. | |
top_k (int, optional): top_k used for sampling. Defaults to 250. | |
top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0. | |
temperature (float, optional): Softmax temperature parameter. Defaults to 1.0. | |
duration (float, optional): Duration of the generated waveform. Defaults to 10.0. | |
cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0. | |
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance, | |
instead of batching together the two. This has some impact on how things | |
are padded but seems to have little impact in practice. | |
extend_stride: when doing extended generation (i.e. more than 10 seconds), by how much | |
should we extend the audio each time. Larger values will mean less context is | |
preserved, and shorter value will require extra computations. | |
""" | |
assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration." | |
self.extend_stride = extend_stride | |
self.duration = duration | |
self.generation_params = { | |
'use_sampling': use_sampling, | |
'temp': temperature, | |
'top_k': top_k, | |
'top_p': top_p, | |
'cfg_coef': cfg_coef, | |
'two_step_cfg': two_step_cfg, | |
} | |