mlobj's picture
Upload 14 files
29a525e
raw
history blame
1.33 kB
import argparse
import hparams
import utils
import multiresunet_model
import tensorflow as tf
import numpy as np
if __name__ == '__main__':
args = argparse.ArgumentParser()
args.add_argument('Path',metavar='path',type=str,help='Path to DSD100 pickled spectrograms. See preprocess_data.py for more details')
args.add_argument('Source',metavar='source',type=str,help='Desired source to separate')
args.add_argument('Spectrum',metavar='spectrum',type=str,help='Low (lf) or High (hf) frequencies training')
args.add_argument('Outpath',metavar='model_out_path',type=str,help='Path to save the model to')
### Parse Args ###
args = args.parse_args()
path = args.Path
source = args.Source
spectrum = args.Spectrum
output_path = args.Outpath
### Load Data ###
x = np.load(path + 'mixture_' + spectrum + '.npy')
y = np.load(path + source + '_' + spectrum + '.npy')
### Construct model ###
model = multiresunet_model.Steminator((hparams.frequency_bins,hparams.chunk_size,hparams.n_channels))
optimizer = tf.keras.optimizers.Adam(lr = hparams.learning_rate)
model.compile(optimizer, loss='mean_absolute_error')
### Training ###
model.fit(x,y,epochs = hparams.epochs, batch_size = hparams.batch_size)
### Save model ###
model.save(output_path + source + '_' + spectrum + '.h5')