flosstradamus's picture
Upload 194 files
afe1a07 verified
raw
history blame
17.3 kB
import torch
import torch.nn as nn
from audioldm2.latent_diffusion.util import (
instantiate_from_config,
)
# from latent_diffusion.modules.encoders.modules import CLAPAudioEmbeddingClassifierFreev2
from transformers import GPT2Config, GPT2Model
import torch.optim.lr_scheduler as lr_scheduler
class Sequence2AudioMAE(nn.Module):
def __init__(
self,
base_learning_rate,
sequence_gen_length,
sequence_input_key,
sequence_input_embed_dim,
cond_stage_config,
optimizer_type="AdamW",
use_warmup=True,
use_ar_gen_loss=False,
use_audiomae_linear=False,
target_tokens_mask_ratio=0.0,
random_mask_ratio=False,
**kwargs
):
super().__init__()
assert use_audiomae_linear == False
self.random_mask_ratio = random_mask_ratio
self.learning_rate = base_learning_rate
self.cond_stage_config = cond_stage_config
self.use_audiomae_linear = use_audiomae_linear
self.optimizer_type = optimizer_type
self.use_warmup = use_warmup
self.use_ar_gen_loss = use_ar_gen_loss
# Even though the LDM can be conditioned on mutliple pooling rate
# Our model always predict the higest pooling rate
# self.time_pool = max(self.cond_stage_config["crossattn_audiomae_pooled"]["params"]["time_pooling_factors"])
# self.freq_pool = max(self.cond_stage_config["crossattn_audiomae_pooled"]["params"]["freq_pooling_factors"])
# self.mae_token_num = int(512/(self.time_pool*self.freq_pool))
self.mae_token_num = sequence_gen_length
self.sequence_input_key = sequence_input_key
self.sequence_input_embed_dim = sequence_input_embed_dim
self.target_tokens_mask_ratio = target_tokens_mask_ratio
self.start_of_sequence_tokens = nn.Embedding(32, 768)
self.end_of_sequence_tokens = nn.Embedding(32, 768)
self.input_sequence_embed_linear = nn.ModuleList([])
self.initial_learning_rate = None
for dim in self.sequence_input_embed_dim:
self.input_sequence_embed_linear.append(nn.Linear(dim, 768))
self.cond_stage_models = nn.ModuleList([])
self.instantiate_cond_stage(cond_stage_config)
self.initialize_param_check_toolkit()
# configuration = GPT2Config(n_layer=1) # TODO
# self.model=GPT2Model(configuration)
###################
# self.model=nn.Linear(768,768, bias=False) # TODO change the model
# with torch.no_grad():
# self.model.weight.copy_(torch.eye(768))
###################
self.model = GPT2Model(GPT2Config.from_pretrained("gpt2"))
###################
# self.model = nn.LSTM(input_size=768, hidden_size=768, num_layers=1,bias=False) # TODO
# self.loss_fn = nn.MSELoss()
self.loss_fn = nn.L1Loss()
self.logger_save_dir = None
self.logger_exp_name = None
self.logger_exp_group_name = None
self.logger_version = None
def set_log_dir(self, save_dir, exp_group_name, exp_name):
self.logger_save_dir = save_dir
self.logger_exp_group_name = exp_group_name
self.logger_exp_name = exp_name
def cfg_uncond(self, batch_size):
unconditional_conditioning = {}
for key in self.cond_stage_model_metadata:
model_idx = self.cond_stage_model_metadata[key]["model_idx"]
unconditional_conditioning[key] = self.cond_stage_models[
model_idx
].get_unconditional_condition(batch_size)
assert (
"crossattn_audiomae_pooled" in unconditional_conditioning.keys()
), "The module is not initialized with AudioMAE"
unconditional_conditioning[
"crossattn_clap_to_audiomae_feature"
] = unconditional_conditioning["crossattn_audiomae_pooled"]
return unconditional_conditioning
def configure_optimizers(self):
lr = float(self.learning_rate)
# params = list(self.model.parameters()) + list(self.input_sequence_embed_linear.parameters())
params = list(self.parameters())
# opt = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.98), eps=1e-9)
opt = eval(self.optimizer_type)(params, lr=lr)
scheduler = lr_scheduler.StepLR(opt, step_size=10, gamma=0.8)
return [opt], [scheduler]
def add_sos_eos_tokens(self, _id, sequence, attn_mask):
batchsize = sequence.size(0)
new_attn_mask_step = torch.ones((batchsize, 1)).to(sequence.device)
key_id = torch.tensor([_id]).to(sequence.device)
# Add two more steps to attn mask
new_attn_mask = torch.cat(
[new_attn_mask_step, attn_mask, new_attn_mask_step], dim=1
)
# Add two more tokens in the sequence
sos_token = self.start_of_sequence_tokens(key_id).expand(batchsize, 1, -1)
eos_token = self.end_of_sequence_tokens(key_id).expand(batchsize, 1, -1)
new_sequence = torch.cat([sos_token, sequence, eos_token], dim=1)
return new_sequence, new_attn_mask
def truncate_sequence_and_mask(self, sequence, mask, max_len=512):
if sequence.size(1) > max_len:
print(
"The input sequence length to GPT-2 model is too long:",
sequence.size(1),
)
return sequence[:, :max_len], mask[:, :max_len]
else:
return sequence, mask
def get_input_sequence_and_mask(self, cond_dict):
input_embeds = None
input_embeds_attn_mask = None
for _id, sequence_key in enumerate(self.sequence_input_key):
assert sequence_key in cond_dict.keys(), (
"Invalid sequence key %s" % sequence_key
)
cond_embed = cond_dict[sequence_key]
if isinstance(cond_embed, list):
assert (
len(cond_embed) == 2
), "The crossattn returned list should have length 2, including embed and attn_mask"
item_input_embeds, item_attn_mask = cond_embed
item_input_embeds = self.input_sequence_embed_linear[_id](
item_input_embeds
)
item_input_embeds, item_attn_mask = self.add_sos_eos_tokens(
_id, item_input_embeds, item_attn_mask
)
if input_embeds is None and input_embeds_attn_mask is None:
input_embeds, input_embeds_attn_mask = (
item_input_embeds,
item_attn_mask,
)
else:
input_embeds = torch.cat(
[input_embeds, item_input_embeds], dim=1
) # The 1-st dimension is time steps
input_embeds_attn_mask = torch.cat(
[input_embeds_attn_mask, item_attn_mask], dim=1
) # The 1-st dimension is time steps
else:
assert isinstance(cond_embed, torch.Tensor)
cond_embed = self.input_sequence_embed_linear[_id](cond_embed)
attn_mask = torch.ones((cond_embed.size(0), cond_embed.size(1))).to(
cond_embed.device
)
item_input_embeds, item_attn_mask = self.add_sos_eos_tokens(
_id, cond_embed, attn_mask
)
if input_embeds is None and input_embeds_attn_mask is None:
input_embeds, input_embeds_attn_mask = (
item_input_embeds,
item_attn_mask,
)
else:
input_embeds, input_embeds_attn_mask = torch.cat(
[input_embeds, item_input_embeds], dim=1
), torch.cat([input_embeds_attn_mask, item_attn_mask], dim=1)
assert input_embeds is not None and input_embeds_attn_mask is not None
input_embeds, input_embeds_attn_mask = self.truncate_sequence_and_mask(
input_embeds, input_embeds_attn_mask, int(1024 - self.mae_token_num)
)
cond_sequence_end_time_idx = input_embeds.size(
1
) # The index that we start to collect the output embeds
return input_embeds, input_embeds_attn_mask, cond_sequence_end_time_idx
def warmup_step(self):
if self.initial_learning_rate is None:
self.initial_learning_rate = float(self.learning_rate)
# Only the first parameter group
if self.global_step <= 1000:
if self.global_step == 0:
print(
"Warming up learning rate start with %s"
% self.initial_learning_rate
)
self.trainer.optimizers[0].param_groups[0]["lr"] = (
self.global_step / 1000
) * self.initial_learning_rate
else:
# TODO set learning rate here
self.trainer.optimizers[0].param_groups[0][
"lr"
] = self.initial_learning_rate
def mask_target_sequence(self, target_embeds, target_embeds_attn_mask):
time_seq_mask = None
if self.target_tokens_mask_ratio > 1e-4:
batchsize, time_seq_len, embed_dim = target_embeds.size()
_, time_seq_len = target_embeds_attn_mask.size()
# Generate random mask
if self.random_mask_ratio:
mask_ratio = torch.rand(1).item() * self.target_tokens_mask_ratio
else:
mask_ratio = self.target_tokens_mask_ratio
time_seq_mask = (torch.rand((batchsize, time_seq_len)) > mask_ratio).to(
target_embeds.device
)
# Mask the target embedding
target_embeds = target_embeds * time_seq_mask.unsqueeze(-1)
target_embeds_attn_mask = target_embeds_attn_mask * time_seq_mask
return target_embeds, target_embeds_attn_mask, time_seq_mask
def generate_partial(self, batch, cond_dict=None, no_grad=False):
if cond_dict is None:
cond_dict = self.get_input(batch)
print("Generate partially prompted audio with in-context learning")
# self.model.train()
# assert self.model.training==True
target_embeds, target_embeds_attn_mask = (
cond_dict["crossattn_audiomae_pooled"][0],
cond_dict["crossattn_audiomae_pooled"][1],
)
target_time_steps = target_embeds.size(1)
(
input_embeds,
input_embeds_attn_mask,
cond_sequence_end_time_idx,
) = self.get_input_sequence_and_mask(cond_dict)
model_input = torch.cat(
[input_embeds, target_embeds[:, : target_time_steps // 4, :]], dim=1
)
model_input_mask = torch.cat(
[
input_embeds_attn_mask,
target_embeds_attn_mask[:, : target_time_steps // 4],
],
dim=1,
)
steps = self.mae_token_num
for _ in range(3 * steps // 4):
output = self.model(
inputs_embeds=model_input, attention_mask=model_input_mask
)["last_hidden_state"]
# Update the model input
model_input = torch.cat([model_input, output[:, -1:, :]], dim=1)
# Update the attention mask
attention_mask_new_step = torch.ones((model_input_mask.size(0), 1)).to(
model_input.device
)
model_input_mask = torch.cat(
[model_input_mask, attention_mask_new_step], dim=1
)
output = model_input[:, cond_sequence_end_time_idx:]
return output, cond_dict
def generate(self, batch, cond_dict=None, no_grad=False):
if cond_dict is None:
cond_dict = self.get_input(batch)
# self.model.train()
# print("!!!!!!!!!!!!!train")
(
input_embeds,
input_embeds_attn_mask,
cond_sequence_end_time_idx,
) = self.get_input_sequence_and_mask(cond_dict)
model_input = input_embeds
model_input_mask = input_embeds_attn_mask
steps = self.mae_token_num
for _ in range(steps):
output = self.model(
inputs_embeds=model_input, attention_mask=model_input_mask
)["last_hidden_state"]
# Update the model input
model_input = torch.cat([model_input, output[:, -1:, :]], dim=1)
# Update the attention mask
attention_mask_new_step = torch.ones((model_input_mask.size(0), 1)).to(
model_input.device
)
model_input_mask = torch.cat(
[model_input_mask, attention_mask_new_step], dim=1
)
return model_input[:, cond_sequence_end_time_idx:], cond_dict
def get_input_item(self, batch, k):
fname, text, waveform, stft, fbank = (
batch["fname"],
batch["text"],
batch["waveform"],
batch["stft"],
batch["log_mel_spec"],
)
ret = {}
ret["fbank"] = (
fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float()
)
ret["stft"] = stft.to(memory_format=torch.contiguous_format).float()
# ret["clip_label"] = clip_label.to(memory_format=torch.contiguous_format).float()
ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float()
ret["text"] = list(text)
ret["fname"] = fname
for key in batch.keys():
if key not in ret.keys():
ret[key] = batch[key]
return ret[k]
def get_input(self, batch):
cond_dict = {}
if len(self.cond_stage_model_metadata.keys()) > 0:
unconditional_cfg = False
for cond_model_key in self.cond_stage_model_metadata.keys():
cond_stage_key = self.cond_stage_model_metadata[cond_model_key][
"cond_stage_key"
]
# if(not self.training):
# if(isinstance(self.cond_stage_models[self.cond_stage_model_metadata[cond_model_key]["model_idx"]], CLAPAudioEmbeddingClassifierFreev2)):
# assert cond_stage_key == "text" # CLAP model should use text for evaluation
# The original data for conditioning
xc = self.get_input_item(batch, cond_stage_key)
if type(xc) == torch.Tensor:
xc = xc.to(self.device)
c = self.get_learned_conditioning(
xc, key=cond_model_key, unconditional_cfg=unconditional_cfg
)
cond_dict[cond_model_key] = c
return cond_dict
def instantiate_cond_stage(self, config):
self.cond_stage_model_metadata = {}
for i, cond_model_key in enumerate(config.keys()):
model = instantiate_from_config(config[cond_model_key])
self.cond_stage_models.append(model)
self.cond_stage_model_metadata[cond_model_key] = {
"model_idx": i,
"cond_stage_key": config[cond_model_key]["cond_stage_key"],
"conditioning_key": config[cond_model_key]["conditioning_key"],
}
def get_learned_conditioning(self, c, key, unconditional_cfg):
assert key in self.cond_stage_model_metadata.keys()
# Classifier-free guidance
if not unconditional_cfg:
c = self.cond_stage_models[
self.cond_stage_model_metadata[key]["model_idx"]
](c)
else:
if isinstance(c, torch.Tensor):
batchsize = c.size(0)
elif isinstance(c, list):
batchsize = len(c)
else:
raise NotImplementedError()
c = self.cond_stage_models[
self.cond_stage_model_metadata[key]["model_idx"]
].get_unconditional_condition(batchsize)
return c
def initialize_param_check_toolkit(self):
self.tracked_steps = 0
self.param_dict = {}
def statistic_require_grad_tensor_number(self, module, name=None):
requires_grad_num = 0
total_num = 0
require_grad_tensor = None
for p in module.parameters():
if p.requires_grad:
requires_grad_num += 1
if require_grad_tensor is None:
require_grad_tensor = p
total_num += 1
print(
"Module: [%s] have %s trainable parameters out of %s total parameters (%.2f)"
% (name, requires_grad_num, total_num, requires_grad_num / total_num)
)
return require_grad_tensor