Spaces:
Running
on
L40S
Running
on
L40S
import typing as tp | |
import torch | |
import torch.nn as nn | |
from dataclasses import dataclass, field, fields | |
from itertools import chain | |
import warnings | |
import torch.nn.functional as F | |
from torch.nn.utils.rnn import pad_sequence | |
from codeclm.utils.utils import length_to_mask, collate | |
from codeclm.modules.streaming import StreamingModule | |
from collections import defaultdict | |
from copy import deepcopy | |
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask | |
# ================================================================ | |
# Condition and Condition attributes definitions | |
# ================================================================ | |
class AudioCondition(tp.NamedTuple): | |
wav: torch.Tensor | |
length: torch.Tensor | |
sample_rate: tp.List[int] | |
path: tp.List[tp.Optional[str]] = [] | |
seek_time: tp.List[tp.Optional[float]] = [] | |
class ConditioningAttributes: | |
text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict) | |
audio: tp.Dict[str, AudioCondition] = field(default_factory=dict) | |
def __getitem__(self, item): | |
return getattr(self, item) | |
def text_attributes(self): | |
return self.text.keys() | |
def audio_attributes(self): | |
return self.audio.keys() | |
def attributes(self): | |
return { | |
"text": self.text_attributes, | |
"audio": self.audio_attributes, | |
} | |
def to_flat_dict(self): | |
return { | |
**{f"text.{k}": v for k, v in self.text.items()}, | |
**{f"audio.{k}": v for k, v in self.audio.items()}, | |
} | |
def from_flat_dict(cls, x): | |
out = cls() | |
for k, v in x.items(): | |
kind, att = k.split(".") | |
out[kind][att] = v | |
return out | |
# ================================================================ | |
# Conditioner (tokenize and encode raw conditions) definitions | |
# ================================================================ | |
class BaseConditioner(nn.Module): | |
"""Base model for all conditioner modules. | |
We allow the output dim to be different than the hidden dim for two reasons: | |
1) keep our LUTs small when the vocab is large; | |
2) make all condition dims consistent. | |
Args: | |
dim (int): Hidden dim of the model. | |
output_dim (int): Output dim of the conditioner. | |
""" | |
def __init__(self, dim: int, output_dim: int, input_token = False, padding_idx=0): | |
super().__init__() | |
self.dim = dim | |
self.output_dim = output_dim | |
if input_token: | |
self.output_proj = nn.Embedding(dim, output_dim, padding_idx) | |
else: | |
self.output_proj = nn.Linear(dim, output_dim) | |
def tokenize(self, *args, **kwargs) -> tp.Any: | |
"""Should be any part of the processing that will lead to a synchronization | |
point, e.g. BPE tokenization with transfer to the GPU. | |
The returned value will be saved and return later when calling forward(). | |
""" | |
raise NotImplementedError() | |
def forward(self, inputs: tp.Any) -> ConditionType: | |
"""Gets input that should be used as conditioning (e.g, genre, description or a waveform). | |
Outputs a ConditionType, after the input data was embedded as a dense vector. | |
Returns: | |
ConditionType: | |
- A tensor of size [B, T, D] where B is the batch size, T is the length of the | |
output embedding and D is the dimension of the embedding. | |
- And a mask indicating where the padding tokens. | |
""" | |
raise NotImplementedError() | |
class TextConditioner(BaseConditioner): | |
... | |
class PhonemeTokenizerConditioner(TextConditioner): | |
def __init__(self, | |
output_dim: int, | |
vocab_list, | |
max_len = 600, | |
max_sentence_per_structure = 50, | |
structure_tokens=None, | |
structure_split_tokens=[','], | |
sentence_split_tokens=['.'], | |
mode='sum', | |
structure_output_dim = 64, | |
sentence_output_dim = 64, | |
max_duration = 120, | |
): | |
self.vocab_list = vocab_list | |
self.max_len = max_len | |
self.mode = mode | |
self.max_sentence_per_structure = max_sentence_per_structure | |
voc_size = len(self.vocab_list) | |
if structure_tokens is None: | |
structure_tokens = [i for i in vocab_list if len(i) > 1 and i[0] == '[' and i[-1] == ']'] | |
self.structure_token_ids = [vocab_list.index(i) for i in structure_tokens if i in vocab_list] | |
self.structure_split_token_ids = [vocab_list.index(i) for i in structure_split_tokens] | |
self.sentence_split_token_ids = [vocab_list.index(i) for i in sentence_split_tokens] | |
# here initialize a output_proj (nn.Embedding) layer | |
# By default the first vocab is "" (null) | |
if mode == 'sum': | |
content_output_dim = output_dim | |
sentence_output_dim = output_dim | |
structure_output_dim = output_dim | |
else: # concat' | |
raise NotImplementedError("concat 模式还未实现") | |
# content_output_dim = output_dim - sentence_output_dim - structure_output_dim # by default | |
super().__init__(voc_size, content_output_dim, input_token=True, padding_idx=0) | |
self.special_emb = nn.Embedding(voc_size, structure_output_dim, padding_idx=0) | |
self.blank_emb = nn.Parameter(torch.zeros(1, output_dim), requires_grad=False) | |
# the first index is "empty structure" token | |
self.sentence_idx_in_structure_emb = nn.Embedding(max_sentence_per_structure, sentence_output_dim) | |
self.sentence_reidx_in_structure_emb = nn.Embedding(max_sentence_per_structure, sentence_output_dim) | |
print("max_len", self.max_len) | |
print(self.structure_token_ids) | |
self.resolution = max_duration / max_len # e.g., 120 / 600 = 0.2s | |
print(self.__class__, f"resolution = {self.resolution}") | |
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]: | |
inputs = [] | |
for xx in x: | |
xx = '' if xx is None else xx | |
vocab_id = [self.vocab_list.index(item) for item in xx.split(" ") if item in self.vocab_list] | |
inputs.append(torch.tensor(vocab_id).long()) # [T] | |
return inputs | |
def forward(self, batch_tokens: tp.List, structure_dur = None) -> ConditionType: | |
""" | |
Encode token_id into three types of embeddings: | |
1) content embedding: phoneme only (or meaningful contents to be sung out) | |
2) structure embedding: structure / separation embeddings, including structures (verse/chorus/...), separators (. / ,) | |
The two above share the same embedding layer, can be changed to separate embedding layers. | |
3) sentence_idx embedding (per structure): | |
""" | |
embeds_batch = [] | |
for b in range(len(batch_tokens)): | |
tokens = batch_tokens[b] | |
content_tokens = torch.zeros_like(tokens) | |
special_tokens = torch.zeros_like(tokens) | |
sentence_idx_in_structure_tokens = torch.zeros_like(tokens) | |
sentence_reidx_in_structure_tokens = torch.zeros_like(tokens) | |
current_sentence_in_structure_idx = 1 | |
current_structure = 0 | |
for i in range(tokens.shape[-1]): | |
token = tokens[i] | |
if token in self.structure_token_ids: # structure token | |
# only update structure token, leave content and sentence index token null (default 0) | |
special_tokens[i] = token | |
content_tokens[i] = token | |
current_structure = token | |
current_sentence_in_structure_idx = 1 | |
sentence_idx_in_structure_tokens[i] = 0 | |
elif token in self.sentence_split_token_ids: # utterance split token | |
# only update structure token, leave content and sentence index token null (default 0) | |
# add up sentence index | |
special_tokens[i] = current_structure | |
content_tokens[i] = token | |
sentence_idx_in_structure_tokens[i] = min(current_sentence_in_structure_idx, self.max_sentence_per_structure - 1) | |
current_sentence_in_structure_idx += 1 | |
elif token in self.structure_split_token_ids: # structure split token | |
# update structure token (current structure), content token (current token), | |
# blank index token | |
content_tokens[i] = token | |
special_tokens[i] = current_structure | |
sentence_idx_in_structure_tokens[i] = sentence_idx_in_structure_tokens[i-1] | |
else: # content tokens | |
content_tokens[i] = token | |
special_tokens[i] = current_structure | |
sentence_idx_in_structure_tokens[i] = min(current_sentence_in_structure_idx, self.max_sentence_per_structure - 1) | |
# 反推 | |
current_sentence_num = sentence_idx_in_structure_tokens[-1] | |
for i in range(tokens.shape[-1]-1,-1,-1): | |
if current_sentence_num != 0: | |
sentence_reidx_in_structure_tokens[i] = min(current_sentence_num + 1 - sentence_idx_in_structure_tokens[i], self.max_sentence_per_structure - 1) | |
if sentence_idx_in_structure_tokens[i] == 0 and i > 0: | |
current_sentence_num = sentence_idx_in_structure_tokens[i-1] | |
# print("tokens", tokens.max(), tokens.min()) | |
# print("special tokens", special_tokens.max(), special_tokens.min()) | |
# print("sentence idx in structure", sentence_idx_in_structure_tokens.max(), sentence_idx_in_structure_tokens.min()) | |
device = self.output_proj.weight.device | |
# import pdb; pdb.set_trace() | |
content_embeds = self.output_proj(content_tokens.to(device)) # [T, N] | |
structure_embeds = self.output_proj(special_tokens.to(device)) | |
# sentence_idx_embeds = self.sentence_idx_in_structure_emb(sentence_idx_in_structure_tokens.to(device)) | |
sentence_idx_embeds = self.sentence_idx_in_structure_emb(sentence_idx_in_structure_tokens.to(device)) + self.sentence_reidx_in_structure_emb(sentence_reidx_in_structure_tokens.to(device)) | |
if self.mode == 'sum': | |
embeds = content_embeds + structure_embeds + sentence_idx_embeds | |
else: | |
embeds = torch.cat((content_embeds, structure_embeds, sentence_idx_embeds), -1) # [T, N] | |
embeds_batch.append(embeds) | |
# set batch_size = 1, [B, T, N] | |
if self.max_len is not None: | |
max_len = self.max_len | |
else: | |
max_len = max([e.shape[0] for e in embeds_batch]) | |
embeds, mask = self.pad_2d_tensor(embeds_batch, max_len) | |
return embeds, embeds, mask | |
def pad_2d_tensor(self, xs, max_len): | |
new_tensor = [] | |
new_mask = [] | |
for x in xs: | |
seq_len, dim = x.size() | |
pad_len = max_len - seq_len | |
if pad_len > 0: | |
pad_tensor = self.blank_emb.repeat(pad_len, 1).to(x.device) # T, D | |
padded_tensor = torch.cat([x, pad_tensor], dim=0) | |
mask = torch.cat((torch.ones_like(x[:, 0]), | |
torch.zeros_like(pad_tensor[:, 0])), 0) # T | |
elif pad_len < 0: | |
padded_tensor = x[:max_len] | |
mask = torch.ones_like(padded_tensor[:, 0]) | |
else: | |
padded_tensor = x | |
mask = torch.ones_like(x[:, 0]) | |
new_tensor.append(padded_tensor) | |
new_mask.append(mask) | |
# [B, T, D] & [B, T] | |
return torch.stack(new_tensor, 0), torch.stack(new_mask, 0) | |
class QwTokenizerConditioner(TextConditioner): | |
def __init__(self, output_dim: int, | |
token_path = "", | |
max_len = 300, | |
add_token_list=[]): #"" | |
from transformers import Qwen2Tokenizer | |
self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path) | |
if add_token_list != []: | |
self.text_tokenizer.add_tokens(add_token_list, special_tokens=True) | |
voc_size = len(self.text_tokenizer.get_vocab()) | |
# here initialize a output_proj (nn.Embedding) layer | |
super().__init__(voc_size, output_dim, input_token=True, padding_idx=151643) | |
self.max_len = max_len | |
self.padding_idx =' <|endoftext|>' | |
vocab = self.text_tokenizer.get_vocab() | |
# struct是全部的结构 | |
struct_tokens = [i for i in add_token_list if i[0]=='[' and i[-1]==']'] | |
self.struct_token_ids = [vocab[i] for i in struct_tokens] | |
self.pad_token_idx = 151643 | |
self.structure_emb = nn.Embedding(200, output_dim, padding_idx=0) | |
# self.split_token_id = vocab["."] | |
print("all structure tokens: ", {self.text_tokenizer.convert_ids_to_tokens(i):i for i in self.struct_token_ids}) | |
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]: | |
x = ['<|im_start|>' + xi if xi is not None else "<|im_start|>" for xi in x] | |
# x = [xi if xi is not None else "" for xi in x] | |
inputs = self.text_tokenizer(x, return_tensors="pt", padding=True) | |
return inputs | |
def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType: | |
""" | |
Add structure embeddings of {verse, chorus, bridge} to text/lyric tokens that | |
belong to these structures accordingly, | |
Then delete or keep these structure embeddings. | |
""" | |
mask = inputs['attention_mask'] | |
tokens = inputs['input_ids'] | |
B = tokens.shape[0] | |
is_sp_embed = torch.any(torch.stack([tokens == i for i in self.struct_token_ids], dim=-1),dim=-1) | |
tp_cover_range = torch.zeros_like(tokens) | |
for b, is_sp in enumerate(is_sp_embed): | |
sp_list = torch.where(is_sp)[0].tolist() | |
sp_list.append(mask[b].sum()) | |
for i, st in enumerate(sp_list[:-1]): | |
tp_cover_range[b, st: sp_list[i+1]] = tokens[b, st] - 151645 | |
if self.max_len is not None: | |
if inputs['input_ids'].shape[-1] > self.max_len: | |
warnings.warn(f"Max len limit ({self.max_len}) Exceed! \ | |
{[self.text_tokenizer.convert_ids_to_tokens(i.tolist()) for i in tokens]} will be cut!") | |
tokens = self.pad_2d_tensor(tokens, self.max_len, self.pad_token_idx).to(self.output_proj.weight.device) | |
mask = self.pad_2d_tensor(mask, self.max_len, 0).to(self.output_proj.weight.device) | |
tp_cover_range = self.pad_2d_tensor(tp_cover_range, self.max_len, 0).to(self.output_proj.weight.device) | |
device = self.output_proj.weight.device | |
content_embeds = self.output_proj(tokens.to(device)) | |
structure_embeds = self.structure_emb(tp_cover_range.to(device)) | |
embeds = content_embeds + structure_embeds | |
return embeds, embeds, mask | |
def pad_2d_tensor(self, x, max_len, pad_id): | |
batch_size, seq_len = x.size() | |
pad_len = max_len - seq_len | |
if pad_len > 0: | |
pad_tensor = torch.full((batch_size, pad_len), pad_id, dtype=x.dtype, device=x.device) | |
padded_tensor = torch.cat([x, pad_tensor], dim=1) | |
elif pad_len < 0: | |
padded_tensor = x[:, :max_len] | |
else: | |
padded_tensor = x | |
return padded_tensor | |
class QwTextConditioner(TextConditioner): | |
def __init__(self, output_dim: int, | |
token_path = "", | |
max_len = 300): #"" | |
from transformers import Qwen2Tokenizer | |
self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path) | |
voc_size = len(self.text_tokenizer.get_vocab()) | |
# here initialize a output_proj (nn.Embedding) layer | |
super().__init__(voc_size, output_dim, input_token=True, padding_idx=151643) | |
self.max_len = max_len | |
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]: | |
x = ['<|im_start|>' + xi if xi is not None else "<|im_start|>" for xi in x] | |
inputs = self.text_tokenizer(x, return_tensors="pt", padding=True) | |
return inputs | |
def forward(self, inputs: tp.Dict[str, torch.Tensor], structure_dur = None) -> ConditionType: | |
""" | |
Add structure embeddings of {verse, chorus, bridge} to text/lyric tokens that | |
belong to these structures accordingly, | |
Then delete or keep these structure embeddings. | |
""" | |
mask = inputs['attention_mask'] | |
tokens = inputs['input_ids'] | |
if self.max_len is not None: | |
if inputs['input_ids'].shape[-1] > self.max_len: | |
warnings.warn(f"Max len limit ({self.max_len}) Exceed! \ | |
{[self.text_tokenizer.convert_ids_to_tokens(i.tolist()) for i in tokens]} will be cut!") | |
tokens = self.pad_2d_tensor(tokens, self.max_len, 151643).to(self.output_proj.weight.device) | |
mask = self.pad_2d_tensor(mask, self.max_len, 0).to(self.output_proj.weight.device) | |
embeds = self.output_proj(tokens) | |
return embeds, embeds, mask | |
def pad_2d_tensor(self, x, max_len, pad_id): | |
batch_size, seq_len = x.size() | |
pad_len = max_len - seq_len | |
if pad_len > 0: | |
pad_tensor = torch.full((batch_size, pad_len), pad_id, dtype=x.dtype, device=x.device) | |
padded_tensor = torch.cat([x, pad_tensor], dim=1) | |
elif pad_len < 0: | |
padded_tensor = x[:, :max_len] | |
else: | |
padded_tensor = x | |
return padded_tensor | |
class AudioConditioner(BaseConditioner): | |
... | |
class QuantizedEmbeddingConditioner(AudioConditioner): | |
def __init__(self, dim: int, | |
code_size: int, | |
code_depth: int, | |
max_len: int, | |
**kwargs): | |
super().__init__(dim, dim, input_token=True) | |
self.code_depth = code_depth | |
# add 1 for <s> token | |
self.emb = nn.ModuleList([nn.Embedding(code_size+2, dim, padding_idx=code_size+1) for _ in range(code_depth)]) | |
# add End-Of-Text embedding | |
self.EOT_emb = nn.Parameter(torch.randn(1, dim), requires_grad=True) | |
self.layer2_EOT_emb = nn.Parameter(torch.randn(1, dim), requires_grad=True) | |
self.output_proj = None | |
self.max_len = max_len | |
self.vocab_size = code_size | |
def tokenize(self, x: AudioCondition) -> AudioCondition: | |
"""no extra ops""" | |
# wav, length, sample_rate, path, seek_time = x | |
# assert length is not None | |
return x #AudioCondition(wav, length, sample_rate, path, seek_time) | |
def forward(self, x: AudioCondition): | |
wav, lengths, *_ = x | |
B = wav.shape[0] | |
wav = wav.reshape(B, self.code_depth, -1).long() | |
if wav.shape[2] < self.max_len - 1: | |
wav = F.pad(wav, [0, self.max_len - 1 - wav.shape[2]], value=self.vocab_size+1) | |
else: | |
wav = wav[:, :, :self.max_len-1] | |
embeds1 = self.emb[0](wav[:, 0]) | |
embeds1 = torch.cat((self.EOT_emb.unsqueeze(0).repeat(B, 1, 1), | |
embeds1), dim=1) | |
embeds2 = sum([self.emb[k](wav[:, k]) for k in range(1, self.code_depth)]) # B,T,D | |
embeds2 = torch.cat((self.layer2_EOT_emb.unsqueeze(0).repeat(B, 1, 1), | |
embeds2), dim=1) | |
lengths = lengths + 1 | |
lengths = torch.clamp(lengths, max=self.max_len) | |
if lengths is not None: | |
mask = length_to_mask(lengths, max_len=embeds1.shape[1]).int() # type: ignore | |
else: | |
mask = torch.ones((B, self.code_depth), device=embeds1.device, dtype=torch.int) | |
return embeds1, embeds2, mask | |
# ================================================================ | |
# Aggregate all conditions and corresponding conditioners | |
# ================================================================ | |
class ConditionerProvider(nn.Module): | |
"""Prepare and provide conditions given all the supported conditioners. | |
Args: | |
conditioners (dict): Dictionary of conditioners. | |
device (torch.device or str, optional): Device for conditioners and output condition types. | |
""" | |
def __init__(self, conditioners: tp.Dict[str, BaseConditioner]): | |
super().__init__() | |
self.conditioners = nn.ModuleDict(conditioners) | |
def text_conditions(self): | |
return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)] | |
def audio_conditions(self): | |
return [k for k, v in self.conditioners.items() if isinstance(v, AudioConditioner)] | |
def has_audio_condition(self): | |
return len(self.audio_conditions) > 0 | |
def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]: | |
"""Match attributes/audios with existing conditioners in self, and compute tokenize them accordingly. | |
This should be called before starting any real GPU work to avoid synchronization points. | |
This will return a dict matching conditioner names to their arbitrary tokenized representations. | |
Args: | |
inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing | |
text and audio conditions. | |
""" | |
assert all([isinstance(x, ConditioningAttributes) for x in inputs]), ( | |
"Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]", | |
f" but types were {set([type(x) for x in inputs])}") | |
output = {} | |
text = self._collate_text(inputs) | |
audios = self._collate_audios(inputs) | |
assert set(text.keys() | audios.keys()).issubset(set(self.conditioners.keys())), ( | |
f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ", | |
f"got {text.keys(), audios.keys()}") | |
for attribute, batch in chain(text.items(), audios.items()): | |
output[attribute] = self.conditioners[attribute].tokenize(batch) | |
return output | |
def forward(self, tokenized: tp.Dict[str, tp.Any], structure_dur = None) -> tp.Dict[str, ConditionType]: | |
"""Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations. | |
The output is for example: | |
{ | |
"genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])), | |
"description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])), | |
... | |
} | |
Args: | |
tokenized (dict): Dict of tokenized representations as returned by `tokenize()`. | |
""" | |
output = {} | |
for attribute, inputs in tokenized.items(): | |
if attribute == 'description' and structure_dur is not None: | |
condition1, condition2, mask = self.conditioners[attribute](inputs, structure_dur = structure_dur) | |
else: | |
condition1, condition2, mask = self.conditioners[attribute](inputs) | |
output[attribute] = (condition1, condition2, mask) | |
return output | |
def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]: | |
"""Given a list of ConditioningAttributes objects, compile a dictionary where the keys | |
are the attributes and the values are the aggregated input per attribute. | |
For example: | |
Input: | |
[ | |
ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...), | |
ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, audio=...), | |
] | |
Output: | |
{ | |
"genre": ["Rock", "Hip-hop"], | |
"description": ["A rock song with a guitar solo", "A hip-hop verse"] | |
} | |
Args: | |
samples (list of ConditioningAttributes): List of ConditioningAttributes samples. | |
Returns: | |
dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch. | |
""" | |
out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list) | |
texts = [x.text for x in samples] | |
for text in texts: | |
for condition in self.text_conditions: | |
out[condition].append(text[condition]) | |
return out | |
def _collate_audios(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, AudioCondition]: | |
"""Generate a dict where the keys are attributes by which we fetch similar audios, | |
and the values are Tensors of audios according to said attributes. | |
*Note*: by the time the samples reach this function, each sample should have some audios | |
inside the "audio" attribute. It should be either: | |
1. A real audio | |
2. A null audio due to the sample having no similar audios (nullified by the dataset) | |
3. A null audio due to it being dropped in a dropout module (nullified by dropout) | |
Args: | |
samples (list of ConditioningAttributes): List of ConditioningAttributes samples. | |
Returns: | |
dict[str, WavCondition]: A dictionary mapping an attribute name to wavs. | |
""" | |
# import pdb; pdb.set_trace() | |
wavs = defaultdict(list) | |
lengths = defaultdict(list) | |
sample_rates = defaultdict(list) | |
paths = defaultdict(list) | |
seek_times = defaultdict(list) | |
out: tp.Dict[str, AudioCondition] = {} | |
for sample in samples: | |
for attribute in self.audio_conditions: | |
wav, length, sample_rate, path, seek_time = sample.audio[attribute] | |
assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]" | |
assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1" | |
wavs[attribute].append(wav.flatten()) # [C*T] | |
lengths[attribute].append(length) | |
sample_rates[attribute].extend(sample_rate) | |
paths[attribute].extend(path) | |
seek_times[attribute].extend(seek_time) | |
# stack all wavs to a single tensor | |
for attribute in self.audio_conditions: | |
stacked_wav, _ = collate(wavs[attribute], dim=0) | |
out[attribute] = AudioCondition( | |
stacked_wav.unsqueeze(1), | |
torch.cat(lengths[attribute]), sample_rates[attribute], | |
paths[attribute], seek_times[attribute]) | |
return out | |
class ConditionFuser(StreamingModule): | |
"""Condition fuser handles the logic to combine the different conditions | |
to the actual model input. | |
Args: | |
fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse | |
each condition. For example: | |
{ | |
"prepend": ["description"], | |
"sum": ["genre", "bpm"], | |
} | |
""" | |
FUSING_METHODS = ["sum", "prepend"] #, "cross", "input_interpolate"] (not support in this simplest version) | |
def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]]): | |
super().__init__() | |
assert all([k in self.FUSING_METHODS for k in fuse2cond.keys()] | |
), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}" | |
self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond | |
self.cond2fuse: tp.Dict[str, str] = {} | |
for fuse_method, conditions in fuse2cond.items(): | |
for condition in conditions: | |
self.cond2fuse[condition] = fuse_method | |
def forward( | |
self, | |
input1: torch.Tensor, | |
input2: torch.Tensor, | |
conditions: tp.Dict[str, ConditionType] | |
) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: | |
"""Fuse the conditions to the provided model input. | |
Args: | |
input (torch.Tensor): Transformer input. | |
conditions (dict[str, ConditionType]): Dict of conditions. | |
Returns: | |
tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input | |
after the conditions have been fused. The second output tensor is the tensor | |
used for cross-attention or None if no cross attention inputs exist. | |
""" | |
#import pdb; pdb.set_trace() | |
B, T, _ = input1.shape | |
if 'offsets' in self._streaming_state: | |
first_step = False | |
offsets = self._streaming_state['offsets'] | |
else: | |
first_step = True | |
offsets = torch.zeros(input1.shape[0], dtype=torch.long, device=input1.device) | |
assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \ | |
f"given conditions contain unknown attributes for fuser, " \ | |
f"expected {self.cond2fuse.keys()}, got {conditions.keys()}" | |
# if 'prepend' mode is used, | |
# the concatenation order will be the SAME with the conditions in config: | |
# prepend: ['description', 'prompt_audio'] (then goes the input) | |
fused_input_1 = input1 | |
fused_input_2 = input2 | |
for fuse_op in self.fuse2cond.keys(): | |
fuse_op_conditions = self.fuse2cond[fuse_op] | |
if fuse_op == 'sum' and len(fuse_op_conditions) > 0: | |
for cond in fuse_op_conditions: | |
this_cond_1, this_cond_2, cond_mask = conditions[cond] | |
fused_input_1 += this_cond_1 | |
fused_input_2 += this_cond_2 | |
elif fuse_op == 'prepend' and len(fuse_op_conditions) > 0: | |
if not first_step: | |
continue | |
reverse_list = deepcopy(fuse_op_conditions) | |
reverse_list.reverse() | |
for cond in reverse_list: | |
this_cond_1, this_cond_2, cond_mask = conditions[cond] | |
fused_input_1 = torch.cat((this_cond_1, fused_input_1), dim=1) # concat along T dim | |
fused_input_2 = torch.cat((this_cond_2, fused_input_2), dim=1) # concat along T dim | |
elif fuse_op not in self.FUSING_METHODS: | |
raise ValueError(f"unknown op ({fuse_op})") | |
if self._is_streaming: | |
self._streaming_state['offsets'] = offsets + T | |
return fused_input_1, fused_input_2 | |
# ================================================================ | |
# Condition Dropout | |
# ================================================================ | |
class DropoutModule(nn.Module): | |
"""Base module for all dropout modules.""" | |
def __init__(self, seed: int = 1234): | |
super().__init__() | |
self.rng = torch.Generator() | |
self.rng.manual_seed(seed) | |
class ClassifierFreeGuidanceDropout(DropoutModule): | |
"""Classifier Free Guidance dropout. | |
All attributes are dropped with the same probability. | |
Args: | |
p (float): Probability to apply condition dropout during training. | |
seed (int): Random seed. | |
""" | |
def __init__(self, p: float, seed: int = 1234): | |
super().__init__(seed=seed) | |
self.p = p | |
def check(self, sample, condition_type, condition): | |
if condition_type not in ['text', 'audio']: | |
raise ValueError("dropout_condition got an unexpected condition type!" | |
f" expected 'text', 'audio' but got '{condition_type}'") | |
if condition not in getattr(sample, condition_type): | |
raise ValueError( | |
"dropout_condition received an unexpected condition!" | |
f" expected audio={sample.audio.keys()} and text={sample.text.keys()}" | |
f" but got '{condition}' of type '{condition_type}'!") | |
def get_null_wav(self, wav, sr=48000) -> AudioCondition: | |
out = wav * 0 + 16385 | |
return AudioCondition( | |
wav=out, | |
length=torch.Tensor([0]).long(), | |
sample_rate=[sr],) | |
def dropout_condition(self, | |
sample: ConditioningAttributes, | |
condition_type: str, | |
condition: str) -> ConditioningAttributes: | |
"""Utility function for nullifying an attribute inside an ConditioningAttributes object. | |
If the condition is of type "wav", then nullify it using `nullify_condition` function. | |
If the condition is of any other type, set its value to None. | |
Works in-place. | |
""" | |
self.check(sample, condition_type, condition) | |
if condition_type == 'audio': | |
audio_cond = sample.audio[condition] | |
depth = audio_cond.wav.shape[1] | |
sample.audio[condition] = self.get_null_wav(audio_cond.wav, sr=audio_cond.sample_rate[0]) | |
else: | |
sample.text[condition] = None | |
return sample | |
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: | |
""" | |
Args: | |
samples (list[ConditioningAttributes]): List of conditions. | |
Returns: | |
list[ConditioningAttributes]: List of conditions after all attributes were set to None. | |
""" | |
# decide on which attributes to drop in a batched fashion | |
# drop = torch.rand(1, generator=self.rng).item() < self.p | |
# if not drop: | |
# return samples | |
# nullify conditions of all attributes | |
samples = deepcopy(samples) | |
for sample in samples: | |
drop = torch.rand(1, generator=self.rng).item() | |
if drop<self.p: | |
for condition_type in ["audio", "text"]: | |
for condition in sample.attributes[condition_type]: | |
self.dropout_condition(sample, condition_type, condition) | |
return samples | |
def __repr__(self): | |
return f"ClassifierFreeGuidanceDropout(p={self.p})" | |
class ClassifierFreeGuidanceDropoutInference(ClassifierFreeGuidanceDropout): | |
"""Classifier Free Guidance dropout during inference. | |
All attributes are dropped with the same probability. | |
Args: | |
p (float): Probability to apply condition dropout during training. | |
seed (int): Random seed. | |
""" | |
def __init__(self, seed: int = 1234): | |
super().__init__(p=1, seed=seed) | |
def dropout_condition_customized(self, | |
sample: ConditioningAttributes, | |
condition_type: str, | |
condition: str, | |
customized: list = None) -> ConditioningAttributes: | |
"""Utility function for nullifying an attribute inside an ConditioningAttributes object. | |
If the condition is of type "audio", then nullify it using `nullify_condition` function. | |
If the condition is of any other type, set its value to None. | |
Works in-place. | |
""" | |
self.check(sample, condition_type, condition) | |
if condition_type == 'audio': | |
audio_cond = sample.audio[condition] | |
depth = audio_cond.wav.shape[1] | |
sample.audio[condition] = self.get_null_wav(audio_cond.wav, sr=audio_cond.sample_rate[0]) | |
else: | |
if customized is None: | |
sample.text[condition] = None | |
else: | |
text_cond = deepcopy(sample.text[condition]) | |
if "structure" in customized: | |
for _s in ['[inst]', '[outro]', '[intro]', '[verse]', '[chorus]', '[bridge]']: | |
text_cond = text_cond.replace(_s, "") | |
text_cond = text_cond.replace(' , ', '') | |
text_cond = text_cond.replace(" ", " ") | |
if '.' in customized: | |
text_cond = text_cond.replace(" . ", " ") | |
text_cond = text_cond.replace(".", " ") | |
sample.text[condition] = text_cond | |
return sample | |
def forward(self, samples: tp.List[ConditioningAttributes], | |
condition_types=["wav", "text"], | |
customized=None, | |
) -> tp.List[ConditioningAttributes]: | |
""" | |
100% dropout some condition attributes (description, prompt_wav) or types (text, wav) of | |
samples during inference. | |
Args: | |
samples (list[ConditioningAttributes]): List of conditions. | |
Returns: | |
list[ConditioningAttributes]: List of conditions after all attributes were set to None. | |
""" | |
new_samples = deepcopy(samples) | |
for condition_type in condition_types: | |
for sample in new_samples: | |
for condition in sample.attributes[condition_type]: | |
self.dropout_condition_customized(sample, condition_type, condition, customized) | |
return new_samples | |
class AttributeDropout(ClassifierFreeGuidanceDropout): | |
"""Dropout with a given probability per attribute. | |
This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes | |
to be dropped out separately. For example, "artist" can be dropped while "genre" remains. | |
This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre" | |
must also be dropped. | |
Args: | |
p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example: | |
... | |
"genre": 0.1, | |
"artist": 0.5, | |
"audio": 0.25, | |
... | |
active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False. | |
seed (int, optional): Random seed. | |
""" | |
def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234): | |
super().__init__(p=p, seed=seed) | |
self.active_on_eval = active_on_eval | |
# construct dict that return the values from p otherwise 0 | |
self.p = {} | |
for condition_type, probs in p.items(): | |
self.p[condition_type] = defaultdict(lambda: 0, probs) | |
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: | |
""" | |
Args: | |
samples (list[ConditioningAttributes]): List of conditions. | |
Returns: | |
list[ConditioningAttributes]: List of conditions after certain attributes were set to None. | |
""" | |
if not self.training and not self.active_on_eval: | |
return samples | |
samples = deepcopy(samples) | |
for condition_type, ps in self.p.items(): # for condition types [text, wav] | |
for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre]) | |
if torch.rand(1, generator=self.rng).item() < p: | |
for sample in samples: | |
self.dropout_condition(sample, condition_type, condition) | |
return samples | |