import torch import torch.nn as nn from audioldm2.latent_diffusion.util import ( instantiate_from_config, ) # from latent_diffusion.modules.encoders.modules import CLAPAudioEmbeddingClassifierFreev2 from transformers import GPT2Config, GPT2Model import torch.optim.lr_scheduler as lr_scheduler class Sequence2AudioMAE(nn.Module): def __init__( self, base_learning_rate, sequence_gen_length, sequence_input_key, sequence_input_embed_dim, cond_stage_config, optimizer_type="AdamW", use_warmup=True, use_ar_gen_loss=False, use_audiomae_linear=False, target_tokens_mask_ratio=0.0, random_mask_ratio=False, **kwargs ): super().__init__() assert use_audiomae_linear == False self.random_mask_ratio = random_mask_ratio self.learning_rate = base_learning_rate self.cond_stage_config = cond_stage_config self.use_audiomae_linear = use_audiomae_linear self.optimizer_type = optimizer_type self.use_warmup = use_warmup self.use_ar_gen_loss = use_ar_gen_loss # Even though the LDM can be conditioned on mutliple pooling rate # Our model always predict the higest pooling rate # self.time_pool = max(self.cond_stage_config["crossattn_audiomae_pooled"]["params"]["time_pooling_factors"]) # self.freq_pool = max(self.cond_stage_config["crossattn_audiomae_pooled"]["params"]["freq_pooling_factors"]) # self.mae_token_num = int(512/(self.time_pool*self.freq_pool)) self.mae_token_num = sequence_gen_length self.sequence_input_key = sequence_input_key self.sequence_input_embed_dim = sequence_input_embed_dim self.target_tokens_mask_ratio = target_tokens_mask_ratio self.start_of_sequence_tokens = nn.Embedding(32, 768) self.end_of_sequence_tokens = nn.Embedding(32, 768) self.input_sequence_embed_linear = nn.ModuleList([]) self.initial_learning_rate = None for dim in self.sequence_input_embed_dim: self.input_sequence_embed_linear.append(nn.Linear(dim, 768)) self.cond_stage_models = nn.ModuleList([]) self.instantiate_cond_stage(cond_stage_config) self.initialize_param_check_toolkit() # configuration = GPT2Config(n_layer=1) # TODO # self.model=GPT2Model(configuration) ################### # self.model=nn.Linear(768,768, bias=False) # TODO change the model # with torch.no_grad(): # self.model.weight.copy_(torch.eye(768)) ################### self.model = GPT2Model(GPT2Config.from_pretrained("gpt2")) ################### # self.model = nn.LSTM(input_size=768, hidden_size=768, num_layers=1,bias=False) # TODO # self.loss_fn = nn.MSELoss() self.loss_fn = nn.L1Loss() self.logger_save_dir = None self.logger_exp_name = None self.logger_exp_group_name = None self.logger_version = None def set_log_dir(self, save_dir, exp_group_name, exp_name): self.logger_save_dir = save_dir self.logger_exp_group_name = exp_group_name self.logger_exp_name = exp_name def cfg_uncond(self, batch_size): unconditional_conditioning = {} for key in self.cond_stage_model_metadata: model_idx = self.cond_stage_model_metadata[key]["model_idx"] unconditional_conditioning[key] = self.cond_stage_models[ model_idx ].get_unconditional_condition(batch_size) assert ( "crossattn_audiomae_pooled" in unconditional_conditioning.keys() ), "The module is not initialized with AudioMAE" unconditional_conditioning[ "crossattn_clap_to_audiomae_feature" ] = unconditional_conditioning["crossattn_audiomae_pooled"] return unconditional_conditioning def configure_optimizers(self): lr = float(self.learning_rate) # params = list(self.model.parameters()) + list(self.input_sequence_embed_linear.parameters()) params = list(self.parameters()) # opt = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.98), eps=1e-9) opt = eval(self.optimizer_type)(params, lr=lr) scheduler = lr_scheduler.StepLR(opt, step_size=10, gamma=0.8) return [opt], [scheduler] def add_sos_eos_tokens(self, _id, sequence, attn_mask): batchsize = sequence.size(0) new_attn_mask_step = torch.ones((batchsize, 1)).to(sequence.device) key_id = torch.tensor([_id]).to(sequence.device) # Add two more steps to attn mask new_attn_mask = torch.cat( [new_attn_mask_step, attn_mask, new_attn_mask_step], dim=1 ) # Add two more tokens in the sequence sos_token = self.start_of_sequence_tokens(key_id).expand(batchsize, 1, -1) eos_token = self.end_of_sequence_tokens(key_id).expand(batchsize, 1, -1) new_sequence = torch.cat([sos_token, sequence, eos_token], dim=1) return new_sequence, new_attn_mask def truncate_sequence_and_mask(self, sequence, mask, max_len=512): if sequence.size(1) > max_len: print( "The input sequence length to GPT-2 model is too long:", sequence.size(1), ) return sequence[:, :max_len], mask[:, :max_len] else: return sequence, mask def get_input_sequence_and_mask(self, cond_dict): input_embeds = None input_embeds_attn_mask = None for _id, sequence_key in enumerate(self.sequence_input_key): assert sequence_key in cond_dict.keys(), ( "Invalid sequence key %s" % sequence_key ) cond_embed = cond_dict[sequence_key] if isinstance(cond_embed, list): assert ( len(cond_embed) == 2 ), "The crossattn returned list should have length 2, including embed and attn_mask" item_input_embeds, item_attn_mask = cond_embed item_input_embeds = self.input_sequence_embed_linear[_id]( item_input_embeds ) item_input_embeds, item_attn_mask = self.add_sos_eos_tokens( _id, item_input_embeds, item_attn_mask ) if input_embeds is None and input_embeds_attn_mask is None: input_embeds, input_embeds_attn_mask = ( item_input_embeds, item_attn_mask, ) else: input_embeds = torch.cat( [input_embeds, item_input_embeds], dim=1 ) # The 1-st dimension is time steps input_embeds_attn_mask = torch.cat( [input_embeds_attn_mask, item_attn_mask], dim=1 ) # The 1-st dimension is time steps else: assert isinstance(cond_embed, torch.Tensor) cond_embed = self.input_sequence_embed_linear[_id](cond_embed) attn_mask = torch.ones((cond_embed.size(0), cond_embed.size(1))).to( cond_embed.device ) item_input_embeds, item_attn_mask = self.add_sos_eos_tokens( _id, cond_embed, attn_mask ) if input_embeds is None and input_embeds_attn_mask is None: input_embeds, input_embeds_attn_mask = ( item_input_embeds, item_attn_mask, ) else: input_embeds, input_embeds_attn_mask = torch.cat( [input_embeds, item_input_embeds], dim=1 ), torch.cat([input_embeds_attn_mask, item_attn_mask], dim=1) assert input_embeds is not None and input_embeds_attn_mask is not None input_embeds, input_embeds_attn_mask = self.truncate_sequence_and_mask( input_embeds, input_embeds_attn_mask, int(1024 - self.mae_token_num) ) cond_sequence_end_time_idx = input_embeds.size( 1 ) # The index that we start to collect the output embeds return input_embeds, input_embeds_attn_mask, cond_sequence_end_time_idx def warmup_step(self): if self.initial_learning_rate is None: self.initial_learning_rate = float(self.learning_rate) # Only the first parameter group if self.global_step <= 1000: if self.global_step == 0: print( "Warming up learning rate start with %s" % self.initial_learning_rate ) self.trainer.optimizers[0].param_groups[0]["lr"] = ( self.global_step / 1000 ) * self.initial_learning_rate else: # TODO set learning rate here self.trainer.optimizers[0].param_groups[0][ "lr" ] = self.initial_learning_rate def mask_target_sequence(self, target_embeds, target_embeds_attn_mask): time_seq_mask = None if self.target_tokens_mask_ratio > 1e-4: batchsize, time_seq_len, embed_dim = target_embeds.size() _, time_seq_len = target_embeds_attn_mask.size() # Generate random mask if self.random_mask_ratio: mask_ratio = torch.rand(1).item() * self.target_tokens_mask_ratio else: mask_ratio = self.target_tokens_mask_ratio time_seq_mask = (torch.rand((batchsize, time_seq_len)) > mask_ratio).to( target_embeds.device ) # Mask the target embedding target_embeds = target_embeds * time_seq_mask.unsqueeze(-1) target_embeds_attn_mask = target_embeds_attn_mask * time_seq_mask return target_embeds, target_embeds_attn_mask, time_seq_mask def generate_partial(self, batch, cond_dict=None, no_grad=False): if cond_dict is None: cond_dict = self.get_input(batch) print("Generate partially prompted audio with in-context learning") # self.model.train() # assert self.model.training==True target_embeds, target_embeds_attn_mask = ( cond_dict["crossattn_audiomae_pooled"][0], cond_dict["crossattn_audiomae_pooled"][1], ) target_time_steps = target_embeds.size(1) ( input_embeds, input_embeds_attn_mask, cond_sequence_end_time_idx, ) = self.get_input_sequence_and_mask(cond_dict) model_input = torch.cat( [input_embeds, target_embeds[:, : target_time_steps // 4, :]], dim=1 ) model_input_mask = torch.cat( [ input_embeds_attn_mask, target_embeds_attn_mask[:, : target_time_steps // 4], ], dim=1, ) steps = self.mae_token_num for _ in range(3 * steps // 4): output = self.model( inputs_embeds=model_input, attention_mask=model_input_mask )["last_hidden_state"] # Update the model input model_input = torch.cat([model_input, output[:, -1:, :]], dim=1) # Update the attention mask attention_mask_new_step = torch.ones((model_input_mask.size(0), 1)).to( model_input.device ) model_input_mask = torch.cat( [model_input_mask, attention_mask_new_step], dim=1 ) output = model_input[:, cond_sequence_end_time_idx:] return output, cond_dict def generate(self, batch, cond_dict=None, no_grad=False): if cond_dict is None: cond_dict = self.get_input(batch) # self.model.train() # print("!!!!!!!!!!!!!train") ( input_embeds, input_embeds_attn_mask, cond_sequence_end_time_idx, ) = self.get_input_sequence_and_mask(cond_dict) model_input = input_embeds model_input_mask = input_embeds_attn_mask steps = self.mae_token_num for _ in range(steps): output = self.model( inputs_embeds=model_input, attention_mask=model_input_mask )["last_hidden_state"] # Update the model input model_input = torch.cat([model_input, output[:, -1:, :]], dim=1) # Update the attention mask attention_mask_new_step = torch.ones((model_input_mask.size(0), 1)).to( model_input.device ) model_input_mask = torch.cat( [model_input_mask, attention_mask_new_step], dim=1 ) return model_input[:, cond_sequence_end_time_idx:], cond_dict def get_input_item(self, batch, k): fname, text, waveform, stft, fbank = ( batch["fname"], batch["text"], batch["waveform"], batch["stft"], batch["log_mel_spec"], ) ret = {} ret["fbank"] = ( fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float() ) ret["stft"] = stft.to(memory_format=torch.contiguous_format).float() # ret["clip_label"] = clip_label.to(memory_format=torch.contiguous_format).float() ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float() ret["text"] = list(text) ret["fname"] = fname for key in batch.keys(): if key not in ret.keys(): ret[key] = batch[key] return ret[k] def get_input(self, batch): cond_dict = {} if len(self.cond_stage_model_metadata.keys()) > 0: unconditional_cfg = False for cond_model_key in self.cond_stage_model_metadata.keys(): cond_stage_key = self.cond_stage_model_metadata[cond_model_key][ "cond_stage_key" ] # if(not self.training): # if(isinstance(self.cond_stage_models[self.cond_stage_model_metadata[cond_model_key]["model_idx"]], CLAPAudioEmbeddingClassifierFreev2)): # assert cond_stage_key == "text" # CLAP model should use text for evaluation # The original data for conditioning xc = self.get_input_item(batch, cond_stage_key) if type(xc) == torch.Tensor: xc = xc.to(self.device) c = self.get_learned_conditioning( xc, key=cond_model_key, unconditional_cfg=unconditional_cfg ) cond_dict[cond_model_key] = c return cond_dict def instantiate_cond_stage(self, config): self.cond_stage_model_metadata = {} for i, cond_model_key in enumerate(config.keys()): model = instantiate_from_config(config[cond_model_key]) self.cond_stage_models.append(model) self.cond_stage_model_metadata[cond_model_key] = { "model_idx": i, "cond_stage_key": config[cond_model_key]["cond_stage_key"], "conditioning_key": config[cond_model_key]["conditioning_key"], } def get_learned_conditioning(self, c, key, unconditional_cfg): assert key in self.cond_stage_model_metadata.keys() # Classifier-free guidance if not unconditional_cfg: c = self.cond_stage_models[ self.cond_stage_model_metadata[key]["model_idx"] ](c) else: if isinstance(c, torch.Tensor): batchsize = c.size(0) elif isinstance(c, list): batchsize = len(c) else: raise NotImplementedError() c = self.cond_stage_models[ self.cond_stage_model_metadata[key]["model_idx"] ].get_unconditional_condition(batchsize) return c def initialize_param_check_toolkit(self): self.tracked_steps = 0 self.param_dict = {} def statistic_require_grad_tensor_number(self, module, name=None): requires_grad_num = 0 total_num = 0 require_grad_tensor = None for p in module.parameters(): if p.requires_grad: requires_grad_num += 1 if require_grad_tensor is None: require_grad_tensor = p total_num += 1 print( "Module: [%s] have %s trainable parameters out of %s total parameters (%.2f)" % (name, requires_grad_num, total_num, requires_grad_num / total_num) ) return require_grad_tensor