|
import os |
|
import sys |
|
sys.path.insert(1, os.path.join(sys.path[0], '../utils')) |
|
import numpy as np |
|
import argparse |
|
import h5py |
|
import math |
|
import time |
|
import logging |
|
import matplotlib.pyplot as plt |
|
|
|
import torch |
|
torch.backends.cudnn.benchmark=True |
|
torch.manual_seed(0) |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
import torch.utils.data |
|
|
|
from utilities import get_filename |
|
from models import * |
|
import config |
|
|
|
|
|
class Transfer_Cnn14(nn.Module): |
|
def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, |
|
fmax, classes_num, freeze_base): |
|
"""Classifier for a new task using pretrained Cnn14 as a sub module. |
|
""" |
|
super(Transfer_Cnn14, self).__init__() |
|
audioset_classes_num = 527 |
|
|
|
self.base = Cnn14(sample_rate, window_size, hop_size, mel_bins, fmin, |
|
fmax, audioset_classes_num) |
|
|
|
|
|
self.fc_transfer = nn.Linear(2048, classes_num, bias=True) |
|
|
|
if freeze_base: |
|
|
|
for param in self.base.parameters(): |
|
param.requires_grad = False |
|
|
|
self.init_weights() |
|
|
|
def init_weights(self): |
|
init_layer(self.fc_transfer) |
|
|
|
def load_from_pretrain(self, pretrained_checkpoint_path): |
|
checkpoint = torch.load(pretrained_checkpoint_path) |
|
self.base.load_state_dict(checkpoint['model']) |
|
|
|
def forward(self, input, mixup_lambda=None): |
|
"""Input: (batch_size, data_length) |
|
""" |
|
output_dict = self.base(input, mixup_lambda) |
|
embedding = output_dict['embedding'] |
|
|
|
clipwise_output = torch.log_softmax(self.fc_transfer(embedding), dim=-1) |
|
output_dict['clipwise_output'] = clipwise_output |
|
|
|
return output_dict |
|
|
|
|
|
def train(args): |
|
|
|
|
|
sample_rate = args.sample_rate |
|
window_size = args.window_size |
|
hop_size = args.hop_size |
|
mel_bins = args.mel_bins |
|
fmin = args.fmin |
|
fmax = args.fmax |
|
model_type = args.model_type |
|
pretrained_checkpoint_path = args.pretrained_checkpoint_path |
|
freeze_base = args.freeze_base |
|
device = 'cuda' if (args.cuda and torch.cuda.is_available()) else 'cpu' |
|
|
|
classes_num = config.classes_num |
|
pretrain = True if pretrained_checkpoint_path else False |
|
|
|
|
|
Model = eval(model_type) |
|
model = Model(sample_rate, window_size, hop_size, mel_bins, fmin, fmax, |
|
classes_num, freeze_base) |
|
|
|
|
|
if pretrain: |
|
logging.info('Load pretrained model from {}'.format(pretrained_checkpoint_path)) |
|
model.load_from_pretrain(pretrained_checkpoint_path) |
|
|
|
|
|
print('GPU number: {}'.format(torch.cuda.device_count())) |
|
model = torch.nn.DataParallel(model) |
|
|
|
if 'cuda' in device: |
|
model.to(device) |
|
|
|
print('Load pretrained model successfully!') |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description='Example of parser. ') |
|
subparsers = parser.add_subparsers(dest='mode') |
|
|
|
|
|
parser_train = subparsers.add_parser('train') |
|
parser_train.add_argument('--sample_rate', type=int, required=True) |
|
parser_train.add_argument('--window_size', type=int, required=True) |
|
parser_train.add_argument('--hop_size', type=int, required=True) |
|
parser_train.add_argument('--mel_bins', type=int, required=True) |
|
parser_train.add_argument('--fmin', type=int, required=True) |
|
parser_train.add_argument('--fmax', type=int, required=True) |
|
parser_train.add_argument('--model_type', type=str, required=True) |
|
parser_train.add_argument('--pretrained_checkpoint_path', type=str) |
|
parser_train.add_argument('--freeze_base', action='store_true', default=False) |
|
parser_train.add_argument('--cuda', action='store_true', default=False) |
|
|
|
|
|
args = parser.parse_args() |
|
args.filename = get_filename(__file__) |
|
|
|
if args.mode == 'train': |
|
train(args) |
|
|
|
else: |
|
raise Exception('Error argument!') |