Dionyssos's picture
cleanup 3
5067878
raw
history blame
9.66 kB
from abc import ABC, abstractmethod
import typing as tp
import omegaconf
import torch
from .encodec import CompressionModel
from .lm import LMModel
from .utils.audio_utils import convert_audio
from .conditioners import ConditioningAttributes
from .utils.autocast import TorchAutocast
class BaseGenModel(ABC):
"""Base generative model with convenient generation API.
Args:
name (str): name of the model.
compression_model (CompressionModel): Compression model
used to map audio to invertible discrete representations.
lm (LMModel): Language model over discrete representations.
max_duration (float, optional): maximum duration the model can produce,
otherwise, inferred from the training params.
"""
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
max_duration: tp.Optional[float] = None):
self.name = name
self.compression_model = compression_model
self.lm = lm
self.cfg: tp.Optional[omegaconf.DictConfig] = None
# Just to be safe, let's put everything in eval mode.
self.compression_model.eval()
self.lm.eval()
if hasattr(lm, 'cfg'):
cfg = lm.cfg
assert isinstance(cfg, omegaconf.DictConfig)
self.cfg = cfg
if max_duration is None:
if self.cfg is not None:
max_duration = lm.cfg.dataset.segment_duration # type: ignore
else:
raise ValueError("You must provide max_duration when building directly your GenModel")
assert max_duration is not None
self.max_duration: float = max_duration
self.duration = self.max_duration
# self.extend_stride is the length of audio extension when generating samples longer
# than self.max_duration. NOTE: the derived class must set self.extend_stride to a
# positive float value when generating with self.duration > self.max_duration.
self.extend_stride: tp.Optional[float] = None
self.device = next(iter(lm.parameters())).device
self.generation_params: dict = {}
self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
if self.device.type == 'cpu':
self.autocast = TorchAutocast(enabled=False)
else:
self.autocast = TorchAutocast(
enabled=True, device_type=self.device.type, dtype=torch.float16)
@property
def frame_rate(self) -> float:
"""Roughly the number of AR steps per seconds."""
return self.compression_model.frame_rate
@property
def sample_rate(self) -> int:
"""Sample rate of the generated audio."""
return self.compression_model.sample_rate
@property
def audio_channels(self) -> int:
"""Audio channels of the generated audio."""
return self.compression_model.channels
def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
"""Override the default progress callback."""
self._progress_callback = progress_callback
@abstractmethod
def set_generation_params(self, *args, **kwargs):
"""Set the generation parameters."""
raise NotImplementedError("No base implementation for setting generation params.")
@staticmethod
@abstractmethod
def get_pretrained(name: str, device=None):
raise NotImplementedError("No base implementation for getting pretrained model")
@torch.no_grad()
def _prepare_tokens_and_attributes(
self,
descriptions: tp.Sequence[tp.Optional[str]],
prompt: tp.Optional[torch.Tensor],
) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
"""Prepare model inputs.
Args:
descriptions (list of str): A list of strings used as text conditioning.
prompt (torch.Tensor): A batch of waveforms used for continuation.
"""
attributes = [
ConditioningAttributes(text={'description': description})
for description in descriptions]
if prompt is not None:
if descriptions is not None:
assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
prompt = prompt.to(self.device)
prompt_tokens, scale = self.compression_model.encode(prompt)
assert scale is None
else:
prompt_tokens = None
return attributes, prompt_tokens
def generate_unconditional(self, num_samples: int, progress: bool = False,
return_tokens: bool = False) -> tp.Union[torch.Tensor,
tp.Tuple[torch.Tensor, torch.Tensor]]:
"""Generate samples in an unconditional manner.
Args:
num_samples (int): Number of samples to be generated.
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
"""
descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
if return_tokens:
return self.generate_audio(tokens), tokens
return self.generate_audio(tokens)
def generate(self, descriptions, progress = False, return_tokens= False):
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
assert prompt_tokens is None
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
if return_tokens:
return self.generate_audio(tokens), tokens
return self.generate_audio(tokens)
def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
"""Generate discrete audio tokens given audio prompt and/or conditions.
Args:
attributes (list of ConditioningAttributes): Conditions used for generation (here text).
prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
Returns:
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
"""
total_gen_len = int(self.duration * self.frame_rate)
max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
current_gen_offset: int = 0
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
generated_tokens += current_gen_offset
if self._progress_callback is not None:
# Note that total_gen_len might be quite wrong depending on the
# codebook pattern used, but with delay it is almost accurate.
self._progress_callback(generated_tokens, tokens_to_generate)
else:
print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r')
if prompt_tokens is not None:
assert max_prompt_len >= prompt_tokens.shape[-1], \
"Prompt is longer than audio to generate"
callback = None
if progress:
callback = _progress_callback
if self.duration <= self.max_duration:
# generate by sampling from LM, simple case.
with self.autocast:
gen_tokens = self.lm.generate(
prompt_tokens, attributes,
callback=callback, max_gen_len=total_gen_len, **self.generation_params)
else:
assert self.extend_stride is not None, "Stride should be defined to generate beyond max_duration"
assert self.extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
all_tokens = []
if prompt_tokens is None:
prompt_length = 0
else:
all_tokens.append(prompt_tokens)
prompt_length = prompt_tokens.shape[-1]
stride_tokens = int(self.frame_rate * self.extend_stride)
while current_gen_offset + prompt_length < total_gen_len:
time_offset = current_gen_offset / self.frame_rate
chunk_duration = min(self.duration - time_offset, self.max_duration)
max_gen_len = int(chunk_duration * self.frame_rate)
with self.autocast:
gen_tokens = self.lm.generate(
prompt_tokens, attributes,
callback=callback, max_gen_len=max_gen_len, **self.generation_params)
if prompt_tokens is None:
all_tokens.append(gen_tokens)
else:
all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
prompt_tokens = gen_tokens[:, :, stride_tokens:]
prompt_length = prompt_tokens.shape[-1]
current_gen_offset += stride_tokens
gen_tokens = torch.cat(all_tokens, dim=-1)
return gen_tokens
def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor:
"""Generate Audio from tokens."""
assert gen_tokens.dim() == 3
with torch.no_grad():
gen_audio = self.compression_model.decode(gen_tokens, None)
return gen_audio