""" Main model for using CodecLM. This will combine all the required components and provide easy access to the generation API. """ import typing as tp import warnings import sys import time import torch import torch.nn as nn from torch.nn import functional as F import torchaudio import numpy as np import lightning as pl from torchmetrics.classification import MulticlassAccuracy import pdb from codeclm.models import builders import math from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from peft import LoraConfig, get_peft_model from datetime import datetime import os os.environ['TOKENIZERS_PARALLELISM'] = "false" class CodecLM_PL(pl.LightningModule): def __init__(self, cfg): super().__init__() self.cfg = cfg # 1) Build audio tokenizer (usually None during training) self.audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg) if self.audio_tokenizer is not None: for param in self.audio_tokenizer.parameters(): param.requires_grad = False if "audio_tokenizer_checkpoint_sep" in self.cfg.keys(): self.seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg) for param in self.seperate_tokenizer.parameters(): param.requires_grad = False else: self.seperate_tokenizer = None # 2) Build LM self.audiolm = builders.get_lm_model(self.cfg) print(self.audiolm) # 输出参数量 print('Number of parameters: ', sum(p.numel() for p in self.audiolm.parameters())) # 3) Load pretrained checkpoint (if any) if self.cfg.use_pretrained == 'deepspeed': checkpoint = torch.load(self.cfg.pretrained.deepspeed_checkpoint, map_location='cpu') missing, unexpected = self.load_state_dict(checkpoint, strict=False) print(f'-------------Missing--------------\n{missing}') print(f'-------------Unexpected--------------\n{unexpected}') print("successfully load deepspeed pretrained model {}".format(self.cfg.pretrained.deepspeed_checkpoint)) self.missing = missing else: self.missing = [] # 如果cfg参数中有lora if hasattr(self.cfg, 'lora'): perf_config = LoraConfig( r = self.cfg.lora.r, lora_alpha = self.cfg.lora.lora_alpha, target_modules = self.cfg.lora.target_modules, lora_dropout = self.cfg.lora.lora_dropout, bias = self.cfg.lora.bias, task_type = self.cfg.lora.task_type, ) self.audiolm = get_peft_model(self.audiolm, perf_config) # 4) Build metrics self.val_steps = [] self.train_slide_acc = [] self.train_steps = [] self.top1_acc_metric = nn.ModuleList([MulticlassAccuracy( self.audiolm.code_size, top_k=1, average="micro", multidim_average="global", ignore_index=self.cfg.lm.code_size, # ignore EOS token prediction ) for _ in range(self.audiolm.code_depth)]) self.top10_acc_metric = nn.ModuleList([MulticlassAccuracy( self.audiolm.code_size, top_k=10, average="micro", multidim_average="global", ignore_index=self.cfg.lm.code_size, ) for _ in range(self.audiolm.code_depth)]) self.epoch = 0 print("++++++++++++++++ training +++++++++++++++++") # TODO: move this part to loader def generate_mask_and_end_token(self, x, sequence_lengths, end_id=16384): batch_size = sequence_lengths.size(0) max_length = x.size(2) # pad one frame, if the maximum sequence length is equal to the input length if max_length == sequence_lengths.max(): x = F.pad(x, (0, 1), value=end_id) max_length = x.size(2) if max_length <= sequence_lengths.max() + 1: sequence_lengths = sequence_lengths - (sequence_lengths.max()+1 - max_length) # Add end token to x according to the sequence length x[torch.arange(batch_size), :, sequence_lengths] = end_id sequence_lengths += 1 mask = torch.arange(max_length).expand(batch_size, max_length) < sequence_lengths.unsqueeze(1) mask = mask.to(x.device) mask_3d = mask.unsqueeze(1).expand(batch_size, x.size(1), max_length) x = torch.where(mask_3d, x, end_id+1) return x, mask_3d @torch.no_grad() def preprocess_batch(self, batch): # this function is usually called during training # 处理 dataloader 返回的数据 audio, text_lyric, time_stamp, structure_dur, prompt_audio, structure_labels = batch dur, valid_st, valid_et = zip(*time_stamp) if self.audio_tokenizer is not None: # only used in inference self.audio_tokenizer.eval() with torch.no_grad(): with torch.cuda.amp.autocast(enabled=False): audio_tokens, scale = self.audio_tokenizer.encode(audio) audio_tokens = audio_tokens[:,:self.cfg.lm.code_depth,:] audio_tokens = audio_tokens.long() else: audio_tokens = audio.long() token_dur = (torch.Tensor(dur) * self.cfg.audio_tokenizer_frame_rate).int() audio_tokens, audio_padding_mask = self.generate_mask_and_end_token(audio_tokens, token_dur, end_id=self.audiolm.eos_token_id) condition_tensors = self.audiolm.prepare_condition_tensors(batch_size=len(text_lyric), text=text_lyric, audio_qt_emb=prompt_audio) return condition_tensors, audio_tokens, audio_padding_mask def get_time(self): # 获取当前的日期和时间 now = datetime.now() # 使用strftime函数格式化日期和时间 formatted_now = now.strftime("%Y-%m-%d %H:%M:%S.%f") return formatted_now def training_step(self, batch, batch_idx): # 1) data processing condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch) # 2) compute model predictions (model forward) model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors, training_steps=self.global_step) # this input can be ignored logits = model_output.logits.float() mask = padding_mask & model_output.mask # 3) compute loss (float) with torch.cuda.amp.autocast(enabled=False): ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask) total_loss = ce if torch.isnan(total_loss): print(self.trainer.global_rank, ce, padding_mask, batch[1]) print('--------------------------------------------------------------') return None # torchaudio.save("error_rank{}.wav".format(self.trainer.global_rank), batch[0][:,0].cpu(), 24000) # import pdb; pdb.set_trace() # 4) compute metrics and log metrics = {} self.log('ce', ce, prog_bar=True) metrics['ppl'] = torch.exp(ce) for k, ce_q in enumerate(ce_per_codebook): metrics[f'ce_q{k + 1}'] = ce_q metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q) masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size) metrics['acc'] = [] for k in range(self.audiolm.code_depth): metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:, k]).item()) metrics['acc'] = torch.mean(torch.Tensor(metrics['acc'])).item() self.train_steps.append({'ce': ce.detach().cpu().item(), 'acc': metrics['acc']}) self.log('train_acc', metrics['acc']+1e-8, prog_bar=True) self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'], prog_bar=True) self.log_dict(metrics) return total_loss @torch.no_grad() def validation_step(self, batch, batch_idx): # 1) data processing condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch) # 2) compute model predictions model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors) logits = model_output.logits mask = padding_mask & model_output.mask # 3) compute loss and metrics ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask) metrics = {} metrics['val_ce'] = ce metrics['val_ppl'] = torch.exp(ce) for k, ce_q in enumerate(ce_per_codebook): metrics[f'val_ce_q{k + 1}'] = ce_q metrics[f'val_ppl_q{k + 1}'] = torch.exp(ce_q) masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size) for k in range(self.audiolm.code_depth): self.top1_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k]) #* total_length self.top10_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k]) self.val_steps.append(metrics) metrics['acc'] = [] metrics['acc_top10'] = [] for k in range(self.audiolm.code_depth): metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item()) metrics['acc_top10'].append(self.top10_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item()) metrics['acc'] = torch.mean(torch.Tensor(metrics['acc'])) metrics['acc_top10'] = torch.mean(torch.Tensor(metrics['acc_top10'])) return metrics['acc'] def on_validation_epoch_end(self) -> None: final_metrics = {} for i in self.val_steps: for k in i: final_metrics[k] = final_metrics.get(k, []) + [i[k]] final_metrics = {k: sum(v) / len(v) for k,v in list(final_metrics.items())} self.log_dict(final_metrics) q_acc = [] q_acc10 = [] for i in range(self.audiolm.code_depth): q_acc.append(self.top1_acc_metric[i].compute()) q_acc10.append(self.top10_acc_metric[i].compute()) self.log(f"val_Top1Acc_{i}", q_acc[-1]) self.log(f"val_Top10Acc_{i}", q_acc10[-1]) self.top1_acc_metric[i].reset() self.top10_acc_metric[i].reset() self.log('val_Top1Acc', sum(q_acc) / self.audiolm.code_depth) self.log('val_Top10Acc', sum(q_acc10) / self.audiolm.code_depth) return super().on_validation_epoch_end() def on_validation_epoch_start(self) -> None: self.val_steps = [] for i in range(self.audiolm.code_depth): self.top1_acc_metric[i].reset() self.top10_acc_metric[i].reset() if len(self.train_steps) > 0: train_metrics = {} for i in self.train_steps: for k in i: train_metrics[k] = train_metrics.get(k, []) + [i[k]] train_metrics = {k: sum(v) / len(v) for k,v in list(train_metrics.items())} self.log('train_summary_Top1Acc', train_metrics['acc']) self.log('train_summary_ce', train_metrics['ce']) self.train_steps = [] return super().on_validation_epoch_start() # 定义优化器 def configure_optimizers(self): total_updates = self.cfg.optim.epochs * self.cfg.optim.updates_per_epoch optim_dict = {} param_groups = [] missing_params = [] other_params = [] cnt = 0 # 去掉开头的‘audiolm.' print('before missing len', len(self.missing)) self.missing = [name.replace('audiolm.', '') for name in self.missing] print('after missing len', len(self.missing)) for name, param in self.audiolm.named_parameters(): if name in self.missing: cnt += 1 print(name) missing_params.append(param) else: other_params.append(param) print(cnt) assert cnt == len(self.missing) param_groups.append({'params': other_params, 'lr': self.cfg.optim.old_lr}) param_groups.append({ 'params': missing_params, 'lr': self.cfg.optim.new_lr # 为missing参数设置10倍的学习率,你可以调整这个倍数 }) if self.cfg.optim.optimizer == "adamw": optim_dict['optimizer'] = torch.optim.AdamW( param_groups, # 使用参数分组替代原来的 self.audiolm.parameters() betas=tuple(self.cfg.optim.adam.betas), weight_decay=self.cfg.optim.adam.weight_decay, eps=self.cfg.optim.adam.eps, ) else: raise NotImplementedError if self.cfg.schedule is None: pass elif self.cfg.schedule.lr_scheduler == "cosine": scheduler = CosineLRScheduler(optim_dict['optimizer'], total_steps=total_updates, warmup_steps=self.cfg.schedule.cosine.warmup, lr_min_ratio=self.cfg.schedule.cosine.lr_min_ratio, cycle_length=self.cfg.schedule.cosine.cycle_length, ) optim_dict['lr_scheduler'] = {"scheduler": scheduler, "interval": "step"} else: raise NotImplementedError return optim_dict def _compute_cross_entropy( self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]: """Compute cross entropy between multi-codebook targets and model's logits. The cross entropy is computed per codebook to provide codebook-level cross entropy. Valid timesteps for each of the codebook are pulled from the mask, where invalid timesteps are set to 0. Args: logits (torch.Tensor): Model's logits of shape [B, K, T, card]. targets (torch.Tensor): Target codes, of shape [B, K, T]. mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. Returns: ce (torch.Tensor): Cross entropy averaged over the codebooks ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). """ # import pdb; pdb.set_trace() B, K, T = targets.shape assert logits.shape[:-1] == targets.shape assert mask.shape == targets.shape ce = torch.zeros([], device=targets.device) ce_per_codebook: tp.List[torch.Tensor] = [] for k in range(K): logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card] targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T] mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T] ce_targets = targets_k[mask_k] ce_logits = logits_k[mask_k] q_ce = F.cross_entropy(ce_logits, ce_targets) ce += q_ce ce_per_codebook.append(q_ce.detach()) # average cross entropy across codebooks ce = ce / K return ce, ce_per_codebook class CodecLM_PL_FT(pl.LightningModule): def __init__(self, cfg): super().__init__() self.cfg = cfg # 1) Build audio tokenizer (usually None during training) self.audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg) if self.audio_tokenizer is not None: for param in self.audio_tokenizer.parameters(): param.requires_grad = False # 2) Build LM self.audiolm = builders.get_lm_model(self.cfg) # 3) Load pretrained checkpoint (if any) if self.cfg.use_pretrained == 'deepspeed': checkpoint = torch.load(self.cfg.pretrained.deepspeed_checkpoint, map_location='cpu') missing, unexpected = self.load_state_dict(checkpoint, strict=False) print(f'-------------Missing--------------\n{missing}') print(f'-------------Unexpected--------------\n{unexpected}') print("successfully load deepspeed pretrained model {}".format(self.cfg.pretrained.deepspeed_checkpoint)) # 4) Build metrics self.val_steps = [] self.train_slide_acc = [] self.train_steps = [] self.top1_acc_metric = nn.ModuleList([MulticlassAccuracy( self.audiolm.code_size, top_k=1, average="micro", multidim_average="global", ignore_index=self.cfg.lm.code_size, # ignore EOS token prediction ) for _ in range(self.audiolm.code_depth)]) self.top10_acc_metric = nn.ModuleList([MulticlassAccuracy( self.audiolm.code_size, top_k=10, average="micro", multidim_average="global", ignore_index=self.cfg.lm.code_size, ) for _ in range(self.audiolm.code_depth)]) self.epoch = 0 print("++++++++++++++++ training +++++++++++++++++") # TODO: move this part to loader def generate_mask_and_end_token(self, x, sequence_lengths, end_id=16384): batch_size = sequence_lengths.size(0) max_length = x.size(2) # pad one frame, if the maximum sequence length is equal to the input length if max_length == sequence_lengths.max(): x = F.pad(x, (0, 1), value=end_id) max_length = x.size(2) if max_length <= sequence_lengths.max() + 1: sequence_lengths = sequence_lengths - (sequence_lengths.max()+1 - max_length) # Add end token to x according to the sequence length x[torch.arange(batch_size), :, sequence_lengths] = end_id sequence_lengths += 1 mask = torch.arange(max_length).expand(batch_size, max_length) < sequence_lengths.unsqueeze(1) mask = mask.to(x.device) mask_3d = mask.unsqueeze(1).expand(batch_size, x.size(1), max_length) x = torch.where(mask_3d, x, end_id+1) return x, mask_3d @torch.no_grad() def preprocess_batch(self, batch): # this function is usually called during training # 处理 dataloader 返回的数据 audio, text_lyric, time_stamp, lang_type, prompt_audio = batch dur, valid_st, valid_et = zip(*time_stamp) if self.audio_tokenizer is not None: # only used in inference self.audio_tokenizer.eval() with torch.no_grad(): with torch.cuda.amp.autocast(enabled=False): audio_tokens, scale = self.audio_tokenizer.encode(audio) audio_tokens = audio_tokens[:,:self.cfg.lm.code_depth,:] audio_tokens = audio_tokens.long() else: audio_tokens = audio.long() token_dur = (torch.Tensor(dur) * self.cfg.audio_tokenizer_frame_rate).int() audio_tokens, audio_padding_mask = self.generate_mask_and_end_token(audio_tokens, token_dur, end_id=self.audiolm.eos_token_id) condition_tensors = self.audiolm.prepare_condition_tensors(batch_size=len(text_lyric), text=text_lyric, audio_qt_emb=prompt_audio) return condition_tensors, audio_tokens, audio_padding_mask def get_time(self): # 获取当前的日期和时间 now = datetime.now() # 使用strftime函数格式化日期和时间 formatted_now = now.strftime("%Y-%m-%d %H:%M:%S.%f") return formatted_now def training_step(self, batch, batch_idx): # 1) data processing condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch) # 2) compute model predictions (model forward) model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors, training_steps=self.global_step) # this input can be ignored logits = model_output.logits.float() mask = padding_mask & model_output.mask # 3) compute loss (float) with torch.cuda.amp.autocast(enabled=False): ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask) total_loss = ce if torch.isnan(total_loss): print(self.trainer.global_rank, ce, padding_mask, batch[1]) # print('------------------------------------------------------------------------') torchaudio.save("error_rank{}.wav".format(self.trainer.global_rank), batch[0][:,0].cpu(), 24000) import pdb; pdb.set_trace() return None # 4) compute metrics and log metrics = {} self.log('ce', ce, prog_bar=True) metrics['ppl'] = torch.exp(ce) for k, ce_q in enumerate(ce_per_codebook): metrics[f'ce_q{k + 1}'] = ce_q metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q) masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size) metrics['acc'] = [] for k in range(self.audiolm.code_depth): metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:, k]).item()) metrics['acc'] = torch.mean(torch.Tensor(metrics['acc'])).item() self.train_steps.append({'ce': ce.detach().cpu().item(), 'acc': metrics['acc']}) self.log('train_acc', metrics['acc']+1e-8, prog_bar=True) self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'], prog_bar=True) self.log_dict(metrics) return total_loss @torch.no_grad() def validation_step(self, batch, batch_idx): # 1) data processing condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch) # 2) compute model predictions model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors) logits = model_output.logits mask = padding_mask & model_output.mask # 3) compute loss and metrics ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask) metrics = {} metrics['val_ce'] = ce metrics['val_ppl'] = torch.exp(ce) for k, ce_q in enumerate(ce_per_codebook): metrics[f'val_ce_q{k + 1}'] = ce_q metrics[f'val_ppl_q{k + 1}'] = torch.exp(ce_q) masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size) for k in range(self.audiolm.code_depth): self.top1_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k]) #* total_length self.top10_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k]) self.val_steps.append(metrics) metrics['acc'] = [] metrics['acc_top10'] = [] for k in range(self.audiolm.code_depth): metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item()) metrics['acc_top10'].append(self.top10_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item()) metrics['acc'] = torch.mean(torch.Tensor(metrics['acc'])) metrics['acc_top10'] = torch.mean(torch.Tensor(metrics['acc_top10'])) return metrics['acc'] def on_validation_epoch_end(self) -> None: final_metrics = {} for i in self.val_steps: for k in i: final_metrics[k] = final_metrics.get(k, []) + [i[k]] final_metrics = {k: sum(v) / len(v) for k,v in list(final_metrics.items())} self.log_dict(final_metrics) q_acc = [] q_acc10 = [] for i in range(self.audiolm.code_depth): q_acc.append(self.top1_acc_metric[i].compute()) q_acc10.append(self.top10_acc_metric[i].compute()) self.log(f"val_Top1Acc_{i}", q_acc[-1]) self.log(f"val_Top10Acc_{i}", q_acc10[-1]) self.top1_acc_metric[i].reset() self.top10_acc_metric[i].reset() self.log('val_Top1Acc', sum(q_acc) / self.audiolm.code_depth) self.log('val_Top10Acc', sum(q_acc10) / self.audiolm.code_depth) return super().on_validation_epoch_end() def on_validation_epoch_start(self) -> None: self.val_steps = [] for i in range(self.audiolm.code_depth): self.top1_acc_metric[i].reset() self.top10_acc_metric[i].reset() if len(self.train_steps) > 0: train_metrics = {} for i in self.train_steps: for k in i: train_metrics[k] = train_metrics.get(k, []) + [i[k]] train_metrics = {k: sum(v) / len(v) for k,v in list(train_metrics.items())} self.log('train_summary_Top1Acc', train_metrics['acc']) self.log('train_summary_ce', train_metrics['ce']) self.train_steps = [] return super().on_validation_epoch_start() # 定义优化器 def configure_optimizers(self): total_updates = self.cfg.optim.epochs * self.cfg.optim.updates_per_epoch optim_dict = {} if self.cfg.optim.optimizer == "adamw": optim_dict['optimizer'] = torch.optim.AdamW( self.audiolm.parameters(), lr=self.cfg.optim.lr, betas=tuple(self.cfg.optim.adam.betas), weight_decay=self.cfg.optim.adam.weight_decay, eps=self.cfg.optim.adam.eps, ) else: raise NotImplementedError if self.cfg.schedule is None: pass elif self.cfg.schedule.lr_scheduler == "cosine": scheduler = CosineLRScheduler(optim_dict['optimizer'], total_steps=total_updates, warmup_steps=self.cfg.schedule.cosine.warmup, lr_min_ratio=self.cfg.schedule.cosine.lr_min_ratio, cycle_length=self.cfg.schedule.cosine.cycle_length, ) optim_dict['lr_scheduler'] = {"scheduler": scheduler, "interval": "step"} else: raise NotImplementedError return optim_dict def _compute_cross_entropy( self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]: """Compute cross entropy between multi-codebook targets and model's logits. The cross entropy is computed per codebook to provide codebook-level cross entropy. Valid timesteps for each of the codebook are pulled from the mask, where invalid timesteps are set to 0. Args: logits (torch.Tensor): Model's logits of shape [B, K, T, card]. targets (torch.Tensor): Target codes, of shape [B, K, T]. mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. Returns: ce (torch.Tensor): Cross entropy averaged over the codebooks ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). """ # import pdb; pdb.set_trace() B, K, T = targets.shape assert logits.shape[:-1] == targets.shape assert mask.shape == targets.shape ce = torch.zeros([], device=targets.device) ce_per_codebook: tp.List[torch.Tensor] = [] for k in range(K): logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card] targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T] mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T] ce_targets = targets_k[mask_k] ce_logits = logits_k[mask_k] q_ce = F.cross_entropy(ce_logits, ce_targets) ce += q_ce ce_per_codebook.append(q_ce.detach()) # average cross entropy across codebooks ce = ce / K return ce, ce_per_codebook class CosineLRScheduler(_LRScheduler):# """Cosine LR scheduler. Args: optimizer (Optimizer): Torch optimizer. warmup_steps (int): Number of warmup steps. total_steps (int): Total number of steps. lr_min_ratio (float): Minimum learning rate. cycle_length (float): Cycle length. """ def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int, lr_min_ratio: float = 0.0, cycle_length: float = 1.0): self.warmup_steps = warmup_steps assert self.warmup_steps >= 0 self.total_steps = total_steps assert self.total_steps >= 0 self.lr_min_ratio = lr_min_ratio self.cycle_length = cycle_length super().__init__(optimizer) def _get_sched_lr(self, lr: float, step: int): if step < self.warmup_steps: lr_ratio = step / self.warmup_steps lr = lr_ratio * lr elif step <= self.total_steps: s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps) lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \ (1. + math.cos(math.pi * s / self.cycle_length)) lr = lr_ratio * lr else: lr_ratio = self.lr_min_ratio lr = lr_ratio * lr return lr def get_lr(self): return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs]