hieupt commited on
Commit
7b18186
·
verified ·
1 Parent(s): 727d6ec

Upload cog_predict.py

Browse files
Files changed (1) hide show
  1. cog_predict.py +144 -0
cog_predict.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cog
3
+ import tempfile
4
+ import zipfile
5
+ from pathlib import Path
6
+ import argparse
7
+ import data.utils
8
+ import model.utils as model_utils
9
+ from test import predict_song
10
+ from model.waveunet import Waveunet
11
+
12
+
13
+ class waveunetPredictor(cog.Predictor):
14
+ def setup(self):
15
+ """Init wave u net model"""
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument(
18
+ "--instruments",
19
+ type=str,
20
+ nargs="+",
21
+ default=["bass", "drums", "other", "vocals"],
22
+ help='List of instruments to separate (default: "bass drums other vocals")',
23
+ )
24
+ parser.add_argument(
25
+ "--cuda", action="store_true", help="Use CUDA (default: False)"
26
+ )
27
+ parser.add_argument(
28
+ "--features",
29
+ type=int,
30
+ default=32,
31
+ help="Number of feature channels per layer",
32
+ )
33
+ parser.add_argument(
34
+ "--load_model",
35
+ type=str,
36
+ default="checkpoints/waveunet/model",
37
+ help="Reload a previously trained model",
38
+ )
39
+ parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
40
+ parser.add_argument(
41
+ "--levels", type=int, default=6, help="Number of DS/US blocks"
42
+ )
43
+ parser.add_argument(
44
+ "--depth", type=int, default=1, help="Number of convs per block"
45
+ )
46
+ parser.add_argument("--sr", type=int, default=44100, help="Sampling rate")
47
+ parser.add_argument(
48
+ "--channels", type=int, default=2, help="Number of input audio channels"
49
+ )
50
+ parser.add_argument(
51
+ "--kernel_size",
52
+ type=int,
53
+ default=5,
54
+ help="Filter width of kernels. Has to be an odd number",
55
+ )
56
+ parser.add_argument(
57
+ "--output_size", type=float, default=2.0, help="Output duration"
58
+ )
59
+ parser.add_argument(
60
+ "--strides", type=int, default=4, help="Strides in Waveunet"
61
+ )
62
+ parser.add_argument(
63
+ "--conv_type",
64
+ type=str,
65
+ default="gn",
66
+ help="Type of convolution (normal, BN-normalised, GN-normalised): normal/bn/gn",
67
+ )
68
+ parser.add_argument(
69
+ "--res",
70
+ type=str,
71
+ default="fixed",
72
+ help="Resampling strategy: fixed sinc-based lowpass filtering or learned conv layer: fixed/learned",
73
+ )
74
+ parser.add_argument(
75
+ "--separate",
76
+ type=int,
77
+ default=1,
78
+ help="Train separate model for each source (1) or only one (0)",
79
+ )
80
+ parser.add_argument(
81
+ "--feature_growth",
82
+ type=str,
83
+ default="double",
84
+ help="How the features in each layer should grow, either (add) the initial number of features each time, or multiply by 2 (double)",
85
+ )
86
+ """
87
+ parser.add_argument('--input', type=str, default=str(input),
88
+ help="Path to input mixture to be separated")
89
+ parser.add_argument('--output', type=str, default=out_path, help="Output path (same folder as input path if not set)")
90
+ """
91
+ args = parser.parse_args([])
92
+ self.args = args
93
+
94
+ num_features = (
95
+ [args.features * i for i in range(1, args.levels + 1)]
96
+ if args.feature_growth == "add"
97
+ else [args.features * 2 ** i for i in range(0, args.levels)]
98
+ )
99
+ target_outputs = int(args.output_size * args.sr)
100
+ self.model = Waveunet(
101
+ args.channels,
102
+ num_features,
103
+ args.channels,
104
+ args.instruments,
105
+ kernel_size=args.kernel_size,
106
+ target_output_size=target_outputs,
107
+ depth=args.depth,
108
+ strides=args.strides,
109
+ conv_type=args.conv_type,
110
+ res=args.res,
111
+ separate=args.separate,
112
+ )
113
+
114
+ if args.cuda:
115
+ self.model = model_utils.DataParallel(model)
116
+ print("move model to gpu")
117
+ self.model.cuda()
118
+
119
+ print("Loading model from checkpoint " + str(args.load_model))
120
+ state = model_utils.load_model(self.model, None, args.load_model, args.cuda)
121
+ print("Step", state["step"])
122
+
123
+ @cog.input("input", type=Path, help="audio mixture path")
124
+ def predict(self, input):
125
+ """Separate tracks from input mixture audio"""
126
+
127
+ out_path = Path(tempfile.mkdtemp())
128
+ zip_path = Path(tempfile.mkdtemp()) / "output.zip"
129
+
130
+ preds = predict_song(self.args, input, self.model)
131
+
132
+ out_names = []
133
+ for inst in preds.keys():
134
+ temp_n = os.path.join(
135
+ str(out_path), os.path.basename(str(input)) + "_" + inst + ".wav"
136
+ )
137
+ data.utils.write_wav(temp_n, preds[inst], self.args.sr)
138
+ out_names.append(temp_n)
139
+
140
+ with zipfile.ZipFile(str(zip_path), "w") as zf:
141
+ for i in out_names:
142
+ zf.write(str(i))
143
+
144
+ return zip_path