Spaces:
Running
on
L40S
Running
on
L40S
File size: 5,471 Bytes
258fd02 d658154 258fd02 d658154 258fd02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
"""
Main model for using CodecLM. This will combine all the required components
and provide easy access to the generation API.
"""
import typing as tp
import warnings
import sys
import time
import torch
import torch.nn as nn
from torch.nn import functional as F
import torchaudio
import numpy as np
import lightning as pl
from torchmetrics.classification import MulticlassAccuracy
import pdb
from codeclm.models import builders
import math
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from peft import LoraConfig, get_peft_model
from datetime import datetime
import os
os.environ['TOKENIZERS_PARALLELISM'] = "false"
class CodecLM_PL(pl.LightningModule):
def __init__(self, cfg, ckpt_path):
super().__init__()
self.cfg = cfg
# 1) Build audio tokenizer (usually None during training)
self.audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg)
if self.audio_tokenizer is not None:
for param in self.audio_tokenizer.parameters():
param.requires_grad = False
if "audio_tokenizer_checkpoint_sep" in self.cfg.keys():
self.seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg)
for param in self.seperate_tokenizer.parameters():
param.requires_grad = False
else:
self.seperate_tokenizer = None
# 2) Build LM
self.audiolm = builders.get_lm_model(self.cfg)
print(self.audiolm)
# 3) Load pretrained checkpoint (if any)
checkpoint = torch.load(ckpt_path, map_location='cpu')
missing, unexpected = self.load_state_dict(checkpoint, strict=False)
print(f'-------------Missing--------------\n{missing}')
print(f'-------------Unexpected--------------\n{unexpected}')
print("successfully load deepspeed pretrained model {}".format(ckpt_path))
# 4) Build metrics
self.val_steps = []
self.train_slide_acc = []
self.train_steps = []
self.top1_acc_metric = nn.ModuleList([MulticlassAccuracy(
self.audiolm.code_size,
top_k=1,
average="micro", multidim_average="global",
ignore_index=self.cfg.lm.code_size, # ignore EOS token prediction
) for _ in range(self.audiolm.code_depth)])
self.top10_acc_metric = nn.ModuleList([MulticlassAccuracy(
self.audiolm.code_size,
top_k=10,
average="micro", multidim_average="global",
ignore_index=self.cfg.lm.code_size,
) for _ in range(self.audiolm.code_depth)])
self.epoch = 0
print("++++++++++++++++ training <song> +++++++++++++++++")
# TODO: move this part to loader
def generate_mask_and_end_token(self, x, sequence_lengths, end_id=16384):
batch_size = sequence_lengths.size(0)
max_length = x.size(2)
# pad one frame, if the maximum sequence length is equal to the input length
if max_length == sequence_lengths.max():
x = F.pad(x, (0, 1), value=end_id)
max_length = x.size(2)
if max_length <= sequence_lengths.max() + 1:
sequence_lengths = sequence_lengths - (sequence_lengths.max()+1 - max_length)
# Add end token to x according to the sequence length
x[torch.arange(batch_size), :, sequence_lengths] = end_id
sequence_lengths += 1
mask = torch.arange(max_length).expand(batch_size, max_length) < sequence_lengths.unsqueeze(1)
mask = mask.to(x.device)
mask_3d = mask.unsqueeze(1).expand(batch_size, x.size(1), max_length)
x = torch.where(mask_3d, x, end_id+1)
return x, mask_3d
def get_time(self):
# 获取当前的日期和时间
now = datetime.now()
# 使用strftime函数格式化日期和时间
formatted_now = now.strftime("%Y-%m-%d %H:%M:%S.%f")
return formatted_now
class CosineLRScheduler(_LRScheduler):#
"""Cosine LR scheduler.
Args:
optimizer (Optimizer): Torch optimizer.
warmup_steps (int): Number of warmup steps.
total_steps (int): Total number of steps.
lr_min_ratio (float): Minimum learning rate.
cycle_length (float): Cycle length.
"""
def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int,
lr_min_ratio: float = 0.0, cycle_length: float = 1.0):
self.warmup_steps = warmup_steps
assert self.warmup_steps >= 0
self.total_steps = total_steps
assert self.total_steps >= 0
self.lr_min_ratio = lr_min_ratio
self.cycle_length = cycle_length
super().__init__(optimizer)
def _get_sched_lr(self, lr: float, step: int):
if step < self.warmup_steps:
lr_ratio = step / self.warmup_steps
lr = lr_ratio * lr
elif step <= self.total_steps:
s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \
(1. + math.cos(math.pi * s / self.cycle_length))
lr = lr_ratio * lr
else:
lr_ratio = self.lr_min_ratio
lr = lr_ratio * lr
return lr
def get_lr(self):
return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs]
|