import yaml import random import argparse import os import time from tqdm import tqdm from pathlib import Path import torch from torch.utils.data import DataLoader from accelerate import Accelerator from diffusers import DDIMScheduler from configs.plugin import get_params from model.p2e_cross import P2E_Cross # from modules.speaker_encoder.encoder import inference as spk_encoder from openvoice.api import ToneColorConverter from transformers import T5Tokenizer, T5EncoderModel from inference import eval_plugin_light from dataset.dreamvc import DreamData # from vc_wrapper import load_diffvc_models from utils import minmax_norm_diff, reverse_minmax_norm_diff parser = argparse.ArgumentParser() # config settings parser.add_argument('--config-name', type=str, default='Plugin_freevc') # training settings parser.add_argument("--amp", type=str, default='fp16') parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--num-workers', type=int, default=4) parser.add_argument('--num-threads', type=int, default=1) parser.add_argument('--save-every', type=int, default=5) # log and random seed parser.add_argument('--random-seed', type=int, default=2023) parser.add_argument('--log-step', type=int, default=200) parser.add_argument('--log-dir', type=str, default='../logs/') parser.add_argument('--save-dir', type=str, default='../ckpts/') args = parser.parse_args() params = get_params(args.config_name) args.log_dir = args.log_dir + args.config_name + '/' with open('model/p2e_cross.yaml', 'r') as fp: config = yaml.safe_load(fp) if os.path.exists(args.save_dir + args.config_name) is False: os.makedirs(args.save_dir + args.config_name) if os.path.exists(args.log_dir) is False: os.makedirs(args.log_dir) if __name__ == '__main__': # Fix the random seed random.seed(args.random_seed) torch.manual_seed(args.random_seed) # Set device torch.set_num_threads(args.num_threads) if torch.cuda.is_available(): args.device = 'cuda' torch.cuda.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) torch.backends.cuda.matmul.allow_tf32 = True if torch.backends.cudnn.is_available(): torch.backends.cudnn.deterministic = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = False else: args.device = 'cpu' train_set = DreamData(data_dir='../prepare_freevc/spk/', meta_dir='../prepare/plugin_meta.csv', subset='train', prompt_dir='../prepare/prompts.csv',) train_loader = DataLoader(train_set, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True) # use accelerator for multi-gpu training accelerator = Accelerator(mixed_precision=args.amp) # vc model ckpt_converter = '../prepare/checkpoints_v2/converter' vc_model = ToneColorConverter(f'{ckpt_converter}/config.json', device='cuda') vc_model.load_ckpt(f'{ckpt_converter}/checkpoint.pth') # text encoder tokenizer = T5Tokenizer.from_pretrained(params.text_encoder.model) text_encoder = T5EncoderModel.from_pretrained(params.text_encoder.model).to(accelerator.device) text_encoder.eval() # main U-Net model = P2E_Cross(config['diffwrap']).to(accelerator.device) # model.load_state_dict(torch.load('64.pt')['model']) total_params = sum([param.nelement() for param in model.parameters()]) print("Number of parameter: %.2fM" % (total_params / 1e6)) noise_scheduler = DDIMScheduler(num_train_timesteps=params.diff.num_train_steps, beta_start=params.diff.beta_start, beta_end=params.diff.beta_end, rescale_betas_zero_snr=True, timestep_spacing="trailing", clip_sample=False, prediction_type='v_prediction') optimizer = torch.optim.AdamW(model.parameters(), lr=params.opt.learning_rate, betas=(params.opt.beta1, params.opt.beta2), weight_decay=params.opt.weight_decay, eps=params.opt.adam_epsilon,) loss_func = torch.nn.MSELoss() model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader) global_step = 0 losses = 0 if accelerator.is_main_process: eval_plugin_light(vc_model, [tokenizer, text_encoder], model, noise_scheduler, (1, 256, 1), val_meta='../prepare/val_meta.csv', val_folder='/home/jerry/Projects/Dataset/Speech/vctk_libritts/', guidance_scale=3, guidance_rescale=0.0, ddim_steps=100, eta=1, random_seed=2024, device=accelerator.device, epoch='test', save_path=args.log_dir + 'output/', val_num=1) accelerator.wait_for_everyone() for epoch in range(args.epochs): model.train() for step, batch in enumerate(tqdm(train_loader)): spk_embed, prompt = batch with torch.no_grad(): # audio_clip = minmax_norm_diff(logmel(audio_clip)).unsqueeze(1) text_batch = tokenizer(prompt, max_length=32, padding='max_length', truncation=True, return_tensors="pt") text, text_mask = text_batch.input_ids.to(spk_embed.device), \ text_batch.attention_mask.to(spk_embed.device) text = text_encoder(input_ids=text, attention_mask=text_mask)[0] # spk_embed = minmax_norm_diff(spk_embed, vmax=0.5, vmin=0.0) # adding noise noise = torch.randn(spk_embed.shape).to(accelerator.device) timesteps = torch.randint(0, params.diff.num_train_steps, (noise.shape[0],), device=accelerator.device, ).long() noisy_target = noise_scheduler.add_noise(spk_embed, noise, timesteps) # v prediction - model output velocity = noise_scheduler.get_velocity(spk_embed, noise, timesteps) # inference pred = model(noisy_target, timesteps, text, text_mask, train_cfg=True, cfg_prob=0.25) # backward if params.diff.v_prediction: loss = loss_func(pred, velocity) else: loss = loss_func(pred, noise) accelerator.backward(loss) optimizer.step() optimizer.zero_grad() global_step += 1 losses += loss.item() if accelerator.is_main_process: if global_step % args.log_step == 0: n = open(args.log_dir + 'diff_vc.txt', mode='a') n.write(time.asctime(time.localtime(time.time()))) n.write('\n') n.write('Epoch: [{}][{}] Batch: [{}][{}] Loss: {:.6f}\n'.format( epoch + 1, args.epochs, step + 1, len(train_loader), losses / args.log_step)) n.close() losses = 0.0 accelerator.wait_for_everyone() if (epoch + 1) % args.save_every == 0: if accelerator.is_main_process: eval_plugin_light(vc_model, [tokenizer, text_encoder], model, noise_scheduler, (1, 256, 1), val_meta='../prepare/val_meta.csv', val_folder='/home/jerry/Projects/Dataset/Speech/vctk_libritts/', guidance_scale=3, guidance_rescale=0.0, ddim_steps=50, eta=1, random_seed=2024, device=accelerator.device, epoch=epoch, save_path=args.log_dir + 'output/', val_num=10) unwrapped_unet = accelerator.unwrap_model(model) accelerator.save({ "model": unwrapped_unet.state_dict(), }, args.save_dir + args.config_name + '/' + str(epoch) + '.pt')