SongGeneration / codeclm /trainer /codec_song_pl.py
waytan22's picture
add auto prompt and interface
d658154
"""
Main model for using CodecLM. This will combine all the required components
and provide easy access to the generation API.
"""
import typing as tp
import warnings
import sys
import time
import torch
import torch.nn as nn
from torch.nn import functional as F
import torchaudio
import numpy as np
import lightning as pl
from torchmetrics.classification import MulticlassAccuracy
import pdb
from codeclm.models import builders
import math
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from peft import LoraConfig, get_peft_model
from datetime import datetime
import os
os.environ['TOKENIZERS_PARALLELISM'] = "false"
class CodecLM_PL(pl.LightningModule):
def __init__(self, cfg, ckpt_path):
super().__init__()
self.cfg = cfg
# 1) Build audio tokenizer (usually None during training)
self.audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg)
if self.audio_tokenizer is not None:
for param in self.audio_tokenizer.parameters():
param.requires_grad = False
if "audio_tokenizer_checkpoint_sep" in self.cfg.keys():
self.seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg)
for param in self.seperate_tokenizer.parameters():
param.requires_grad = False
else:
self.seperate_tokenizer = None
# 2) Build LM
self.audiolm = builders.get_lm_model(self.cfg)
print(self.audiolm)
# 3) Load pretrained checkpoint (if any)
checkpoint = torch.load(ckpt_path, map_location='cpu')
missing, unexpected = self.load_state_dict(checkpoint, strict=False)
print(f'-------------Missing--------------\n{missing}')
print(f'-------------Unexpected--------------\n{unexpected}')
print("successfully load deepspeed pretrained model {}".format(ckpt_path))
# 4) Build metrics
self.val_steps = []
self.train_slide_acc = []
self.train_steps = []
self.top1_acc_metric = nn.ModuleList([MulticlassAccuracy(
self.audiolm.code_size,
top_k=1,
average="micro", multidim_average="global",
ignore_index=self.cfg.lm.code_size, # ignore EOS token prediction
) for _ in range(self.audiolm.code_depth)])
self.top10_acc_metric = nn.ModuleList([MulticlassAccuracy(
self.audiolm.code_size,
top_k=10,
average="micro", multidim_average="global",
ignore_index=self.cfg.lm.code_size,
) for _ in range(self.audiolm.code_depth)])
self.epoch = 0
print("++++++++++++++++ training <song> +++++++++++++++++")
# TODO: move this part to loader
def generate_mask_and_end_token(self, x, sequence_lengths, end_id=16384):
batch_size = sequence_lengths.size(0)
max_length = x.size(2)
# pad one frame, if the maximum sequence length is equal to the input length
if max_length == sequence_lengths.max():
x = F.pad(x, (0, 1), value=end_id)
max_length = x.size(2)
if max_length <= sequence_lengths.max() + 1:
sequence_lengths = sequence_lengths - (sequence_lengths.max()+1 - max_length)
# Add end token to x according to the sequence length
x[torch.arange(batch_size), :, sequence_lengths] = end_id
sequence_lengths += 1
mask = torch.arange(max_length).expand(batch_size, max_length) < sequence_lengths.unsqueeze(1)
mask = mask.to(x.device)
mask_3d = mask.unsqueeze(1).expand(batch_size, x.size(1), max_length)
x = torch.where(mask_3d, x, end_id+1)
return x, mask_3d
def get_time(self):
# 获取当前的日期和时间
now = datetime.now()
# 使用strftime函数格式化日期和时间
formatted_now = now.strftime("%Y-%m-%d %H:%M:%S.%f")
return formatted_now
class CosineLRScheduler(_LRScheduler):#
"""Cosine LR scheduler.
Args:
optimizer (Optimizer): Torch optimizer.
warmup_steps (int): Number of warmup steps.
total_steps (int): Total number of steps.
lr_min_ratio (float): Minimum learning rate.
cycle_length (float): Cycle length.
"""
def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int,
lr_min_ratio: float = 0.0, cycle_length: float = 1.0):
self.warmup_steps = warmup_steps
assert self.warmup_steps >= 0
self.total_steps = total_steps
assert self.total_steps >= 0
self.lr_min_ratio = lr_min_ratio
self.cycle_length = cycle_length
super().__init__(optimizer)
def _get_sched_lr(self, lr: float, step: int):
if step < self.warmup_steps:
lr_ratio = step / self.warmup_steps
lr = lr_ratio * lr
elif step <= self.total_steps:
s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \
(1. + math.cos(math.pi * s / self.cycle_length))
lr = lr_ratio * lr
else:
lr_ratio = self.lr_min_ratio
lr = lr_ratio * lr
return lr
def get_lr(self):
return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs]