import argparse from datetime import datetime import json import logging import os import random import numpy as np import torch import torch.nn as nn import torch.utils.data from lib import dataset from lib import nets from lib import spec_utils def setup_logger(name, logfile='LOGFILENAME.log'): logger = logging.getLogger(name) logger.setLevel(logging.DEBUG) logger.propagate = False fh = logging.FileHandler(logfile, encoding='utf8') fh.setLevel(logging.DEBUG) fh_formatter = logging.Formatter( '%(asctime)s - %(levelname)s - %(message)s') fh.setFormatter(fh_formatter) sh = logging.StreamHandler() sh.setLevel(logging.INFO) logger.addHandler(fh) logger.addHandler(sh) return logger def train_epoch(dataloader, model, device, optimizer, accumulation_steps): model.train() sum_loss = 0 crit = nn.L1Loss() for itr, (X_batch, y_batch) in enumerate(dataloader): X_batch = X_batch.to(device) y_batch = y_batch.to(device) pred, aux = model(X_batch) loss_main = crit(pred * X_batch, y_batch) loss_aux = crit(aux * X_batch, y_batch) loss = loss_main * 0.8 + loss_aux * 0.2 accum_loss = loss / accumulation_steps accum_loss.backward() if (itr + 1) % accumulation_steps == 0: optimizer.step() model.zero_grad() sum_loss += loss.item() * len(X_batch) # the rest batch if (itr + 1) % accumulation_steps != 0: optimizer.step() model.zero_grad() return sum_loss / len(dataloader.dataset) def validate_epoch(dataloader, model, device): model.eval() sum_loss = 0 crit = nn.L1Loss() with torch.no_grad(): for X_batch, y_batch in dataloader: X_batch = X_batch.to(device) y_batch = y_batch.to(device) pred = model.predict(X_batch) y_batch = spec_utils.crop_center(y_batch, pred) loss = crit(pred, y_batch) sum_loss += loss.item() * len(X_batch) return sum_loss / len(dataloader.dataset) def main(): p = argparse.ArgumentParser() p.add_argument('--gpu', '-g', type=int, default=-1) p.add_argument('--seed', '-s', type=int, default=2019) p.add_argument('--sr', '-r', type=int, default=44100) p.add_argument('--hop_length', '-H', type=int, default=1024) p.add_argument('--n_fft', '-f', type=int, default=2048) p.add_argument('--dataset', '-d', required=True) p.add_argument('--split_mode', '-S', type=str, choices=['random', 'subdirs'], default='random') p.add_argument('--learning_rate', '-l', type=float, default=0.001) p.add_argument('--lr_min', type=float, default=0.0001) p.add_argument('--lr_decay_factor', type=float, default=0.9) p.add_argument('--lr_decay_patience', type=int, default=6) p.add_argument('--batchsize', '-B', type=int, default=4) p.add_argument('--accumulation_steps', '-A', type=int, default=1) p.add_argument('--cropsize', '-C', type=int, default=256) p.add_argument('--patches', '-p', type=int, default=16) p.add_argument('--val_rate', '-v', type=float, default=0.2) p.add_argument('--val_filelist', '-V', type=str, default=None) p.add_argument('--val_batchsize', '-b', type=int, default=6) p.add_argument('--val_cropsize', '-c', type=int, default=256) p.add_argument('--num_workers', '-w', type=int, default=6) p.add_argument('--epoch', '-E', type=int, default=200) p.add_argument('--reduction_rate', '-R', type=float, default=0.0) p.add_argument('--reduction_level', '-L', type=float, default=0.2) p.add_argument('--mixup_rate', '-M', type=float, default=0.0) p.add_argument('--mixup_alpha', '-a', type=float, default=1.0) p.add_argument('--pretrained_model', '-P', type=str, default=None) p.add_argument('--debug', action='store_true') args = p.parse_args() logger.debug(vars(args)) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) val_filelist = [] if args.val_filelist is not None: with open(args.val_filelist, 'r', encoding='utf8') as f: val_filelist = json.load(f) train_filelist, val_filelist = dataset.train_val_split( dataset_dir=args.dataset, split_mode=args.split_mode, val_rate=args.val_rate, val_filelist=val_filelist ) if args.debug: logger.info('### DEBUG MODE') train_filelist = train_filelist[:1] val_filelist = val_filelist[:1] elif args.val_filelist is None and args.split_mode == 'random': with open('val_{}.json'.format(timestamp), 'w', encoding='utf8') as f: json.dump(val_filelist, f, ensure_ascii=False) for i, (X_fname, y_fname) in enumerate(val_filelist): logger.info('{} {} {}'.format(i + 1, os.path.basename(X_fname), os.path.basename(y_fname))) device = torch.device('cpu') model = nets.CascadedNet(args.n_fft, 32, 128) if args.pretrained_model is not None: model.load_state_dict(torch.load(args.pretrained_model, map_location=device)) if torch.cuda.is_available() and args.gpu >= 0: device = torch.device('cuda:{}'.format(args.gpu)) model.to(device) optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=args.learning_rate ) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=args.lr_decay_factor, patience=args.lr_decay_patience, threshold=1e-6, min_lr=args.lr_min, verbose=True ) bins = args.n_fft // 2 + 1 freq_to_bin = 2 * bins / args.sr unstable_bins = int(200 * freq_to_bin) stable_bins = int(22050 * freq_to_bin) reduction_weight = np.concatenate([ np.linspace(0, 1, unstable_bins, dtype=np.float32)[:, None], np.linspace(1, 0, stable_bins - unstable_bins, dtype=np.float32)[:, None], np.zeros((bins - stable_bins, 1), dtype=np.float32), ], axis=0) * args.reduction_level training_set = dataset.make_training_set( filelist=train_filelist, sr=args.sr, hop_length=args.hop_length, n_fft=args.n_fft ) train_dataset = dataset.VocalRemoverTrainingSet( training_set * args.patches, cropsize=args.cropsize, reduction_rate=args.reduction_rate, reduction_weight=reduction_weight, mixup_rate=args.mixup_rate, mixup_alpha=args.mixup_alpha ) train_dataloader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=args.batchsize, shuffle=True, num_workers=args.num_workers ) patch_list = dataset.make_validation_set( filelist=val_filelist, cropsize=args.val_cropsize, sr=args.sr, hop_length=args.hop_length, n_fft=args.n_fft, offset=model.offset ) val_dataset = dataset.VocalRemoverValidationSet( patch_list=patch_list ) val_dataloader = torch.utils.data.DataLoader( dataset=val_dataset, batch_size=args.val_batchsize, shuffle=False, num_workers=args.num_workers ) log = [] best_loss = np.inf for epoch in range(args.epoch): logger.info('# epoch {}'.format(epoch)) train_loss = train_epoch(train_dataloader, model, device, optimizer, args.accumulation_steps) val_loss = validate_epoch(val_dataloader, model, device) logger.info( ' * training loss = {:.6f}, validation loss = {:.6f}' .format(train_loss, val_loss) ) scheduler.step(val_loss) if val_loss < best_loss: best_loss = val_loss logger.info(' * best validation loss') model_path = 'models/model_iter{}.pth'.format(epoch) torch.save(model.state_dict(), model_path) log.append([train_loss, val_loss]) with open('loss_{}.json'.format(timestamp), 'w', encoding='utf8') as f: json.dump(log, f, ensure_ascii=False) if __name__ == '__main__': timestamp = datetime.now().strftime('%Y%m%d%H%M%S') logger = setup_logger(__name__, 'train_{}.log'.format(timestamp)) try: main() except Exception as e: logger.exception(e)