jhtonyKoo commited on
Commit
dad5520
·
1 Parent(s): 0e6e391

Delete inference/mastering_transfer.py

Browse files
Files changed (1) hide show
  1. inference/mastering_transfer.py +0 -360
inference/mastering_transfer.py DELETED
@@ -1,360 +0,0 @@
1
- """
2
- Inference code of music style transfer
3
- of the work "Music Mixing Style Transfer: A Contrastive Learning Approach to Disentangle Audio Effects"
4
- Process : converts the mastering style of the input music recording to that of the refernce music.
5
- files inside the target directory should be organized as follow
6
- "path_to_data_directory"/"song_name_#1"/input.wav
7
- "path_to_data_directory"/"song_name_#1"/reference.wav
8
- ...
9
- "path_to_data_directory"/"song_name_#n"/input.wav
10
- "path_to_data_directory"/"song_name_#n"/reference.wav
11
- where the 'input' and 'reference' should share the same names.
12
- """
13
- import numpy as np
14
- from glob import glob
15
- import os
16
- import torch
17
-
18
- import sys
19
- currentdir = os.path.dirname(os.path.realpath(__file__))
20
- sys.path.append(os.path.join(os.path.dirname(currentdir), "mixing_style_transfer"))
21
- from networks import FXencoder, TCNModel
22
- from data_loader import *
23
- import librosa
24
- import pyloudnorm
25
-
26
-
27
-
28
- class Mastering_Style_Transfer_Inference:
29
- def __init__(self, args, trained_w_ddp=True):
30
- if torch.cuda.is_available():
31
- self.device = torch.device("cuda:0")
32
- else:
33
- self.device = torch.device("cpu")
34
-
35
- # inference computational hyperparameters
36
- self.args = args
37
- self.segment_length = args.segment_length
38
- self.batch_size = args.batch_size
39
- self.sample_rate = 44100 # sampling rate should be 44100
40
- self.time_in_seconds = int(args.segment_length // self.sample_rate)
41
-
42
- # directory configuration
43
- self.output_dir = args.target_dir if args.output_dir==None else args.output_dir
44
- self.target_dir = args.target_dir
45
-
46
- # load model and its checkpoint weights
47
- self.models = {}
48
- self.models['effects_encoder'] = FXencoder(args.cfg_encoder).to(self.device)
49
- self.models['mastering_converter'] = TCNModel(nparams=args.cfg_converter["condition_dimension"], \
50
- ninputs=2, \
51
- noutputs=2, \
52
- nblocks=args.cfg_converter["nblocks"], \
53
- dilation_growth=args.cfg_converter["dilation_growth"], \
54
- kernel_size=args.cfg_converter["kernel_size"], \
55
- channel_width=args.cfg_converter["channel_width"], \
56
- stack_size=args.cfg_converter["stack_size"], \
57
- cond_dim=args.cfg_converter["condition_dimension"], \
58
- causal=args.cfg_converter["causal"]).to(self.device)
59
-
60
- ckpt_paths = {'effects_encoder' : args.ckpt_path_enc, \
61
- 'mastering_converter' : args.ckpt_path_conv}
62
- # reload saved model weights
63
- ddp = trained_w_ddp
64
- self.reload_weights(ckpt_paths, ddp=ddp)
65
-
66
-
67
- # reload model weights from the target checkpoint path
68
- def reload_weights(self, ckpt_paths, ddp=True):
69
- for cur_model_name in self.models.keys():
70
- checkpoint = torch.load(ckpt_paths[cur_model_name], map_location=self.device)
71
-
72
- from collections import OrderedDict
73
- new_state_dict = OrderedDict()
74
- for k, v in checkpoint["model"].items():
75
- # remove `module.` if the model was trained with DDP
76
- name = k[7:] if ddp else k
77
- new_state_dict[name] = v
78
-
79
- # load params
80
- self.models[cur_model_name].load_state_dict(new_state_dict)
81
-
82
- print(f"---reloaded checkpoint weights : {cur_model_name} ---")
83
-
84
-
85
- # Inference whole song
86
- def inference(self, input_track_path, reference_track_path):
87
- print("\n======= Start to inference music mastering style transfer =======")
88
-
89
- # load input wavs
90
- input_aud = load_wav_segment(input_track_path, axis=0)
91
- reference_aud = load_wav_segment(reference_track_path, axis=0)
92
-
93
- # loudness normalization for stability
94
- meter = pyloudnorm.Meter(44100)
95
- norm_loudness_gain = -16.
96
- loudness_in = meter.integrated_loudness(input_aud.transpose(-1, -2))
97
- loudness_ref = meter.integrated_loudness(reference_aud.transpose(-1, -2))
98
-
99
- input_aud = pyloudnorm.normalize.loudness(input_aud, loudness_in, norm_loudness_gain)
100
- input_aud = np.clip(input_aud, -1., 1.)
101
- reference_aud = pyloudnorm.normalize.loudness(reference_aud, loudness_ref, norm_loudness_gain)
102
- reference_aud = np.clip(reference_aud, -1., 1.)
103
-
104
- input_aud = torch.FloatTensor(input_aud).to(self.device)
105
- reference_aud = torch.FloatTensor(reference_aud).to(self.device)
106
-
107
- cur_out_dir = './yt_dir/0/'
108
- os.makedirs(cur_out_dir, exist_ok=True)
109
- ''' segmentize whole songs into batch '''
110
- if input_aud.shape[1] > self.args.segment_length:
111
- cur_inst_input_stem = self.batchwise_segmentization(input_aud, \
112
- "input", \
113
- segment_length=self.args.segment_length, \
114
- discard_last=False)
115
- else:
116
- cur_inst_input_stem = [input_aud.unsqueeze(0)]
117
- if reference_aud.shape[1] > self.args.segment_length*2:
118
- cur_inst_reference_stem = self.batchwise_segmentization(reference_aud, \
119
- "reference", \
120
- segment_length=self.args.segment_length_ref, \
121
- discard_last=False)
122
- else:
123
- cur_inst_reference_stem = [reference_aud.unsqueeze(0)]
124
-
125
- ''' inference '''
126
- # first extract reference style embedding
127
- infered_ref_data_list = []
128
- for cur_ref_data in cur_inst_reference_stem:
129
- cur_ref_data = cur_ref_data.to(self.device)
130
- # Effects Encoder inference
131
- with torch.no_grad():
132
- self.models["effects_encoder"].eval()
133
- reference_feature = self.models["effects_encoder"](cur_ref_data)
134
- infered_ref_data_list.append(reference_feature)
135
- # compute average value from the extracted embeddings
136
- infered_ref_data = torch.stack(infered_ref_data_list)
137
- infered_ref_data_avg = torch.mean(infered_ref_data.reshape(infered_ref_data.shape[0]*infered_ref_data.shape[1], infered_ref_data.shape[2]), axis=0)
138
-
139
- # mastering style converter
140
- infered_data_list = []
141
- for cur_data in cur_inst_input_stem:
142
- cur_data = cur_data.to(self.device)
143
- with torch.no_grad():
144
- self.models["mastering_converter"].eval()
145
- infered_data = self.models["mastering_converter"](cur_data, infered_ref_data_avg.unsqueeze(0))
146
- infered_data_list.append(infered_data.cpu().detach())
147
-
148
- # combine back to whole song
149
- for cur_idx, cur_batch_infered_data in enumerate(infered_data_list):
150
- cur_infered_data_sequential = torch.cat(torch.unbind(cur_batch_infered_data, dim=0), dim=-1)
151
- fin_data_out = cur_infered_data_sequential if cur_idx==0 else torch.cat((fin_data_out, cur_infered_data_sequential), dim=-1)
152
- # final output of current instrument
153
- fin_data_out_mastered = fin_data_out[:, :input_aud.shape[-1]].numpy()
154
-
155
- # adjust to reference's loudness
156
- loudness_out = meter.integrated_loudness(fin_data_out_mastered.transpose(-1, -2))
157
- fin_data_out_mastered = pyloudnorm.normalize.loudness(fin_data_out_mastered, loudness_out, loudness_ref)
158
- fin_data_out_mastered = np.clip(fin_data_out_mastered, -1., 1.)
159
-
160
- # remix
161
- fin_output_path_mastering = os.path.join(cur_out_dir, f"remastered_output.wav")
162
- sf.write(fin_output_path_mastering, fin_data_out_mastered.transpose(-1, -2), self.args.sample_rate, 'PCM_16')
163
-
164
- return fin_output_path_mastering
165
-
166
-
167
- # Inference whole song
168
- def inference_interpolation(self, ):
169
- print("\n======= Start to inference interpolation examples =======")
170
- # normalized input
171
- output_name_tag = 'output_interpolation' if self.args.normalize_input else 'output_notnormed_interpolation'
172
-
173
- for step, (input_stems, reference_stems_A, reference_stems_B, dir_name) in enumerate(self.data_loader):
174
- print(f"---inference file name : {dir_name[0]}---")
175
- cur_out_dir = dir_name[0].replace(self.target_dir, self.output_dir)
176
- os.makedirs(cur_out_dir, exist_ok=True)
177
- ''' stem-level inference '''
178
- inst_outputs = []
179
- for cur_inst_idx, cur_inst_name in enumerate(self.args.instruments):
180
- print(f'\t{cur_inst_name}...')
181
- ''' segmentize whole song '''
182
- # segmentize input according to number of interpolating segments
183
- interpolate_segment_length = input_stems[0][cur_inst_idx].shape[1] // self.args.interpolate_segments + 1
184
- cur_inst_input_stem = self.batchwise_segmentization(input_stems[0][cur_inst_idx], \
185
- dir_name[0], \
186
- segment_length=interpolate_segment_length, \
187
- discard_last=False)
188
- # batchwise segmentize 2 reference tracks
189
- if len(reference_stems_A[0][cur_inst_idx][0]) > self.args.segment_length_ref:
190
- cur_inst_reference_stem_A = self.batchwise_segmentization(reference_stems_A[0][cur_inst_idx], \
191
- dir_name[0], \
192
- segment_length=self.args.segment_length_ref, \
193
- discard_last=False)
194
- else:
195
- cur_inst_reference_stem_A = [reference_stems_A[:, cur_inst_idx]]
196
- if len(reference_stems_B[0][cur_inst_idx][0]) > self.args.segment_length_ref:
197
- cur_inst_reference_stem_B = self.batchwise_segmentization(reference_stems_B[0][cur_inst_idx], \
198
- dir_name[0], \
199
- segment_length=self.args.segment_length, \
200
- discard_last=False)
201
- else:
202
- cur_inst_reference_stem_B = [reference_stems_B[:, cur_inst_idx]]
203
-
204
- ''' inference '''
205
- # first extract reference style embeddings
206
- # reference A
207
- infered_ref_data_list = []
208
- for cur_ref_data in cur_inst_reference_stem_A:
209
- cur_ref_data = cur_ref_data.to(self.device)
210
- # Effects Encoder inference
211
- with torch.no_grad():
212
- self.models["effects_encoder"].eval()
213
- reference_feature = self.models["effects_encoder"](cur_ref_data)
214
- infered_ref_data_list.append(reference_feature)
215
- # compute average value from the extracted exbeddings
216
- infered_ref_data = torch.stack(infered_ref_data_list)
217
- infered_ref_data_avg_A = torch.mean(infered_ref_data.reshape(infered_ref_data.shape[0]*infered_ref_data.shape[1], infered_ref_data.shape[2]), axis=0)
218
-
219
- # reference B
220
- infered_ref_data_list = []
221
- for cur_ref_data in cur_inst_reference_stem_B:
222
- cur_ref_data = cur_ref_data.to(self.device)
223
- # Effects Encoder inference
224
- with torch.no_grad():
225
- self.models["effects_encoder"].eval()
226
- reference_feature = self.models["effects_encoder"](cur_ref_data)
227
- infered_ref_data_list.append(reference_feature)
228
- # compute average value from the extracted exbeddings
229
- infered_ref_data = torch.stack(infered_ref_data_list)
230
- infered_ref_data_avg_B = torch.mean(infered_ref_data.reshape(infered_ref_data.shape[0]*infered_ref_data.shape[1], infered_ref_data.shape[2]), axis=0)
231
-
232
- # mixing style converter
233
- infered_data_list = []
234
- for cur_idx, cur_data in enumerate(cur_inst_input_stem):
235
- cur_data = cur_data.to(self.device)
236
- # perform linear interpolation on embedding space
237
- cur_weight = (self.args.interpolate_segments-1-cur_idx) / (self.args.interpolate_segments-1)
238
- cur_ref_emb = cur_weight * infered_ref_data_avg_A + (1-cur_weight) * infered_ref_data_avg_B
239
- with torch.no_grad():
240
- self.models["mastering_converter"].eval()
241
- infered_data = self.models["mastering_converter"](cur_data, cur_ref_emb.unsqueeze(0))
242
- infered_data_list.append(infered_data.cpu().detach())
243
-
244
- # combine back to whole song
245
- for cur_idx, cur_batch_infered_data in enumerate(infered_data_list):
246
- cur_infered_data_sequential = torch.cat(torch.unbind(cur_batch_infered_data, dim=0), dim=-1)
247
- fin_data_out = cur_infered_data_sequential if cur_idx==0 else torch.cat((fin_data_out, cur_infered_data_sequential), dim=-1)
248
- # final output of current instrument
249
- fin_data_out_inst = fin_data_out[:, :input_stems[0][cur_inst_idx].shape[-1]].numpy()
250
- inst_outputs.append(fin_data_out_inst)
251
-
252
- # save output of each instrument
253
- if self.args.save_each_inst:
254
- sf.write(os.path.join(cur_out_dir, f"{cur_inst_name}_{output_name_tag}.wav"), fin_data_out_inst.transpose(-1, -2), self.args.sample_rate, 'PCM_16')
255
- # remix
256
- fin_data_out_mix = sum(inst_outputs)
257
- fin_output_path = os.path.join(cur_out_dir, f"mixture_{output_name_tag}.wav")
258
- sf.write(fin_output_path, fin_data_out_mix.transpose(-1, -2), self.args.sample_rate, 'PCM_16')
259
-
260
- return fin_output_path
261
-
262
-
263
- # function that segmentize an entire song into batch
264
- def batchwise_segmentization(self, target_song, song_name, segment_length, discard_last=False):
265
- assert target_song.shape[-1] >= self.args.segment_length, \
266
- f"Error : Insufficient duration!\n\t \
267
- Target song's length is shorter than segment length.\n\t \
268
- Song name : {song_name}\n\t \
269
- Consider changing the 'segment_length' or song with sufficient duration"
270
-
271
- # discard restovers (last segment)
272
- if discard_last:
273
- target_length = target_song.shape[-1] - target_song.shape[-1] % segment_length
274
- target_song = target_song[:, :target_length]
275
- # pad last segment
276
- else:
277
- pad_length = segment_length - target_song.shape[-1] % segment_length
278
- target_song = torch.cat((target_song, torch.zeros(2, pad_length)), axis=-1)
279
-
280
- # segmentize according to the given segment_length
281
- whole_batch_data = []
282
- batch_wise_data = []
283
- for cur_segment_idx in range(target_song.shape[-1]//segment_length):
284
- batch_wise_data.append(target_song[..., cur_segment_idx*segment_length:(cur_segment_idx+1)*segment_length])
285
- if len(batch_wise_data)==self.args.batch_size:
286
- whole_batch_data.append(torch.stack(batch_wise_data, dim=0))
287
- batch_wise_data = []
288
- if batch_wise_data:
289
- whole_batch_data.append(torch.stack(batch_wise_data, dim=0))
290
-
291
- return whole_batch_data
292
-
293
-
294
-
295
- def set_up_mastering(start_point_in_second=0, duration_in_second=30):
296
- os.environ['MASTER_ADDR'] = '127.0.0.1'
297
- os.environ["CUDA_VISIBLE_DEVICES"] = '0'
298
- os.environ['MASTER_PORT'] = '8888'
299
-
300
- def str2bool(v):
301
- if v.lower() in ('yes', 'true', 't', 'y', '1'):
302
- return True
303
- elif v.lower() in ('no', 'false', 'f', 'n', '0'):
304
- return False
305
- else:
306
- raise argparse.ArgumentTypeError('Boolean value expected.')
307
-
308
- ''' Configurations for music mixing style transfer '''
309
- currentdir = os.path.dirname(os.path.realpath(__file__))
310
- default_ckpt_path_enc = os.path.join(os.path.dirname(currentdir), 'weights', 'FXencoder_ps.pt')
311
- default_ckpt_path_conv = os.path.join(os.path.dirname(currentdir), 'weights', 'MixFXcloner_ps.pt')
312
- default_ckpt_path_master = os.path.join(os.path.dirname(currentdir), 'weights', 'MasterFXcloner_ps.pt')
313
- default_norm_feature_path = os.path.join(os.path.dirname(currentdir), 'weights', 'musdb18_fxfeatures_eqcompimagegain.npy')
314
-
315
- import argparse
316
- import yaml
317
- parser = argparse.ArgumentParser()
318
-
319
- directory_args = parser.add_argument_group('Directory args')
320
- # directory paths
321
- directory_args.add_argument('--target_dir', type=str, default='./yt_dir/')
322
- directory_args.add_argument('--output_dir', type=str, default=None, help='if no output_dir is specified (None), the results will be saved inside the target_dir')
323
- directory_args.add_argument('--input_file_name', type=str, default='input')
324
- directory_args.add_argument('--reference_file_name', type=str, default='reference')
325
- directory_args.add_argument('--reference_file_name_2interpolate', type=str, default='reference_B')
326
- # saved weights
327
- directory_args.add_argument('--ckpt_path_enc', type=str, default=default_ckpt_path_enc)
328
- directory_args.add_argument('--ckpt_path_conv', type=str, default=default_ckpt_path_master)
329
- directory_args.add_argument('--precomputed_normalization_feature', type=str, default=default_norm_feature_path)
330
-
331
- inference_args = parser.add_argument_group('Inference args')
332
- inference_args.add_argument('--sample_rate', type=int, default=44100)
333
- inference_args.add_argument('--segment_length', type=int, default=2**19) # segmentize input according to this duration
334
- inference_args.add_argument('--segment_length_ref', type=int, default=2**19) # segmentize reference according to this duration
335
- # stem-level instruments & separation
336
- inference_args.add_argument('--instruments', type=str2bool, default=["drums", "bass", "other", "vocals"], help='instrumental tracks to perform style transfer')
337
- inference_args.add_argument('--stem_level_directory_name', type=str, default='separated')
338
- inference_args.add_argument('--save_each_inst', type=str2bool, default=False)
339
- inference_args.add_argument('--do_not_separate', type=str2bool, default=False)
340
- inference_args.add_argument('--separation_model', type=str, default='htdemucs')
341
- # FX normalization
342
- inference_args.add_argument('--normalize_input', type=str2bool, default=False)
343
- inference_args.add_argument('--normalization_order', type=str2bool, default=['loudness', 'eq', 'compression', 'imager', 'loudness']) # Effects to be normalized, order matters
344
- # interpolation
345
- inference_args.add_argument('--interpolation', type=str2bool, default=False)
346
- inference_args.add_argument('--interpolate_segments', type=int, default=30)
347
-
348
- device_args = parser.add_argument_group('Device args')
349
- device_args.add_argument('--workers', type=int, default=1)
350
- device_args.add_argument('--batch_size', type=int, default=1) # for processing long audio
351
-
352
- args = parser.parse_args()
353
-
354
- # load network configurations
355
- with open(os.path.join(currentdir, 'configs.yaml'), 'r') as f:
356
- configs = yaml.full_load(f)
357
- args.cfg_encoder = configs['Effects_Encoder']['default']
358
- args.cfg_converter = configs['TCN']['default']
359
-
360
- return args