from abc import ABC, abstractmethod import typing as tp import omegaconf import torch from .encodec import CompressionModel from .lm import LMModel from .builders import get_wrapped_compression_model from .utils.audio_utils import convert_audio from .conditioners import ConditioningAttributes from .utils.autocast import TorchAutocast class BaseGenModel(ABC): """Base generative model with convenient generation API. Args: name (str): name of the model. compression_model (CompressionModel): Compression model used to map audio to invertible discrete representations. lm (LMModel): Language model over discrete representations. max_duration (float, optional): maximum duration the model can produce, otherwise, inferred from the training params. """ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, max_duration: tp.Optional[float] = None): self.name = name self.compression_model = compression_model self.lm = lm self.cfg: tp.Optional[omegaconf.DictConfig] = None # Just to be safe, let's put everything in eval mode. self.compression_model.eval() self.lm.eval() if hasattr(lm, 'cfg'): cfg = lm.cfg assert isinstance(cfg, omegaconf.DictConfig) self.cfg = cfg if self.cfg is not None: self.compression_model = get_wrapped_compression_model(self.compression_model, self.cfg) if max_duration is None: if self.cfg is not None: max_duration = lm.cfg.dataset.segment_duration # type: ignore else: raise ValueError("You must provide max_duration when building directly your GenModel") assert max_duration is not None self.max_duration: float = max_duration self.duration = self.max_duration # self.extend_stride is the length of audio extension when generating samples longer # than self.max_duration. NOTE: the derived class must set self.extend_stride to a # positive float value when generating with self.duration > self.max_duration. self.extend_stride: tp.Optional[float] = None self.device = next(iter(lm.parameters())).device self.generation_params: dict = {} self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None if self.device.type == 'cpu': self.autocast = TorchAutocast(enabled=False) else: self.autocast = TorchAutocast( enabled=True, device_type=self.device.type, dtype=torch.float16) @property def frame_rate(self) -> float: """Roughly the number of AR steps per seconds.""" return self.compression_model.frame_rate @property def sample_rate(self) -> int: """Sample rate of the generated audio.""" return self.compression_model.sample_rate @property def audio_channels(self) -> int: """Audio channels of the generated audio.""" return self.compression_model.channels def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None): """Override the default progress callback.""" self._progress_callback = progress_callback @abstractmethod def set_generation_params(self, *args, **kwargs): """Set the generation parameters.""" raise NotImplementedError("No base implementation for setting generation params.") @staticmethod @abstractmethod def get_pretrained(name: str, device=None): raise NotImplementedError("No base implementation for getting pretrained model") @torch.no_grad() def _prepare_tokens_and_attributes( self, descriptions: tp.Sequence[tp.Optional[str]], prompt: tp.Optional[torch.Tensor], ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]: """Prepare model inputs. Args: descriptions (list of str): A list of strings used as text conditioning. prompt (torch.Tensor): A batch of waveforms used for continuation. """ attributes = [ ConditioningAttributes(text={'description': description}) for description in descriptions] if prompt is not None: if descriptions is not None: assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match" prompt = prompt.to(self.device) prompt_tokens, scale = self.compression_model.encode(prompt) assert scale is None else: prompt_tokens = None return attributes, prompt_tokens def generate_unconditional(self, num_samples: int, progress: bool = False, return_tokens: bool = False) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: """Generate samples in an unconditional manner. Args: num_samples (int): Number of samples to be generated. progress (bool, optional): Flag to display progress of the generation process. Defaults to False. """ descriptions: tp.List[tp.Optional[str]] = [None] * num_samples attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) tokens = self._generate_tokens(attributes, prompt_tokens, progress) if return_tokens: return self.generate_audio(tokens), tokens return self.generate_audio(tokens) def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \ -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: """Generate samples conditioned on text. Args: descriptions (list of str): A list of strings used as text conditioning. progress (bool, optional): Flag to display progress of the generation process. Defaults to False. """ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) assert prompt_tokens is None tokens = self._generate_tokens(attributes, prompt_tokens, progress) if return_tokens: return self.generate_audio(tokens), tokens return self.generate_audio(tokens) def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int, descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None, progress: bool = False, return_tokens: bool = False) \ -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: """Generate samples conditioned on audio prompts and an optional text description. Args: prompt (torch.Tensor): A batch of waveforms used for continuation. Prompt should be [B, C, T], or [C, T] if only one sample is generated. prompt_sample_rate (int): Sampling rate of the given audio waveforms. descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None. progress (bool, optional): Flag to display progress of the generation process. Defaults to False. """ if prompt.dim() == 2: prompt = prompt[None] if prompt.dim() != 3: raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).") prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels) if descriptions is None: descriptions = [None] * len(prompt) attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt) assert prompt_tokens is not None tokens = self._generate_tokens(attributes, prompt_tokens, progress) if return_tokens: return self.generate_audio(tokens), tokens return self.generate_audio(tokens) def _generate_tokens(self, attributes: tp.List[ConditioningAttributes], prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor: """Generate discrete audio tokens given audio prompt and/or conditions. Args: attributes (list of ConditioningAttributes): Conditions used for generation (here text). prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation. progress (bool, optional): Flag to display progress of the generation process. Defaults to False. Returns: torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params. """ total_gen_len = int(self.duration * self.frame_rate) max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate) current_gen_offset: int = 0 def _progress_callback(generated_tokens: int, tokens_to_generate: int): generated_tokens += current_gen_offset if self._progress_callback is not None: # Note that total_gen_len might be quite wrong depending on the # codebook pattern used, but with delay it is almost accurate. self._progress_callback(generated_tokens, tokens_to_generate) else: print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r') if prompt_tokens is not None: assert max_prompt_len >= prompt_tokens.shape[-1], \ "Prompt is longer than audio to generate" callback = None if progress: callback = _progress_callback if self.duration <= self.max_duration: # generate by sampling from LM, simple case. with self.autocast: gen_tokens = self.lm.generate( prompt_tokens, attributes, callback=callback, max_gen_len=total_gen_len, **self.generation_params) else: assert self.extend_stride is not None, "Stride should be defined to generate beyond max_duration" assert self.extend_stride < self.max_duration, "Cannot stride by more than max generation duration." all_tokens = [] if prompt_tokens is None: prompt_length = 0 else: all_tokens.append(prompt_tokens) prompt_length = prompt_tokens.shape[-1] stride_tokens = int(self.frame_rate * self.extend_stride) while current_gen_offset + prompt_length < total_gen_len: time_offset = current_gen_offset / self.frame_rate chunk_duration = min(self.duration - time_offset, self.max_duration) max_gen_len = int(chunk_duration * self.frame_rate) with self.autocast: gen_tokens = self.lm.generate( prompt_tokens, attributes, callback=callback, max_gen_len=max_gen_len, **self.generation_params) if prompt_tokens is None: all_tokens.append(gen_tokens) else: all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:]) prompt_tokens = gen_tokens[:, :, stride_tokens:] prompt_length = prompt_tokens.shape[-1] current_gen_offset += stride_tokens gen_tokens = torch.cat(all_tokens, dim=-1) return gen_tokens def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor: """Generate Audio from tokens.""" assert gen_tokens.dim() == 3 with torch.no_grad(): gen_audio = self.compression_model.decode(gen_tokens, None) return gen_audio