|
from collections import defaultdict |
|
from dataclasses import dataclass, field |
|
import logging |
|
import random |
|
import typing as tp |
|
import warnings |
|
import soundfile |
|
from transformers import T5EncoderModel, T5Tokenizer |
|
import torch |
|
from torch import nn |
|
from .streaming import StreamingModule |
|
|
|
from .utils.autocast import TorchAutocast |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
TextCondition = tp.Optional[str] |
|
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] |
|
|
|
|
|
|
|
|
|
|
|
class JointEmbedCondition(tp.NamedTuple): |
|
wav: torch.Tensor |
|
text: tp.List[tp.Optional[str]] |
|
length: torch.Tensor |
|
sample_rate: tp.List[int] |
|
path: tp.List[tp.Optional[str]] = [] |
|
seek_time: tp.List[tp.Optional[float]] = [] |
|
|
|
|
|
@dataclass |
|
class ConditioningAttributes: |
|
text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict) |
|
wav: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict) |
|
joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict) |
|
|
|
def __getitem__(self, item): |
|
return getattr(self, item) |
|
|
|
@property |
|
def text_attributes(self): |
|
return self.text.keys() |
|
|
|
@property |
|
def wav_attributes(self): |
|
return self.wav.keys() |
|
|
|
@property |
|
def joint_embed_attributes(self): |
|
return self.joint_embed.keys() |
|
|
|
@property |
|
def attributes(self): |
|
return { |
|
"text": self.text_attributes, |
|
"wav": self.wav_attributes, |
|
"joint_embed": self.joint_embed_attributes, |
|
} |
|
|
|
def to_flat_dict(self): |
|
return { |
|
**{f"text.{k}": v for k, v in self.text.items()}, |
|
**{f"wav.{k}": v for k, v in self.wav.items()}, |
|
**{f"joint_embed.{k}": v for k, v in self.joint_embed.items()} |
|
} |
|
|
|
@classmethod |
|
def from_flat_dict(cls, x): |
|
out = cls() |
|
for k, v in x.items(): |
|
kind, att = k.split(".") |
|
out[kind][att] = v |
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Tokenizer: |
|
"""Base tokenizer implementation |
|
(in case we want to introduce more advances tokenizers in the future). |
|
""" |
|
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: |
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseConditioner(nn.Module): |
|
"""Base model for all conditioner modules. |
|
We allow the output dim to be different than the hidden dim for two reasons: |
|
1) keep our LUTs small when the vocab is large; |
|
2) make all condition dims consistent. |
|
|
|
Args: |
|
dim (int): Hidden dim of the model. |
|
output_dim (int): Output dim of the conditioner. |
|
""" |
|
def __init__(self, dim: int, output_dim: int): |
|
super().__init__() |
|
self.dim = dim |
|
self.output_dim = output_dim |
|
self.output_proj = nn.Linear(dim, output_dim) |
|
|
|
def tokenize(self, *args, **kwargs) -> tp.Any: |
|
"""Should be any part of the processing that will lead to a synchronization |
|
point, e.g. BPE tokenization with transfer to the GPU. |
|
|
|
The returned value will be saved and return later when calling forward(). |
|
""" |
|
raise NotImplementedError() |
|
|
|
def forward(self, inputs: tp.Any) -> ConditionType: |
|
"""Gets input that should be used as conditioning (e.g, genre, description or a waveform). |
|
Outputs a ConditionType, after the input data was embedded as a dense vector. |
|
|
|
Returns: |
|
ConditionType: |
|
- A tensor of size [B, T, D] where B is the batch size, T is the length of the |
|
output embedding and D is the dimension of the embedding. |
|
- And a mask indicating where the padding tokens. |
|
""" |
|
raise NotImplementedError() |
|
|
|
|
|
class TextConditioner(BaseConditioner): |
|
... |
|
|
|
|
|
|
|
|
|
|
|
class T5Conditioner(TextConditioner): |
|
"""T5-based TextConditioner. |
|
|
|
Args: |
|
name (str): Name of the T5 model. |
|
output_dim (int): Output dim of the conditioner. |
|
finetune (bool): Whether to fine-tune T5 at train time. |
|
device (str): Device for T5 Conditioner. |
|
autocast_dtype (tp.Optional[str], optional): Autocast dtype. |
|
word_dropout (float, optional): Word dropout probability. |
|
normalize_text (bool, optional): Whether to apply text normalization. |
|
""" |
|
MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", |
|
"google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", |
|
"google/flan-t5-xl", "google/flan-t5-xxl"] |
|
MODELS_DIMS = { |
|
"t5-small": 512, |
|
"t5-base": 768, |
|
"t5-large": 1024, |
|
"t5-3b": 1024, |
|
"t5-11b": 1024, |
|
"google/flan-t5-small": 512, |
|
"google/flan-t5-base": 768, |
|
"google/flan-t5-large": 1024, |
|
"google/flan-t5-3b": 1024, |
|
"google/flan-t5-11b": 1024, |
|
} |
|
|
|
def __init__(self, name: str, output_dim: int, finetune: bool, device: str, |
|
autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0., |
|
normalize_text: bool = False): |
|
assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})" |
|
super().__init__(self.MODELS_DIMS[name], output_dim) |
|
self.device = device |
|
self.name = name |
|
self.finetune = finetune |
|
self.word_dropout = word_dropout |
|
if autocast_dtype is None or self.device == 'cpu': |
|
self.autocast = TorchAutocast(enabled=False) |
|
if self.device != 'cpu': |
|
logger.warning("T5 has no autocast, this might lead to NaN") |
|
else: |
|
dtype = getattr(torch, autocast_dtype) |
|
assert isinstance(dtype, torch.dtype) |
|
logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}") |
|
self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype) |
|
|
|
|
|
previous_level = logging.root.manager.disable |
|
logging.disable(logging.ERROR) |
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
try: |
|
self.t5_tokenizer = T5Tokenizer.from_pretrained(name) |
|
t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune) |
|
finally: |
|
logging.disable(previous_level) |
|
if finetune: |
|
self.t5 = t5 |
|
else: |
|
|
|
|
|
self.__dict__['t5'] = t5.to(device) |
|
|
|
self.normalize_text = normalize_text |
|
if normalize_text: |
|
self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True) |
|
|
|
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]: |
|
|
|
entries: tp.List[str] = [xi if xi is not None else "" for xi in x] |
|
if self.normalize_text: |
|
_, _, entries = self.text_normalizer(entries, return_text=True) |
|
if self.word_dropout > 0. and self.training: |
|
new_entries = [] |
|
for entry in entries: |
|
words = [word for word in entry.split(" ") if random.random() >= self.word_dropout] |
|
new_entries.append(" ".join(words)) |
|
entries = new_entries |
|
|
|
empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""]) |
|
|
|
inputs = self.t5_tokenizer(entries, return_tensors='pt', padding=True).to(self.device) |
|
mask = inputs['attention_mask'] |
|
mask[empty_idx, :] = 0 |
|
return inputs |
|
|
|
def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType: |
|
mask = inputs['attention_mask'] |
|
with torch.set_grad_enabled(self.finetune), self.autocast: |
|
embeds = self.t5(**inputs).last_hidden_state |
|
embeds = self.output_proj(embeds.to(self.output_proj.weight)) |
|
embeds = (embeds * mask.unsqueeze(-1)) |
|
return embeds, mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConditioningProvider(nn.Module): |
|
"""Prepare and provide conditions given all the supported conditioners. |
|
|
|
Args: |
|
conditioners (dict): Dictionary of conditioners. |
|
device (torch.device or str, optional): Device for conditioners and output condition types. |
|
""" |
|
def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"): |
|
super().__init__() |
|
self.device = device |
|
self.conditioners = nn.ModuleDict(conditioners) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@property |
|
def text_conditions(self): |
|
return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)] |
|
|
|
|
|
|
|
def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]: |
|
"""Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly. |
|
This should be called before starting any real GPU work to avoid synchronization points. |
|
This will return a dict matching conditioner names to their arbitrary tokenized representations. |
|
|
|
Args: |
|
inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing |
|
text and wav conditions. |
|
""" |
|
assert all([isinstance(x, ConditioningAttributes) for x in inputs]), ( |
|
"Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]", |
|
f" but types were {set([type(x) for x in inputs])}" |
|
) |
|
|
|
output = {} |
|
text = self._collate_text(inputs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for attribute, batch in text.items(): |
|
output[attribute] = self.conditioners[attribute].tokenize(batch) |
|
return output |
|
|
|
def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]: |
|
"""Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations. |
|
The output is for example: |
|
{ |
|
"genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])), |
|
"description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])), |
|
... |
|
} |
|
|
|
Args: |
|
tokenized (dict): Dict of tokenized representations as returned by `tokenize()`. |
|
""" |
|
output = {} |
|
for attribute, inputs in tokenized.items(): |
|
condition, mask = self.conditioners[attribute](inputs) |
|
output[attribute] = (condition, mask) |
|
return output |
|
|
|
def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]: |
|
"""Given a list of ConditioningAttributes objects, compile a dictionary where the keys |
|
are the attributes and the values are the aggregated input per attribute. |
|
For example: |
|
Input: |
|
[ |
|
ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...), |
|
ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...), |
|
] |
|
Output: |
|
{ |
|
"genre": ["Rock", "Hip-hop"], |
|
"description": ["A rock song with a guitar solo", "A hip-hop verse"] |
|
} |
|
|
|
Args: |
|
samples (list of ConditioningAttributes): List of ConditioningAttributes samples. |
|
Returns: |
|
dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch. |
|
""" |
|
out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list) |
|
texts = [x.text for x in samples] |
|
for text in texts: |
|
for condition in self.text_conditions: |
|
out[condition].append(text[condition]) |
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConditionFuser(StreamingModule): |
|
"""Condition fuser handles the logic to combine the different conditions |
|
to the actual model input. |
|
|
|
Args: |
|
fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse |
|
each condition. For example: |
|
{ |
|
"prepend": ["description"], |
|
"sum": ["genre", "bpm"], |
|
"cross": ["description"], |
|
} |
|
cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention. |
|
cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used. |
|
""" |
|
FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"] |
|
|
|
def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False, |
|
cross_attention_pos_emb_scale: float = 1.0): |
|
super().__init__() |
|
assert all( |
|
[k in self.FUSING_METHODS for k in fuse2cond.keys()] |
|
), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}" |
|
self.cross_attention_pos_emb = cross_attention_pos_emb |
|
self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale |
|
self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond |
|
self.cond2fuse: tp.Dict[str, str] = {} |
|
for fuse_method, conditions in fuse2cond.items(): |
|
for condition in conditions: |
|
self.cond2fuse[condition] = fuse_method |
|
|
|
def forward( |
|
self, |
|
input: torch.Tensor, |
|
conditions: tp.Dict[str, ConditionType] |
|
) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: |
|
"""Fuse the conditions to the provided model input. |
|
|
|
Args: |
|
input (torch.Tensor): Transformer input. |
|
conditions (dict[str, ConditionType]): Dict of conditions. |
|
Returns: |
|
tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input |
|
after the conditions have been fused. The second output tensor is the tensor |
|
used for cross-attention or None if no cross attention inputs exist. |
|
""" |
|
B, T, _ = input.shape |
|
|
|
|
|
first_step = True |
|
offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device) |
|
|
|
|
|
cross_attention_output = None |
|
for cond_type, (cond, cond_mask) in conditions.items(): |
|
|
|
|
|
cross_attention_output = cond |
|
|
|
|
|
if self._is_streaming: |
|
self._streaming_state['offsets'] = offsets + T |
|
|
|
return input, cross_attention_output |
|
|