|
import os |
|
import sys |
|
sys.path.insert(1, os.path.join(sys.path[0], '../utils')) |
|
import numpy as np |
|
import argparse |
|
import time |
|
import logging |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
import torch.utils.data |
|
|
|
from utilities import (create_folder, get_filename, create_logging, Mixup, |
|
StatisticsContainer) |
|
from models import (PVT, PVT2, PVT_lr, PVT_nopretrain, PVT_2layer, Cnn14, Cnn14_no_specaug, Cnn14_no_dropout, |
|
Cnn6, Cnn10, ResNet22, ResNet38, ResNet54, Cnn14_emb512, Cnn14_emb128, |
|
Cnn14_emb32, MobileNetV1, MobileNetV2, LeeNet11, LeeNet24, DaiNet19, |
|
Res1dNet31, Res1dNet51, Wavegram_Cnn14, Wavegram_Logmel_Cnn14, |
|
Wavegram_Logmel128_Cnn14, Cnn14_16k, Cnn14_8k, Cnn14_mel32, Cnn14_mel128, |
|
Cnn14_mixup_time_domain, Cnn14_DecisionLevelMax, Cnn14_DecisionLevelAtt, Cnn6_Transformer, GLAM, GLAM2, GLAM3, Cnn4, EAT) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from pytorch_utils import (move_data_to_device, count_parameters, count_flops, |
|
do_mixup) |
|
from data_generator import (AudioSetDataset, TrainSampler, BalancedTrainSampler, |
|
AlternateTrainSampler, EvaluateSampler, collate_fn) |
|
from evaluate import Evaluator |
|
import config |
|
from losses import get_loss_func |
|
|
|
|
|
def train(args): |
|
"""Train AudioSet tagging model. |
|
|
|
Args: |
|
dataset_dir: str |
|
workspace: str |
|
data_type: 'balanced_train' | 'full_train' |
|
window_size: int |
|
hop_size: int |
|
mel_bins: int |
|
model_type: str |
|
loss_type: 'clip_bce' |
|
balanced: 'none' | 'balanced' | 'alternate' |
|
augmentation: 'none' | 'mixup' |
|
batch_size: int |
|
learning_rate: float |
|
resume_iteration: int |
|
early_stop: int |
|
accumulation_steps: int |
|
cuda: bool |
|
""" |
|
|
|
|
|
workspace = args.workspace |
|
data_type = args.data_type |
|
sample_rate = args.sample_rate |
|
window_size = args.window_size |
|
hop_size = args.hop_size |
|
mel_bins = args.mel_bins |
|
fmin = args.fmin |
|
fmax = args.fmax |
|
model_type = args.model_type |
|
loss_type = args.loss_type |
|
balanced = args.balanced |
|
augmentation = args.augmentation |
|
batch_size = args.batch_size |
|
learning_rate = args.learning_rate |
|
resume_iteration = args.resume_iteration |
|
early_stop = args.early_stop |
|
device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu') |
|
filename = args.filename |
|
|
|
num_workers = 8 |
|
clip_samples = config.clip_samples |
|
classes_num = config.classes_num |
|
loss_func = get_loss_func(loss_type) |
|
|
|
|
|
black_list_csv = None |
|
|
|
train_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes', |
|
'{}.h5'.format(data_type)) |
|
|
|
eval_bal_indexes_hdf5_path = os.path.join(workspace, |
|
'hdf5s', 'indexes', 'balanced_train.h5') |
|
|
|
eval_test_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes', |
|
'eval.h5') |
|
|
|
checkpoints_dir = os.path.join(workspace, 'checkpoints', filename, |
|
'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'.format( |
|
sample_rate, window_size, hop_size, mel_bins, fmin, fmax), |
|
'data_type={}'.format(data_type), model_type, |
|
'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced), |
|
'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size)) |
|
create_folder(checkpoints_dir) |
|
|
|
statistics_path = os.path.join(workspace, 'statistics', filename, |
|
'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'.format( |
|
sample_rate, window_size, hop_size, mel_bins, fmin, fmax), |
|
'data_type={}'.format(data_type), model_type, |
|
'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced), |
|
'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size), |
|
'statistics.pkl') |
|
create_folder(os.path.dirname(statistics_path)) |
|
|
|
logs_dir = os.path.join(workspace, 'logs', filename, |
|
'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'.format( |
|
sample_rate, window_size, hop_size, mel_bins, fmin, fmax), |
|
'data_type={}'.format(data_type), model_type, |
|
'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced), |
|
'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size)) |
|
|
|
create_logging(logs_dir, filemode='w') |
|
logging.info(args) |
|
|
|
if 'cuda' in str(device): |
|
logging.info('Using GPU.') |
|
device = 'cuda' |
|
else: |
|
logging.info('Using CPU. Set --cuda flag to use GPU.') |
|
device = 'cpu' |
|
|
|
|
|
Model = eval(model_type) |
|
model = Model(sample_rate=sample_rate, window_size=window_size, |
|
hop_size=hop_size, mel_bins=mel_bins, fmin=fmin, fmax=fmax, |
|
classes_num=classes_num) |
|
total = sum(p.numel() for p in model.parameters()) |
|
print("Total params: %.2fM" % (total/1e6)) |
|
logging.info("Total params: %.2fM" % (total/1e6)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset = AudioSetDataset(sample_rate=sample_rate) |
|
|
|
|
|
if balanced == 'none': |
|
Sampler = TrainSampler |
|
elif balanced == 'balanced': |
|
Sampler = BalancedTrainSampler |
|
elif balanced == 'alternate': |
|
Sampler = AlternateTrainSampler |
|
|
|
train_sampler = Sampler( |
|
indexes_hdf5_path=train_indexes_hdf5_path, |
|
batch_size=batch_size * 2 if 'mixup' in augmentation else batch_size, |
|
black_list_csv=black_list_csv) |
|
|
|
|
|
eval_bal_sampler = EvaluateSampler( |
|
indexes_hdf5_path=eval_bal_indexes_hdf5_path, batch_size=batch_size) |
|
|
|
eval_test_sampler = EvaluateSampler( |
|
indexes_hdf5_path=eval_test_indexes_hdf5_path, batch_size=batch_size) |
|
|
|
|
|
train_loader = torch.utils.data.DataLoader(dataset=dataset, |
|
batch_sampler=train_sampler, collate_fn=collate_fn, |
|
num_workers=num_workers, pin_memory=True) |
|
|
|
eval_bal_loader = torch.utils.data.DataLoader(dataset=dataset, |
|
batch_sampler=eval_bal_sampler, collate_fn=collate_fn, |
|
num_workers=num_workers, pin_memory=True) |
|
|
|
eval_test_loader = torch.utils.data.DataLoader(dataset=dataset, |
|
batch_sampler=eval_test_sampler, collate_fn=collate_fn, |
|
num_workers=num_workers, pin_memory=True) |
|
mix=0.5 |
|
if 'mixup' in augmentation: |
|
mixup_augmenter = Mixup(mixup_alpha=mix) |
|
print(mix) |
|
logging.info(mix) |
|
|
|
|
|
evaluator = Evaluator(model=model) |
|
|
|
|
|
statistics_container = StatisticsContainer(statistics_path) |
|
|
|
|
|
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.05, amsgrad=True) |
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=4, min_lr=1e-06, verbose=True) |
|
train_bgn_time = time.time() |
|
|
|
|
|
if resume_iteration > 0: |
|
resume_checkpoint_path = os.path.join(workspace, 'checkpoints', filename, |
|
'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'.format( |
|
sample_rate, window_size, hop_size, mel_bins, fmin, fmax), |
|
'data_type={}'.format(data_type), model_type, |
|
'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced), |
|
'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size), |
|
'{}_iterations.pth'.format(resume_iteration)) |
|
|
|
logging.info('Loading checkpoint {}'.format(resume_checkpoint_path)) |
|
checkpoint = torch.load(resume_checkpoint_path) |
|
model.load_state_dict(checkpoint['model']) |
|
train_sampler.load_state_dict(checkpoint['sampler']) |
|
statistics_container.load_state_dict(resume_iteration) |
|
iteration = checkpoint['iteration'] |
|
|
|
else: |
|
iteration = 0 |
|
|
|
|
|
print('GPU number: {}'.format(torch.cuda.device_count())) |
|
model = torch.nn.DataParallel(model) |
|
|
|
if 'cuda' in str(device): |
|
model.to(device) |
|
|
|
if resume_iteration: |
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
scheduler.load_state_dict(checkpoint['scheduler']) |
|
print(optimizer.state_dict()['param_groups'][0]['lr']) |
|
|
|
time1 = time.time() |
|
|
|
for batch_data_dict in train_loader: |
|
"""batch_data_dict: { |
|
'audio_name': (batch_size [*2 if mixup],), |
|
'waveform': (batch_size [*2 if mixup], clip_samples), |
|
'target': (batch_size [*2 if mixup], classes_num), |
|
(ifexist) 'mixup_lambda': (batch_size * 2,)} |
|
""" |
|
|
|
|
|
if (iteration % 2000 == 0 and iteration >= resume_iteration) or (iteration == 0): |
|
train_fin_time = time.time() |
|
|
|
bal_statistics = evaluator.evaluate(eval_bal_loader) |
|
test_statistics = evaluator.evaluate(eval_test_loader) |
|
|
|
logging.info('Validate bal mAP: {:.3f}'.format( |
|
np.mean(bal_statistics['average_precision']))) |
|
|
|
logging.info('Validate test mAP: {:.3f}'.format( |
|
np.mean(test_statistics['average_precision']))) |
|
|
|
statistics_container.append(iteration, bal_statistics, data_type='bal') |
|
statistics_container.append(iteration, test_statistics, data_type='test') |
|
statistics_container.dump() |
|
|
|
train_time = train_fin_time - train_bgn_time |
|
validate_time = time.time() - train_fin_time |
|
|
|
logging.info( |
|
'iteration: {}, train time: {:.3f} s, validate time: {:.3f} s' |
|
''.format(iteration, train_time, validate_time)) |
|
|
|
logging.info('------------------------------------') |
|
|
|
train_bgn_time = time.time() |
|
|
|
|
|
if iteration % 2000 == 0: |
|
checkpoint = { |
|
'iteration': iteration, |
|
'model': model.module.state_dict(), |
|
'sampler': train_sampler.state_dict(), |
|
'optimizer': optimizer.state_dict(), |
|
'scheduler': scheduler.state_dict()} |
|
|
|
checkpoint_path = os.path.join( |
|
checkpoints_dir, '{}_iterations.pth'.format(iteration)) |
|
|
|
torch.save(checkpoint, checkpoint_path) |
|
logging.info('Model saved to {}'.format(checkpoint_path)) |
|
|
|
|
|
if 'mixup' in augmentation: |
|
batch_data_dict['mixup_lambda'] = mixup_augmenter.get_lambda( |
|
batch_size=len(batch_data_dict['waveform'])) |
|
|
|
|
|
for key in batch_data_dict.keys(): |
|
batch_data_dict[key] = move_data_to_device(batch_data_dict[key], device) |
|
|
|
|
|
model.train() |
|
|
|
if 'mixup' in augmentation: |
|
batch_output_dict = model(batch_data_dict['waveform'], |
|
batch_data_dict['mixup_lambda']) |
|
"""{'clipwise_output': (batch_size, classes_num), ...}""" |
|
|
|
batch_target_dict = {'target': do_mixup(batch_data_dict['target'], |
|
batch_data_dict['mixup_lambda'])} |
|
"""{'target': (batch_size, classes_num)}""" |
|
else: |
|
batch_output_dict = model(batch_data_dict['waveform'], None) |
|
"""{'clipwise_output': (batch_size, classes_num), ...}""" |
|
|
|
batch_target_dict = {'target': batch_data_dict['target']} |
|
"""{'target': (batch_size, classes_num)}""" |
|
|
|
|
|
loss = loss_func(batch_output_dict, batch_target_dict) |
|
|
|
loss.backward() |
|
|
|
optimizer.step() |
|
optimizer.zero_grad() |
|
|
|
if iteration % 10 == 0: |
|
print(iteration, loss) |
|
|
|
|
|
|
|
|
|
if iteration % 2000 == 0: |
|
scheduler.step(np.mean(test_statistics['average_precision'])) |
|
print(optimizer.state_dict()['param_groups'][0]['lr']) |
|
logging.info(optimizer.state_dict()['param_groups'][0]['lr']) |
|
|
|
|
|
if iteration == early_stop: |
|
break |
|
|
|
iteration += 1 |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
parser = argparse.ArgumentParser(description='Example of parser. ') |
|
subparsers = parser.add_subparsers(dest='mode') |
|
|
|
parser_train = subparsers.add_parser('train') |
|
parser_train.add_argument('--workspace', type=str, required=True) |
|
parser_train.add_argument('--data_type', type=str, default='full_train', choices=['balanced_train', 'full_train']) |
|
parser_train.add_argument('--sample_rate', type=int, default=32000) |
|
parser_train.add_argument('--window_size', type=int, default=1024) |
|
parser_train.add_argument('--hop_size', type=int, default=320) |
|
parser_train.add_argument('--mel_bins', type=int, default=64) |
|
parser_train.add_argument('--fmin', type=int, default=50) |
|
parser_train.add_argument('--fmax', type=int, default=14000) |
|
parser_train.add_argument('--model_type', type=str, required=True) |
|
parser_train.add_argument('--loss_type', type=str, default='clip_bce', choices=['clip_bce']) |
|
parser_train.add_argument('--balanced', type=str, default='balanced', choices=['none', 'balanced', 'alternate']) |
|
parser_train.add_argument('--augmentation', type=str, default='mixup', choices=['none', 'mixup']) |
|
parser_train.add_argument('--batch_size', type=int, default=32) |
|
parser_train.add_argument('--learning_rate', type=float, default=1e-3) |
|
parser_train.add_argument('--resume_iteration', type=int, default=0) |
|
parser_train.add_argument('--early_stop', type=int, default=1000000) |
|
parser_train.add_argument('--cuda', action='store_true', default=False) |
|
|
|
args = parser.parse_args() |
|
args.filename = get_filename(__file__) |
|
|
|
if args.mode == 'train': |
|
train(args) |
|
|
|
else: |
|
raise Exception('Error argument!') |