|
import yaml
|
|
import random
|
|
import argparse
|
|
import os
|
|
import time
|
|
from tqdm import tqdm
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
|
|
from accelerate import Accelerator
|
|
from diffusers import DDIMScheduler
|
|
|
|
from configs.plugin import get_params
|
|
from model.p2e_cross import P2E_Cross
|
|
|
|
from openvoice.api import ToneColorConverter
|
|
from transformers import T5Tokenizer, T5EncoderModel
|
|
from inference import eval_plugin_light
|
|
from dataset.dreamvc import DreamData
|
|
|
|
from utils import minmax_norm_diff, reverse_minmax_norm_diff
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
|
parser.add_argument('--config-name', type=str, default='Plugin_base')
|
|
|
|
|
|
parser.add_argument("--amp", type=str, default='fp16')
|
|
parser.add_argument('--epochs', type=int, default=100)
|
|
parser.add_argument('--batch-size', type=int, default=32)
|
|
parser.add_argument('--num-workers', type=int, default=4)
|
|
parser.add_argument('--num-threads', type=int, default=1)
|
|
parser.add_argument('--save-every', type=int, default=5)
|
|
|
|
|
|
parser.add_argument('--random-seed', type=int, default=2023)
|
|
parser.add_argument('--log-step', type=int, default=200)
|
|
parser.add_argument('--log-dir', type=str, default='../logs/')
|
|
parser.add_argument('--save-dir', type=str, default='../ckpts/')
|
|
|
|
args = parser.parse_args()
|
|
params = get_params(args.config_name)
|
|
args.log_dir = args.log_dir + args.config_name + '/'
|
|
|
|
with open('model/p2e_cross.yaml', 'r') as fp:
|
|
config = yaml.safe_load(fp)
|
|
|
|
if os.path.exists(args.save_dir + args.config_name) is False:
|
|
os.makedirs(args.save_dir + args.config_name)
|
|
|
|
if os.path.exists(args.log_dir) is False:
|
|
os.makedirs(args.log_dir)
|
|
|
|
if __name__ == '__main__':
|
|
|
|
random.seed(args.random_seed)
|
|
torch.manual_seed(args.random_seed)
|
|
|
|
|
|
torch.set_num_threads(args.num_threads)
|
|
if torch.cuda.is_available():
|
|
args.device = 'cuda'
|
|
torch.cuda.manual_seed(args.random_seed)
|
|
torch.cuda.manual_seed_all(args.random_seed)
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
if torch.backends.cudnn.is_available():
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
torch.backends.cudnn.benchmark = False
|
|
else:
|
|
args.device = 'cpu'
|
|
|
|
train_set = DreamData(data_dir='../prepare/spk/', meta_dir='../prepare/plugin_meta.csv',
|
|
subset='train', prompt_dir='../prepare/prompts.csv',)
|
|
train_loader = DataLoader(train_set, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True)
|
|
|
|
|
|
accelerator = Accelerator(mixed_precision=args.amp)
|
|
|
|
|
|
ckpt_converter = '../prepare/checkpoints_v2/converter'
|
|
vc_model = ToneColorConverter(f'{ckpt_converter}/config.json', device='cuda')
|
|
vc_model.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
|
|
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained(params.text_encoder.model)
|
|
text_encoder = T5EncoderModel.from_pretrained(params.text_encoder.model).to(accelerator.device)
|
|
text_encoder.eval()
|
|
|
|
|
|
model = P2E_Cross(config['diffwrap']).to(accelerator.device)
|
|
|
|
|
|
total_params = sum([param.nelement() for param in model.parameters()])
|
|
print("Number of parameter: %.2fM" % (total_params / 1e6))
|
|
|
|
noise_scheduler = DDIMScheduler(num_train_timesteps=params.diff.num_train_steps,
|
|
beta_start=params.diff.beta_start, beta_end=params.diff.beta_end,
|
|
rescale_betas_zero_snr=True,
|
|
timestep_spacing="trailing",
|
|
clip_sample=False,
|
|
prediction_type='v_prediction')
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(),
|
|
lr=params.opt.learning_rate,
|
|
betas=(params.opt.beta1, params.opt.beta2),
|
|
weight_decay=params.opt.weight_decay,
|
|
eps=params.opt.adam_epsilon,)
|
|
loss_func = torch.nn.MSELoss()
|
|
|
|
model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
|
|
|
|
global_step = 0
|
|
losses = 0
|
|
|
|
if accelerator.is_main_process:
|
|
eval_plugin_light(vc_model, [tokenizer, text_encoder],
|
|
model, noise_scheduler, (1, 256, 1),
|
|
val_meta='../prepare/val_meta.csv',
|
|
val_folder='/home/jerry/Projects/Dataset/Speech/vctk_libritts/',
|
|
guidance_scale=3, guidance_rescale=0.0,
|
|
ddim_steps=100, eta=1, random_seed=2024,
|
|
device=accelerator.device,
|
|
epoch='test', save_path=args.log_dir + 'output/', val_num=1)
|
|
accelerator.wait_for_everyone()
|
|
|
|
for epoch in range(args.epochs):
|
|
model.train()
|
|
for step, batch in enumerate(tqdm(train_loader)):
|
|
spk_embed, prompt = batch
|
|
|
|
with torch.no_grad():
|
|
|
|
text_batch = tokenizer(prompt,
|
|
max_length=32,
|
|
padding='max_length', truncation=True, return_tensors="pt")
|
|
text, text_mask = text_batch.input_ids.to(spk_embed.device), \
|
|
text_batch.attention_mask.to(spk_embed.device)
|
|
text = text_encoder(input_ids=text, attention_mask=text_mask)[0]
|
|
|
|
|
|
|
|
|
|
noise = torch.randn(spk_embed.shape).to(accelerator.device)
|
|
timesteps = torch.randint(0, params.diff.num_train_steps, (noise.shape[0],),
|
|
device=accelerator.device, ).long()
|
|
noisy_target = noise_scheduler.add_noise(spk_embed, noise, timesteps)
|
|
|
|
velocity = noise_scheduler.get_velocity(spk_embed, noise, timesteps)
|
|
|
|
|
|
pred = model(noisy_target, timesteps, text, text_mask, train_cfg=True, cfg_prob=0.25)
|
|
|
|
if params.diff.v_prediction:
|
|
loss = loss_func(pred, velocity)
|
|
else:
|
|
loss = loss_func(pred, noise)
|
|
|
|
accelerator.backward(loss)
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
global_step += 1
|
|
losses += loss.item()
|
|
|
|
if accelerator.is_main_process:
|
|
if global_step % args.log_step == 0:
|
|
n = open(args.log_dir + 'diff_vc.txt', mode='a')
|
|
n.write(time.asctime(time.localtime(time.time())))
|
|
n.write('\n')
|
|
n.write('Epoch: [{}][{}] Batch: [{}][{}] Loss: {:.6f}\n'.format(
|
|
epoch + 1, args.epochs, step + 1, len(train_loader), losses / args.log_step))
|
|
n.close()
|
|
losses = 0.0
|
|
|
|
accelerator.wait_for_everyone()
|
|
|
|
if (epoch + 1) % args.save_every == 0:
|
|
if accelerator.is_main_process:
|
|
eval_plugin_light(vc_model, [tokenizer, text_encoder],
|
|
model, noise_scheduler, (1, 256, 1),
|
|
val_meta='../prepare/val_meta.csv',
|
|
val_folder='/home/jerry/Projects/Dataset/Speech/vctk_libritts/',
|
|
guidance_scale=3, guidance_rescale=0.0,
|
|
ddim_steps=50, eta=1, random_seed=2024,
|
|
device=accelerator.device,
|
|
epoch=epoch, save_path=args.log_dir + 'output/', val_num=10)
|
|
|
|
unwrapped_unet = accelerator.unwrap_model(model)
|
|
accelerator.save({
|
|
"model": unwrapped_unet.state_dict(),
|
|
}, args.save_dir + args.config_name + '/' + str(epoch) + '.pt')
|
|
|