hieupt commited on
Commit
df578bd
·
verified ·
1 Parent(s): 7b18186

Upload predict.py

Browse files
Files changed (1) hide show
  1. predict.py +75 -0
predict.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import data.utils
5
+ import model.utils as model_utils
6
+
7
+ from test import predict_song
8
+ from model.waveunet import Waveunet
9
+
10
+ def main(args):
11
+ # MODEL
12
+ num_features = [args.features*i for i in range(1, args.levels+1)] if args.feature_growth == "add" else \
13
+ [args.features*2**i for i in range(0, args.levels)]
14
+ target_outputs = int(args.output_size * args.sr)
15
+ model = Waveunet(args.channels, num_features, args.channels, args.instruments, kernel_size=args.kernel_size,
16
+ target_output_size=target_outputs, depth=args.depth, strides=args.strides,
17
+ conv_type=args.conv_type, res=args.res, separate=args.separate)
18
+
19
+ if args.cuda:
20
+ model = model_utils.DataParallel(model)
21
+ print("move model to gpu")
22
+ model.cuda()
23
+
24
+ print("Loading model from checkpoint " + str(args.load_model))
25
+ state = model_utils.load_model(model, None, args.load_model, args.cuda)
26
+ print('Step', state['step'])
27
+
28
+ preds = predict_song(args, args.input, model)
29
+
30
+ output_folder = os.path.dirname(args.input) if args.output is None else args.output
31
+ for inst in preds.keys():
32
+ data.utils.write_wav(os.path.join(output_folder, os.path.basename(args.input) + "_" + inst + ".wav"), preds[inst], args.sr)
33
+
34
+ if __name__ == '__main__':
35
+ parser = argparse.ArgumentParser()
36
+ parser.add_argument('--instruments', type=str, nargs='+', default=["bass", "drums", "other", "vocals"],
37
+ help="List of instruments to separate (default: \"bass drums other vocals\")")
38
+ parser.add_argument('--cuda', action='store_true',
39
+ help='Use CUDA (default: False)')
40
+ parser.add_argument('--features', type=int, default=32,
41
+ help='Number of feature channels per layer')
42
+ parser.add_argument('--load_model', type=str, default='checkpoints/waveunet/model',
43
+ help='Reload a previously trained model')
44
+ parser.add_argument('--batch_size', type=int, default=4,
45
+ help="Batch size")
46
+ parser.add_argument('--levels', type=int, default=6,
47
+ help="Number of DS/US blocks")
48
+ parser.add_argument('--depth', type=int, default=1,
49
+ help="Number of convs per block")
50
+ parser.add_argument('--sr', type=int, default=44100,
51
+ help="Sampling rate")
52
+ parser.add_argument('--channels', type=int, default=2,
53
+ help="Number of input audio channels")
54
+ parser.add_argument('--kernel_size', type=int, default=5,
55
+ help="Filter width of kernels. Has to be an odd number")
56
+ parser.add_argument('--output_size', type=float, default=2.0,
57
+ help="Output duration")
58
+ parser.add_argument('--strides', type=int, default=4,
59
+ help="Strides in Waveunet")
60
+ parser.add_argument('--conv_type', type=str, default="gn",
61
+ help="Type of convolution (normal, BN-normalised, GN-normalised): normal/bn/gn")
62
+ parser.add_argument('--res', type=str, default="fixed",
63
+ help="Resampling strategy: fixed sinc-based lowpass filtering or learned conv layer: fixed/learned")
64
+ parser.add_argument('--separate', type=int, default=1,
65
+ help="Train separate model for each source (1) or only one (0)")
66
+ parser.add_argument('--feature_growth', type=str, default="double",
67
+ help="How the features in each layer should grow, either (add) the initial number of features each time, or multiply by 2 (double)")
68
+
69
+ parser.add_argument('--input', type=str, default=os.path.join("audio_examples", "Cristina Vane - So Easy", "mix.mp3"),
70
+ help="Path to input mixture to be separated")
71
+ parser.add_argument('--output', type=str, default=None, help="Output path (same folder as input path if not set)")
72
+
73
+ args = parser.parse_args()
74
+
75
+ main(args)