SongGeneration / codeclm /trainer /codec_song_pl.py
hainazhu
Add application file
258fd02
raw
history blame
30.3 kB
"""
Main model for using CodecLM. This will combine all the required components
and provide easy access to the generation API.
"""
import typing as tp
import warnings
import sys
import time
import torch
import torch.nn as nn
from torch.nn import functional as F
import torchaudio
import numpy as np
import lightning as pl
from torchmetrics.classification import MulticlassAccuracy
import pdb
from codeclm.models import builders
import math
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from peft import LoraConfig, get_peft_model
from datetime import datetime
import os
os.environ['TOKENIZERS_PARALLELISM'] = "false"
class CodecLM_PL(pl.LightningModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
# 1) Build audio tokenizer (usually None during training)
self.audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg)
if self.audio_tokenizer is not None:
for param in self.audio_tokenizer.parameters():
param.requires_grad = False
if "audio_tokenizer_checkpoint_sep" in self.cfg.keys():
self.seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg)
for param in self.seperate_tokenizer.parameters():
param.requires_grad = False
else:
self.seperate_tokenizer = None
# 2) Build LM
self.audiolm = builders.get_lm_model(self.cfg)
print(self.audiolm)
# 输出参数量
print('Number of parameters: ', sum(p.numel() for p in self.audiolm.parameters()))
# 3) Load pretrained checkpoint (if any)
if self.cfg.use_pretrained == 'deepspeed':
checkpoint = torch.load(self.cfg.pretrained.deepspeed_checkpoint, map_location='cpu')
missing, unexpected = self.load_state_dict(checkpoint, strict=False)
print(f'-------------Missing--------------\n{missing}')
print(f'-------------Unexpected--------------\n{unexpected}')
print("successfully load deepspeed pretrained model {}".format(self.cfg.pretrained.deepspeed_checkpoint))
self.missing = missing
else:
self.missing = []
# 如果cfg参数中有lora
if hasattr(self.cfg, 'lora'):
perf_config = LoraConfig(
r = self.cfg.lora.r,
lora_alpha = self.cfg.lora.lora_alpha,
target_modules = self.cfg.lora.target_modules,
lora_dropout = self.cfg.lora.lora_dropout,
bias = self.cfg.lora.bias,
task_type = self.cfg.lora.task_type,
)
self.audiolm = get_peft_model(self.audiolm, perf_config)
# 4) Build metrics
self.val_steps = []
self.train_slide_acc = []
self.train_steps = []
self.top1_acc_metric = nn.ModuleList([MulticlassAccuracy(
self.audiolm.code_size,
top_k=1,
average="micro", multidim_average="global",
ignore_index=self.cfg.lm.code_size, # ignore EOS token prediction
) for _ in range(self.audiolm.code_depth)])
self.top10_acc_metric = nn.ModuleList([MulticlassAccuracy(
self.audiolm.code_size,
top_k=10,
average="micro", multidim_average="global",
ignore_index=self.cfg.lm.code_size,
) for _ in range(self.audiolm.code_depth)])
self.epoch = 0
print("++++++++++++++++ training <song> +++++++++++++++++")
# TODO: move this part to loader
def generate_mask_and_end_token(self, x, sequence_lengths, end_id=16384):
batch_size = sequence_lengths.size(0)
max_length = x.size(2)
# pad one frame, if the maximum sequence length is equal to the input length
if max_length == sequence_lengths.max():
x = F.pad(x, (0, 1), value=end_id)
max_length = x.size(2)
if max_length <= sequence_lengths.max() + 1:
sequence_lengths = sequence_lengths - (sequence_lengths.max()+1 - max_length)
# Add end token to x according to the sequence length
x[torch.arange(batch_size), :, sequence_lengths] = end_id
sequence_lengths += 1
mask = torch.arange(max_length).expand(batch_size, max_length) < sequence_lengths.unsqueeze(1)
mask = mask.to(x.device)
mask_3d = mask.unsqueeze(1).expand(batch_size, x.size(1), max_length)
x = torch.where(mask_3d, x, end_id+1)
return x, mask_3d
@torch.no_grad()
def preprocess_batch(self, batch): # this function is usually called during training
# 处理 dataloader 返回的数据
audio, text_lyric, time_stamp, structure_dur, prompt_audio, structure_labels = batch
dur, valid_st, valid_et = zip(*time_stamp)
if self.audio_tokenizer is not None:
# only used in inference
self.audio_tokenizer.eval()
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
audio_tokens, scale = self.audio_tokenizer.encode(audio)
audio_tokens = audio_tokens[:,:self.cfg.lm.code_depth,:]
audio_tokens = audio_tokens.long()
else:
audio_tokens = audio.long()
token_dur = (torch.Tensor(dur) * self.cfg.audio_tokenizer_frame_rate).int()
audio_tokens, audio_padding_mask = self.generate_mask_and_end_token(audio_tokens, token_dur,
end_id=self.audiolm.eos_token_id)
condition_tensors = self.audiolm.prepare_condition_tensors(batch_size=len(text_lyric),
text=text_lyric, audio_qt_emb=prompt_audio)
return condition_tensors, audio_tokens, audio_padding_mask
def get_time(self):
# 获取当前的日期和时间
now = datetime.now()
# 使用strftime函数格式化日期和时间
formatted_now = now.strftime("%Y-%m-%d %H:%M:%S.%f")
return formatted_now
def training_step(self, batch, batch_idx):
# 1) data processing
condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch)
# 2) compute model predictions (model forward)
model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors,
training_steps=self.global_step) # this input can be ignored
logits = model_output.logits.float()
mask = padding_mask & model_output.mask
# 3) compute loss (float)
with torch.cuda.amp.autocast(enabled=False):
ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
total_loss = ce
if torch.isnan(total_loss):
print(self.trainer.global_rank, ce, padding_mask, batch[1])
print('--------------------------------------------------------------')
return None
# torchaudio.save("error_rank{}.wav".format(self.trainer.global_rank), batch[0][:,0].cpu(), 24000)
# import pdb; pdb.set_trace()
# 4) compute metrics and log
metrics = {}
self.log('ce', ce, prog_bar=True)
metrics['ppl'] = torch.exp(ce)
for k, ce_q in enumerate(ce_per_codebook):
metrics[f'ce_q{k + 1}'] = ce_q
metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q)
masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size)
metrics['acc'] = []
for k in range(self.audiolm.code_depth):
metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(),
masked_labels[:, k]).item())
metrics['acc'] = torch.mean(torch.Tensor(metrics['acc'])).item()
self.train_steps.append({'ce': ce.detach().cpu().item(), 'acc': metrics['acc']})
self.log('train_acc', metrics['acc']+1e-8, prog_bar=True)
self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'], prog_bar=True)
self.log_dict(metrics)
return total_loss
@torch.no_grad()
def validation_step(self, batch, batch_idx):
# 1) data processing
condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch)
# 2) compute model predictions
model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors)
logits = model_output.logits
mask = padding_mask & model_output.mask
# 3) compute loss and metrics
ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
metrics = {}
metrics['val_ce'] = ce
metrics['val_ppl'] = torch.exp(ce)
for k, ce_q in enumerate(ce_per_codebook):
metrics[f'val_ce_q{k + 1}'] = ce_q
metrics[f'val_ppl_q{k + 1}'] = torch.exp(ce_q)
masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size)
for k in range(self.audiolm.code_depth):
self.top1_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k]) #* total_length
self.top10_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k])
self.val_steps.append(metrics)
metrics['acc'] = []
metrics['acc_top10'] = []
for k in range(self.audiolm.code_depth):
metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item())
metrics['acc_top10'].append(self.top10_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item())
metrics['acc'] = torch.mean(torch.Tensor(metrics['acc']))
metrics['acc_top10'] = torch.mean(torch.Tensor(metrics['acc_top10']))
return metrics['acc']
def on_validation_epoch_end(self) -> None:
final_metrics = {}
for i in self.val_steps:
for k in i:
final_metrics[k] = final_metrics.get(k, []) + [i[k]]
final_metrics = {k: sum(v) / len(v) for k,v in list(final_metrics.items())}
self.log_dict(final_metrics)
q_acc = []
q_acc10 = []
for i in range(self.audiolm.code_depth):
q_acc.append(self.top1_acc_metric[i].compute())
q_acc10.append(self.top10_acc_metric[i].compute())
self.log(f"val_Top1Acc_{i}", q_acc[-1])
self.log(f"val_Top10Acc_{i}", q_acc10[-1])
self.top1_acc_metric[i].reset()
self.top10_acc_metric[i].reset()
self.log('val_Top1Acc', sum(q_acc) / self.audiolm.code_depth)
self.log('val_Top10Acc', sum(q_acc10) / self.audiolm.code_depth)
return super().on_validation_epoch_end()
def on_validation_epoch_start(self) -> None:
self.val_steps = []
for i in range(self.audiolm.code_depth):
self.top1_acc_metric[i].reset()
self.top10_acc_metric[i].reset()
if len(self.train_steps) > 0:
train_metrics = {}
for i in self.train_steps:
for k in i:
train_metrics[k] = train_metrics.get(k, []) + [i[k]]
train_metrics = {k: sum(v) / len(v) for k,v in list(train_metrics.items())}
self.log('train_summary_Top1Acc', train_metrics['acc'])
self.log('train_summary_ce', train_metrics['ce'])
self.train_steps = []
return super().on_validation_epoch_start()
# 定义优化器
def configure_optimizers(self):
total_updates = self.cfg.optim.epochs * self.cfg.optim.updates_per_epoch
optim_dict = {}
param_groups = []
missing_params = []
other_params = []
cnt = 0
# 去掉开头的‘audiolm.'
print('before missing len', len(self.missing))
self.missing = [name.replace('audiolm.', '') for name in self.missing]
print('after missing len', len(self.missing))
for name, param in self.audiolm.named_parameters():
if name in self.missing:
cnt += 1
print(name)
missing_params.append(param)
else:
other_params.append(param)
print(cnt)
assert cnt == len(self.missing)
param_groups.append({'params': other_params, 'lr': self.cfg.optim.old_lr})
param_groups.append({
'params': missing_params,
'lr': self.cfg.optim.new_lr # 为missing参数设置10倍的学习率,你可以调整这个倍数
})
if self.cfg.optim.optimizer == "adamw":
optim_dict['optimizer'] = torch.optim.AdamW(
param_groups, # 使用参数分组替代原来的 self.audiolm.parameters()
betas=tuple(self.cfg.optim.adam.betas),
weight_decay=self.cfg.optim.adam.weight_decay,
eps=self.cfg.optim.adam.eps,
)
else:
raise NotImplementedError
if self.cfg.schedule is None:
pass
elif self.cfg.schedule.lr_scheduler == "cosine":
scheduler = CosineLRScheduler(optim_dict['optimizer'],
total_steps=total_updates,
warmup_steps=self.cfg.schedule.cosine.warmup,
lr_min_ratio=self.cfg.schedule.cosine.lr_min_ratio,
cycle_length=self.cfg.schedule.cosine.cycle_length,
)
optim_dict['lr_scheduler'] = {"scheduler": scheduler, "interval": "step"}
else:
raise NotImplementedError
return optim_dict
def _compute_cross_entropy(
self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:
"""Compute cross entropy between multi-codebook targets and model's logits.
The cross entropy is computed per codebook to provide codebook-level cross entropy.
Valid timesteps for each of the codebook are pulled from the mask, where invalid
timesteps are set to 0.
Args:
logits (torch.Tensor): Model's logits of shape [B, K, T, card].
targets (torch.Tensor): Target codes, of shape [B, K, T].
mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
Returns:
ce (torch.Tensor): Cross entropy averaged over the codebooks
ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
"""
# import pdb; pdb.set_trace()
B, K, T = targets.shape
assert logits.shape[:-1] == targets.shape
assert mask.shape == targets.shape
ce = torch.zeros([], device=targets.device)
ce_per_codebook: tp.List[torch.Tensor] = []
for k in range(K):
logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card]
targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T]
mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T]
ce_targets = targets_k[mask_k]
ce_logits = logits_k[mask_k]
q_ce = F.cross_entropy(ce_logits, ce_targets)
ce += q_ce
ce_per_codebook.append(q_ce.detach())
# average cross entropy across codebooks
ce = ce / K
return ce, ce_per_codebook
class CodecLM_PL_FT(pl.LightningModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
# 1) Build audio tokenizer (usually None during training)
self.audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg)
if self.audio_tokenizer is not None:
for param in self.audio_tokenizer.parameters():
param.requires_grad = False
# 2) Build LM
self.audiolm = builders.get_lm_model(self.cfg)
# 3) Load pretrained checkpoint (if any)
if self.cfg.use_pretrained == 'deepspeed':
checkpoint = torch.load(self.cfg.pretrained.deepspeed_checkpoint, map_location='cpu')
missing, unexpected = self.load_state_dict(checkpoint, strict=False)
print(f'-------------Missing--------------\n{missing}')
print(f'-------------Unexpected--------------\n{unexpected}')
print("successfully load deepspeed pretrained model {}".format(self.cfg.pretrained.deepspeed_checkpoint))
# 4) Build metrics
self.val_steps = []
self.train_slide_acc = []
self.train_steps = []
self.top1_acc_metric = nn.ModuleList([MulticlassAccuracy(
self.audiolm.code_size,
top_k=1,
average="micro", multidim_average="global",
ignore_index=self.cfg.lm.code_size, # ignore EOS token prediction
) for _ in range(self.audiolm.code_depth)])
self.top10_acc_metric = nn.ModuleList([MulticlassAccuracy(
self.audiolm.code_size,
top_k=10,
average="micro", multidim_average="global",
ignore_index=self.cfg.lm.code_size,
) for _ in range(self.audiolm.code_depth)])
self.epoch = 0
print("++++++++++++++++ training <song> +++++++++++++++++")
# TODO: move this part to loader
def generate_mask_and_end_token(self, x, sequence_lengths, end_id=16384):
batch_size = sequence_lengths.size(0)
max_length = x.size(2)
# pad one frame, if the maximum sequence length is equal to the input length
if max_length == sequence_lengths.max():
x = F.pad(x, (0, 1), value=end_id)
max_length = x.size(2)
if max_length <= sequence_lengths.max() + 1:
sequence_lengths = sequence_lengths - (sequence_lengths.max()+1 - max_length)
# Add end token to x according to the sequence length
x[torch.arange(batch_size), :, sequence_lengths] = end_id
sequence_lengths += 1
mask = torch.arange(max_length).expand(batch_size, max_length) < sequence_lengths.unsqueeze(1)
mask = mask.to(x.device)
mask_3d = mask.unsqueeze(1).expand(batch_size, x.size(1), max_length)
x = torch.where(mask_3d, x, end_id+1)
return x, mask_3d
@torch.no_grad()
def preprocess_batch(self, batch): # this function is usually called during training
# 处理 dataloader 返回的数据
audio, text_lyric, time_stamp, lang_type, prompt_audio = batch
dur, valid_st, valid_et = zip(*time_stamp)
if self.audio_tokenizer is not None:
# only used in inference
self.audio_tokenizer.eval()
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
audio_tokens, scale = self.audio_tokenizer.encode(audio)
audio_tokens = audio_tokens[:,:self.cfg.lm.code_depth,:]
audio_tokens = audio_tokens.long()
else:
audio_tokens = audio.long()
token_dur = (torch.Tensor(dur) * self.cfg.audio_tokenizer_frame_rate).int()
audio_tokens, audio_padding_mask = self.generate_mask_and_end_token(audio_tokens, token_dur,
end_id=self.audiolm.eos_token_id)
condition_tensors = self.audiolm.prepare_condition_tensors(batch_size=len(text_lyric),
text=text_lyric, audio_qt_emb=prompt_audio)
return condition_tensors, audio_tokens, audio_padding_mask
def get_time(self):
# 获取当前的日期和时间
now = datetime.now()
# 使用strftime函数格式化日期和时间
formatted_now = now.strftime("%Y-%m-%d %H:%M:%S.%f")
return formatted_now
def training_step(self, batch, batch_idx):
# 1) data processing
condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch)
# 2) compute model predictions (model forward)
model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors,
training_steps=self.global_step) # this input can be ignored
logits = model_output.logits.float()
mask = padding_mask & model_output.mask
# 3) compute loss (float)
with torch.cuda.amp.autocast(enabled=False):
ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
total_loss = ce
if torch.isnan(total_loss):
print(self.trainer.global_rank, ce, padding_mask, batch[1])
# print('------------------------------------------------------------------------')
torchaudio.save("error_rank{}.wav".format(self.trainer.global_rank), batch[0][:,0].cpu(), 24000)
import pdb; pdb.set_trace()
return None
# 4) compute metrics and log
metrics = {}
self.log('ce', ce, prog_bar=True)
metrics['ppl'] = torch.exp(ce)
for k, ce_q in enumerate(ce_per_codebook):
metrics[f'ce_q{k + 1}'] = ce_q
metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q)
masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size)
metrics['acc'] = []
for k in range(self.audiolm.code_depth):
metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(),
masked_labels[:, k]).item())
metrics['acc'] = torch.mean(torch.Tensor(metrics['acc'])).item()
self.train_steps.append({'ce': ce.detach().cpu().item(), 'acc': metrics['acc']})
self.log('train_acc', metrics['acc']+1e-8, prog_bar=True)
self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'], prog_bar=True)
self.log_dict(metrics)
return total_loss
@torch.no_grad()
def validation_step(self, batch, batch_idx):
# 1) data processing
condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch)
# 2) compute model predictions
model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors)
logits = model_output.logits
mask = padding_mask & model_output.mask
# 3) compute loss and metrics
ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
metrics = {}
metrics['val_ce'] = ce
metrics['val_ppl'] = torch.exp(ce)
for k, ce_q in enumerate(ce_per_codebook):
metrics[f'val_ce_q{k + 1}'] = ce_q
metrics[f'val_ppl_q{k + 1}'] = torch.exp(ce_q)
masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size)
for k in range(self.audiolm.code_depth):
self.top1_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k]) #* total_length
self.top10_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k])
self.val_steps.append(metrics)
metrics['acc'] = []
metrics['acc_top10'] = []
for k in range(self.audiolm.code_depth):
metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item())
metrics['acc_top10'].append(self.top10_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item())
metrics['acc'] = torch.mean(torch.Tensor(metrics['acc']))
metrics['acc_top10'] = torch.mean(torch.Tensor(metrics['acc_top10']))
return metrics['acc']
def on_validation_epoch_end(self) -> None:
final_metrics = {}
for i in self.val_steps:
for k in i:
final_metrics[k] = final_metrics.get(k, []) + [i[k]]
final_metrics = {k: sum(v) / len(v) for k,v in list(final_metrics.items())}
self.log_dict(final_metrics)
q_acc = []
q_acc10 = []
for i in range(self.audiolm.code_depth):
q_acc.append(self.top1_acc_metric[i].compute())
q_acc10.append(self.top10_acc_metric[i].compute())
self.log(f"val_Top1Acc_{i}", q_acc[-1])
self.log(f"val_Top10Acc_{i}", q_acc10[-1])
self.top1_acc_metric[i].reset()
self.top10_acc_metric[i].reset()
self.log('val_Top1Acc', sum(q_acc) / self.audiolm.code_depth)
self.log('val_Top10Acc', sum(q_acc10) / self.audiolm.code_depth)
return super().on_validation_epoch_end()
def on_validation_epoch_start(self) -> None:
self.val_steps = []
for i in range(self.audiolm.code_depth):
self.top1_acc_metric[i].reset()
self.top10_acc_metric[i].reset()
if len(self.train_steps) > 0:
train_metrics = {}
for i in self.train_steps:
for k in i:
train_metrics[k] = train_metrics.get(k, []) + [i[k]]
train_metrics = {k: sum(v) / len(v) for k,v in list(train_metrics.items())}
self.log('train_summary_Top1Acc', train_metrics['acc'])
self.log('train_summary_ce', train_metrics['ce'])
self.train_steps = []
return super().on_validation_epoch_start()
# 定义优化器
def configure_optimizers(self):
total_updates = self.cfg.optim.epochs * self.cfg.optim.updates_per_epoch
optim_dict = {}
if self.cfg.optim.optimizer == "adamw":
optim_dict['optimizer'] = torch.optim.AdamW(
self.audiolm.parameters(),
lr=self.cfg.optim.lr,
betas=tuple(self.cfg.optim.adam.betas),
weight_decay=self.cfg.optim.adam.weight_decay,
eps=self.cfg.optim.adam.eps,
)
else:
raise NotImplementedError
if self.cfg.schedule is None:
pass
elif self.cfg.schedule.lr_scheduler == "cosine":
scheduler = CosineLRScheduler(optim_dict['optimizer'],
total_steps=total_updates,
warmup_steps=self.cfg.schedule.cosine.warmup,
lr_min_ratio=self.cfg.schedule.cosine.lr_min_ratio,
cycle_length=self.cfg.schedule.cosine.cycle_length,
)
optim_dict['lr_scheduler'] = {"scheduler": scheduler, "interval": "step"}
else:
raise NotImplementedError
return optim_dict
def _compute_cross_entropy(
self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:
"""Compute cross entropy between multi-codebook targets and model's logits.
The cross entropy is computed per codebook to provide codebook-level cross entropy.
Valid timesteps for each of the codebook are pulled from the mask, where invalid
timesteps are set to 0.
Args:
logits (torch.Tensor): Model's logits of shape [B, K, T, card].
targets (torch.Tensor): Target codes, of shape [B, K, T].
mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
Returns:
ce (torch.Tensor): Cross entropy averaged over the codebooks
ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
"""
# import pdb; pdb.set_trace()
B, K, T = targets.shape
assert logits.shape[:-1] == targets.shape
assert mask.shape == targets.shape
ce = torch.zeros([], device=targets.device)
ce_per_codebook: tp.List[torch.Tensor] = []
for k in range(K):
logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card]
targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T]
mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T]
ce_targets = targets_k[mask_k]
ce_logits = logits_k[mask_k]
q_ce = F.cross_entropy(ce_logits, ce_targets)
ce += q_ce
ce_per_codebook.append(q_ce.detach())
# average cross entropy across codebooks
ce = ce / K
return ce, ce_per_codebook
class CosineLRScheduler(_LRScheduler):#
"""Cosine LR scheduler.
Args:
optimizer (Optimizer): Torch optimizer.
warmup_steps (int): Number of warmup steps.
total_steps (int): Total number of steps.
lr_min_ratio (float): Minimum learning rate.
cycle_length (float): Cycle length.
"""
def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int,
lr_min_ratio: float = 0.0, cycle_length: float = 1.0):
self.warmup_steps = warmup_steps
assert self.warmup_steps >= 0
self.total_steps = total_steps
assert self.total_steps >= 0
self.lr_min_ratio = lr_min_ratio
self.cycle_length = cycle_length
super().__init__(optimizer)
def _get_sched_lr(self, lr: float, step: int):
if step < self.warmup_steps:
lr_ratio = step / self.warmup_steps
lr = lr_ratio * lr
elif step <= self.total_steps:
s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \
(1. + math.cos(math.pi * s / self.cycle_length))
lr = lr_ratio * lr
else:
lr_ratio = self.lr_min_ratio
lr = lr_ratio * lr
return lr
def get_lr(self):
return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs]