hieupt commited on
Commit
6fd5219
·
verified ·
1 Parent(s): 028e95a

Upload test.py

Browse files
Files changed (1) hide show
  1. test.py +204 -0
test.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import museval
2
+ from tqdm import tqdm
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ import data.utils
8
+ import model.utils as model_utils
9
+ import utils
10
+
11
+ def compute_model_output(model, inputs):
12
+ '''
13
+ Computes outputs of model with given inputs. Does NOT allow propagating gradients! See compute_loss for training.
14
+ Procedure depends on whether we have one model for each source or not
15
+ :param model: Model to train with
16
+ :param compute_grad: Whether to compute gradients
17
+ :return: Model outputs, Average loss over batch
18
+ '''
19
+ all_outputs = {}
20
+
21
+ if model.separate:
22
+ for inst in model.instruments:
23
+ output = model(inputs, inst)
24
+ all_outputs[inst] = output[inst].detach().clone()
25
+ else:
26
+ all_outputs = model(inputs)
27
+
28
+ return all_outputs
29
+
30
+ def predict(audio, model):
31
+ '''
32
+ Predict sources for a given audio input signal, with a given model. Audio is split into chunks to make predictions on each chunk before they are concatenated.
33
+ :param audio: Audio input tensor, either Pytorch tensor or numpy array
34
+ :param model: Pytorch model
35
+ :return: Source predictions, dictionary with source names as keys
36
+ '''
37
+ if isinstance(audio, torch.Tensor):
38
+ is_cuda = audio.is_cuda()
39
+ audio = audio.detach().cpu().numpy()
40
+ return_mode = "pytorch"
41
+ else:
42
+ return_mode = "numpy"
43
+
44
+ expected_outputs = audio.shape[1]
45
+
46
+ # Pad input if it is not divisible in length by the frame shift number
47
+ output_shift = model.shapes["output_frames"]
48
+ pad_back = audio.shape[1] % output_shift
49
+ pad_back = 0 if pad_back == 0 else output_shift - pad_back
50
+ if pad_back > 0:
51
+ audio = np.pad(audio, [(0,0), (0, pad_back)], mode="constant", constant_values=0.0)
52
+
53
+ target_outputs = audio.shape[1]
54
+ outputs = {key: np.zeros(audio.shape, np.float32) for key in model.instruments}
55
+
56
+ # Pad mixture across time at beginning and end so that neural network can make prediction at the beginning and end of signal
57
+ pad_front_context = model.shapes["output_start_frame"]
58
+ pad_back_context = model.shapes["input_frames"] - model.shapes["output_end_frame"]
59
+ audio = np.pad(audio, [(0,0), (pad_front_context, pad_back_context)], mode="constant", constant_values=0.0)
60
+
61
+ # Iterate over mixture magnitudes, fetch network prediction
62
+ with torch.no_grad():
63
+ for target_start_pos in range(0, target_outputs, model.shapes["output_frames"]):
64
+ # Prepare mixture excerpt by selecting time interval
65
+ curr_input = audio[:, target_start_pos:target_start_pos + model.shapes["input_frames"]] # Since audio was front-padded input of [targetpos:targetpos+inputframes] actually predicts [targetpos:targetpos+outputframes] target range
66
+
67
+ # Convert to Pytorch tensor for model prediction
68
+ curr_input = torch.from_numpy(curr_input).unsqueeze(0)
69
+
70
+ # Predict
71
+ for key, curr_targets in compute_model_output(model, curr_input).items():
72
+ outputs[key][:,target_start_pos:target_start_pos+model.shapes["output_frames"]] = curr_targets.squeeze(0).cpu().numpy()
73
+
74
+ # Crop to expected length (since we padded to handle the frame shift)
75
+ outputs = {key : outputs[key][:,:expected_outputs] for key in outputs.keys()}
76
+
77
+ if return_mode == "pytorch":
78
+ outputs = torch.from_numpy(outputs)
79
+ if is_cuda:
80
+ outputs = outputs.cuda()
81
+ return outputs
82
+
83
+ def predict_song(args, audio_path, model):
84
+ '''
85
+ Predicts sources for an audio file for which the file path is given, using a given model.
86
+ Takes care of resampling the input audio to the models sampling rate and resampling predictions back to input sampling rate.
87
+ :param args: Options dictionary
88
+ :param audio_path: Path to mixture audio file
89
+ :param model: Pytorch model
90
+ :return: Source estimates given as dictionary with keys as source names
91
+ '''
92
+ model.eval()
93
+
94
+ # Load mixture in original sampling rate
95
+ mix_audio, mix_sr = data.utils.load(audio_path, sr=None, mono=False)
96
+ mix_channels = mix_audio.shape[0]
97
+ mix_len = mix_audio.shape[1]
98
+
99
+ # Adapt mixture channels to required input channels
100
+ if args.channels == 1:
101
+ mix_audio = np.mean(mix_audio, axis=0, keepdims=True)
102
+ else:
103
+ if mix_channels == 1: # Duplicate channels if input is mono but model is stereo
104
+ mix_audio = np.tile(mix_audio, [args.channels, 1])
105
+ else:
106
+ assert(mix_channels == args.channels)
107
+
108
+ # resample to model sampling rate
109
+ mix_audio = data.utils.resample(mix_audio, mix_sr, args.sr)
110
+
111
+ sources = predict(mix_audio, model)
112
+
113
+ # Resample back to mixture sampling rate in case we had model on different sampling rate
114
+ sources = {key : data.utils.resample(sources[key], args.sr, mix_sr) for key in sources.keys()}
115
+
116
+ # In case we had to pad the mixture at the end, or we have a few samples too many due to inconsistent down- and upsamṕling, remove those samples from source prediction now
117
+ for key in sources.keys():
118
+ diff = sources[key].shape[1] - mix_len
119
+ if diff > 0:
120
+ print("WARNING: Cropping " + str(diff) + " samples")
121
+ sources[key] = sources[key][:, :-diff]
122
+ elif diff < 0:
123
+ print("WARNING: Padding output by " + str(diff) + " samples")
124
+ sources[key] = np.pad(sources[key], [(0,0), (0, -diff)], "constant", 0.0)
125
+
126
+ # Adapt channels
127
+ if mix_channels > args.channels:
128
+ assert(args.channels == 1)
129
+ # Duplicate mono predictions
130
+ sources[key] = np.tile(sources[key], [mix_channels, 1])
131
+ elif mix_channels < args.channels:
132
+ assert(mix_channels == 1)
133
+ # Reduce model output to mono
134
+ sources[key] = np.mean(sources[key], axis=0, keepdims=True)
135
+
136
+ sources[key] = np.asfortranarray(sources[key]) # So librosa does not complain if we want to save it
137
+
138
+ return sources
139
+
140
+ def evaluate(args, dataset, model, instruments):
141
+ '''
142
+ Evaluates a given model on a given dataset
143
+ :param args: Options dict
144
+ :param dataset: Dataset object
145
+ :param model: Pytorch model
146
+ :param instruments: List of source names
147
+ :return: Performance metric dictionary, list with each element describing one dataset sample's results
148
+ '''
149
+ perfs = list()
150
+ model.eval()
151
+ with torch.no_grad():
152
+ for example in dataset:
153
+ print("Evaluating " + example["mix"])
154
+
155
+ # Load source references in their original sr and channel number
156
+ target_sources = np.stack([data.utils.load(example[instrument], sr=None, mono=False)[0].T for instrument in instruments])
157
+
158
+ # Predict using mixture
159
+ pred_sources = predict_song(args, example["mix"], model)
160
+ pred_sources = np.stack([pred_sources[key].T for key in instruments])
161
+
162
+ # Evaluate
163
+ SDR, ISR, SIR, SAR, _ = museval.metrics.bss_eval(target_sources, pred_sources)
164
+ song = {}
165
+ for idx, name in enumerate(instruments):
166
+ song[name] = {"SDR" : SDR[idx], "ISR" : ISR[idx], "SIR" : SIR[idx], "SAR" : SAR[idx]}
167
+ perfs.append(song)
168
+
169
+ return perfs
170
+
171
+
172
+ def validate(args, model, criterion, test_data):
173
+ '''
174
+ Iterate with a given model over a given test dataset and compute the desired loss
175
+ :param args: Options dictionary
176
+ :param model: Pytorch model
177
+ :param criterion: Loss function to use (similar to Pytorch criterions)
178
+ :param test_data: Test dataset (Pytorch dataset)
179
+ :return:
180
+ '''
181
+ # PREPARE DATA
182
+ dataloader = torch.utils.data.DataLoader(test_data,
183
+ batch_size=args.batch_size,
184
+ shuffle=False,
185
+ num_workers=args.num_workers)
186
+
187
+ # VALIDATE
188
+ model.eval()
189
+ total_loss = 0.
190
+ with tqdm(total=len(test_data) // args.batch_size) as pbar, torch.no_grad():
191
+ for example_num, (x, targets) in enumerate(dataloader):
192
+ if args.cuda:
193
+ x = x.cuda()
194
+ for k in list(targets.keys()):
195
+ targets[k] = targets[k].cuda()
196
+
197
+ _, avg_loss = model_utils.compute_loss(model, x, targets, criterion)
198
+
199
+ total_loss += (1. / float(example_num + 1)) * (avg_loss - total_loss)
200
+
201
+ pbar.set_description("Current loss: {:.4f}".format(total_loss))
202
+ pbar.update(1)
203
+
204
+ return total_loss