from collections import defaultdict from dataclasses import dataclass, field import logging import random import typing as tp import warnings import soundfile from transformers import T5EncoderModel, T5Tokenizer # type: ignore import torch from torch import nn from .streaming import StreamingModule from .utils.autocast import TorchAutocast logger = logging.getLogger(__name__) TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist) ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask class JointEmbedCondition(tp.NamedTuple): wav: torch.Tensor text: tp.List[tp.Optional[str]] length: torch.Tensor sample_rate: tp.List[int] path: tp.List[tp.Optional[str]] = [] seek_time: tp.List[tp.Optional[float]] = [] @dataclass class ConditioningAttributes: text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict) wav: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict) joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict) def __getitem__(self, item): return getattr(self, item) @property def text_attributes(self): return self.text.keys() @property def wav_attributes(self): return self.wav.keys() @property def joint_embed_attributes(self): return self.joint_embed.keys() @property def attributes(self): return { "text": self.text_attributes, "wav": self.wav_attributes, "joint_embed": self.joint_embed_attributes, } def to_flat_dict(self): return { **{f"text.{k}": v for k, v in self.text.items()}, **{f"wav.{k}": v for k, v in self.wav.items()}, **{f"joint_embed.{k}": v for k, v in self.joint_embed.items()} } @classmethod def from_flat_dict(cls, x): out = cls() for k, v in x.items(): kind, att = k.split(".") out[kind][att] = v return out class Tokenizer: """Base tokenizer implementation (in case we want to introduce more advances tokenizers in the future). """ def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError() class BaseConditioner(nn.Module): """Base model for all conditioner modules. We allow the output dim to be different than the hidden dim for two reasons: 1) keep our LUTs small when the vocab is large; 2) make all condition dims consistent. Args: dim (int): Hidden dim of the model. output_dim (int): Output dim of the conditioner. """ def __init__(self, dim: int, output_dim: int): super().__init__() self.dim = dim self.output_dim = output_dim self.output_proj = nn.Linear(dim, output_dim) def tokenize(self, *args, **kwargs) -> tp.Any: """Should be any part of the processing that will lead to a synchronization point, e.g. BPE tokenization with transfer to the GPU. The returned value will be saved and return later when calling forward(). """ raise NotImplementedError() def forward(self, inputs: tp.Any) -> ConditionType: """Gets input that should be used as conditioning (e.g, genre, description or a waveform). Outputs a ConditionType, after the input data was embedded as a dense vector. Returns: ConditionType: - A tensor of size [B, T, D] where B is the batch size, T is the length of the output embedding and D is the dimension of the embedding. - And a mask indicating where the padding tokens. """ raise NotImplementedError() class TextConditioner(BaseConditioner): ... class T5Conditioner(TextConditioner): """T5-based TextConditioner. Args: name (str): Name of the T5 model. output_dim (int): Output dim of the conditioner. finetune (bool): Whether to fine-tune T5 at train time. device (str): Device for T5 Conditioner. autocast_dtype (tp.Optional[str], optional): Autocast dtype. word_dropout (float, optional): Word dropout probability. normalize_text (bool, optional): Whether to apply text normalization. """ MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", "google/flan-t5-xl", "google/flan-t5-xxl"] MODELS_DIMS = { "t5-small": 512, "t5-base": 768, "t5-large": 1024, "t5-3b": 1024, "t5-11b": 1024, "google/flan-t5-small": 512, "google/flan-t5-base": 768, "google/flan-t5-large": 1024, "google/flan-t5-3b": 1024, "google/flan-t5-11b": 1024, } def __init__(self, name: str, output_dim: int, finetune: bool, device: str, autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0., normalize_text: bool = False): assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})" super().__init__(self.MODELS_DIMS[name], output_dim) self.device = device self.name = name self.finetune = finetune self.word_dropout = word_dropout if autocast_dtype is None or self.device == 'cpu': self.autocast = TorchAutocast(enabled=False) if self.device != 'cpu': logger.warning("T5 has no autocast, this might lead to NaN") else: dtype = getattr(torch, autocast_dtype) assert isinstance(dtype, torch.dtype) logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}") self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype) # Let's disable logging temporarily because T5 will vomit some errors otherwise. # thanks https://gist.github.com/simon-weber/7853144 previous_level = logging.root.manager.disable logging.disable(logging.ERROR) with warnings.catch_warnings(): warnings.simplefilter("ignore") try: self.t5_tokenizer = T5Tokenizer.from_pretrained(name) t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune) finally: logging.disable(previous_level) if finetune: self.t5 = t5 else: # this makes sure that the t5 models is not part # of the saved checkpoint self.__dict__['t5'] = t5.to(device) self.normalize_text = normalize_text if normalize_text: self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True) def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]: # if current sample doesn't have a certain attribute, replace with empty string entries: tp.List[str] = [xi if xi is not None else "" for xi in x] if self.normalize_text: _, _, entries = self.text_normalizer(entries, return_text=True) if self.word_dropout > 0. and self.training: new_entries = [] for entry in entries: words = [word for word in entry.split(" ") if random.random() >= self.word_dropout] new_entries.append(" ".join(words)) entries = new_entries empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""]) inputs = self.t5_tokenizer(entries, return_tensors='pt', padding=True).to(self.device) mask = inputs['attention_mask'] mask[empty_idx, :] = 0 # zero-out index where the input is non-existant return inputs def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType: mask = inputs['attention_mask'] with torch.set_grad_enabled(self.finetune), self.autocast: embeds = self.t5(**inputs).last_hidden_state embeds = self.output_proj(embeds.to(self.output_proj.weight)) embeds = (embeds * mask.unsqueeze(-1)) return embeds, mask class ConditioningProvider(nn.Module): """Prepare and provide conditions given all the supported conditioners. Args: conditioners (dict): Dictionary of conditioners. device (torch.device or str, optional): Device for conditioners and output condition types. """ def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"): super().__init__() self.device = device self.conditioners = nn.ModuleDict(conditioners) # @property # def joint_embed_conditions(self): # return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)] # @property # def has_joint_embed_conditions(self): # return len(self.joint_embed_conditions) > 0 @property def text_conditions(self): return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)] def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]: """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly. This should be called before starting any real GPU work to avoid synchronization points. This will return a dict matching conditioner names to their arbitrary tokenized representations. Args: inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing text and wav conditions. """ assert all([isinstance(x, ConditioningAttributes) for x in inputs]), ( "Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]", f" but types were {set([type(x) for x in inputs])}" ) output = {} text = self._collate_text(inputs) # wavs = self._collate_wavs(inputs) # joint_embeds = self._collate_joint_embeds(inputs) # assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), ( # f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ", # f"got {text.keys(), wavs.keys(), joint_embeds.keys()}" # ) for attribute, batch in text.items(): #, joint_embeds.items()): output[attribute] = self.conditioners[attribute].tokenize(batch) return output def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]: """Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations. The output is for example: { "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])), "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])), ... } Args: tokenized (dict): Dict of tokenized representations as returned by `tokenize()`. """ output = {} for attribute, inputs in tokenized.items(): condition, mask = self.conditioners[attribute](inputs) output[attribute] = (condition, mask) return output def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]: """Given a list of ConditioningAttributes objects, compile a dictionary where the keys are the attributes and the values are the aggregated input per attribute. For example: Input: [ ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...), ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...), ] Output: { "genre": ["Rock", "Hip-hop"], "description": ["A rock song with a guitar solo", "A hip-hop verse"] } Args: samples (list of ConditioningAttributes): List of ConditioningAttributes samples. Returns: dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch. """ out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list) texts = [x.text for x in samples] for text in texts: for condition in self.text_conditions: out[condition].append(text[condition]) return out class ConditionFuser(StreamingModule): """Condition fuser handles the logic to combine the different conditions to the actual model input. Args: fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse each condition. For example: { "prepend": ["description"], "sum": ["genre", "bpm"], "cross": ["description"], } cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention. cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used. """ FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"] def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False, cross_attention_pos_emb_scale: float = 1.0): super().__init__() assert all( [k in self.FUSING_METHODS for k in fuse2cond.keys()] ), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}" self.cross_attention_pos_emb = cross_attention_pos_emb self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond self.cond2fuse: tp.Dict[str, str] = {} for fuse_method, conditions in fuse2cond.items(): for condition in conditions: self.cond2fuse[condition] = fuse_method def forward( self, input: torch.Tensor, conditions: tp.Dict[str, ConditionType] ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: """Fuse the conditions to the provided model input. Args: input (torch.Tensor): Transformer input. conditions (dict[str, ConditionType]): Dict of conditions. Returns: tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input after the conditions have been fused. The second output tensor is the tensor used for cross-attention or None if no cross attention inputs exist. """ B, T, _ = input.shape first_step = True offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device) cross_attention_output = None for cond_type, (cond, cond_mask) in conditions.items(): # print(f'{self.cond2fuse=}') - self.cond2fuse={'description': 'cross'} cross_attention_output = cond if self._is_streaming: self._streaming_state['offsets'] = offsets + T return input, cross_attention_output