diff --git a/.gitattributes b/.gitattributes
index c7d9f3332a950355d5a77d85000f05e6f45435ea..e832a12027ab7698d80304afad130dd240a00897 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -32,3 +32,38 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+samples/interpolation/\#0/input.wav filter=lfs diff=lfs merge=lfs -text
+samples/interpolation/\#0/reference_B.wav filter=lfs diff=lfs merge=lfs -text
+samples/interpolation/\#0/reference.wav filter=lfs diff=lfs merge=lfs -text
+samples/interpolation/\#0/separated/mdx_extra/input/bass.wav filter=lfs diff=lfs merge=lfs -text
+samples/interpolation/\#0/separated/mdx_extra/input/drums.wav filter=lfs diff=lfs merge=lfs -text
+samples/interpolation/\#0/separated/mdx_extra/input/other.wav filter=lfs diff=lfs merge=lfs -text
+samples/interpolation/\#0/separated/mdx_extra/input/vocals.wav filter=lfs diff=lfs merge=lfs -text
+samples/interpolation/\#0/separated/mdx_extra/reference_B/bass.wav filter=lfs diff=lfs merge=lfs -text
+samples/interpolation/\#0/separated/mdx_extra/reference_B/drums.wav filter=lfs diff=lfs merge=lfs -text
+samples/interpolation/\#0/separated/mdx_extra/reference_B/other.wav filter=lfs diff=lfs merge=lfs -text
+samples/interpolation/\#0/separated/mdx_extra/reference_B/vocals.wav filter=lfs diff=lfs merge=lfs -text
+samples/interpolation/\#0/separated/mdx_extra/reference/bass.wav filter=lfs diff=lfs merge=lfs -text
+samples/interpolation/\#0/separated/mdx_extra/reference/drums.wav filter=lfs diff=lfs merge=lfs -text
+samples/interpolation/\#0/separated/mdx_extra/reference/other.wav filter=lfs diff=lfs merge=lfs -text
+samples/interpolation/\#0/separated/mdx_extra/reference/vocals.wav filter=lfs diff=lfs merge=lfs -text
+samples/style_transfer/\#0/input.wav filter=lfs diff=lfs merge=lfs -text
+samples/style_transfer/\#0/reference.wav filter=lfs diff=lfs merge=lfs -text
+samples/style_transfer/\#0/separated/mdx_extra/input/bass.wav filter=lfs diff=lfs merge=lfs -text
+samples/style_transfer/\#0/separated/mdx_extra/input/drums.wav filter=lfs diff=lfs merge=lfs -text
+samples/style_transfer/\#0/separated/mdx_extra/input/other.wav filter=lfs diff=lfs merge=lfs -text
+samples/style_transfer/\#0/separated/mdx_extra/input/vocals.wav filter=lfs diff=lfs merge=lfs -text
+samples/style_transfer/\#0/separated/mdx_extra/reference/bass.wav filter=lfs diff=lfs merge=lfs -text
+samples/style_transfer/\#0/separated/mdx_extra/reference/drums.wav filter=lfs diff=lfs merge=lfs -text
+samples/style_transfer/\#0/separated/mdx_extra/reference/other.wav filter=lfs diff=lfs merge=lfs -text
+samples/style_transfer/\#0/separated/mdx_extra/reference/vocals.wav filter=lfs diff=lfs merge=lfs -text
+samples/style_transfer/\#2/input.wav filter=lfs diff=lfs merge=lfs -text
+samples/style_transfer/\#2/reference.wav filter=lfs diff=lfs merge=lfs -text
+samples/style_transfer/\#2/separated/mdx_extra/input/bass.wav filter=lfs diff=lfs merge=lfs -text
+samples/style_transfer/\#2/separated/mdx_extra/input/drums.wav filter=lfs diff=lfs merge=lfs -text
+samples/style_transfer/\#2/separated/mdx_extra/input/other.wav filter=lfs diff=lfs merge=lfs -text
+samples/style_transfer/\#2/separated/mdx_extra/input/vocals.wav filter=lfs diff=lfs merge=lfs -text
+samples/style_transfer/\#2/separated/mdx_extra/reference/bass.wav filter=lfs diff=lfs merge=lfs -text
+samples/style_transfer/\#2/separated/mdx_extra/reference/drums.wav filter=lfs diff=lfs merge=lfs -text
+samples/style_transfer/\#2/separated/mdx_extra/reference/other.wav filter=lfs diff=lfs merge=lfs -text
+samples/style_transfer/\#2/separated/mdx_extra/reference/vocals.wav filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
index 96b0dc38120e15bb1d7cd5eedfd09a0c078df677..dac150e11107ccd6e25fccdf4763fad7eca4c827 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,126 @@
----
-title: Music Mixing Style Transfer
-emoji: 🏃
-colorFrom: gray
-colorTo: green
-sdk: gradio
-sdk_version: 3.21.0
-app_file: app.py
-pinned: false
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# Music Mixing Style Transfer
+
+This repository includes source code and pre-trained models of the work *Music Mixing Style Transfer: A Contrastive Learning Approach to Disentangle Audio Effects* by [Junghyun Koo](https://linkedin.com/in/junghyun-koo-525a31251), [Marco A. Martínez-Ramírez](https://m-marco.com/about/), [Wei-Hsiang Liao](https://jp.linkedin.com/in/wei-hsiang-liao-66283154), [Stefan Uhlich](https://scholar.google.de/citations?user=hja8ejYAAAAJ&hl=de), [Kyogu Lee](https://linkedin.com/in/kyogu-lee-7a93b611), and [Yuki Mitsufuji](https://www.yukimitsufuji.com/).
+
+
+[![arXiv](https://img.shields.io/badge/arXiv-2211.02247-b31b1b.svg)](https://arxiv.org/abs/2211.02247)
+[![Web](https://img.shields.io/badge/Web-Demo_Page-green.svg)](https://jhtonyKoo.github.io/MixingStyleTransfer/)
+[![Supplementary](https://img.shields.io/badge/Supplementary-Materials-white.svg)](https://tinyurl.com/4math4pm)
+
+
+
+## Pre-trained Models
+| Model | Configuration | Training Dataset |
+|-------------|-------------|-------------|
+[FXencoder (Φp.s.)](https://drive.google.com/file/d/1BFABsJRUVgJS5UE5iuM03dbfBjmI9LT5/view?usp=sharing) | Used *FX normalization* and *probability scheduling* techniques for training | Trained with [MUSDB18](https://sigsep.github.io/datasets/musdb.html) Dataset
+[MixFXcloner](https://drive.google.com/file/d/1Qu8rD7HpTNA1gJUVp2IuaeU_Nue8-VA3/view?usp=sharing) | Mixing style converter trained with Φp.s. | Trained with [MUSDB18](https://sigsep.github.io/datasets/musdb.html) Dataset
+
+
+## Installation
+```
+pip install -r "requirements.txt"
+```
+
+# Inference
+
+## Mixing Style Transfer
+
+To run the inference code for mixing style transfer,
+1. Download pre-trained models above and place them under the folder named 'weights' (default)
+2. Prepare input and reference tracks under the folder named 'samples/style_transfer' (default)
+Target files should be organized as follow:
+```
+ "path_to_data_directory"/"song_name_#1"/"input_file_name".wav
+ "path_to_data_directory"/"song_name_#1"/"reference_file_name".wav
+ ...
+ "path_to_data_directory"/"song_name_#n"/"input_file_name".wav
+ "path_to_data_directory"/"song_name_#n"/"reference_file_name".wav
+```
+3. Run 'inference/style_transfer.py'
+```
+python inference/style_transfer.py \
+ --ckpt_path_enc "path_to_checkpoint_of_FXencoder" \
+ --ckpt_path_conv "path_to_checkpoint_of_MixFXcloner" \
+ --target_dir "path_to_directory_containing_inference_samples"
+```
+4. Outputs will be stored under the same folder to inference data directory (default)
+
+*Note: The system accepts WAV files of stereo-channeled, 44.1kHZ, and 16-bit rate. We recommend to use audio samples that are not too loud: it's better for the system to transfer these samples by reducing the loudness of mixture-wise inputs (maintaining the overall balance of each instrument).*
+
+
+
+## Interpolation With 2 Different Reference Tracks
+
+Inference code for two reference tracks is almost the same as mixing style transfer.
+1. Download pre-trained models above and place them under the folder named 'weights' (default)
+2. Prepare input and 2 reference tracks under the folder named 'samples/style_transfer' (default)
+Target files should be organized as follow:
+```
+ "path_to_data_directory"/"song_name_#1"/"input_track_name".wav
+ "path_to_data_directory"/"song_name_#1"/"reference_file_name".wav
+ "path_to_data_directory"/"song_name_#1"/"reference_file_name_2interpolate".wav
+ ...
+ "path_to_data_directory"/"song_name_#n"/"input_track_name".wav
+ "path_to_data_directory"/"song_name_#n"/"reference_file_name".wav
+ "path_to_data_directory"/"song_name_#n"/"reference_file_name_2interpolate".wav
+```
+3. Run 'inference/style_transfer.py'
+```
+python inference/style_transfer.py \
+ --ckpt_path_enc "path_to_checkpoint_of_FXencoder" \
+ --ckpt_path_conv "path_to_checkpoint_of_MixFXcloner" \
+ --target_dir "path_to_directory_containing_inference_samples" \
+ --interpolation True \
+ --interpolate_segments "number of segments to perform interpolation"
+```
+4. Outputs will be stored under the same folder to inference data directory (default)
+
+*Note: This example of interpolating 2 different reference tracks is not mentioned in the paper, but this example implies a potential for controllable style transfer using latent space.*
+
+
+
+## Feature Extraction Using *FXencoder*
+
+This inference code will extracts audio effects-related embeddings using our proposed FXencoder. This code will process all the .wav files under the target directory.
+
+1. Download FXencoder's pre-trained model above and place it under the folder named 'weights' (default)=
+2. Run 'inference/style_transfer.py'
+```
+python inference/feature_extraction.py \
+ --ckpt_path_enc "path_to_checkpoint_of_FXencoder" \
+ --target_dir "path_to_directory_containing_inference_samples"
+```
+3. Outputs will be stored under the same folder to inference data directory (default)
+
+
+
+
+# Implementation
+
+All the details of our system implementation are under the folder "mixing_style_transfer".
+
+FXmanipulator
+ -> mixing_style_transfer/mixing_manipulator/
+network architectures
+ -> mixing_style_transfer/networks/
+configuration of each sub-networks
+ -> mixing_style_transfer/networks/configs.yaml
+data loader
+ -> mixing_style_transfer/data_loader/
+
+
+# Citation
+
+Please consider citing the work upon usage.
+
+```
+@article{koo2022music,
+ title={Music Mixing Style Transfer: A Contrastive Learning Approach to Disentangle Audio Effects},
+ author={Koo, Junghyun and Martinez-Ramirez, Marco A and Liao, Wei-Hsiang and Uhlich, Stefan and Lee, Kyogu and Mitsufuji, Yuki},
+ journal={arXiv preprint arXiv:2211.02247},
+ year={2022}
+}
+```
+
+
+
diff --git a/inference/configs.yaml b/inference/configs.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ed02a9ec0b889824491aa2a72ce0c9a3515ace3d
--- /dev/null
+++ b/inference/configs.yaml
@@ -0,0 +1,30 @@
+# model architecture configurations
+
+
+# Music Effects Encoder
+Effects_Encoder:
+
+ default:
+ channels: [16, 32, 64, 128, 256, 256, 512, 512, 1024, 1024, 2048, 2048]
+ kernels: [25, 25, 15, 15, 10, 10, 10, 10, 5, 5, 5, 5]
+ strides: [4, 4, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1]
+ dilation: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
+ bias: True
+ norm: 'batch'
+ conv_block: 'res'
+ activation: "relu"
+
+
+# TCN
+TCN:
+
+ # receptive field = 5.2 seconds
+ default:
+ condition_dimension: 2048
+ nblocks: 14
+ dilation_growth: 2
+ kernel_size: 15
+ channel_width: 128
+ stack_size: 15
+ causal: False
+
diff --git a/inference/feature_extraction.py b/inference/feature_extraction.py
new file mode 100644
index 0000000000000000000000000000000000000000..a846232991200df780a3938a5e19d3892e9a812e
--- /dev/null
+++ b/inference/feature_extraction.py
@@ -0,0 +1,194 @@
+"""
+ Inference code of extracting embeddings from music recordings using FXencoder
+ of the work "Music Mixing Style Transfer: A Contrastive Learning Approach to Disentangle Audio Effects"
+
+ Process : extracts FX embeddings of each song inside the target directory.
+"""
+from glob import glob
+import os
+import librosa
+import numpy as np
+import torch
+
+import sys
+currentdir = os.path.dirname(os.path.realpath(__file__))
+sys.path.append(os.path.join(os.path.dirname(currentdir), "mixing_style_transfer"))
+from networks import FXencoder
+from data_loader import *
+
+
+class FXencoder_Inference:
+ def __init__(self, args, trained_w_ddp=True):
+ if args.inference_device!='cpu' and torch.cuda.is_available():
+ self.device = torch.device("cuda:0")
+ else:
+ self.device = torch.device("cpu")
+
+ # inference computational hyperparameters
+ self.segment_length = args.segment_length
+ self.batch_size = args.batch_size
+ self.sample_rate = 44100 # sampling rate should be 44100
+ self.time_in_seconds = int(args.segment_length // self.sample_rate)
+
+ # directory configuration
+ self.output_dir = args.target_dir if args.output_dir==None else args.output_dir
+ self.target_dir = args.target_dir
+
+ # load model and its checkpoint weights
+ self.models = {}
+ self.models['effects_encoder'] = FXencoder(args.cfg_encoder).to(self.device)
+ ckpt_paths = {'effects_encoder' : args.ckpt_path_enc}
+ # reload saved model weights
+ ddp = trained_w_ddp
+ self.reload_weights(ckpt_paths, ddp=ddp)
+
+ # save current arguments
+ self.save_args(args)
+
+
+ # reload model weights from the target checkpoint path
+ def reload_weights(self, ckpt_paths, ddp=True):
+ for cur_model_name in self.models.keys():
+ checkpoint = torch.load(ckpt_paths[cur_model_name], map_location=self.device)
+
+ from collections import OrderedDict
+ new_state_dict = OrderedDict()
+ for k, v in checkpoint["model"].items():
+ # remove `module.` if the model was trained with DDP
+ name = k[7:] if ddp else k
+ new_state_dict[name] = v
+
+ # load params
+ self.models[cur_model_name].load_state_dict(new_state_dict)
+
+ print(f"---reloaded checkpoint weights : {cur_model_name} ---")
+
+
+ # save averaged embedding from whole songs
+ def save_averaged_embeddings(self, ):
+ # embedding output directory path
+ emb_out_dir = f"{self.output_dir}"
+ print(f'\n\n=====Inference seconds : {self.time_in_seconds}=====')
+
+ # target_file_paths = glob(f"{self.target_dir}/**/*.wav", recursive=True)
+ target_file_paths = glob(os.path.join(self.target_dir, '**', '*.wav'), recursive=True)
+ for step, target_file_path in enumerate(target_file_paths):
+ print(f"\nInference step : {step+1}/{len(target_file_paths)}")
+ print(f"---current file path : {target_file_path}---")
+
+ ''' load waveform signal '''
+ target_song_whole = load_wav_segment(target_file_path, axis=0)
+ # check if mono -> convert to stereo by duplicating mono signal
+ if len(target_song_whole.shape)==1:
+ target_song_whole = np.stack((target_song_whole, target_song_whole), axis=0)
+ # check axis dimension
+ # signal shape should be : [channel, signal duration]
+ elif target_song_whole.shape[1]==2:
+ target_song_whole = target_song_whole.transpose()
+ target_song_whole = torch.from_numpy(target_song_whole).float()
+ ''' segmentize whole songs into batch '''
+ whole_batch_data = self.batchwise_segmentization(target_song_whole, target_file_path)
+
+ ''' inference '''
+ # infer whole song
+ infered_data_list = []
+ infered_c_list = []
+ infered_z_list = []
+ for cur_idx, cur_data in enumerate(whole_batch_data):
+ cur_data = cur_data.to(self.device)
+
+ with torch.no_grad():
+ self.models["effects_encoder"].eval()
+ # FXencoder
+ out_c_emb = self.models["effects_encoder"](cur_data)
+ infered_c_list.append(out_c_emb.cpu().detach())
+ avg_c_feat = torch.mean(torch.cat(infered_c_list, dim=0), dim=0).squeeze().cpu().detach().numpy()
+
+ # save outputs
+ cur_output_path = target_file_path.replace(self.target_dir, self.output_dir).replace('.wav', '_fx_embedding.npy')
+ os.makedirs(os.path.dirname(cur_output_path), exist_ok=True)
+ np.save(cur_output_path, avg_c_feat)
+
+
+ # function that segmentize an entire song into batch
+ def batchwise_segmentization(self, target_song, target_file_path, discard_last=False):
+ assert target_song.shape[-1] >= self.segment_length, \
+ f"Error : Insufficient duration!\n\t \
+ Target song's length is shorter than segment length.\n\t \
+ Song name : {target_file_path}\n\t \
+ Consider changing the 'segment_length' or song with sufficient duration"
+
+ # discard restovers (last segment)
+ if discard_last:
+ target_length = target_song.shape[-1] - target_song.shape[-1] % self.segment_length
+ target_song = target_song[:, :target_length]
+ # pad last segment
+ else:
+ pad_length = self.segment_length - target_song.shape[-1] % self.segment_length
+ target_song = torch.cat((target_song, torch.zeros(2, pad_length)), axis=-1)
+
+ whole_batch_data = []
+ batch_wise_data = []
+ for cur_segment_idx in range(target_song.shape[-1]//self.segment_length):
+ batch_wise_data.append(target_song[..., cur_segment_idx*self.segment_length:(cur_segment_idx+1)*self.segment_length])
+ if len(batch_wise_data)==self.batch_size:
+ whole_batch_data.append(torch.stack(batch_wise_data, dim=0))
+ batch_wise_data = []
+ if batch_wise_data:
+ whole_batch_data.append(torch.stack(batch_wise_data, dim=0))
+
+ return whole_batch_data
+
+
+ # save current inference arguments
+ def save_args(self, params):
+ info = '\n[args]\n'
+ for sub_args in parser._action_groups:
+ if sub_args.title in ['positional arguments', 'optional arguments', 'options']:
+ continue
+ size_sub = len(sub_args._group_actions)
+ info += f' {sub_args.title} ({size_sub})\n'
+ for i, arg in enumerate(sub_args._group_actions):
+ prefix = '-'
+ info += f' {prefix} {arg.dest:20s}: {getattr(params, arg.dest)}\n'
+ info += '\n'
+
+ os.makedirs(self.output_dir, exist_ok=True)
+ record_path = f"{self.output_dir}feature_extraction_inference_configurations.txt"
+ f = open(record_path, 'w')
+ np.savetxt(f, [info], delimiter=" ", fmt="%s")
+ f.close()
+
+
+
+if __name__ == '__main__':
+ ''' Configurations for inferencing music effects encoder '''
+ currentdir = os.path.dirname(os.path.realpath(__file__))
+ default_ckpt_path = os.path.join(os.path.dirname(currentdir), 'weights', 'FXencoder_ps.pt')
+
+ import argparse
+ import yaml
+ parser = argparse.ArgumentParser()
+
+ directory_args = parser.add_argument_group('Directory args')
+ directory_args.add_argument('--target_dir', type=str, default='./samples/')
+ directory_args.add_argument('--output_dir', type=str, default=None, help='if no output_dir is specified (None), the results will be saved inside the target_dir')
+ directory_args.add_argument('--ckpt_path_enc', type=str, default=default_ckpt_path)
+
+ inference_args = parser.add_argument_group('Inference args')
+ inference_args.add_argument('--segment_length', type=int, default=44100*10) # segmentize input according to this duration
+ inference_args.add_argument('--batch_size', type=int, default=1) # for processing long audio
+ inference_args.add_argument('--inference_device', type=str, default='cpu', help="if this option is not set to 'cpu', inference will happen on gpu only if there is a detected one")
+
+ args = parser.parse_args()
+
+ # load network configurations
+ with open(os.path.join(currentdir, 'configs.yaml'), 'r') as f:
+ configs = yaml.full_load(f)
+ args.cfg_encoder = configs['Effects_Encoder']['default']
+
+ # Extract features using pre-trained FXencoder
+ inference_encoder = FXencoder_Inference(args)
+ inference_encoder.save_averaged_embeddings()
+
+
\ No newline at end of file
diff --git a/inference/style_transfer.py b/inference/style_transfer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cf6c3dbf1e0e1986141542a57792e3d0c10dd5b
--- /dev/null
+++ b/inference/style_transfer.py
@@ -0,0 +1,400 @@
+"""
+ Inference code of music style transfer
+ of the work "Music Mixing Style Transfer: A Contrastive Learning Approach to Disentangle Audio Effects"
+
+ Process : converts the mixing style of the input music recording to that of the refernce music.
+ files inside the target directory should be organized as follow
+ "path_to_data_directory"/"song_name_#1"/input.wav
+ "path_to_data_directory"/"song_name_#1"/reference.wav
+ ...
+ "path_to_data_directory"/"song_name_#n"/input.wav
+ "path_to_data_directory"/"song_name_#n"/reference.wav
+ where the 'input' and 'reference' should share the same names.
+"""
+import numpy as np
+from glob import glob
+import os
+import torch
+
+import sys
+currentdir = os.path.dirname(os.path.realpath(__file__))
+sys.path.append(os.path.join(os.path.dirname(currentdir), "mixing_style_transfer"))
+from networks import FXencoder, TCNModel
+from data_loader import *
+
+
+
+class Mixing_Style_Transfer_Inference:
+ def __init__(self, args, trained_w_ddp=True):
+ if args.inference_device!='cpu' and torch.cuda.is_available():
+ self.device = torch.device("cuda:0")
+ else:
+ self.device = torch.device("cpu")
+
+ # inference computational hyperparameters
+ self.args = args
+ self.segment_length = args.segment_length
+ self.batch_size = args.batch_size
+ self.sample_rate = 44100 # sampling rate should be 44100
+ self.time_in_seconds = int(args.segment_length // self.sample_rate)
+
+ # directory configuration
+ self.output_dir = args.target_dir if args.output_dir==None else args.output_dir
+ self.target_dir = args.target_dir
+
+ # load model and its checkpoint weights
+ self.models = {}
+ self.models['effects_encoder'] = FXencoder(args.cfg_encoder).to(self.device)
+ self.models['mixing_converter'] = TCNModel(nparams=args.cfg_converter["condition_dimension"], \
+ ninputs=2, \
+ noutputs=2, \
+ nblocks=args.cfg_converter["nblocks"], \
+ dilation_growth=args.cfg_converter["dilation_growth"], \
+ kernel_size=args.cfg_converter["kernel_size"], \
+ channel_width=args.cfg_converter["channel_width"], \
+ stack_size=args.cfg_converter["stack_size"], \
+ cond_dim=args.cfg_converter["condition_dimension"], \
+ causal=args.cfg_converter["causal"]).to(self.device)
+
+ ckpt_paths = {'effects_encoder' : args.ckpt_path_enc, \
+ 'mixing_converter' : args.ckpt_path_conv}
+ # reload saved model weights
+ ddp = trained_w_ddp
+ self.reload_weights(ckpt_paths, ddp=ddp)
+
+ # load data loader for the inference procedure
+ inference_dataset = Song_Dataset_Inference(args)
+ self.data_loader = DataLoader(inference_dataset, \
+ batch_size=1, \
+ shuffle=False, \
+ num_workers=args.workers, \
+ drop_last=False)
+
+ # save current arguments
+ self.save_args(args)
+
+ ''' check stem-wise result '''
+ if not self.args.do_not_separate:
+ os.environ['MKL_THREADING_LAYER'] = 'GNU'
+ separate_file_names = [args.input_file_name, args.reference_file_name]
+ if self.args.interpolation:
+ separate_file_names.append(args.reference_file_name_2interpolate)
+ for cur_idx, cur_inf_dir in enumerate(sorted(glob(f"{args.target_dir}*/"))):
+ for cur_file_name in separate_file_names:
+ cur_sep_file_path = os.path.join(cur_inf_dir, cur_file_name+'.wav')
+ cur_sep_output_dir = os.path.join(cur_inf_dir, args.stem_level_directory_name)
+ if os.path.exists(os.path.join(cur_sep_output_dir, self.args.separation_model, cur_file_name, 'drums.wav')):
+ print(f'\talready separated current file : {cur_sep_file_path}')
+ else:
+ cur_cmd_line = f"demucs {cur_sep_file_path} -n {self.args.separation_model} -d {self.args.separation_device} -o {cur_sep_output_dir}"
+ os.system(cur_cmd_line)
+
+
+ # reload model weights from the target checkpoint path
+ def reload_weights(self, ckpt_paths, ddp=True):
+ for cur_model_name in self.models.keys():
+ checkpoint = torch.load(ckpt_paths[cur_model_name], map_location=self.device)
+
+ from collections import OrderedDict
+ new_state_dict = OrderedDict()
+ for k, v in checkpoint["model"].items():
+ # remove `module.` if the model was trained with DDP
+ name = k[7:] if ddp else k
+ new_state_dict[name] = v
+
+ # load params
+ self.models[cur_model_name].load_state_dict(new_state_dict)
+
+ print(f"---reloaded checkpoint weights : {cur_model_name} ---")
+
+
+ # Inference whole song
+ def inference(self, ):
+ print("\n======= Start to inference music mixing style transfer =======")
+ # normalized input
+ output_name_tag = 'output' if self.args.normalize_input else 'output_notnormed'
+
+ for step, (input_stems, reference_stems, dir_name) in enumerate(self.data_loader):
+ print(f"---inference file name : {dir_name[0]}---")
+ cur_out_dir = dir_name[0].replace(self.target_dir, self.output_dir)
+ os.makedirs(cur_out_dir, exist_ok=True)
+ ''' stem-level inference '''
+ inst_outputs = []
+ for cur_inst_idx, cur_inst_name in enumerate(self.args.instruments):
+ print(f'\t{cur_inst_name}...')
+ ''' segmentize whole songs into batch '''
+ if len(input_stems[0][cur_inst_idx][0]) > self.args.segment_length:
+ cur_inst_input_stem = self.batchwise_segmentization(input_stems[0][cur_inst_idx], \
+ dir_name[0], \
+ segment_length=self.args.segment_length, \
+ discard_last=False)
+ else:
+ cur_inst_input_stem = [input_stems[:, cur_inst_idx]]
+ if len(reference_stems[0][cur_inst_idx][0]) > self.args.segment_length*2:
+ cur_inst_reference_stem = self.batchwise_segmentization(reference_stems[0][cur_inst_idx], \
+ dir_name[0], \
+ segment_length=self.args.segment_length_ref, \
+ discard_last=False)
+ else:
+ cur_inst_reference_stem = [reference_stems[:, cur_inst_idx]]
+
+ ''' inference '''
+ # first extract reference style embedding
+ infered_ref_data_list = []
+ for cur_ref_data in cur_inst_reference_stem:
+ cur_ref_data = cur_ref_data.to(self.device)
+ # Effects Encoder inference
+ with torch.no_grad():
+ self.models["effects_encoder"].eval()
+ reference_feature = self.models["effects_encoder"](cur_ref_data)
+ infered_ref_data_list.append(reference_feature)
+ # compute average value from the extracted exbeddings
+ infered_ref_data = torch.stack(infered_ref_data_list)
+ infered_ref_data_avg = torch.mean(infered_ref_data.reshape(infered_ref_data.shape[0]*infered_ref_data.shape[1], infered_ref_data.shape[2]), axis=0)
+
+ # mixing style converter
+ infered_data_list = []
+ for cur_data in cur_inst_input_stem:
+ cur_data = cur_data.to(self.device)
+ with torch.no_grad():
+ self.models["mixing_converter"].eval()
+ infered_data = self.models["mixing_converter"](cur_data, infered_ref_data_avg.unsqueeze(0))
+ infered_data_list.append(infered_data.cpu().detach())
+
+ # combine back to whole song
+ for cur_idx, cur_batch_infered_data in enumerate(infered_data_list):
+ cur_infered_data_sequential = torch.cat(torch.unbind(cur_batch_infered_data, dim=0), dim=-1)
+ fin_data_out = cur_infered_data_sequential if cur_idx==0 else torch.cat((fin_data_out, cur_infered_data_sequential), dim=-1)
+ # final output of current instrument
+ fin_data_out_inst = fin_data_out[:, :input_stems[0][cur_inst_idx].shape[-1]].numpy()
+
+ inst_outputs.append(fin_data_out_inst)
+ # save output of each instrument
+ if self.args.save_each_inst:
+ sf.write(os.path.join(cur_out_dir, f"{cur_inst_name}_{output_name_tag}.wav"), fin_data_out_inst.transpose(-1, -2), self.args.sample_rate, 'PCM_16')
+ # remix
+ fin_data_out_mix = sum(inst_outputs)
+ sf.write(os.path.join(cur_out_dir, f"mixture_{output_name_tag}.wav"), fin_data_out_mix.transpose(-1, -2), self.args.sample_rate, 'PCM_16')
+
+
+ # Inference whole song
+ def inference_interpolation(self, ):
+ print("\n======= Start to inference interpolation examples =======")
+ # normalized input
+ output_name_tag = 'output_interpolation' if self.args.normalize_input else 'output_notnormed_interpolation'
+
+ for step, (input_stems, reference_stems_A, reference_stems_B, dir_name) in enumerate(self.data_loader):
+ print(f"---inference file name : {dir_name[0]}---")
+ cur_out_dir = dir_name[0].replace(self.target_dir, self.output_dir)
+ os.makedirs(cur_out_dir, exist_ok=True)
+ ''' stem-level inference '''
+ inst_outputs = []
+ for cur_inst_idx, cur_inst_name in enumerate(self.args.instruments):
+ print(f'\t{cur_inst_name}...')
+ ''' segmentize whole song '''
+ # segmentize input according to number of interpolating segments
+ interpolate_segment_length = input_stems[0][cur_inst_idx].shape[1] // self.args.interpolate_segments + 1
+ cur_inst_input_stem = self.batchwise_segmentization(input_stems[0][cur_inst_idx], \
+ dir_name[0], \
+ segment_length=interpolate_segment_length, \
+ discard_last=False)
+ # batchwise segmentize 2 reference tracks
+ if len(reference_stems_A[0][cur_inst_idx][0]) > self.args.segment_length_ref:
+ cur_inst_reference_stem_A = self.batchwise_segmentization(reference_stems_A[0][cur_inst_idx], \
+ dir_name[0], \
+ segment_length=self.args.segment_length_ref, \
+ discard_last=False)
+ else:
+ cur_inst_reference_stem_A = [reference_stems_A[:, cur_inst_idx]]
+ if len(reference_stems_B[0][cur_inst_idx][0]) > self.args.segment_length_ref:
+ cur_inst_reference_stem_B = self.batchwise_segmentization(reference_stems_B[0][cur_inst_idx], \
+ dir_name[0], \
+ segment_length=self.args.segment_length, \
+ discard_last=False)
+ else:
+ cur_inst_reference_stem_B = [reference_stems_B[:, cur_inst_idx]]
+
+ ''' inference '''
+ # first extract reference style embeddings
+ # reference A
+ infered_ref_data_list = []
+ for cur_ref_data in cur_inst_reference_stem_A:
+ cur_ref_data = cur_ref_data.to(self.device)
+ # Effects Encoder inference
+ with torch.no_grad():
+ self.models["effects_encoder"].eval()
+ reference_feature = self.models["effects_encoder"](cur_ref_data)
+ infered_ref_data_list.append(reference_feature)
+ # compute average value from the extracted exbeddings
+ infered_ref_data = torch.stack(infered_ref_data_list)
+ infered_ref_data_avg_A = torch.mean(infered_ref_data.reshape(infered_ref_data.shape[0]*infered_ref_data.shape[1], infered_ref_data.shape[2]), axis=0)
+
+ # reference B
+ infered_ref_data_list = []
+ for cur_ref_data in cur_inst_reference_stem_B:
+ cur_ref_data = cur_ref_data.to(self.device)
+ # Effects Encoder inference
+ with torch.no_grad():
+ self.models["effects_encoder"].eval()
+ reference_feature = self.models["effects_encoder"](cur_ref_data)
+ infered_ref_data_list.append(reference_feature)
+ # compute average value from the extracted exbeddings
+ infered_ref_data = torch.stack(infered_ref_data_list)
+ infered_ref_data_avg_B = torch.mean(infered_ref_data.reshape(infered_ref_data.shape[0]*infered_ref_data.shape[1], infered_ref_data.shape[2]), axis=0)
+
+ # mixing style converter
+ infered_data_list = []
+ for cur_idx, cur_data in enumerate(cur_inst_input_stem):
+ cur_data = cur_data.to(self.device)
+ # perform linear interpolation on embedding space
+ cur_weight = (self.args.interpolate_segments-1-cur_idx) / (self.args.interpolate_segments-1)
+ cur_ref_emb = cur_weight * infered_ref_data_avg_A + (1-cur_weight) * infered_ref_data_avg_B
+ with torch.no_grad():
+ self.models["mixing_converter"].eval()
+ infered_data = self.models["mixing_converter"](cur_data, cur_ref_emb.unsqueeze(0))
+ infered_data_list.append(infered_data.cpu().detach())
+
+ # combine back to whole song
+ for cur_idx, cur_batch_infered_data in enumerate(infered_data_list):
+ cur_infered_data_sequential = torch.cat(torch.unbind(cur_batch_infered_data, dim=0), dim=-1)
+ fin_data_out = cur_infered_data_sequential if cur_idx==0 else torch.cat((fin_data_out, cur_infered_data_sequential), dim=-1)
+ # final output of current instrument
+ fin_data_out_inst = fin_data_out[:, :input_stems[0][cur_inst_idx].shape[-1]].numpy()
+ inst_outputs.append(fin_data_out_inst)
+
+ # save output of each instrument
+ if self.args.save_each_inst:
+ sf.write(os.path.join(cur_out_dir, f"{cur_inst_name}_{output_name_tag}.wav"), fin_data_out_inst.transpose(-1, -2), self.args.sample_rate, 'PCM_16')
+ # remix
+ fin_data_out_mix = sum(inst_outputs)
+ sf.write(os.path.join(cur_out_dir, f"mixture_{output_name_tag}.wav"), fin_data_out_mix.transpose(-1, -2), self.args.sample_rate, 'PCM_16')
+
+
+ # function that segmentize an entire song into batch
+ def batchwise_segmentization(self, target_song, song_name, segment_length, discard_last=False):
+ assert target_song.shape[-1] >= self.args.segment_length, \
+ f"Error : Insufficient duration!\n\t \
+ Target song's length is shorter than segment length.\n\t \
+ Song name : {song_name}\n\t \
+ Consider changing the 'segment_length' or song with sufficient duration"
+
+ # discard restovers (last segment)
+ if discard_last:
+ target_length = target_song.shape[-1] - target_song.shape[-1] % segment_length
+ target_song = target_song[:, :target_length]
+ # pad last segment
+ else:
+ pad_length = segment_length - target_song.shape[-1] % segment_length
+ target_song = torch.cat((target_song, torch.zeros(2, pad_length)), axis=-1)
+
+ # segmentize according to the given segment_length
+ whole_batch_data = []
+ batch_wise_data = []
+ for cur_segment_idx in range(target_song.shape[-1]//segment_length):
+ batch_wise_data.append(target_song[..., cur_segment_idx*segment_length:(cur_segment_idx+1)*segment_length])
+ if len(batch_wise_data)==self.args.batch_size:
+ whole_batch_data.append(torch.stack(batch_wise_data, dim=0))
+ batch_wise_data = []
+ if batch_wise_data:
+ whole_batch_data.append(torch.stack(batch_wise_data, dim=0))
+
+ return whole_batch_data
+
+
+ # save current inference arguments
+ def save_args(self, params):
+ info = '\n[args]\n'
+ for sub_args in parser._action_groups:
+ if sub_args.title in ['positional arguments', 'optional arguments', 'options']:
+ continue
+ size_sub = len(sub_args._group_actions)
+ info += f' {sub_args.title} ({size_sub})\n'
+ for i, arg in enumerate(sub_args._group_actions):
+ prefix = '-'
+ info += f' {prefix} {arg.dest:20s}: {getattr(params, arg.dest)}\n'
+ info += '\n'
+
+ os.makedirs(self.output_dir, exist_ok=True)
+ record_path = f"{self.output_dir}style_transfer_inference_configurations.txt"
+ f = open(record_path, 'w')
+ np.savetxt(f, [info], delimiter=" ", fmt="%s")
+ f.close()
+
+
+
+if __name__ == '__main__':
+ os.environ['MASTER_ADDR'] = '127.0.0.1'
+ os.environ["CUDA_VISIBLE_DEVICES"] = '0'
+ os.environ['MASTER_PORT'] = '8888'
+
+ def str2bool(v):
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ else:
+ raise argparse.ArgumentTypeError('Boolean value expected.')
+
+ ''' Configurations for music mixing style transfer '''
+ currentdir = os.path.dirname(os.path.realpath(__file__))
+ default_ckpt_path_enc = os.path.join(os.path.dirname(currentdir), 'weights', 'FXencoder_ps.pt')
+ default_ckpt_path_conv = os.path.join(os.path.dirname(currentdir), 'weights', 'MixFXcloner_ps.pt')
+ default_norm_feature_path = os.path.join(os.path.dirname(currentdir), 'weights', 'musdb18_fxfeatures_eqcompimagegain.npy')
+
+ import argparse
+ import yaml
+ parser = argparse.ArgumentParser()
+
+ directory_args = parser.add_argument_group('Directory args')
+ # directory paths
+ directory_args.add_argument('--target_dir', type=str, default='./samples/style_transfer/')
+ directory_args.add_argument('--output_dir', type=str, default=None, help='if no output_dir is specified (None), the results will be saved inside the target_dir')
+ directory_args.add_argument('--input_file_name', type=str, default='input')
+ directory_args.add_argument('--reference_file_name', type=str, default='reference')
+ directory_args.add_argument('--reference_file_name_2interpolate', type=str, default='reference_B')
+ # saved weights
+ directory_args.add_argument('--ckpt_path_enc', type=str, default=default_ckpt_path_enc)
+ directory_args.add_argument('--ckpt_path_conv', type=str, default=default_ckpt_path_conv)
+ directory_args.add_argument('--precomputed_normalization_feature', type=str, default=default_norm_feature_path)
+
+ inference_args = parser.add_argument_group('Inference args')
+ inference_args.add_argument('--sample_rate', type=int, default=44100)
+ inference_args.add_argument('--segment_length', type=int, default=2**19) # segmentize input according to this duration
+ inference_args.add_argument('--segment_length_ref', type=int, default=2**19) # segmentize reference according to this duration
+ # stem-level instruments & separation
+ inference_args.add_argument('--instruments', type=str2bool, default=["drums", "bass", "other", "vocals"], help='instrumental tracks to perform style transfer')
+ inference_args.add_argument('--stem_level_directory_name', type=str, default='separated')
+ inference_args.add_argument('--save_each_inst', type=str2bool, default=False)
+ inference_args.add_argument('--do_not_separate', type=str2bool, default=False)
+ inference_args.add_argument('--separation_model', type=str, default='mdx_extra')
+ # FX normalization
+ inference_args.add_argument('--normalize_input', type=str2bool, default=True)
+ inference_args.add_argument('--normalization_order', type=str2bool, default=['loudness', 'eq', 'compression', 'imager', 'loudness']) # Effects to be normalized, order matters
+ # interpolation
+ inference_args.add_argument('--interpolation', type=str2bool, default=False)
+ inference_args.add_argument('--interpolate_segments', type=int, default=30)
+
+ device_args = parser.add_argument_group('Device args')
+ device_args.add_argument('--workers', type=int, default=1)
+ device_args.add_argument('--inference_device', type=str, default='gpu', help="if this option is not set to 'cpu', inference will happen on gpu only if there is a detected one")
+ device_args.add_argument('--batch_size', type=int, default=1) # for processing long audio
+ device_args.add_argument('--separation_device', type=str, default='cpu', help="device for performing source separation using Demucs")
+
+ args = parser.parse_args()
+
+ # load network configurations
+ with open(os.path.join(currentdir, 'configs.yaml'), 'r') as f:
+ configs = yaml.full_load(f)
+ args.cfg_encoder = configs['Effects_Encoder']['default']
+ args.cfg_converter = configs['TCN']['default']
+
+
+ # Perform music mixing style transfer
+ inference_style_transfer = Mixing_Style_Transfer_Inference(args)
+ if args.interpolation:
+ inference_style_transfer.inference_interpolation()
+ else:
+ inference_style_transfer.inference()
+
+
+
diff --git a/mixing_style_transfer/data_loader/__init__.py b/mixing_style_transfer/data_loader/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd5f27a7d2742aaf3301599d1c5c9a8b58aa3ef4
--- /dev/null
+++ b/mixing_style_transfer/data_loader/__init__.py
@@ -0,0 +1,2 @@
+from .data_loader import *
+from .loader_utils import *
\ No newline at end of file
diff --git a/mixing_style_transfer/data_loader/data_loader.py b/mixing_style_transfer/data_loader/data_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..5821e7f20a7a40a4c0483d97d4c35b9c3b1aeddc
--- /dev/null
+++ b/mixing_style_transfer/data_loader/data_loader.py
@@ -0,0 +1,672 @@
+"""
+ Data Loaders for
+ 1. contrastive learning of audio effects
+ 2. music mixing style transfer
+ introduced in "Music Mixing Style Transfer: A Contrastive Learning Approach to Disentangle Audio Effects"
+"""
+import numpy as np
+import wave
+import soundfile as sf
+import time
+import random
+from glob import glob
+
+import torch
+import torch.utils.data as data
+from torch.utils.data import DataLoader
+from torch.utils.data import Dataset
+
+import os
+import sys
+currentdir = os.path.dirname(os.path.realpath(__file__))
+sys.path.append(currentdir)
+sys.path.append(os.path.dirname(currentdir))
+sys.path.append(os.path.dirname(os.path.dirname(currentdir)))
+from loader_utils import *
+from mixing_manipulator import *
+
+
+
+'''
+ Collate Functions
+'''
+class Collate_Variable_Length_Segments:
+ def __init__(self, args):
+ self.segment_length = args.segment_length
+ self.random_length = args.reference_length
+ self.num_strong_negatives = args.num_strong_negatives
+ if 'musdb' in args.using_dataset.lower():
+ self.instruments = ["drums", "bass", "other", "vocals"]
+ else:
+ raise NotImplementedError
+
+
+ # collate function to trim segments A and B to random duration
+ # this function can handle different number of 'strong negative' inputs
+ def random_duration_segments_strong_negatives(self, batch):
+ num_inst = len(self.instruments)
+ # randomize current input length
+ max_length = batch[0][0].shape[-1]
+ min_length = max_length//2
+ input_length_a, input_length_b = torch.randint(low=min_length, high=max_length, size=(2,))
+
+ output_dict_A = {}
+ output_dict_B = {}
+ for cur_inst in self.instruments:
+ output_dict_A[cur_inst] = []
+ output_dict_B[cur_inst] = []
+ for cur_item in batch:
+ # set starting points
+ start_point_a = torch.randint(low=0, high=max_length-input_length_a, size=(1,))[0]
+ start_point_b = torch.randint(low=0, high=max_length-input_length_b, size=(1,))[0]
+ # append to output dictionary
+ for cur_i, cur_inst in enumerate(self.instruments):
+ # append A# and B# with its strong negative samples
+ for cur_neg_idx in range(self.num_strong_negatives+1):
+ output_dict_A[cur_inst].append(cur_item[cur_i*(self.num_strong_negatives+1)*2+2*cur_neg_idx][:, start_point_a : start_point_a+input_length_a])
+ output_dict_B[cur_inst].append(cur_item[cur_i*(self.num_strong_negatives+1)*2+1+2*cur_neg_idx][:, start_point_b : start_point_b+input_length_b])
+
+ '''
+ Output format :
+ [drums_A, bass_A, other_A, vocals_A],
+ [drums_B, bass_B, other_B, vocals_B]
+ '''
+ return [torch.stack(cur_segments, dim=0) for cur_inst, cur_segments in output_dict_A.items()], \
+ [torch.stack(cur_segments, dim=0) for cur_inst, cur_segments in output_dict_B.items()]
+
+
+ # collate function for training mixing style transfer
+ def style_transfer_collate(self, batch):
+ output_dict_A1 = {}
+ output_dict_A2 = {}
+ output_dict_B2 = {}
+ for cur_inst in self.instruments:
+ output_dict_A1[cur_inst] = []
+ output_dict_A2[cur_inst] = []
+ output_dict_B2[cur_inst] = []
+ for cur_item in batch:
+ # append to output dictionary
+ for cur_i, cur_inst in enumerate(self.instruments):
+ output_dict_A1[cur_inst].append(cur_item[cur_i*3])
+ output_dict_A2[cur_inst].append(cur_item[cur_i*3+1])
+ output_dict_B2[cur_inst].append(cur_item[cur_i*3+2])
+
+ '''
+ Output format :
+ [drums_A1, bass_A1, other_A1, vocals_A1],
+ [drums_A2, bass_A2, other_A2, vocals_A2],
+ [drums_B2, bass_B2, other_B2, vocals_B2]
+ '''
+ return [torch.stack(cur_segments, dim=0) for cur_inst, cur_segments in output_dict_A1.items()], \
+ [torch.stack(cur_segments, dim=0) for cur_inst, cur_segments in output_dict_A2.items()], \
+ [torch.stack(cur_segments, dim=0) for cur_inst, cur_segments in output_dict_B2.items()]
+
+
+'''
+ Data Loaders
+'''
+
+# Data loader for training the 'FXencoder'
+ # randomly loads two segments (A and B) from the dataset
+ # both segments are manipulated via FXmanipulator using (1+number of strong negative samples) sets of parameters (resulting A1, A2, ..., A#, and B1, B2, ..., B#) (# = number of strong negative samples)
+ # segments with the same effects applied (A1 and B1) are assigned as the positive pair during the training
+ # segments with the same content but with different effects applied (A2, A3, ..., A3 for A1) are also formed in a batch as 'strong negative' samples
+ # in the paper, we use strong negative samples = 1
+class MUSDB_Dataset_Mixing_Manipulated_FXencoder(Dataset):
+ def __init__(self, args, \
+ mode, \
+ applying_effects='full', \
+ apply_prob_dict=None):
+ self.args = args
+ self.data_dir = args.data_dir + mode + "/"
+ self.mode = mode
+ self.applying_effects = applying_effects
+ self.normalization_order = args.normalization_order
+ self.fixed_random_seed = args.random_seed
+ self.pad_b4_manipulation = args.pad_b4_manipulation
+ self.pad_length = 2048
+
+ if 'musdb' in args.using_dataset.lower():
+ self.instruments = ["drums", "bass", "other", "vocals"]
+ else:
+ raise NotImplementedError
+
+ # path to contents
+ self.data_paths = {}
+ self.data_length_ratio_list = {}
+ # load data paths for each instrument
+ for cur_inst in self.instruments:
+ self.data_paths[cur_inst] = glob(f'{self.data_dir}{cur_inst}_normalized_{self.normalization_order}_silence_trimmed*.wav') \
+ if args.use_normalized else glob(f'{self.data_dir}{cur_inst}_silence_trimmed*.wav')
+ self.data_length_ratio_list[cur_inst] = []
+ # compute audio duration and its ratio
+ for cur_file_path in self.data_paths[cur_inst]:
+ cur_wav_length = load_wav_length(cur_file_path)
+ cur_inst_length_ratio = cur_wav_length / get_total_audio_length(self.data_paths[cur_inst])
+ self.data_length_ratio_list[cur_inst].append(cur_inst_length_ratio)
+
+ # load effects chain
+ if applying_effects=='full':
+ if apply_prob_dict==None:
+ # initial (default) applying probabilities of each FX
+ apply_prob_dict = {'eq' : 0.9, \
+ 'comp' : 0.9, \
+ 'pan' : 0.3, \
+ 'imager' : 0.8, \
+ 'gain': 0.5}
+ reverb_prob = {'drums' : 0.5, \
+ 'bass' : 0.01, \
+ 'vocals' : 0.9, \
+ 'other' : 0.7}
+
+ self.mixing_manipulator = {}
+ for cur_inst in self.data_paths.keys():
+ if 'reverb' in apply_prob_dict.keys():
+ if cur_inst=='drums':
+ cur_reverb_weight = 0.5
+ elif cur_inst=='bass':
+ cur_reverb_weight = 0.1
+ else:
+ cur_reverb_weight = 1.0
+ apply_prob_dict['reverb'] *= cur_reverb_weight
+ else:
+ apply_prob_dict['reverb'] = reverb_prob[cur_inst]
+ # create FXmanipulator for current instrument
+ self.mixing_manipulator[cur_inst] = create_inst_effects_augmentation_chain_(cur_inst, \
+ apply_prob_dict=apply_prob_dict, \
+ ir_dir_path=args.ir_dir_path, \
+ sample_rate=args.sample_rate)
+ # for single effects
+ else:
+ self.mixing_manipulator = {}
+ if not isinstance(applying_effects, list):
+ applying_effects = [applying_effects]
+ for cur_inst in self.data_paths.keys():
+ self.mixing_manipulator[cur_inst] = create_effects_augmentation_chain(applying_effects, \
+ ir_dir_path=args.ir_dir_path)
+
+
+ def __len__(self):
+ if self.mode=='train':
+ return self.args.batch_size_total * 40
+ else:
+ return self.args.batch_size_total
+
+
+ def __getitem__(self, idx):
+ if self.mode=="train":
+ torch.manual_seed(int(time.time())*(idx+1) % (2**32-1))
+ np.random.seed(int(time.time())*(idx+1) % (2**32-1))
+ random.seed(int(time.time())*(idx+1) % (2**32-1))
+ else:
+ # fixed random seed for evaluation
+ torch.manual_seed(idx*self.fixed_random_seed)
+ np.random.seed(idx*self.fixed_random_seed)
+ random.seed(idx*self.fixed_random_seed)
+
+ manipulated_segments = {}
+ for cur_neg_idx in range(self.args.num_strong_negatives+1):
+ manipulated_segments[cur_neg_idx] = {}
+
+ # load already-saved data to save time for on-the-fly manipulation
+ cur_data_dir_path = f"{self.data_dir}manipulated_encoder/{self.args.data_save_name}/{self.applying_effects}/{idx}/"
+ if self.mode=="val" and os.path.exists(cur_data_dir_path):
+ for cur_inst in self.instruments:
+ for cur_neg_idx in range(self.args.num_strong_negatives+1):
+ cur_A_file_path = f"{cur_data_dir_path}{cur_inst}_A{cur_neg_idx+1}.wav"
+ cur_B_file_path = f"{cur_data_dir_path}{cur_inst}_B{cur_neg_idx+1}.wav"
+ cur_A = load_wav_segment(cur_A_file_path, axis=0, sample_rate=self.args.sample_rate)
+ cur_B = load_wav_segment(cur_B_file_path, axis=0, sample_rate=self.args.sample_rate)
+ manipulated_segments[cur_neg_idx][cur_inst] = [torch.from_numpy(cur_A).float(), torch.from_numpy(cur_B).float()]
+ else:
+ # repeat for number of instruments
+ for cur_inst, cur_paths in self.data_paths.items():
+ # choose file_path to be loaded
+ cur_chosen_paths = np.random.choice(cur_paths, 2, p = self.data_length_ratio_list[cur_inst])
+ # get random 2 starting points for each instrument
+ last_point_A = load_wav_length(cur_chosen_paths[0])-self.args.segment_length_ref
+ last_point_B = load_wav_length(cur_chosen_paths[1])-self.args.segment_length_ref
+ # simply load more data to prevent artifacts likely to be caused by the manipulator
+ if self.pad_b4_manipulation:
+ last_point_A -= self.pad_length*2
+ last_point_B -= self.pad_length*2
+ cur_inst_start_point_A = torch.randint(low=0, \
+ high=last_point_A, \
+ size=(1,))[0]
+ cur_inst_start_point_B = torch.randint(low=0, \
+ high=last_point_B, \
+ size=(1,))[0]
+ # load wav segments from the selected starting points
+ load_duration = self.args.segment_length_ref+self.pad_length*2 if self.pad_b4_manipulation else self.args.segment_length_ref
+ cur_inst_segment_A = load_wav_segment(cur_chosen_paths[0], \
+ start_point=cur_inst_start_point_A, \
+ duration=load_duration, \
+ axis=1, \
+ sample_rate=self.args.sample_rate)
+ cur_inst_segment_B = load_wav_segment(cur_chosen_paths[1], \
+ start_point=cur_inst_start_point_B, \
+ duration=load_duration, \
+ axis=1, \
+ sample_rate=self.args.sample_rate)
+ # mixing manipulation
+ # append A# and B# with its strong negative samples
+ for cur_neg_idx in range(self.args.num_strong_negatives+1):
+ cur_manipulated_segment_A, cur_manipulated_segment_B = self.mixing_manipulator[cur_inst]([cur_inst_segment_A, cur_inst_segment_B])
+
+ # remove over-loaded area
+ if self.pad_b4_manipulation:
+ cur_manipulated_segment_A = cur_manipulated_segment_A[self.pad_length:-self.pad_length]
+ cur_manipulated_segment_B = cur_manipulated_segment_B[self.pad_length:-self.pad_length]
+ manipulated_segments[cur_neg_idx][cur_inst] = [torch.clamp(torch.transpose(torch.from_numpy(cur_manipulated_segment_A).float(), 1, 0), min=-1, max=1), \
+ torch.clamp(torch.transpose(torch.from_numpy(cur_manipulated_segment_B).float(), 1, 0), min=-1, max=1)]
+
+ # check manipulated data by saving them
+ if self.mode=="val" and not os.path.exists(cur_data_dir_path):
+ os.makedirs(cur_dir_path, exist_ok=True)
+ for cur_inst in manipulated_segments[0].keys():
+ for cur_manipulated_key, cur_manipualted_dict in manipulated_segments.items():
+ sf.write(f"{cur_dir_path}{cur_inst}_A{cur_manipulated_key+1}.wav", torch.transpose(cur_manipualted_dict[cur_inst][0], 1, 0), self.args.sample_rate, 'PCM_16')
+ sf.write(f"{cur_dir_path}{cur_inst}_B{cur_manipulated_key+1}.wav", torch.transpose(cur_manipualted_dict[cur_inst][1], 1, 0), self.args.sample_rate, 'PCM_16')
+
+ output_list = []
+ output_list_param = []
+ for cur_inst in manipulated_segments[0].keys():
+ for cur_manipulated_key, cur_manipualted_dict in manipulated_segments.items():
+ output_list.extend(cur_manipualted_dict[cur_inst])
+
+ '''
+ Output format:
+ list of effects manipulated stems of each instrument
+ drums_A1, drums_B1, drums_A2, drums_B2, drums_A3, drums_B3, ... ,
+ bass_A1, bass_B1, bass_A2, bass_B2, bass_A3, bass_B3, ... ,
+ other_A1, other_B1, other_A2, other_B2, other_A3, other_B3, ... ,
+ vocals_A1, vocals_B1, vocals_A2, vocals_B2, vocals_A3, vocals_B3, ...
+ each stem has the shape of (number of channels, segment duration)
+ '''
+ return output_list
+
+
+ # generate random manipulated results for evaluation
+ def generate_contents_w_effects(self, num_content, num_effects, out_dir):
+ print(f"start generating random effects of {self.applying_effects} applied contents")
+ os.makedirs(out_dir, exist_ok=True)
+
+ manipulated_segments = {}
+ for cur_fx_idx in range(num_effects):
+ manipulated_segments[cur_fx_idx] = {}
+ # repeat for number of instruments
+ for cur_inst, cur_paths in self.data_paths.items():
+ # choose file_path to be loaded
+ cur_path = np.random.choice(cur_paths, 1, p = self.data_length_ratio_list[cur_inst])[0]
+ print(f"\tgenerating instrument : {cur_inst}")
+ # get random 2 starting points for each instrument
+ last_point = load_wav_length(cur_path)-self.args.segment_length_ref
+ # simply load more data to prevent artifacts likely to be caused by the manipulator
+ if self.pad_b4_manipulation:
+ last_point -= self.pad_length*2
+ cur_inst_start_points = torch.randint(low=0, \
+ high=last_point, \
+ size=(num_content,))
+ # load wav segments from the selected starting points
+ cur_inst_segments = []
+ for cur_num_content in range(num_content):
+ cur_ori_sample = load_wav_segment(cur_path, \
+ start_point=cur_inst_start_points[cur_num_content], \
+ duration=self.args.segment_length_ref, \
+ axis=1, \
+ sample_rate=self.args.sample_rate)
+ cur_inst_segments.append(cur_ori_sample)
+
+ sf.write(f"{out_dir}{cur_inst}_ori_{cur_num_content}.wav", cur_ori_sample, self.args.sample_rate, 'PCM_16')
+
+ # mixing manipulation
+ for cur_fx_idx in range(num_effects):
+ cur_manipulated_segments = self.mixing_manipulator[cur_inst](cur_inst_segments)
+ # remove over-loaded area
+ if self.pad_b4_manipulation:
+ for cur_man_idx in range(len(cur_manipulated_segments)):
+ cur_segment_trimmed = cur_manipulated_segments[cur_man_idx][self.pad_length:-self.pad_length]
+ cur_manipulated_segments[cur_man_idx] = torch.clamp(torch.transpose(torch.from_numpy(cur_segment_trimmed).float(), 1, 0), min=-1, max=1)
+ manipulated_segments[cur_fx_idx][cur_inst] = cur_manipulated_segments
+
+ # write generated data
+ # save each instruments
+ for cur_inst in manipulated_segments[0].keys():
+ for cur_manipulated_key, cur_manipualted_dict in manipulated_segments.items():
+ for cur_content_idx in range(num_content):
+ sf.write(f"{out_dir}{cur_inst}_{chr(65+cur_content_idx//26)}{chr(65+cur_content_idx%26)}{cur_manipulated_key+1}.wav", torch.transpose(cur_manipualted_dict[cur_inst][cur_content_idx], 1, 0), self.args.sample_rate, 'PCM_16')
+ # save mixture
+ for cur_manipulated_key, cur_manipualted_dict in manipulated_segments.items():
+ for cur_content_idx in range(num_content):
+ for cur_idx, cur_inst in enumerate(manipulated_segments[0].keys()):
+ if cur_idx==0:
+ cur_mixture = cur_manipualted_dict[cur_inst][cur_content_idx]
+ else:
+ cur_mixture += cur_manipualted_dict[cur_inst][cur_content_idx]
+ sf.write(f"{out_dir}mixture_{chr(65+cur_content_idx//26)}{chr(65+cur_content_idx%26)}{cur_manipulated_key+1}.wav", torch.transpose(cur_mixture, 1, 0), self.args.sample_rate, 'PCM_16')
+
+ return
+
+
+
+# Data loader for training the 'Mastering Style Converter'
+ # loads two segments (A and B) from the dataset
+ # both segments are manipulated via Mastering Effects Manipulator (resulting A1, A2, and B2)
+ # one of the manipulated segment is used as a reference segment (B2), which is randomly manipulated the same as the ground truth segment (A2)
+class MUSDB_Dataset_Mixing_Manipulated_Style_Transfer(Dataset):
+ def __init__(self, args, \
+ mode, \
+ applying_effects='full', \
+ apply_prob_dict=None):
+ self.args = args
+ self.data_dir = args.data_dir + mode + "/"
+ self.mode = mode
+ self.applying_effects = applying_effects
+ self.fixed_random_seed = args.random_seed
+ self.pad_b4_manipulation = args.pad_b4_manipulation
+ self.pad_length = 2048
+
+ if 'musdb' in args.using_dataset.lower():
+ self.instruments = ["drums", "bass", "other", "vocals"]
+ else:
+ raise NotImplementedError
+
+ # load data paths for each instrument
+ self.data_paths = {}
+ self.data_length_ratio_list = {}
+ for cur_inst in self.instruments:
+ self.data_paths[cur_inst] = glob(f'{self.data_dir}{cur_inst}_normalized_{self.args.normalization_order}_silence_trimmed*.wav') \
+ if args.use_normalized else glob(f'{self.data_dir}{cur_inst}_silence_trimmed.wav')
+ self.data_length_ratio_list[cur_inst] = []
+ # compute audio duration and its ratio
+ for cur_file_path in self.data_paths[cur_inst]:
+ cur_wav_length = load_wav_length(cur_file_path)
+ cur_inst_length_ratio = cur_wav_length / get_total_audio_length(self.data_paths[cur_inst])
+ self.data_length_ratio_list[cur_inst].append(cur_inst_length_ratio)
+
+ self.mixing_manipulator = {}
+ if applying_effects=='full':
+ if apply_prob_dict==None:
+ # initial (default) applying probabilities of each FX
+ # we don't update these probabilities for training the MixFXcloner
+ apply_prob_dict = {'eq' : 0.9, \
+ 'comp' : 0.9, \
+ 'pan' : 0.3, \
+ 'imager' : 0.8, \
+ 'gain': 0.5}
+ reverb_prob = {'drums' : 0.5, \
+ 'bass' : 0.01, \
+ 'vocals' : 0.9, \
+ 'other' : 0.7}
+ for cur_inst in self.data_paths.keys():
+ if 'reverb' in apply_prob_dict.keys():
+ if cur_inst=='drums':
+ cur_reverb_weight = 0.5
+ elif cur_inst=='bass':
+ cur_reverb_weight = 0.1
+ else:
+ cur_reverb_weight = 1.0
+ apply_prob_dict['reverb'] *= cur_reverb_weight
+ else:
+ apply_prob_dict['reverb'] = reverb_prob[cur_inst]
+ self.mixing_manipulator[cur_inst] = create_inst_effects_augmentation_chain(cur_inst, \
+ apply_prob_dict=apply_prob_dict, \
+ ir_dir_path=args.ir_dir_path, \
+ sample_rate=args.sample_rate)
+ # for single effects
+ else:
+ if not isinstance(applying_effects, list):
+ applying_effects = [applying_effects]
+ for cur_inst in self.data_paths.keys():
+ self.mixing_manipulator[cur_inst] = create_effects_augmentation_chain(applying_effects, \
+ ir_dir_path=args.ir_dir_path)
+
+
+ def __len__(self):
+ min_length = get_total_audio_length(glob(f'{self.data_dir}vocals_normalized_{self.args.normalization_order}*.wav'))
+ data_len = min_length // self.args.segment_length
+ return data_len
+
+
+ def __getitem__(self, idx):
+ if self.mode=="train":
+ torch.manual_seed(int(time.time())*(idx+1) % (2**32-1))
+ np.random.seed(int(time.time())*(idx+1) % (2**32-1))
+ random.seed(int(time.time())*(idx+1) % (2**32-1))
+ else:
+ # fixed random seed for evaluation
+ torch.manual_seed(idx*self.fixed_random_seed)
+ np.random.seed(idx*self.fixed_random_seed)
+ random.seed(idx*self.fixed_random_seed)
+
+ manipulated_segments = {}
+
+ # load already-saved data to save time for on-the-fly manipulation
+ cur_data_dir_path = f"{self.data_dir}manipulated_converter/{self.args.data_save_name}/{self.applying_effects}/{idx}/"
+ if self.mode=="val" and os.path.exists(cur_data_dir_path):
+ for cur_inst in self.instruments:
+ cur_A1_file_path = f"{cur_data_dir_path}{cur_inst}_A1.wav"
+ cur_A2_file_path = f"{cur_data_dir_path}{cur_inst}_A2.wav"
+ cur_B2_file_path = f"{cur_data_dir_path}{cur_inst}_B2.wav"
+ cur_manipulated_segment_A1 = load_wav_segment(cur_A1_file_path, axis=0, sample_rate=self.args.sample_rate)
+ cur_manipulated_segment_A2 = load_wav_segment(cur_A2_file_path, axis=0, sample_rate=self.args.sample_rate)
+ cur_manipulated_segment_B2 = load_wav_segment(cur_B2_file_path, axis=0, sample_rate=self.args.sample_rate)
+ manipulated_segments[cur_inst] = [torch.from_numpy(cur_manipulated_segment_A1).float(), \
+ torch.from_numpy(cur_manipulated_segment_A2).float(), \
+ torch.from_numpy(cur_manipulated_segment_B2).float()]
+ else:
+ # repeat for number of instruments
+ for cur_inst, cur_paths in self.data_paths.items():
+ # choose file_path to be loaded
+ cur_chosen_paths = np.random.choice(cur_paths, 2, p = self.data_length_ratio_list[cur_inst])
+ # cur_chosen_paths = [cur_paths[idx], cur_paths[idx+1]]
+ # get random 2 starting points for each instrument
+ last_point_A = load_wav_length(cur_chosen_paths[0])-self.args.segment_length_ref
+ last_point_B = load_wav_length(cur_chosen_paths[1])-self.args.segment_length_ref
+ # simply load more data to prevent artifacts likely to be caused by the manipulator
+ if self.pad_b4_manipulation:
+ last_point_A -= self.pad_length*2
+ last_point_B -= self.pad_length*2
+ cur_inst_start_point_A = torch.randint(low=0, \
+ high=last_point_A, \
+ size=(1,))[0]
+ cur_inst_start_point_B = torch.randint(low=0, \
+ high=last_point_B, \
+ size=(1,))[0]
+ # load wav segments from the selected starting points
+ load_duration = self.args.segment_length_ref+self.pad_length*2 if self.pad_b4_manipulation else self.args.segment_length_ref
+ cur_inst_segment_A = load_wav_segment(cur_chosen_paths[0], \
+ start_point=cur_inst_start_point_A, \
+ duration=load_duration, \
+ axis=1, \
+ sample_rate=self.args.sample_rate)
+ cur_inst_segment_B = load_wav_segment(cur_chosen_paths[1], \
+ start_point=cur_inst_start_point_B, \
+ duration=load_duration, \
+ axis=1, \
+ sample_rate=self.args.sample_rate)
+ ''' mixing manipulation '''
+ # manipulate segment A and B to produce
+ # input : A1 (normalized sample)
+ # ground truth : A2
+ # reference : B2
+ cur_manipulated_segment_A1 = cur_inst_segment_A
+ cur_manipulated_segment_A2, cur_manipulated_segment_B2 = self.mixing_manipulator[cur_inst]([cur_inst_segment_A, cur_inst_segment_B])
+ # remove over-loaded area
+ if self.pad_b4_manipulation:
+ cur_manipulated_segment_A1 = cur_manipulated_segment_A1[self.pad_length:-self.pad_length]
+ cur_manipulated_segment_A2 = cur_manipulated_segment_A2[self.pad_length:-self.pad_length]
+ cur_manipulated_segment_B2 = cur_manipulated_segment_B2[self.pad_length:-self.pad_length]
+ manipulated_segments[cur_inst] = [torch.clamp(torch.transpose(torch.from_numpy(cur_manipulated_segment_A1).float(), 1, 0), min=-1, max=1), \
+ torch.clamp(torch.transpose(torch.from_numpy(cur_manipulated_segment_A2).float(), 1, 0), min=-1, max=1), \
+ torch.clamp(torch.transpose(torch.from_numpy(cur_manipulated_segment_B2).float(), 1, 0), min=-1, max=1)]
+
+ # check manipulated data by saving them
+ if (self.mode=="val" and not os.path.exists(cur_data_dir_path)):
+ mixture_dict = {}
+ for cur_inst in manipulated_segments.keys():
+ cur_inst_dir_path = f"{cur_data_dir_path}{idx}/{cur_inst}/"
+ os.makedirs(cur_inst_dir_path, exist_ok=True)
+ sf.write(f"{cur_inst_dir_path}A1.wav", torch.transpose(manipulated_segments[cur_inst][0], 1, 0), self.args.sample_rate, 'PCM_16')
+ sf.write(f"{cur_inst_dir_path}A2.wav", torch.transpose(manipulated_segments[cur_inst][1], 1, 0), self.args.sample_rate, 'PCM_16')
+ sf.write(f"{cur_inst_dir_path}B2.wav", torch.transpose(manipulated_segments[cur_inst][2], 1, 0), self.args.sample_rate, 'PCM_16')
+ mixture_dict['A1'] = torch.transpose(manipulated_segments[cur_inst][0], 1, 0)
+ mixture_dict['A2'] = torch.transpose(manipulated_segments[cur_inst][1], 1, 0)
+ mixture_dict['B2'] = torch.transpose(manipulated_segments[cur_inst][2], 1, 0)
+ cur_mix_dir_path = f"{cur_data_dir_path}{idx}/mixture/"
+ os.makedirs(cur_mix_dir_path, exist_ok=True)
+ sf.write(f"{cur_mix_dir_path}A1.wav", mixture_dict['A1'], self.args.sample_rate, 'PCM_16')
+ sf.write(f"{cur_mix_dir_path}A2.wav", mixture_dict['A2'], self.args.sample_rate, 'PCM_16')
+ sf.write(f"{cur_mix_dir_path}B2.wav", mixture_dict['B2'], self.args.sample_rate, 'PCM_16')
+
+ output_list = []
+ for cur_inst in manipulated_segments.keys():
+ output_list.extend(manipulated_segments[cur_inst])
+
+ '''
+ Output format:
+ list of effects manipulated stems of each instrument
+ drums_A1, drums_A2, drums_B2,
+ bass_A1, bass_A2, bass_B2,
+ other_A1, other_A2, other_B2,
+ vocals_A1, vocals_A2, vocals_B2,
+ each stem has the shape of (number of channels, segment duration)
+ Notation :
+ A1 = input of the network
+ A2 = ground truth
+ B2 = reference track
+ '''
+ return output_list
+
+
+
+# Data loader for inferencing the task 'Mixing Style Transfer'
+### loads whole mixture or stems from the target directory
+class Song_Dataset_Inference(Dataset):
+ def __init__(self, args):
+ self.args = args
+ self.data_dir = args.target_dir
+ self.interpolate = args.interpolation
+
+ self.instruments = args.instruments
+
+ self.data_dir_paths = sorted(glob(f"{self.data_dir}*/"))
+
+ self.input_name = args.input_file_name
+ self.reference_name = args.reference_file_name
+ self.stem_level_directory_name = args.stem_level_directory_name \
+ if self.args.do_not_separate else os.path.join(args.stem_level_directory_name, args.separation_model)
+ if self.interpolate:
+ self.reference_name_B = args.reference_file_name_2interpolate
+
+ # audio effects normalizer
+ if args.normalize_input:
+ self.normalization_chain = Audio_Effects_Normalizer(precomputed_feature_path=args.precomputed_normalization_feature, \
+ STEMS=args.instruments, \
+ EFFECTS=args.normalization_order)
+
+
+ def __len__(self):
+ return len(self.data_dir_paths)
+
+
+ def __getitem__(self, idx):
+ ''' stem-level conversion '''
+ input_stems = []
+ reference_stems = []
+ reference_B_stems = []
+ for cur_inst in self.instruments:
+ cur_input_file_path = os.path.join(self.data_dir_paths[idx], self.stem_level_directory_name, self.input_name, cur_inst+'.wav')
+ cur_reference_file_path = os.path.join(self.data_dir_paths[idx], self.stem_level_directory_name, self.reference_name, cur_inst+'.wav')
+
+ # load wav
+ cur_input_wav = load_wav_segment(cur_input_file_path, axis=0, sample_rate=self.args.sample_rate)
+ cur_reference_wav = load_wav_segment(cur_reference_file_path, axis=0, sample_rate=self.args.sample_rate)
+
+ if self.args.normalize_input:
+ cur_input_wav = self.normalization_chain.normalize_audio(cur_input_wav.transpose(), src=cur_inst).transpose()
+
+ input_stems.append(torch.clamp(torch.from_numpy(cur_input_wav).float(), min=-1, max=1))
+ reference_stems.append(torch.clamp(torch.from_numpy(cur_reference_wav).float(), min=-1, max=1))
+
+ # for interpolation
+ if self.interpolate:
+ cur_reference_B_file_path = os.path.join(self.data_dir_paths[idx], self.stem_level_directory_name, self.reference_name_B, cur_inst+'.wav')
+ cur_reference_B_wav = load_wav_segment(cur_reference_B_file_path, axis=0, sample_rate=self.args.sample_rate)
+ reference_B_stems.append(torch.clamp(torch.from_numpy(cur_reference_B_wav).float(), min=-1, max=1))
+
+ dir_name = os.path.dirname(self.data_dir_paths[idx])
+
+ if self.interpolate:
+ return torch.stack(input_stems, dim=0), torch.stack(reference_stems, dim=0), torch.stack(reference_B_stems, dim=0), dir_name
+ else:
+ return torch.stack(input_stems, dim=0), torch.stack(reference_stems, dim=0), dir_name
+
+
+
+# check dataset
+if __name__ == '__main__':
+ """
+ Test code of data loaders
+ """
+ import time
+ print('checking dataset...')
+
+ total_epochs = 1
+ bs = 5
+ check_step_size = 3
+ collate_class = Collate_Variable_Length_Segments(args)
+
+
+ print('\n========== Effects Encoder ==========')
+ from config import args
+ ##### generate samples with ranfom configuration
+ # args.normalization_order = 'eqcompimagegain'
+ # for cur_effect in ['full', 'gain', 'comp', 'reverb', 'eq', 'imager', 'pan']:
+ # start_time = time.time()
+ # dataset = MUSDB_Dataset_Mixing_Manipulated_FXencoder(args, mode='val', applying_effects=cur_effect, check_data=True)
+ # dataset.generate_contents_w_effects(num_content=25, num_effects=10)
+ # print(f'\t---time taken : {time.time()-start_time}---')
+
+ ### training data loder
+ dataset = MUSDB_Dataset_Mixing_Manipulated_FXencoder(args, mode='train', applying_effects=['comp'])
+ data_loader = DataLoader(dataset, \
+ batch_size=bs, \
+ shuffle=False, \
+ collate_fn=collate_class.random_duration_segments_strong_negatives, \
+ drop_last=False, \
+ num_workers=0)
+
+ for epoch in range(total_epochs):
+ start_time_loader = time.time()
+ for step, output_list in enumerate(data_loader):
+ if step==check_step_size:
+ break
+ print(f'Epoch {epoch+1}/{total_epochs}\tStep {step+1}/{len(data_loader)}')
+ print(f'num contents : {len(output_list)}\tnum instruments : {len(output_list[0])}\tcontent A shape : {output_list[0][0].shape}\t content B shape : {output_list[1][0].shape} \ttime taken: {time.time()-start_time_loader:.4f}')
+ start_time_loader = time.time()
+
+
+ print('\n========== Mixing Style Transfer ==========')
+ from trainer_mixing_transfer.config_conv import args
+ ### training data loder
+ dataset = MUSDB_Dataset_Mixing_Manipulated_Style_Transfer(args, mode='train')
+ data_loader = DataLoader(dataset, \
+ batch_size=bs, \
+ shuffle=False, \
+ collate_fn=collate_class.style_transfer_collate, \
+ drop_last=False, \
+ num_workers=0)
+
+ for epoch in range(total_epochs):
+ start_time_loader = time.time()
+ for step, output_list in enumerate(data_loader):
+ if step==check_step_size:
+ break
+ print(f'Epoch {epoch+1}/{total_epochs}\tStep {step+1}/{len(data_loader)}')
+ print(f'num contents : {len(output_list)}\tnum instruments : {len(output_list[0])}\tA1 shape : {output_list[0][0].shape}\tA2 shape : {output_list[1][0].shape}\tA3 shape : {output_list[2][0].shape}\ttime taken: {time.time()-start_time_loader:.4f}')
+ start_time_loader = time.time()
+
+
+ print('\n--- checking dataset completed ---')
+
diff --git a/mixing_style_transfer/data_loader/loader_utils.py b/mixing_style_transfer/data_loader/loader_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed63ce98c5e3f6e56d82f47ac7e38bfc7b76e8f2
--- /dev/null
+++ b/mixing_style_transfer/data_loader/loader_utils.py
@@ -0,0 +1,71 @@
+""" Utility file for loaders """
+
+import numpy as np
+import soundfile as sf
+import wave
+
+
+
+# Function to convert frame level audio into atomic time
+def frames_to_time(total_length, sr=44100):
+ in_time = total_length / sr
+ hour = int(in_time / 3600)
+ minute = int((in_time - hour*3600) / 60)
+ second = int(in_time - hour*3600 - minute*60)
+ return f"{hour:02d}:{minute:02d}:{second:02d}"
+
+
+# Function to convert atomic labeled time into frames or seconds
+def time_to_frames(input_time, to_frames=True, sr=44100):
+ hour, minute, second = input_time.split(':')
+ total_seconds = int(hour)*3600 + int(minute)*60 + int(second)
+ return total_seconds*sr if to_frames else total_seconds
+
+
+# Function to convert seconds to atomic labeled time
+def sec_to_time(input_time):
+ return frames_to_time(input_time, sr=1)
+
+
+# Function to load total trainable raw audio lengths
+def get_total_audio_length(audio_paths):
+ total_length = 0
+ for cur_audio_path in audio_paths:
+ cur_wav = wave.open(cur_audio_path, 'r')
+ total_length += cur_wav.getnframes() # here, length = # of frames
+ return total_length
+
+
+# Function to load length of an input wav audio
+def load_wav_length(audio_path):
+ pt_wav = wave.open(audio_path, 'r')
+ length = pt_wav.getnframes()
+ return length
+
+
+# Function to load only selected 16 bit, stereo wav audio segment from an input wav audio
+def load_wav_segment(audio_path, start_point=None, duration=None, axis=1, sample_rate=44100):
+ start_point = 0 if start_point==None else start_point
+ duration = load_wav_length(audio_path) if duration==None else duration
+ pt_wav = wave.open(audio_path, 'r')
+
+ if pt_wav.getframerate()!=sample_rate:
+ raise ValueError(f"ValueError: input audio's sample rate should be {sample_rate}")
+ pt_wav.setpos(start_point)
+ x = pt_wav.readframes(duration)
+ if pt_wav.getsampwidth()==2:
+ x = np.frombuffer(x, dtype=np.int16)
+ X = x / float(2**15) # needs to be 16 bit format
+ elif pt_wav.getsampwidth()==4:
+ x = np.frombuffer(x, dtype=np.int32)
+ X = x / float(2**31) # needs to be 32 bit format
+ else:
+ raise ValueError("ValueError: input audio's bit depth should be 16 or 32-bit")
+
+ # exception for stereo channels
+ if pt_wav.getnchannels()==2:
+ X_l = np.expand_dims(X[::2], axis=axis)
+ X_r = np.expand_dims(X[1::2], axis=axis)
+ X = np.concatenate((X_l, X_r), axis=axis)
+ return X
+
diff --git a/mixing_style_transfer/mixing_manipulator/__init__.py b/mixing_style_transfer/mixing_manipulator/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f1174376e3ee1509c5b8963046abc2a936ac9cf
--- /dev/null
+++ b/mixing_style_transfer/mixing_manipulator/__init__.py
@@ -0,0 +1,4 @@
+from .audio_effects_chain import *
+from .common_audioeffects import *
+from .common_dataprocessing import create_dataset
+from data_normalization import Audio_Effects_Normalizer
\ No newline at end of file
diff --git a/mixing_style_transfer/mixing_manipulator/audio_effects_chain.py b/mixing_style_transfer/mixing_manipulator/audio_effects_chain.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8518666f4e39564b39c86ae524f9b46c68c6eac
--- /dev/null
+++ b/mixing_style_transfer/mixing_manipulator/audio_effects_chain.py
@@ -0,0 +1,165 @@
+"""
+ Implementation of Audio Effects Chain Manipulation for the task 'Mixing Style Transfer'
+"""
+from glob import glob
+import os
+import sys
+
+currentdir = os.path.dirname(os.path.realpath(__file__))
+sys.path.append(currentdir)
+sys.path.append(os.path.dirname(currentdir))
+from common_audioeffects import *
+from common_dataprocessing import create_dataset
+
+
+
+# create augmentation effects chain according to targeted effects with their applying probability
+def create_effects_augmentation_chain(effects, \
+ ir_dir_path=None, \
+ sample_rate=44100, \
+ shuffle=False, \
+ parallel=False, \
+ parallel_weight_factor=None):
+ '''
+ Args:
+ effects (list of tuples or string) : First tuple element is string denoting the target effects.
+ Second tuple element is probability of applying current effects.
+ ir_dir_path (string) : directory path that contains directories of impulse responses organized according to RT60
+ sample_rate (int) : using sampling rate
+ shuffle (boolean) : shuffle FXs inside current FX chain
+ parallel (boolean) : compute parallel FX computation (alpha * input + (1-alpha) * manipulated output)
+ parallel_weight_factor : the value of alpha for parallel FX computation. default=None : random value in between (0.0, 0.5)
+ '''
+ fx_list = []
+ apply_prob = []
+ for cur_fx in effects:
+ # store probability to apply current effects. default is to set as 100%
+ if isinstance(cur_fx, tuple):
+ apply_prob.append(cur_fx[1])
+ cur_fx = cur_fx[0]
+ else:
+ apply_prob.append(1)
+
+ # processors of each audio effects
+ if isinstance(cur_fx, AugmentationChain) or isinstance(cur_fx, Processor):
+ fx_list.append(cur_fx)
+ elif cur_fx.lower()=='gain':
+ fx_list.append(Gain())
+ elif 'eq' in cur_fx.lower():
+ fx_list.append(Equaliser(n_channels=2, sample_rate=sample_rate))
+ elif 'comp' in cur_fx.lower():
+ fx_list.append(Compressor(sample_rate=sample_rate))
+ elif 'expand' in cur_fx.lower():
+ fx_list.append(Expander(sample_rate=sample_rate))
+ elif 'pan' in cur_fx.lower():
+ fx_list.append(Panner())
+ elif 'image'in cur_fx.lower():
+ fx_list.append(MidSideImager())
+ elif 'algorithmic' in cur_fx.lower():
+ fx_list.append(AlgorithmicReverb(sample_rate=sample_rate))
+ elif 'reverb' in cur_fx.lower():
+ # apply algorithmic reverberation if ir_dir_path is not defined
+ if ir_dir_path==None:
+ fx_list.append(AlgorithmicReverb(sample_rate=sample_rate))
+ # apply convolution reverberation
+ else:
+ IR_paths = glob(f"{ir_dir_path}*/RT60_avg/[!0-]*")
+ IR_list = []
+ IR_dict = {}
+ for IR_path in IR_paths:
+ cur_rt = IR_path.split('/')[-1]
+ if cur_rt not in IR_dict:
+ IR_dict[cur_rt] = []
+ IR_dict[cur_rt].extend(create_dataset(path=IR_path, \
+ accepted_sampling_rates=[sample_rate], \
+ sources=['impulse_response'], \
+ mapped_sources={}, load_to_memory=True, debug=False)[0])
+ long_ir_list = []
+ for cur_rt in IR_dict:
+ cur_rt_len = int(cur_rt.split('-')[0])
+ if cur_rt_len < 3000:
+ IR_list.append(IR_dict[cur_rt])
+ else:
+ long_ir_list.extend(IR_dict[cur_rt])
+
+ IR_list.append(long_ir_list)
+ fx_list.append(ConvolutionalReverb(IR_list, sample_rate))
+ else:
+ raise ValueError(f"make sure the target effects are in the Augment FX chain : received fx called {cur_fx}")
+
+ aug_chain_in = []
+ for cur_i, cur_fx in enumerate(fx_list):
+ normalize = False if isinstance(cur_fx, AugmentationChain) or cur_fx.name=='Gain' else True
+ aug_chain_in.append((cur_fx, apply_prob[cur_i], normalize))
+
+ return AugmentationChain(fxs=aug_chain_in, shuffle=shuffle, parallel=parallel, parallel_weight_factor=parallel_weight_factor)
+
+
+# create audio FX-chain according to input instrument
+def create_inst_effects_augmentation_chain(inst, \
+ apply_prob_dict, \
+ ir_dir_path=None, \
+ algorithmic=False, \
+ sample_rate=44100):
+ '''
+ Args:
+ inst (string) : FXmanipulator for target instrument. Current version only distinguishes 'drums' for applying reverberation
+ apply_prob_dict (dictionary of (FX name, probability)) : applying proababilities for each FX
+ ir_dir_path (string) : directory path that contains directories of impulse responses organized according to RT60
+ algorithmic (boolean) : rather to use algorithmic reverberation (True) or convolution reverberation (False)
+ sample_rate (int) : using sampling rate
+ '''
+ reverb_type = 'algorithmic' if algorithmic else 'reverb'
+ eq_comp_rand = create_effects_augmentation_chain([('eq', apply_prob_dict['eq']), ('comp', apply_prob_dict['comp'])], \
+ ir_dir_path=ir_dir_path, \
+ sample_rate=sample_rate, \
+ shuffle=True)
+ pan_image_rand = create_effects_augmentation_chain([('pan', apply_prob_dict['pan']), ('imager', apply_prob_dict['imager'])], \
+ ir_dir_path=ir_dir_path, \
+ sample_rate=sample_rate, \
+ shuffle=True)
+ if inst=='drums':
+ # apply reverberation to low frequency with little probability
+ low_pass_eq_params = ParameterList()
+ low_pass_eq_params.add(Parameter('high_shelf_gain', -50.0, 'float', minimum=-50.0, maximum=-50.0))
+ low_pass_eq_params.add(Parameter('high_shelf_freq', 100.0, 'float', minimum=100.0, maximum=100.0))
+ low_pass_eq = Equaliser(n_channels=2, \
+ sample_rate=sample_rate, \
+ bands=['high_shelf'], \
+ parameters=low_pass_eq_params)
+ reverb_parallel_low = create_effects_augmentation_chain([low_pass_eq, (reverb_type, apply_prob_dict['reverb']*0.01)], \
+ ir_dir_path=ir_dir_path, \
+ sample_rate=sample_rate, \
+ parallel=True, \
+ parallel_weight_factor=0.8)
+ # high pass eq for drums reverberation
+ high_pass_eq_params = ParameterList()
+ high_pass_eq_params.add(Parameter('low_shelf_gain', -50.0, 'float', minimum=-50.0, maximum=-50.0))
+ high_pass_eq_params.add(Parameter('low_shelf_freq', 100.0, 'float', minimum=100.0, maximum=100.0))
+ high_pass_eq = Equaliser(n_channels=2, \
+ sample_rate=sample_rate, \
+ bands=['low_shelf'], \
+ parameters=high_pass_eq_params)
+ reverb_parallel_high = create_effects_augmentation_chain([high_pass_eq, (reverb_type, apply_prob_dict['reverb'])], \
+ ir_dir_path=ir_dir_path, \
+ sample_rate=sample_rate, \
+ parallel=True, \
+ parallel_weight_factor=0.6)
+ reverb_parallel = create_effects_augmentation_chain([reverb_parallel_low, reverb_parallel_high], \
+ ir_dir_path=ir_dir_path, \
+ sample_rate=sample_rate)
+ else:
+ reverb_parallel = create_effects_augmentation_chain([(reverb_type, apply_prob_dict['reverb'])], \
+ ir_dir_path=ir_dir_path, \
+ sample_rate=sample_rate, \
+ parallel=True)
+ # full effects chain
+ effects_chain = create_effects_augmentation_chain([eq_comp_rand, \
+ pan_image_rand, \
+ reverb_parallel, \
+ ('gain', apply_prob_dict['gain'])], \
+ ir_dir_path=ir_dir_path, \
+ sample_rate=sample_rate)
+
+ return effects_chain
+
diff --git a/mixing_style_transfer/mixing_manipulator/common_audioeffects.py b/mixing_style_transfer/mixing_manipulator/common_audioeffects.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d2932ee6813d71962397fb75d72a1df8bdbec26
--- /dev/null
+++ b/mixing_style_transfer/mixing_manipulator/common_audioeffects.py
@@ -0,0 +1,1537 @@
+"""
+Audio effects for data augmentation.
+
+Several audio effects can be combined into an augmentation chain.
+
+Important note: We assume that the parallelization during training is done using
+ multi-processing and not multi-threading. Hence, we do not need the
+ `@sox.sox_context()` decorators as discussed in this
+ [thread](https://github.com/pseeth/soxbindings/issues/4).
+
+AI Music Technology Group, Sony Group Corporation
+AI Speech and Sound Group, Sony Europe
+
+
+This implementation originally belongs to Sony Group Corporation,
+ which has been introduced in the work "Automatic music mixing with deep learning and out-of-domain data".
+ Original repo link: https://github.com/sony/FxNorm-automix
+This work modifies a few implementations from the original repo to suit the task.
+"""
+
+from itertools import permutations
+import logging
+import numpy as np
+import pymixconsole as pymc
+from pymixconsole.parameter import Parameter
+from pymixconsole.parameter_list import ParameterList
+from pymixconsole.processor import Processor
+from random import shuffle
+from scipy.signal import oaconvolve
+import soxbindings as sox
+from typing import List, Optional, Tuple, Union
+from numba import jit
+
+# prevent pysox from logging warnings regarding non-opimal timestretch factors
+logging.getLogger('sox').setLevel(logging.ERROR)
+
+
+# Monkey-Patch `Processor` for convenience
+# (a) Allow `None` as blocksize if processor can work on variable-length audio
+def new_init(self, name, parameters, block_size, sample_rate, dtype='float32'):
+ """
+ Initialize processor.
+
+ Args:
+ self: Reference to object
+ name (str): Name of processor.
+ parameters (parameter_list): Parameters for this processor.
+ block_size (int): Size of blocks for blockwise processing.
+ Can also be `None` if full audio can be processed at once.
+ sample_rate (int): Sample rate of input audio. Use `None` if effect is independent of this value.
+ dtype (str): data type of samples
+ """
+ self.name = name
+ self.parameters = parameters
+ self.block_size = block_size
+ self.sample_rate = sample_rate
+ self.dtype = dtype
+
+
+# (b) make code simpler
+def new_update(self, parameter_name):
+ """
+ Update processor after randomization of parameters.
+
+ Args:
+ self: Reference to object.
+ parameter_name (str): Parameter whose value has changed.
+ """
+ pass
+
+
+# (c) representation for nice print
+def new_repr(self):
+ """
+ Create human-readable representation.
+
+ Args:
+ self: Reference to object.
+
+ Returns:
+ string representation of object.
+ """
+ return f'Processor(name={self.name!r}, parameters={self.parameters!r}'
+
+
+Processor.__init__ = new_init
+Processor.__repr__ = new_repr
+Processor.update = new_update
+
+
+class AugmentationChain:
+ """Basic audio Fx chain which is used for data augmentation."""
+
+ def __init__(self,
+ fxs: Optional[List[Tuple[Union[Processor, 'AugmentationChain'], float, bool]]] = [],
+ shuffle: Optional[bool] = False,
+ parallel: Optional[bool] = False,
+ parallel_weight_factor = None,
+ randomize_param_value=True):
+ """
+ Create augmentation chain from the dictionary `fxs`.
+
+ Args:
+ fxs (list of tuples): First tuple element is an instances of `pymc.processor` or `AugmentationChain` that
+ we want to use for data augmentation. Second element gives probability that effect should be applied.
+ Third element defines, whether the processed signal is normalized by the RMS of the input.
+ shuffle (bool): If `True` then order of Fx are changed whenever chain is applied.
+ """
+ self.fxs = fxs
+ self.shuffle = shuffle
+ self.parallel = parallel
+ self.parallel_weight_factor = parallel_weight_factor
+ self.randomize_param_value = randomize_param_value
+
+ def apply_processor(self, x, processor: Processor, rms_normalize):
+ """
+ Pass audio in `x` through `processor` and output the respective processed audio.
+
+ Args:
+ x (Numpy array): Input audio of shape `n_samples` x `n_channels`.
+ processor (Processor): Audio effect that we want to apply.
+ rms_normalize (bool): If `True`, the processed signal is normalized by the RMS of the signal.
+
+ Returns:
+ Numpy array: Processed audio of shape `n_samples` x `n_channels` (same size as `x')
+ """
+
+ n_samples_input = x.shape[0]
+
+ if processor.block_size is None:
+ y = processor.process(x)
+ else:
+ # make sure that n_samples is a multiple of `processor.block_size`
+ if x.shape[0] % processor.block_size != 0:
+ n_pad = processor.block_size - x.shape[0] % processor.block_size
+ x = np.pad(x, ((0, n_pad), (0, 0)), mode='reflective')
+
+ y = np.zeros_like(x)
+ for idx in range(0, x.shape[0], processor.block_size):
+ y[idx:idx+processor.block_size, :] = processor.process(x[idx:idx+processor.block_size, :])
+
+ if rms_normalize:
+ # normalize output energy such that it is the same as the input energy
+ scale = np.sqrt(np.mean(np.square(x)) / np.maximum(1e-7, np.mean(np.square(y))))
+ y *= scale
+
+ # return audio of same length as x
+ return y[:n_samples_input, :]
+
+ def apply_same_processor(self, x_list, processor: Processor, rms_normalize):
+ for i in range(len(x_list)):
+ x_list[i] = self.apply_processor(x_list[i], processor, rms_normalize)
+
+ return x_list
+
+ def __call__(self, x_list):
+ """
+ Apply the same augmentation chain to audio tracks in list `x_list`.
+
+ Args:
+ x_list (list of Numpy array) : List of audio samples of shape `n_samples` x `n_channels`.
+
+ Returns:
+ y_list (list of Numpy array) : List of processed audio of same shape as `x_list` where the same effects have been applied.
+ """
+ # randomly shuffle effect order if `self.shuffle` is True
+ if self.shuffle:
+ shuffle(self.fxs)
+
+ # apply effects with probabilities given in `self.fxs`
+ y_list = x_list.copy()
+ for fx, p, rms_normalize in self.fxs:
+ if np.random.rand() < p:
+ if isinstance(fx, Processor):
+ # randomize all effect parameters (also calls `update()` for each processor)
+ if self.randomize_param_value:
+ fx.randomize()
+ else:
+ fx.update(None)
+
+ # apply processor
+ y_list = self.apply_same_processor(y_list, fx, rms_normalize)
+ else:
+ y_list = fx(y_list)
+
+ if self.parallel:
+ # weighting factor of input signal in the range of (0.0 ~ 0.5)
+ weight_in = self.parallel_weight_factor if self.parallel_weight_factor else np.random.rand() / 2.
+ for i in range(len(y_list)):
+ y_list[i] = weight_in*x_list[i] + (1-weight_in)*y_list[i]
+
+ return y_list
+
+ def __repr__(self):
+ """
+ Human-readable representation.
+
+ Returns:
+ string representation of object.
+ """
+ return f'AugmentationChain(fxs={self.fxs!r}, shuffle={self.shuffle!r})'
+
+
+# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% DISTORTION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+def hard_clip(x, threshold_dB, drive):
+ """
+ Hard clip distortion.
+
+ Args:
+ x: input audio
+ threshold_dB: threshold
+ drive: drive
+
+ Returns:
+ (Numpy array): distorted audio
+ """
+ drive_linear = np.power(10., drive / 20.).astype(np.float32)
+ threshold_linear = 10. ** (threshold_dB / 20.)
+ return np.clip(x * drive_linear, -threshold_linear, threshold_linear)
+
+
+def overdrive(x, drive, colour, sample_rate):
+ """
+ Overdrive distortion.
+
+ Args:
+ x: input audio
+ drive: Controls the amount of distortion (dB).
+ colour: Controls the amount of even harmonic content in the output(dB)
+ sample_rate: sampling rate
+
+ Returns:
+ (Numpy array): distorted audio
+ """
+ scale = np.max(np.abs(x))
+ if scale > 0.9:
+ clips = True
+ x = x * (0.9 / scale)
+ else:
+ clips = False
+
+ tfm = sox.Transformer()
+ tfm.overdrive(gain_db=drive, colour=colour)
+ y = tfm.build_array(input_array=x, sample_rate_in=sample_rate).astype(np.float32)
+
+ if clips:
+ y *= scale / 0.9 # rescale output to original scale
+ return y
+
+
+def hyperbolic_tangent(x, drive):
+ """
+ Hyperbolic Tanh distortion.
+
+ Args:
+ x: input audio
+ drive: drive
+
+ Returns:
+ (Numpy array): distorted audio
+ """
+ drive_linear = np.power(10., drive / 20.).astype(np.float32)
+ return np.tanh(2. * x * drive_linear)
+
+
+def soft_sine(x, drive):
+ """
+ Soft sine distortion.
+
+ Args:
+ x: input audio
+ drive: drive
+
+ Returns:
+ (Numpy array): distorted audio
+ """
+ drive_linear = np.power(10., drive / 20.).astype(np.float32)
+ y = np.clip(x * drive_linear, -np.pi/4.0, np.pi/4.0)
+ return np.sin(2. * y)
+
+
+def bit_crusher(x, bits):
+ """
+ Bit crusher distortion.
+
+ Args:
+ x: input audio
+ bits: bits
+
+ Returns:
+ (Numpy array): distorted audio
+ """
+ return np.rint(x * (2 ** bits)) / (2 ** bits)
+
+
+class Distortion(Processor):
+ """
+ Distortion processor.
+
+ Processor parameters:
+ mode (str): Currently supports the following five modes: hard_clip, waveshaper, soft_sine, tanh, bit_crusher.
+ Each mode has different parameters such as threshold, factor, or bits.
+ threshold (float): threshold
+ drive (float): drive
+ factor (float): factor
+ limit_range (float): limit range
+ bits (int): bits
+ """
+
+ def __init__(self, sample_rate, name='Distortion', parameters=None):
+ """
+ Initialize processor.
+
+ Args:
+ sample_rate (int): sample rate.
+ name (str): Name of processor.
+ parameters (parameter_list): Parameters for this processor.
+ """
+ super().__init__(name, None, block_size=None, sample_rate=sample_rate)
+ if not parameters:
+ self.parameters = ParameterList()
+ self.parameters.add(Parameter('mode', 'hard_clip', 'string',
+ options=['hard_clip',
+ 'overdrive',
+ 'soft_sine',
+ 'tanh',
+ 'bit_crusher']))
+ self.parameters.add(Parameter('threshold', 0.0, 'float',
+ units='dB', maximum=0.0, minimum=-20.0))
+ self.parameters.add(Parameter('drive', 0.0, 'float',
+ units='dB', maximum=20.0, minimum=0.0))
+ self.parameters.add(Parameter('colour', 20.0, 'float',
+ maximum=100.0, minimum=0.0))
+ self.parameters.add(Parameter('bits', 12, 'int',
+ maximum=12, minimum=8))
+
+ def process(self, x):
+ """
+ Process audio.
+
+ Args:
+ x (Numpy array): input audio of size `n_samples x n_channels`.
+
+ Returns:
+ (Numpy array): distorted audio of size `n_samples x n_channels`.
+ """
+ if self.parameters.mode.value == 'hard_clip':
+ y = hard_clip(x, self.parameters.threshold.value, self.parameters.drive.value)
+ elif self.parameters.mode.value == 'overdrive':
+ y = overdrive(x, self.parameters.drive.value,
+ self.parameters.colour.value, self.sample_rate)
+ elif self.parameters.mode.value == 'soft_sine':
+ y = soft_sine(x, self.parameters.drive.value)
+ elif self.parameters.mode.value == 'tanh':
+ y = hyperbolic_tangent(x, self.parameters.drive.value)
+ elif self.parameters.mode.value == 'bit_crusher':
+ y = bit_crusher(x, self.parameters.bits.value)
+
+ # If the output has low amplitude, (some distortion settigns can "crush" down the amplitude)
+ # Then it`s normalised to the input's amplitude
+ x_max = np.max(np.abs(x)) + 1e-8
+ o_max = np.max(np.abs(y)) + 1e-8
+ if x_max > o_max:
+ y = y*(x_max/o_max)
+
+ return y
+
+
+# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% EQUALISER %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+class Equaliser(Processor):
+ """
+ Five band parametric equaliser (two shelves and three central bands).
+
+ All gains are set in dB values and range from `MIN_GAIN` dB to `MAX_GAIN` dB.
+ This processor is implemented as cascade of five biquad IIR filters
+ that are implemented using the infamous cookbook formulae from RBJ.
+
+ Processor parameters:
+ low_shelf_gain (float), low_shelf_freq (float)
+ first_band_gain (float), first_band_freq (float), first_band_q (float)
+ second_band_gain (float), second_band_freq (float), second_band_q (float)
+ third_band_gain (float), third_band_freq (float), third_band_q (float)
+
+ original from https://github.com/csteinmetz1/pymixconsole/blob/master/pymixconsole/processors/equaliser.py
+ """
+
+ def __init__(self, n_channels,
+ sample_rate,
+ gain_range=(-15.0, 15.0),
+ q_range=(0.1, 2.0),
+ bands=['low_shelf', 'first_band', 'second_band', 'third_band', 'high_shelf'],
+ hard_clip=False,
+ name='Equaliser', parameters=None):
+ """
+ Initialize processor.
+
+ Args:
+ n_channels (int): Number of audio channels.
+ sample_rate (int): Sample rate of audio.
+ gain_range (tuple of floats): minimum and maximum gain that can be used.
+ q_range (tuple of floats): minimum and maximum q value.
+ hard_clip (bool): Whether we clip to [-1, 1.] after processing.
+ name (str): Name of processor.
+ parameters (parameter_list): Parameters for this processor.
+ """
+ super().__init__(name, parameters=parameters, block_size=None, sample_rate=sample_rate)
+
+ self.n_channels = n_channels
+
+ MIN_GAIN, MAX_GAIN = gain_range
+ MIN_Q, MAX_Q = q_range
+
+ if not parameters:
+ self.parameters = ParameterList()
+ # low shelf parameters -------
+ self.parameters.add(Parameter('low_shelf_gain', 0.0, 'float', minimum=MIN_GAIN, maximum=MAX_GAIN))
+ self.parameters.add(Parameter('low_shelf_freq', 80.0, 'float', minimum=30.0, maximum=200.0))
+ # first band parameters ------
+ self.parameters.add(Parameter('first_band_gain', 0.0, 'float', minimum=MIN_GAIN, maximum=MAX_GAIN))
+ self.parameters.add(Parameter('first_band_freq', 400.0, 'float', minimum=200.0, maximum=1000.0))
+ self.parameters.add(Parameter('first_band_q', 0.7, 'float', minimum=MIN_Q, maximum=MAX_Q))
+ # second band parameters -----
+ self.parameters.add(Parameter('second_band_gain', 0.0, 'float', minimum=MIN_GAIN, maximum=MAX_GAIN))
+ self.parameters.add(Parameter('second_band_freq', 2000.0, 'float', minimum=1000.0, maximum=3000.0))
+ self.parameters.add(Parameter('second_band_q', 0.7, 'float', minimum=MIN_Q, maximum=MAX_Q))
+ # third band parameters ------
+ self.parameters.add(Parameter('third_band_gain', 0.0, 'float', minimum=MIN_GAIN, maximum=MAX_GAIN))
+ self.parameters.add(Parameter('third_band_freq', 4000.0, 'float', minimum=3000.0, maximum=8000.0))
+ self.parameters.add(Parameter('third_band_q', 0.7, 'float', minimum=MIN_Q, maximum=MAX_Q))
+ # high shelf parameters ------
+ self.parameters.add(Parameter('high_shelf_gain', 0.0, 'float', minimum=MIN_GAIN, maximum=MAX_GAIN))
+ self.parameters.add(Parameter('high_shelf_freq', 8000.0, 'float', minimum=5000.0, maximum=10000.0))
+
+ self.bands = bands
+ self.filters = self.setup_filters()
+ self.hard_clip = hard_clip
+
+ def setup_filters(self):
+ """
+ Create IIR filters.
+
+ Returns:
+ IIR filters
+ """
+ filters = {}
+
+ for band in self.bands:
+
+ G = getattr(self.parameters, band + '_gain').value
+ fc = getattr(self.parameters, band + '_freq').value
+ rate = self.sample_rate
+
+ if band in ['low_shelf', 'high_shelf']:
+ Q = 0.707
+ filter_type = band
+ else:
+ Q = getattr(self.parameters, band + '_q').value
+ filter_type = 'peaking'
+
+ filters[band] = pymc.components.iirfilter.IIRfilter(G, Q, fc, rate, filter_type, n_channels=self.n_channels)
+
+ return filters
+
+ def update_filter(self, band):
+ """
+ Update filters.
+
+ Args:
+ band (str): Band that should be updated.
+ """
+ self.filters[band].G = getattr(self.parameters, band + '_gain').value
+ self.filters[band].fc = getattr(self.parameters, band + '_freq').value
+ self.filters[band].rate = self.sample_rate
+
+ if band in ['first_band', 'second_band', 'third_band']:
+ self.filters[band].Q = getattr(self.parameters, band + '_q').value
+
+ def update(self, parameter_name=None):
+ """
+ Update processor after randomization of parameters.
+
+ Args:
+ parameter_name (str): Parameter whose value has changed.
+ """
+ if parameter_name is not None:
+ bands = ['_'.join(parameter_name.split('_')[:2])]
+ else:
+ bands = self.bands
+
+ for band in bands:
+ self.update_filter(band)
+
+ for _band, iirfilter in self.filters.items():
+ iirfilter.reset_state()
+
+ def reset_state(self):
+ """Reset state."""
+ for _band, iirfilter in self.filters.items():
+ iirfilter.reset_state()
+
+ def process(self, x):
+ """
+ Process audio.
+
+ Args:
+ x (Numpy array): input audio of size `n_samples x n_channels`.
+
+ Returns:
+ (Numpy array): equalized audio of size `n_samples x n_channels`.
+ """
+ for _band, iirfilter in self.filters.items():
+ iirfilter.reset_state()
+ x = iirfilter.apply_filter(x)
+
+ if self.hard_clip:
+ x = np.clip(x, -1.0, 1.0)
+
+ # make sure that we have float32 as IIR filtering returns float64
+ x = x.astype(np.float32)
+
+ # make sure that we have two dimensions (if `n_channels == 1`)
+ if x.ndim == 1:
+ x = x[:, np.newaxis]
+
+ return x
+
+
+# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% COMPRESSOR %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+@jit(nopython=True)
+def compressor_process(x, threshold, attack_time, release_time, ratio, makeup_gain, sample_rate, yL_prev):
+ """
+ Apply compressor.
+
+ Args:
+ x (Numpy array): audio data.
+ threshold: threshold in dB.
+ attack_time: attack_time in ms.
+ release_time: release_time in ms.
+ ratio: ratio.
+ makeup_gain: makeup_gain.
+ sample_rate: sample rate.
+ yL_prev: internal state of the envelop gain.
+
+ Returns:
+ compressed audio.
+ """
+ M = x.shape[0]
+ x_g = np.zeros(M)
+ x_l = np.zeros(M)
+ y_g = np.zeros(M)
+ y_l = np.zeros(M)
+ c = np.zeros(M)
+ yL_prev = 0.
+
+ alpha_attack = np.exp(-1/(0.001 * sample_rate * attack_time))
+ alpha_release = np.exp(-1/(0.001 * sample_rate * release_time))
+
+ for i in np.arange(M):
+ if np.abs(x[i]) < 0.000001:
+ x_g[i] = -120.0
+ else:
+ x_g[i] = 20 * np.log10(np.abs(x[i]))
+
+ if ratio > 1:
+ if x_g[i] >= threshold:
+ y_g[i] = threshold + (x_g[i] - threshold) / ratio
+ else:
+ y_g[i] = x_g[i]
+ elif ratio < 1:
+ if x_g[i] <= threshold:
+ y_g[i] = threshold + (x_g[i] - threshold) / (1/ratio)
+ else:
+ y_g[i] = x_g[i]
+
+ x_l[i] = x_g[i] - y_g[i]
+
+ if x_l[i] > yL_prev:
+ y_l[i] = alpha_attack * yL_prev + (1 - alpha_attack) * x_l[i]
+ else:
+ y_l[i] = alpha_release * yL_prev + (1 - alpha_release) * x_l[i]
+
+ c[i] = np.power(10.0, (makeup_gain - y_l[i]) / 20.0)
+ yL_prev = y_l[i]
+
+ y = x * c
+
+ return y, yL_prev
+
+
+class Compressor(Processor):
+ """
+ Single band stereo dynamic range compressor.
+
+ Processor parameters:
+ threshold (float)
+ attack_time (float)
+ release_time (float)
+ ratio (float)
+ makeup_gain (float)
+ """
+
+ def __init__(self, sample_rate, name='Compressor', parameters=None):
+ """
+ Initialize processor.
+
+ Args:
+ sample_rate (int): Sample rate of input audio.
+ name (str): Name of processor.
+ parameters (parameter_list): Parameters for this processor.
+ """
+ super().__init__(name=name, parameters=parameters, block_size=None, sample_rate=sample_rate)
+
+ if not parameters:
+ self.parameters = ParameterList()
+ self.parameters.add(Parameter('threshold', -20.0, 'float', units='dB', minimum=-80.0, maximum=-5.0))
+ self.parameters.add(Parameter('attack_time', 2.0, 'float', units='ms', minimum=1., maximum=20.0))
+ self.parameters.add(Parameter('release_time', 100.0, 'float', units='ms', minimum=50.0, maximum=500.0))
+ self.parameters.add(Parameter('ratio', 4.0, 'float', minimum=4., maximum=40.0))
+ # we remove makeup_gain parameter inside the Compressor
+
+ # store internal state (for block-wise processing)
+ self.yL_prev = None
+
+ def process(self, x):
+ """
+ Process audio.
+
+ Args:
+ x (Numpy array): input audio of size `n_samples x n_channels`.
+
+ Returns:
+ (Numpy array): compressed audio of size `n_samples x n_channels`.
+ """
+ if self.yL_prev is None:
+ self.yL_prev = [0.] * x.shape[1]
+
+ if not self.parameters.threshold.value == 0.0 or not self.parameters.ratio.value == 1.0:
+ y = np.zeros_like(x)
+
+ for ch in range(x.shape[1]):
+ y[:, ch], self.yL_prev[ch] = compressor_process(x[:, ch],
+ self.parameters.threshold.value,
+ self.parameters.attack_time.value,
+ self.parameters.release_time.value,
+ self.parameters.ratio.value,
+ 0.0, # makeup_gain = 0
+ self.sample_rate,
+ self.yL_prev[ch])
+ else:
+ y = x
+
+ return y
+
+ def update(self, parameter_name=None):
+ """
+ Update processor after randomization of parameters.
+
+ Args:
+ parameter_name (str): Parameter whose value has changed.
+ """
+ self.yL_prev = None
+
+
+# %%%%%%%%%%%%%%%%%%%%%%%%%% CONVOLUTIONAL REVERB %%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+class ConvolutionalReverb(Processor):
+ """
+ Convolutional Reverb.
+
+ Processor parameters:
+ wet_dry (float): Wet/dry ratio.
+ decay (float): Applies a fade out to the impulse response.
+ pre_delay (float): Value in ms. Shifts the IR in time and allows.
+ A positive value produces a traditional delay between the dry signal and the wet.
+ A negative delay is, in reality, zero delay, but effectively trims off the start of IR,
+ so the reverb response begins at a point further in.
+ """
+
+ def __init__(self, impulse_responses, sample_rate, name='ConvolutionalReverb', parameters=None):
+ """
+ Initialize processor.
+
+ Args:
+ impulse_responses (list): List with impulse responses created by `common_dataprocessing.create_dataset`
+ sample_rate (int): Sample rate that we should assume (used for fade-out computation)
+ name (str): Name of processor.
+ parameters (parameter_list): Parameters for this processor.
+
+ Raises:
+ ValueError: if no impulse responses are provided.
+ """
+ super().__init__(name=name, parameters=parameters, block_size=None, sample_rate=sample_rate)
+
+ if impulse_responses is None:
+ raise ValueError('List of impulse responses must be provided for ConvolutionalReverb processor.')
+ self.impulse_responses = impulse_responses
+
+ if not parameters:
+ self.parameters = ParameterList()
+ self.max_ir_num = len(max(impulse_responses, key=len))
+ self.parameters.add(Parameter('index', 0, 'int', minimum=0, maximum=len(impulse_responses)))
+ self.parameters.add(Parameter('index_ir', 0, 'int', minimum=0, maximum=self.max_ir_num))
+ self.parameters.add(Parameter('wet', 1.0, 'float', minimum=1.0, maximum=1.0))
+ self.parameters.add(Parameter('dry', 0.0, 'float', minimum=0.0, maximum=0.0))
+ self.parameters.add(Parameter('decay', 1.0, 'float', minimum=1.0, maximum=1.0))
+ self.parameters.add(Parameter('pre_delay', 0, 'int', units='ms', minimum=0, maximum=0))
+
+ def update(self, parameter_name=None):
+ """
+ Update processor after randomization of parameters.
+
+ Args:
+ parameter_name (str): Parameter whose value has changed.
+ """
+ # we sample IR with a uniform random distribution according to RT60 values
+ chosen_ir_duration = self.impulse_responses[self.parameters.index.value]
+ chosen_ir_idx = self.parameters.index_ir.value % len(chosen_ir_duration)
+ self.h = np.copy(chosen_ir_duration[chosen_ir_idx]['impulse_response']())
+
+ # fade out the impulse based on the decay setting (starting from peak value)
+ if self.parameters.decay.value < 1.:
+ idx_peak = np.argmax(np.max(np.abs(self.h), axis=1), axis=0)
+ fstart = np.minimum(self.h.shape[0],
+ idx_peak + int(self.parameters.decay.value * (self.h.shape[0] - idx_peak)))
+ fstop = np.minimum(self.h.shape[0], fstart + int(0.020*self.sample_rate)) # constant 20 ms fade out
+ flen = fstop - fstart
+
+ fade = np.arange(1, flen+1, dtype=self.dtype)/flen
+ fade = np.power(0.1, fade * 5)
+ self.h[fstart:fstop, :] *= fade[:, np.newaxis]
+ self.h = self.h[:fstop]
+
+ def process(self, x):
+ """
+ Process audio.
+
+ Args:
+ x (Numpy array): input audio of size `n_samples x n_channels`.
+
+ Returns:
+ (Numpy array): reverbed audio of size `n_samples x n_channels`.
+ """
+ # reshape IR to the correct size
+ n_channels = x.shape[1]
+ if self.h.shape[1] == 1 and n_channels > 1:
+ self.h = np.hstack([self.h] * n_channels) # repeat mono IR for multi-channel input
+ if self.h.shape[1] > 1 and n_channels == 1:
+ self.h = self.h[:, np.random.randint(self.h.shape[1]), np.newaxis] # randomly choose one IR channel
+
+ if self.parameters.wet.value == 0.0:
+ return x
+ else:
+ # perform convolution to get wet signal
+ y = oaconvolve(x, self.h, mode='full', axes=0)
+
+ # cut out wet signal (compensating for the delay that the IR is introducing + predelay)
+ idx = np.argmax(np.max(np.abs(self.h), axis=1), axis=0)
+ idx += int(0.001 * np.abs(self.parameters.pre_delay.value) * self.sample_rate)
+
+ idx = np.clip(idx, 0, self.h.shape[0]-1)
+
+ y = y[idx:idx+x.shape[0], :]
+
+ # return weighted sum of dry and wet signal
+ return self.parameters.dry.value * x + self.parameters.wet.value * y
+
+
+# %%%%%%%%%%%%%%%%%%%%%%%%%%%%% HAAS EFFECT %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+def haas_process(x, delay, feedback, wet_channel):
+ """
+ Add Haas effect to audio.
+
+ Args:
+ x (Numpy array): input audio.
+ delay: Delay that we apply to one of the channels (in samples).
+ feedback: Feedback value.
+ wet_channel: Which channel we process (`left` or `right`).
+
+ Returns:
+ (Numpy array): Audio with Haas effect.
+ """
+ y = np.copy(x)
+ if wet_channel == 'left':
+ y[:, 0] += feedback * np.roll(x[:, 0], delay)
+ elif wet_channel == 'right':
+ y[:, 1] += feedback * np.roll(x[:, 1], delay)
+
+ return y
+
+
+class Haas(Processor):
+ """
+ Haas Effect Processor.
+
+ Randomly selects one channel and applies a short delay to it.
+
+ Processor parameters:
+ delay (int)
+ feedback (float)
+ wet_channel (string)
+ """
+
+ def __init__(self, sample_rate, delay_range=(-0.040, 0.040), name='Haas', parameters=None,
+ ):
+ """
+ Initialize processor.
+
+ Args:
+ sample_rate (int): Sample rate of input audio.
+ delay_range (tuple of floats): minimum/maximum delay for Haas effect.
+ name (str): Name of processor.
+ parameters (parameter_list): Parameters for this processor.
+ """
+ super().__init__(name=name, parameters=parameters, block_size=None, sample_rate=sample_rate)
+
+ if not parameters:
+ self.parameters = ParameterList()
+ self.parameters.add(Parameter('delay', int(delay_range[1] * sample_rate), 'int', units='samples',
+ minimum=int(delay_range[0] * sample_rate),
+ maximum=int(delay_range[1] * sample_rate)))
+ self.parameters.add(Parameter('feedback', 0.35, 'float', minimum=0.33, maximum=0.66))
+ self.parameters.add(Parameter('wet_channel', 'left', 'string', options=['left', 'right']))
+
+ def process(self, x):
+ """
+ Process audio.
+
+ Args:
+ x (Numpy array): input audio of size `n_samples x n_channels`.
+
+ Returns:
+ (Numpy array): audio with Haas effect of size `n_samples x n_channels`.
+ """
+ assert x.shape[1] == 1 or x.shape[1] == 2, 'Haas effect only works with monaural or stereo audio.'
+
+ if x.shape[1] < 2:
+ x = np.repeat(x, 2, axis=1)
+
+ y = haas_process(x, self.parameters.delay.value,
+ self.parameters.feedback.value, self.parameters.wet_channel.value)
+
+ return y
+
+ def update(self, parameter_name=None):
+ """
+ Update processor after randomization of parameters.
+
+ Args:
+ parameter_name (str): Parameter whose value has changed.
+ """
+ self.reset_state()
+
+ def reset_state(self):
+ """Reset state."""
+ self.read_idx = 0
+ self.write_idx = self.parameters.delay.value
+ self.buffer = np.zeros((65536, 2))
+
+
+# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%% PANNER %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+class Panner(Processor):
+ """
+ Simple stereo panner.
+
+ If input is mono, output is stereo.
+ Original edited from https://github.com/csteinmetz1/pymixconsole/blob/master/pymixconsole/processors/panner.py
+ """
+
+ def __init__(self, name='Panner', parameters=None):
+ """
+ Initialize processor.
+
+ Args:
+ name (str): Name of processor.
+ parameters (parameter_list): Parameters for this processor.
+ """
+ # default processor class constructor
+ super().__init__(name=name, parameters=parameters, block_size=None, sample_rate=None)
+
+ if not parameters:
+ self.parameters = ParameterList()
+ self.parameters.add(Parameter('pan', 0.5, 'float', minimum=0., maximum=1.))
+ self.parameters.add(Parameter('pan_law', '-4.5dB', 'string',
+ options=['-4.5dB', 'linear', 'constant_power']))
+
+ # setup the coefficents based on default params
+ self.update()
+
+ def _calculate_pan_coefficents(self):
+ """
+ Calculate panning coefficients from the chosen pan law.
+
+ Based on the set pan law determine the gain value
+ to apply for the left and right channel to achieve panning effect.
+ This operates on the assumption that the input channel is mono.
+ The output data will be stereo at the moment, but could be expanded
+ to a higher channel count format.
+ The panning value is in the range [0, 1], where
+ 0 means the signal is panned completely to the left, and
+ 1 means the signal is apanned copletely to the right.
+
+ Raises:
+ ValueError: `self.parameters.pan_law` is not supported.
+ """
+ self.gains = np.zeros(2, dtype=self.dtype)
+
+ # first scale the linear [0, 1] to [0, pi/2]
+ theta = self.parameters.pan.value * (np.pi/2)
+
+ if self.parameters.pan_law.value == 'linear':
+ self.gains[0] = ((np.pi/2) - theta) * (2/np.pi)
+ self.gains[1] = theta * (2/np.pi)
+ elif self.parameters.pan_law.value == 'constant_power':
+ self.gains[0] = np.cos(theta)
+ self.gains[1] = np.sin(theta)
+ elif self.parameters.pan_law.value == '-4.5dB':
+ self.gains[0] = np.sqrt(((np.pi/2) - theta) * (2/np.pi) * np.cos(theta))
+ self.gains[1] = np.sqrt(theta * (2/np.pi) * np.sin(theta))
+ else:
+ raise ValueError(f'Invalid pan_law {self.parameters.pan_law.value}.')
+
+
+ def process(self, x):
+ """
+ Process audio.
+
+ Args:
+ x (Numpy array): input audio of size `n_samples x n_channels`.
+
+ Returns:
+ (Numpy array): panned audio of size `n_samples x n_channels`.
+ """
+ assert x.shape[1] == 1 or x.shape[1] == 2, 'Panner only works with monaural or stereo audio.'
+
+ if x.shape[1] < 2:
+ x = np.repeat(x, 2, axis=1)
+
+
+ return x * self.gains
+
+ def update(self, parameter_name=None):
+ """
+ Update processor after randomization of parameters.
+
+ Args:
+ parameter_name (str): Parameter whose value has changed.
+ """
+ self._calculate_pan_coefficents()
+
+ def reset_state(self):
+ """Reset state."""
+ self._output_buffer = np.empty([self.block_size, 2])
+ self.update()
+
+
+# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%% STEREO IMAGER %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+class MidSideImager(Processor):
+ def __init__(self, name='IMAGER', parameters=None):
+ super().__init__(name, parameters=parameters, block_size=None, sample_rate=None)
+
+ if not parameters:
+ self.parameters = ParameterList()
+ # values of 0.0~1.0 indicate making the signal more centered while 1.0~2.0 means making the signal more wider
+ self.parameters.add(Parameter("bal", 0.0, "float", processor=self, minimum=0.0, maximum=2.0))
+
+ def process(self, data):
+ """
+ # input shape : [signal length, 2]
+ ### note! stereo imager won't work if the input signal is a mono signal (left==right)
+ ### if you want to apply stereo imager to a mono signal, first stereoize it with Haas effects
+ """
+
+ # to mid-side channels
+ mid, side = self.lr_to_ms(data[:,0], data[:,1])
+ # apply mid-side weights according to energy
+ mid_e, side_e = np.sum(mid**2), np.sum(side**2)
+ total_e = mid_e + side_e
+ # apply weights
+ max_side_multiplier = np.sqrt(total_e / (side_e + 1e-3))
+ # compute current multiply factor
+ cur_bal = round(getattr(self.parameters, "bal").value, 3)
+ side_gain = cur_bal if cur_bal <= 1. else max_side_multiplier * (cur_bal-1)
+ # multiply weighting factor
+ new_side = side * side_gain
+ new_side_e = side_e * (side_gain ** 2)
+ left_mid_e = total_e - new_side_e
+ mid_gain = np.sqrt(left_mid_e / (mid_e + 1e-3))
+ new_mid = mid * mid_gain
+ # convert back to left-right channels
+ left, right = self.ms_to_lr(new_mid, new_side)
+ imaged = np.stack([left, right], 1)
+
+ return imaged
+
+ # left-right channeled signal to mid-side signal
+ def lr_to_ms(self, left, right):
+ mid = left + right
+ side = left - right
+ return mid, side
+
+ # mid-side channeled signal to left-right signal
+ def ms_to_lr(self, mid, side):
+ left = (mid + side) / 2
+ right = (mid - side) / 2
+ return left, right
+
+ def update(self, parameter_name=None):
+ return parameter_name
+
+
+# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% GAIN %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+class Gain(Processor):
+ """
+ Gain Processor.
+
+ Applies gain in dB and can also randomly inverts polarity.
+
+ Processor parameters:
+ gain (float): Gain that should be applied (dB scale).
+ invert (bool): If True, then we also invert the waveform.
+ """
+
+ def __init__(self, name='Gain', parameters=None):
+ """
+ Initialize processor.
+
+ Args:
+ name (str): Name of processor.
+ parameters (parameter_list): Parameters for this processor.
+ """
+ super().__init__(name, parameters=parameters, block_size=None, sample_rate=None)
+
+ if not parameters:
+ self.parameters = ParameterList()
+ # self.parameters.add(Parameter('gain', 1.0, 'float', units='dB', minimum=-12.0, maximum=6.0))
+ self.parameters.add(Parameter('gain', 1.0, 'float', units='dB', minimum=-6.0, maximum=9.0))
+ self.parameters.add(Parameter('invert', False, 'bool'))
+
+ def process(self, x):
+ """
+ Process audio.
+
+ Args:
+ x (Numpy array): input audio of size `n_samples x n_channels`.
+
+ Returns:
+ (Numpy array): gain-augmented audio of size `n_samples x n_channels`.
+ """
+ gain = 10 ** (self.parameters.gain.value / 20.)
+ if self.parameters.invert.value:
+ gain = -gain
+ return gain * x
+
+
+# %%%%%%%%%%%%%%%%%%%%%%% SIMPLE CHANNEL SWAP %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+class SwapChannels(Processor):
+ """
+ Swap channels in multi-channel audio.
+
+ Processor parameters:
+ index (int) Selects the permutation that we are using.
+ Please note that "no permutation" is one of the permutations in `self.permutations` at index `0`.
+ """
+
+ def __init__(self, n_channels, name='SwapChannels', parameters=None):
+ """
+ Initialize processor.
+
+ Args:
+ n_channels (int): Number of channels in audio that we want to process.
+ name (str): Name of processor.
+ parameters (parameter_list): Parameters for this processor.
+ """
+ super().__init__(name=name, parameters=parameters, block_size=None, sample_rate=None)
+
+ self.permutations = tuple(permutations(range(n_channels), n_channels))
+
+ if not parameters:
+ self.parameters = ParameterList()
+ self.parameters.add(Parameter('index', 0, 'int', minimum=0, maximum=len(self.permutations)))
+
+ def process(self, x):
+ """
+ Process audio.
+
+ Args:
+ x (Numpy array): input audio of size `n_samples x n_channels`.
+
+ Returns:
+ (Numpy array): channel-swapped audio of size `n_samples x n_channels`.
+ """
+ return x[:, self.permutations[self.parameters.index.value]]
+
+
+# %%%%%%%%%%%%%%%%%%%%%%% Monauralize %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+class Monauralize(Processor):
+ """
+ Monauralizes audio (i.e., removes spatial information).
+
+ Process parameters:
+ seed_channel (int): channel that we use for overwriting the others.
+ """
+
+ def __init__(self, n_channels, name='Monauralize', parameters=None):
+ """
+ Initialize processor.
+
+ Args:
+ n_channels (int): Number of channels in audio that we want to process.
+ name (str): Name of processor.
+ parameters (parameter_list): Parameters for this processor.
+ """
+ super().__init__(name=name, parameters=parameters, block_size=None, sample_rate=None)
+
+ if not parameters:
+ self.parameters = ParameterList()
+ self.parameters.add(Parameter('seed_channel', 0, 'int', minimum=0, maximum=n_channels))
+
+ def process(self, x):
+ """
+ Process audio.
+
+ Args:
+ x (Numpy array): input audio of size `n_samples x n_channels`.
+
+ Returns:
+ (Numpy array): monauralized audio of size `n_samples x n_channels`.
+ """
+ return np.tile(x[:, [self.parameters.seed_channel.value]], (1, x.shape[1]))
+
+
+# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%% PITCH SHIFT %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+class PitchShift(Processor):
+ """
+ Simple pitch shifter using SoX and soxbindings (https://github.com/pseeth/soxbindings).
+
+ Processor parameters:
+ steps (float): Pitch shift as positive/negative semitones
+ quick (bool): If True, this effect will run faster but with lower sound quality.
+ """
+
+ def __init__(self, sample_rate, fix_length=True, name='PitchShift', parameters=None):
+ """
+ Initialize processor.
+
+ Args:
+ sample_rate (int): Sample rate of input audio.
+ fix_length (bool): If True, then output has same length as input.
+ name (str): Name of processor.
+ parameters (parameter_list): Parameters for this processor.
+ """
+ super().__init__(name=name, parameters=parameters, block_size=None, sample_rate=sample_rate)
+
+ if not parameters:
+ self.parameters = ParameterList()
+ self.parameters.add(Parameter('steps', 0.0, 'float', minimum=-6., maximum=6.))
+ self.parameters.add(Parameter('quick', False, 'bool'))
+
+ self.fix_length = fix_length
+ self.clips = False
+
+ def process(self, x):
+ """
+ Process audio.
+
+ Args:
+ x (Numpy array): input audio of size `n_samples x n_channels`.
+
+ Returns:
+ (Numpy array): pitch-shifted audio of size `n_samples x n_channels`.
+ """
+ if self.parameters.steps.value == 0.0:
+ y = x
+ else:
+ scale = np.max(np.abs(x))
+ if scale > 0.9:
+ clips = True
+ x = x * (0.9 / scale)
+ else:
+ clips = False
+
+ tfm = sox.Transformer()
+ tfm.pitch(self.parameters.steps.value, quick=bool(self.parameters.quick.value))
+ y = tfm.build_array(input_array=x, sample_rate_in=self.sample_rate).astype(np.float32)
+
+ if clips:
+ y *= scale / 0.9 # rescale output to original scale
+
+ if self.fix_length:
+ n_samples_input = x.shape[0]
+ n_samples_output = y.shape[0]
+ if n_samples_input < n_samples_output:
+ idx1 = (n_samples_output - n_samples_input) // 2
+ idx2 = idx1 + n_samples_input
+ y = y[idx1:idx2]
+ elif n_samples_input > n_samples_output:
+ n_pad = n_samples_input - n_samples_output
+ y = np.pad(y, ((n_pad//2, n_pad - n_pad//2), (0, 0)))
+
+ return y
+
+
+# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%% TIME STRETCH %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+class TimeStretch(Processor):
+ """
+ Simple time stretcher using SoX and soxbindings (https://github.com/pseeth/soxbindings).
+
+ Processor parameters:
+ factor (float): Time stretch factor.
+ quick (bool): If True, this effect will run faster but with lower sound quality.
+ stretch_type (str): Algorithm used for stretching (`tempo` or `stretch`).
+ audio_type (str): Sets which time segments are most optmial when finding
+ the best overlapping points for time stretching.
+ """
+
+ def __init__(self, sample_rate, fix_length=True, name='TimeStretch', parameters=None):
+ """
+ Initialize processor.
+
+ Args:
+ sample_rate (int): Sample rate of input audio.
+ fix_length (bool): If True, then output has same length as input.
+ name (str): Name of processor.
+ parameters (parameter_list): Parameters for this processor.
+ """
+ super().__init__(name=name, parameters=parameters, block_size=None, sample_rate=sample_rate)
+
+ if not parameters:
+ self.parameters = ParameterList()
+ self.parameters.add(Parameter('factor', 1.0, 'float', minimum=1/1.33, maximum=1.33))
+ self.parameters.add(Parameter('quick', False, 'bool'))
+ self.parameters.add(Parameter('stretch_type', 'tempo', 'string', options=['tempo', 'stretch']))
+ self.parameters.add(Parameter('audio_type', 'l', 'string', options=['m', 's', 'l']))
+
+ self.fix_length = fix_length
+
+ def process(self, x):
+ """
+ Process audio.
+
+ Args:
+ x (Numpy array): input audio of size `n_samples x n_channels`.
+
+ Returns:
+ (Numpy array): time-stretched audio of size `n_samples x n_channels`.
+ """
+ if self.parameters.factor.value == 1.0:
+ y = x
+ else:
+ scale = np.max(np.abs(x))
+ if scale > 0.9:
+ clips = True
+ x = x * (0.9 / scale)
+ else:
+ clips = False
+
+ tfm = sox.Transformer()
+ if self.parameters.stretch_type.value == 'stretch':
+ tfm.stretch(self.parameters.factor.value)
+ elif self.parameters.stretch_type.value == 'tempo':
+ tfm.tempo(self.parameters.factor.value,
+ audio_type=self.parameters.audio_type.value,
+ quick=bool(self.parameters.quick.value))
+ y = tfm.build_array(input_array=x, sample_rate_in=self.sample_rate).astype(np.float32)
+
+ if clips:
+ y *= scale / 0.9 # rescale output to original scale
+
+ if self.fix_length:
+ n_samples_input = x.shape[0]
+ n_samples_output = y.shape[0]
+ if n_samples_input < n_samples_output:
+ idx1 = (n_samples_output - n_samples_input) // 2
+ idx2 = idx1 + n_samples_input
+ y = y[idx1:idx2]
+ elif n_samples_input > n_samples_output:
+ n_pad = n_samples_input - n_samples_output
+ y = np.pad(y, ((n_pad//2, n_pad - n_pad//2), (0, 0)))
+
+ return y
+
+
+# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%% PLAYBACK SPEED %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+class PlaybackSpeed(Processor):
+ """
+ Simple playback speed effect using SoX and soxbindings (https://github.com/pseeth/soxbindings).
+
+ Processor parameters:
+ factor (float): Playback speed factor.
+ """
+
+ def __init__(self, sample_rate, fix_length=True, name='PlaybackSpeed', parameters=None):
+ """
+ Initialize processor.
+
+ Args:
+ sample_rate (int): Sample rate of input audio.
+ fix_length (bool): If True, then output has same length as input.
+ name (str): Name of processor.
+ parameters (parameter_list): Parameters for this processor.
+ """
+ super().__init__(name=name, parameters=parameters, block_size=None, sample_rate=sample_rate)
+
+ if not parameters:
+ self.parameters = ParameterList()
+ self.parameters.add(Parameter('factor', 1.0, 'float', minimum=1./1.33, maximum=1.33))
+
+ self.fix_length = fix_length
+
+ def process(self, x):
+ """
+ Process audio.
+
+ Args:
+ x (Numpy array): input audio of size `n_samples x n_channels`.
+
+ Returns:
+ (Numpy array): resampled audio of size `n_samples x n_channels`.
+ """
+ if self.parameters.factor.value == 1.0:
+ y = x
+ else:
+ scale = np.max(np.abs(x))
+ if scale > 0.9:
+ clips = True
+ x = x * (0.9 / scale)
+ else:
+ clips = False
+
+ tfm = sox.Transformer()
+ tfm.speed(self.parameters.factor.value)
+ y = tfm.build_array(input_array=x, sample_rate_in=self.sample_rate).astype(np.float32)
+
+ if clips:
+ y *= scale / 0.9 # rescale output to original scale
+
+ if self.fix_length:
+ n_samples_input = x.shape[0]
+ n_samples_output = y.shape[0]
+ if n_samples_input < n_samples_output:
+ idx1 = (n_samples_output - n_samples_input) // 2
+ idx2 = idx1 + n_samples_input
+ y = y[idx1:idx2]
+ elif n_samples_input > n_samples_output:
+ n_pad = n_samples_input - n_samples_output
+ y = np.pad(y, ((n_pad//2, n_pad - n_pad//2), (0, 0)))
+
+ return y
+
+
+# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%% BEND %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+class Bend(Processor):
+ """
+ Simple bend effect using SoX and soxbindings (https://github.com/pseeth/soxbindings).
+
+ Processor parameters:
+ n_bends (int): Number of segments or intervals to pitch shift
+ """
+
+ def __init__(self, sample_rate, pitch_range=(-600, 600), fix_length=True, name='Bend', parameters=None):
+ """
+ Initialize processor.
+
+ Args:
+ sample_rate (int): Sample rate of input audio.
+ pitch_range (tuple of ints): min and max pitch bending ranges in cents
+ fix_length (bool): If True, then output has same length as input.
+ name (str): Name of processor.
+ parameters (parameter_list): Parameters for this processor.
+ """
+ super().__init__(name=name, parameters=parameters, block_size=None, sample_rate=sample_rate)
+
+ if not parameters:
+ self.parameters = ParameterList()
+ self.parameters.add(Parameter('n_bends', 2, 'int', minimum=2, maximum=10))
+ self.pitch_range_min, self.pitch_range_max = pitch_range
+
+ def process(self, x):
+ """
+ Process audio.
+
+ Args:
+ x (Numpy array): input audio of size `n_samples x n_channels`.
+
+ Returns:
+ (Numpy array): pitch-bended audio of size `n_samples x n_channels`.
+ """
+ n_bends = self.parameters.n_bends.value
+ max_length = x.shape[0] / self.sample_rate
+
+ # Generates random non-overlapping segments
+ delta = 1. / self.sample_rate
+ boundaries = np.sort(delta + np.random.rand(n_bends-1) * (max_length - delta))
+
+ start, end = np.zeros(n_bends), np.zeros(n_bends)
+ start[0] = delta
+ for i, b in enumerate(boundaries):
+ end[i] = b
+ start[i+1] = b
+ end[-1] = max_length
+
+ # randomly sample pitch-shifts in cents
+ cents = np.random.randint(self.pitch_range_min, self.pitch_range_max+1, n_bends)
+
+ # remove segment if cent value is zero or start == end (as SoX does not allow such values)
+ idx_keep = np.logical_and(cents != 0, start != end)
+ n_bends, start, end, cents = sum(idx_keep), start[idx_keep], end[idx_keep], cents[idx_keep]
+
+ scale = np.max(np.abs(x))
+ if scale > 0.9:
+ clips = True
+ x = x * (0.9 / scale)
+ else:
+ clips = False
+
+ tfm = sox.Transformer()
+ tfm.bend(n_bends=int(n_bends), start_times=list(start), end_times=list(end), cents=list(cents))
+ y = tfm.build_array(input_array=x, sample_rate_in=self.sample_rate).astype(np.float32)
+
+ if clips:
+ y *= scale / 0.9 # rescale output to original scale
+
+ return y
+
+
+
+
+
+# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ALGORITHMIC REVERB %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+class AlgorithmicReverb(Processor):
+ def __init__(self, name="algoreverb", parameters=None, sample_rate=44100, **kwargs):
+
+ super().__init__(name=name, parameters=parameters, block_size=None, sample_rate=sample_rate, **kwargs)
+
+ if not parameters:
+ self.parameters = ParameterList()
+ self.parameters.add(Parameter("room_size", 0.5, "float", minimum=0.05, maximum=0.85))
+ self.parameters.add(Parameter("damping", 0.1, "float", minimum=0.0, maximum=1.0))
+ self.parameters.add(Parameter("dry_mix", 0.9, "float", minimum=0.0, maximum=1.0))
+ self.parameters.add(Parameter("wet_mix", 0.1, "float", minimum=0.0, maximum=1.0))
+ self.parameters.add(Parameter("width", 0.7, "float", minimum=0.0, maximum=1.0))
+
+ # Tuning
+ self.stereospread = 23
+ self.scalegain = 0.2
+
+
+ def process(self, data):
+
+ if data.ndim >= 2:
+ dataL = data[:,0]
+ if data.shape[1] == 2:
+ dataR = data[:,1]
+ else:
+ dataR = data[:,0]
+ else:
+ dataL = data
+ dataR = data
+
+ output = np.zeros((data.shape[0], 2))
+
+ xL, xR = self.process_filters(dataL.copy(), dataR.copy())
+
+ wet1_g = self.parameters.wet_mix.value * ((self.parameters.width.value/2) + 0.5)
+ wet2_g = self.parameters.wet_mix.value * ((1-self.parameters.width.value)/2)
+ dry_g = self.parameters.dry_mix.value
+
+ output[:,0] = (wet1_g * xL) + (wet2_g * xR) + (dry_g * dataL)
+ output[:,1] = (wet1_g * xR) + (wet2_g * xL) + (dry_g * dataR)
+
+ return output
+
+ def process_filters(self, dataL, dataR):
+
+ xL = self.combL1.process(dataL.copy() * self.scalegain)
+ xL += self.combL2.process(dataL.copy() * self.scalegain)
+ xL += self.combL3.process(dataL.copy() * self.scalegain)
+ xL += self.combL4.process(dataL.copy() * self.scalegain)
+ xL = self.combL5.process(dataL.copy() * self.scalegain)
+ xL += self.combL6.process(dataL.copy() * self.scalegain)
+ xL += self.combL7.process(dataL.copy() * self.scalegain)
+ xL += self.combL8.process(dataL.copy() * self.scalegain)
+
+ xR = self.combR1.process(dataR.copy() * self.scalegain)
+ xR += self.combR2.process(dataR.copy() * self.scalegain)
+ xR += self.combR3.process(dataR.copy() * self.scalegain)
+ xR += self.combR4.process(dataR.copy() * self.scalegain)
+ xR = self.combR5.process(dataR.copy() * self.scalegain)
+ xR += self.combR6.process(dataR.copy() * self.scalegain)
+ xR += self.combR7.process(dataR.copy() * self.scalegain)
+ xR += self.combR8.process(dataR.copy() * self.scalegain)
+
+ yL1 = self.allpassL1.process(xL)
+ yL2 = self.allpassL2.process(yL1)
+ yL3 = self.allpassL3.process(yL2)
+ yL4 = self.allpassL4.process(yL3)
+
+ yR1 = self.allpassR1.process(xR)
+ yR2 = self.allpassR2.process(yR1)
+ yR3 = self.allpassR3.process(yR2)
+ yR4 = self.allpassR4.process(yR3)
+
+ return yL4, yR4
+
+ def update(self, parameter_name):
+
+ rs = self.parameters.room_size.value
+ dp = self.parameters.damping.value
+ ss = self.stereospread
+
+ # initialize allpass and feedback comb-filters
+ # (with coefficients optimized for fs=44.1kHz)
+ self.allpassL1 = pymc.components.allpass.Allpass(556, rs, self.block_size)
+ self.allpassR1 = pymc.components.allpass.Allpass(556+ss, rs, self.block_size)
+ self.allpassL2 = pymc.components.allpass.Allpass(441, rs, self.block_size)
+ self.allpassR2 = pymc.components.allpass.Allpass(441+ss, rs, self.block_size)
+ self.allpassL3 = pymc.components.allpass.Allpass(341, rs, self.block_size)
+ self.allpassR3 = pymc.components.allpass.Allpass(341+ss, rs, self.block_size)
+ self.allpassL4 = pymc.components.allpass.Allpass(225, rs, self.block_size)
+ self.allpassR4 = pymc.components.allpass.Allpass(255+ss, rs, self.block_size)
+
+ self.combL1 = pymc.components.comb.Comb(1116, dp, rs, self.block_size)
+ self.combR1 = pymc.components.comb.Comb(1116+ss, dp, rs, self.block_size)
+ self.combL2 = pymc.components.comb.Comb(1188, dp, rs, self.block_size)
+ self.combR2 = pymc.components.comb.Comb(1188+ss, dp, rs, self.block_size)
+ self.combL3 = pymc.components.comb.Comb(1277, dp, rs, self.block_size)
+ self.combR3 = pymc.components.comb.Comb(1277+ss, dp, rs, self.block_size)
+ self.combL4 = pymc.components.comb.Comb(1356, dp, rs, self.block_size)
+ self.combR4 = pymc.components.comb.Comb(1356+ss, dp, rs, self.block_size)
+ self.combL5 = pymc.components.comb.Comb(1422, dp, rs, self.block_size)
+ self.combR5 = pymc.components.comb.Comb(1422+ss, dp, rs, self.block_size)
+ self.combL6 = pymc.components.comb.Comb(1491, dp, rs, self.block_size)
+ self.combR6 = pymc.components.comb.Comb(1491+ss, dp, rs, self.block_size)
+ self.combL7 = pymc.components.comb.Comb(1557, dp, rs, self.block_size)
+ self.combR7 = pymc.components.comb.Comb(1557+ss, dp, rs, self.block_size)
+ self.combL8 = pymc.components.comb.Comb(1617, dp, rs, self.block_size)
+ self.combR8 = pymc.components.comb.Comb(1617+ss, dp, rs, self.block_size)
+
diff --git a/mixing_style_transfer/mixing_manipulator/common_dataprocessing.py b/mixing_style_transfer/mixing_manipulator/common_dataprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0a5a4faf0c410ea8ffb7bc034448f41ed3a8cf1
--- /dev/null
+++ b/mixing_style_transfer/mixing_manipulator/common_dataprocessing.py
@@ -0,0 +1,535 @@
+"""
+Module with common functions for loading training data and preparing minibatches.
+
+AI Music Technology Group, Sony Group Corporation
+AI Speech and Sound Group, Sony Europe
+
+This implementation originally belongs to Sony Group Corporation,
+ which has been introduced in the work "Automatic music mixing with deep learning and out-of-domain data".
+ Original repo link: https://github.com/sony/FxNorm-automix
+"""
+
+import numpy as np
+import os
+import sys
+import functools
+import scipy.io.wavfile as wav
+import soundfile as sf
+from typing import Tuple
+
+currentdir = os.path.dirname(os.path.realpath(__file__))
+sys.path.append(currentdir)
+from common_audioeffects import AugmentationChain
+from common_miscellaneous import uprint
+
+
+def load_wav(file_path, mmap=False, convert_float=False):
+ """
+ Load a WAV file in C_CONTIGUOUS format.
+
+ Args:
+ file_path: Path to WAV file (16bit, 24bit or 32bit PCM supported)
+ mmap: If `True`, then we do not load the WAV data into memory but use a memory-mapped representation
+
+ Returns:
+ fs: Sample rate
+ samples: Numpy array (np.int16 or np.int32) with audio [n_samples x n_channels]
+ """
+ fs, samples = wav.read(file_path, mmap=mmap)
+
+ # ensure that we have a 2d array (monaural files are just loaded as vectors)
+ if samples.ndim == 1:
+ samples = samples[:, np.newaxis]
+
+ # make sure that we have loaded an integer PCM WAV file as we assume this later
+ # when we scale the amplitude
+ assert(samples.dtype == np.int16 or samples.dtype == np.int32)
+
+ if convert_float:
+ conversion_scale = 1. / (1. + np.iinfo(samples.dtype).max)
+ samples = samples.astype(dtype=np.float32) * conversion_scale
+
+ return fs, samples
+
+
+def save_wav(file_path, fs, samples, subtype='PCM_16'):
+ """
+ Save a WAV file (16bit or 32bit PCM).
+
+ Important note: We save here using the same conversion as is used in
+ `generate_data`, i.e., we multiply by `1 + np.iinfo(np.int16).max`
+ or `1 + np.iinfo(np.int32).max` which is a different behavior
+ than `libsndfile` as described here:
+ http://www.mega-nerd.com/libsndfile/FAQ.html#Q010
+
+ Args:
+ file_path: Path where to store the WAV file
+ fs: Sample rate
+ samples: Numpy array (float32 with values in [-1, 1) and shape [n_samples x n_channels])
+ subtype: Either `PCM_16` or `PCM_24` or `PCM_32` in order to store as 16bit, 24bit or 32bit PCM file
+ """
+ assert subtype in ['PCM_16', 'PCM_24', 'PCM_32'], subtype
+
+ if subtype == 'PCM_16':
+ dtype = np.int16
+ else:
+ dtype = np.int32
+
+ # convert to int16 (check for clipping)
+
+ samples = samples * (1 + np.iinfo(dtype).max)
+ if np.min(samples) < np.iinfo(dtype).min or np.max(samples) > np.iinfo(dtype).max:
+ uprint(f'WARNING: Clipping occurs for {file_path}.')
+ samples_ = samples / (1 + np.iinfo(dtype).max)
+ print('max value ', np.max(np.abs(samples_)))
+ samples = np.clip(samples, np.iinfo(dtype).min, np.iinfo(dtype).max)
+ samples = samples.astype(dtype)
+
+ # store WAV file
+ sf.write(file_path, samples, fs, subtype=subtype)
+
+
+def load_files_lists(path):
+ """
+ Auxiliary function to find the paths for all mixtures in a database.
+
+ Args:
+ path: path to the folder containing the files to list
+
+ Returns:
+ list_of_directories: list of directories (= list of songs) in `path`
+ """
+ # get directories in `path`
+ list_of_directories = []
+ for folder in os.listdir(path):
+ list_of_directories.append(folder)
+
+ return list_of_directories
+
+
+def create_dataset(path, accepted_sampling_rates, sources, mapped_sources, n_channels=-1, load_to_memory=False,
+ debug=False, verbose=False):
+ """
+ Prepare data in `path` for training/validation/test set generation.
+
+ Args:
+ path: path to the dataset
+ accepted_sampling_rates: list of accepted sampling rates
+ sources: list of sources
+ mapped_sources: list of mapped sources
+ n_channels: number of channels
+ load_to_memory: whether to load to main memory
+ debug: if `True`, then we load only `NUM_SAMPLES_SMALL_DATASET`
+
+ Raises:
+ ValueError: mapping of sources not possible is data is not loaded into memory
+
+ Returns:
+ data: list of dictionaries with function handles (to load the data)
+ directories: list of directories
+ """
+ NUM_SAMPLES_SMALL_DATASET = 16
+
+ # source mapping currently only works if we load everything into the memory
+ if mapped_sources and not load_to_memory:
+ raise ValueError('Mapping of sources only supported if data is loaded into the memory.')
+
+ # get directories for dataset
+ directories = load_files_lists(path)
+
+ # load all songs for dataset
+ if debug:
+ data = [dict() for _x in range(np.minimum(NUM_SAMPLES_SMALL_DATASET, len(directories)))]
+ else:
+ data = [dict() for _x in range(len(directories))]
+
+ material_length = {} # in seconds
+ for i, d in enumerate(directories):
+ if verbose:
+ uprint(f'Processing mixture ({i+1} of {len(directories)}): {d}')
+
+ # add names of all files in this folder
+ files = os.listdir(os.path.join(path, d))
+ for f in files:
+ src_name = os.path.splitext(f)[0]
+ if ((src_name not in sources
+ and src_name not in mapped_sources)):
+ if verbose:
+ uprint(f'\tIgnoring unknown source from file {f}')
+ else:
+ if src_name not in sources:
+ src_name = mapped_sources[src_name]
+ if verbose:
+ uprint(f'\tAdding function handle for "{src_name}" from file {f}')
+
+ _data = load_wav(os.path.join(path, d, f), mmap=not load_to_memory)
+
+ # determine properties from loaded data
+ _samplingrate = _data[0]
+ _n_channels = _data[1].shape[1]
+ _duration = _data[1].shape[0] / _samplingrate
+
+ # collect statistics about data for each source
+ if src_name in material_length:
+ material_length[src_name] += _duration
+ else:
+ material_length[src_name] = _duration
+
+ # make sure that sample rate and number of channels matches
+ if n_channels != -1 and _n_channels != n_channels:
+ raise ValueError(f'File has {_n_channels} '
+ f'channels but expected {n_channels}.')
+
+ if _samplingrate not in accepted_sampling_rates:
+ raise ValueError(f'File has fs = {_samplingrate}Hz '
+ f'but expected {accepted_sampling_rates}Hz.')
+
+ # if we already loaded data for this source then append data
+ if src_name in data[i]:
+ _data = (_data[0], np.vstack((_data[1],
+ data[i][src_name].keywords['file_path_or_data'][1])))
+ data[i][src_name] = functools.partial(generate_data,
+ file_path_or_data=_data)
+
+ if debug and i == NUM_SAMPLES_SMALL_DATASET-1:
+ # load only first `NUM_SAMPLES_SMALL_DATASET` songs
+ break
+
+ # delete all entries where we did not find an source file
+ idx_empty = [_ for _ in range(len(data)) if len(data[_]) == 0]
+ for idx in sorted(idx_empty, reverse=True):
+ del data[idx]
+
+ return data, directories
+
+def create_dataset_mixing(path, accepted_sampling_rates, sources, mapped_sources, n_channels=-1, load_to_memory=False,
+ debug=False, pad_wrap_samples=None):
+ """
+ Prepare data in `path` for training/validation/test set generation.
+
+ Args:
+ path: path to the dataset
+ accepted_sampling_rates: list of accepted sampling rates
+ sources: list of sources
+ mapped_sources: list of mapped sources
+ n_channels: number of channels
+ load_to_memory: whether to load to main memory
+ debug: if `True`, then we load only `NUM_SAMPLES_SMALL_DATASET`
+
+ Raises:
+ ValueError: mapping of sources not possible is data is not loaded into memory
+
+ Returns:
+ data: list of dictionaries with function handles (to load the data)
+ directories: list of directories
+ """
+ NUM_SAMPLES_SMALL_DATASET = 16
+
+ # source mapping currently only works if we load everything into the memory
+ if mapped_sources and not load_to_memory:
+ raise ValueError('Mapping of sources only supported if data is loaded into the memory.')
+
+ # get directories for dataset
+ directories = load_files_lists(path)
+ directories.sort()
+
+ # load all songs for dataset
+ uprint(f'\nCreating dataset for path={path} ...')
+
+ if debug:
+ data = [dict() for _x in range(np.minimum(NUM_SAMPLES_SMALL_DATASET, len(directories)))]
+ else:
+ data = [dict() for _x in range(len(directories))]
+
+ material_length = {} # in seconds
+ for i, d in enumerate(directories):
+ uprint(f'Processing mixture ({i+1} of {len(directories)}): {d}')
+
+ # add names of all files in this folder
+ files = os.listdir(os.path.join(path, d))
+ _data_mix = []
+ _stems_name = []
+ for f in files:
+ src_name = os.path.splitext(f)[0]
+ if ((src_name not in sources
+ and src_name not in mapped_sources)):
+ uprint(f'\tIgnoring unknown source from file {f}')
+ else:
+ if src_name not in sources:
+ src_name = mapped_sources[src_name]
+ uprint(f'\tAdding function handle for "{src_name}" from file {f}')
+
+ _data = load_wav(os.path.join(path, d, f), mmap=not load_to_memory)
+
+ if pad_wrap_samples:
+ _data = (_data[0], np.pad(_data[1], [(pad_wrap_samples, 0), (0,0)], 'wrap'))
+
+ # determine properties from loaded data
+ _samplingrate = _data[0]
+ _n_channels = _data[1].shape[1]
+ _duration = _data[1].shape[0] / _samplingrate
+
+ # collect statistics about data for each source
+ if src_name in material_length:
+ material_length[src_name] += _duration
+ else:
+ material_length[src_name] = _duration
+
+ # make sure that sample rate and number of channels matches
+ if n_channels != -1 and _n_channels != n_channels:
+ if _n_channels == 1: # Converts mono to stereo with repeated channels
+ _data = (_data[0], np.repeat(_data[1], 2, axis=-1))
+ print("Converted file to stereo by repeating mono channel")
+ else:
+ raise ValueError(f'File has {_n_channels} '
+ f'channels but expected {n_channels}.')
+
+ if _samplingrate not in accepted_sampling_rates:
+ raise ValueError(f'File has fs = {_samplingrate}Hz '
+ f'but expected {accepted_sampling_rates}Hz.')
+
+ # if we already loaded data for this source then append data
+ if src_name in data[i]:
+ _data = (_data[0], np.vstack((_data[1],
+ data[i][src_name].keywords['file_path_or_data'][1])))
+
+ _data_mix.append(_data)
+ _stems_name.append(src_name)
+
+ data[i]["-".join(_stems_name)] = functools.partial(generate_data,
+ file_path_or_data=_data_mix)
+
+ if debug and i == NUM_SAMPLES_SMALL_DATASET-1:
+ # load only first `NUM_SAMPLES_SMALL_DATASET` songs
+ break
+
+ # delete all entries where we did not find an source file
+ idx_empty = [_ for _ in range(len(data)) if len(data[_]) == 0]
+ for idx in sorted(idx_empty, reverse=True):
+ del data[idx]
+
+ uprint(f'Finished preparation of dataset. '
+ f'Found in total the following material (in {len(data)} directories):')
+ for src in material_length:
+ uprint(f'\t{src}: {material_length[src] / 60.0 / 60.0:.2f} hours')
+ return data, directories
+
+
+def generate_data(file_path_or_data, random_sample_size=None):
+ """
+ Load one stem/several stems specified by `file_path_or_data`.
+
+ Alternatively, can also be the result of `wav.read()` if the data has already been loaded previously.
+
+ If `file_path_or_data` is a tuple/list, then we load several files and will return also a tuple/list.
+ This is useful for cases where we want to make sure to have the same random chunk for several stems.
+
+ If `random_sample_chunk_size` is not None, then only `random_sample_chunk_size` samples are randomly selected.
+
+ Args:
+ file_path_or_data: either path to data or the data itself
+ random_sample_size: if `random_sample_size` is not None, only `random_sample_size` samples are randomly selected
+
+ Returns:
+ samples: data with size `num_samples x num_channels` or a list of samples
+ """
+ needs_wrapping = False
+ if isinstance(file_path_or_data, str):
+ needs_wrapping = True # single file path -> wrap
+ if ((type(file_path_or_data[0]) is not list
+ and type(file_path_or_data[0]) is not tuple)):
+ needs_wrapping = True # single data -> wrap
+ if needs_wrapping:
+ file_path_or_data = (file_path_or_data,)
+
+ # create list where we store all samples
+ samples = [None] * len(file_path_or_data)
+
+ # load samples from wav file
+ for i, fpod in enumerate(file_path_or_data):
+ if isinstance(fpod, str):
+ _fs, samples[i] = load_wav(fpod)
+ else:
+ _fs, samples[i] = fpod
+
+ # if `random_sample_chunk_size` is not None, then only select subset
+ if random_sample_size is not None:
+ # get maximum length of all stems (at least `random_sample_chunk_size`)
+ max_length = random_sample_size
+ for s in samples:
+ max_length = np.maximum(max_length, s.shape[0])
+
+ # make sure that we can select enough audio and that all have the same length `max_length`
+ # (for short loops, `random_sample_chunk_size` can be larger than `s.shape[0]`)
+ for i, s in enumerate(samples):
+ if s.shape[0] < max_length:
+ required_padding = max_length - s.shape[0]
+ zeros = np.zeros((required_padding // 2 + 1, s.shape[1]),
+ dtype=s.dtype, order='F')
+ samples[i] = np.concatenate([zeros, s, zeros])
+
+ # select random part of audio
+ idx_start = np.random.randint(max_length)
+
+ for i, s in enumerate(samples):
+ if idx_start + random_sample_size < s.shape[0]:
+ samples[i] = s[idx_start:idx_start + random_sample_size]
+ else:
+ samples[i] = np.concatenate([s[idx_start:],
+ s[:random_sample_size - (s.shape[0] - idx_start)]])
+
+ # convert from `int16/int32` to `float32` precision (this will also make a copy)
+ for i, s in enumerate(samples):
+ conversion_scale = 1. / (1. + np.iinfo(s.dtype).max)
+ samples[i] = s.astype(dtype=np.float32) * conversion_scale
+
+ if len(samples) == 1:
+ return samples[0]
+ else:
+ return samples
+
+
+def create_minibatch(data: list, sources: list,
+ present_prob: dict, overlap_prob: dict,
+ augmenter: AugmentationChain, augmenter_padding: Tuple[int],
+ batch_size: int, n_samples: int, n_channels: int, idx_songs: dict):
+ """
+ Create a minibatch.
+
+ This function also handles the case that we do not have a source in one mixture.
+ This can, e.g., happen for instrumental pieces that do not have vocals.
+
+ Args:
+ data (list): data to create the minibatch from.
+ sources (list): list of sources.
+ present_prob (dict): probability of a source to be present.
+ overlap_prob (dict): probability of overlap.
+ augmenter (AugmentationChain): audio effect chain that we want to apply for data augmentation
+ augmenter_padding (tuple of ints): padding that we should apply to left/right side of data to avoid
+ boundary effects of `augmenter`.
+ batch_size (int): number of training samples in one minibatch.
+ n_samples (int): number of time samples.
+ n_channels (int): number of channels.
+ idx_songs (dict): index of songs.
+
+ Returns:
+ inp (Numpy array): minibatch, input to the network (i.e. the mixture) of size
+ `batch_size x n_samples x n_channels`
+ tar (dict with Numpy arrays): dictionary which contains for each source the targets,
+ each of the `c_contiguous` ndarrays is `batch_size x n_samples x n_channels`
+ """
+ # initialize numpy arrays which keep input/targets
+ shp = (batch_size, n_samples, n_channels)
+ inp = np.zeros(shape=shp, dtype=np.float32, order='C')
+ tar = {src: np.zeros(shape=shp, dtype=np.float32, order='C') for src in sources}
+
+ # use padding to avoid boundary effects of augmenter
+ pad_left = None if augmenter_padding[0] == 0 else augmenter_padding[0]
+ pad_right = None if augmenter_padding[1] == 0 else -augmenter_padding[1]
+
+ def augm(i, s, n):
+ return augmenter(data[i][s](random_sample_size=n+sum(augmenter_padding)))[pad_left:pad_right]
+
+ # create mini-batch
+ for src in sources:
+
+ for j in range(batch_size):
+ # get song index for this source
+ _idx_song = idx_songs[src][j]
+
+ # determine whether this source is present/whether we overlap
+ is_present = src not in present_prob or np.random.rand() < present_prob[src]
+ is_overlap = src in overlap_prob and np.random.rand() < overlap_prob[src]
+
+ # if song contains source, then add it to input/targetg]
+ if src in data[_idx_song] and is_present:
+ tar[src][j, ...] = augm(_idx_song, src, n_samples)
+
+ # overlap source with same source from randomly choosen other song
+ if is_overlap:
+ idx_overlap_ = np.random.randint(len(data))
+ if idx_overlap_ != _idx_song and src in data[idx_overlap_]:
+ tar[src][j, ...] += augm(idx_overlap_, src, n_samples)
+
+ # compute input
+ inp += tar[src]
+
+ # make sure that all have not too large amplitude (check only mixture)
+ maxabs_amp = np.maximum(1.0, 1e-6 + np.max(np.abs(inp), axis=(1, 2), keepdims=True))
+ inp /= maxabs_amp
+ for src in sources:
+ tar[src] /= maxabs_amp
+
+ return inp, tar
+
+def create_minibatch_mixing(data: list, sources: list, inputs: list, outputs: list,
+ present_prob: dict, overlap_prob: dict,
+ augmenter: AugmentationChain, augmenter_padding: Tuple[int], augmenter_sources: list,
+ batch_size: int, n_samples: int, n_channels: int, idx_songs: dict):
+ """
+ Create a minibatch.
+
+ This function also handles the case that we do not have a source in one mixture.
+ This can, e.g., happen for instrumental pieces that do not have vocals.
+
+ Args:
+ data (list): data to create the minibatch from.
+ sources (list): list of sources.
+ present_prob (dict): probability of a source to be present.
+ overlap_prob (dict): probability of overlap.
+ augmenter (AugmentationChain): audio effect chain that we want to apply for data augmentation
+ augmenter_padding (tuple of ints): padding that we should apply to left/right side of data to avoid
+ boundary effects of `augmenter`.
+ augmenter_sources (list): list of sources to augment
+ batch_size (int): number of training samples in one minibatch.
+ n_samples (int): number of time samples.
+ n_channels (int): number of channels.
+ idx_songs (dict): index of songs.
+
+ Returns:
+ inp (Numpy array): minibatch, input to the network (i.e. the mixture) of size
+ `batch_size x n_samples x n_channels`
+ tar (dict with Numpy arrays): dictionary which contains for each source the targets,
+ each of the `c_contiguous` ndarrays is `batch_size x n_samples x n_channels`
+ """
+ # initialize numpy arrays which keep input/targets
+ shp = (batch_size, n_samples, n_channels)
+ stems = {src: np.zeros(shape=shp, dtype=np.float32, order='C') for src in inputs}
+ mix = {src: np.zeros(shape=shp, dtype=np.float32, order='C') for src in outputs}
+
+ # use padding to avoid boundary effects of augmenter
+ pad_left = None if augmenter_padding[0] == 0 else augmenter_padding[0]
+ pad_right = None if augmenter_padding[1] == 0 else -augmenter_padding[1]
+
+ def augm(i, n):
+ s = list(data[i])[0]
+ input_multitracks = data[i][s](random_sample_size=n+sum(augmenter_padding))
+ audio_tags = list(data[i])[0].split("-")
+
+ # Only applies augmentation to inputs, not output.
+ for k, tag in enumerate(audio_tags):
+ if tag in augmenter_sources:
+ input_multitracks[k] = augmenter(input_multitracks[k])[pad_left:pad_right]
+ else:
+ input_multitracks[k] = input_multitracks[k][pad_left:pad_right]
+ return input_multitracks
+
+ # create mini-batch
+ for src in outputs:
+
+ for j in range(batch_size):
+ # get song index for this source
+ _idx_song = idx_songs[src][j]
+
+ multitrack_audio = augm(_idx_song, n_samples)
+
+ audio_tags = list(data[_idx_song])[0].split("-")
+
+ for i, tag in enumerate(audio_tags):
+ if tag in inputs:
+ stems[tag][j, ...] = multitrack_audio[i]
+ if tag in outputs:
+ mix[tag][j, ...] = multitrack_audio[i]
+
+ return stems, mix
+
diff --git a/mixing_style_transfer/mixing_manipulator/common_miscellaneous.py b/mixing_style_transfer/mixing_manipulator/common_miscellaneous.py
new file mode 100644
index 0000000000000000000000000000000000000000..a996f9b3b1b2732d8b30e1e9d816d8e6de28f749
--- /dev/null
+++ b/mixing_style_transfer/mixing_manipulator/common_miscellaneous.py
@@ -0,0 +1,219 @@
+"""
+Common miscellaneous functions.
+
+AI Music Technology Group, Sony Group Corporation
+AI Speech and Sound Group, Sony Europe
+
+This implementation originally belongs to Sony Group Corporation,
+ which has been introduced in the work "Automatic music mixing with deep learning and out-of-domain data".
+ Original repo link: https://github.com/sony/FxNorm-automix
+"""
+import os
+import psutil
+import sys
+import numpy as np
+import librosa
+import torch
+import math
+
+
+def uprint(s):
+ """
+ Unbuffered print to stdout.
+
+ We also flush stderr to have the log-file in sync.
+
+ Args:
+ s: string to print
+ """
+ print(s)
+ sys.stdout.flush()
+ sys.stderr.flush()
+
+
+def recursive_getattr(obj, attr):
+ """
+ Run `getattr` recursively (e.g., for `fc1.weight`).
+
+ Args:
+ obj: object
+ attr: attribute to get
+
+ Returns:
+ object
+ """
+ for a in attr.split('.'):
+ obj = getattr(obj, a)
+ return obj
+
+
+def compute_stft(samples, hop_length, fft_size, stft_window):
+ """
+ Compute the STFT of `samples` applying a Hann window of size `FFT_SIZE`, shifted for each frame by `hop_length`.
+
+ Args:
+ samples: num samples x channels
+ hop_length: window shift in samples
+ fft_size: FFT size which is also the window size
+ stft_window: STFT analysis window
+
+ Returns:
+ stft: frames x channels x freqbins
+ """
+ n_channels = samples.shape[1]
+ n_frames = 1+int((samples.shape[0] - fft_size)/hop_length)
+ stft = np.empty((n_frames, n_channels, fft_size//2+1), dtype=np.complex64)
+
+ # convert into f_contiguous (such that [:,n] slicing is c_contiguous)
+ samples = np.asfortranarray(samples)
+
+ for n in range(n_channels):
+ # compute STFT (output has size `n_frames x N_BINS`)
+ stft[:, n, :] = librosa.stft(samples[:, n],
+ n_fft=fft_size,
+ hop_length=hop_length,
+ window=stft_window,
+ center=False).transpose()
+ return stft
+
+
+def compute_istft(stft, hop_length, stft_window):
+ """
+ Compute the inverse STFT of `stft`.
+
+ Args:
+ stft: frames x channels x freqbins
+ hop_length: window shift in samples
+ stft_window: STFT synthesis window
+
+ Returns:
+ samples: num samples x channels
+ """
+ for n in range(stft.shape[1]):
+ s = librosa.istft(stft[:, n, :].transpose(),
+ hop_length=hop_length, window=stft_window, center=False)
+ if n == 0:
+ samples = s
+ else:
+ samples = np.column_stack((samples, s))
+
+ # ensure that we have a 2d array (monaural files are just loaded as vectors)
+ if samples.ndim == 1:
+ samples = samples[:, np.newaxis]
+
+ return samples
+
+
+def get_size(obj):
+ """
+ Recursively find size of objects (in bytes).
+
+ Args:
+ obj: object
+
+ Returns:
+ size of object
+ """
+ size = sys.getsizeof(obj)
+
+ import functools
+
+ if isinstance(obj, dict):
+ size += sum([get_size(v) for v in obj.values()])
+ size += sum([get_size(k) for k in obj.keys()])
+ elif isinstance(obj, functools.partial):
+ size += sum([get_size(v) for v in obj.keywords.values()])
+ size += sum([get_size(k) for k in obj.keywords.keys()])
+ elif isinstance(obj, list):
+ size += sum([get_size(i) for i in obj])
+ elif isinstance(obj, tuple):
+ size += sum([get_size(i) for i in obj])
+ return size
+
+
+def get_process_memory():
+ """
+ Return memory consumption in GBytes.
+
+ Returns:
+ memory used by the process
+ """
+ return psutil.Process(os.getpid()).memory_info()[0] / (2 ** 30)
+
+
+def check_complete_convolution(input_size, kernel_size, stride=1,
+ padding=0, dilation=1, note=''):
+ """
+ Check where the convolution is complete.
+
+ Returns true if no time steps left over in a Conv1d
+
+ Args:
+ input_size: size of input
+ kernel_size: size of kernel
+ stride: stride
+ padding: padding
+ dilation: dilation
+ note: string for additional notes
+ """
+ is_complete = ((input_size + 2*padding - dilation * (kernel_size - 1) - 1)
+ / stride + 1).is_integer()
+ uprint(f'{note} {is_complete}')
+
+
+def pad_to_shape(x: torch.Tensor, y: int) -> torch.Tensor:
+ """
+ Right-pad or right-trim first argument last dimension to have same size as second argument.
+
+ Args:
+ x: Tensor to be padded.
+ y: Size to pad/trim x last dimension to
+
+ Returns:
+ `x` padded to match `y`'s dimension.
+ """
+ inp_len = y
+ output_len = x.shape[-1]
+ return torch.nn.functional.pad(x, [0, inp_len - output_len])
+
+
+def valid_length(input_size, kernel_size, stride=1, padding=0, dilation=1):
+ """
+ Return the nearest valid upper length to use with the model so that there is no time steps left over in a 1DConv.
+
+ For all layers, size of the (input - kernel_size) % stride = 0.
+ Here valid means that there is no left over frame neglected and discarded.
+
+ Args:
+ input_size: size of input
+ kernel_size: size of kernel
+ stride: stride
+ padding: padding
+ dilation: dilation
+
+ Returns:
+ valid length for convolution
+ """
+ length = math.ceil((input_size + 2*padding - dilation * (kernel_size - 1) - 1)/stride) + 1
+ length = (length - 1) * stride - 2*padding + dilation * (kernel_size - 1) + 1
+
+ return int(length)
+
+
+def td_length_from_fd(fd_length: int, fft_size: int, fft_hop: int) -> int:
+ """
+ Return the length in time domain, given the length in frequency domain.
+
+ Return the necessary length in the time domain of a signal to be transformed into
+ a signal of length `fd_length` in time-frequency domain with the given STFT
+ parameters `fft_size` and `fft_hop`. No padding is assumed.
+
+ Args:
+ fd_length: length in frequency domain
+ fft_size: size of FFT
+ fft_hop: hop length
+
+ Returns:
+ length in time domain
+ """
+ return (fd_length - 1) * fft_hop + fft_size
diff --git a/mixing_style_transfer/mixing_manipulator/data_normalization.py b/mixing_style_transfer/mixing_manipulator/data_normalization.py
new file mode 100644
index 0000000000000000000000000000000000000000..368aaefdc0cc2d9c56b87c107da9630639664e52
--- /dev/null
+++ b/mixing_style_transfer/mixing_manipulator/data_normalization.py
@@ -0,0 +1,173 @@
+"""
+ Implementation of the 'audio effects chain normalization'
+"""
+import numpy as np
+import scipy
+
+import os
+import sys
+currentdir = os.path.dirname(os.path.realpath(__file__))
+sys.path.append(currentdir)
+from utils_data_normalization import *
+from normalization_imager import *
+
+
+'''
+ Audio Effects Chain Normalization
+ process: normalizes input stems according to given precomputed features
+'''
+class Audio_Effects_Normalizer:
+ def __init__(self, precomputed_feature_path, \
+ STEMS=['drums', 'bass', 'other', 'vocals'], \
+ EFFECTS=['eq', 'compression', 'imager', 'loudness']):
+ self.STEMS = STEMS # Stems to be normalized
+ self.EFFECTS = EFFECTS # Effects to be normalized, order matters
+
+ # Audio settings
+ self.SR = 44100
+ self.SUBTYPE = 'PCM_16'
+
+ # General Settings
+ self.FFT_SIZE = 2**16
+ self.HOP_LENGTH = self.FFT_SIZE//4
+
+ # Loudness
+ self.NTAPS = 1001
+ self.LUFS = -30
+ self.MIN_DB = -40 # Min amplitude to apply EQ matching
+
+ # Compressor
+ self.COMP_USE_EXPANDER = False
+ self.COMP_PEAK_NORM = -10.0
+ self.COMP_TRUE_PEAK = False
+ self.COMP_PERCENTILE = 75 # features_mean (v1) was done with 25
+ self.COMP_MIN_TH = -40
+ self.COMP_MAX_RATIO = 20
+ comp_settings = {key:{} for key in self.STEMS}
+ for key in comp_settings:
+ if key == 'vocals':
+ comp_settings[key]['attack'] = 7.5
+ comp_settings[key]['release'] = 400.0
+ comp_settings[key]['ratio'] = 4
+ comp_settings[key]['n_mels'] = 128
+ elif key == 'drums':
+ comp_settings[key]['attack'] = 10.0
+ comp_settings[key]['release'] = 180.0
+ comp_settings[key]['ratio'] = 6
+ comp_settings[key]['n_mels'] = 128
+ elif key == 'bass':
+ comp_settings[key]['attack'] = 10.0
+ comp_settings[key]['release'] = 500.0
+ comp_settings[key]['ratio'] = 5
+ comp_settings[key]['n_mels'] = 16
+ elif key == 'other':
+ comp_settings[key]['attack'] = 15.0
+ comp_settings[key]['release'] = 666.0
+ comp_settings[key]['ratio'] = 4
+ comp_settings[key]['n_mels'] = 128
+ self.comp_settings = comp_settings
+
+ # Load Pre-computed Audio Effects Features
+ features_mean = np.load(precomputed_feature_path, allow_pickle='TRUE')[()]
+ self.features_mean = self.smooth_feature(features_mean)
+
+
+ # normalize current audio input with the order of designed audio FX
+ def normalize_audio(self, audio, src):
+ assert src in self.STEMS
+
+ normalized_audio = audio
+ for cur_effect in self.EFFECTS:
+ normalized_audio = self.normalize_audio_per_effect(normalized_audio, src=src, effect=cur_effect)
+
+ return normalized_audio
+
+
+ # normalize current audio input with current targeted audio FX
+ def normalize_audio_per_effect(self, audio, src, effect):
+ audio = audio.astype(dtype=np.float32)
+ audio_track = np.pad(audio, ((self.FFT_SIZE, self.FFT_SIZE), (0, 0)), mode='constant')
+
+ assert len(audio_track.shape) == 2 # Always expects two dimensions
+
+ if audio_track.shape[1] == 1: # Converts mono to stereo with repeated channels
+ audio_track = np.repeat(audio_track, 2, axis=-1)
+
+ output_audio = audio_track.copy()
+
+ max_db = amp_to_db(np.max(np.abs(output_audio)))
+ if max_db > self.MIN_DB:
+
+ if effect == 'eq':
+ # normalize each channel
+ for ch in range(audio_track.shape[1]):
+ audio_eq_matched = get_eq_matching(output_audio[:, ch],
+ self.features_mean[effect][src],
+ sr=self.SR,
+ n_fft=self.FFT_SIZE,
+ hop_length=self.HOP_LENGTH,
+ min_db=self.MIN_DB,
+ ntaps=self.NTAPS,
+ lufs=self.LUFS)
+
+
+ np.copyto(output_audio[:,ch], audio_eq_matched)
+
+ elif effect == 'compression':
+ assert(len(self.features_mean[effect][src])==2)
+ # normalize each channel
+ for ch in range(audio_track.shape[1]):
+ try:
+ audio_comp_matched = get_comp_matching(output_audio[:, ch],
+ self.features_mean[effect][src][0],
+ self.features_mean[effect][src][1],
+ self.comp_settings[src]['ratio'],
+ self.comp_settings[src]['attack'],
+ self.comp_settings[src]['release'],
+ sr=self.SR,
+ min_db=self.MIN_DB,
+ min_th=self.COMP_MIN_TH,
+ comp_peak_norm=self.COMP_PEAK_NORM,
+ max_ratio=self.COMP_MAX_RATIO,
+ n_mels=self.comp_settings[src]['n_mels'],
+ true_peak=self.COMP_TRUE_PEAK,
+ percentile=self.COMP_PERCENTILE,
+ expander=self.COMP_USE_EXPANDER)
+
+ np.copyto(output_audio[:,ch], audio_comp_matched[:, 0])
+ except:
+ break
+
+ elif effect == 'loudness':
+ output_audio = fx_utils.lufs_normalize(output_audio, self.SR, self.features_mean[effect][src], log=False)
+
+ elif effect == 'imager':
+ # threshold of applying Haas effects
+ mono_threshold = 0.99 if src=='bass' else 0.975
+ audio_imager_matched = normalize_imager(output_audio, \
+ target_side_mid_bal=self.features_mean[effect][src], \
+ mono_threshold=mono_threshold, \
+ sr=self.SR)
+
+ np.copyto(output_audio, audio_imager_matched)
+
+ output_audio = output_audio[self.FFT_SIZE:self.FFT_SIZE+audio.shape[0]]
+ return output_audio
+
+
+ def smooth_feature(self, feature_dict_):
+
+ for effect in self.EFFECTS:
+ for key in self.STEMS:
+ if effect == 'eq':
+ if key in ['other', 'vocals']:
+ f = 401
+ else:
+ f = 151
+ feature_dict_[effect][key] = scipy.signal.savgol_filter(feature_dict_[effect][key],
+ f, 1, mode='mirror')
+ elif effect == 'panning':
+ feature_dict_[effect][key] = scipy.signal.savgol_filter(feature_dict_[effect][key],
+ 501, 1, mode='mirror')
+ return feature_dict_
+
diff --git a/mixing_style_transfer/mixing_manipulator/fx_utils.py b/mixing_style_transfer/mixing_manipulator/fx_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1dd3137c8cb5bc3ed0a86a65a1b79fb2ab8cf73e
--- /dev/null
+++ b/mixing_style_transfer/mixing_manipulator/fx_utils.py
@@ -0,0 +1,313 @@
+import warnings
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+
+import numpy as np
+import scipy
+import math
+import librosa
+import librosa.display
+import fnmatch
+import os
+from functools import partial
+import pyloudnorm
+from scipy.signal import lfilter
+from sklearn.metrics import mean_absolute_error, mean_squared_error
+from sklearn.metrics.pairwise import paired_distances
+
+
+import matplotlib.pyplot as plt
+
+def db(x):
+ """Computes the decible energy of a signal"""
+ return 20*np.log10(np.sqrt(np.mean(np.square(x))))
+
+def melspectrogram(y, mirror_pad=False):
+ """Compute melspectrogram feature extraction
+
+ Keyword arguments:
+ signal -- input audio as a signal in a numpy object
+ inputnorm -- normalization of output
+ mirror_pad -- pre and post-pend mirror signals
+
+ Returns freq x time
+
+
+ Assumes the input sampling rate is 22050Hz
+ """
+
+ # Extract mel.
+ fftsize = 1024
+ window = 1024
+ hop = 512
+ melBin = 128
+ sr = 22050
+
+ # mirror pad signal
+ # first embedding centered on time 0
+ # last embedding centered on end of signal
+ if mirror_pad:
+ y = np.insert(y, 0, y[0:int(half_frame_length_sec * sr)][::-1])
+ y = np.insert(y, len(y), y[-int(half_frame_length_sec * sr):][::-1])
+
+ S = librosa.core.stft(y,n_fft=fftsize,hop_length=hop,win_length=window)
+ X = np.abs(S)
+ mel_basis = librosa.filters.mel(sr,n_fft=fftsize,n_mels=melBin)
+ mel_S = np.dot(mel_basis,X)
+
+ # value log compression
+ mel_S = np.log10(1+10*mel_S)
+ mel_S = mel_S.astype(np.float32)
+
+
+ return mel_S
+
+
+def getFilesPath(directory, extension):
+
+ n_path=[]
+ for path, subdirs, files in os.walk(directory):
+ for name in files:
+ if fnmatch.fnmatch(name, extension):
+ n_path.append(os.path.join(path,name))
+ n_path.sort()
+
+ return n_path
+
+
+
+def getRandomTrim(x, length, pad=0, start=None):
+
+ length = length+pad
+ if x.shape[0] <= length:
+ x_ = x
+ while(x.shape[0] <= length):
+ x_ = np.concatenate((x_,x_))
+ else:
+ if start is None:
+ start = np.random.randint(0, x.shape[0]-length, size=None)
+ end = length+start
+ if end > x.shape[0]:
+ x_ = x[start:]
+ x_ = np.concatenate((x_, x[:length-x.shape[0]]))
+ else:
+ x_ = x[start:length+start]
+
+ return x_[:length]
+
+def fadeIn(x, length=128):
+
+ w = scipy.signal.hann(length*2, sym=True)
+ w1 = w[0:length]
+ ones = np.ones(int(x.shape[0]-length))
+ w = np.append(w1, ones)
+
+ return x*w
+
+def fadeOut(x, length=128):
+
+ w = scipy.signal.hann(length*2, sym=True)
+ w2 = w[length:length*2]
+ ones = np.ones(int(x.shape[0]-length))
+ w = np.append(ones, w2)
+
+ return x*w
+
+
+def plotTimeFreq(audio, sr, n_fft=512, hop_length=128, ylabels=None):
+
+ n = len(audio)
+# plt.figure(figsize=(14, 4*n))
+ colors = list(plt.cm.viridis(np.linspace(0,1,n)))
+
+ X = []
+ X_db = []
+ maxs = np.zeros((n,))
+ mins = np.zeros((n,))
+ maxs_t = np.zeros((n,))
+ for i, x in enumerate(audio):
+
+ if x.ndim == 2 and x.shape[-1] == 2:
+ x = librosa.core.to_mono(x.T)
+ X_ = librosa.stft(x, n_fft=n_fft, hop_length=hop_length)
+ X_db_ = librosa.amplitude_to_db(abs(X_))
+ X.append(X_)
+ X_db.append(X_db_)
+ maxs[i] = np.max(X_db_)
+ mins[i] = np.min(X_db_)
+ maxs_t[i] = np.max(np.abs(x))
+ vmax = np.max(maxs)
+ vmin = np.min(mins)
+ tmax = np.max(maxs_t)
+ for i, x in enumerate(audio):
+
+ if x.ndim == 2 and x.shape[-1] == 2:
+ x = librosa.core.to_mono(x.T)
+
+ plt.subplot(n, 2, 2*i+1)
+ librosa.display.waveplot(x, sr=sr, color=colors[i])
+ if ylabels:
+ plt.ylabel(ylabels[i])
+
+ plt.ylim(-tmax,tmax)
+ plt.subplot(n, 2, 2*i+2)
+ librosa.display.specshow(X_db[i], sr=sr, x_axis='time', y_axis='log',
+ hop_length=hop_length, cmap='GnBu', vmax=vmax, vmin=vmin)
+# plt.colorbar(format='%+2.0f dB')
+
+
+
+
+
+
+
+
+def slicing(x, win_length, hop_length, center = True, windowing = False, pad = 0):
+ # Pad the time series so that frames are centered
+ if center:
+# x = np.pad(x, int((win_length-hop_length+pad) // 2), mode='constant')
+ x = np.pad(x, ((int((win_length-hop_length+pad)//2), int((win_length+hop_length+pad)//2)),), mode='constant')
+
+ # Window the time series.
+ y_frames = librosa.util.frame(x, frame_length=win_length, hop_length=hop_length)
+ if windowing:
+ window = scipy.signal.hann(win_length, sym=False)
+ else:
+ window = 1.0
+ f = []
+ for i in range(len(y_frames.T)):
+ f.append(y_frames.T[i]*window)
+ return np.float32(np.asarray(f))
+
+
+def overlap(x, x_len, win_length, hop_length, windowing = True, rate = 1):
+ x = x.reshape(x.shape[0],x.shape[1]).T
+ if windowing:
+ window = scipy.signal.hann(win_length, sym=False)
+ rate = rate*hop_length/win_length
+ else:
+ window = 1
+ rate = 1
+ n_frames = x_len / hop_length
+ expected_signal_len = int(win_length + hop_length * (n_frames))
+ y = np.zeros(expected_signal_len)
+ for i in range(int(n_frames)):
+ sample = i * hop_length
+ w = x[:, i]
+ y[sample:(sample + win_length)] = y[sample:(sample + win_length)] + w*window
+ y = y[int(win_length // 2):-int(win_length // 2)]
+ return np.float32(y*rate)
+
+
+
+
+
+
+
+def highpassFiltering(x_list, f0, sr):
+
+ b1, a1 = scipy.signal.butter(4, f0/(sr/2),'highpass')
+ x_f = []
+ for x in x_list:
+ x_f_ = scipy.signal.filtfilt(b1, a1, x).copy(order='F')
+ x_f.append(x_f_)
+ return x_f
+
+def lineartodB(x):
+ return 20*np.log10(x)
+def dBtoLinear(x):
+ return np.power(10,x/20)
+
+def lufs_normalize(x, sr, lufs, log=True):
+
+ # measure the loudness first
+ meter = pyloudnorm.Meter(sr) # create BS.1770 meter
+ loudness = meter.integrated_loudness(x+1e-10)
+ if log:
+ print("original loudness: ", loudness," max value: ", np.max(np.abs(x)))
+
+ loudness_normalized_audio = pyloudnorm.normalize.loudness(x, loudness, lufs)
+
+ maxabs_amp = np.maximum(1.0, 1e-6 + np.max(np.abs(loudness_normalized_audio)))
+ loudness_normalized_audio /= maxabs_amp
+
+ loudness = meter.integrated_loudness(loudness_normalized_audio)
+ if log:
+ print("new loudness: ", loudness," max value: ", np.max(np.abs(loudness_normalized_audio)))
+
+
+ return loudness_normalized_audio
+
+import soxbindings as sox
+
+def lufs_normalize_compand(x, sr, lufs):
+
+ tfm = sox.Transformer()
+ tfm.compand(attack_time = 0.001,
+ decay_time = 0.01,
+ soft_knee_db = 1.0,
+ tf_points = [(-70, -70), (-0.1, -20), (0, 0)])
+
+ x = tfm.build_array(input_array=x, sample_rate_in=sr).astype(np.float32)
+
+ # measure the loudness first
+ meter = pyloudnorm.Meter(sr) # create BS.1770 meter
+ loudness = meter.integrated_loudness(x)
+ print("original loudness: ", loudness," max value: ", np.max(np.abs(x)))
+
+ loudness_normalized_audio = pyloudnorm.normalize.loudness(x, loudness, lufs)
+
+ maxabs_amp = np.maximum(1.0, 1e-6 + np.max(np.abs(loudness_normalized_audio)))
+ loudness_normalized_audio /= maxabs_amp
+
+ loudness = meter.integrated_loudness(loudness_normalized_audio)
+ print("new loudness: ", loudness," max value: ", np.max(np.abs(loudness_normalized_audio)))
+
+
+
+
+
+
+ return loudness_normalized_audio
+
+
+
+
+
+def getDistances(x,y):
+
+ distances = {}
+ distances['mae'] = mean_absolute_error(x, y)
+ distances['mse'] = mean_squared_error(x, y)
+ distances['euclidean'] = np.mean(paired_distances(x, y, metric='euclidean'))
+ distances['manhattan'] = np.mean(paired_distances(x, y, metric='manhattan'))
+ distances['cosine'] = np.mean(paired_distances(x, y, metric='cosine'))
+
+ distances['mae'] = round(distances['mae'], 5)
+ distances['mse'] = round(distances['mse'], 5)
+ distances['euclidean'] = round(distances['euclidean'], 5)
+ distances['manhattan'] = round(distances['manhattan'], 5)
+ distances['cosine'] = round(distances['cosine'], 5)
+
+ return distances
+
+def getMFCC(x, sr, mels=128, mfcc=13, mean_norm=False):
+
+ melspec = librosa.feature.melspectrogram(y=x, sr=sr, S=None,
+ n_fft=1024, hop_length=256,
+ n_mels=mels, power=2.0)
+ melspec_dB = librosa.power_to_db(melspec, ref=np.max)
+ mfcc = librosa.feature.mfcc(S=melspec_dB, sr=sr, n_mfcc=mfcc)
+ if mean_norm:
+ mfcc -= (np.mean(mfcc, axis=0))
+ return mfcc
+
+
+def getMSE_MFCC(y_true, y_pred, sr, mels=128, mfcc=13, mean_norm=False):
+
+ ratio = np.mean(np.abs(y_true))/np.mean(np.abs(y_pred))
+ y_pred = ratio*y_pred
+
+ y_mfcc = getMFCC(y_true, sr, mels=mels, mfcc=mfcc, mean_norm=mean_norm)
+ z_mfcc = getMFCC(y_pred, sr, mels=mels, mfcc=mfcc, mean_norm=mean_norm)
+
+ return getDistances(y_mfcc[:,:], z_mfcc[:,:])
\ No newline at end of file
diff --git a/mixing_style_transfer/mixing_manipulator/normalization_imager.py b/mixing_style_transfer/mixing_manipulator/normalization_imager.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2848432f798baa8b0afb73e3d31cd42f2114885
--- /dev/null
+++ b/mixing_style_transfer/mixing_manipulator/normalization_imager.py
@@ -0,0 +1,121 @@
+"""
+ Implementation of the normalization process of stereo-imaging and panning effects
+"""
+import numpy as np
+import sys
+import os
+
+currentdir = os.path.dirname(os.path.realpath(__file__))
+sys.path.append(currentdir)
+from common_audioeffects import AugmentationChain, Haas
+
+
+'''
+ ### normalization algorithm for stereo imaging and panning effects ###
+ process :
+ 1. inputs 2-channeled audio
+ 2. apply Haas effects if the input audio is almost mono
+ 3. normalize mid-side channels according to target precomputed feature value
+ 4. normalize left-right channels 50-50
+ 5. normalize mid-side channels again
+'''
+def normalize_imager(data, \
+ target_side_mid_bal=0.9, \
+ mono_threshold=0.95, \
+ sr=44100, \
+ eps=1e-04, \
+ verbose=False):
+
+ # to mid-side channels
+ mid, side = lr_to_ms(data[:,0], data[:,1])
+
+ if verbose:
+ print_balance(data[:,0], data[:,1])
+ print_balance(mid, side)
+ print()
+
+ # apply mid-side weights according to energy
+ mid_e, side_e = np.sum(mid**2), np.sum(side**2)
+ total_e = mid_e + side_e
+ # apply haas effect to almost-mono signal
+ if mid_e/total_e > mono_threshold:
+ aug_chain = AugmentationChain(fxs=[(Haas(sample_rate=sr), 1, True)])
+ data = aug_chain([data])[0]
+ mid, side = lr_to_ms(data[:,0], data[:,1])
+
+ if verbose:
+ print_balance(data[:,0], data[:,1])
+ print_balance(mid, side)
+ print()
+
+ # normalize mid-side channels (stereo imaging)
+ new_mid, new_side = process_balance(mid, side, tgt_e1_bal=target_side_mid_bal, eps=eps)
+ left, right = ms_to_lr(new_mid, new_side)
+ imaged = np.stack([left, right], 1)
+
+ if verbose:
+ print_balance(new_mid, new_side)
+ print_balance(left, right)
+ print()
+
+ # normalize panning to have the balance of left-right channels 50-50
+ left, right = process_balance(left, right, tgt_e1_bal=0.5, eps=eps)
+ mid, side = lr_to_ms(left, right)
+
+ if verbose:
+ print_balance(mid, side)
+ print_balance(left, right)
+ print()
+
+ # normalize again mid-side channels (stereo imaging)
+ new_mid, new_side = process_balance(mid, side, tgt_e1_bal=target_side_mid_bal, eps=eps)
+ left, right = ms_to_lr(new_mid, new_side)
+ imaged = np.stack([left, right], 1)
+
+ if verbose:
+ print_balance(new_mid, new_side)
+ print_balance(left, right)
+ print()
+
+ return imaged
+
+
+# balance out 2 input data's energy according to given balance
+# tgt_e1_bal range = [0.0, 1.0]
+ # tgt_e2_bal = 1.0 - tgt_e1_bal_range
+def process_balance(data_1, data_2, tgt_e1_bal=0.5, eps=1e-04):
+
+ e_1, e_2 = np.sum(data_1**2), np.sum(data_2**2)
+ total_e = e_1 + e_2
+
+ tgt_1_gain = np.sqrt(tgt_e1_bal * total_e / (e_1 + eps))
+
+ new_data_1 = data_1 * tgt_1_gain
+ new_e_1 = e_1 * (tgt_1_gain ** 2)
+ left_e_1 = total_e - new_e_1
+ tgt_2_gain = np.sqrt(left_e_1 / (e_2 + 1e-3))
+ new_data_2 = data_2 * tgt_2_gain
+
+ return new_data_1, new_data_2
+
+
+# left-right channeled signal to mid-side signal
+def lr_to_ms(left, right):
+ mid = left + right
+ side = left - right
+ return mid, side
+
+
+# mid-side channeled signal to left-right signal
+def ms_to_lr(mid, side):
+ left = (mid + side) / 2
+ right = (mid - side) / 2
+ return left, right
+
+
+# print energy balance of 2 inputs
+def print_balance(data_1, data_2):
+ e_1, e_2 = np.sum(data_1**2), np.sum(data_2**2)
+ total_e = e_1 + e_2
+ print(total_e, e_1/total_e, e_2/total_e)
+
diff --git a/mixing_style_transfer/mixing_manipulator/utils_data_normalization.py b/mixing_style_transfer/mixing_manipulator/utils_data_normalization.py
new file mode 100644
index 0000000000000000000000000000000000000000..69512cd8b0f5f4c95704fe4c4b41c9cb8123a108
--- /dev/null
+++ b/mixing_style_transfer/mixing_manipulator/utils_data_normalization.py
@@ -0,0 +1,906 @@
+import os
+
+import sys
+import time
+import numpy as np
+import scipy
+import librosa
+import pyloudnorm as pyln
+
+sys.setrecursionlimit(int(1e6))
+
+import sklearn
+
+currentdir = os.path.dirname(os.path.realpath(__file__))
+sys.path.append(currentdir)
+from common_miscellaneous import compute_stft, compute_istft
+from common_audioeffects import Panner, Compressor, AugmentationChain, ConvolutionalReverb, Equaliser, AlgorithmicReverb
+import fx_utils
+
+import soundfile as sf
+import aubio
+
+import time
+
+import warnings
+
+# Functions
+
+def print_dict(dict_):
+ for i in dict_:
+ print(i)
+ for j in dict_[i]:
+ print('\t', j)
+
+def amp_to_db(x):
+ return 20*np.log10(x + 1e-30)
+
+def db_to_amp(x):
+ return 10**(x/20)
+
+def get_running_stats(x, features, N=20):
+ mean = []
+ std = []
+ for i in range(len(features)):
+ mean_, std_ = running_mean_std(x[:,i], N)
+ mean.append(mean_)
+ std.append(std_)
+ mean = np.asarray(mean)
+ std = np.asarray(std)
+
+ return mean, std
+
+def running_mean_std(x, N):
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=RuntimeWarning)
+ cumsum = np.cumsum(np.insert(x, 0, 0))
+ cumsum2 = np.cumsum(np.insert(x**2, 0, 0))
+ mean = (cumsum[N:] - cumsum[:-N]) / float(N)
+
+ std = np.sqrt(((cumsum2[N:] - cumsum2[:-N]) / N) - (mean * mean))
+
+ return mean, std
+
+def get_eq_matching(audio_t, ref_spec, sr=44100, n_fft=65536, hop_length=16384,
+ min_db=-50, ntaps=101, lufs=-30):
+
+ audio_t = np.copy(audio_t)
+ max_db = amp_to_db(np.max(np.abs(audio_t)))
+ if max_db > min_db:
+
+ audio_t = fx_utils.lufs_normalize(audio_t, sr, lufs, log=False)
+ audio_D = compute_stft(np.expand_dims(audio_t, 1),
+ hop_length,
+ n_fft,
+ np.sqrt(np.hanning(n_fft+1)[:-1]))
+ audio_D = np.abs(audio_D)
+ audio_D_avg = np.mean(audio_D, axis=0)[0]
+
+ m = ref_spec.shape[0]
+
+ Ts = 1.0/sr # sampling interval
+ n = m # length of the signal
+ kk = np.arange(n)
+ T = n/sr
+ frq = kk/T # two sides frequency range
+ frq /=2
+
+ diff_eq = amp_to_db(ref_spec)-amp_to_db(audio_D_avg)
+ diff_eq = db_to_amp(diff_eq)
+ diff_eq = np.sqrt(diff_eq)
+
+ diff_filter = scipy.signal.firwin2(ntaps,
+ frq/np.max(frq),
+ diff_eq,
+ nfreqs=None, window='hamming',
+ nyq=None, antisymmetric=False)
+
+
+ output = scipy.signal.filtfilt(diff_filter, 1, audio_t,
+ axis=-1, padtype='odd', padlen=None,
+ method='pad', irlen=None)
+
+ else:
+ output = audio_t
+
+ return output
+
+def get_SPS(x, n_fft=2048, hop_length=1024, smooth=False, frames=False):
+
+ x = np.copy(x)
+ eps = 1e-20
+
+ audio_D = compute_stft(x,
+ hop_length,
+ n_fft,
+ np.sqrt(np.hanning(n_fft+1)[:-1]))
+
+ audio_D_l = np.abs(audio_D[:, 0, :] + eps)
+ audio_D_r = np.abs(audio_D[:, 1, :] + eps)
+
+ phi = 2 * (np.abs(audio_D_l*np.conj(audio_D_r)))/(np.abs(audio_D_l)**2+np.abs(audio_D_r)**2)
+
+ phi_l = np.abs(audio_D_l*np.conj(audio_D_r))/(np.abs(audio_D_l)**2)
+ phi_r = np.abs(audio_D_r*np.conj(audio_D_l))/(np.abs(audio_D_r)**2)
+ delta = phi_l - phi_r
+ delta_ = np.sign(delta)
+ SPS = (1-phi)*delta_
+
+ phi_mean = np.mean(phi, axis=0)
+ if smooth:
+ phi_mean = scipy.signal.savgol_filter(phi_mean, 501, 1, mode='mirror')
+
+ SPS_mean = np.mean(SPS, axis=0)
+ if smooth:
+ SPS_mean = scipy.signal.savgol_filter(SPS_mean, 501, 1, mode='mirror')
+
+
+ return SPS_mean, phi_mean, SPS, phi
+
+
+def get_mean_side(sps, freqs=[50,2500], sr=44100, n_fft=2048):
+
+ sign = np.sign(sps+ 1e-10)
+
+ idx1 = freqs[0]
+ idx2 = freqs[1]
+
+ f1 = int(np.floor(idx1*n_fft/sr))
+ f2 = int(np.floor(idx2*n_fft/sr))
+
+ sign_mean = np.mean(sign[f1:f2])/np.abs(np.mean(sign[f1:f2]))
+ sign_mean
+
+ return sign_mean
+
+def get_panning_param_values(phi, side):
+
+ p = np.zeros_like(phi)
+
+ g = (np.clip(phi+1e-30, 0, 1))/2
+
+ for i, g_ in enumerate(g):
+
+ if side > 0:
+ p[i] = 1 - g_
+
+ elif side < 0:
+ p[i] = g_
+
+ else:
+ p[i] = 0.5
+
+ g_l = 1-p
+ g_r = p
+
+ return p, [g_l, g_r]
+
+def get_panning_matching(audio, ref_phi,
+ sr=44100, n_fft=2048, hop_length=1024,
+ min_db_f=-10, max_freq_pan=16000, frames=True):
+
+ eps = 1e-20
+ window = np.sqrt(np.hanning(n_fft+1)[:-1])
+ audio = np.copy(audio)
+ audio_t = np.pad(audio, ((n_fft, n_fft), (0, 0)), mode='constant')
+
+ sps_mean_, phi_mean_, _, _ = get_SPS(audio_t, n_fft=n_fft, hop_length=hop_length, smooth=True)
+
+ side = get_mean_side(sps_mean_, sr=sr, n_fft=n_fft)
+
+ if side > 0:
+ alpha = 0.7
+ else:
+ alpha = 0.3
+
+ processor = Panner()
+ processor.parameters.pan.value = alpha
+ processor.parameters.pan_law.value = 'linear'
+ processor.update()
+ audio_t_ = processor.process(audio_t)
+
+ sps_mean_, phi_mean, sps_frames, phi_frames = get_SPS(audio_t_, n_fft=n_fft,
+ hop_length=hop_length,
+ smooth=True, frames=frames)
+
+ if frames:
+
+ p_i_ = []
+ g_i_ = []
+ p_ref = []
+ g_ref = []
+ for i in range(len(sps_frames)):
+ sps_ = sps_frames[i]
+ phi_ = phi_frames[i]
+ p_, g_ = get_panning_param_values(phi_, side)
+ p_i_.append(p_)
+ g_i_.append(g_)
+ p_, g_ = get_panning_param_values(ref_phi, side)
+ p_ref.append(p_)
+ g_ref.append(g_)
+ ratio = (np.asarray(g_ref)/(np.asarray(g_i_)+eps))
+ g_l = ratio[:,0,:]
+ g_r = ratio[:,1,:]
+
+
+ else:
+ p, g = get_panning_param_values(ref_phi, side)
+ p_i, g_i = get_panning_param_values(phi_mean, side)
+ ratio = (np.asarray(g)/np.asarray(g_i))
+ g_l = ratio[0]
+ g_r = ratio[1]
+
+ audio_new_D = compute_stft(audio_t_,
+ hop_length,
+ n_fft,
+ window)
+
+ audio_new_D_mono = audio_new_D.copy()
+ audio_new_D_mono = audio_new_D_mono[:, 0, :] + audio_new_D_mono[:, 1, :]
+ audio_new_D_mono = np.abs(audio_new_D_mono)
+
+ audio_new_D_phase = np.angle(audio_new_D)
+ audio_new_D = np.abs(audio_new_D)
+
+ audio_new_D_l = audio_new_D[:, 0, :]
+ audio_new_D_r = audio_new_D[:, 1, :]
+
+ if frames:
+ for i, frame in enumerate(audio_new_D_mono):
+ max_db = amp_to_db(np.max(np.abs(frame)))
+ if max_db < min_db_f:
+ g_r[i] = np.ones_like(frame)
+ g_l[i] = np.ones_like(frame)
+
+ idx1 = max_freq_pan
+ f1 = int(np.floor(idx1*n_fft/sr))
+ ones = np.ones_like(g_l)
+ g_l[f1:] = ones[f1:]
+ g_r[f1:] = ones[f1:]
+
+ audio_new_D_l = audio_new_D_l*g_l
+ audio_new_D_r = audio_new_D_r*g_r
+
+ audio_new_D_l = np.expand_dims(audio_new_D_l, 0)
+ audio_new_D_r = np.expand_dims(audio_new_D_r, 0)
+
+ audio_new_D_ = np.concatenate((audio_new_D_l,audio_new_D_r))
+
+ audio_new_D_ = np.moveaxis(audio_new_D_, 0, 1)
+
+ audio_new_D_ = audio_new_D_ * (np.cos(audio_new_D_phase) + np.sin(audio_new_D_phase)*1j)
+
+ audio_new_t = compute_istft(audio_new_D_,
+ hop_length,
+ window)
+
+ audio_new_t = audio_new_t[n_fft:n_fft+audio.shape[0]]
+
+ return audio_new_t
+
+
+
+def get_mean_peak(audio, sr=44100, true_peak=False, n_mels=128, percentile=75):
+
+# Returns mean peak value in dB after the 1Q is removed.
+# Input should be in the shape samples x channel
+
+ audio_ = audio
+ window_size = 2**10 # FFT size
+ hop_size = window_size
+
+ peak = []
+ std = []
+ for ch in range(audio_.shape[-1]):
+ x = np.ascontiguousarray(audio_[:, ch])
+
+ if true_peak:
+ x = librosa.resample(x, sr, 4*sr)
+ sr = 4*sr
+ window_size = 4*window_size
+ hop_size = 4*hop_size
+
+ onset_func = aubio.onset('hfc', buf_size=window_size, hop_size=hop_size, samplerate=sr)
+
+ frames = np.float32(librosa.util.frame(x, frame_length=window_size, hop_length=hop_size))
+
+ onset_times = []
+ for frame in frames.T:
+
+ if onset_func(frame):
+
+ onset_time = onset_func.get_last()
+ onset_times.append(onset_time)
+
+ samples=[]
+ if onset_times:
+ for i, p in enumerate(onset_times[:-1]):
+ samples.append(onset_times[i]+np.argmax(np.abs(x[onset_times[i]:onset_times[i+1]])))
+ samples.append(onset_times[-1]+np.argmax(np.abs(x[onset_times[-1]:])))
+
+ p_value = []
+ for p in samples:
+ p_ = amp_to_db(np.abs(x[p]))
+ p_value.append(p_)
+ p_value_=[]
+ for p in p_value:
+ if p > np.percentile(p_value, percentile):
+ p_value_.append(p)
+ if p_value_:
+ peak.append(np.mean(p_value_))
+ std.append(np.std(p_value_))
+ elif p_value:
+ peak.append(np.mean(p_value))
+ std.append(np.std(p_value))
+ else:
+ return None
+ return [np.mean(peak), np.mean(std)]
+
+def compress(processor, audio, sr, th, ratio, attack, release):
+
+ eps = 1e-20
+ x = audio
+
+ processor.parameters.threshold.value = th
+ processor.parameters.ratio.value = ratio
+ processor.parameters.attack_time.value = attack
+ processor.parameters.release_time.value = release
+ processor.update()
+ output = processor.process(x)
+
+ if np.max(np.abs(output)) >= 1.0:
+ output = np.clip(output, -1.0, 1.0)
+
+ return output
+
+def get_comp_matching(audio,
+ ref_peak, ref_std,
+ ratio, attack, release, sr=44100,
+ min_db=-50, comp_peak_norm=-10.0,
+ min_th=-40, max_ratio=20, n_mels=128,
+ true_peak=False, percentile=75, expander=True):
+
+ x = audio.copy()
+
+ if x.ndim < 2:
+ x = np.expand_dims(x, 1)
+
+ max_db = amp_to_db(np.max(np.abs(x)))
+ if max_db > min_db:
+
+ x = pyln.normalize.peak(x, comp_peak_norm)
+
+ peak, std = get_mean_peak(x, sr,
+ n_mels=n_mels,
+ true_peak=true_peak,
+ percentile=percentile)
+
+ if peak > (ref_peak - ref_std) and peak < (ref_peak + ref_std):
+ return x
+
+ # DownwardCompress
+ elif peak > (ref_peak - ref_std):
+ processor = Compressor(sample_rate=sr)
+ # print('compress')
+ ratios = np.linspace(ratio, max_ratio, max_ratio-ratio+1)
+ ths = np.linspace(-1-9, min_th, 2*np.abs(min_th)-1-18)
+ for rt in ratios:
+ for th in ths:
+ y = compress(processor, x, sr, th, rt, attack, release)
+ peak, std = get_mean_peak(y, sr,
+ n_mels=n_mels,
+ true_peak=true_peak,
+ percentile=percentile)
+ if peak < (ref_peak + ref_std):
+ break
+ else:
+ continue
+ break
+
+ return y
+
+ # Upward Expand
+ elif peak < (ref_peak + ref_std):
+
+ if expander:
+ processor = Compressor(sample_rate=sr)
+ ratios = np.linspace(ratio, max_ratio, max_ratio-ratio+1)
+ ths = np.linspace(-1, min_th, 2*np.abs(min_th)-1)[::-1]
+
+ for rt in ratios:
+ for th in ths:
+ y = compress(processor, x, sr, th, 1/rt, attack, release)
+ peak, std = get_mean_peak(y, sr,
+ n_mels=n_mels,
+ true_peak=true_peak,
+ percentile=percentile)
+ if peak > (ref_peak - ref_std):
+ break
+ else:
+ continue
+ break
+
+ return y
+
+ else:
+ return x
+ else:
+ return x
+
+
+
+# REVERB
+
+
+def get_reverb_send(audio, eq_parameters, rv_parameters, impulse_responses=None,
+ eq_prob=1.0, rv_prob=1.0, parallel=True, shuffle=False, sr=44100, bands=['low_shelf', 'high_shelf']):
+
+ x = audio.copy()
+
+ if x.ndim < 2:
+ x = np.expand_dims(x, 1)
+
+ channels = x.shape[-1]
+ eq_gain = eq_parameters.low_shelf_gain.value
+
+
+ eq = Equaliser(n_channels=channels,
+ sample_rate=sr,
+ gain_range=(eq_gain, eq_gain),
+ bands=bands,
+ hard_clip=False,
+ name='Equaliser', parameters=eq_parameters)
+ eq.randomize()
+
+ if impulse_responses:
+
+ reverb = ConvolutionalReverb(impulse_responses=impulse_responses,
+ sample_rate=sr,
+ parameters=rv_parameters)
+
+ else:
+
+ reverb = AlgorithmicReverb(sample_rate=sr,
+ parameters=rv_parameters)
+
+ reverb.randomize()
+
+ fxchain = AugmentationChain([
+ (eq, rv_prob, False),
+ (reverb, eq_prob, False)
+ ],
+ shuffle=shuffle, parallel=parallel)
+
+ output = fxchain(x)
+
+ return output
+
+
+
+# FUNCTIONS TO COMPUTE FEATURES
+
+def compute_loudness_features(args_):
+
+ audio_out_ = args_[0]
+ audio_tar_ = args_[1]
+ idx = args_[2]
+ sr = args_[3]
+
+ loudness_ = {key:[] for key in ['d_lufs', 'd_peak',]}
+
+ peak_tar = np.max(np.abs(audio_tar_))
+ peak_tar_db = 20.0 * np.log10(peak_tar)
+
+ peak_out = np.max(np.abs(audio_out_))
+ peak_out_db = 20.0 * np.log10(peak_out)
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=RuntimeWarning)
+ meter = pyln.Meter(sr) # create BS.1770 meter
+ loudness_tar = meter.integrated_loudness(audio_tar_)
+ loudness_out = meter.integrated_loudness(audio_out_)
+
+ loudness_['d_lufs'].append(sklearn.metrics.mean_absolute_percentage_error([loudness_tar], [loudness_out]))
+ loudness_['d_peak'].append(sklearn.metrics.mean_absolute_percentage_error([peak_tar_db], [peak_out_db]))
+
+ return loudness_
+
+def compute_spectral_features(args_):
+
+ audio_out_ = args_[0]
+ audio_tar_ = args_[1]
+ idx = args_[2]
+ sr = args_[3]
+ fft_size = args_[4]
+ hop_length = args_[5]
+ channels = args_[6]
+
+ audio_out_ = pyln.normalize.peak(audio_out_, -1.0)
+ audio_tar_ = pyln.normalize.peak(audio_tar_, -1.0)
+
+ spec_out_ = compute_stft(audio_out_,
+ hop_length,
+ fft_size,
+ np.sqrt(np.hanning(fft_size+1)[:-1]))
+ spec_out_ = np.transpose(spec_out_, axes=[1, -1, 0])
+ spec_out_ = np.abs(spec_out_)
+
+ spec_tar_ = compute_stft(audio_tar_,
+ hop_length,
+ fft_size,
+ np.sqrt(np.hanning(fft_size+1)[:-1]))
+ spec_tar_ = np.transpose(spec_tar_, axes=[1, -1, 0])
+ spec_tar_ = np.abs(spec_tar_)
+
+ spectral_ = {key:[] for key in ['centroid_mean',
+ 'bandwidth_mean',
+ 'contrast_l_mean',
+ 'contrast_m_mean',
+ 'contrast_h_mean',
+ 'rolloff_mean',
+ 'flatness_mean',
+ 'mape_mean',
+ ]}
+
+ centroid_mean_ = []
+ centroid_std_ = []
+ bandwidth_mean_ = []
+ bandwidth_std_ = []
+ contrast_l_mean_ = []
+ contrast_l_std_ = []
+ contrast_m_mean_ = []
+ contrast_m_std_ = []
+ contrast_h_mean_ = []
+ contrast_h_std_ = []
+ rolloff_mean_ = []
+ rolloff_std_ = []
+ flatness_mean_ = []
+
+ for ch in range(channels):
+ tar = spec_tar_[ch]
+ out = spec_out_[ch]
+
+ tar_sc = librosa.feature.spectral_centroid(y=None, sr=sr, S=tar,
+ n_fft=fft_size, hop_length=hop_length)
+
+ out_sc = librosa.feature.spectral_centroid(y=None, sr=sr, S=out,
+ n_fft=fft_size, hop_length=hop_length)
+
+ tar_bw = librosa.feature.spectral_bandwidth(y=None, sr=sr, S=tar,
+ n_fft=fft_size, hop_length=hop_length,
+ centroid=tar_sc, norm=True, p=2)
+
+ out_bw = librosa.feature.spectral_bandwidth(y=None, sr=sr, S=out,
+ n_fft=fft_size, hop_length=hop_length,
+ centroid=out_sc, norm=True, p=2)
+ # l = 0-250, m = 1-2-3 = 250 - 2000, h = 2000 - SR/2
+ tar_ct = librosa.feature.spectral_contrast(y=None, sr=sr, S=tar,
+ n_fft=fft_size, hop_length=hop_length,
+ fmin=250.0, n_bands=4, quantile=0.02, linear=False)
+
+ out_ct = librosa.feature.spectral_contrast(y=None, sr=sr, S=out,
+ n_fft=fft_size, hop_length=hop_length,
+ fmin=250.0, n_bands=4, quantile=0.02, linear=False)
+
+ tar_ro = librosa.feature.spectral_rolloff(y=None, sr=sr, S=tar,
+ n_fft=fft_size, hop_length=hop_length,
+ roll_percent=0.85)
+
+ out_ro = librosa.feature.spectral_rolloff(y=None, sr=sr, S=out,
+ n_fft=fft_size, hop_length=hop_length,
+ roll_percent=0.85)
+
+ tar_ft = librosa.feature.spectral_flatness(y=None, S=tar,
+ n_fft=fft_size, hop_length=hop_length,
+ amin=1e-10, power=2.0)
+
+ out_ft = librosa.feature.spectral_flatness(y=None, S=out,
+ n_fft=fft_size, hop_length=hop_length,
+ amin=1e-10, power=2.0)
+
+
+ eps = 1e-0
+ N = 40
+ mean_sc_tar, std_sc_tar = get_running_stats(tar_sc.T+eps, [0], N=N)
+ mean_sc_out, std_sc_out = get_running_stats(out_sc.T+eps, [0], N=N)
+
+ assert np.isnan(mean_sc_tar).any() == False, f'NAN values mean_sc_tar {idx}'
+ assert np.isnan(mean_sc_out).any() == False, f'NAN values mean_sc_out {idx}'
+
+
+ mean_bw_tar, std_bw_tar = get_running_stats(tar_bw.T+eps, [0], N=N)
+ mean_bw_out, std_bw_out = get_running_stats(out_bw.T+eps, [0], N=N)
+
+ assert np.isnan(mean_bw_tar).any() == False, f'NAN values tar mean bw {idx}'
+ assert np.isnan(mean_bw_out).any() == False, f'NAN values out mean bw {idx}'
+
+ mean_ct_tar, std_ct_tar = get_running_stats(tar_ct.T, list(range(tar_ct.shape[0])), N=N)
+ mean_ct_out, std_ct_out = get_running_stats(out_ct.T, list(range(out_ct.shape[0])), N=N)
+
+ assert np.isnan(mean_ct_tar).any() == False, f'NAN values tar mean ct {idx}'
+ assert np.isnan(mean_ct_out).any() == False, f'NAN values out mean ct {idx}'
+
+ mean_ro_tar, std_ro_tar = get_running_stats(tar_ro.T+eps, [0], N=N)
+ mean_ro_out, std_ro_out = get_running_stats(out_ro.T+eps, [0], N=N)
+
+ assert np.isnan(mean_ro_tar).any() == False, f'NAN values tar mean ro {idx}'
+ assert np.isnan(mean_ro_out).any() == False, f'NAN values out mean ro {idx}'
+
+ mean_ft_tar, std_ft_tar = get_running_stats(tar_ft.T, [0], N=800) # gives very high numbers due to N (80) value
+ mean_ft_out, std_ft_out = get_running_stats(out_ft.T, [0], N=800)
+
+ mape_mean_sc = sklearn.metrics.mean_absolute_percentage_error(mean_sc_tar[0], mean_sc_out[0])
+
+ mape_mean_bw = sklearn.metrics.mean_absolute_percentage_error(mean_bw_tar[0], mean_bw_out[0])
+
+ mape_mean_ct_l = sklearn.metrics.mean_absolute_percentage_error(mean_ct_tar[0], mean_ct_out[0])
+
+ mape_mean_ct_m = sklearn.metrics.mean_absolute_percentage_error(np.mean(mean_ct_tar[1:4], axis=0),
+ np.mean(mean_ct_out[1:4], axis=0))
+
+ mape_mean_ct_h = sklearn.metrics.mean_absolute_percentage_error(mean_ct_tar[-1], mean_ct_out[-1])
+
+ mape_mean_ro = sklearn.metrics.mean_absolute_percentage_error(mean_ro_tar[0], mean_ro_out[0])
+
+ mape_mean_ft = sklearn.metrics.mean_absolute_percentage_error(mean_ft_tar[0], mean_ft_out[0])
+
+ centroid_mean_.append(mape_mean_sc)
+ bandwidth_mean_.append(mape_mean_bw)
+ contrast_l_mean_.append(mape_mean_ct_l)
+ contrast_m_mean_.append(mape_mean_ct_m)
+ contrast_h_mean_.append(mape_mean_ct_h)
+ rolloff_mean_.append(mape_mean_ro)
+ flatness_mean_.append(mape_mean_ft)
+
+ spectral_['centroid_mean'].append(np.mean(centroid_mean_))
+
+ spectral_['bandwidth_mean'].append(np.mean(bandwidth_mean_))
+
+ spectral_['contrast_l_mean'].append(np.mean(contrast_l_mean_))
+
+ spectral_['contrast_m_mean'].append(np.mean(contrast_m_mean_))
+
+ spectral_['contrast_h_mean'].append(np.mean(contrast_h_mean_))
+
+ spectral_['rolloff_mean'].append(np.mean(rolloff_mean_))
+
+ spectral_['flatness_mean'].append(np.mean(flatness_mean_))
+
+ spectral_['mape_mean'].append(np.mean([np.mean(centroid_mean_),
+ np.mean(bandwidth_mean_),
+ np.mean(contrast_l_mean_),
+ np.mean(contrast_m_mean_),
+ np.mean(contrast_h_mean_),
+ np.mean(rolloff_mean_),
+ np.mean(flatness_mean_),
+ ]))
+
+ return spectral_
+
+# PANNING
+def get_panning_rms_frame(sps_frame, freqs=[0,22050], sr=44100, n_fft=2048):
+
+ idx1 = freqs[0]
+ idx2 = freqs[1]
+
+ f1 = int(np.floor(idx1*n_fft/sr))
+ f2 = int(np.floor(idx2*n_fft/sr))
+
+ p_rms = np.sqrt((1/(f2-f1)) * np.sum(sps_frame[f1:f2]**2))
+
+ return p_rms
+def get_panning_rms(sps, freqs=[[0, 22050]], sr=44100, n_fft=2048):
+
+ p_rms = []
+ for frame in sps:
+ p_rms_ = []
+ for f in freqs:
+ rms = get_panning_rms_frame(frame, freqs=f, sr=sr, n_fft=n_fft)
+ p_rms_.append(rms)
+ p_rms.append(p_rms_)
+
+ return np.asarray(p_rms)
+
+
+
+def compute_panning_features(args_):
+
+ audio_out_ = args_[0]
+ audio_tar_ = args_[1]
+ idx = args_[2]
+ sr = args_[3]
+ fft_size = args_[4]
+ hop_length = args_[5]
+
+ audio_out_ = pyln.normalize.peak(audio_out_, -1.0)
+ audio_tar_ = pyln.normalize.peak(audio_tar_, -1.0)
+
+ panning_ = {}
+
+ freqs=[[0, sr//2], [0, 250], [250, 2500], [2500, sr//2]]
+
+ _, _, sps_frames_tar, _ = get_SPS(audio_tar_, n_fft=fft_size,
+ hop_length=hop_length,
+ smooth=True, frames=True)
+
+ _, _, sps_frames_out, _ = get_SPS(audio_out_, n_fft=fft_size,
+ hop_length=hop_length,
+ smooth=True, frames=True)
+
+
+ p_rms_tar = get_panning_rms(sps_frames_tar,
+ freqs=freqs,
+ sr=sr,
+ n_fft=fft_size)
+
+ p_rms_out = get_panning_rms(sps_frames_out,
+ freqs=freqs,
+ sr=sr,
+ n_fft=fft_size)
+
+ # to avoid num instability, deletes frames with zero rms from target
+ if np.min(p_rms_tar) == 0.0:
+ id_zeros = np.where(p_rms_tar.T[0] == 0)
+ p_rms_tar_ = []
+ p_rms_out_ = []
+ for i in range(len(freqs)):
+ temp_tar = np.delete(p_rms_tar.T[i], id_zeros)
+ temp_out = np.delete(p_rms_out.T[i], id_zeros)
+ p_rms_tar_.append(temp_tar)
+ p_rms_out_.append(temp_out)
+ p_rms_tar_ = np.asarray(p_rms_tar_)
+ p_rms_tar = p_rms_tar_.T
+ p_rms_out_ = np.asarray(p_rms_out_)
+ p_rms_out = p_rms_out_.T
+
+ N = 40
+
+ mean_tar, std_tar = get_running_stats(p_rms_tar, freqs, N=N)
+ mean_out, std_out = get_running_stats(p_rms_out, freqs, N=N)
+
+ panning_['P_t_mean'] = [sklearn.metrics.mean_absolute_percentage_error(mean_tar[0], mean_out[0])]
+ panning_['P_l_mean'] = [sklearn.metrics.mean_absolute_percentage_error(mean_tar[1], mean_out[1])]
+ panning_['P_m_mean'] = [sklearn.metrics.mean_absolute_percentage_error(mean_tar[2], mean_out[2])]
+ panning_['P_h_mean'] = [sklearn.metrics.mean_absolute_percentage_error(mean_tar[3], mean_out[3])]
+
+ panning_['mape_mean'] = [np.mean([panning_['P_t_mean'],
+ panning_['P_l_mean'],
+ panning_['P_m_mean'],
+ panning_['P_h_mean'],
+ ])]
+
+ return panning_
+
+# DYNAMIC
+
+def get_rms_dynamic_crest(x, frame_length, hop_length):
+
+ rms = []
+ dynamic_spread = []
+ crest = []
+ for ch in range(x.shape[-1]):
+ frames = librosa.util.frame(x[:, ch], frame_length=frame_length, hop_length=hop_length)
+ rms_ = []
+ dynamic_spread_ = []
+ crest_ = []
+ for i in frames.T:
+ x_rms = amp_to_db(np.sqrt(np.sum(i**2)/frame_length))
+ x_d = np.sum(amp_to_db(np.abs(i)) - x_rms)/frame_length
+ x_c = amp_to_db(np.max(np.abs(i)))/x_rms
+
+ rms_.append(x_rms)
+ dynamic_spread_.append(x_d)
+ crest_.append(x_c)
+ rms.append(rms_)
+ dynamic_spread.append(dynamic_spread_)
+ crest.append(crest_)
+
+ rms = np.asarray(rms)
+ dynamic_spread = np.asarray(dynamic_spread)
+ crest = np.asarray(crest)
+
+ rms = np.mean(rms, axis=0)
+ dynamic_spread = np.mean(dynamic_spread, axis=0)
+ crest = np.mean(crest, axis=0)
+
+ rms = np.expand_dims(rms, axis=0)
+ dynamic_spread = np.expand_dims(dynamic_spread, axis=0)
+ crest = np.expand_dims(crest, axis=0)
+
+ return rms, dynamic_spread, crest
+
+def lowpassFiltering(x, f0, sr):
+
+ b1, a1 = scipy.signal.butter(4, f0/(sr/2),'lowpass')
+ x_f = []
+ for ch in range(x.shape[-1]):
+ x_f_ = scipy.signal.filtfilt(b1, a1, x[:, ch]).copy(order='F')
+ x_f.append(x_f_)
+ return np.asarray(x_f).T
+
+
+def get_low_freq_weighting(x, sr, n_fft, hop_length, f0 = 1000):
+
+ x_low = lowpassFiltering(x, f0, sr)
+
+ X_low = compute_stft(x_low,
+ hop_length,
+ n_fft,
+ np.sqrt(np.hanning(n_fft+1)[:-1]))
+ X_low = np.transpose(X_low, axes=[1, -1, 0])
+ X_low = np.abs(X_low)
+
+ X = compute_stft(x,
+ hop_length,
+ n_fft,
+ np.sqrt(np.hanning(n_fft+1)[:-1]))
+ X = np.transpose(X, axes=[1, -1, 0])
+ X = np.abs(X)
+
+ eps = 1e-5
+ ratio = (X_low)/(X+eps)
+ ratio = np.sum(ratio, axis = 1)
+ ratio = np.mean(ratio, axis = 0)
+
+ return np.expand_dims(ratio, axis=0)
+
+def compute_dynamic_features(args_):
+
+ audio_out_ = args_[0]
+ audio_tar_ = args_[1]
+ idx = args_[2]
+ sr = args_[3]
+ fft_size = args_[4]
+ hop_length = args_[5]
+
+ audio_out_ = pyln.normalize.peak(audio_out_, -1.0)
+ audio_tar_ = pyln.normalize.peak(audio_tar_, -1.0)
+
+ dynamic_ = {}
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=UserWarning)
+
+ rms_tar, dyn_tar, crest_tar = get_rms_dynamic_crest(audio_tar_, fft_size, hop_length)
+ rms_out, dyn_out, crest_out = get_rms_dynamic_crest(audio_out_, fft_size, hop_length)
+
+ low_ratio_tar = get_low_freq_weighting(audio_tar_, sr, fft_size, hop_length, f0=1000)
+
+ low_ratio_out = get_low_freq_weighting(audio_out_, sr, fft_size, hop_length, f0=1000)
+
+ N = 40
+
+ eps = 1e-10
+
+ rms_tar = (-1*rms_tar) + 1.0
+ rms_out = (-1*rms_out) + 1.0
+ dyn_tar = (-1*dyn_tar) + 1.0
+ dyn_out = (-1*dyn_out) + 1.0
+
+ mean_rms_tar, std_rms_tar = get_running_stats(rms_tar.T, [0], N=N)
+ mean_rms_out, std_rms_out = get_running_stats(rms_out.T, [0], N=N)
+
+ mean_dyn_tar, std_dyn_tar = get_running_stats(dyn_tar.T, [0], N=N)
+ mean_dyn_out, std_dyn_out = get_running_stats(dyn_out.T, [0], N=N)
+
+ mean_crest_tar, std_crest_tar = get_running_stats(crest_tar.T, [0], N=N)
+ mean_crest_out, std_crest_out = get_running_stats(crest_out.T, [0], N=N)
+
+ mean_low_ratio_tar, std_low_ratio_tar = get_running_stats(low_ratio_tar.T, [0], N=N)
+ mean_low_ratio_out, std_low_ratio_out = get_running_stats(low_ratio_out.T, [0], N=N)
+
+ dynamic_['rms_mean'] = [sklearn.metrics.mean_absolute_percentage_error(mean_rms_tar, mean_rms_out)]
+ dynamic_['dyn_mean'] = [sklearn.metrics.mean_absolute_percentage_error(mean_dyn_tar, mean_dyn_out)]
+ dynamic_['crest_mean'] = [sklearn.metrics.mean_absolute_percentage_error(mean_crest_tar, mean_crest_out)]
+
+ dynamic_['l_ratio_mean_mape'] = [sklearn.metrics.mean_absolute_percentage_error(mean_low_ratio_tar, mean_low_ratio_out)]
+ dynamic_['l_ratio_mean_l2'] = [sklearn.metrics.mean_squared_error(mean_low_ratio_tar, mean_low_ratio_out)]
+
+ dynamic_['mape_mean'] = [np.mean([dynamic_['rms_mean'],
+ dynamic_['dyn_mean'],
+ dynamic_['crest_mean'],
+ ])]
+
+ return dynamic_
+
\ No newline at end of file
diff --git a/mixing_style_transfer/modules/__init__.py b/mixing_style_transfer/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5020a03d70b8496dd458912272af077e395f25c
--- /dev/null
+++ b/mixing_style_transfer/modules/__init__.py
@@ -0,0 +1,3 @@
+from .front_back_end import *
+from .loss import *
+from .training_utils import *
\ No newline at end of file
diff --git a/mixing_style_transfer/modules/front_back_end.py b/mixing_style_transfer/modules/front_back_end.py
new file mode 100644
index 0000000000000000000000000000000000000000..68a734ecd916d8e045e9520ae3afcc6b65f58dd1
--- /dev/null
+++ b/mixing_style_transfer/modules/front_back_end.py
@@ -0,0 +1,226 @@
+""" Front-end: processing raw data input """
+import torch
+import torch.nn as nn
+import torchaudio.functional as ta_F
+import torchaudio
+
+
+
+class FrontEnd(nn.Module):
+ def __init__(self, channel='stereo', \
+ n_fft=2048, \
+ hop_length=None, \
+ win_length=None, \
+ window="hann", \
+ device=torch.device("cpu")):
+ super(FrontEnd, self).__init__()
+ self.channel = channel
+ self.n_fft = n_fft
+ self.hop_length = n_fft//4 if hop_length==None else hop_length
+ self.win_length = n_fft if win_length==None else win_length
+ if window=="hann":
+ self.window = torch.hann_window(window_length=self.win_length, periodic=True).to(device)
+ elif window=="hamming":
+ self.window = torch.hamming_window(window_length=self.win_length, periodic=True).to(device)
+
+
+ def forward(self, input, mode):
+ # front-end function which channel-wise combines all demanded features
+ # input shape : batch x channel x raw waveform
+ # output shape : batch x channel x frequency x time
+
+ front_output_list = []
+ for cur_mode in mode:
+ # Real & Imaginary
+ if cur_mode=="cplx":
+ if self.channel=="mono":
+ output = torch.stft(input, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
+ elif self.channel=="stereo":
+ output_l = torch.stft(input[:,0], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
+ output_r = torch.stft(input[:,1], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
+ output = torch.cat((output_l, output_r), axis=-1)
+ if input.shape[2] % round(self.n_fft/4) == 0:
+ output = output[:, :, :-1]
+ if self.n_fft % 2 == 0:
+ output = output[:, :-1]
+ front_output_list.append(output.permute(0, 3, 1, 2))
+ # Magnitude & Phase
+ elif cur_mode=="mag":
+ if self.channel=="mono":
+ cur_cplx = torch.stft(input, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
+ output = self.mag(cur_cplx).unsqueeze(-1)[..., 0:1]
+ elif self.channel=="stereo":
+ cplx_l = torch.stft(input[:,0], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
+ cplx_r = torch.stft(input[:,1], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
+ mag_l = self.mag(cplx_l).unsqueeze(-1)
+ mag_r = self.mag(cplx_r).unsqueeze(-1)
+ output = torch.cat((mag_l, mag_r), axis=-1)
+
+ if input.shape[-1] % round(self.n_fft/4) == 0:
+ output = output[:, :, :-1]
+ if self.n_fft % 2 == 0: # discard highest frequency
+ output = output[:, 1:]
+ front_output_list.append(output.permute(0, 3, 1, 2))
+
+ # combine all demanded features
+ if not front_output_list:
+ raise NameError("NameError at FrontEnd: check using features for front-end")
+ elif len(mode)!=1:
+ for i, cur_output in enumerate(front_output_list):
+ if i==0:
+ front_output = cur_output
+ else:
+ front_output = torch.cat((front_output, cur_output), axis=1)
+ else:
+ front_output = front_output_list[0]
+
+ return front_output
+
+
+ def mag(self, cplx_input, eps=1e-07):
+ mag_summed = cplx_input.pow(2.).sum(-1) + eps
+ return mag_summed.pow(0.5)
+
+
+
+
+class BackEnd(nn.Module):
+ def __init__(self, channel='stereo', \
+ n_fft=2048, \
+ hop_length=None, \
+ win_length=None, \
+ window="hann", \
+ eps=1e-07, \
+ orig_freq=44100, \
+ new_freq=16000, \
+ device=torch.device("cpu")):
+ super(BackEnd, self).__init__()
+ self.device = device
+ self.channel = channel
+ self.n_fft = n_fft
+ self.hop_length = n_fft//4 if hop_length==None else hop_length
+ self.win_length = n_fft if win_length==None else win_length
+ self.eps = eps
+ if window=="hann":
+ self.window = torch.hann_window(window_length=self.win_length, periodic=True).to(self.device)
+ elif window=="hamming":
+ self.window = torch.hamming_window(window_length=self.win_length, periodic=True).to(self.device)
+ self.resample_func_8k = torchaudio.transforms.Resample(orig_freq=orig_freq, new_freq=8000).to(self.device)
+ self.resample_func = torchaudio.transforms.Resample(orig_freq=orig_freq, new_freq=new_freq).to(self.device)
+
+ def magphase_to_cplx(self, magphase_spec):
+ real = magphase_spec[..., 0] * torch.cos(magphase_spec[..., 1])
+ imaginary = magphase_spec[..., 0] * torch.sin(magphase_spec[..., 1])
+ return torch.cat((real.unsqueeze(-1), imaginary.unsqueeze(-1)), dim=-1)
+
+
+ def forward(self, input, phase, mode):
+ # back-end function which convert output spectrograms into waveform
+ # input shape : batch x channel x frequency x time
+ # output shape : batch x channel x raw waveform
+
+ # convert to shape : batch x frequency x time x channel
+ input = input.permute(0, 2, 3, 1)
+ # pad highest frequency
+ pad = torch.zeros((input.shape[0], 1, input.shape[2], input.shape[3])).to(self.device)
+ input = torch.cat((pad, input), dim=1)
+
+ back_output_list = []
+ channel_count = 0
+ for i, cur_mode in enumerate(mode):
+ # Real & Imaginary
+ if cur_mode=="cplx":
+ if self.channel=="mono":
+ output = ta_F.istft(input[...,channel_count:channel_count+2], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window).unsqueeze(1)
+ channel_count += 2
+ elif self.channel=="stereo":
+ cplx_spec = torch.cat([input[...,channel_count:channel_count+2], input[...,channel_count+2:channel_count+4]], dim=0)
+ output_wav = ta_F.istft(cplx_spec, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
+ output = torch.cat((output_wav[:output_wav.shape[0]//2].unsqueeze(1), output_wav[output_wav.shape[0]//2:].unsqueeze(1)), dim=1)
+ channel_count += 4
+ back_output_list.append(output)
+ # Magnitude & Phase
+ elif cur_mode=="mag_phase" or cur_mode=="mag":
+ if self.channel=="mono":
+ if cur_mode=="mag":
+ input_spec = torch.cat((input[...,channel_count:channel_count+1], phase), axis=-1)
+ channel_count += 1
+ else:
+ input_spec = input[...,channel_count:channel_count+2]
+ channel_count += 2
+ cplx_spec = self.magphase_to_cplx(input_spec)
+ output = ta_F.istft(cplx_spec, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window).unsqueeze(1)
+ elif self.channel=="stereo":
+ if cur_mode=="mag":
+ input_spec_l = torch.cat((input[...,channel_count:channel_count+1], phase[...,0:1]), axis=-1)
+ input_spec_r = torch.cat((input[...,channel_count+1:channel_count+2], phase[...,1:2]), axis=-1)
+ channel_count += 2
+ else:
+ input_spec_l = input[...,channel_count:channel_count+2]
+ input_spec_r = input[...,channel_count+2:channel_count+4]
+ channel_count += 4
+ cplx_spec_l = self.magphase_to_cplx(input_spec_l)
+ cplx_spec_r = self.magphase_to_cplx(input_spec_r)
+ cplx_spec = torch.cat([cplx_spec_l, cplx_spec_r], dim=0)
+ output_wav = torch.istft(cplx_spec, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
+ output = torch.cat((output_wav[:output_wav.shape[0]//2].unsqueeze(1), output_wav[output_wav.shape[0]//2:].unsqueeze(1)), dim=1)
+ channel_count += 4
+ back_output_list.append(output)
+ elif cur_mode=="griff":
+ if self.channel=="mono":
+ output = self.griffin_lim(input.squeeze(-1), input.device).unsqueeze(1)
+ # output = self.griff(input.permute(0, 3, 1, 2))
+ else:
+ output_l = self.griffin_lim(input[..., 0], input.device).unsqueeze(1)
+ output_r = self.griffin_lim(input[..., 1], input.device).unsqueeze(1)
+ output = torch.cat((output_l, output_r), axis=1)
+
+ back_output_list.append(output)
+
+ # combine all demanded feature outputs
+ if not back_output_list:
+ raise NameError("NameError at BackEnd: check using features for back-end")
+ elif len(mode)!=1:
+ for i, cur_output in enumerate(back_output_list):
+ if i==0:
+ back_output = cur_output
+ else:
+ back_output = torch.cat((back_output, cur_output), axis=1)
+ else:
+ back_output = back_output_list[0]
+
+ return back_output
+
+
+ def griffin_lim(self, l_est, gpu, n_iter=100):
+ l_est = l_est.cpu().detach()
+
+ l_est = torch.pow(l_est, 1/0.80)
+ # l_est [batch, channel, time]
+ l_mag = l_est.unsqueeze(-1)
+ l_phase = 2 * np.pi * torch.rand_like(l_mag) - np.pi
+ real = l_mag * torch.cos(l_phase)
+ imag = l_mag * torch.sin(l_phase)
+ S = torch.cat((real, imag), axis=-1)
+ S_mag = (real**2 + imag**2 + self.eps) ** 1/2
+ for i in range(n_iter):
+ x = ta_F.istft(S, n_fft=2048, hop_length=512, win_length=2048, window=torch.hann_window(2048))
+ S_new = torch.stft(x, n_fft=2048, hop_length=512, win_length=2048, window=torch.hann_window(2048))
+ S_new_phase = S_new/mag(S_new)
+ S = S_mag * S_new_phase
+ return x / torch.max(torch.abs(x))
+
+
+
+if __name__ == '__main__':
+
+ batch_size = 16
+ channel = 2
+ segment_length = 512*128*6
+ input_wav = torch.rand((batch_size, channel, segment_length))
+
+ mode = ["cplx", "mag"]
+ fe = FrontEnd(channel="stereo")
+
+ output = fe(input_wav, mode=mode)
+ print(f"Input shape : {input_wav.shape}\nOutput shape : {output.shape}")
diff --git a/mixing_style_transfer/modules/loss.py b/mixing_style_transfer/modules/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfab06d6a921e8f3dab3beeb0093b648efe806d2
--- /dev/null
+++ b/mixing_style_transfer/modules/loss.py
@@ -0,0 +1,260 @@
+"""
+ Implementation of objective functions used in the task 'End-to-end Remastering System'
+"""
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+
+import os
+import sys
+currentdir = os.path.dirname(os.path.realpath(__file__))
+sys.path.append(os.path.dirname(currentdir))
+
+from modules.training_utils import *
+from modules.front_back_end import *
+
+
+
+'''
+ Normalized Temperature-scaled Cross Entropy (NT-Xent) Loss
+ below source code (class NT_Xent) is a replication from the github repository - https://github.com/Spijkervet/SimCLR
+ the original implementation can be found here: https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/nt_xent.py
+'''
+class NT_Xent(nn.Module):
+ def __init__(self, batch_size, temperature, world_size):
+ super(NT_Xent, self).__init__()
+ self.batch_size = batch_size
+ self.temperature = temperature
+ self.world_size = world_size
+
+ self.mask = self.mask_correlated_samples(batch_size, world_size)
+ self.criterion = nn.CrossEntropyLoss(reduction="sum")
+ self.similarity_f = nn.CosineSimilarity(dim=2)
+
+ def mask_correlated_samples(self, batch_size, world_size):
+ N = 2 * batch_size * world_size
+ mask = torch.ones((N, N), dtype=bool)
+ mask = mask.fill_diagonal_(0)
+ for i in range(batch_size * world_size):
+ mask[i, batch_size + i] = 0
+ mask[batch_size + i, i] = 0
+ # mask[i, batch_size * world_size + i] = 0
+ # mask[batch_size * world_size + i, i] = 0
+ return mask
+
+ def forward(self, z_i, z_j):
+ """
+ We do not sample negative examples explicitly.
+ Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples.
+ """
+ N = 2 * self.batch_size * self.world_size
+
+ z = torch.cat((z_i, z_j), dim=0)
+ # combine embeddings from all GPUs
+ if self.world_size > 1:
+ z = torch.cat(GatherLayer.apply(z), dim=0)
+
+ sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature
+
+ sim_i_j = torch.diag(sim, self.batch_size * self.world_size)
+ sim_j_i = torch.diag(sim, -self.batch_size * self.world_size)
+
+ # We have 2N samples, but with Distributed training every GPU gets N examples too, resulting in: 2xNxN
+ positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
+ negative_samples = sim[self.mask].reshape(N, -1)
+
+ labels = torch.zeros(N).to(positive_samples.device).long()
+ logits = torch.cat((positive_samples, negative_samples), dim=1)
+ loss = self.criterion(logits, labels)
+ loss /= N
+ return loss
+
+
+
+# Root Mean Squared Loss
+# penalizes the volume factor with non-linearlity
+class RMSLoss(nn.Module):
+ def __init__(self, reduce, loss_type="l2"):
+ super(RMSLoss, self).__init__()
+ self.weight_factor = 100.
+ if loss_type=="l2":
+ self.loss = nn.MSELoss(reduce=None)
+
+
+ def forward(self, est_targets, targets):
+ est_targets = est_targets.reshape(est_targets.shape[0]*est_targets.shape[1], est_targets.shape[2])
+ targets = targets.reshape(targets.shape[0]*targets.shape[1], targets.shape[2])
+ normalized_est = torch.sqrt(torch.mean(est_targets**2, dim=-1))
+ normalized_tgt = torch.sqrt(torch.mean(targets**2, dim=-1))
+
+ weight = torch.clamp(torch.abs(normalized_tgt-normalized_est), min=1/self.weight_factor) * self.weight_factor
+
+ return torch.mean(weight**1.5 * self.loss(normalized_est, normalized_tgt))
+
+
+
+# Multi-Scale Spectral Loss proposed at the paper "DDSP: DIFFERENTIABLE DIGITAL SIGNAL PROCESSING" (https://arxiv.org/abs/2001.04643)
+# we extend this loss by applying it to mid/side channels
+class MultiScale_Spectral_Loss_MidSide_DDSP(nn.Module):
+ def __init__(self, mode='midside', \
+ reduce=True, \
+ n_filters=None, \
+ windows_size=None, \
+ hops_size=None, \
+ window="hann", \
+ eps=1e-7, \
+ device=torch.device("cpu")):
+ super(MultiScale_Spectral_Loss_MidSide_DDSP, self).__init__()
+ self.mode = mode
+ self.eps = eps
+ self.mid_weight = 0.5 # value in the range of 0.0 ~ 1.0
+ self.logmag_weight = 0.1
+
+ if n_filters is None:
+ n_filters = [4096, 2048, 1024, 512]
+ # n_filters = [4096]
+ if windows_size is None:
+ windows_size = [4096, 2048, 1024, 512]
+ # windows_size = [4096]
+ if hops_size is None:
+ hops_size = [1024, 512, 256, 128]
+ # hops_size = [1024]
+
+ self.multiscales = []
+ for i in range(len(windows_size)):
+ cur_scale = {'window_size' : float(windows_size[i])}
+ if self.mode=='midside':
+ cur_scale['front_end'] = FrontEnd(channel='mono', \
+ n_fft=n_filters[i], \
+ hop_length=hops_size[i], \
+ win_length=windows_size[i], \
+ window=window, \
+ device=device)
+ elif self.mode=='ori':
+ cur_scale['front_end'] = FrontEnd(channel='stereo', \
+ n_fft=n_filters[i], \
+ hop_length=hops_size[i], \
+ win_length=windows_size[i], \
+ window=window, \
+ device=device)
+ self.multiscales.append(cur_scale)
+
+ self.objective_l1 = nn.L1Loss(reduce=reduce)
+ self.objective_l2 = nn.MSELoss(reduce=reduce)
+
+
+ def forward(self, est_targets, targets):
+ if self.mode=='midside':
+ return self.forward_midside(est_targets, targets)
+ elif self.mode=='ori':
+ return self.forward_ori(est_targets, targets)
+
+
+ def forward_ori(self, est_targets, targets):
+ total_loss = 0.0
+ total_mag_loss = 0.0
+ total_logmag_loss = 0.0
+ for cur_scale in self.multiscales:
+ est_mag = cur_scale['front_end'](est_targets, mode=["mag"])
+ tgt_mag = cur_scale['front_end'](targets, mode=["mag"])
+
+ mag_loss = self.magnitude_loss(est_mag, tgt_mag)
+ logmag_loss = self.log_magnitude_loss(est_mag, tgt_mag)
+ # cur_loss = mag_loss + logmag_loss
+ # total_loss += cur_loss
+ total_mag_loss += mag_loss
+ total_logmag_loss += logmag_loss
+ # return total_loss
+ # print(f"ori - mag : {total_mag_loss}\tlog mag : {total_logmag_loss}")
+ return (1-self.logmag_weight)*total_mag_loss + \
+ (self.logmag_weight)*total_logmag_loss
+
+
+ def forward_midside(self, est_targets, targets):
+ est_mid, est_side = self.to_mid_side(est_targets)
+ tgt_mid, tgt_side = self.to_mid_side(targets)
+ total_loss = 0.0
+ total_mag_loss = 0.0
+ total_logmag_loss = 0.0
+ for cur_scale in self.multiscales:
+ est_mid_mag = cur_scale['front_end'](est_mid, mode=["mag"])
+ est_side_mag = cur_scale['front_end'](est_side, mode=["mag"])
+ tgt_mid_mag = cur_scale['front_end'](tgt_mid, mode=["mag"])
+ tgt_side_mag = cur_scale['front_end'](tgt_side, mode=["mag"])
+
+ mag_loss = self.mid_weight*self.magnitude_loss(est_mid_mag, tgt_mid_mag) + \
+ (1-self.mid_weight)*self.magnitude_loss(est_side_mag, tgt_side_mag)
+ logmag_loss = self.mid_weight*self.log_magnitude_loss(est_mid_mag, tgt_mid_mag) + \
+ (1-self.mid_weight)*self.log_magnitude_loss(est_side_mag, tgt_side_mag)
+ # cur_loss = mag_loss + logmag_loss
+ # total_loss += cur_loss
+ total_mag_loss += mag_loss
+ total_logmag_loss += logmag_loss
+ # return total_loss
+ # print(f"midside - mag : {total_mag_loss}\tlog mag : {total_logmag_loss}")
+ return (1-self.logmag_weight)*total_mag_loss + \
+ (self.logmag_weight)*total_logmag_loss
+
+
+ def to_mid_side(self, stereo_in):
+ mid = stereo_in[:,0] + stereo_in[:,1]
+ side = stereo_in[:,0] - stereo_in[:,1]
+ return mid, side
+
+
+ def magnitude_loss(self, est_mag_spec, tgt_mag_spec):
+ return torch.norm(self.objective_l1(est_mag_spec, tgt_mag_spec))
+
+
+ def log_magnitude_loss(self, est_mag_spec, tgt_mag_spec):
+ est_log_mag_spec = torch.log10(est_mag_spec+self.eps)
+ tgt_log_mag_spec = torch.log10(tgt_mag_spec+self.eps)
+ return self.objective_l2(est_log_mag_spec, tgt_log_mag_spec)
+
+
+
+# hinge loss for discriminator
+def dis_hinge(dis_fake, dis_real):
+ return torch.mean(torch.relu(1. - dis_real)) + torch.mean(torch.relu(1. + dis_fake))
+
+
+# hinge loss for generator
+def gen_hinge(dis_fake, dis_real=None):
+ return -torch.mean(dis_fake)
+
+
+# DirectCLR's implementation of infoNCE loss
+def infoNCE(nn, p, temperature=0.1):
+ nn = torch.nn.functional.normalize(nn, dim=1)
+ p = torch.nn.functional.normalize(p, dim=1)
+ nn = gather_from_all(nn)
+ p = gather_from_all(p)
+ logits = nn @ p.T
+ logits /= temperature
+ n = p.shape[0]
+ labels = torch.arange(0, n, dtype=torch.long).cuda()
+ loss = torch.nn.functional.cross_entropy(logits, labels)
+ return loss
+
+
+
+
+# Class of available loss functions
+class Loss:
+ def __init__(self, args, reduce=True):
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device(f"cuda:{args.gpu}")
+ self.l1 = nn.L1Loss(reduce=reduce)
+ self.mse = nn.MSELoss(reduce=reduce)
+ self.ce = nn.CrossEntropyLoss()
+ self.triplet = nn.TripletMarginLoss(margin=1., p=2)
+
+ # self.ntxent = NT_Xent(args.train_batch*2, args.temperature, world_size=len(args.using_gpu.split(",")))
+ self.ntxent = NT_Xent(args.batch_size_total*(args.num_strong_negatives+1), args.temperature, world_size=1)
+ self.multi_scale_spectral_midside = MultiScale_Spectral_Loss_MidSide_DDSP(mode='midside', eps=args.eps, device=device)
+ self.multi_scale_spectral_ori = MultiScale_Spectral_Loss_MidSide_DDSP(mode='ori', eps=args.eps, device=device)
+ self.gain = RMSLoss(reduce=reduce)
+ self.infonce = infoNCE
+
diff --git a/mixing_style_transfer/modules/training_utils.py b/mixing_style_transfer/modules/training_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5195374a6c2489295f8f4bde96a54418eb732af7
--- /dev/null
+++ b/mixing_style_transfer/modules/training_utils.py
@@ -0,0 +1,174 @@
+""" Utility file for trainers """
+import os
+import shutil
+from glob import glob
+
+import torch
+import torch.distributed as dist
+
+
+
+''' checkpoint functions '''
+# saves checkpoint
+def save_checkpoint(model, \
+ optimizer, \
+ scheduler, \
+ epoch, \
+ checkpoint_dir, \
+ name, \
+ model_name):
+ os.makedirs(checkpoint_dir, exist_ok=True)
+ checkpoint_state = {
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "scheduler": scheduler.state_dict(),
+ "epoch": epoch
+ }
+ checkpoint_path = os.path.join(checkpoint_dir,'{}_{}_{}.pt'.format(name, model_name, epoch))
+ torch.save(checkpoint_state, checkpoint_path)
+ print("Saved checkpoint: {}".format(checkpoint_path))
+
+
+# reload model weights from checkpoint file
+def reload_ckpt(args, \
+ network, \
+ optimizer, \
+ scheduler, \
+ gpu, \
+ model_name, \
+ manual_reload_name=None, \
+ manual_reload=False, \
+ manual_reload_dir=None, \
+ epoch=None, \
+ fit_sefa=False):
+ if manual_reload:
+ reload_name = manual_reload_name
+ else:
+ reload_name = args.name
+ if manual_reload_dir:
+ ckpt_dir = manual_reload_dir + reload_name + "/ckpt/"
+ else:
+ ckpt_dir = args.output_dir + reload_name + "/ckpt/"
+ temp_ckpt_dir = f'{args.output_dir}{reload_name}/ckpt_temp/'
+ reload_epoch = epoch
+ # find best or latest epoch
+ if epoch==None:
+ reload_epoch_temp = 0
+ reload_epoch_ckpt = 0
+ if len(os.listdir(temp_ckpt_dir))!=0:
+ reload_epoch_temp = find_best_epoch(temp_ckpt_dir)
+ if len(os.listdir(ckpt_dir))!=0:
+ reload_epoch_ckpt = find_best_epoch(ckpt_dir)
+ if reload_epoch_ckpt >= reload_epoch_temp:
+ reload_epoch = reload_epoch_ckpt
+ else:
+ reload_epoch = reload_epoch_temp
+ ckpt_dir = temp_ckpt_dir
+ else:
+ if os.path.isfile(f"{temp_ckpt_dir}{reload_epoch}/{reload_name}_{model_name}_{reload_epoch}.pt"):
+ ckpt_dir = temp_ckpt_dir
+ # reloading weight
+ if model_name==None:
+ resuming_path = f"{ckpt_dir}{reload_epoch}/{reload_name}_{reload_epoch}.pt"
+ else:
+ resuming_path = f"{ckpt_dir}{reload_epoch}/{reload_name}_{model_name}_{reload_epoch}.pt"
+ if gpu==0:
+ print("===Resume checkpoint from: {}===".format(resuming_path))
+ loc = 'cuda:{}'.format(gpu)
+ checkpoint = torch.load(resuming_path, map_location=loc)
+ start_epoch = 0 if manual_reload and not fit_sefa else checkpoint["epoch"]
+
+ if manual_reload_dir is not None and 'parameter_estimation' in manual_reload_dir:
+ from collections import OrderedDict
+ new_state_dict = OrderedDict()
+ for k, v in checkpoint["model"].items():
+ name = 'module.' + k # add `module.`
+ new_state_dict[name] = v
+ network.load_state_dict(new_state_dict)
+ else:
+ network.load_state_dict(checkpoint["model"])
+ if not manual_reload:
+ optimizer.load_state_dict(checkpoint["optimizer"])
+ scheduler.load_state_dict(checkpoint["scheduler"])
+ if gpu==0:
+ # print("=> loaded checkpoint '{}' (epoch {})".format(resuming_path, checkpoint['epoch']))
+ print("=> loaded checkpoint '{}' (epoch {})".format(resuming_path, epoch))
+ return start_epoch
+
+
+# find best epoch for reloading current model
+def find_best_epoch(input_dir):
+ cur_epochs = glob("{}*".format(input_dir))
+ return find_by_name(cur_epochs)
+
+
+# sort string epoch names by integers
+def find_by_name(epochs):
+ int_epochs = []
+ for e in epochs:
+ int_epochs.append(int(os.path.basename(e)))
+ int_epochs.sort()
+ return (int_epochs[-1])
+
+
+# remove ckpt files
+def remove_ckpt(cur_ckpt_path_dir, leave=2):
+ ckpt_nums = [int(i) for i in os.listdir(cur_ckpt_path_dir)]
+ ckpt_nums.sort()
+ del_num = len(ckpt_nums) - leave
+ cur_del_num = 0
+ while del_num > 0:
+ shutil.rmtree("{}{}".format(cur_ckpt_path_dir, ckpt_nums[cur_del_num]))
+ del_num -= 1
+ cur_del_num += 1
+
+
+
+''' multi-GPU functions '''
+
+# gather function implemented from DirectCLR
+class GatherLayer_Direct(torch.autograd.Function):
+ """
+ Gather tensors from all workers with support for backward propagation:
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
+ """
+
+ @staticmethod
+ def forward(ctx, x):
+ output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
+ dist.all_gather(output, x)
+ return tuple(output)
+
+ @staticmethod
+ def backward(ctx, *grads):
+ all_gradients = torch.stack(grads)
+ dist.all_reduce(all_gradients)
+ return all_gradients[dist.get_rank()]
+
+from classy_vision.generic.distributed_util import (
+ convert_to_distributed_tensor,
+ convert_to_normal_tensor,
+ is_distributed_training_run,
+)
+def gather_from_all(tensor: torch.Tensor) -> torch.Tensor:
+ """
+ Similar to classy_vision.generic.distributed_util.gather_from_all
+ except that it does not cut the gradients
+ """
+ if tensor.ndim == 0:
+ # 0 dim tensors cannot be gathered. so unsqueeze
+ tensor = tensor.unsqueeze(0)
+
+ if is_distributed_training_run():
+ tensor, orig_device = convert_to_distributed_tensor(tensor)
+ gathered_tensors = GatherLayer_Direct.apply(tensor)
+ gathered_tensors = [
+ convert_to_normal_tensor(_tensor, orig_device)
+ for _tensor in gathered_tensors
+ ]
+ else:
+ gathered_tensors = [tensor]
+ gathered_tensor = torch.cat(gathered_tensors, 0)
+ return gathered_tensor
+
+
diff --git a/mixing_style_transfer/networks/__init__.py b/mixing_style_transfer/networks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fe695aec7c837c75a665bae0091975a9131056a
--- /dev/null
+++ b/mixing_style_transfer/networks/__init__.py
@@ -0,0 +1,2 @@
+from .architectures import *
+from .network_utils import *
\ No newline at end of file
diff --git a/mixing_style_transfer/networks/architectures.py b/mixing_style_transfer/networks/architectures.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0545e17bde5740d5bc2ba86539b68e897af26ca
--- /dev/null
+++ b/mixing_style_transfer/networks/architectures.py
@@ -0,0 +1,290 @@
+"""
+"Music Mixing Style Transfer: A Contrastive Learning Approach to Disentangle Audio Effects"
+
+ Implementation of neural networks used in the task 'Music Mixing Style Transfer'
+ - 'FXencoder'
+ - TCN based 'MixFXcloner'
+
+ We modify the TCN implementation from: https://github.com/csteinmetz1/micro-tcn
+ which was introduced in the work "Efficient neural networks for real-time modeling of analog dynamic range compression" by Christian J. Steinmetz, and Joshua D. Reiss
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.init as init
+
+import os
+import sys
+currentdir = os.path.dirname(os.path.realpath(__file__))
+sys.path.append(os.path.dirname(currentdir))
+
+from networks.network_utils import *
+
+
+
+# FXencoder that extracts audio effects from music recordings trained with a contrastive objective
+class FXencoder(nn.Module):
+ def __init__(self, config):
+ super(FXencoder, self).__init__()
+ # input is stereo channeled audio
+ config["channels"].insert(0, 2)
+
+ # encoder layers
+ encoder = []
+ for i in range(len(config["kernels"])):
+ if config["conv_block"]=='res':
+ encoder.append(Res_ConvBlock(dimension=1, \
+ in_channels=config["channels"][i], \
+ out_channels=config["channels"][i+1], \
+ kernel_size=config["kernels"][i], \
+ stride=config["strides"][i], \
+ padding="SAME", \
+ dilation=config["dilation"][i], \
+ norm=config["norm"], \
+ activation=config["activation"], \
+ last_activation=config["activation"]))
+ elif config["conv_block"]=='conv':
+ encoder.append(ConvBlock(dimension=1, \
+ layer_num=1, \
+ in_channels=config["channels"][i], \
+ out_channels=config["channels"][i+1], \
+ kernel_size=config["kernels"][i], \
+ stride=config["strides"][i], \
+ padding="VALID", \
+ dilation=config["dilation"][i], \
+ norm=config["norm"], \
+ activation=config["activation"], \
+ last_activation=config["activation"], \
+ mode='conv'))
+ self.encoder = nn.Sequential(*encoder)
+
+ # pooling method
+ self.glob_pool = nn.AdaptiveAvgPool1d(1)
+
+ # network forward operation
+ def forward(self, input):
+ enc_output = self.encoder(input)
+ glob_pooled = self.glob_pool(enc_output).squeeze(-1)
+
+ # outputs c feature
+ return glob_pooled
+
+
+# MixFXcloner which is based on a Temporal Convolutional Network (TCN) module
+ # original implementation : https://github.com/csteinmetz1/micro-tcn
+import pytorch_lightning as pl
+class TCNModel(pl.LightningModule):
+ """ Temporal convolutional network with conditioning module.
+ Args:
+ nparams (int): Number of conditioning parameters.
+ ninputs (int): Number of input channels (mono = 1, stereo 2). Default: 1
+ noutputs (int): Number of output channels (mono = 1, stereo 2). Default: 1
+ nblocks (int): Number of total TCN blocks. Default: 10
+ kernel_size (int): Width of the convolutional kernels. Default: 3
+ dialation_growth (int): Compute the dilation factor at each block as dilation_growth ** (n % stack_size). Default: 1
+ channel_growth (int): Compute the output channels at each black as in_ch * channel_growth. Default: 2
+ channel_width (int): When channel_growth = 1 all blocks use convolutions with this many channels. Default: 64
+ stack_size (int): Number of blocks that constitute a single stack of blocks. Default: 10
+ grouped (bool): Use grouped convolutions to reduce the total number of parameters. Default: False
+ causal (bool): Causal TCN configuration does not consider future input values. Default: False
+ skip_connections (bool): Skip connections from each block to the output. Default: False
+ num_examples (int): Number of evaluation audio examples to log after each epochs. Default: 4
+ """
+ def __init__(self,
+ nparams,
+ ninputs=1,
+ noutputs=1,
+ nblocks=10,
+ kernel_size=3,
+ dilation_growth=1,
+ channel_growth=1,
+ channel_width=32,
+ stack_size=10,
+ cond_dim=2048,
+ grouped=False,
+ causal=False,
+ skip_connections=False,
+ num_examples=4,
+ save_dir=None,
+ **kwargs):
+ super(TCNModel, self).__init__()
+ self.save_hyperparameters()
+
+ self.blocks = torch.nn.ModuleList()
+ for n in range(nblocks):
+ in_ch = out_ch if n > 0 else ninputs
+
+ if self.hparams.channel_growth > 1:
+ out_ch = in_ch * self.hparams.channel_growth
+ else:
+ out_ch = self.hparams.channel_width
+
+ dilation = self.hparams.dilation_growth ** (n % self.hparams.stack_size)
+ self.blocks.append(TCNBlock(in_ch,
+ out_ch,
+ kernel_size=self.hparams.kernel_size,
+ dilation=dilation,
+ padding="same" if self.hparams.causal else "valid",
+ causal=self.hparams.causal,
+ cond_dim=cond_dim,
+ grouped=self.hparams.grouped,
+ conditional=True if self.hparams.nparams > 0 else False))
+
+ self.output = torch.nn.Conv1d(out_ch, noutputs, kernel_size=1)
+
+ def forward(self, x, cond):
+ # iterate over blocks passing conditioning
+ for idx, block in enumerate(self.blocks):
+ # for SeFa
+ if isinstance(cond, list):
+ x = block(x, cond[idx])
+ else:
+ x = block(x, cond)
+ skips = 0
+
+ out = torch.clamp(self.output(x + skips), min=-1, max=1)
+
+ return out
+
+ def compute_receptive_field(self):
+ """ Compute the receptive field in samples."""
+ rf = self.hparams.kernel_size
+ for n in range(1,self.hparams.nblocks):
+ dilation = self.hparams.dilation_growth ** (n % self.hparams.stack_size)
+ rf = rf + ((self.hparams.kernel_size-1) * dilation)
+ return rf
+
+ # add any model hyperparameters here
+ @staticmethod
+ def add_model_specific_args(parent_parser):
+ parser = ArgumentParser(parents=[parent_parser], add_help=False)
+ # --- model related ---
+ parser.add_argument('--ninputs', type=int, default=1)
+ parser.add_argument('--noutputs', type=int, default=1)
+ parser.add_argument('--nblocks', type=int, default=4)
+ parser.add_argument('--kernel_size', type=int, default=5)
+ parser.add_argument('--dilation_growth', type=int, default=10)
+ parser.add_argument('--channel_growth', type=int, default=1)
+ parser.add_argument('--channel_width', type=int, default=32)
+ parser.add_argument('--stack_size', type=int, default=10)
+ parser.add_argument('--grouped', default=False, action='store_true')
+ parser.add_argument('--causal', default=False, action="store_true")
+ parser.add_argument('--skip_connections', default=False, action="store_true")
+
+ return parser
+
+
+class TCNBlock(torch.nn.Module):
+ def __init__(self,
+ in_ch,
+ out_ch,
+ kernel_size=3,
+ dilation=1,
+ cond_dim=2048,
+ grouped=False,
+ causal=False,
+ conditional=False,
+ **kwargs):
+ super(TCNBlock, self).__init__()
+
+ self.in_ch = in_ch
+ self.out_ch = out_ch
+ self.kernel_size = kernel_size
+ self.dilation = dilation
+ self.grouped = grouped
+ self.causal = causal
+ self.conditional = conditional
+
+ groups = out_ch if grouped and (in_ch % out_ch == 0) else 1
+
+ self.pad_length = ((kernel_size-1)*dilation) if self.causal else ((kernel_size-1)*dilation)//2
+ self.conv1 = torch.nn.Conv1d(in_ch,
+ out_ch,
+ kernel_size=kernel_size,
+ padding=self.pad_length,
+ dilation=dilation,
+ groups=groups,
+ bias=False)
+ if grouped:
+ self.conv1b = torch.nn.Conv1d(out_ch, out_ch, kernel_size=1)
+
+ if conditional:
+ self.film = FiLM(cond_dim, out_ch)
+ self.bn = torch.nn.BatchNorm1d(out_ch)
+
+ self.relu = torch.nn.LeakyReLU()
+ self.res = torch.nn.Conv1d(in_ch,
+ out_ch,
+ kernel_size=1,
+ groups=in_ch,
+ bias=False)
+
+ def forward(self, x, p):
+ x_in = x
+
+ x = self.relu(self.bn(self.conv1(x)))
+ x = self.film(x, p)
+
+ x_res = self.res(x_in)
+
+ if self.causal:
+ x = x[..., :-self.pad_length]
+ x += x_res
+
+ return x
+
+
+
+if __name__ == '__main__':
+ ''' check model I/O shape '''
+ import yaml
+ with open('networks/configs.yaml', 'r') as f:
+ configs = yaml.full_load(f)
+
+ batch_size = 32
+ sr = 44100
+ input_length = sr*5
+
+ input = torch.rand(batch_size, 2, input_length)
+ print(f"Input Shape : {input.shape}\n")
+
+
+ print('\n========== Audio Effects Encoder (FXencoder) ==========')
+ model_arc = "FXencoder"
+ model_options = "default"
+
+ config = configs[model_arc][model_options]
+ print(f"configuration: {config}")
+
+ network = FXencoder(config)
+ pytorch_total_params = sum(p.numel() for p in network.parameters() if p.requires_grad)
+ print(f"Number of trainable parameters : {pytorch_total_params}")
+
+ # model inference
+ output_c = network(input)
+ print(f"Output Shape : {output_c.shape}")
+
+
+ print('\n========== TCN based MixFXcloner ==========')
+ model_arc = "TCN"
+ model_options = "default"
+
+ config = configs[model_arc][model_options]
+ print(f"configuration: {config}")
+
+ network = TCNModel(nparams=config["condition_dimension"], ninputs=2, noutputs=2, \
+ nblocks=config["nblocks"], \
+ dilation_growth=config["dilation_growth"], \
+ kernel_size=config["kernel_size"], \
+ channel_width=config["channel_width"], \
+ stack_size=config["stack_size"], \
+ cond_dim=config["condition_dimension"], \
+ causal=config["causal"])
+ pytorch_total_params = sum(p.numel() for p in network.parameters() if p.requires_grad)
+ print(f"Number of trainable parameters : {pytorch_total_params}\tReceptive field duration : {network.compute_receptive_field() / sr:.3f}")
+
+ ref_embedding = output_c
+ # model inference
+ output = network(input, output_c)
+ print(f"Output Shape : {output.shape}")
+
diff --git a/mixing_style_transfer/networks/configs.yaml b/mixing_style_transfer/networks/configs.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ed02a9ec0b889824491aa2a72ce0c9a3515ace3d
--- /dev/null
+++ b/mixing_style_transfer/networks/configs.yaml
@@ -0,0 +1,30 @@
+# model architecture configurations
+
+
+# Music Effects Encoder
+Effects_Encoder:
+
+ default:
+ channels: [16, 32, 64, 128, 256, 256, 512, 512, 1024, 1024, 2048, 2048]
+ kernels: [25, 25, 15, 15, 10, 10, 10, 10, 5, 5, 5, 5]
+ strides: [4, 4, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1]
+ dilation: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
+ bias: True
+ norm: 'batch'
+ conv_block: 'res'
+ activation: "relu"
+
+
+# TCN
+TCN:
+
+ # receptive field = 5.2 seconds
+ default:
+ condition_dimension: 2048
+ nblocks: 14
+ dilation_growth: 2
+ kernel_size: 15
+ channel_width: 128
+ stack_size: 15
+ causal: False
+
diff --git a/mixing_style_transfer/networks/network_utils.py b/mixing_style_transfer/networks/network_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9dd5a4b8878dc606da3abe754910b30ca4075316
--- /dev/null
+++ b/mixing_style_transfer/networks/network_utils.py
@@ -0,0 +1,184 @@
+"""
+ Utility File
+ containing functions for neural networks
+"""
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.init as init
+import torch
+import torchaudio
+
+
+
+# 1-dimensional convolutional layer
+# in the order of conv -> norm -> activation
+class Conv1d_layer(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, \
+ stride=1, \
+ padding="SAME", dilation=1, bias=True, \
+ norm="batch", activation="relu", \
+ mode="conv"):
+ super(Conv1d_layer, self).__init__()
+
+ self.conv1d = nn.Sequential()
+
+ ''' padding '''
+ if mode=="deconv":
+ padding = int(dilation * (kernel_size-1) / 2)
+ out_padding = 0 if stride==1 else 1
+ elif mode=="conv" or "alias_free" in mode:
+ if padding == "SAME":
+ pad = int((kernel_size-1) * dilation)
+ l_pad = int(pad//2)
+ r_pad = pad - l_pad
+ padding_area = (l_pad, r_pad)
+ elif padding == "VALID":
+ padding_area = (0, 0)
+ else:
+ pass
+
+ ''' convolutional layer '''
+ if mode=="deconv":
+ self.conv1d.add_module("deconv1d", nn.ConvTranspose1d(in_channels, out_channels, kernel_size, \
+ stride=stride, padding=padding, output_padding=out_padding, \
+ dilation=dilation, \
+ bias=bias))
+ elif mode=="conv":
+ self.conv1d.add_module(f"{mode}1d_pad", nn.ReflectionPad1d(padding_area))
+ self.conv1d.add_module(f"{mode}1d", nn.Conv1d(in_channels, out_channels, kernel_size, \
+ stride=stride, padding=0, \
+ dilation=dilation, \
+ bias=bias))
+ elif "alias_free" in mode:
+ if "up" in mode:
+ up_factor = stride * 2
+ down_factor = 2
+ elif "down" in mode:
+ up_factor = 2
+ down_factor = stride * 2
+ else:
+ raise ValueError("choose alias-free method : 'up' or 'down'")
+ # procedure : conv -> upsample -> lrelu -> low-pass filter -> downsample
+ # the torchaudio.transforms.Resample's default resampling_method is 'sinc_interpolation' which performs low-pass filter during the process
+ # details at https://pytorch.org/audio/stable/transforms.html
+ self.conv1d.add_module(f"{mode}1d_pad", nn.ReflectionPad1d(padding_area))
+ self.conv1d.add_module(f"{mode}1d", nn.Conv1d(in_channels, out_channels, kernel_size, \
+ stride=1, padding=0, \
+ dilation=dilation, \
+ bias=bias))
+ self.conv1d.add_module(f"{mode}upsample", torchaudio.transforms.Resample(orig_freq=1, new_freq=up_factor))
+ self.conv1d.add_module(f"{mode}lrelu", nn.LeakyReLU())
+ self.conv1d.add_module(f"{mode}downsample", torchaudio.transforms.Resample(orig_freq=down_factor, new_freq=1))
+
+ ''' normalization '''
+ if norm=="batch":
+ self.conv1d.add_module("batch_norm", nn.BatchNorm1d(out_channels))
+ # self.conv1d.add_module("batch_norm", nn.SyncBatchNorm(out_channels))
+
+ ''' activation '''
+ if 'alias_free' not in mode:
+ if activation=="relu":
+ self.conv1d.add_module("relu", nn.ReLU())
+ elif activation=="lrelu":
+ self.conv1d.add_module("lrelu", nn.LeakyReLU())
+
+
+ def forward(self, input):
+ # input shape should be : batch x channel x height x width
+ output = self.conv1d(input)
+ return output
+
+
+
+# Residual Block
+ # the input is added after the first convolutional layer, retaining its original channel size
+ # therefore, the second convolutional layer's output channel may differ
+class Res_ConvBlock(nn.Module):
+ def __init__(self, dimension, \
+ in_channels, out_channels, \
+ kernel_size, \
+ stride=1, padding="SAME", \
+ dilation=1, \
+ bias=True, \
+ norm="batch", \
+ activation="relu", last_activation="relu", \
+ mode="conv"):
+ super(Res_ConvBlock, self).__init__()
+
+ if dimension==1:
+ self.conv1 = Conv1d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation)
+ self.conv2 = Conv1d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode)
+ elif dimension==2:
+ self.conv1 = Conv2d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation)
+ self.conv2 = Conv2d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode)
+
+
+ def forward(self, input):
+ c1_out = self.conv1(input) + input
+ c2_out = self.conv2(c1_out)
+ return c2_out
+
+
+
+# Convoluaionl Block
+ # consists of multiple (number of layer_num) convolutional layers
+ # only the final convoluational layer outputs the desired 'out_channels'
+class ConvBlock(nn.Module):
+ def __init__(self, dimension, layer_num, \
+ in_channels, out_channels, \
+ kernel_size, \
+ stride=1, padding="SAME", \
+ dilation=1, \
+ bias=True, \
+ norm="batch", \
+ activation="relu", last_activation="relu", \
+ mode="conv"):
+ super(ConvBlock, self).__init__()
+
+ conv_block = []
+ if dimension==1:
+ for i in range(layer_num-1):
+ conv_block.append(Conv1d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation))
+ conv_block.append(Conv1d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode))
+ elif dimension==2:
+ for i in range(layer_num-1):
+ conv_block.append(Conv2d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation))
+ conv_block.append(Conv2d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode))
+ self.conv_block = nn.Sequential(*conv_block)
+
+
+ def forward(self, input):
+ return self.conv_block(input)
+
+
+
+# Feature-wise Linear Modulation
+class FiLM(nn.Module):
+ def __init__(self, condition_len=2048, feature_len=1024):
+ super(FiLM, self).__init__()
+ self.film_fc = nn.Linear(condition_len, feature_len*2)
+ self.feat_len = feature_len
+
+
+ def forward(self, feature, condition, sefa=None):
+ # SeFA
+ if sefa:
+ weight = self.film_fc.weight.T
+ weight = weight / torch.linalg.norm((weight+1e-07), dim=0, keepdims=True)
+ eigen_values, eigen_vectors = torch.eig(torch.matmul(weight, weight.T), eigenvectors=True)
+
+ ####### custom parameters #######
+ chosen_eig_idx = sefa[0]
+ alpha = eigen_values[chosen_eig_idx][0] * sefa[1]
+ #################################
+
+ An = eigen_vectors[chosen_eig_idx].repeat(condition.shape[0], 1)
+ alpha_An = alpha * An
+
+ condition += alpha_An
+
+ film_factor = self.film_fc(condition).unsqueeze(-1)
+ r, b = torch.split(film_factor, self.feat_len, dim=1)
+ return r*feature + b
+
+
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c80119118b2ffe9e6ef6e88a63faaa5430687b9f
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,23 @@
+aubio==0.4.9
+classy_vision==0.6.0
+config==0.5.1
+librosa==0.9.2
+matplotlib==3.3.3
+numba==0.48.0
+numpy==1.23.0
+psutil==5.7.2
+pyloudnorm==0.1.0
+git+https://github.com/csteinmetz1/pymixconsole
+pypesq==1.2.4
+pytorch_lightning==1.3.2
+PyYAML==5.4
+scikit_learn==1.1.3
+scipy==1.6
+SoundFile==0.10.3.post1
+soxbindings==1.2.3
+torch==1.9.0
+torchaudio==0.9.0
+torchvision==0.10.0
+torchmetrics==0.6.0
+torchtext==0.10.0
+demucs
diff --git a/samples/interpolation/#0/input.wav b/samples/interpolation/#0/input.wav
new file mode 100644
index 0000000000000000000000000000000000000000..282442b6c90a8b7345aabc7d6452d3018facfc7f
--- /dev/null
+++ b/samples/interpolation/#0/input.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fc4a7d1283666051d43d07e9a11d4c5014426b0753b316fb64d1aef30288b0bd
+size 5274396
diff --git a/samples/interpolation/#0/reference.wav b/samples/interpolation/#0/reference.wav
new file mode 100644
index 0000000000000000000000000000000000000000..389c9f5ed2e3e48e3d78b480d66245144bd5c473
--- /dev/null
+++ b/samples/interpolation/#0/reference.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:207cdf21724c640a1b7006013d2519c3ac4176604019d7738905911990036a6d
+size 3842338
diff --git a/samples/interpolation/#0/reference_B.wav b/samples/interpolation/#0/reference_B.wav
new file mode 100644
index 0000000000000000000000000000000000000000..11392d555bdbd180bcf1c8b4f10f28e69ce062b2
--- /dev/null
+++ b/samples/interpolation/#0/reference_B.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0b591708b64f54c3f5c2d03c4359bfa6bc7b125041bb2792a34cd0fd18ae4961
+size 3790802
diff --git a/samples/interpolation/#0/separated/mdx_extra/input/bass.wav b/samples/interpolation/#0/separated/mdx_extra/input/bass.wav
new file mode 100644
index 0000000000000000000000000000000000000000..3d74c6a9f08b59e99f0e2188fea1e2685066c3f7
--- /dev/null
+++ b/samples/interpolation/#0/separated/mdx_extra/input/bass.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:be2299dc8d612104b305f7a80c93c5dcb05da62e815c3c865bdfac91e56cbedf
+size 5274396
diff --git a/samples/interpolation/#0/separated/mdx_extra/input/drums.wav b/samples/interpolation/#0/separated/mdx_extra/input/drums.wav
new file mode 100644
index 0000000000000000000000000000000000000000..8489761cc33eb8a0ceb7d6a1b732d63c91fdba54
--- /dev/null
+++ b/samples/interpolation/#0/separated/mdx_extra/input/drums.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b503786e6b8503346ed0f2ceb482bc16440cd63ec3e0b83952c47998a292fb84
+size 5274396
diff --git a/samples/interpolation/#0/separated/mdx_extra/input/other.wav b/samples/interpolation/#0/separated/mdx_extra/input/other.wav
new file mode 100644
index 0000000000000000000000000000000000000000..2476dac73c279fb9f9a0c95aeb1de4a3106359be
--- /dev/null
+++ b/samples/interpolation/#0/separated/mdx_extra/input/other.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8e0f4593e841ae116e0e1a71c66e84c032c82c5780c295af3a6cb1968bc7d4dd
+size 5274396
diff --git a/samples/interpolation/#0/separated/mdx_extra/input/vocals.wav b/samples/interpolation/#0/separated/mdx_extra/input/vocals.wav
new file mode 100644
index 0000000000000000000000000000000000000000..1d3c3417f4b28d240de49a170a2c2ecb885e48a7
--- /dev/null
+++ b/samples/interpolation/#0/separated/mdx_extra/input/vocals.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:690bedc47fca08257eba3ec411a3a9f7afb340b5ad6d8bf909f372f4ec369853
+size 5274396
diff --git a/samples/interpolation/#0/separated/mdx_extra/reference/bass.wav b/samples/interpolation/#0/separated/mdx_extra/reference/bass.wav
new file mode 100644
index 0000000000000000000000000000000000000000..039132e2fe3cd66bf8575d3f1f09652719966b08
--- /dev/null
+++ b/samples/interpolation/#0/separated/mdx_extra/reference/bass.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:20023ad9e5b90ee8c338dcc77e5f1644b5131d3acc2ab17f8469bc8bfe57a353
+size 3529776
diff --git a/samples/interpolation/#0/separated/mdx_extra/reference/drums.wav b/samples/interpolation/#0/separated/mdx_extra/reference/drums.wav
new file mode 100644
index 0000000000000000000000000000000000000000..619cdcf8c2d35eea5f7ebac58c9e3b6631cc0637
--- /dev/null
+++ b/samples/interpolation/#0/separated/mdx_extra/reference/drums.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5a0c050eb941aa939d0c4f78216ad3a40d5ee43d51a7ba23266d228d661785e6
+size 3529776
diff --git a/samples/interpolation/#0/separated/mdx_extra/reference/other.wav b/samples/interpolation/#0/separated/mdx_extra/reference/other.wav
new file mode 100644
index 0000000000000000000000000000000000000000..cd328f0c5f131ea58787cb96fda7a4bd63541be6
--- /dev/null
+++ b/samples/interpolation/#0/separated/mdx_extra/reference/other.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c405c3ed345388b809520ea5f43229210d67f73ab4b1c1903ca2c761c7467214
+size 3529776
diff --git a/samples/interpolation/#0/separated/mdx_extra/reference/vocals.wav b/samples/interpolation/#0/separated/mdx_extra/reference/vocals.wav
new file mode 100644
index 0000000000000000000000000000000000000000..ce39b12748ed3b42654de8c75ae2d6bba6dee426
--- /dev/null
+++ b/samples/interpolation/#0/separated/mdx_extra/reference/vocals.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:07e7685489ab7fe68f3e3752ba01789a8be350a67b86ecf95fa877bb6965e0b4
+size 3529776
diff --git a/samples/interpolation/#0/separated/mdx_extra/reference_B/bass.wav b/samples/interpolation/#0/separated/mdx_extra/reference_B/bass.wav
new file mode 100644
index 0000000000000000000000000000000000000000..64134d98b6755128d2207b7372f492cad35145bd
--- /dev/null
+++ b/samples/interpolation/#0/separated/mdx_extra/reference_B/bass.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:237b9986f23f90c8c7a92d3f284243eaceb0b053b530e11984cc487915f96d45
+size 3482464
diff --git a/samples/interpolation/#0/separated/mdx_extra/reference_B/drums.wav b/samples/interpolation/#0/separated/mdx_extra/reference_B/drums.wav
new file mode 100644
index 0000000000000000000000000000000000000000..e2a5f3fed2445646fdc541f73bbd280020f74d65
--- /dev/null
+++ b/samples/interpolation/#0/separated/mdx_extra/reference_B/drums.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:49e9ff173b853fcb22ccc6dee3c63def308c2a4cc5c76bf46f3bbd80f0f4bd41
+size 3482464
diff --git a/samples/interpolation/#0/separated/mdx_extra/reference_B/other.wav b/samples/interpolation/#0/separated/mdx_extra/reference_B/other.wav
new file mode 100644
index 0000000000000000000000000000000000000000..65a09681bcf10ae2feae4243fc058d8e1afac032
--- /dev/null
+++ b/samples/interpolation/#0/separated/mdx_extra/reference_B/other.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:65f48ae65fd288c052efca0fedace5ad8ad9504d664be040c3af48796689f3c1
+size 3482464
diff --git a/samples/interpolation/#0/separated/mdx_extra/reference_B/vocals.wav b/samples/interpolation/#0/separated/mdx_extra/reference_B/vocals.wav
new file mode 100644
index 0000000000000000000000000000000000000000..dd8f51aa4787929635e79727d33e467ff98e3cf3
--- /dev/null
+++ b/samples/interpolation/#0/separated/mdx_extra/reference_B/vocals.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9ac504e751f3c601e8cb5ee5175b28e8d0606e78de66bf606f5097fec11ada4c
+size 3482464
diff --git a/samples/style_transfer/#0/input.wav b/samples/style_transfer/#0/input.wav
new file mode 100644
index 0000000000000000000000000000000000000000..4b7f2b7ed6ca9082566b3a4e2b83c6753ba01a7f
--- /dev/null
+++ b/samples/style_transfer/#0/input.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1d4365f3972c0a58f01479ea90a532e319e4ebe8773cae5e88b5a13fca3de26c
+size 2646196
diff --git a/samples/style_transfer/#0/reference.wav b/samples/style_transfer/#0/reference.wav
new file mode 100644
index 0000000000000000000000000000000000000000..3536d4c66828731a34c7cfc302bdff66c6a5f750
--- /dev/null
+++ b/samples/style_transfer/#0/reference.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:01d5d2560ae9d368aab59f1c5297b34f962126a81b535ecbf9031d272b907823
+size 5421522
diff --git a/samples/style_transfer/#0/separated/mdx_extra/input/bass.wav b/samples/style_transfer/#0/separated/mdx_extra/input/bass.wav
new file mode 100644
index 0000000000000000000000000000000000000000..bf21648d359d650727a5ec74073e47a24decbca6
--- /dev/null
+++ b/samples/style_transfer/#0/separated/mdx_extra/input/bass.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b46e0c2d82c11f8ff407fe89eb662834429c40f5642066eb341ca2e9b56cd264
+size 2646196
diff --git a/samples/style_transfer/#0/separated/mdx_extra/input/drums.wav b/samples/style_transfer/#0/separated/mdx_extra/input/drums.wav
new file mode 100644
index 0000000000000000000000000000000000000000..f37eadd1987336a4c107469902ad4908fcd082cd
--- /dev/null
+++ b/samples/style_transfer/#0/separated/mdx_extra/input/drums.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a98c5400d00ba28b7b4644c473a34dd2b027b8e5496546cfbfe76f45c554c8da
+size 2646196
diff --git a/samples/style_transfer/#0/separated/mdx_extra/input/other.wav b/samples/style_transfer/#0/separated/mdx_extra/input/other.wav
new file mode 100644
index 0000000000000000000000000000000000000000..2834bf3d26add427c429bbc4eb53d6b511497988
--- /dev/null
+++ b/samples/style_transfer/#0/separated/mdx_extra/input/other.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c26161d9fc98898100c9c229e55eab28a41787ba57253872b82f3f8e916e39fb
+size 2646196
diff --git a/samples/style_transfer/#0/separated/mdx_extra/input/vocals.wav b/samples/style_transfer/#0/separated/mdx_extra/input/vocals.wav
new file mode 100644
index 0000000000000000000000000000000000000000..1293bbe552911ee064eeb4155efdb58fd2137b3e
--- /dev/null
+++ b/samples/style_transfer/#0/separated/mdx_extra/input/vocals.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:337af4058bcdaf404fb110d5e071c24bbf77753a55aaa7e8db47726ce502bcab
+size 2646196
diff --git a/samples/style_transfer/#0/separated/mdx_extra/reference/bass.wav b/samples/style_transfer/#0/separated/mdx_extra/reference/bass.wav
new file mode 100644
index 0000000000000000000000000000000000000000..d871a9fe72d178c3e435cc74e69f1048038a8232
--- /dev/null
+++ b/samples/style_transfer/#0/separated/mdx_extra/reference/bass.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b79cc96acf4550bbd9c1bd79ea3ab43a1fb90ea15e65288806d379c7de6e8015
+size 5421404
diff --git a/samples/style_transfer/#0/separated/mdx_extra/reference/drums.wav b/samples/style_transfer/#0/separated/mdx_extra/reference/drums.wav
new file mode 100644
index 0000000000000000000000000000000000000000..fa19af641b9ff3cebc726d16ff836a2a12fbf336
--- /dev/null
+++ b/samples/style_transfer/#0/separated/mdx_extra/reference/drums.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:395bc9f89b9e996a12ec7e57bedf9c84899b67fb0d9beabd7e979fe5ac7e8ebb
+size 5421404
diff --git a/samples/style_transfer/#0/separated/mdx_extra/reference/other.wav b/samples/style_transfer/#0/separated/mdx_extra/reference/other.wav
new file mode 100644
index 0000000000000000000000000000000000000000..55537ca7ef5280165c51ff8340d96a4192466315
--- /dev/null
+++ b/samples/style_transfer/#0/separated/mdx_extra/reference/other.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:68bab09fa8dee9591a34400ba54284255f4ab059d18e83351ed2e95a894e40e1
+size 5421404
diff --git a/samples/style_transfer/#0/separated/mdx_extra/reference/vocals.wav b/samples/style_transfer/#0/separated/mdx_extra/reference/vocals.wav
new file mode 100644
index 0000000000000000000000000000000000000000..d3c600debfc8c164db95c6beb047059480080013
--- /dev/null
+++ b/samples/style_transfer/#0/separated/mdx_extra/reference/vocals.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f124e60a18e5b43e9c361c1a4cae9095a17f5059e917db70763d093903dc01d3
+size 5421404
diff --git a/samples/style_transfer/#2/input.wav b/samples/style_transfer/#2/input.wav
new file mode 100644
index 0000000000000000000000000000000000000000..9756d8fcd20db2962358f7e6f6126de21fbd80ff
--- /dev/null
+++ b/samples/style_transfer/#2/input.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d86b336d0e48b263f093a201aaf2d200c09da740394feb34996ed66b94193193
+size 4322250
diff --git a/samples/style_transfer/#2/reference.wav b/samples/style_transfer/#2/reference.wav
new file mode 100644
index 0000000000000000000000000000000000000000..ec9bec238ece78e11b560212e446248c74008a5b
--- /dev/null
+++ b/samples/style_transfer/#2/reference.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d801007aa5c4ca426ccd583f7acad380b7c75f888a7725de7469d1d91020419f
+size 5912622
diff --git a/samples/style_transfer/#2/separated/mdx_extra/input/bass.wav b/samples/style_transfer/#2/separated/mdx_extra/input/bass.wav
new file mode 100644
index 0000000000000000000000000000000000000000..06638eb244ff798907d755a2abcfc3779f89b2f9
--- /dev/null
+++ b/samples/style_transfer/#2/separated/mdx_extra/input/bass.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5c2a109d444459edfd3e07ca6b252b1b5ec4a19c8c1e713dab919950d3b61828
+size 4322132
diff --git a/samples/style_transfer/#2/separated/mdx_extra/input/drums.wav b/samples/style_transfer/#2/separated/mdx_extra/input/drums.wav
new file mode 100644
index 0000000000000000000000000000000000000000..c8e6aeb42f3afc5139d8e7b9729bd28d311b5c32
--- /dev/null
+++ b/samples/style_transfer/#2/separated/mdx_extra/input/drums.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5eb6b6d1b16263d9496260180460dc63e5bae79dd238f234cf8435d2763c259f
+size 4322132
diff --git a/samples/style_transfer/#2/separated/mdx_extra/input/other.wav b/samples/style_transfer/#2/separated/mdx_extra/input/other.wav
new file mode 100644
index 0000000000000000000000000000000000000000..5658b4c7dc53e69b7eef4f5cf0e0d03cbf5df0b9
--- /dev/null
+++ b/samples/style_transfer/#2/separated/mdx_extra/input/other.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:51e74238969312ab5e6218b2b864b2fb8635f1a6eddc0f47ad9adce81c985825
+size 4322132
diff --git a/samples/style_transfer/#2/separated/mdx_extra/input/vocals.wav b/samples/style_transfer/#2/separated/mdx_extra/input/vocals.wav
new file mode 100644
index 0000000000000000000000000000000000000000..1dadad88dd6c4173906527f40f8ed4f2287a397e
--- /dev/null
+++ b/samples/style_transfer/#2/separated/mdx_extra/input/vocals.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:46453591d477875da925200b78540e35f758bb67708a89c227869c65a8236b19
+size 4322132
diff --git a/samples/style_transfer/#2/separated/mdx_extra/reference/bass.wav b/samples/style_transfer/#2/separated/mdx_extra/reference/bass.wav
new file mode 100644
index 0000000000000000000000000000000000000000..170481d704533aead763e056ee3ef900b4a5b509
--- /dev/null
+++ b/samples/style_transfer/#2/separated/mdx_extra/reference/bass.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:71d47d119c93ea07beb0002e1c7a441d862258111e4b14c3a805281ed80e3ade
+size 5912504
diff --git a/samples/style_transfer/#2/separated/mdx_extra/reference/drums.wav b/samples/style_transfer/#2/separated/mdx_extra/reference/drums.wav
new file mode 100644
index 0000000000000000000000000000000000000000..7969ef0952bdb82886c3445d5dbaddc8adb44c02
--- /dev/null
+++ b/samples/style_transfer/#2/separated/mdx_extra/reference/drums.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:31d1382c6c455b59ca15cd6eda3b11649b1b61c00c8c0e19a2df756925a0d872
+size 5912504
diff --git a/samples/style_transfer/#2/separated/mdx_extra/reference/other.wav b/samples/style_transfer/#2/separated/mdx_extra/reference/other.wav
new file mode 100644
index 0000000000000000000000000000000000000000..d952228a686c27a00418ff59e49b186e8ed2efe1
--- /dev/null
+++ b/samples/style_transfer/#2/separated/mdx_extra/reference/other.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5a7a0b5706b444653038716b28b81eb448971093ea1931143fd0c3b560a7fb1b
+size 5912504
diff --git a/samples/style_transfer/#2/separated/mdx_extra/reference/vocals.wav b/samples/style_transfer/#2/separated/mdx_extra/reference/vocals.wav
new file mode 100644
index 0000000000000000000000000000000000000000..6e4176d40f09de01bc9dce85b0dfacc408111bb6
--- /dev/null
+++ b/samples/style_transfer/#2/separated/mdx_extra/reference/vocals.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:97cf038b433df7205f043523f7560fb416895664abf541987f3fb72971f7c757
+size 5912504
diff --git a/weights/musdb18_fxfeatures_eqcompimagegain.npy b/weights/musdb18_fxfeatures_eqcompimagegain.npy
new file mode 100644
index 0000000000000000000000000000000000000000..3366d287633fc5775947e724c7ee2de817ea4420
--- /dev/null
+++ b/weights/musdb18_fxfeatures_eqcompimagegain.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:368187f0ccb11c6428c82ea071f3e2a7a1265beb99fcae624e1c0f86efed5876
+size 525467