jhtonyKoo commited on
Commit
2777fde
·
1 Parent(s): cd708f0

Upload 61 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. README.md +126 -12
  3. inference/configs.yaml +30 -0
  4. inference/feature_extraction.py +194 -0
  5. inference/style_transfer.py +400 -0
  6. mixing_style_transfer/data_loader/__init__.py +2 -0
  7. mixing_style_transfer/data_loader/data_loader.py +672 -0
  8. mixing_style_transfer/data_loader/loader_utils.py +71 -0
  9. mixing_style_transfer/mixing_manipulator/__init__.py +4 -0
  10. mixing_style_transfer/mixing_manipulator/audio_effects_chain.py +165 -0
  11. mixing_style_transfer/mixing_manipulator/common_audioeffects.py +1537 -0
  12. mixing_style_transfer/mixing_manipulator/common_dataprocessing.py +535 -0
  13. mixing_style_transfer/mixing_manipulator/common_miscellaneous.py +219 -0
  14. mixing_style_transfer/mixing_manipulator/data_normalization.py +173 -0
  15. mixing_style_transfer/mixing_manipulator/fx_utils.py +313 -0
  16. mixing_style_transfer/mixing_manipulator/normalization_imager.py +121 -0
  17. mixing_style_transfer/mixing_manipulator/utils_data_normalization.py +906 -0
  18. mixing_style_transfer/modules/__init__.py +3 -0
  19. mixing_style_transfer/modules/front_back_end.py +226 -0
  20. mixing_style_transfer/modules/loss.py +260 -0
  21. mixing_style_transfer/modules/training_utils.py +174 -0
  22. mixing_style_transfer/networks/__init__.py +2 -0
  23. mixing_style_transfer/networks/architectures.py +290 -0
  24. mixing_style_transfer/networks/configs.yaml +30 -0
  25. mixing_style_transfer/networks/network_utils.py +184 -0
  26. requirements.txt +23 -0
  27. samples/interpolation/#0/input.wav +3 -0
  28. samples/interpolation/#0/reference.wav +3 -0
  29. samples/interpolation/#0/reference_B.wav +3 -0
  30. samples/interpolation/#0/separated/mdx_extra/input/bass.wav +3 -0
  31. samples/interpolation/#0/separated/mdx_extra/input/drums.wav +3 -0
  32. samples/interpolation/#0/separated/mdx_extra/input/other.wav +3 -0
  33. samples/interpolation/#0/separated/mdx_extra/input/vocals.wav +3 -0
  34. samples/interpolation/#0/separated/mdx_extra/reference/bass.wav +3 -0
  35. samples/interpolation/#0/separated/mdx_extra/reference/drums.wav +3 -0
  36. samples/interpolation/#0/separated/mdx_extra/reference/other.wav +3 -0
  37. samples/interpolation/#0/separated/mdx_extra/reference/vocals.wav +3 -0
  38. samples/interpolation/#0/separated/mdx_extra/reference_B/bass.wav +3 -0
  39. samples/interpolation/#0/separated/mdx_extra/reference_B/drums.wav +3 -0
  40. samples/interpolation/#0/separated/mdx_extra/reference_B/other.wav +3 -0
  41. samples/interpolation/#0/separated/mdx_extra/reference_B/vocals.wav +3 -0
  42. samples/style_transfer/#0/input.wav +3 -0
  43. samples/style_transfer/#0/reference.wav +3 -0
  44. samples/style_transfer/#0/separated/mdx_extra/input/bass.wav +3 -0
  45. samples/style_transfer/#0/separated/mdx_extra/input/drums.wav +3 -0
  46. samples/style_transfer/#0/separated/mdx_extra/input/other.wav +3 -0
  47. samples/style_transfer/#0/separated/mdx_extra/input/vocals.wav +3 -0
  48. samples/style_transfer/#0/separated/mdx_extra/reference/bass.wav +3 -0
  49. samples/style_transfer/#0/separated/mdx_extra/reference/drums.wav +3 -0
  50. samples/style_transfer/#0/separated/mdx_extra/reference/other.wav +3 -0
.gitattributes CHANGED
@@ -32,3 +32,38 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ samples/interpolation/\#0/input.wav filter=lfs diff=lfs merge=lfs -text
36
+ samples/interpolation/\#0/reference_B.wav filter=lfs diff=lfs merge=lfs -text
37
+ samples/interpolation/\#0/reference.wav filter=lfs diff=lfs merge=lfs -text
38
+ samples/interpolation/\#0/separated/mdx_extra/input/bass.wav filter=lfs diff=lfs merge=lfs -text
39
+ samples/interpolation/\#0/separated/mdx_extra/input/drums.wav filter=lfs diff=lfs merge=lfs -text
40
+ samples/interpolation/\#0/separated/mdx_extra/input/other.wav filter=lfs diff=lfs merge=lfs -text
41
+ samples/interpolation/\#0/separated/mdx_extra/input/vocals.wav filter=lfs diff=lfs merge=lfs -text
42
+ samples/interpolation/\#0/separated/mdx_extra/reference_B/bass.wav filter=lfs diff=lfs merge=lfs -text
43
+ samples/interpolation/\#0/separated/mdx_extra/reference_B/drums.wav filter=lfs diff=lfs merge=lfs -text
44
+ samples/interpolation/\#0/separated/mdx_extra/reference_B/other.wav filter=lfs diff=lfs merge=lfs -text
45
+ samples/interpolation/\#0/separated/mdx_extra/reference_B/vocals.wav filter=lfs diff=lfs merge=lfs -text
46
+ samples/interpolation/\#0/separated/mdx_extra/reference/bass.wav filter=lfs diff=lfs merge=lfs -text
47
+ samples/interpolation/\#0/separated/mdx_extra/reference/drums.wav filter=lfs diff=lfs merge=lfs -text
48
+ samples/interpolation/\#0/separated/mdx_extra/reference/other.wav filter=lfs diff=lfs merge=lfs -text
49
+ samples/interpolation/\#0/separated/mdx_extra/reference/vocals.wav filter=lfs diff=lfs merge=lfs -text
50
+ samples/style_transfer/\#0/input.wav filter=lfs diff=lfs merge=lfs -text
51
+ samples/style_transfer/\#0/reference.wav filter=lfs diff=lfs merge=lfs -text
52
+ samples/style_transfer/\#0/separated/mdx_extra/input/bass.wav filter=lfs diff=lfs merge=lfs -text
53
+ samples/style_transfer/\#0/separated/mdx_extra/input/drums.wav filter=lfs diff=lfs merge=lfs -text
54
+ samples/style_transfer/\#0/separated/mdx_extra/input/other.wav filter=lfs diff=lfs merge=lfs -text
55
+ samples/style_transfer/\#0/separated/mdx_extra/input/vocals.wav filter=lfs diff=lfs merge=lfs -text
56
+ samples/style_transfer/\#0/separated/mdx_extra/reference/bass.wav filter=lfs diff=lfs merge=lfs -text
57
+ samples/style_transfer/\#0/separated/mdx_extra/reference/drums.wav filter=lfs diff=lfs merge=lfs -text
58
+ samples/style_transfer/\#0/separated/mdx_extra/reference/other.wav filter=lfs diff=lfs merge=lfs -text
59
+ samples/style_transfer/\#0/separated/mdx_extra/reference/vocals.wav filter=lfs diff=lfs merge=lfs -text
60
+ samples/style_transfer/\#2/input.wav filter=lfs diff=lfs merge=lfs -text
61
+ samples/style_transfer/\#2/reference.wav filter=lfs diff=lfs merge=lfs -text
62
+ samples/style_transfer/\#2/separated/mdx_extra/input/bass.wav filter=lfs diff=lfs merge=lfs -text
63
+ samples/style_transfer/\#2/separated/mdx_extra/input/drums.wav filter=lfs diff=lfs merge=lfs -text
64
+ samples/style_transfer/\#2/separated/mdx_extra/input/other.wav filter=lfs diff=lfs merge=lfs -text
65
+ samples/style_transfer/\#2/separated/mdx_extra/input/vocals.wav filter=lfs diff=lfs merge=lfs -text
66
+ samples/style_transfer/\#2/separated/mdx_extra/reference/bass.wav filter=lfs diff=lfs merge=lfs -text
67
+ samples/style_transfer/\#2/separated/mdx_extra/reference/drums.wav filter=lfs diff=lfs merge=lfs -text
68
+ samples/style_transfer/\#2/separated/mdx_extra/reference/other.wav filter=lfs diff=lfs merge=lfs -text
69
+ samples/style_transfer/\#2/separated/mdx_extra/reference/vocals.wav filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,126 @@
1
- ---
2
- title: Music Mixing Style Transfer
3
- emoji: 🏃
4
- colorFrom: gray
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 3.21.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Music Mixing Style Transfer
2
+
3
+ This repository includes source code and pre-trained models of the work *Music Mixing Style Transfer: A Contrastive Learning Approach to Disentangle Audio Effects* by [Junghyun Koo](https://linkedin.com/in/junghyun-koo-525a31251), [Marco A. Martínez-Ramírez](https://m-marco.com/about/), [Wei-Hsiang Liao](https://jp.linkedin.com/in/wei-hsiang-liao-66283154), [Stefan Uhlich](https://scholar.google.de/citations?user=hja8ejYAAAAJ&hl=de), [Kyogu Lee](https://linkedin.com/in/kyogu-lee-7a93b611), and [Yuki Mitsufuji](https://www.yukimitsufuji.com/).
4
+
5
+
6
+ [![arXiv](https://img.shields.io/badge/arXiv-2211.02247-b31b1b.svg)](https://arxiv.org/abs/2211.02247)
7
+ [![Web](https://img.shields.io/badge/Web-Demo_Page-green.svg)](https://jhtonyKoo.github.io/MixingStyleTransfer/)
8
+ [![Supplementary](https://img.shields.io/badge/Supplementary-Materials-white.svg)](https://tinyurl.com/4math4pm)
9
+
10
+
11
+
12
+ ## Pre-trained Models
13
+ | Model | Configuration | Training Dataset |
14
+ |-------------|-------------|-------------|
15
+ [FXencoder (Φ<sub>p.s.</sub>)](https://drive.google.com/file/d/1BFABsJRUVgJS5UE5iuM03dbfBjmI9LT5/view?usp=sharing) | Used *FX normalization* and *probability scheduling* techniques for training | Trained with [MUSDB18](https://sigsep.github.io/datasets/musdb.html) Dataset
16
+ [MixFXcloner](https://drive.google.com/file/d/1Qu8rD7HpTNA1gJUVp2IuaeU_Nue8-VA3/view?usp=sharing) | Mixing style converter trained with Φ<sub>p.s.</sub> | Trained with [MUSDB18](https://sigsep.github.io/datasets/musdb.html) Dataset
17
+
18
+
19
+ ## Installation
20
+ ```
21
+ pip install -r "requirements.txt"
22
+ ```
23
+
24
+ # Inference
25
+
26
+ ## Mixing Style Transfer
27
+
28
+ To run the inference code for <i>mixing style transfer</i>,
29
+ 1. Download pre-trained models above and place them under the folder named 'weights' (default)
30
+ 2. Prepare input and reference tracks under the folder named 'samples/style_transfer' (default)
31
+ Target files should be organized as follow:
32
+ ```
33
+ "path_to_data_directory"/"song_name_#1"/"input_file_name".wav
34
+ "path_to_data_directory"/"song_name_#1"/"reference_file_name".wav
35
+ ...
36
+ "path_to_data_directory"/"song_name_#n"/"input_file_name".wav
37
+ "path_to_data_directory"/"song_name_#n"/"reference_file_name".wav
38
+ ```
39
+ 3. Run 'inference/style_transfer.py'
40
+ ```
41
+ python inference/style_transfer.py \
42
+ --ckpt_path_enc "path_to_checkpoint_of_FXencoder" \
43
+ --ckpt_path_conv "path_to_checkpoint_of_MixFXcloner" \
44
+ --target_dir "path_to_directory_containing_inference_samples"
45
+ ```
46
+ 4. Outputs will be stored under the same folder to inference data directory (default)
47
+
48
+ *Note: The system accepts WAV files of stereo-channeled, 44.1kHZ, and 16-bit rate. We recommend to use audio samples that are not too loud: it's better for the system to transfer these samples by reducing the loudness of mixture-wise inputs (maintaining the overall balance of each instrument).*
49
+
50
+
51
+
52
+ ## Interpolation With 2 Different Reference Tracks
53
+
54
+ Inference code for <interpolating> two reference tracks is almost the same as <i>mixing style transfer</i>.
55
+ 1. Download pre-trained models above and place them under the folder named 'weights' (default)
56
+ 2. Prepare input and 2 reference tracks under the folder named 'samples/style_transfer' (default)
57
+ Target files should be organized as follow:
58
+ ```
59
+ "path_to_data_directory"/"song_name_#1"/"input_track_name".wav
60
+ "path_to_data_directory"/"song_name_#1"/"reference_file_name".wav
61
+ "path_to_data_directory"/"song_name_#1"/"reference_file_name_2interpolate".wav
62
+ ...
63
+ "path_to_data_directory"/"song_name_#n"/"input_track_name".wav
64
+ "path_to_data_directory"/"song_name_#n"/"reference_file_name".wav
65
+ "path_to_data_directory"/"song_name_#n"/"reference_file_name_2interpolate".wav
66
+ ```
67
+ 3. Run 'inference/style_transfer.py'
68
+ ```
69
+ python inference/style_transfer.py \
70
+ --ckpt_path_enc "path_to_checkpoint_of_FXencoder" \
71
+ --ckpt_path_conv "path_to_checkpoint_of_MixFXcloner" \
72
+ --target_dir "path_to_directory_containing_inference_samples" \
73
+ --interpolation True \
74
+ --interpolate_segments "number of segments to perform interpolation"
75
+ ```
76
+ 4. Outputs will be stored under the same folder to inference data directory (default)
77
+
78
+ *Note: This example of interpolating 2 different reference tracks is not mentioned in the paper, but this example implies a potential for controllable style transfer using latent space.*
79
+
80
+
81
+
82
+ ## Feature Extraction Using *FXencoder*
83
+
84
+ This inference code will extracts audio effects-related embeddings using our proposed <i>FXencoder</i>. This code will process all the .wav files under the target directory.
85
+
86
+ 1. Download <i>FXencoder</i>'s pre-trained model above and place it under the folder named 'weights' (default)=
87
+ 2. Run 'inference/style_transfer.py'
88
+ ```
89
+ python inference/feature_extraction.py \
90
+ --ckpt_path_enc "path_to_checkpoint_of_FXencoder" \
91
+ --target_dir "path_to_directory_containing_inference_samples"
92
+ ```
93
+ 3. Outputs will be stored under the same folder to inference data directory (default)
94
+
95
+
96
+
97
+
98
+ # Implementation
99
+
100
+ All the details of our system implementation are under the folder "mixing_style_transfer".
101
+
102
+ <li><i>FXmanipulator</i></li>
103
+ &emsp;&emsp;-> mixing_style_transfer/mixing_manipulator/
104
+ <li>network architectures</li>
105
+ &emsp;&emsp;-> mixing_style_transfer/networks/
106
+ <li>configuration of each sub-networks</li>
107
+ &emsp;&emsp;-> mixing_style_transfer/networks/configs.yaml
108
+ <li>data loader</li>
109
+ &emsp;&emsp;-> mixing_style_transfer/data_loader/
110
+
111
+
112
+ # Citation
113
+
114
+ Please consider citing the work upon usage.
115
+
116
+ ```
117
+ @article{koo2022music,
118
+ title={Music Mixing Style Transfer: A Contrastive Learning Approach to Disentangle Audio Effects},
119
+ author={Koo, Junghyun and Martinez-Ramirez, Marco A and Liao, Wei-Hsiang and Uhlich, Stefan and Lee, Kyogu and Mitsufuji, Yuki},
120
+ journal={arXiv preprint arXiv:2211.02247},
121
+ year={2022}
122
+ }
123
+ ```
124
+
125
+
126
+
inference/configs.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model architecture configurations
2
+
3
+
4
+ # Music Effects Encoder
5
+ Effects_Encoder:
6
+
7
+ default:
8
+ channels: [16, 32, 64, 128, 256, 256, 512, 512, 1024, 1024, 2048, 2048]
9
+ kernels: [25, 25, 15, 15, 10, 10, 10, 10, 5, 5, 5, 5]
10
+ strides: [4, 4, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1]
11
+ dilation: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
12
+ bias: True
13
+ norm: 'batch'
14
+ conv_block: 'res'
15
+ activation: "relu"
16
+
17
+
18
+ # TCN
19
+ TCN:
20
+
21
+ # receptive field = 5.2 seconds
22
+ default:
23
+ condition_dimension: 2048
24
+ nblocks: 14
25
+ dilation_growth: 2
26
+ kernel_size: 15
27
+ channel_width: 128
28
+ stack_size: 15
29
+ causal: False
30
+
inference/feature_extraction.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference code of extracting embeddings from music recordings using FXencoder
3
+ of the work "Music Mixing Style Transfer: A Contrastive Learning Approach to Disentangle Audio Effects"
4
+
5
+ Process : extracts FX embeddings of each song inside the target directory.
6
+ """
7
+ from glob import glob
8
+ import os
9
+ import librosa
10
+ import numpy as np
11
+ import torch
12
+
13
+ import sys
14
+ currentdir = os.path.dirname(os.path.realpath(__file__))
15
+ sys.path.append(os.path.join(os.path.dirname(currentdir), "mixing_style_transfer"))
16
+ from networks import FXencoder
17
+ from data_loader import *
18
+
19
+
20
+ class FXencoder_Inference:
21
+ def __init__(self, args, trained_w_ddp=True):
22
+ if args.inference_device!='cpu' and torch.cuda.is_available():
23
+ self.device = torch.device("cuda:0")
24
+ else:
25
+ self.device = torch.device("cpu")
26
+
27
+ # inference computational hyperparameters
28
+ self.segment_length = args.segment_length
29
+ self.batch_size = args.batch_size
30
+ self.sample_rate = 44100 # sampling rate should be 44100
31
+ self.time_in_seconds = int(args.segment_length // self.sample_rate)
32
+
33
+ # directory configuration
34
+ self.output_dir = args.target_dir if args.output_dir==None else args.output_dir
35
+ self.target_dir = args.target_dir
36
+
37
+ # load model and its checkpoint weights
38
+ self.models = {}
39
+ self.models['effects_encoder'] = FXencoder(args.cfg_encoder).to(self.device)
40
+ ckpt_paths = {'effects_encoder' : args.ckpt_path_enc}
41
+ # reload saved model weights
42
+ ddp = trained_w_ddp
43
+ self.reload_weights(ckpt_paths, ddp=ddp)
44
+
45
+ # save current arguments
46
+ self.save_args(args)
47
+
48
+
49
+ # reload model weights from the target checkpoint path
50
+ def reload_weights(self, ckpt_paths, ddp=True):
51
+ for cur_model_name in self.models.keys():
52
+ checkpoint = torch.load(ckpt_paths[cur_model_name], map_location=self.device)
53
+
54
+ from collections import OrderedDict
55
+ new_state_dict = OrderedDict()
56
+ for k, v in checkpoint["model"].items():
57
+ # remove `module.` if the model was trained with DDP
58
+ name = k[7:] if ddp else k
59
+ new_state_dict[name] = v
60
+
61
+ # load params
62
+ self.models[cur_model_name].load_state_dict(new_state_dict)
63
+
64
+ print(f"---reloaded checkpoint weights : {cur_model_name} ---")
65
+
66
+
67
+ # save averaged embedding from whole songs
68
+ def save_averaged_embeddings(self, ):
69
+ # embedding output directory path
70
+ emb_out_dir = f"{self.output_dir}"
71
+ print(f'\n\n=====Inference seconds : {self.time_in_seconds}=====')
72
+
73
+ # target_file_paths = glob(f"{self.target_dir}/**/*.wav", recursive=True)
74
+ target_file_paths = glob(os.path.join(self.target_dir, '**', '*.wav'), recursive=True)
75
+ for step, target_file_path in enumerate(target_file_paths):
76
+ print(f"\nInference step : {step+1}/{len(target_file_paths)}")
77
+ print(f"---current file path : {target_file_path}---")
78
+
79
+ ''' load waveform signal '''
80
+ target_song_whole = load_wav_segment(target_file_path, axis=0)
81
+ # check if mono -> convert to stereo by duplicating mono signal
82
+ if len(target_song_whole.shape)==1:
83
+ target_song_whole = np.stack((target_song_whole, target_song_whole), axis=0)
84
+ # check axis dimension
85
+ # signal shape should be : [channel, signal duration]
86
+ elif target_song_whole.shape[1]==2:
87
+ target_song_whole = target_song_whole.transpose()
88
+ target_song_whole = torch.from_numpy(target_song_whole).float()
89
+ ''' segmentize whole songs into batch '''
90
+ whole_batch_data = self.batchwise_segmentization(target_song_whole, target_file_path)
91
+
92
+ ''' inference '''
93
+ # infer whole song
94
+ infered_data_list = []
95
+ infered_c_list = []
96
+ infered_z_list = []
97
+ for cur_idx, cur_data in enumerate(whole_batch_data):
98
+ cur_data = cur_data.to(self.device)
99
+
100
+ with torch.no_grad():
101
+ self.models["effects_encoder"].eval()
102
+ # FXencoder
103
+ out_c_emb = self.models["effects_encoder"](cur_data)
104
+ infered_c_list.append(out_c_emb.cpu().detach())
105
+ avg_c_feat = torch.mean(torch.cat(infered_c_list, dim=0), dim=0).squeeze().cpu().detach().numpy()
106
+
107
+ # save outputs
108
+ cur_output_path = target_file_path.replace(self.target_dir, self.output_dir).replace('.wav', '_fx_embedding.npy')
109
+ os.makedirs(os.path.dirname(cur_output_path), exist_ok=True)
110
+ np.save(cur_output_path, avg_c_feat)
111
+
112
+
113
+ # function that segmentize an entire song into batch
114
+ def batchwise_segmentization(self, target_song, target_file_path, discard_last=False):
115
+ assert target_song.shape[-1] >= self.segment_length, \
116
+ f"Error : Insufficient duration!\n\t \
117
+ Target song's length is shorter than segment length.\n\t \
118
+ Song name : {target_file_path}\n\t \
119
+ Consider changing the 'segment_length' or song with sufficient duration"
120
+
121
+ # discard restovers (last segment)
122
+ if discard_last:
123
+ target_length = target_song.shape[-1] - target_song.shape[-1] % self.segment_length
124
+ target_song = target_song[:, :target_length]
125
+ # pad last segment
126
+ else:
127
+ pad_length = self.segment_length - target_song.shape[-1] % self.segment_length
128
+ target_song = torch.cat((target_song, torch.zeros(2, pad_length)), axis=-1)
129
+
130
+ whole_batch_data = []
131
+ batch_wise_data = []
132
+ for cur_segment_idx in range(target_song.shape[-1]//self.segment_length):
133
+ batch_wise_data.append(target_song[..., cur_segment_idx*self.segment_length:(cur_segment_idx+1)*self.segment_length])
134
+ if len(batch_wise_data)==self.batch_size:
135
+ whole_batch_data.append(torch.stack(batch_wise_data, dim=0))
136
+ batch_wise_data = []
137
+ if batch_wise_data:
138
+ whole_batch_data.append(torch.stack(batch_wise_data, dim=0))
139
+
140
+ return whole_batch_data
141
+
142
+
143
+ # save current inference arguments
144
+ def save_args(self, params):
145
+ info = '\n[args]\n'
146
+ for sub_args in parser._action_groups:
147
+ if sub_args.title in ['positional arguments', 'optional arguments', 'options']:
148
+ continue
149
+ size_sub = len(sub_args._group_actions)
150
+ info += f' {sub_args.title} ({size_sub})\n'
151
+ for i, arg in enumerate(sub_args._group_actions):
152
+ prefix = '-'
153
+ info += f' {prefix} {arg.dest:20s}: {getattr(params, arg.dest)}\n'
154
+ info += '\n'
155
+
156
+ os.makedirs(self.output_dir, exist_ok=True)
157
+ record_path = f"{self.output_dir}feature_extraction_inference_configurations.txt"
158
+ f = open(record_path, 'w')
159
+ np.savetxt(f, [info], delimiter=" ", fmt="%s")
160
+ f.close()
161
+
162
+
163
+
164
+ if __name__ == '__main__':
165
+ ''' Configurations for inferencing music effects encoder '''
166
+ currentdir = os.path.dirname(os.path.realpath(__file__))
167
+ default_ckpt_path = os.path.join(os.path.dirname(currentdir), 'weights', 'FXencoder_ps.pt')
168
+
169
+ import argparse
170
+ import yaml
171
+ parser = argparse.ArgumentParser()
172
+
173
+ directory_args = parser.add_argument_group('Directory args')
174
+ directory_args.add_argument('--target_dir', type=str, default='./samples/')
175
+ directory_args.add_argument('--output_dir', type=str, default=None, help='if no output_dir is specified (None), the results will be saved inside the target_dir')
176
+ directory_args.add_argument('--ckpt_path_enc', type=str, default=default_ckpt_path)
177
+
178
+ inference_args = parser.add_argument_group('Inference args')
179
+ inference_args.add_argument('--segment_length', type=int, default=44100*10) # segmentize input according to this duration
180
+ inference_args.add_argument('--batch_size', type=int, default=1) # for processing long audio
181
+ inference_args.add_argument('--inference_device', type=str, default='cpu', help="if this option is not set to 'cpu', inference will happen on gpu only if there is a detected one")
182
+
183
+ args = parser.parse_args()
184
+
185
+ # load network configurations
186
+ with open(os.path.join(currentdir, 'configs.yaml'), 'r') as f:
187
+ configs = yaml.full_load(f)
188
+ args.cfg_encoder = configs['Effects_Encoder']['default']
189
+
190
+ # Extract features using pre-trained FXencoder
191
+ inference_encoder = FXencoder_Inference(args)
192
+ inference_encoder.save_averaged_embeddings()
193
+
194
+
inference/style_transfer.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference code of music style transfer
3
+ of the work "Music Mixing Style Transfer: A Contrastive Learning Approach to Disentangle Audio Effects"
4
+
5
+ Process : converts the mixing style of the input music recording to that of the refernce music.
6
+ files inside the target directory should be organized as follow
7
+ "path_to_data_directory"/"song_name_#1"/input.wav
8
+ "path_to_data_directory"/"song_name_#1"/reference.wav
9
+ ...
10
+ "path_to_data_directory"/"song_name_#n"/input.wav
11
+ "path_to_data_directory"/"song_name_#n"/reference.wav
12
+ where the 'input' and 'reference' should share the same names.
13
+ """
14
+ import numpy as np
15
+ from glob import glob
16
+ import os
17
+ import torch
18
+
19
+ import sys
20
+ currentdir = os.path.dirname(os.path.realpath(__file__))
21
+ sys.path.append(os.path.join(os.path.dirname(currentdir), "mixing_style_transfer"))
22
+ from networks import FXencoder, TCNModel
23
+ from data_loader import *
24
+
25
+
26
+
27
+ class Mixing_Style_Transfer_Inference:
28
+ def __init__(self, args, trained_w_ddp=True):
29
+ if args.inference_device!='cpu' and torch.cuda.is_available():
30
+ self.device = torch.device("cuda:0")
31
+ else:
32
+ self.device = torch.device("cpu")
33
+
34
+ # inference computational hyperparameters
35
+ self.args = args
36
+ self.segment_length = args.segment_length
37
+ self.batch_size = args.batch_size
38
+ self.sample_rate = 44100 # sampling rate should be 44100
39
+ self.time_in_seconds = int(args.segment_length // self.sample_rate)
40
+
41
+ # directory configuration
42
+ self.output_dir = args.target_dir if args.output_dir==None else args.output_dir
43
+ self.target_dir = args.target_dir
44
+
45
+ # load model and its checkpoint weights
46
+ self.models = {}
47
+ self.models['effects_encoder'] = FXencoder(args.cfg_encoder).to(self.device)
48
+ self.models['mixing_converter'] = TCNModel(nparams=args.cfg_converter["condition_dimension"], \
49
+ ninputs=2, \
50
+ noutputs=2, \
51
+ nblocks=args.cfg_converter["nblocks"], \
52
+ dilation_growth=args.cfg_converter["dilation_growth"], \
53
+ kernel_size=args.cfg_converter["kernel_size"], \
54
+ channel_width=args.cfg_converter["channel_width"], \
55
+ stack_size=args.cfg_converter["stack_size"], \
56
+ cond_dim=args.cfg_converter["condition_dimension"], \
57
+ causal=args.cfg_converter["causal"]).to(self.device)
58
+
59
+ ckpt_paths = {'effects_encoder' : args.ckpt_path_enc, \
60
+ 'mixing_converter' : args.ckpt_path_conv}
61
+ # reload saved model weights
62
+ ddp = trained_w_ddp
63
+ self.reload_weights(ckpt_paths, ddp=ddp)
64
+
65
+ # load data loader for the inference procedure
66
+ inference_dataset = Song_Dataset_Inference(args)
67
+ self.data_loader = DataLoader(inference_dataset, \
68
+ batch_size=1, \
69
+ shuffle=False, \
70
+ num_workers=args.workers, \
71
+ drop_last=False)
72
+
73
+ # save current arguments
74
+ self.save_args(args)
75
+
76
+ ''' check stem-wise result '''
77
+ if not self.args.do_not_separate:
78
+ os.environ['MKL_THREADING_LAYER'] = 'GNU'
79
+ separate_file_names = [args.input_file_name, args.reference_file_name]
80
+ if self.args.interpolation:
81
+ separate_file_names.append(args.reference_file_name_2interpolate)
82
+ for cur_idx, cur_inf_dir in enumerate(sorted(glob(f"{args.target_dir}*/"))):
83
+ for cur_file_name in separate_file_names:
84
+ cur_sep_file_path = os.path.join(cur_inf_dir, cur_file_name+'.wav')
85
+ cur_sep_output_dir = os.path.join(cur_inf_dir, args.stem_level_directory_name)
86
+ if os.path.exists(os.path.join(cur_sep_output_dir, self.args.separation_model, cur_file_name, 'drums.wav')):
87
+ print(f'\talready separated current file : {cur_sep_file_path}')
88
+ else:
89
+ cur_cmd_line = f"demucs {cur_sep_file_path} -n {self.args.separation_model} -d {self.args.separation_device} -o {cur_sep_output_dir}"
90
+ os.system(cur_cmd_line)
91
+
92
+
93
+ # reload model weights from the target checkpoint path
94
+ def reload_weights(self, ckpt_paths, ddp=True):
95
+ for cur_model_name in self.models.keys():
96
+ checkpoint = torch.load(ckpt_paths[cur_model_name], map_location=self.device)
97
+
98
+ from collections import OrderedDict
99
+ new_state_dict = OrderedDict()
100
+ for k, v in checkpoint["model"].items():
101
+ # remove `module.` if the model was trained with DDP
102
+ name = k[7:] if ddp else k
103
+ new_state_dict[name] = v
104
+
105
+ # load params
106
+ self.models[cur_model_name].load_state_dict(new_state_dict)
107
+
108
+ print(f"---reloaded checkpoint weights : {cur_model_name} ---")
109
+
110
+
111
+ # Inference whole song
112
+ def inference(self, ):
113
+ print("\n======= Start to inference music mixing style transfer =======")
114
+ # normalized input
115
+ output_name_tag = 'output' if self.args.normalize_input else 'output_notnormed'
116
+
117
+ for step, (input_stems, reference_stems, dir_name) in enumerate(self.data_loader):
118
+ print(f"---inference file name : {dir_name[0]}---")
119
+ cur_out_dir = dir_name[0].replace(self.target_dir, self.output_dir)
120
+ os.makedirs(cur_out_dir, exist_ok=True)
121
+ ''' stem-level inference '''
122
+ inst_outputs = []
123
+ for cur_inst_idx, cur_inst_name in enumerate(self.args.instruments):
124
+ print(f'\t{cur_inst_name}...')
125
+ ''' segmentize whole songs into batch '''
126
+ if len(input_stems[0][cur_inst_idx][0]) > self.args.segment_length:
127
+ cur_inst_input_stem = self.batchwise_segmentization(input_stems[0][cur_inst_idx], \
128
+ dir_name[0], \
129
+ segment_length=self.args.segment_length, \
130
+ discard_last=False)
131
+ else:
132
+ cur_inst_input_stem = [input_stems[:, cur_inst_idx]]
133
+ if len(reference_stems[0][cur_inst_idx][0]) > self.args.segment_length*2:
134
+ cur_inst_reference_stem = self.batchwise_segmentization(reference_stems[0][cur_inst_idx], \
135
+ dir_name[0], \
136
+ segment_length=self.args.segment_length_ref, \
137
+ discard_last=False)
138
+ else:
139
+ cur_inst_reference_stem = [reference_stems[:, cur_inst_idx]]
140
+
141
+ ''' inference '''
142
+ # first extract reference style embedding
143
+ infered_ref_data_list = []
144
+ for cur_ref_data in cur_inst_reference_stem:
145
+ cur_ref_data = cur_ref_data.to(self.device)
146
+ # Effects Encoder inference
147
+ with torch.no_grad():
148
+ self.models["effects_encoder"].eval()
149
+ reference_feature = self.models["effects_encoder"](cur_ref_data)
150
+ infered_ref_data_list.append(reference_feature)
151
+ # compute average value from the extracted exbeddings
152
+ infered_ref_data = torch.stack(infered_ref_data_list)
153
+ infered_ref_data_avg = torch.mean(infered_ref_data.reshape(infered_ref_data.shape[0]*infered_ref_data.shape[1], infered_ref_data.shape[2]), axis=0)
154
+
155
+ # mixing style converter
156
+ infered_data_list = []
157
+ for cur_data in cur_inst_input_stem:
158
+ cur_data = cur_data.to(self.device)
159
+ with torch.no_grad():
160
+ self.models["mixing_converter"].eval()
161
+ infered_data = self.models["mixing_converter"](cur_data, infered_ref_data_avg.unsqueeze(0))
162
+ infered_data_list.append(infered_data.cpu().detach())
163
+
164
+ # combine back to whole song
165
+ for cur_idx, cur_batch_infered_data in enumerate(infered_data_list):
166
+ cur_infered_data_sequential = torch.cat(torch.unbind(cur_batch_infered_data, dim=0), dim=-1)
167
+ fin_data_out = cur_infered_data_sequential if cur_idx==0 else torch.cat((fin_data_out, cur_infered_data_sequential), dim=-1)
168
+ # final output of current instrument
169
+ fin_data_out_inst = fin_data_out[:, :input_stems[0][cur_inst_idx].shape[-1]].numpy()
170
+
171
+ inst_outputs.append(fin_data_out_inst)
172
+ # save output of each instrument
173
+ if self.args.save_each_inst:
174
+ sf.write(os.path.join(cur_out_dir, f"{cur_inst_name}_{output_name_tag}.wav"), fin_data_out_inst.transpose(-1, -2), self.args.sample_rate, 'PCM_16')
175
+ # remix
176
+ fin_data_out_mix = sum(inst_outputs)
177
+ sf.write(os.path.join(cur_out_dir, f"mixture_{output_name_tag}.wav"), fin_data_out_mix.transpose(-1, -2), self.args.sample_rate, 'PCM_16')
178
+
179
+
180
+ # Inference whole song
181
+ def inference_interpolation(self, ):
182
+ print("\n======= Start to inference interpolation examples =======")
183
+ # normalized input
184
+ output_name_tag = 'output_interpolation' if self.args.normalize_input else 'output_notnormed_interpolation'
185
+
186
+ for step, (input_stems, reference_stems_A, reference_stems_B, dir_name) in enumerate(self.data_loader):
187
+ print(f"---inference file name : {dir_name[0]}---")
188
+ cur_out_dir = dir_name[0].replace(self.target_dir, self.output_dir)
189
+ os.makedirs(cur_out_dir, exist_ok=True)
190
+ ''' stem-level inference '''
191
+ inst_outputs = []
192
+ for cur_inst_idx, cur_inst_name in enumerate(self.args.instruments):
193
+ print(f'\t{cur_inst_name}...')
194
+ ''' segmentize whole song '''
195
+ # segmentize input according to number of interpolating segments
196
+ interpolate_segment_length = input_stems[0][cur_inst_idx].shape[1] // self.args.interpolate_segments + 1
197
+ cur_inst_input_stem = self.batchwise_segmentization(input_stems[0][cur_inst_idx], \
198
+ dir_name[0], \
199
+ segment_length=interpolate_segment_length, \
200
+ discard_last=False)
201
+ # batchwise segmentize 2 reference tracks
202
+ if len(reference_stems_A[0][cur_inst_idx][0]) > self.args.segment_length_ref:
203
+ cur_inst_reference_stem_A = self.batchwise_segmentization(reference_stems_A[0][cur_inst_idx], \
204
+ dir_name[0], \
205
+ segment_length=self.args.segment_length_ref, \
206
+ discard_last=False)
207
+ else:
208
+ cur_inst_reference_stem_A = [reference_stems_A[:, cur_inst_idx]]
209
+ if len(reference_stems_B[0][cur_inst_idx][0]) > self.args.segment_length_ref:
210
+ cur_inst_reference_stem_B = self.batchwise_segmentization(reference_stems_B[0][cur_inst_idx], \
211
+ dir_name[0], \
212
+ segment_length=self.args.segment_length, \
213
+ discard_last=False)
214
+ else:
215
+ cur_inst_reference_stem_B = [reference_stems_B[:, cur_inst_idx]]
216
+
217
+ ''' inference '''
218
+ # first extract reference style embeddings
219
+ # reference A
220
+ infered_ref_data_list = []
221
+ for cur_ref_data in cur_inst_reference_stem_A:
222
+ cur_ref_data = cur_ref_data.to(self.device)
223
+ # Effects Encoder inference
224
+ with torch.no_grad():
225
+ self.models["effects_encoder"].eval()
226
+ reference_feature = self.models["effects_encoder"](cur_ref_data)
227
+ infered_ref_data_list.append(reference_feature)
228
+ # compute average value from the extracted exbeddings
229
+ infered_ref_data = torch.stack(infered_ref_data_list)
230
+ infered_ref_data_avg_A = torch.mean(infered_ref_data.reshape(infered_ref_data.shape[0]*infered_ref_data.shape[1], infered_ref_data.shape[2]), axis=0)
231
+
232
+ # reference B
233
+ infered_ref_data_list = []
234
+ for cur_ref_data in cur_inst_reference_stem_B:
235
+ cur_ref_data = cur_ref_data.to(self.device)
236
+ # Effects Encoder inference
237
+ with torch.no_grad():
238
+ self.models["effects_encoder"].eval()
239
+ reference_feature = self.models["effects_encoder"](cur_ref_data)
240
+ infered_ref_data_list.append(reference_feature)
241
+ # compute average value from the extracted exbeddings
242
+ infered_ref_data = torch.stack(infered_ref_data_list)
243
+ infered_ref_data_avg_B = torch.mean(infered_ref_data.reshape(infered_ref_data.shape[0]*infered_ref_data.shape[1], infered_ref_data.shape[2]), axis=0)
244
+
245
+ # mixing style converter
246
+ infered_data_list = []
247
+ for cur_idx, cur_data in enumerate(cur_inst_input_stem):
248
+ cur_data = cur_data.to(self.device)
249
+ # perform linear interpolation on embedding space
250
+ cur_weight = (self.args.interpolate_segments-1-cur_idx) / (self.args.interpolate_segments-1)
251
+ cur_ref_emb = cur_weight * infered_ref_data_avg_A + (1-cur_weight) * infered_ref_data_avg_B
252
+ with torch.no_grad():
253
+ self.models["mixing_converter"].eval()
254
+ infered_data = self.models["mixing_converter"](cur_data, cur_ref_emb.unsqueeze(0))
255
+ infered_data_list.append(infered_data.cpu().detach())
256
+
257
+ # combine back to whole song
258
+ for cur_idx, cur_batch_infered_data in enumerate(infered_data_list):
259
+ cur_infered_data_sequential = torch.cat(torch.unbind(cur_batch_infered_data, dim=0), dim=-1)
260
+ fin_data_out = cur_infered_data_sequential if cur_idx==0 else torch.cat((fin_data_out, cur_infered_data_sequential), dim=-1)
261
+ # final output of current instrument
262
+ fin_data_out_inst = fin_data_out[:, :input_stems[0][cur_inst_idx].shape[-1]].numpy()
263
+ inst_outputs.append(fin_data_out_inst)
264
+
265
+ # save output of each instrument
266
+ if self.args.save_each_inst:
267
+ sf.write(os.path.join(cur_out_dir, f"{cur_inst_name}_{output_name_tag}.wav"), fin_data_out_inst.transpose(-1, -2), self.args.sample_rate, 'PCM_16')
268
+ # remix
269
+ fin_data_out_mix = sum(inst_outputs)
270
+ sf.write(os.path.join(cur_out_dir, f"mixture_{output_name_tag}.wav"), fin_data_out_mix.transpose(-1, -2), self.args.sample_rate, 'PCM_16')
271
+
272
+
273
+ # function that segmentize an entire song into batch
274
+ def batchwise_segmentization(self, target_song, song_name, segment_length, discard_last=False):
275
+ assert target_song.shape[-1] >= self.args.segment_length, \
276
+ f"Error : Insufficient duration!\n\t \
277
+ Target song's length is shorter than segment length.\n\t \
278
+ Song name : {song_name}\n\t \
279
+ Consider changing the 'segment_length' or song with sufficient duration"
280
+
281
+ # discard restovers (last segment)
282
+ if discard_last:
283
+ target_length = target_song.shape[-1] - target_song.shape[-1] % segment_length
284
+ target_song = target_song[:, :target_length]
285
+ # pad last segment
286
+ else:
287
+ pad_length = segment_length - target_song.shape[-1] % segment_length
288
+ target_song = torch.cat((target_song, torch.zeros(2, pad_length)), axis=-1)
289
+
290
+ # segmentize according to the given segment_length
291
+ whole_batch_data = []
292
+ batch_wise_data = []
293
+ for cur_segment_idx in range(target_song.shape[-1]//segment_length):
294
+ batch_wise_data.append(target_song[..., cur_segment_idx*segment_length:(cur_segment_idx+1)*segment_length])
295
+ if len(batch_wise_data)==self.args.batch_size:
296
+ whole_batch_data.append(torch.stack(batch_wise_data, dim=0))
297
+ batch_wise_data = []
298
+ if batch_wise_data:
299
+ whole_batch_data.append(torch.stack(batch_wise_data, dim=0))
300
+
301
+ return whole_batch_data
302
+
303
+
304
+ # save current inference arguments
305
+ def save_args(self, params):
306
+ info = '\n[args]\n'
307
+ for sub_args in parser._action_groups:
308
+ if sub_args.title in ['positional arguments', 'optional arguments', 'options']:
309
+ continue
310
+ size_sub = len(sub_args._group_actions)
311
+ info += f' {sub_args.title} ({size_sub})\n'
312
+ for i, arg in enumerate(sub_args._group_actions):
313
+ prefix = '-'
314
+ info += f' {prefix} {arg.dest:20s}: {getattr(params, arg.dest)}\n'
315
+ info += '\n'
316
+
317
+ os.makedirs(self.output_dir, exist_ok=True)
318
+ record_path = f"{self.output_dir}style_transfer_inference_configurations.txt"
319
+ f = open(record_path, 'w')
320
+ np.savetxt(f, [info], delimiter=" ", fmt="%s")
321
+ f.close()
322
+
323
+
324
+
325
+ if __name__ == '__main__':
326
+ os.environ['MASTER_ADDR'] = '127.0.0.1'
327
+ os.environ["CUDA_VISIBLE_DEVICES"] = '0'
328
+ os.environ['MASTER_PORT'] = '8888'
329
+
330
+ def str2bool(v):
331
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
332
+ return True
333
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
334
+ return False
335
+ else:
336
+ raise argparse.ArgumentTypeError('Boolean value expected.')
337
+
338
+ ''' Configurations for music mixing style transfer '''
339
+ currentdir = os.path.dirname(os.path.realpath(__file__))
340
+ default_ckpt_path_enc = os.path.join(os.path.dirname(currentdir), 'weights', 'FXencoder_ps.pt')
341
+ default_ckpt_path_conv = os.path.join(os.path.dirname(currentdir), 'weights', 'MixFXcloner_ps.pt')
342
+ default_norm_feature_path = os.path.join(os.path.dirname(currentdir), 'weights', 'musdb18_fxfeatures_eqcompimagegain.npy')
343
+
344
+ import argparse
345
+ import yaml
346
+ parser = argparse.ArgumentParser()
347
+
348
+ directory_args = parser.add_argument_group('Directory args')
349
+ # directory paths
350
+ directory_args.add_argument('--target_dir', type=str, default='./samples/style_transfer/')
351
+ directory_args.add_argument('--output_dir', type=str, default=None, help='if no output_dir is specified (None), the results will be saved inside the target_dir')
352
+ directory_args.add_argument('--input_file_name', type=str, default='input')
353
+ directory_args.add_argument('--reference_file_name', type=str, default='reference')
354
+ directory_args.add_argument('--reference_file_name_2interpolate', type=str, default='reference_B')
355
+ # saved weights
356
+ directory_args.add_argument('--ckpt_path_enc', type=str, default=default_ckpt_path_enc)
357
+ directory_args.add_argument('--ckpt_path_conv', type=str, default=default_ckpt_path_conv)
358
+ directory_args.add_argument('--precomputed_normalization_feature', type=str, default=default_norm_feature_path)
359
+
360
+ inference_args = parser.add_argument_group('Inference args')
361
+ inference_args.add_argument('--sample_rate', type=int, default=44100)
362
+ inference_args.add_argument('--segment_length', type=int, default=2**19) # segmentize input according to this duration
363
+ inference_args.add_argument('--segment_length_ref', type=int, default=2**19) # segmentize reference according to this duration
364
+ # stem-level instruments & separation
365
+ inference_args.add_argument('--instruments', type=str2bool, default=["drums", "bass", "other", "vocals"], help='instrumental tracks to perform style transfer')
366
+ inference_args.add_argument('--stem_level_directory_name', type=str, default='separated')
367
+ inference_args.add_argument('--save_each_inst', type=str2bool, default=False)
368
+ inference_args.add_argument('--do_not_separate', type=str2bool, default=False)
369
+ inference_args.add_argument('--separation_model', type=str, default='mdx_extra')
370
+ # FX normalization
371
+ inference_args.add_argument('--normalize_input', type=str2bool, default=True)
372
+ inference_args.add_argument('--normalization_order', type=str2bool, default=['loudness', 'eq', 'compression', 'imager', 'loudness']) # Effects to be normalized, order matters
373
+ # interpolation
374
+ inference_args.add_argument('--interpolation', type=str2bool, default=False)
375
+ inference_args.add_argument('--interpolate_segments', type=int, default=30)
376
+
377
+ device_args = parser.add_argument_group('Device args')
378
+ device_args.add_argument('--workers', type=int, default=1)
379
+ device_args.add_argument('--inference_device', type=str, default='gpu', help="if this option is not set to 'cpu', inference will happen on gpu only if there is a detected one")
380
+ device_args.add_argument('--batch_size', type=int, default=1) # for processing long audio
381
+ device_args.add_argument('--separation_device', type=str, default='cpu', help="device for performing source separation using Demucs")
382
+
383
+ args = parser.parse_args()
384
+
385
+ # load network configurations
386
+ with open(os.path.join(currentdir, 'configs.yaml'), 'r') as f:
387
+ configs = yaml.full_load(f)
388
+ args.cfg_encoder = configs['Effects_Encoder']['default']
389
+ args.cfg_converter = configs['TCN']['default']
390
+
391
+
392
+ # Perform music mixing style transfer
393
+ inference_style_transfer = Mixing_Style_Transfer_Inference(args)
394
+ if args.interpolation:
395
+ inference_style_transfer.inference_interpolation()
396
+ else:
397
+ inference_style_transfer.inference()
398
+
399
+
400
+
mixing_style_transfer/data_loader/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .data_loader import *
2
+ from .loader_utils import *
mixing_style_transfer/data_loader/data_loader.py ADDED
@@ -0,0 +1,672 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data Loaders for
3
+ 1. contrastive learning of audio effects
4
+ 2. music mixing style transfer
5
+ introduced in "Music Mixing Style Transfer: A Contrastive Learning Approach to Disentangle Audio Effects"
6
+ """
7
+ import numpy as np
8
+ import wave
9
+ import soundfile as sf
10
+ import time
11
+ import random
12
+ from glob import glob
13
+
14
+ import torch
15
+ import torch.utils.data as data
16
+ from torch.utils.data import DataLoader
17
+ from torch.utils.data import Dataset
18
+
19
+ import os
20
+ import sys
21
+ currentdir = os.path.dirname(os.path.realpath(__file__))
22
+ sys.path.append(currentdir)
23
+ sys.path.append(os.path.dirname(currentdir))
24
+ sys.path.append(os.path.dirname(os.path.dirname(currentdir)))
25
+ from loader_utils import *
26
+ from mixing_manipulator import *
27
+
28
+
29
+
30
+ '''
31
+ Collate Functions
32
+ '''
33
+ class Collate_Variable_Length_Segments:
34
+ def __init__(self, args):
35
+ self.segment_length = args.segment_length
36
+ self.random_length = args.reference_length
37
+ self.num_strong_negatives = args.num_strong_negatives
38
+ if 'musdb' in args.using_dataset.lower():
39
+ self.instruments = ["drums", "bass", "other", "vocals"]
40
+ else:
41
+ raise NotImplementedError
42
+
43
+
44
+ # collate function to trim segments A and B to random duration
45
+ # this function can handle different number of 'strong negative' inputs
46
+ def random_duration_segments_strong_negatives(self, batch):
47
+ num_inst = len(self.instruments)
48
+ # randomize current input length
49
+ max_length = batch[0][0].shape[-1]
50
+ min_length = max_length//2
51
+ input_length_a, input_length_b = torch.randint(low=min_length, high=max_length, size=(2,))
52
+
53
+ output_dict_A = {}
54
+ output_dict_B = {}
55
+ for cur_inst in self.instruments:
56
+ output_dict_A[cur_inst] = []
57
+ output_dict_B[cur_inst] = []
58
+ for cur_item in batch:
59
+ # set starting points
60
+ start_point_a = torch.randint(low=0, high=max_length-input_length_a, size=(1,))[0]
61
+ start_point_b = torch.randint(low=0, high=max_length-input_length_b, size=(1,))[0]
62
+ # append to output dictionary
63
+ for cur_i, cur_inst in enumerate(self.instruments):
64
+ # append A# and B# with its strong negative samples
65
+ for cur_neg_idx in range(self.num_strong_negatives+1):
66
+ output_dict_A[cur_inst].append(cur_item[cur_i*(self.num_strong_negatives+1)*2+2*cur_neg_idx][:, start_point_a : start_point_a+input_length_a])
67
+ output_dict_B[cur_inst].append(cur_item[cur_i*(self.num_strong_negatives+1)*2+1+2*cur_neg_idx][:, start_point_b : start_point_b+input_length_b])
68
+
69
+ '''
70
+ Output format :
71
+ [drums_A, bass_A, other_A, vocals_A],
72
+ [drums_B, bass_B, other_B, vocals_B]
73
+ '''
74
+ return [torch.stack(cur_segments, dim=0) for cur_inst, cur_segments in output_dict_A.items()], \
75
+ [torch.stack(cur_segments, dim=0) for cur_inst, cur_segments in output_dict_B.items()]
76
+
77
+
78
+ # collate function for training mixing style transfer
79
+ def style_transfer_collate(self, batch):
80
+ output_dict_A1 = {}
81
+ output_dict_A2 = {}
82
+ output_dict_B2 = {}
83
+ for cur_inst in self.instruments:
84
+ output_dict_A1[cur_inst] = []
85
+ output_dict_A2[cur_inst] = []
86
+ output_dict_B2[cur_inst] = []
87
+ for cur_item in batch:
88
+ # append to output dictionary
89
+ for cur_i, cur_inst in enumerate(self.instruments):
90
+ output_dict_A1[cur_inst].append(cur_item[cur_i*3])
91
+ output_dict_A2[cur_inst].append(cur_item[cur_i*3+1])
92
+ output_dict_B2[cur_inst].append(cur_item[cur_i*3+2])
93
+
94
+ '''
95
+ Output format :
96
+ [drums_A1, bass_A1, other_A1, vocals_A1],
97
+ [drums_A2, bass_A2, other_A2, vocals_A2],
98
+ [drums_B2, bass_B2, other_B2, vocals_B2]
99
+ '''
100
+ return [torch.stack(cur_segments, dim=0) for cur_inst, cur_segments in output_dict_A1.items()], \
101
+ [torch.stack(cur_segments, dim=0) for cur_inst, cur_segments in output_dict_A2.items()], \
102
+ [torch.stack(cur_segments, dim=0) for cur_inst, cur_segments in output_dict_B2.items()]
103
+
104
+
105
+ '''
106
+ Data Loaders
107
+ '''
108
+
109
+ # Data loader for training the 'FXencoder'
110
+ # randomly loads two segments (A and B) from the dataset
111
+ # both segments are manipulated via FXmanipulator using (1+number of strong negative samples) sets of parameters (resulting A1, A2, ..., A#, and B1, B2, ..., B#) (# = number of strong negative samples)
112
+ # segments with the same effects applied (A1 and B1) are assigned as the positive pair during the training
113
+ # segments with the same content but with different effects applied (A2, A3, ..., A3 for A1) are also formed in a batch as 'strong negative' samples
114
+ # in the paper, we use strong negative samples = 1
115
+ class MUSDB_Dataset_Mixing_Manipulated_FXencoder(Dataset):
116
+ def __init__(self, args, \
117
+ mode, \
118
+ applying_effects='full', \
119
+ apply_prob_dict=None):
120
+ self.args = args
121
+ self.data_dir = args.data_dir + mode + "/"
122
+ self.mode = mode
123
+ self.applying_effects = applying_effects
124
+ self.normalization_order = args.normalization_order
125
+ self.fixed_random_seed = args.random_seed
126
+ self.pad_b4_manipulation = args.pad_b4_manipulation
127
+ self.pad_length = 2048
128
+
129
+ if 'musdb' in args.using_dataset.lower():
130
+ self.instruments = ["drums", "bass", "other", "vocals"]
131
+ else:
132
+ raise NotImplementedError
133
+
134
+ # path to contents
135
+ self.data_paths = {}
136
+ self.data_length_ratio_list = {}
137
+ # load data paths for each instrument
138
+ for cur_inst in self.instruments:
139
+ self.data_paths[cur_inst] = glob(f'{self.data_dir}{cur_inst}_normalized_{self.normalization_order}_silence_trimmed*.wav') \
140
+ if args.use_normalized else glob(f'{self.data_dir}{cur_inst}_silence_trimmed*.wav')
141
+ self.data_length_ratio_list[cur_inst] = []
142
+ # compute audio duration and its ratio
143
+ for cur_file_path in self.data_paths[cur_inst]:
144
+ cur_wav_length = load_wav_length(cur_file_path)
145
+ cur_inst_length_ratio = cur_wav_length / get_total_audio_length(self.data_paths[cur_inst])
146
+ self.data_length_ratio_list[cur_inst].append(cur_inst_length_ratio)
147
+
148
+ # load effects chain
149
+ if applying_effects=='full':
150
+ if apply_prob_dict==None:
151
+ # initial (default) applying probabilities of each FX
152
+ apply_prob_dict = {'eq' : 0.9, \
153
+ 'comp' : 0.9, \
154
+ 'pan' : 0.3, \
155
+ 'imager' : 0.8, \
156
+ 'gain': 0.5}
157
+ reverb_prob = {'drums' : 0.5, \
158
+ 'bass' : 0.01, \
159
+ 'vocals' : 0.9, \
160
+ 'other' : 0.7}
161
+
162
+ self.mixing_manipulator = {}
163
+ for cur_inst in self.data_paths.keys():
164
+ if 'reverb' in apply_prob_dict.keys():
165
+ if cur_inst=='drums':
166
+ cur_reverb_weight = 0.5
167
+ elif cur_inst=='bass':
168
+ cur_reverb_weight = 0.1
169
+ else:
170
+ cur_reverb_weight = 1.0
171
+ apply_prob_dict['reverb'] *= cur_reverb_weight
172
+ else:
173
+ apply_prob_dict['reverb'] = reverb_prob[cur_inst]
174
+ # create FXmanipulator for current instrument
175
+ self.mixing_manipulator[cur_inst] = create_inst_effects_augmentation_chain_(cur_inst, \
176
+ apply_prob_dict=apply_prob_dict, \
177
+ ir_dir_path=args.ir_dir_path, \
178
+ sample_rate=args.sample_rate)
179
+ # for single effects
180
+ else:
181
+ self.mixing_manipulator = {}
182
+ if not isinstance(applying_effects, list):
183
+ applying_effects = [applying_effects]
184
+ for cur_inst in self.data_paths.keys():
185
+ self.mixing_manipulator[cur_inst] = create_effects_augmentation_chain(applying_effects, \
186
+ ir_dir_path=args.ir_dir_path)
187
+
188
+
189
+ def __len__(self):
190
+ if self.mode=='train':
191
+ return self.args.batch_size_total * 40
192
+ else:
193
+ return self.args.batch_size_total
194
+
195
+
196
+ def __getitem__(self, idx):
197
+ if self.mode=="train":
198
+ torch.manual_seed(int(time.time())*(idx+1) % (2**32-1))
199
+ np.random.seed(int(time.time())*(idx+1) % (2**32-1))
200
+ random.seed(int(time.time())*(idx+1) % (2**32-1))
201
+ else:
202
+ # fixed random seed for evaluation
203
+ torch.manual_seed(idx*self.fixed_random_seed)
204
+ np.random.seed(idx*self.fixed_random_seed)
205
+ random.seed(idx*self.fixed_random_seed)
206
+
207
+ manipulated_segments = {}
208
+ for cur_neg_idx in range(self.args.num_strong_negatives+1):
209
+ manipulated_segments[cur_neg_idx] = {}
210
+
211
+ # load already-saved data to save time for on-the-fly manipulation
212
+ cur_data_dir_path = f"{self.data_dir}manipulated_encoder/{self.args.data_save_name}/{self.applying_effects}/{idx}/"
213
+ if self.mode=="val" and os.path.exists(cur_data_dir_path):
214
+ for cur_inst in self.instruments:
215
+ for cur_neg_idx in range(self.args.num_strong_negatives+1):
216
+ cur_A_file_path = f"{cur_data_dir_path}{cur_inst}_A{cur_neg_idx+1}.wav"
217
+ cur_B_file_path = f"{cur_data_dir_path}{cur_inst}_B{cur_neg_idx+1}.wav"
218
+ cur_A = load_wav_segment(cur_A_file_path, axis=0, sample_rate=self.args.sample_rate)
219
+ cur_B = load_wav_segment(cur_B_file_path, axis=0, sample_rate=self.args.sample_rate)
220
+ manipulated_segments[cur_neg_idx][cur_inst] = [torch.from_numpy(cur_A).float(), torch.from_numpy(cur_B).float()]
221
+ else:
222
+ # repeat for number of instruments
223
+ for cur_inst, cur_paths in self.data_paths.items():
224
+ # choose file_path to be loaded
225
+ cur_chosen_paths = np.random.choice(cur_paths, 2, p = self.data_length_ratio_list[cur_inst])
226
+ # get random 2 starting points for each instrument
227
+ last_point_A = load_wav_length(cur_chosen_paths[0])-self.args.segment_length_ref
228
+ last_point_B = load_wav_length(cur_chosen_paths[1])-self.args.segment_length_ref
229
+ # simply load more data to prevent artifacts likely to be caused by the manipulator
230
+ if self.pad_b4_manipulation:
231
+ last_point_A -= self.pad_length*2
232
+ last_point_B -= self.pad_length*2
233
+ cur_inst_start_point_A = torch.randint(low=0, \
234
+ high=last_point_A, \
235
+ size=(1,))[0]
236
+ cur_inst_start_point_B = torch.randint(low=0, \
237
+ high=last_point_B, \
238
+ size=(1,))[0]
239
+ # load wav segments from the selected starting points
240
+ load_duration = self.args.segment_length_ref+self.pad_length*2 if self.pad_b4_manipulation else self.args.segment_length_ref
241
+ cur_inst_segment_A = load_wav_segment(cur_chosen_paths[0], \
242
+ start_point=cur_inst_start_point_A, \
243
+ duration=load_duration, \
244
+ axis=1, \
245
+ sample_rate=self.args.sample_rate)
246
+ cur_inst_segment_B = load_wav_segment(cur_chosen_paths[1], \
247
+ start_point=cur_inst_start_point_B, \
248
+ duration=load_duration, \
249
+ axis=1, \
250
+ sample_rate=self.args.sample_rate)
251
+ # mixing manipulation
252
+ # append A# and B# with its strong negative samples
253
+ for cur_neg_idx in range(self.args.num_strong_negatives+1):
254
+ cur_manipulated_segment_A, cur_manipulated_segment_B = self.mixing_manipulator[cur_inst]([cur_inst_segment_A, cur_inst_segment_B])
255
+
256
+ # remove over-loaded area
257
+ if self.pad_b4_manipulation:
258
+ cur_manipulated_segment_A = cur_manipulated_segment_A[self.pad_length:-self.pad_length]
259
+ cur_manipulated_segment_B = cur_manipulated_segment_B[self.pad_length:-self.pad_length]
260
+ manipulated_segments[cur_neg_idx][cur_inst] = [torch.clamp(torch.transpose(torch.from_numpy(cur_manipulated_segment_A).float(), 1, 0), min=-1, max=1), \
261
+ torch.clamp(torch.transpose(torch.from_numpy(cur_manipulated_segment_B).float(), 1, 0), min=-1, max=1)]
262
+
263
+ # check manipulated data by saving them
264
+ if self.mode=="val" and not os.path.exists(cur_data_dir_path):
265
+ os.makedirs(cur_dir_path, exist_ok=True)
266
+ for cur_inst in manipulated_segments[0].keys():
267
+ for cur_manipulated_key, cur_manipualted_dict in manipulated_segments.items():
268
+ sf.write(f"{cur_dir_path}{cur_inst}_A{cur_manipulated_key+1}.wav", torch.transpose(cur_manipualted_dict[cur_inst][0], 1, 0), self.args.sample_rate, 'PCM_16')
269
+ sf.write(f"{cur_dir_path}{cur_inst}_B{cur_manipulated_key+1}.wav", torch.transpose(cur_manipualted_dict[cur_inst][1], 1, 0), self.args.sample_rate, 'PCM_16')
270
+
271
+ output_list = []
272
+ output_list_param = []
273
+ for cur_inst in manipulated_segments[0].keys():
274
+ for cur_manipulated_key, cur_manipualted_dict in manipulated_segments.items():
275
+ output_list.extend(cur_manipualted_dict[cur_inst])
276
+
277
+ '''
278
+ Output format:
279
+ list of effects manipulated stems of each instrument
280
+ drums_A1, drums_B1, drums_A2, drums_B2, drums_A3, drums_B3, ... ,
281
+ bass_A1, bass_B1, bass_A2, bass_B2, bass_A3, bass_B3, ... ,
282
+ other_A1, other_B1, other_A2, other_B2, other_A3, other_B3, ... ,
283
+ vocals_A1, vocals_B1, vocals_A2, vocals_B2, vocals_A3, vocals_B3, ...
284
+ each stem has the shape of (number of channels, segment duration)
285
+ '''
286
+ return output_list
287
+
288
+
289
+ # generate random manipulated results for evaluation
290
+ def generate_contents_w_effects(self, num_content, num_effects, out_dir):
291
+ print(f"start generating random effects of {self.applying_effects} applied contents")
292
+ os.makedirs(out_dir, exist_ok=True)
293
+
294
+ manipulated_segments = {}
295
+ for cur_fx_idx in range(num_effects):
296
+ manipulated_segments[cur_fx_idx] = {}
297
+ # repeat for number of instruments
298
+ for cur_inst, cur_paths in self.data_paths.items():
299
+ # choose file_path to be loaded
300
+ cur_path = np.random.choice(cur_paths, 1, p = self.data_length_ratio_list[cur_inst])[0]
301
+ print(f"\tgenerating instrument : {cur_inst}")
302
+ # get random 2 starting points for each instrument
303
+ last_point = load_wav_length(cur_path)-self.args.segment_length_ref
304
+ # simply load more data to prevent artifacts likely to be caused by the manipulator
305
+ if self.pad_b4_manipulation:
306
+ last_point -= self.pad_length*2
307
+ cur_inst_start_points = torch.randint(low=0, \
308
+ high=last_point, \
309
+ size=(num_content,))
310
+ # load wav segments from the selected starting points
311
+ cur_inst_segments = []
312
+ for cur_num_content in range(num_content):
313
+ cur_ori_sample = load_wav_segment(cur_path, \
314
+ start_point=cur_inst_start_points[cur_num_content], \
315
+ duration=self.args.segment_length_ref, \
316
+ axis=1, \
317
+ sample_rate=self.args.sample_rate)
318
+ cur_inst_segments.append(cur_ori_sample)
319
+
320
+ sf.write(f"{out_dir}{cur_inst}_ori_{cur_num_content}.wav", cur_ori_sample, self.args.sample_rate, 'PCM_16')
321
+
322
+ # mixing manipulation
323
+ for cur_fx_idx in range(num_effects):
324
+ cur_manipulated_segments = self.mixing_manipulator[cur_inst](cur_inst_segments)
325
+ # remove over-loaded area
326
+ if self.pad_b4_manipulation:
327
+ for cur_man_idx in range(len(cur_manipulated_segments)):
328
+ cur_segment_trimmed = cur_manipulated_segments[cur_man_idx][self.pad_length:-self.pad_length]
329
+ cur_manipulated_segments[cur_man_idx] = torch.clamp(torch.transpose(torch.from_numpy(cur_segment_trimmed).float(), 1, 0), min=-1, max=1)
330
+ manipulated_segments[cur_fx_idx][cur_inst] = cur_manipulated_segments
331
+
332
+ # write generated data
333
+ # save each instruments
334
+ for cur_inst in manipulated_segments[0].keys():
335
+ for cur_manipulated_key, cur_manipualted_dict in manipulated_segments.items():
336
+ for cur_content_idx in range(num_content):
337
+ sf.write(f"{out_dir}{cur_inst}_{chr(65+cur_content_idx//26)}{chr(65+cur_content_idx%26)}{cur_manipulated_key+1}.wav", torch.transpose(cur_manipualted_dict[cur_inst][cur_content_idx], 1, 0), self.args.sample_rate, 'PCM_16')
338
+ # save mixture
339
+ for cur_manipulated_key, cur_manipualted_dict in manipulated_segments.items():
340
+ for cur_content_idx in range(num_content):
341
+ for cur_idx, cur_inst in enumerate(manipulated_segments[0].keys()):
342
+ if cur_idx==0:
343
+ cur_mixture = cur_manipualted_dict[cur_inst][cur_content_idx]
344
+ else:
345
+ cur_mixture += cur_manipualted_dict[cur_inst][cur_content_idx]
346
+ sf.write(f"{out_dir}mixture_{chr(65+cur_content_idx//26)}{chr(65+cur_content_idx%26)}{cur_manipulated_key+1}.wav", torch.transpose(cur_mixture, 1, 0), self.args.sample_rate, 'PCM_16')
347
+
348
+ return
349
+
350
+
351
+
352
+ # Data loader for training the 'Mastering Style Converter'
353
+ # loads two segments (A and B) from the dataset
354
+ # both segments are manipulated via Mastering Effects Manipulator (resulting A1, A2, and B2)
355
+ # one of the manipulated segment is used as a reference segment (B2), which is randomly manipulated the same as the ground truth segment (A2)
356
+ class MUSDB_Dataset_Mixing_Manipulated_Style_Transfer(Dataset):
357
+ def __init__(self, args, \
358
+ mode, \
359
+ applying_effects='full', \
360
+ apply_prob_dict=None):
361
+ self.args = args
362
+ self.data_dir = args.data_dir + mode + "/"
363
+ self.mode = mode
364
+ self.applying_effects = applying_effects
365
+ self.fixed_random_seed = args.random_seed
366
+ self.pad_b4_manipulation = args.pad_b4_manipulation
367
+ self.pad_length = 2048
368
+
369
+ if 'musdb' in args.using_dataset.lower():
370
+ self.instruments = ["drums", "bass", "other", "vocals"]
371
+ else:
372
+ raise NotImplementedError
373
+
374
+ # load data paths for each instrument
375
+ self.data_paths = {}
376
+ self.data_length_ratio_list = {}
377
+ for cur_inst in self.instruments:
378
+ self.data_paths[cur_inst] = glob(f'{self.data_dir}{cur_inst}_normalized_{self.args.normalization_order}_silence_trimmed*.wav') \
379
+ if args.use_normalized else glob(f'{self.data_dir}{cur_inst}_silence_trimmed.wav')
380
+ self.data_length_ratio_list[cur_inst] = []
381
+ # compute audio duration and its ratio
382
+ for cur_file_path in self.data_paths[cur_inst]:
383
+ cur_wav_length = load_wav_length(cur_file_path)
384
+ cur_inst_length_ratio = cur_wav_length / get_total_audio_length(self.data_paths[cur_inst])
385
+ self.data_length_ratio_list[cur_inst].append(cur_inst_length_ratio)
386
+
387
+ self.mixing_manipulator = {}
388
+ if applying_effects=='full':
389
+ if apply_prob_dict==None:
390
+ # initial (default) applying probabilities of each FX
391
+ # we don't update these probabilities for training the MixFXcloner
392
+ apply_prob_dict = {'eq' : 0.9, \
393
+ 'comp' : 0.9, \
394
+ 'pan' : 0.3, \
395
+ 'imager' : 0.8, \
396
+ 'gain': 0.5}
397
+ reverb_prob = {'drums' : 0.5, \
398
+ 'bass' : 0.01, \
399
+ 'vocals' : 0.9, \
400
+ 'other' : 0.7}
401
+ for cur_inst in self.data_paths.keys():
402
+ if 'reverb' in apply_prob_dict.keys():
403
+ if cur_inst=='drums':
404
+ cur_reverb_weight = 0.5
405
+ elif cur_inst=='bass':
406
+ cur_reverb_weight = 0.1
407
+ else:
408
+ cur_reverb_weight = 1.0
409
+ apply_prob_dict['reverb'] *= cur_reverb_weight
410
+ else:
411
+ apply_prob_dict['reverb'] = reverb_prob[cur_inst]
412
+ self.mixing_manipulator[cur_inst] = create_inst_effects_augmentation_chain(cur_inst, \
413
+ apply_prob_dict=apply_prob_dict, \
414
+ ir_dir_path=args.ir_dir_path, \
415
+ sample_rate=args.sample_rate)
416
+ # for single effects
417
+ else:
418
+ if not isinstance(applying_effects, list):
419
+ applying_effects = [applying_effects]
420
+ for cur_inst in self.data_paths.keys():
421
+ self.mixing_manipulator[cur_inst] = create_effects_augmentation_chain(applying_effects, \
422
+ ir_dir_path=args.ir_dir_path)
423
+
424
+
425
+ def __len__(self):
426
+ min_length = get_total_audio_length(glob(f'{self.data_dir}vocals_normalized_{self.args.normalization_order}*.wav'))
427
+ data_len = min_length // self.args.segment_length
428
+ return data_len
429
+
430
+
431
+ def __getitem__(self, idx):
432
+ if self.mode=="train":
433
+ torch.manual_seed(int(time.time())*(idx+1) % (2**32-1))
434
+ np.random.seed(int(time.time())*(idx+1) % (2**32-1))
435
+ random.seed(int(time.time())*(idx+1) % (2**32-1))
436
+ else:
437
+ # fixed random seed for evaluation
438
+ torch.manual_seed(idx*self.fixed_random_seed)
439
+ np.random.seed(idx*self.fixed_random_seed)
440
+ random.seed(idx*self.fixed_random_seed)
441
+
442
+ manipulated_segments = {}
443
+
444
+ # load already-saved data to save time for on-the-fly manipulation
445
+ cur_data_dir_path = f"{self.data_dir}manipulated_converter/{self.args.data_save_name}/{self.applying_effects}/{idx}/"
446
+ if self.mode=="val" and os.path.exists(cur_data_dir_path):
447
+ for cur_inst in self.instruments:
448
+ cur_A1_file_path = f"{cur_data_dir_path}{cur_inst}_A1.wav"
449
+ cur_A2_file_path = f"{cur_data_dir_path}{cur_inst}_A2.wav"
450
+ cur_B2_file_path = f"{cur_data_dir_path}{cur_inst}_B2.wav"
451
+ cur_manipulated_segment_A1 = load_wav_segment(cur_A1_file_path, axis=0, sample_rate=self.args.sample_rate)
452
+ cur_manipulated_segment_A2 = load_wav_segment(cur_A2_file_path, axis=0, sample_rate=self.args.sample_rate)
453
+ cur_manipulated_segment_B2 = load_wav_segment(cur_B2_file_path, axis=0, sample_rate=self.args.sample_rate)
454
+ manipulated_segments[cur_inst] = [torch.from_numpy(cur_manipulated_segment_A1).float(), \
455
+ torch.from_numpy(cur_manipulated_segment_A2).float(), \
456
+ torch.from_numpy(cur_manipulated_segment_B2).float()]
457
+ else:
458
+ # repeat for number of instruments
459
+ for cur_inst, cur_paths in self.data_paths.items():
460
+ # choose file_path to be loaded
461
+ cur_chosen_paths = np.random.choice(cur_paths, 2, p = self.data_length_ratio_list[cur_inst])
462
+ # cur_chosen_paths = [cur_paths[idx], cur_paths[idx+1]]
463
+ # get random 2 starting points for each instrument
464
+ last_point_A = load_wav_length(cur_chosen_paths[0])-self.args.segment_length_ref
465
+ last_point_B = load_wav_length(cur_chosen_paths[1])-self.args.segment_length_ref
466
+ # simply load more data to prevent artifacts likely to be caused by the manipulator
467
+ if self.pad_b4_manipulation:
468
+ last_point_A -= self.pad_length*2
469
+ last_point_B -= self.pad_length*2
470
+ cur_inst_start_point_A = torch.randint(low=0, \
471
+ high=last_point_A, \
472
+ size=(1,))[0]
473
+ cur_inst_start_point_B = torch.randint(low=0, \
474
+ high=last_point_B, \
475
+ size=(1,))[0]
476
+ # load wav segments from the selected starting points
477
+ load_duration = self.args.segment_length_ref+self.pad_length*2 if self.pad_b4_manipulation else self.args.segment_length_ref
478
+ cur_inst_segment_A = load_wav_segment(cur_chosen_paths[0], \
479
+ start_point=cur_inst_start_point_A, \
480
+ duration=load_duration, \
481
+ axis=1, \
482
+ sample_rate=self.args.sample_rate)
483
+ cur_inst_segment_B = load_wav_segment(cur_chosen_paths[1], \
484
+ start_point=cur_inst_start_point_B, \
485
+ duration=load_duration, \
486
+ axis=1, \
487
+ sample_rate=self.args.sample_rate)
488
+ ''' mixing manipulation '''
489
+ # manipulate segment A and B to produce
490
+ # input : A1 (normalized sample)
491
+ # ground truth : A2
492
+ # reference : B2
493
+ cur_manipulated_segment_A1 = cur_inst_segment_A
494
+ cur_manipulated_segment_A2, cur_manipulated_segment_B2 = self.mixing_manipulator[cur_inst]([cur_inst_segment_A, cur_inst_segment_B])
495
+ # remove over-loaded area
496
+ if self.pad_b4_manipulation:
497
+ cur_manipulated_segment_A1 = cur_manipulated_segment_A1[self.pad_length:-self.pad_length]
498
+ cur_manipulated_segment_A2 = cur_manipulated_segment_A2[self.pad_length:-self.pad_length]
499
+ cur_manipulated_segment_B2 = cur_manipulated_segment_B2[self.pad_length:-self.pad_length]
500
+ manipulated_segments[cur_inst] = [torch.clamp(torch.transpose(torch.from_numpy(cur_manipulated_segment_A1).float(), 1, 0), min=-1, max=1), \
501
+ torch.clamp(torch.transpose(torch.from_numpy(cur_manipulated_segment_A2).float(), 1, 0), min=-1, max=1), \
502
+ torch.clamp(torch.transpose(torch.from_numpy(cur_manipulated_segment_B2).float(), 1, 0), min=-1, max=1)]
503
+
504
+ # check manipulated data by saving them
505
+ if (self.mode=="val" and not os.path.exists(cur_data_dir_path)):
506
+ mixture_dict = {}
507
+ for cur_inst in manipulated_segments.keys():
508
+ cur_inst_dir_path = f"{cur_data_dir_path}{idx}/{cur_inst}/"
509
+ os.makedirs(cur_inst_dir_path, exist_ok=True)
510
+ sf.write(f"{cur_inst_dir_path}A1.wav", torch.transpose(manipulated_segments[cur_inst][0], 1, 0), self.args.sample_rate, 'PCM_16')
511
+ sf.write(f"{cur_inst_dir_path}A2.wav", torch.transpose(manipulated_segments[cur_inst][1], 1, 0), self.args.sample_rate, 'PCM_16')
512
+ sf.write(f"{cur_inst_dir_path}B2.wav", torch.transpose(manipulated_segments[cur_inst][2], 1, 0), self.args.sample_rate, 'PCM_16')
513
+ mixture_dict['A1'] = torch.transpose(manipulated_segments[cur_inst][0], 1, 0)
514
+ mixture_dict['A2'] = torch.transpose(manipulated_segments[cur_inst][1], 1, 0)
515
+ mixture_dict['B2'] = torch.transpose(manipulated_segments[cur_inst][2], 1, 0)
516
+ cur_mix_dir_path = f"{cur_data_dir_path}{idx}/mixture/"
517
+ os.makedirs(cur_mix_dir_path, exist_ok=True)
518
+ sf.write(f"{cur_mix_dir_path}A1.wav", mixture_dict['A1'], self.args.sample_rate, 'PCM_16')
519
+ sf.write(f"{cur_mix_dir_path}A2.wav", mixture_dict['A2'], self.args.sample_rate, 'PCM_16')
520
+ sf.write(f"{cur_mix_dir_path}B2.wav", mixture_dict['B2'], self.args.sample_rate, 'PCM_16')
521
+
522
+ output_list = []
523
+ for cur_inst in manipulated_segments.keys():
524
+ output_list.extend(manipulated_segments[cur_inst])
525
+
526
+ '''
527
+ Output format:
528
+ list of effects manipulated stems of each instrument
529
+ drums_A1, drums_A2, drums_B2,
530
+ bass_A1, bass_A2, bass_B2,
531
+ other_A1, other_A2, other_B2,
532
+ vocals_A1, vocals_A2, vocals_B2,
533
+ each stem has the shape of (number of channels, segment duration)
534
+ Notation :
535
+ A1 = input of the network
536
+ A2 = ground truth
537
+ B2 = reference track
538
+ '''
539
+ return output_list
540
+
541
+
542
+
543
+ # Data loader for inferencing the task 'Mixing Style Transfer'
544
+ ### loads whole mixture or stems from the target directory
545
+ class Song_Dataset_Inference(Dataset):
546
+ def __init__(self, args):
547
+ self.args = args
548
+ self.data_dir = args.target_dir
549
+ self.interpolate = args.interpolation
550
+
551
+ self.instruments = args.instruments
552
+
553
+ self.data_dir_paths = sorted(glob(f"{self.data_dir}*/"))
554
+
555
+ self.input_name = args.input_file_name
556
+ self.reference_name = args.reference_file_name
557
+ self.stem_level_directory_name = args.stem_level_directory_name \
558
+ if self.args.do_not_separate else os.path.join(args.stem_level_directory_name, args.separation_model)
559
+ if self.interpolate:
560
+ self.reference_name_B = args.reference_file_name_2interpolate
561
+
562
+ # audio effects normalizer
563
+ if args.normalize_input:
564
+ self.normalization_chain = Audio_Effects_Normalizer(precomputed_feature_path=args.precomputed_normalization_feature, \
565
+ STEMS=args.instruments, \
566
+ EFFECTS=args.normalization_order)
567
+
568
+
569
+ def __len__(self):
570
+ return len(self.data_dir_paths)
571
+
572
+
573
+ def __getitem__(self, idx):
574
+ ''' stem-level conversion '''
575
+ input_stems = []
576
+ reference_stems = []
577
+ reference_B_stems = []
578
+ for cur_inst in self.instruments:
579
+ cur_input_file_path = os.path.join(self.data_dir_paths[idx], self.stem_level_directory_name, self.input_name, cur_inst+'.wav')
580
+ cur_reference_file_path = os.path.join(self.data_dir_paths[idx], self.stem_level_directory_name, self.reference_name, cur_inst+'.wav')
581
+
582
+ # load wav
583
+ cur_input_wav = load_wav_segment(cur_input_file_path, axis=0, sample_rate=self.args.sample_rate)
584
+ cur_reference_wav = load_wav_segment(cur_reference_file_path, axis=0, sample_rate=self.args.sample_rate)
585
+
586
+ if self.args.normalize_input:
587
+ cur_input_wav = self.normalization_chain.normalize_audio(cur_input_wav.transpose(), src=cur_inst).transpose()
588
+
589
+ input_stems.append(torch.clamp(torch.from_numpy(cur_input_wav).float(), min=-1, max=1))
590
+ reference_stems.append(torch.clamp(torch.from_numpy(cur_reference_wav).float(), min=-1, max=1))
591
+
592
+ # for interpolation
593
+ if self.interpolate:
594
+ cur_reference_B_file_path = os.path.join(self.data_dir_paths[idx], self.stem_level_directory_name, self.reference_name_B, cur_inst+'.wav')
595
+ cur_reference_B_wav = load_wav_segment(cur_reference_B_file_path, axis=0, sample_rate=self.args.sample_rate)
596
+ reference_B_stems.append(torch.clamp(torch.from_numpy(cur_reference_B_wav).float(), min=-1, max=1))
597
+
598
+ dir_name = os.path.dirname(self.data_dir_paths[idx])
599
+
600
+ if self.interpolate:
601
+ return torch.stack(input_stems, dim=0), torch.stack(reference_stems, dim=0), torch.stack(reference_B_stems, dim=0), dir_name
602
+ else:
603
+ return torch.stack(input_stems, dim=0), torch.stack(reference_stems, dim=0), dir_name
604
+
605
+
606
+
607
+ # check dataset
608
+ if __name__ == '__main__':
609
+ """
610
+ Test code of data loaders
611
+ """
612
+ import time
613
+ print('checking dataset...')
614
+
615
+ total_epochs = 1
616
+ bs = 5
617
+ check_step_size = 3
618
+ collate_class = Collate_Variable_Length_Segments(args)
619
+
620
+
621
+ print('\n========== Effects Encoder ==========')
622
+ from config import args
623
+ ##### generate samples with ranfom configuration
624
+ # args.normalization_order = 'eqcompimagegain'
625
+ # for cur_effect in ['full', 'gain', 'comp', 'reverb', 'eq', 'imager', 'pan']:
626
+ # start_time = time.time()
627
+ # dataset = MUSDB_Dataset_Mixing_Manipulated_FXencoder(args, mode='val', applying_effects=cur_effect, check_data=True)
628
+ # dataset.generate_contents_w_effects(num_content=25, num_effects=10)
629
+ # print(f'\t---time taken : {time.time()-start_time}---')
630
+
631
+ ### training data loder
632
+ dataset = MUSDB_Dataset_Mixing_Manipulated_FXencoder(args, mode='train', applying_effects=['comp'])
633
+ data_loader = DataLoader(dataset, \
634
+ batch_size=bs, \
635
+ shuffle=False, \
636
+ collate_fn=collate_class.random_duration_segments_strong_negatives, \
637
+ drop_last=False, \
638
+ num_workers=0)
639
+
640
+ for epoch in range(total_epochs):
641
+ start_time_loader = time.time()
642
+ for step, output_list in enumerate(data_loader):
643
+ if step==check_step_size:
644
+ break
645
+ print(f'Epoch {epoch+1}/{total_epochs}\tStep {step+1}/{len(data_loader)}')
646
+ print(f'num contents : {len(output_list)}\tnum instruments : {len(output_list[0])}\tcontent A shape : {output_list[0][0].shape}\t content B shape : {output_list[1][0].shape} \ttime taken: {time.time()-start_time_loader:.4f}')
647
+ start_time_loader = time.time()
648
+
649
+
650
+ print('\n========== Mixing Style Transfer ==========')
651
+ from trainer_mixing_transfer.config_conv import args
652
+ ### training data loder
653
+ dataset = MUSDB_Dataset_Mixing_Manipulated_Style_Transfer(args, mode='train')
654
+ data_loader = DataLoader(dataset, \
655
+ batch_size=bs, \
656
+ shuffle=False, \
657
+ collate_fn=collate_class.style_transfer_collate, \
658
+ drop_last=False, \
659
+ num_workers=0)
660
+
661
+ for epoch in range(total_epochs):
662
+ start_time_loader = time.time()
663
+ for step, output_list in enumerate(data_loader):
664
+ if step==check_step_size:
665
+ break
666
+ print(f'Epoch {epoch+1}/{total_epochs}\tStep {step+1}/{len(data_loader)}')
667
+ print(f'num contents : {len(output_list)}\tnum instruments : {len(output_list[0])}\tA1 shape : {output_list[0][0].shape}\tA2 shape : {output_list[1][0].shape}\tA3 shape : {output_list[2][0].shape}\ttime taken: {time.time()-start_time_loader:.4f}')
668
+ start_time_loader = time.time()
669
+
670
+
671
+ print('\n--- checking dataset completed ---')
672
+
mixing_style_transfer/data_loader/loader_utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Utility file for loaders """
2
+
3
+ import numpy as np
4
+ import soundfile as sf
5
+ import wave
6
+
7
+
8
+
9
+ # Function to convert frame level audio into atomic time
10
+ def frames_to_time(total_length, sr=44100):
11
+ in_time = total_length / sr
12
+ hour = int(in_time / 3600)
13
+ minute = int((in_time - hour*3600) / 60)
14
+ second = int(in_time - hour*3600 - minute*60)
15
+ return f"{hour:02d}:{minute:02d}:{second:02d}"
16
+
17
+
18
+ # Function to convert atomic labeled time into frames or seconds
19
+ def time_to_frames(input_time, to_frames=True, sr=44100):
20
+ hour, minute, second = input_time.split(':')
21
+ total_seconds = int(hour)*3600 + int(minute)*60 + int(second)
22
+ return total_seconds*sr if to_frames else total_seconds
23
+
24
+
25
+ # Function to convert seconds to atomic labeled time
26
+ def sec_to_time(input_time):
27
+ return frames_to_time(input_time, sr=1)
28
+
29
+
30
+ # Function to load total trainable raw audio lengths
31
+ def get_total_audio_length(audio_paths):
32
+ total_length = 0
33
+ for cur_audio_path in audio_paths:
34
+ cur_wav = wave.open(cur_audio_path, 'r')
35
+ total_length += cur_wav.getnframes() # here, length = # of frames
36
+ return total_length
37
+
38
+
39
+ # Function to load length of an input wav audio
40
+ def load_wav_length(audio_path):
41
+ pt_wav = wave.open(audio_path, 'r')
42
+ length = pt_wav.getnframes()
43
+ return length
44
+
45
+
46
+ # Function to load only selected 16 bit, stereo wav audio segment from an input wav audio
47
+ def load_wav_segment(audio_path, start_point=None, duration=None, axis=1, sample_rate=44100):
48
+ start_point = 0 if start_point==None else start_point
49
+ duration = load_wav_length(audio_path) if duration==None else duration
50
+ pt_wav = wave.open(audio_path, 'r')
51
+
52
+ if pt_wav.getframerate()!=sample_rate:
53
+ raise ValueError(f"ValueError: input audio's sample rate should be {sample_rate}")
54
+ pt_wav.setpos(start_point)
55
+ x = pt_wav.readframes(duration)
56
+ if pt_wav.getsampwidth()==2:
57
+ x = np.frombuffer(x, dtype=np.int16)
58
+ X = x / float(2**15) # needs to be 16 bit format
59
+ elif pt_wav.getsampwidth()==4:
60
+ x = np.frombuffer(x, dtype=np.int32)
61
+ X = x / float(2**31) # needs to be 32 bit format
62
+ else:
63
+ raise ValueError("ValueError: input audio's bit depth should be 16 or 32-bit")
64
+
65
+ # exception for stereo channels
66
+ if pt_wav.getnchannels()==2:
67
+ X_l = np.expand_dims(X[::2], axis=axis)
68
+ X_r = np.expand_dims(X[1::2], axis=axis)
69
+ X = np.concatenate((X_l, X_r), axis=axis)
70
+ return X
71
+
mixing_style_transfer/mixing_manipulator/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .audio_effects_chain import *
2
+ from .common_audioeffects import *
3
+ from .common_dataprocessing import create_dataset
4
+ from data_normalization import Audio_Effects_Normalizer
mixing_style_transfer/mixing_manipulator/audio_effects_chain.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of Audio Effects Chain Manipulation for the task 'Mixing Style Transfer'
3
+ """
4
+ from glob import glob
5
+ import os
6
+ import sys
7
+
8
+ currentdir = os.path.dirname(os.path.realpath(__file__))
9
+ sys.path.append(currentdir)
10
+ sys.path.append(os.path.dirname(currentdir))
11
+ from common_audioeffects import *
12
+ from common_dataprocessing import create_dataset
13
+
14
+
15
+
16
+ # create augmentation effects chain according to targeted effects with their applying probability
17
+ def create_effects_augmentation_chain(effects, \
18
+ ir_dir_path=None, \
19
+ sample_rate=44100, \
20
+ shuffle=False, \
21
+ parallel=False, \
22
+ parallel_weight_factor=None):
23
+ '''
24
+ Args:
25
+ effects (list of tuples or string) : First tuple element is string denoting the target effects.
26
+ Second tuple element is probability of applying current effects.
27
+ ir_dir_path (string) : directory path that contains directories of impulse responses organized according to RT60
28
+ sample_rate (int) : using sampling rate
29
+ shuffle (boolean) : shuffle FXs inside current FX chain
30
+ parallel (boolean) : compute parallel FX computation (alpha * input + (1-alpha) * manipulated output)
31
+ parallel_weight_factor : the value of alpha for parallel FX computation. default=None : random value in between (0.0, 0.5)
32
+ '''
33
+ fx_list = []
34
+ apply_prob = []
35
+ for cur_fx in effects:
36
+ # store probability to apply current effects. default is to set as 100%
37
+ if isinstance(cur_fx, tuple):
38
+ apply_prob.append(cur_fx[1])
39
+ cur_fx = cur_fx[0]
40
+ else:
41
+ apply_prob.append(1)
42
+
43
+ # processors of each audio effects
44
+ if isinstance(cur_fx, AugmentationChain) or isinstance(cur_fx, Processor):
45
+ fx_list.append(cur_fx)
46
+ elif cur_fx.lower()=='gain':
47
+ fx_list.append(Gain())
48
+ elif 'eq' in cur_fx.lower():
49
+ fx_list.append(Equaliser(n_channels=2, sample_rate=sample_rate))
50
+ elif 'comp' in cur_fx.lower():
51
+ fx_list.append(Compressor(sample_rate=sample_rate))
52
+ elif 'expand' in cur_fx.lower():
53
+ fx_list.append(Expander(sample_rate=sample_rate))
54
+ elif 'pan' in cur_fx.lower():
55
+ fx_list.append(Panner())
56
+ elif 'image'in cur_fx.lower():
57
+ fx_list.append(MidSideImager())
58
+ elif 'algorithmic' in cur_fx.lower():
59
+ fx_list.append(AlgorithmicReverb(sample_rate=sample_rate))
60
+ elif 'reverb' in cur_fx.lower():
61
+ # apply algorithmic reverberation if ir_dir_path is not defined
62
+ if ir_dir_path==None:
63
+ fx_list.append(AlgorithmicReverb(sample_rate=sample_rate))
64
+ # apply convolution reverberation
65
+ else:
66
+ IR_paths = glob(f"{ir_dir_path}*/RT60_avg/[!0-]*")
67
+ IR_list = []
68
+ IR_dict = {}
69
+ for IR_path in IR_paths:
70
+ cur_rt = IR_path.split('/')[-1]
71
+ if cur_rt not in IR_dict:
72
+ IR_dict[cur_rt] = []
73
+ IR_dict[cur_rt].extend(create_dataset(path=IR_path, \
74
+ accepted_sampling_rates=[sample_rate], \
75
+ sources=['impulse_response'], \
76
+ mapped_sources={}, load_to_memory=True, debug=False)[0])
77
+ long_ir_list = []
78
+ for cur_rt in IR_dict:
79
+ cur_rt_len = int(cur_rt.split('-')[0])
80
+ if cur_rt_len < 3000:
81
+ IR_list.append(IR_dict[cur_rt])
82
+ else:
83
+ long_ir_list.extend(IR_dict[cur_rt])
84
+
85
+ IR_list.append(long_ir_list)
86
+ fx_list.append(ConvolutionalReverb(IR_list, sample_rate))
87
+ else:
88
+ raise ValueError(f"make sure the target effects are in the Augment FX chain : received fx called {cur_fx}")
89
+
90
+ aug_chain_in = []
91
+ for cur_i, cur_fx in enumerate(fx_list):
92
+ normalize = False if isinstance(cur_fx, AugmentationChain) or cur_fx.name=='Gain' else True
93
+ aug_chain_in.append((cur_fx, apply_prob[cur_i], normalize))
94
+
95
+ return AugmentationChain(fxs=aug_chain_in, shuffle=shuffle, parallel=parallel, parallel_weight_factor=parallel_weight_factor)
96
+
97
+
98
+ # create audio FX-chain according to input instrument
99
+ def create_inst_effects_augmentation_chain(inst, \
100
+ apply_prob_dict, \
101
+ ir_dir_path=None, \
102
+ algorithmic=False, \
103
+ sample_rate=44100):
104
+ '''
105
+ Args:
106
+ inst (string) : FXmanipulator for target instrument. Current version only distinguishes 'drums' for applying reverberation
107
+ apply_prob_dict (dictionary of (FX name, probability)) : applying proababilities for each FX
108
+ ir_dir_path (string) : directory path that contains directories of impulse responses organized according to RT60
109
+ algorithmic (boolean) : rather to use algorithmic reverberation (True) or convolution reverberation (False)
110
+ sample_rate (int) : using sampling rate
111
+ '''
112
+ reverb_type = 'algorithmic' if algorithmic else 'reverb'
113
+ eq_comp_rand = create_effects_augmentation_chain([('eq', apply_prob_dict['eq']), ('comp', apply_prob_dict['comp'])], \
114
+ ir_dir_path=ir_dir_path, \
115
+ sample_rate=sample_rate, \
116
+ shuffle=True)
117
+ pan_image_rand = create_effects_augmentation_chain([('pan', apply_prob_dict['pan']), ('imager', apply_prob_dict['imager'])], \
118
+ ir_dir_path=ir_dir_path, \
119
+ sample_rate=sample_rate, \
120
+ shuffle=True)
121
+ if inst=='drums':
122
+ # apply reverberation to low frequency with little probability
123
+ low_pass_eq_params = ParameterList()
124
+ low_pass_eq_params.add(Parameter('high_shelf_gain', -50.0, 'float', minimum=-50.0, maximum=-50.0))
125
+ low_pass_eq_params.add(Parameter('high_shelf_freq', 100.0, 'float', minimum=100.0, maximum=100.0))
126
+ low_pass_eq = Equaliser(n_channels=2, \
127
+ sample_rate=sample_rate, \
128
+ bands=['high_shelf'], \
129
+ parameters=low_pass_eq_params)
130
+ reverb_parallel_low = create_effects_augmentation_chain([low_pass_eq, (reverb_type, apply_prob_dict['reverb']*0.01)], \
131
+ ir_dir_path=ir_dir_path, \
132
+ sample_rate=sample_rate, \
133
+ parallel=True, \
134
+ parallel_weight_factor=0.8)
135
+ # high pass eq for drums reverberation
136
+ high_pass_eq_params = ParameterList()
137
+ high_pass_eq_params.add(Parameter('low_shelf_gain', -50.0, 'float', minimum=-50.0, maximum=-50.0))
138
+ high_pass_eq_params.add(Parameter('low_shelf_freq', 100.0, 'float', minimum=100.0, maximum=100.0))
139
+ high_pass_eq = Equaliser(n_channels=2, \
140
+ sample_rate=sample_rate, \
141
+ bands=['low_shelf'], \
142
+ parameters=high_pass_eq_params)
143
+ reverb_parallel_high = create_effects_augmentation_chain([high_pass_eq, (reverb_type, apply_prob_dict['reverb'])], \
144
+ ir_dir_path=ir_dir_path, \
145
+ sample_rate=sample_rate, \
146
+ parallel=True, \
147
+ parallel_weight_factor=0.6)
148
+ reverb_parallel = create_effects_augmentation_chain([reverb_parallel_low, reverb_parallel_high], \
149
+ ir_dir_path=ir_dir_path, \
150
+ sample_rate=sample_rate)
151
+ else:
152
+ reverb_parallel = create_effects_augmentation_chain([(reverb_type, apply_prob_dict['reverb'])], \
153
+ ir_dir_path=ir_dir_path, \
154
+ sample_rate=sample_rate, \
155
+ parallel=True)
156
+ # full effects chain
157
+ effects_chain = create_effects_augmentation_chain([eq_comp_rand, \
158
+ pan_image_rand, \
159
+ reverb_parallel, \
160
+ ('gain', apply_prob_dict['gain'])], \
161
+ ir_dir_path=ir_dir_path, \
162
+ sample_rate=sample_rate)
163
+
164
+ return effects_chain
165
+
mixing_style_transfer/mixing_manipulator/common_audioeffects.py ADDED
@@ -0,0 +1,1537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio effects for data augmentation.
3
+
4
+ Several audio effects can be combined into an augmentation chain.
5
+
6
+ Important note: We assume that the parallelization during training is done using
7
+ multi-processing and not multi-threading. Hence, we do not need the
8
+ `@sox.sox_context()` decorators as discussed in this
9
+ [thread](https://github.com/pseeth/soxbindings/issues/4).
10
+
11
+ AI Music Technology Group, Sony Group Corporation
12
+ AI Speech and Sound Group, Sony Europe
13
+
14
+
15
+ This implementation originally belongs to Sony Group Corporation,
16
+ which has been introduced in the work "Automatic music mixing with deep learning and out-of-domain data".
17
+ Original repo link: https://github.com/sony/FxNorm-automix
18
+ This work modifies a few implementations from the original repo to suit the task.
19
+ """
20
+
21
+ from itertools import permutations
22
+ import logging
23
+ import numpy as np
24
+ import pymixconsole as pymc
25
+ from pymixconsole.parameter import Parameter
26
+ from pymixconsole.parameter_list import ParameterList
27
+ from pymixconsole.processor import Processor
28
+ from random import shuffle
29
+ from scipy.signal import oaconvolve
30
+ import soxbindings as sox
31
+ from typing import List, Optional, Tuple, Union
32
+ from numba import jit
33
+
34
+ # prevent pysox from logging warnings regarding non-opimal timestretch factors
35
+ logging.getLogger('sox').setLevel(logging.ERROR)
36
+
37
+
38
+ # Monkey-Patch `Processor` for convenience
39
+ # (a) Allow `None` as blocksize if processor can work on variable-length audio
40
+ def new_init(self, name, parameters, block_size, sample_rate, dtype='float32'):
41
+ """
42
+ Initialize processor.
43
+
44
+ Args:
45
+ self: Reference to object
46
+ name (str): Name of processor.
47
+ parameters (parameter_list): Parameters for this processor.
48
+ block_size (int): Size of blocks for blockwise processing.
49
+ Can also be `None` if full audio can be processed at once.
50
+ sample_rate (int): Sample rate of input audio. Use `None` if effect is independent of this value.
51
+ dtype (str): data type of samples
52
+ """
53
+ self.name = name
54
+ self.parameters = parameters
55
+ self.block_size = block_size
56
+ self.sample_rate = sample_rate
57
+ self.dtype = dtype
58
+
59
+
60
+ # (b) make code simpler
61
+ def new_update(self, parameter_name):
62
+ """
63
+ Update processor after randomization of parameters.
64
+
65
+ Args:
66
+ self: Reference to object.
67
+ parameter_name (str): Parameter whose value has changed.
68
+ """
69
+ pass
70
+
71
+
72
+ # (c) representation for nice print
73
+ def new_repr(self):
74
+ """
75
+ Create human-readable representation.
76
+
77
+ Args:
78
+ self: Reference to object.
79
+
80
+ Returns:
81
+ string representation of object.
82
+ """
83
+ return f'Processor(name={self.name!r}, parameters={self.parameters!r}'
84
+
85
+
86
+ Processor.__init__ = new_init
87
+ Processor.__repr__ = new_repr
88
+ Processor.update = new_update
89
+
90
+
91
+ class AugmentationChain:
92
+ """Basic audio Fx chain which is used for data augmentation."""
93
+
94
+ def __init__(self,
95
+ fxs: Optional[List[Tuple[Union[Processor, 'AugmentationChain'], float, bool]]] = [],
96
+ shuffle: Optional[bool] = False,
97
+ parallel: Optional[bool] = False,
98
+ parallel_weight_factor = None,
99
+ randomize_param_value=True):
100
+ """
101
+ Create augmentation chain from the dictionary `fxs`.
102
+
103
+ Args:
104
+ fxs (list of tuples): First tuple element is an instances of `pymc.processor` or `AugmentationChain` that
105
+ we want to use for data augmentation. Second element gives probability that effect should be applied.
106
+ Third element defines, whether the processed signal is normalized by the RMS of the input.
107
+ shuffle (bool): If `True` then order of Fx are changed whenever chain is applied.
108
+ """
109
+ self.fxs = fxs
110
+ self.shuffle = shuffle
111
+ self.parallel = parallel
112
+ self.parallel_weight_factor = parallel_weight_factor
113
+ self.randomize_param_value = randomize_param_value
114
+
115
+ def apply_processor(self, x, processor: Processor, rms_normalize):
116
+ """
117
+ Pass audio in `x` through `processor` and output the respective processed audio.
118
+
119
+ Args:
120
+ x (Numpy array): Input audio of shape `n_samples` x `n_channels`.
121
+ processor (Processor): Audio effect that we want to apply.
122
+ rms_normalize (bool): If `True`, the processed signal is normalized by the RMS of the signal.
123
+
124
+ Returns:
125
+ Numpy array: Processed audio of shape `n_samples` x `n_channels` (same size as `x')
126
+ """
127
+
128
+ n_samples_input = x.shape[0]
129
+
130
+ if processor.block_size is None:
131
+ y = processor.process(x)
132
+ else:
133
+ # make sure that n_samples is a multiple of `processor.block_size`
134
+ if x.shape[0] % processor.block_size != 0:
135
+ n_pad = processor.block_size - x.shape[0] % processor.block_size
136
+ x = np.pad(x, ((0, n_pad), (0, 0)), mode='reflective')
137
+
138
+ y = np.zeros_like(x)
139
+ for idx in range(0, x.shape[0], processor.block_size):
140
+ y[idx:idx+processor.block_size, :] = processor.process(x[idx:idx+processor.block_size, :])
141
+
142
+ if rms_normalize:
143
+ # normalize output energy such that it is the same as the input energy
144
+ scale = np.sqrt(np.mean(np.square(x)) / np.maximum(1e-7, np.mean(np.square(y))))
145
+ y *= scale
146
+
147
+ # return audio of same length as x
148
+ return y[:n_samples_input, :]
149
+
150
+ def apply_same_processor(self, x_list, processor: Processor, rms_normalize):
151
+ for i in range(len(x_list)):
152
+ x_list[i] = self.apply_processor(x_list[i], processor, rms_normalize)
153
+
154
+ return x_list
155
+
156
+ def __call__(self, x_list):
157
+ """
158
+ Apply the same augmentation chain to audio tracks in list `x_list`.
159
+
160
+ Args:
161
+ x_list (list of Numpy array) : List of audio samples of shape `n_samples` x `n_channels`.
162
+
163
+ Returns:
164
+ y_list (list of Numpy array) : List of processed audio of same shape as `x_list` where the same effects have been applied.
165
+ """
166
+ # randomly shuffle effect order if `self.shuffle` is True
167
+ if self.shuffle:
168
+ shuffle(self.fxs)
169
+
170
+ # apply effects with probabilities given in `self.fxs`
171
+ y_list = x_list.copy()
172
+ for fx, p, rms_normalize in self.fxs:
173
+ if np.random.rand() < p:
174
+ if isinstance(fx, Processor):
175
+ # randomize all effect parameters (also calls `update()` for each processor)
176
+ if self.randomize_param_value:
177
+ fx.randomize()
178
+ else:
179
+ fx.update(None)
180
+
181
+ # apply processor
182
+ y_list = self.apply_same_processor(y_list, fx, rms_normalize)
183
+ else:
184
+ y_list = fx(y_list)
185
+
186
+ if self.parallel:
187
+ # weighting factor of input signal in the range of (0.0 ~ 0.5)
188
+ weight_in = self.parallel_weight_factor if self.parallel_weight_factor else np.random.rand() / 2.
189
+ for i in range(len(y_list)):
190
+ y_list[i] = weight_in*x_list[i] + (1-weight_in)*y_list[i]
191
+
192
+ return y_list
193
+
194
+ def __repr__(self):
195
+ """
196
+ Human-readable representation.
197
+
198
+ Returns:
199
+ string representation of object.
200
+ """
201
+ return f'AugmentationChain(fxs={self.fxs!r}, shuffle={self.shuffle!r})'
202
+
203
+
204
+ # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% DISTORTION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
205
+ def hard_clip(x, threshold_dB, drive):
206
+ """
207
+ Hard clip distortion.
208
+
209
+ Args:
210
+ x: input audio
211
+ threshold_dB: threshold
212
+ drive: drive
213
+
214
+ Returns:
215
+ (Numpy array): distorted audio
216
+ """
217
+ drive_linear = np.power(10., drive / 20.).astype(np.float32)
218
+ threshold_linear = 10. ** (threshold_dB / 20.)
219
+ return np.clip(x * drive_linear, -threshold_linear, threshold_linear)
220
+
221
+
222
+ def overdrive(x, drive, colour, sample_rate):
223
+ """
224
+ Overdrive distortion.
225
+
226
+ Args:
227
+ x: input audio
228
+ drive: Controls the amount of distortion (dB).
229
+ colour: Controls the amount of even harmonic content in the output(dB)
230
+ sample_rate: sampling rate
231
+
232
+ Returns:
233
+ (Numpy array): distorted audio
234
+ """
235
+ scale = np.max(np.abs(x))
236
+ if scale > 0.9:
237
+ clips = True
238
+ x = x * (0.9 / scale)
239
+ else:
240
+ clips = False
241
+
242
+ tfm = sox.Transformer()
243
+ tfm.overdrive(gain_db=drive, colour=colour)
244
+ y = tfm.build_array(input_array=x, sample_rate_in=sample_rate).astype(np.float32)
245
+
246
+ if clips:
247
+ y *= scale / 0.9 # rescale output to original scale
248
+ return y
249
+
250
+
251
+ def hyperbolic_tangent(x, drive):
252
+ """
253
+ Hyperbolic Tanh distortion.
254
+
255
+ Args:
256
+ x: input audio
257
+ drive: drive
258
+
259
+ Returns:
260
+ (Numpy array): distorted audio
261
+ """
262
+ drive_linear = np.power(10., drive / 20.).astype(np.float32)
263
+ return np.tanh(2. * x * drive_linear)
264
+
265
+
266
+ def soft_sine(x, drive):
267
+ """
268
+ Soft sine distortion.
269
+
270
+ Args:
271
+ x: input audio
272
+ drive: drive
273
+
274
+ Returns:
275
+ (Numpy array): distorted audio
276
+ """
277
+ drive_linear = np.power(10., drive / 20.).astype(np.float32)
278
+ y = np.clip(x * drive_linear, -np.pi/4.0, np.pi/4.0)
279
+ return np.sin(2. * y)
280
+
281
+
282
+ def bit_crusher(x, bits):
283
+ """
284
+ Bit crusher distortion.
285
+
286
+ Args:
287
+ x: input audio
288
+ bits: bits
289
+
290
+ Returns:
291
+ (Numpy array): distorted audio
292
+ """
293
+ return np.rint(x * (2 ** bits)) / (2 ** bits)
294
+
295
+
296
+ class Distortion(Processor):
297
+ """
298
+ Distortion processor.
299
+
300
+ Processor parameters:
301
+ mode (str): Currently supports the following five modes: hard_clip, waveshaper, soft_sine, tanh, bit_crusher.
302
+ Each mode has different parameters such as threshold, factor, or bits.
303
+ threshold (float): threshold
304
+ drive (float): drive
305
+ factor (float): factor
306
+ limit_range (float): limit range
307
+ bits (int): bits
308
+ """
309
+
310
+ def __init__(self, sample_rate, name='Distortion', parameters=None):
311
+ """
312
+ Initialize processor.
313
+
314
+ Args:
315
+ sample_rate (int): sample rate.
316
+ name (str): Name of processor.
317
+ parameters (parameter_list): Parameters for this processor.
318
+ """
319
+ super().__init__(name, None, block_size=None, sample_rate=sample_rate)
320
+ if not parameters:
321
+ self.parameters = ParameterList()
322
+ self.parameters.add(Parameter('mode', 'hard_clip', 'string',
323
+ options=['hard_clip',
324
+ 'overdrive',
325
+ 'soft_sine',
326
+ 'tanh',
327
+ 'bit_crusher']))
328
+ self.parameters.add(Parameter('threshold', 0.0, 'float',
329
+ units='dB', maximum=0.0, minimum=-20.0))
330
+ self.parameters.add(Parameter('drive', 0.0, 'float',
331
+ units='dB', maximum=20.0, minimum=0.0))
332
+ self.parameters.add(Parameter('colour', 20.0, 'float',
333
+ maximum=100.0, minimum=0.0))
334
+ self.parameters.add(Parameter('bits', 12, 'int',
335
+ maximum=12, minimum=8))
336
+
337
+ def process(self, x):
338
+ """
339
+ Process audio.
340
+
341
+ Args:
342
+ x (Numpy array): input audio of size `n_samples x n_channels`.
343
+
344
+ Returns:
345
+ (Numpy array): distorted audio of size `n_samples x n_channels`.
346
+ """
347
+ if self.parameters.mode.value == 'hard_clip':
348
+ y = hard_clip(x, self.parameters.threshold.value, self.parameters.drive.value)
349
+ elif self.parameters.mode.value == 'overdrive':
350
+ y = overdrive(x, self.parameters.drive.value,
351
+ self.parameters.colour.value, self.sample_rate)
352
+ elif self.parameters.mode.value == 'soft_sine':
353
+ y = soft_sine(x, self.parameters.drive.value)
354
+ elif self.parameters.mode.value == 'tanh':
355
+ y = hyperbolic_tangent(x, self.parameters.drive.value)
356
+ elif self.parameters.mode.value == 'bit_crusher':
357
+ y = bit_crusher(x, self.parameters.bits.value)
358
+
359
+ # If the output has low amplitude, (some distortion settigns can "crush" down the amplitude)
360
+ # Then it`s normalised to the input's amplitude
361
+ x_max = np.max(np.abs(x)) + 1e-8
362
+ o_max = np.max(np.abs(y)) + 1e-8
363
+ if x_max > o_max:
364
+ y = y*(x_max/o_max)
365
+
366
+ return y
367
+
368
+
369
+ # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% EQUALISER %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
370
+ class Equaliser(Processor):
371
+ """
372
+ Five band parametric equaliser (two shelves and three central bands).
373
+
374
+ All gains are set in dB values and range from `MIN_GAIN` dB to `MAX_GAIN` dB.
375
+ This processor is implemented as cascade of five biquad IIR filters
376
+ that are implemented using the infamous cookbook formulae from RBJ.
377
+
378
+ Processor parameters:
379
+ low_shelf_gain (float), low_shelf_freq (float)
380
+ first_band_gain (float), first_band_freq (float), first_band_q (float)
381
+ second_band_gain (float), second_band_freq (float), second_band_q (float)
382
+ third_band_gain (float), third_band_freq (float), third_band_q (float)
383
+
384
+ original from https://github.com/csteinmetz1/pymixconsole/blob/master/pymixconsole/processors/equaliser.py
385
+ """
386
+
387
+ def __init__(self, n_channels,
388
+ sample_rate,
389
+ gain_range=(-15.0, 15.0),
390
+ q_range=(0.1, 2.0),
391
+ bands=['low_shelf', 'first_band', 'second_band', 'third_band', 'high_shelf'],
392
+ hard_clip=False,
393
+ name='Equaliser', parameters=None):
394
+ """
395
+ Initialize processor.
396
+
397
+ Args:
398
+ n_channels (int): Number of audio channels.
399
+ sample_rate (int): Sample rate of audio.
400
+ gain_range (tuple of floats): minimum and maximum gain that can be used.
401
+ q_range (tuple of floats): minimum and maximum q value.
402
+ hard_clip (bool): Whether we clip to [-1, 1.] after processing.
403
+ name (str): Name of processor.
404
+ parameters (parameter_list): Parameters for this processor.
405
+ """
406
+ super().__init__(name, parameters=parameters, block_size=None, sample_rate=sample_rate)
407
+
408
+ self.n_channels = n_channels
409
+
410
+ MIN_GAIN, MAX_GAIN = gain_range
411
+ MIN_Q, MAX_Q = q_range
412
+
413
+ if not parameters:
414
+ self.parameters = ParameterList()
415
+ # low shelf parameters -------
416
+ self.parameters.add(Parameter('low_shelf_gain', 0.0, 'float', minimum=MIN_GAIN, maximum=MAX_GAIN))
417
+ self.parameters.add(Parameter('low_shelf_freq', 80.0, 'float', minimum=30.0, maximum=200.0))
418
+ # first band parameters ------
419
+ self.parameters.add(Parameter('first_band_gain', 0.0, 'float', minimum=MIN_GAIN, maximum=MAX_GAIN))
420
+ self.parameters.add(Parameter('first_band_freq', 400.0, 'float', minimum=200.0, maximum=1000.0))
421
+ self.parameters.add(Parameter('first_band_q', 0.7, 'float', minimum=MIN_Q, maximum=MAX_Q))
422
+ # second band parameters -----
423
+ self.parameters.add(Parameter('second_band_gain', 0.0, 'float', minimum=MIN_GAIN, maximum=MAX_GAIN))
424
+ self.parameters.add(Parameter('second_band_freq', 2000.0, 'float', minimum=1000.0, maximum=3000.0))
425
+ self.parameters.add(Parameter('second_band_q', 0.7, 'float', minimum=MIN_Q, maximum=MAX_Q))
426
+ # third band parameters ------
427
+ self.parameters.add(Parameter('third_band_gain', 0.0, 'float', minimum=MIN_GAIN, maximum=MAX_GAIN))
428
+ self.parameters.add(Parameter('third_band_freq', 4000.0, 'float', minimum=3000.0, maximum=8000.0))
429
+ self.parameters.add(Parameter('third_band_q', 0.7, 'float', minimum=MIN_Q, maximum=MAX_Q))
430
+ # high shelf parameters ------
431
+ self.parameters.add(Parameter('high_shelf_gain', 0.0, 'float', minimum=MIN_GAIN, maximum=MAX_GAIN))
432
+ self.parameters.add(Parameter('high_shelf_freq', 8000.0, 'float', minimum=5000.0, maximum=10000.0))
433
+
434
+ self.bands = bands
435
+ self.filters = self.setup_filters()
436
+ self.hard_clip = hard_clip
437
+
438
+ def setup_filters(self):
439
+ """
440
+ Create IIR filters.
441
+
442
+ Returns:
443
+ IIR filters
444
+ """
445
+ filters = {}
446
+
447
+ for band in self.bands:
448
+
449
+ G = getattr(self.parameters, band + '_gain').value
450
+ fc = getattr(self.parameters, band + '_freq').value
451
+ rate = self.sample_rate
452
+
453
+ if band in ['low_shelf', 'high_shelf']:
454
+ Q = 0.707
455
+ filter_type = band
456
+ else:
457
+ Q = getattr(self.parameters, band + '_q').value
458
+ filter_type = 'peaking'
459
+
460
+ filters[band] = pymc.components.iirfilter.IIRfilter(G, Q, fc, rate, filter_type, n_channels=self.n_channels)
461
+
462
+ return filters
463
+
464
+ def update_filter(self, band):
465
+ """
466
+ Update filters.
467
+
468
+ Args:
469
+ band (str): Band that should be updated.
470
+ """
471
+ self.filters[band].G = getattr(self.parameters, band + '_gain').value
472
+ self.filters[band].fc = getattr(self.parameters, band + '_freq').value
473
+ self.filters[band].rate = self.sample_rate
474
+
475
+ if band in ['first_band', 'second_band', 'third_band']:
476
+ self.filters[band].Q = getattr(self.parameters, band + '_q').value
477
+
478
+ def update(self, parameter_name=None):
479
+ """
480
+ Update processor after randomization of parameters.
481
+
482
+ Args:
483
+ parameter_name (str): Parameter whose value has changed.
484
+ """
485
+ if parameter_name is not None:
486
+ bands = ['_'.join(parameter_name.split('_')[:2])]
487
+ else:
488
+ bands = self.bands
489
+
490
+ for band in bands:
491
+ self.update_filter(band)
492
+
493
+ for _band, iirfilter in self.filters.items():
494
+ iirfilter.reset_state()
495
+
496
+ def reset_state(self):
497
+ """Reset state."""
498
+ for _band, iirfilter in self.filters.items():
499
+ iirfilter.reset_state()
500
+
501
+ def process(self, x):
502
+ """
503
+ Process audio.
504
+
505
+ Args:
506
+ x (Numpy array): input audio of size `n_samples x n_channels`.
507
+
508
+ Returns:
509
+ (Numpy array): equalized audio of size `n_samples x n_channels`.
510
+ """
511
+ for _band, iirfilter in self.filters.items():
512
+ iirfilter.reset_state()
513
+ x = iirfilter.apply_filter(x)
514
+
515
+ if self.hard_clip:
516
+ x = np.clip(x, -1.0, 1.0)
517
+
518
+ # make sure that we have float32 as IIR filtering returns float64
519
+ x = x.astype(np.float32)
520
+
521
+ # make sure that we have two dimensions (if `n_channels == 1`)
522
+ if x.ndim == 1:
523
+ x = x[:, np.newaxis]
524
+
525
+ return x
526
+
527
+
528
+ # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% COMPRESSOR %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
529
+ @jit(nopython=True)
530
+ def compressor_process(x, threshold, attack_time, release_time, ratio, makeup_gain, sample_rate, yL_prev):
531
+ """
532
+ Apply compressor.
533
+
534
+ Args:
535
+ x (Numpy array): audio data.
536
+ threshold: threshold in dB.
537
+ attack_time: attack_time in ms.
538
+ release_time: release_time in ms.
539
+ ratio: ratio.
540
+ makeup_gain: makeup_gain.
541
+ sample_rate: sample rate.
542
+ yL_prev: internal state of the envelop gain.
543
+
544
+ Returns:
545
+ compressed audio.
546
+ """
547
+ M = x.shape[0]
548
+ x_g = np.zeros(M)
549
+ x_l = np.zeros(M)
550
+ y_g = np.zeros(M)
551
+ y_l = np.zeros(M)
552
+ c = np.zeros(M)
553
+ yL_prev = 0.
554
+
555
+ alpha_attack = np.exp(-1/(0.001 * sample_rate * attack_time))
556
+ alpha_release = np.exp(-1/(0.001 * sample_rate * release_time))
557
+
558
+ for i in np.arange(M):
559
+ if np.abs(x[i]) < 0.000001:
560
+ x_g[i] = -120.0
561
+ else:
562
+ x_g[i] = 20 * np.log10(np.abs(x[i]))
563
+
564
+ if ratio > 1:
565
+ if x_g[i] >= threshold:
566
+ y_g[i] = threshold + (x_g[i] - threshold) / ratio
567
+ else:
568
+ y_g[i] = x_g[i]
569
+ elif ratio < 1:
570
+ if x_g[i] <= threshold:
571
+ y_g[i] = threshold + (x_g[i] - threshold) / (1/ratio)
572
+ else:
573
+ y_g[i] = x_g[i]
574
+
575
+ x_l[i] = x_g[i] - y_g[i]
576
+
577
+ if x_l[i] > yL_prev:
578
+ y_l[i] = alpha_attack * yL_prev + (1 - alpha_attack) * x_l[i]
579
+ else:
580
+ y_l[i] = alpha_release * yL_prev + (1 - alpha_release) * x_l[i]
581
+
582
+ c[i] = np.power(10.0, (makeup_gain - y_l[i]) / 20.0)
583
+ yL_prev = y_l[i]
584
+
585
+ y = x * c
586
+
587
+ return y, yL_prev
588
+
589
+
590
+ class Compressor(Processor):
591
+ """
592
+ Single band stereo dynamic range compressor.
593
+
594
+ Processor parameters:
595
+ threshold (float)
596
+ attack_time (float)
597
+ release_time (float)
598
+ ratio (float)
599
+ makeup_gain (float)
600
+ """
601
+
602
+ def __init__(self, sample_rate, name='Compressor', parameters=None):
603
+ """
604
+ Initialize processor.
605
+
606
+ Args:
607
+ sample_rate (int): Sample rate of input audio.
608
+ name (str): Name of processor.
609
+ parameters (parameter_list): Parameters for this processor.
610
+ """
611
+ super().__init__(name=name, parameters=parameters, block_size=None, sample_rate=sample_rate)
612
+
613
+ if not parameters:
614
+ self.parameters = ParameterList()
615
+ self.parameters.add(Parameter('threshold', -20.0, 'float', units='dB', minimum=-80.0, maximum=-5.0))
616
+ self.parameters.add(Parameter('attack_time', 2.0, 'float', units='ms', minimum=1., maximum=20.0))
617
+ self.parameters.add(Parameter('release_time', 100.0, 'float', units='ms', minimum=50.0, maximum=500.0))
618
+ self.parameters.add(Parameter('ratio', 4.0, 'float', minimum=4., maximum=40.0))
619
+ # we remove makeup_gain parameter inside the Compressor
620
+
621
+ # store internal state (for block-wise processing)
622
+ self.yL_prev = None
623
+
624
+ def process(self, x):
625
+ """
626
+ Process audio.
627
+
628
+ Args:
629
+ x (Numpy array): input audio of size `n_samples x n_channels`.
630
+
631
+ Returns:
632
+ (Numpy array): compressed audio of size `n_samples x n_channels`.
633
+ """
634
+ if self.yL_prev is None:
635
+ self.yL_prev = [0.] * x.shape[1]
636
+
637
+ if not self.parameters.threshold.value == 0.0 or not self.parameters.ratio.value == 1.0:
638
+ y = np.zeros_like(x)
639
+
640
+ for ch in range(x.shape[1]):
641
+ y[:, ch], self.yL_prev[ch] = compressor_process(x[:, ch],
642
+ self.parameters.threshold.value,
643
+ self.parameters.attack_time.value,
644
+ self.parameters.release_time.value,
645
+ self.parameters.ratio.value,
646
+ 0.0, # makeup_gain = 0
647
+ self.sample_rate,
648
+ self.yL_prev[ch])
649
+ else:
650
+ y = x
651
+
652
+ return y
653
+
654
+ def update(self, parameter_name=None):
655
+ """
656
+ Update processor after randomization of parameters.
657
+
658
+ Args:
659
+ parameter_name (str): Parameter whose value has changed.
660
+ """
661
+ self.yL_prev = None
662
+
663
+
664
+ # %%%%%%%%%%%%%%%%%%%%%%%%%% CONVOLUTIONAL REVERB %%%%%%%%%%%%%%%%%%%%%%%%%%%%%
665
+ class ConvolutionalReverb(Processor):
666
+ """
667
+ Convolutional Reverb.
668
+
669
+ Processor parameters:
670
+ wet_dry (float): Wet/dry ratio.
671
+ decay (float): Applies a fade out to the impulse response.
672
+ pre_delay (float): Value in ms. Shifts the IR in time and allows.
673
+ A positive value produces a traditional delay between the dry signal and the wet.
674
+ A negative delay is, in reality, zero delay, but effectively trims off the start of IR,
675
+ so the reverb response begins at a point further in.
676
+ """
677
+
678
+ def __init__(self, impulse_responses, sample_rate, name='ConvolutionalReverb', parameters=None):
679
+ """
680
+ Initialize processor.
681
+
682
+ Args:
683
+ impulse_responses (list): List with impulse responses created by `common_dataprocessing.create_dataset`
684
+ sample_rate (int): Sample rate that we should assume (used for fade-out computation)
685
+ name (str): Name of processor.
686
+ parameters (parameter_list): Parameters for this processor.
687
+
688
+ Raises:
689
+ ValueError: if no impulse responses are provided.
690
+ """
691
+ super().__init__(name=name, parameters=parameters, block_size=None, sample_rate=sample_rate)
692
+
693
+ if impulse_responses is None:
694
+ raise ValueError('List of impulse responses must be provided for ConvolutionalReverb processor.')
695
+ self.impulse_responses = impulse_responses
696
+
697
+ if not parameters:
698
+ self.parameters = ParameterList()
699
+ self.max_ir_num = len(max(impulse_responses, key=len))
700
+ self.parameters.add(Parameter('index', 0, 'int', minimum=0, maximum=len(impulse_responses)))
701
+ self.parameters.add(Parameter('index_ir', 0, 'int', minimum=0, maximum=self.max_ir_num))
702
+ self.parameters.add(Parameter('wet', 1.0, 'float', minimum=1.0, maximum=1.0))
703
+ self.parameters.add(Parameter('dry', 0.0, 'float', minimum=0.0, maximum=0.0))
704
+ self.parameters.add(Parameter('decay', 1.0, 'float', minimum=1.0, maximum=1.0))
705
+ self.parameters.add(Parameter('pre_delay', 0, 'int', units='ms', minimum=0, maximum=0))
706
+
707
+ def update(self, parameter_name=None):
708
+ """
709
+ Update processor after randomization of parameters.
710
+
711
+ Args:
712
+ parameter_name (str): Parameter whose value has changed.
713
+ """
714
+ # we sample IR with a uniform random distribution according to RT60 values
715
+ chosen_ir_duration = self.impulse_responses[self.parameters.index.value]
716
+ chosen_ir_idx = self.parameters.index_ir.value % len(chosen_ir_duration)
717
+ self.h = np.copy(chosen_ir_duration[chosen_ir_idx]['impulse_response']())
718
+
719
+ # fade out the impulse based on the decay setting (starting from peak value)
720
+ if self.parameters.decay.value < 1.:
721
+ idx_peak = np.argmax(np.max(np.abs(self.h), axis=1), axis=0)
722
+ fstart = np.minimum(self.h.shape[0],
723
+ idx_peak + int(self.parameters.decay.value * (self.h.shape[0] - idx_peak)))
724
+ fstop = np.minimum(self.h.shape[0], fstart + int(0.020*self.sample_rate)) # constant 20 ms fade out
725
+ flen = fstop - fstart
726
+
727
+ fade = np.arange(1, flen+1, dtype=self.dtype)/flen
728
+ fade = np.power(0.1, fade * 5)
729
+ self.h[fstart:fstop, :] *= fade[:, np.newaxis]
730
+ self.h = self.h[:fstop]
731
+
732
+ def process(self, x):
733
+ """
734
+ Process audio.
735
+
736
+ Args:
737
+ x (Numpy array): input audio of size `n_samples x n_channels`.
738
+
739
+ Returns:
740
+ (Numpy array): reverbed audio of size `n_samples x n_channels`.
741
+ """
742
+ # reshape IR to the correct size
743
+ n_channels = x.shape[1]
744
+ if self.h.shape[1] == 1 and n_channels > 1:
745
+ self.h = np.hstack([self.h] * n_channels) # repeat mono IR for multi-channel input
746
+ if self.h.shape[1] > 1 and n_channels == 1:
747
+ self.h = self.h[:, np.random.randint(self.h.shape[1]), np.newaxis] # randomly choose one IR channel
748
+
749
+ if self.parameters.wet.value == 0.0:
750
+ return x
751
+ else:
752
+ # perform convolution to get wet signal
753
+ y = oaconvolve(x, self.h, mode='full', axes=0)
754
+
755
+ # cut out wet signal (compensating for the delay that the IR is introducing + predelay)
756
+ idx = np.argmax(np.max(np.abs(self.h), axis=1), axis=0)
757
+ idx += int(0.001 * np.abs(self.parameters.pre_delay.value) * self.sample_rate)
758
+
759
+ idx = np.clip(idx, 0, self.h.shape[0]-1)
760
+
761
+ y = y[idx:idx+x.shape[0], :]
762
+
763
+ # return weighted sum of dry and wet signal
764
+ return self.parameters.dry.value * x + self.parameters.wet.value * y
765
+
766
+
767
+ # %%%%%%%%%%%%%%%%%%%%%%%%%%%%% HAAS EFFECT %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
768
+ def haas_process(x, delay, feedback, wet_channel):
769
+ """
770
+ Add Haas effect to audio.
771
+
772
+ Args:
773
+ x (Numpy array): input audio.
774
+ delay: Delay that we apply to one of the channels (in samples).
775
+ feedback: Feedback value.
776
+ wet_channel: Which channel we process (`left` or `right`).
777
+
778
+ Returns:
779
+ (Numpy array): Audio with Haas effect.
780
+ """
781
+ y = np.copy(x)
782
+ if wet_channel == 'left':
783
+ y[:, 0] += feedback * np.roll(x[:, 0], delay)
784
+ elif wet_channel == 'right':
785
+ y[:, 1] += feedback * np.roll(x[:, 1], delay)
786
+
787
+ return y
788
+
789
+
790
+ class Haas(Processor):
791
+ """
792
+ Haas Effect Processor.
793
+
794
+ Randomly selects one channel and applies a short delay to it.
795
+
796
+ Processor parameters:
797
+ delay (int)
798
+ feedback (float)
799
+ wet_channel (string)
800
+ """
801
+
802
+ def __init__(self, sample_rate, delay_range=(-0.040, 0.040), name='Haas', parameters=None,
803
+ ):
804
+ """
805
+ Initialize processor.
806
+
807
+ Args:
808
+ sample_rate (int): Sample rate of input audio.
809
+ delay_range (tuple of floats): minimum/maximum delay for Haas effect.
810
+ name (str): Name of processor.
811
+ parameters (parameter_list): Parameters for this processor.
812
+ """
813
+ super().__init__(name=name, parameters=parameters, block_size=None, sample_rate=sample_rate)
814
+
815
+ if not parameters:
816
+ self.parameters = ParameterList()
817
+ self.parameters.add(Parameter('delay', int(delay_range[1] * sample_rate), 'int', units='samples',
818
+ minimum=int(delay_range[0] * sample_rate),
819
+ maximum=int(delay_range[1] * sample_rate)))
820
+ self.parameters.add(Parameter('feedback', 0.35, 'float', minimum=0.33, maximum=0.66))
821
+ self.parameters.add(Parameter('wet_channel', 'left', 'string', options=['left', 'right']))
822
+
823
+ def process(self, x):
824
+ """
825
+ Process audio.
826
+
827
+ Args:
828
+ x (Numpy array): input audio of size `n_samples x n_channels`.
829
+
830
+ Returns:
831
+ (Numpy array): audio with Haas effect of size `n_samples x n_channels`.
832
+ """
833
+ assert x.shape[1] == 1 or x.shape[1] == 2, 'Haas effect only works with monaural or stereo audio.'
834
+
835
+ if x.shape[1] < 2:
836
+ x = np.repeat(x, 2, axis=1)
837
+
838
+ y = haas_process(x, self.parameters.delay.value,
839
+ self.parameters.feedback.value, self.parameters.wet_channel.value)
840
+
841
+ return y
842
+
843
+ def update(self, parameter_name=None):
844
+ """
845
+ Update processor after randomization of parameters.
846
+
847
+ Args:
848
+ parameter_name (str): Parameter whose value has changed.
849
+ """
850
+ self.reset_state()
851
+
852
+ def reset_state(self):
853
+ """Reset state."""
854
+ self.read_idx = 0
855
+ self.write_idx = self.parameters.delay.value
856
+ self.buffer = np.zeros((65536, 2))
857
+
858
+
859
+ # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%% PANNER %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
860
+ class Panner(Processor):
861
+ """
862
+ Simple stereo panner.
863
+
864
+ If input is mono, output is stereo.
865
+ Original edited from https://github.com/csteinmetz1/pymixconsole/blob/master/pymixconsole/processors/panner.py
866
+ """
867
+
868
+ def __init__(self, name='Panner', parameters=None):
869
+ """
870
+ Initialize processor.
871
+
872
+ Args:
873
+ name (str): Name of processor.
874
+ parameters (parameter_list): Parameters for this processor.
875
+ """
876
+ # default processor class constructor
877
+ super().__init__(name=name, parameters=parameters, block_size=None, sample_rate=None)
878
+
879
+ if not parameters:
880
+ self.parameters = ParameterList()
881
+ self.parameters.add(Parameter('pan', 0.5, 'float', minimum=0., maximum=1.))
882
+ self.parameters.add(Parameter('pan_law', '-4.5dB', 'string',
883
+ options=['-4.5dB', 'linear', 'constant_power']))
884
+
885
+ # setup the coefficents based on default params
886
+ self.update()
887
+
888
+ def _calculate_pan_coefficents(self):
889
+ """
890
+ Calculate panning coefficients from the chosen pan law.
891
+
892
+ Based on the set pan law determine the gain value
893
+ to apply for the left and right channel to achieve panning effect.
894
+ This operates on the assumption that the input channel is mono.
895
+ The output data will be stereo at the moment, but could be expanded
896
+ to a higher channel count format.
897
+ The panning value is in the range [0, 1], where
898
+ 0 means the signal is panned completely to the left, and
899
+ 1 means the signal is apanned copletely to the right.
900
+
901
+ Raises:
902
+ ValueError: `self.parameters.pan_law` is not supported.
903
+ """
904
+ self.gains = np.zeros(2, dtype=self.dtype)
905
+
906
+ # first scale the linear [0, 1] to [0, pi/2]
907
+ theta = self.parameters.pan.value * (np.pi/2)
908
+
909
+ if self.parameters.pan_law.value == 'linear':
910
+ self.gains[0] = ((np.pi/2) - theta) * (2/np.pi)
911
+ self.gains[1] = theta * (2/np.pi)
912
+ elif self.parameters.pan_law.value == 'constant_power':
913
+ self.gains[0] = np.cos(theta)
914
+ self.gains[1] = np.sin(theta)
915
+ elif self.parameters.pan_law.value == '-4.5dB':
916
+ self.gains[0] = np.sqrt(((np.pi/2) - theta) * (2/np.pi) * np.cos(theta))
917
+ self.gains[1] = np.sqrt(theta * (2/np.pi) * np.sin(theta))
918
+ else:
919
+ raise ValueError(f'Invalid pan_law {self.parameters.pan_law.value}.')
920
+
921
+
922
+ def process(self, x):
923
+ """
924
+ Process audio.
925
+
926
+ Args:
927
+ x (Numpy array): input audio of size `n_samples x n_channels`.
928
+
929
+ Returns:
930
+ (Numpy array): panned audio of size `n_samples x n_channels`.
931
+ """
932
+ assert x.shape[1] == 1 or x.shape[1] == 2, 'Panner only works with monaural or stereo audio.'
933
+
934
+ if x.shape[1] < 2:
935
+ x = np.repeat(x, 2, axis=1)
936
+
937
+
938
+ return x * self.gains
939
+
940
+ def update(self, parameter_name=None):
941
+ """
942
+ Update processor after randomization of parameters.
943
+
944
+ Args:
945
+ parameter_name (str): Parameter whose value has changed.
946
+ """
947
+ self._calculate_pan_coefficents()
948
+
949
+ def reset_state(self):
950
+ """Reset state."""
951
+ self._output_buffer = np.empty([self.block_size, 2])
952
+ self.update()
953
+
954
+
955
+ # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%% STEREO IMAGER %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
956
+ class MidSideImager(Processor):
957
+ def __init__(self, name='IMAGER', parameters=None):
958
+ super().__init__(name, parameters=parameters, block_size=None, sample_rate=None)
959
+
960
+ if not parameters:
961
+ self.parameters = ParameterList()
962
+ # values of 0.0~1.0 indicate making the signal more centered while 1.0~2.0 means making the signal more wider
963
+ self.parameters.add(Parameter("bal", 0.0, "float", processor=self, minimum=0.0, maximum=2.0))
964
+
965
+ def process(self, data):
966
+ """
967
+ # input shape : [signal length, 2]
968
+ ### note! stereo imager won't work if the input signal is a mono signal (left==right)
969
+ ### if you want to apply stereo imager to a mono signal, first stereoize it with Haas effects
970
+ """
971
+
972
+ # to mid-side channels
973
+ mid, side = self.lr_to_ms(data[:,0], data[:,1])
974
+ # apply mid-side weights according to energy
975
+ mid_e, side_e = np.sum(mid**2), np.sum(side**2)
976
+ total_e = mid_e + side_e
977
+ # apply weights
978
+ max_side_multiplier = np.sqrt(total_e / (side_e + 1e-3))
979
+ # compute current multiply factor
980
+ cur_bal = round(getattr(self.parameters, "bal").value, 3)
981
+ side_gain = cur_bal if cur_bal <= 1. else max_side_multiplier * (cur_bal-1)
982
+ # multiply weighting factor
983
+ new_side = side * side_gain
984
+ new_side_e = side_e * (side_gain ** 2)
985
+ left_mid_e = total_e - new_side_e
986
+ mid_gain = np.sqrt(left_mid_e / (mid_e + 1e-3))
987
+ new_mid = mid * mid_gain
988
+ # convert back to left-right channels
989
+ left, right = self.ms_to_lr(new_mid, new_side)
990
+ imaged = np.stack([left, right], 1)
991
+
992
+ return imaged
993
+
994
+ # left-right channeled signal to mid-side signal
995
+ def lr_to_ms(self, left, right):
996
+ mid = left + right
997
+ side = left - right
998
+ return mid, side
999
+
1000
+ # mid-side channeled signal to left-right signal
1001
+ def ms_to_lr(self, mid, side):
1002
+ left = (mid + side) / 2
1003
+ right = (mid - side) / 2
1004
+ return left, right
1005
+
1006
+ def update(self, parameter_name=None):
1007
+ return parameter_name
1008
+
1009
+
1010
+ # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% GAIN %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1011
+ class Gain(Processor):
1012
+ """
1013
+ Gain Processor.
1014
+
1015
+ Applies gain in dB and can also randomly inverts polarity.
1016
+
1017
+ Processor parameters:
1018
+ gain (float): Gain that should be applied (dB scale).
1019
+ invert (bool): If True, then we also invert the waveform.
1020
+ """
1021
+
1022
+ def __init__(self, name='Gain', parameters=None):
1023
+ """
1024
+ Initialize processor.
1025
+
1026
+ Args:
1027
+ name (str): Name of processor.
1028
+ parameters (parameter_list): Parameters for this processor.
1029
+ """
1030
+ super().__init__(name, parameters=parameters, block_size=None, sample_rate=None)
1031
+
1032
+ if not parameters:
1033
+ self.parameters = ParameterList()
1034
+ # self.parameters.add(Parameter('gain', 1.0, 'float', units='dB', minimum=-12.0, maximum=6.0))
1035
+ self.parameters.add(Parameter('gain', 1.0, 'float', units='dB', minimum=-6.0, maximum=9.0))
1036
+ self.parameters.add(Parameter('invert', False, 'bool'))
1037
+
1038
+ def process(self, x):
1039
+ """
1040
+ Process audio.
1041
+
1042
+ Args:
1043
+ x (Numpy array): input audio of size `n_samples x n_channels`.
1044
+
1045
+ Returns:
1046
+ (Numpy array): gain-augmented audio of size `n_samples x n_channels`.
1047
+ """
1048
+ gain = 10 ** (self.parameters.gain.value / 20.)
1049
+ if self.parameters.invert.value:
1050
+ gain = -gain
1051
+ return gain * x
1052
+
1053
+
1054
+ # %%%%%%%%%%%%%%%%%%%%%%% SIMPLE CHANNEL SWAP %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1055
+ class SwapChannels(Processor):
1056
+ """
1057
+ Swap channels in multi-channel audio.
1058
+
1059
+ Processor parameters:
1060
+ index (int) Selects the permutation that we are using.
1061
+ Please note that "no permutation" is one of the permutations in `self.permutations` at index `0`.
1062
+ """
1063
+
1064
+ def __init__(self, n_channels, name='SwapChannels', parameters=None):
1065
+ """
1066
+ Initialize processor.
1067
+
1068
+ Args:
1069
+ n_channels (int): Number of channels in audio that we want to process.
1070
+ name (str): Name of processor.
1071
+ parameters (parameter_list): Parameters for this processor.
1072
+ """
1073
+ super().__init__(name=name, parameters=parameters, block_size=None, sample_rate=None)
1074
+
1075
+ self.permutations = tuple(permutations(range(n_channels), n_channels))
1076
+
1077
+ if not parameters:
1078
+ self.parameters = ParameterList()
1079
+ self.parameters.add(Parameter('index', 0, 'int', minimum=0, maximum=len(self.permutations)))
1080
+
1081
+ def process(self, x):
1082
+ """
1083
+ Process audio.
1084
+
1085
+ Args:
1086
+ x (Numpy array): input audio of size `n_samples x n_channels`.
1087
+
1088
+ Returns:
1089
+ (Numpy array): channel-swapped audio of size `n_samples x n_channels`.
1090
+ """
1091
+ return x[:, self.permutations[self.parameters.index.value]]
1092
+
1093
+
1094
+ # %%%%%%%%%%%%%%%%%%%%%%% Monauralize %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1095
+ class Monauralize(Processor):
1096
+ """
1097
+ Monauralizes audio (i.e., removes spatial information).
1098
+
1099
+ Process parameters:
1100
+ seed_channel (int): channel that we use for overwriting the others.
1101
+ """
1102
+
1103
+ def __init__(self, n_channels, name='Monauralize', parameters=None):
1104
+ """
1105
+ Initialize processor.
1106
+
1107
+ Args:
1108
+ n_channels (int): Number of channels in audio that we want to process.
1109
+ name (str): Name of processor.
1110
+ parameters (parameter_list): Parameters for this processor.
1111
+ """
1112
+ super().__init__(name=name, parameters=parameters, block_size=None, sample_rate=None)
1113
+
1114
+ if not parameters:
1115
+ self.parameters = ParameterList()
1116
+ self.parameters.add(Parameter('seed_channel', 0, 'int', minimum=0, maximum=n_channels))
1117
+
1118
+ def process(self, x):
1119
+ """
1120
+ Process audio.
1121
+
1122
+ Args:
1123
+ x (Numpy array): input audio of size `n_samples x n_channels`.
1124
+
1125
+ Returns:
1126
+ (Numpy array): monauralized audio of size `n_samples x n_channels`.
1127
+ """
1128
+ return np.tile(x[:, [self.parameters.seed_channel.value]], (1, x.shape[1]))
1129
+
1130
+
1131
+ # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%% PITCH SHIFT %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1132
+ class PitchShift(Processor):
1133
+ """
1134
+ Simple pitch shifter using SoX and soxbindings (https://github.com/pseeth/soxbindings).
1135
+
1136
+ Processor parameters:
1137
+ steps (float): Pitch shift as positive/negative semitones
1138
+ quick (bool): If True, this effect will run faster but with lower sound quality.
1139
+ """
1140
+
1141
+ def __init__(self, sample_rate, fix_length=True, name='PitchShift', parameters=None):
1142
+ """
1143
+ Initialize processor.
1144
+
1145
+ Args:
1146
+ sample_rate (int): Sample rate of input audio.
1147
+ fix_length (bool): If True, then output has same length as input.
1148
+ name (str): Name of processor.
1149
+ parameters (parameter_list): Parameters for this processor.
1150
+ """
1151
+ super().__init__(name=name, parameters=parameters, block_size=None, sample_rate=sample_rate)
1152
+
1153
+ if not parameters:
1154
+ self.parameters = ParameterList()
1155
+ self.parameters.add(Parameter('steps', 0.0, 'float', minimum=-6., maximum=6.))
1156
+ self.parameters.add(Parameter('quick', False, 'bool'))
1157
+
1158
+ self.fix_length = fix_length
1159
+ self.clips = False
1160
+
1161
+ def process(self, x):
1162
+ """
1163
+ Process audio.
1164
+
1165
+ Args:
1166
+ x (Numpy array): input audio of size `n_samples x n_channels`.
1167
+
1168
+ Returns:
1169
+ (Numpy array): pitch-shifted audio of size `n_samples x n_channels`.
1170
+ """
1171
+ if self.parameters.steps.value == 0.0:
1172
+ y = x
1173
+ else:
1174
+ scale = np.max(np.abs(x))
1175
+ if scale > 0.9:
1176
+ clips = True
1177
+ x = x * (0.9 / scale)
1178
+ else:
1179
+ clips = False
1180
+
1181
+ tfm = sox.Transformer()
1182
+ tfm.pitch(self.parameters.steps.value, quick=bool(self.parameters.quick.value))
1183
+ y = tfm.build_array(input_array=x, sample_rate_in=self.sample_rate).astype(np.float32)
1184
+
1185
+ if clips:
1186
+ y *= scale / 0.9 # rescale output to original scale
1187
+
1188
+ if self.fix_length:
1189
+ n_samples_input = x.shape[0]
1190
+ n_samples_output = y.shape[0]
1191
+ if n_samples_input < n_samples_output:
1192
+ idx1 = (n_samples_output - n_samples_input) // 2
1193
+ idx2 = idx1 + n_samples_input
1194
+ y = y[idx1:idx2]
1195
+ elif n_samples_input > n_samples_output:
1196
+ n_pad = n_samples_input - n_samples_output
1197
+ y = np.pad(y, ((n_pad//2, n_pad - n_pad//2), (0, 0)))
1198
+
1199
+ return y
1200
+
1201
+
1202
+ # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%% TIME STRETCH %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1203
+ class TimeStretch(Processor):
1204
+ """
1205
+ Simple time stretcher using SoX and soxbindings (https://github.com/pseeth/soxbindings).
1206
+
1207
+ Processor parameters:
1208
+ factor (float): Time stretch factor.
1209
+ quick (bool): If True, this effect will run faster but with lower sound quality.
1210
+ stretch_type (str): Algorithm used for stretching (`tempo` or `stretch`).
1211
+ audio_type (str): Sets which time segments are most optmial when finding
1212
+ the best overlapping points for time stretching.
1213
+ """
1214
+
1215
+ def __init__(self, sample_rate, fix_length=True, name='TimeStretch', parameters=None):
1216
+ """
1217
+ Initialize processor.
1218
+
1219
+ Args:
1220
+ sample_rate (int): Sample rate of input audio.
1221
+ fix_length (bool): If True, then output has same length as input.
1222
+ name (str): Name of processor.
1223
+ parameters (parameter_list): Parameters for this processor.
1224
+ """
1225
+ super().__init__(name=name, parameters=parameters, block_size=None, sample_rate=sample_rate)
1226
+
1227
+ if not parameters:
1228
+ self.parameters = ParameterList()
1229
+ self.parameters.add(Parameter('factor', 1.0, 'float', minimum=1/1.33, maximum=1.33))
1230
+ self.parameters.add(Parameter('quick', False, 'bool'))
1231
+ self.parameters.add(Parameter('stretch_type', 'tempo', 'string', options=['tempo', 'stretch']))
1232
+ self.parameters.add(Parameter('audio_type', 'l', 'string', options=['m', 's', 'l']))
1233
+
1234
+ self.fix_length = fix_length
1235
+
1236
+ def process(self, x):
1237
+ """
1238
+ Process audio.
1239
+
1240
+ Args:
1241
+ x (Numpy array): input audio of size `n_samples x n_channels`.
1242
+
1243
+ Returns:
1244
+ (Numpy array): time-stretched audio of size `n_samples x n_channels`.
1245
+ """
1246
+ if self.parameters.factor.value == 1.0:
1247
+ y = x
1248
+ else:
1249
+ scale = np.max(np.abs(x))
1250
+ if scale > 0.9:
1251
+ clips = True
1252
+ x = x * (0.9 / scale)
1253
+ else:
1254
+ clips = False
1255
+
1256
+ tfm = sox.Transformer()
1257
+ if self.parameters.stretch_type.value == 'stretch':
1258
+ tfm.stretch(self.parameters.factor.value)
1259
+ elif self.parameters.stretch_type.value == 'tempo':
1260
+ tfm.tempo(self.parameters.factor.value,
1261
+ audio_type=self.parameters.audio_type.value,
1262
+ quick=bool(self.parameters.quick.value))
1263
+ y = tfm.build_array(input_array=x, sample_rate_in=self.sample_rate).astype(np.float32)
1264
+
1265
+ if clips:
1266
+ y *= scale / 0.9 # rescale output to original scale
1267
+
1268
+ if self.fix_length:
1269
+ n_samples_input = x.shape[0]
1270
+ n_samples_output = y.shape[0]
1271
+ if n_samples_input < n_samples_output:
1272
+ idx1 = (n_samples_output - n_samples_input) // 2
1273
+ idx2 = idx1 + n_samples_input
1274
+ y = y[idx1:idx2]
1275
+ elif n_samples_input > n_samples_output:
1276
+ n_pad = n_samples_input - n_samples_output
1277
+ y = np.pad(y, ((n_pad//2, n_pad - n_pad//2), (0, 0)))
1278
+
1279
+ return y
1280
+
1281
+
1282
+ # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%% PLAYBACK SPEED %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1283
+ class PlaybackSpeed(Processor):
1284
+ """
1285
+ Simple playback speed effect using SoX and soxbindings (https://github.com/pseeth/soxbindings).
1286
+
1287
+ Processor parameters:
1288
+ factor (float): Playback speed factor.
1289
+ """
1290
+
1291
+ def __init__(self, sample_rate, fix_length=True, name='PlaybackSpeed', parameters=None):
1292
+ """
1293
+ Initialize processor.
1294
+
1295
+ Args:
1296
+ sample_rate (int): Sample rate of input audio.
1297
+ fix_length (bool): If True, then output has same length as input.
1298
+ name (str): Name of processor.
1299
+ parameters (parameter_list): Parameters for this processor.
1300
+ """
1301
+ super().__init__(name=name, parameters=parameters, block_size=None, sample_rate=sample_rate)
1302
+
1303
+ if not parameters:
1304
+ self.parameters = ParameterList()
1305
+ self.parameters.add(Parameter('factor', 1.0, 'float', minimum=1./1.33, maximum=1.33))
1306
+
1307
+ self.fix_length = fix_length
1308
+
1309
+ def process(self, x):
1310
+ """
1311
+ Process audio.
1312
+
1313
+ Args:
1314
+ x (Numpy array): input audio of size `n_samples x n_channels`.
1315
+
1316
+ Returns:
1317
+ (Numpy array): resampled audio of size `n_samples x n_channels`.
1318
+ """
1319
+ if self.parameters.factor.value == 1.0:
1320
+ y = x
1321
+ else:
1322
+ scale = np.max(np.abs(x))
1323
+ if scale > 0.9:
1324
+ clips = True
1325
+ x = x * (0.9 / scale)
1326
+ else:
1327
+ clips = False
1328
+
1329
+ tfm = sox.Transformer()
1330
+ tfm.speed(self.parameters.factor.value)
1331
+ y = tfm.build_array(input_array=x, sample_rate_in=self.sample_rate).astype(np.float32)
1332
+
1333
+ if clips:
1334
+ y *= scale / 0.9 # rescale output to original scale
1335
+
1336
+ if self.fix_length:
1337
+ n_samples_input = x.shape[0]
1338
+ n_samples_output = y.shape[0]
1339
+ if n_samples_input < n_samples_output:
1340
+ idx1 = (n_samples_output - n_samples_input) // 2
1341
+ idx2 = idx1 + n_samples_input
1342
+ y = y[idx1:idx2]
1343
+ elif n_samples_input > n_samples_output:
1344
+ n_pad = n_samples_input - n_samples_output
1345
+ y = np.pad(y, ((n_pad//2, n_pad - n_pad//2), (0, 0)))
1346
+
1347
+ return y
1348
+
1349
+
1350
+ # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%% BEND %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1351
+ class Bend(Processor):
1352
+ """
1353
+ Simple bend effect using SoX and soxbindings (https://github.com/pseeth/soxbindings).
1354
+
1355
+ Processor parameters:
1356
+ n_bends (int): Number of segments or intervals to pitch shift
1357
+ """
1358
+
1359
+ def __init__(self, sample_rate, pitch_range=(-600, 600), fix_length=True, name='Bend', parameters=None):
1360
+ """
1361
+ Initialize processor.
1362
+
1363
+ Args:
1364
+ sample_rate (int): Sample rate of input audio.
1365
+ pitch_range (tuple of ints): min and max pitch bending ranges in cents
1366
+ fix_length (bool): If True, then output has same length as input.
1367
+ name (str): Name of processor.
1368
+ parameters (parameter_list): Parameters for this processor.
1369
+ """
1370
+ super().__init__(name=name, parameters=parameters, block_size=None, sample_rate=sample_rate)
1371
+
1372
+ if not parameters:
1373
+ self.parameters = ParameterList()
1374
+ self.parameters.add(Parameter('n_bends', 2, 'int', minimum=2, maximum=10))
1375
+ self.pitch_range_min, self.pitch_range_max = pitch_range
1376
+
1377
+ def process(self, x):
1378
+ """
1379
+ Process audio.
1380
+
1381
+ Args:
1382
+ x (Numpy array): input audio of size `n_samples x n_channels`.
1383
+
1384
+ Returns:
1385
+ (Numpy array): pitch-bended audio of size `n_samples x n_channels`.
1386
+ """
1387
+ n_bends = self.parameters.n_bends.value
1388
+ max_length = x.shape[0] / self.sample_rate
1389
+
1390
+ # Generates random non-overlapping segments
1391
+ delta = 1. / self.sample_rate
1392
+ boundaries = np.sort(delta + np.random.rand(n_bends-1) * (max_length - delta))
1393
+
1394
+ start, end = np.zeros(n_bends), np.zeros(n_bends)
1395
+ start[0] = delta
1396
+ for i, b in enumerate(boundaries):
1397
+ end[i] = b
1398
+ start[i+1] = b
1399
+ end[-1] = max_length
1400
+
1401
+ # randomly sample pitch-shifts in cents
1402
+ cents = np.random.randint(self.pitch_range_min, self.pitch_range_max+1, n_bends)
1403
+
1404
+ # remove segment if cent value is zero or start == end (as SoX does not allow such values)
1405
+ idx_keep = np.logical_and(cents != 0, start != end)
1406
+ n_bends, start, end, cents = sum(idx_keep), start[idx_keep], end[idx_keep], cents[idx_keep]
1407
+
1408
+ scale = np.max(np.abs(x))
1409
+ if scale > 0.9:
1410
+ clips = True
1411
+ x = x * (0.9 / scale)
1412
+ else:
1413
+ clips = False
1414
+
1415
+ tfm = sox.Transformer()
1416
+ tfm.bend(n_bends=int(n_bends), start_times=list(start), end_times=list(end), cents=list(cents))
1417
+ y = tfm.build_array(input_array=x, sample_rate_in=self.sample_rate).astype(np.float32)
1418
+
1419
+ if clips:
1420
+ y *= scale / 0.9 # rescale output to original scale
1421
+
1422
+ return y
1423
+
1424
+
1425
+
1426
+
1427
+
1428
+ # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ALGORITHMIC REVERB %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1429
+ class AlgorithmicReverb(Processor):
1430
+ def __init__(self, name="algoreverb", parameters=None, sample_rate=44100, **kwargs):
1431
+
1432
+ super().__init__(name=name, parameters=parameters, block_size=None, sample_rate=sample_rate, **kwargs)
1433
+
1434
+ if not parameters:
1435
+ self.parameters = ParameterList()
1436
+ self.parameters.add(Parameter("room_size", 0.5, "float", minimum=0.05, maximum=0.85))
1437
+ self.parameters.add(Parameter("damping", 0.1, "float", minimum=0.0, maximum=1.0))
1438
+ self.parameters.add(Parameter("dry_mix", 0.9, "float", minimum=0.0, maximum=1.0))
1439
+ self.parameters.add(Parameter("wet_mix", 0.1, "float", minimum=0.0, maximum=1.0))
1440
+ self.parameters.add(Parameter("width", 0.7, "float", minimum=0.0, maximum=1.0))
1441
+
1442
+ # Tuning
1443
+ self.stereospread = 23
1444
+ self.scalegain = 0.2
1445
+
1446
+
1447
+ def process(self, data):
1448
+
1449
+ if data.ndim >= 2:
1450
+ dataL = data[:,0]
1451
+ if data.shape[1] == 2:
1452
+ dataR = data[:,1]
1453
+ else:
1454
+ dataR = data[:,0]
1455
+ else:
1456
+ dataL = data
1457
+ dataR = data
1458
+
1459
+ output = np.zeros((data.shape[0], 2))
1460
+
1461
+ xL, xR = self.process_filters(dataL.copy(), dataR.copy())
1462
+
1463
+ wet1_g = self.parameters.wet_mix.value * ((self.parameters.width.value/2) + 0.5)
1464
+ wet2_g = self.parameters.wet_mix.value * ((1-self.parameters.width.value)/2)
1465
+ dry_g = self.parameters.dry_mix.value
1466
+
1467
+ output[:,0] = (wet1_g * xL) + (wet2_g * xR) + (dry_g * dataL)
1468
+ output[:,1] = (wet1_g * xR) + (wet2_g * xL) + (dry_g * dataR)
1469
+
1470
+ return output
1471
+
1472
+ def process_filters(self, dataL, dataR):
1473
+
1474
+ xL = self.combL1.process(dataL.copy() * self.scalegain)
1475
+ xL += self.combL2.process(dataL.copy() * self.scalegain)
1476
+ xL += self.combL3.process(dataL.copy() * self.scalegain)
1477
+ xL += self.combL4.process(dataL.copy() * self.scalegain)
1478
+ xL = self.combL5.process(dataL.copy() * self.scalegain)
1479
+ xL += self.combL6.process(dataL.copy() * self.scalegain)
1480
+ xL += self.combL7.process(dataL.copy() * self.scalegain)
1481
+ xL += self.combL8.process(dataL.copy() * self.scalegain)
1482
+
1483
+ xR = self.combR1.process(dataR.copy() * self.scalegain)
1484
+ xR += self.combR2.process(dataR.copy() * self.scalegain)
1485
+ xR += self.combR3.process(dataR.copy() * self.scalegain)
1486
+ xR += self.combR4.process(dataR.copy() * self.scalegain)
1487
+ xR = self.combR5.process(dataR.copy() * self.scalegain)
1488
+ xR += self.combR6.process(dataR.copy() * self.scalegain)
1489
+ xR += self.combR7.process(dataR.copy() * self.scalegain)
1490
+ xR += self.combR8.process(dataR.copy() * self.scalegain)
1491
+
1492
+ yL1 = self.allpassL1.process(xL)
1493
+ yL2 = self.allpassL2.process(yL1)
1494
+ yL3 = self.allpassL3.process(yL2)
1495
+ yL4 = self.allpassL4.process(yL3)
1496
+
1497
+ yR1 = self.allpassR1.process(xR)
1498
+ yR2 = self.allpassR2.process(yR1)
1499
+ yR3 = self.allpassR3.process(yR2)
1500
+ yR4 = self.allpassR4.process(yR3)
1501
+
1502
+ return yL4, yR4
1503
+
1504
+ def update(self, parameter_name):
1505
+
1506
+ rs = self.parameters.room_size.value
1507
+ dp = self.parameters.damping.value
1508
+ ss = self.stereospread
1509
+
1510
+ # initialize allpass and feedback comb-filters
1511
+ # (with coefficients optimized for fs=44.1kHz)
1512
+ self.allpassL1 = pymc.components.allpass.Allpass(556, rs, self.block_size)
1513
+ self.allpassR1 = pymc.components.allpass.Allpass(556+ss, rs, self.block_size)
1514
+ self.allpassL2 = pymc.components.allpass.Allpass(441, rs, self.block_size)
1515
+ self.allpassR2 = pymc.components.allpass.Allpass(441+ss, rs, self.block_size)
1516
+ self.allpassL3 = pymc.components.allpass.Allpass(341, rs, self.block_size)
1517
+ self.allpassR3 = pymc.components.allpass.Allpass(341+ss, rs, self.block_size)
1518
+ self.allpassL4 = pymc.components.allpass.Allpass(225, rs, self.block_size)
1519
+ self.allpassR4 = pymc.components.allpass.Allpass(255+ss, rs, self.block_size)
1520
+
1521
+ self.combL1 = pymc.components.comb.Comb(1116, dp, rs, self.block_size)
1522
+ self.combR1 = pymc.components.comb.Comb(1116+ss, dp, rs, self.block_size)
1523
+ self.combL2 = pymc.components.comb.Comb(1188, dp, rs, self.block_size)
1524
+ self.combR2 = pymc.components.comb.Comb(1188+ss, dp, rs, self.block_size)
1525
+ self.combL3 = pymc.components.comb.Comb(1277, dp, rs, self.block_size)
1526
+ self.combR3 = pymc.components.comb.Comb(1277+ss, dp, rs, self.block_size)
1527
+ self.combL4 = pymc.components.comb.Comb(1356, dp, rs, self.block_size)
1528
+ self.combR4 = pymc.components.comb.Comb(1356+ss, dp, rs, self.block_size)
1529
+ self.combL5 = pymc.components.comb.Comb(1422, dp, rs, self.block_size)
1530
+ self.combR5 = pymc.components.comb.Comb(1422+ss, dp, rs, self.block_size)
1531
+ self.combL6 = pymc.components.comb.Comb(1491, dp, rs, self.block_size)
1532
+ self.combR6 = pymc.components.comb.Comb(1491+ss, dp, rs, self.block_size)
1533
+ self.combL7 = pymc.components.comb.Comb(1557, dp, rs, self.block_size)
1534
+ self.combR7 = pymc.components.comb.Comb(1557+ss, dp, rs, self.block_size)
1535
+ self.combL8 = pymc.components.comb.Comb(1617, dp, rs, self.block_size)
1536
+ self.combR8 = pymc.components.comb.Comb(1617+ss, dp, rs, self.block_size)
1537
+
mixing_style_transfer/mixing_manipulator/common_dataprocessing.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module with common functions for loading training data and preparing minibatches.
3
+
4
+ AI Music Technology Group, Sony Group Corporation
5
+ AI Speech and Sound Group, Sony Europe
6
+
7
+ This implementation originally belongs to Sony Group Corporation,
8
+ which has been introduced in the work "Automatic music mixing with deep learning and out-of-domain data".
9
+ Original repo link: https://github.com/sony/FxNorm-automix
10
+ """
11
+
12
+ import numpy as np
13
+ import os
14
+ import sys
15
+ import functools
16
+ import scipy.io.wavfile as wav
17
+ import soundfile as sf
18
+ from typing import Tuple
19
+
20
+ currentdir = os.path.dirname(os.path.realpath(__file__))
21
+ sys.path.append(currentdir)
22
+ from common_audioeffects import AugmentationChain
23
+ from common_miscellaneous import uprint
24
+
25
+
26
+ def load_wav(file_path, mmap=False, convert_float=False):
27
+ """
28
+ Load a WAV file in C_CONTIGUOUS format.
29
+
30
+ Args:
31
+ file_path: Path to WAV file (16bit, 24bit or 32bit PCM supported)
32
+ mmap: If `True`, then we do not load the WAV data into memory but use a memory-mapped representation
33
+
34
+ Returns:
35
+ fs: Sample rate
36
+ samples: Numpy array (np.int16 or np.int32) with audio [n_samples x n_channels]
37
+ """
38
+ fs, samples = wav.read(file_path, mmap=mmap)
39
+
40
+ # ensure that we have a 2d array (monaural files are just loaded as vectors)
41
+ if samples.ndim == 1:
42
+ samples = samples[:, np.newaxis]
43
+
44
+ # make sure that we have loaded an integer PCM WAV file as we assume this later
45
+ # when we scale the amplitude
46
+ assert(samples.dtype == np.int16 or samples.dtype == np.int32)
47
+
48
+ if convert_float:
49
+ conversion_scale = 1. / (1. + np.iinfo(samples.dtype).max)
50
+ samples = samples.astype(dtype=np.float32) * conversion_scale
51
+
52
+ return fs, samples
53
+
54
+
55
+ def save_wav(file_path, fs, samples, subtype='PCM_16'):
56
+ """
57
+ Save a WAV file (16bit or 32bit PCM).
58
+
59
+ Important note: We save here using the same conversion as is used in
60
+ `generate_data`, i.e., we multiply by `1 + np.iinfo(np.int16).max`
61
+ or `1 + np.iinfo(np.int32).max` which is a different behavior
62
+ than `libsndfile` as described here:
63
+ http://www.mega-nerd.com/libsndfile/FAQ.html#Q010
64
+
65
+ Args:
66
+ file_path: Path where to store the WAV file
67
+ fs: Sample rate
68
+ samples: Numpy array (float32 with values in [-1, 1) and shape [n_samples x n_channels])
69
+ subtype: Either `PCM_16` or `PCM_24` or `PCM_32` in order to store as 16bit, 24bit or 32bit PCM file
70
+ """
71
+ assert subtype in ['PCM_16', 'PCM_24', 'PCM_32'], subtype
72
+
73
+ if subtype == 'PCM_16':
74
+ dtype = np.int16
75
+ else:
76
+ dtype = np.int32
77
+
78
+ # convert to int16 (check for clipping)
79
+
80
+ samples = samples * (1 + np.iinfo(dtype).max)
81
+ if np.min(samples) < np.iinfo(dtype).min or np.max(samples) > np.iinfo(dtype).max:
82
+ uprint(f'WARNING: Clipping occurs for {file_path}.')
83
+ samples_ = samples / (1 + np.iinfo(dtype).max)
84
+ print('max value ', np.max(np.abs(samples_)))
85
+ samples = np.clip(samples, np.iinfo(dtype).min, np.iinfo(dtype).max)
86
+ samples = samples.astype(dtype)
87
+
88
+ # store WAV file
89
+ sf.write(file_path, samples, fs, subtype=subtype)
90
+
91
+
92
+ def load_files_lists(path):
93
+ """
94
+ Auxiliary function to find the paths for all mixtures in a database.
95
+
96
+ Args:
97
+ path: path to the folder containing the files to list
98
+
99
+ Returns:
100
+ list_of_directories: list of directories (= list of songs) in `path`
101
+ """
102
+ # get directories in `path`
103
+ list_of_directories = []
104
+ for folder in os.listdir(path):
105
+ list_of_directories.append(folder)
106
+
107
+ return list_of_directories
108
+
109
+
110
+ def create_dataset(path, accepted_sampling_rates, sources, mapped_sources, n_channels=-1, load_to_memory=False,
111
+ debug=False, verbose=False):
112
+ """
113
+ Prepare data in `path` for training/validation/test set generation.
114
+
115
+ Args:
116
+ path: path to the dataset
117
+ accepted_sampling_rates: list of accepted sampling rates
118
+ sources: list of sources
119
+ mapped_sources: list of mapped sources
120
+ n_channels: number of channels
121
+ load_to_memory: whether to load to main memory
122
+ debug: if `True`, then we load only `NUM_SAMPLES_SMALL_DATASET`
123
+
124
+ Raises:
125
+ ValueError: mapping of sources not possible is data is not loaded into memory
126
+
127
+ Returns:
128
+ data: list of dictionaries with function handles (to load the data)
129
+ directories: list of directories
130
+ """
131
+ NUM_SAMPLES_SMALL_DATASET = 16
132
+
133
+ # source mapping currently only works if we load everything into the memory
134
+ if mapped_sources and not load_to_memory:
135
+ raise ValueError('Mapping of sources only supported if data is loaded into the memory.')
136
+
137
+ # get directories for dataset
138
+ directories = load_files_lists(path)
139
+
140
+ # load all songs for dataset
141
+ if debug:
142
+ data = [dict() for _x in range(np.minimum(NUM_SAMPLES_SMALL_DATASET, len(directories)))]
143
+ else:
144
+ data = [dict() for _x in range(len(directories))]
145
+
146
+ material_length = {} # in seconds
147
+ for i, d in enumerate(directories):
148
+ if verbose:
149
+ uprint(f'Processing mixture ({i+1} of {len(directories)}): {d}')
150
+
151
+ # add names of all files in this folder
152
+ files = os.listdir(os.path.join(path, d))
153
+ for f in files:
154
+ src_name = os.path.splitext(f)[0]
155
+ if ((src_name not in sources
156
+ and src_name not in mapped_sources)):
157
+ if verbose:
158
+ uprint(f'\tIgnoring unknown source from file {f}')
159
+ else:
160
+ if src_name not in sources:
161
+ src_name = mapped_sources[src_name]
162
+ if verbose:
163
+ uprint(f'\tAdding function handle for "{src_name}" from file {f}')
164
+
165
+ _data = load_wav(os.path.join(path, d, f), mmap=not load_to_memory)
166
+
167
+ # determine properties from loaded data
168
+ _samplingrate = _data[0]
169
+ _n_channels = _data[1].shape[1]
170
+ _duration = _data[1].shape[0] / _samplingrate
171
+
172
+ # collect statistics about data for each source
173
+ if src_name in material_length:
174
+ material_length[src_name] += _duration
175
+ else:
176
+ material_length[src_name] = _duration
177
+
178
+ # make sure that sample rate and number of channels matches
179
+ if n_channels != -1 and _n_channels != n_channels:
180
+ raise ValueError(f'File has {_n_channels} '
181
+ f'channels but expected {n_channels}.')
182
+
183
+ if _samplingrate not in accepted_sampling_rates:
184
+ raise ValueError(f'File has fs = {_samplingrate}Hz '
185
+ f'but expected {accepted_sampling_rates}Hz.')
186
+
187
+ # if we already loaded data for this source then append data
188
+ if src_name in data[i]:
189
+ _data = (_data[0], np.vstack((_data[1],
190
+ data[i][src_name].keywords['file_path_or_data'][1])))
191
+ data[i][src_name] = functools.partial(generate_data,
192
+ file_path_or_data=_data)
193
+
194
+ if debug and i == NUM_SAMPLES_SMALL_DATASET-1:
195
+ # load only first `NUM_SAMPLES_SMALL_DATASET` songs
196
+ break
197
+
198
+ # delete all entries where we did not find an source file
199
+ idx_empty = [_ for _ in range(len(data)) if len(data[_]) == 0]
200
+ for idx in sorted(idx_empty, reverse=True):
201
+ del data[idx]
202
+
203
+ return data, directories
204
+
205
+ def create_dataset_mixing(path, accepted_sampling_rates, sources, mapped_sources, n_channels=-1, load_to_memory=False,
206
+ debug=False, pad_wrap_samples=None):
207
+ """
208
+ Prepare data in `path` for training/validation/test set generation.
209
+
210
+ Args:
211
+ path: path to the dataset
212
+ accepted_sampling_rates: list of accepted sampling rates
213
+ sources: list of sources
214
+ mapped_sources: list of mapped sources
215
+ n_channels: number of channels
216
+ load_to_memory: whether to load to main memory
217
+ debug: if `True`, then we load only `NUM_SAMPLES_SMALL_DATASET`
218
+
219
+ Raises:
220
+ ValueError: mapping of sources not possible is data is not loaded into memory
221
+
222
+ Returns:
223
+ data: list of dictionaries with function handles (to load the data)
224
+ directories: list of directories
225
+ """
226
+ NUM_SAMPLES_SMALL_DATASET = 16
227
+
228
+ # source mapping currently only works if we load everything into the memory
229
+ if mapped_sources and not load_to_memory:
230
+ raise ValueError('Mapping of sources only supported if data is loaded into the memory.')
231
+
232
+ # get directories for dataset
233
+ directories = load_files_lists(path)
234
+ directories.sort()
235
+
236
+ # load all songs for dataset
237
+ uprint(f'\nCreating dataset for path={path} ...')
238
+
239
+ if debug:
240
+ data = [dict() for _x in range(np.minimum(NUM_SAMPLES_SMALL_DATASET, len(directories)))]
241
+ else:
242
+ data = [dict() for _x in range(len(directories))]
243
+
244
+ material_length = {} # in seconds
245
+ for i, d in enumerate(directories):
246
+ uprint(f'Processing mixture ({i+1} of {len(directories)}): {d}')
247
+
248
+ # add names of all files in this folder
249
+ files = os.listdir(os.path.join(path, d))
250
+ _data_mix = []
251
+ _stems_name = []
252
+ for f in files:
253
+ src_name = os.path.splitext(f)[0]
254
+ if ((src_name not in sources
255
+ and src_name not in mapped_sources)):
256
+ uprint(f'\tIgnoring unknown source from file {f}')
257
+ else:
258
+ if src_name not in sources:
259
+ src_name = mapped_sources[src_name]
260
+ uprint(f'\tAdding function handle for "{src_name}" from file {f}')
261
+
262
+ _data = load_wav(os.path.join(path, d, f), mmap=not load_to_memory)
263
+
264
+ if pad_wrap_samples:
265
+ _data = (_data[0], np.pad(_data[1], [(pad_wrap_samples, 0), (0,0)], 'wrap'))
266
+
267
+ # determine properties from loaded data
268
+ _samplingrate = _data[0]
269
+ _n_channels = _data[1].shape[1]
270
+ _duration = _data[1].shape[0] / _samplingrate
271
+
272
+ # collect statistics about data for each source
273
+ if src_name in material_length:
274
+ material_length[src_name] += _duration
275
+ else:
276
+ material_length[src_name] = _duration
277
+
278
+ # make sure that sample rate and number of channels matches
279
+ if n_channels != -1 and _n_channels != n_channels:
280
+ if _n_channels == 1: # Converts mono to stereo with repeated channels
281
+ _data = (_data[0], np.repeat(_data[1], 2, axis=-1))
282
+ print("Converted file to stereo by repeating mono channel")
283
+ else:
284
+ raise ValueError(f'File has {_n_channels} '
285
+ f'channels but expected {n_channels}.')
286
+
287
+ if _samplingrate not in accepted_sampling_rates:
288
+ raise ValueError(f'File has fs = {_samplingrate}Hz '
289
+ f'but expected {accepted_sampling_rates}Hz.')
290
+
291
+ # if we already loaded data for this source then append data
292
+ if src_name in data[i]:
293
+ _data = (_data[0], np.vstack((_data[1],
294
+ data[i][src_name].keywords['file_path_or_data'][1])))
295
+
296
+ _data_mix.append(_data)
297
+ _stems_name.append(src_name)
298
+
299
+ data[i]["-".join(_stems_name)] = functools.partial(generate_data,
300
+ file_path_or_data=_data_mix)
301
+
302
+ if debug and i == NUM_SAMPLES_SMALL_DATASET-1:
303
+ # load only first `NUM_SAMPLES_SMALL_DATASET` songs
304
+ break
305
+
306
+ # delete all entries where we did not find an source file
307
+ idx_empty = [_ for _ in range(len(data)) if len(data[_]) == 0]
308
+ for idx in sorted(idx_empty, reverse=True):
309
+ del data[idx]
310
+
311
+ uprint(f'Finished preparation of dataset. '
312
+ f'Found in total the following material (in {len(data)} directories):')
313
+ for src in material_length:
314
+ uprint(f'\t{src}: {material_length[src] / 60.0 / 60.0:.2f} hours')
315
+ return data, directories
316
+
317
+
318
+ def generate_data(file_path_or_data, random_sample_size=None):
319
+ """
320
+ Load one stem/several stems specified by `file_path_or_data`.
321
+
322
+ Alternatively, can also be the result of `wav.read()` if the data has already been loaded previously.
323
+
324
+ If `file_path_or_data` is a tuple/list, then we load several files and will return also a tuple/list.
325
+ This is useful for cases where we want to make sure to have the same random chunk for several stems.
326
+
327
+ If `random_sample_chunk_size` is not None, then only `random_sample_chunk_size` samples are randomly selected.
328
+
329
+ Args:
330
+ file_path_or_data: either path to data or the data itself
331
+ random_sample_size: if `random_sample_size` is not None, only `random_sample_size` samples are randomly selected
332
+
333
+ Returns:
334
+ samples: data with size `num_samples x num_channels` or a list of samples
335
+ """
336
+ needs_wrapping = False
337
+ if isinstance(file_path_or_data, str):
338
+ needs_wrapping = True # single file path -> wrap
339
+ if ((type(file_path_or_data[0]) is not list
340
+ and type(file_path_or_data[0]) is not tuple)):
341
+ needs_wrapping = True # single data -> wrap
342
+ if needs_wrapping:
343
+ file_path_or_data = (file_path_or_data,)
344
+
345
+ # create list where we store all samples
346
+ samples = [None] * len(file_path_or_data)
347
+
348
+ # load samples from wav file
349
+ for i, fpod in enumerate(file_path_or_data):
350
+ if isinstance(fpod, str):
351
+ _fs, samples[i] = load_wav(fpod)
352
+ else:
353
+ _fs, samples[i] = fpod
354
+
355
+ # if `random_sample_chunk_size` is not None, then only select subset
356
+ if random_sample_size is not None:
357
+ # get maximum length of all stems (at least `random_sample_chunk_size`)
358
+ max_length = random_sample_size
359
+ for s in samples:
360
+ max_length = np.maximum(max_length, s.shape[0])
361
+
362
+ # make sure that we can select enough audio and that all have the same length `max_length`
363
+ # (for short loops, `random_sample_chunk_size` can be larger than `s.shape[0]`)
364
+ for i, s in enumerate(samples):
365
+ if s.shape[0] < max_length:
366
+ required_padding = max_length - s.shape[0]
367
+ zeros = np.zeros((required_padding // 2 + 1, s.shape[1]),
368
+ dtype=s.dtype, order='F')
369
+ samples[i] = np.concatenate([zeros, s, zeros])
370
+
371
+ # select random part of audio
372
+ idx_start = np.random.randint(max_length)
373
+
374
+ for i, s in enumerate(samples):
375
+ if idx_start + random_sample_size < s.shape[0]:
376
+ samples[i] = s[idx_start:idx_start + random_sample_size]
377
+ else:
378
+ samples[i] = np.concatenate([s[idx_start:],
379
+ s[:random_sample_size - (s.shape[0] - idx_start)]])
380
+
381
+ # convert from `int16/int32` to `float32` precision (this will also make a copy)
382
+ for i, s in enumerate(samples):
383
+ conversion_scale = 1. / (1. + np.iinfo(s.dtype).max)
384
+ samples[i] = s.astype(dtype=np.float32) * conversion_scale
385
+
386
+ if len(samples) == 1:
387
+ return samples[0]
388
+ else:
389
+ return samples
390
+
391
+
392
+ def create_minibatch(data: list, sources: list,
393
+ present_prob: dict, overlap_prob: dict,
394
+ augmenter: AugmentationChain, augmenter_padding: Tuple[int],
395
+ batch_size: int, n_samples: int, n_channels: int, idx_songs: dict):
396
+ """
397
+ Create a minibatch.
398
+
399
+ This function also handles the case that we do not have a source in one mixture.
400
+ This can, e.g., happen for instrumental pieces that do not have vocals.
401
+
402
+ Args:
403
+ data (list): data to create the minibatch from.
404
+ sources (list): list of sources.
405
+ present_prob (dict): probability of a source to be present.
406
+ overlap_prob (dict): probability of overlap.
407
+ augmenter (AugmentationChain): audio effect chain that we want to apply for data augmentation
408
+ augmenter_padding (tuple of ints): padding that we should apply to left/right side of data to avoid
409
+ boundary effects of `augmenter`.
410
+ batch_size (int): number of training samples in one minibatch.
411
+ n_samples (int): number of time samples.
412
+ n_channels (int): number of channels.
413
+ idx_songs (dict): index of songs.
414
+
415
+ Returns:
416
+ inp (Numpy array): minibatch, input to the network (i.e. the mixture) of size
417
+ `batch_size x n_samples x n_channels`
418
+ tar (dict with Numpy arrays): dictionary which contains for each source the targets,
419
+ each of the `c_contiguous` ndarrays is `batch_size x n_samples x n_channels`
420
+ """
421
+ # initialize numpy arrays which keep input/targets
422
+ shp = (batch_size, n_samples, n_channels)
423
+ inp = np.zeros(shape=shp, dtype=np.float32, order='C')
424
+ tar = {src: np.zeros(shape=shp, dtype=np.float32, order='C') for src in sources}
425
+
426
+ # use padding to avoid boundary effects of augmenter
427
+ pad_left = None if augmenter_padding[0] == 0 else augmenter_padding[0]
428
+ pad_right = None if augmenter_padding[1] == 0 else -augmenter_padding[1]
429
+
430
+ def augm(i, s, n):
431
+ return augmenter(data[i][s](random_sample_size=n+sum(augmenter_padding)))[pad_left:pad_right]
432
+
433
+ # create mini-batch
434
+ for src in sources:
435
+
436
+ for j in range(batch_size):
437
+ # get song index for this source
438
+ _idx_song = idx_songs[src][j]
439
+
440
+ # determine whether this source is present/whether we overlap
441
+ is_present = src not in present_prob or np.random.rand() < present_prob[src]
442
+ is_overlap = src in overlap_prob and np.random.rand() < overlap_prob[src]
443
+
444
+ # if song contains source, then add it to input/targetg]
445
+ if src in data[_idx_song] and is_present:
446
+ tar[src][j, ...] = augm(_idx_song, src, n_samples)
447
+
448
+ # overlap source with same source from randomly choosen other song
449
+ if is_overlap:
450
+ idx_overlap_ = np.random.randint(len(data))
451
+ if idx_overlap_ != _idx_song and src in data[idx_overlap_]:
452
+ tar[src][j, ...] += augm(idx_overlap_, src, n_samples)
453
+
454
+ # compute input
455
+ inp += tar[src]
456
+
457
+ # make sure that all have not too large amplitude (check only mixture)
458
+ maxabs_amp = np.maximum(1.0, 1e-6 + np.max(np.abs(inp), axis=(1, 2), keepdims=True))
459
+ inp /= maxabs_amp
460
+ for src in sources:
461
+ tar[src] /= maxabs_amp
462
+
463
+ return inp, tar
464
+
465
+ def create_minibatch_mixing(data: list, sources: list, inputs: list, outputs: list,
466
+ present_prob: dict, overlap_prob: dict,
467
+ augmenter: AugmentationChain, augmenter_padding: Tuple[int], augmenter_sources: list,
468
+ batch_size: int, n_samples: int, n_channels: int, idx_songs: dict):
469
+ """
470
+ Create a minibatch.
471
+
472
+ This function also handles the case that we do not have a source in one mixture.
473
+ This can, e.g., happen for instrumental pieces that do not have vocals.
474
+
475
+ Args:
476
+ data (list): data to create the minibatch from.
477
+ sources (list): list of sources.
478
+ present_prob (dict): probability of a source to be present.
479
+ overlap_prob (dict): probability of overlap.
480
+ augmenter (AugmentationChain): audio effect chain that we want to apply for data augmentation
481
+ augmenter_padding (tuple of ints): padding that we should apply to left/right side of data to avoid
482
+ boundary effects of `augmenter`.
483
+ augmenter_sources (list): list of sources to augment
484
+ batch_size (int): number of training samples in one minibatch.
485
+ n_samples (int): number of time samples.
486
+ n_channels (int): number of channels.
487
+ idx_songs (dict): index of songs.
488
+
489
+ Returns:
490
+ inp (Numpy array): minibatch, input to the network (i.e. the mixture) of size
491
+ `batch_size x n_samples x n_channels`
492
+ tar (dict with Numpy arrays): dictionary which contains for each source the targets,
493
+ each of the `c_contiguous` ndarrays is `batch_size x n_samples x n_channels`
494
+ """
495
+ # initialize numpy arrays which keep input/targets
496
+ shp = (batch_size, n_samples, n_channels)
497
+ stems = {src: np.zeros(shape=shp, dtype=np.float32, order='C') for src in inputs}
498
+ mix = {src: np.zeros(shape=shp, dtype=np.float32, order='C') for src in outputs}
499
+
500
+ # use padding to avoid boundary effects of augmenter
501
+ pad_left = None if augmenter_padding[0] == 0 else augmenter_padding[0]
502
+ pad_right = None if augmenter_padding[1] == 0 else -augmenter_padding[1]
503
+
504
+ def augm(i, n):
505
+ s = list(data[i])[0]
506
+ input_multitracks = data[i][s](random_sample_size=n+sum(augmenter_padding))
507
+ audio_tags = list(data[i])[0].split("-")
508
+
509
+ # Only applies augmentation to inputs, not output.
510
+ for k, tag in enumerate(audio_tags):
511
+ if tag in augmenter_sources:
512
+ input_multitracks[k] = augmenter(input_multitracks[k])[pad_left:pad_right]
513
+ else:
514
+ input_multitracks[k] = input_multitracks[k][pad_left:pad_right]
515
+ return input_multitracks
516
+
517
+ # create mini-batch
518
+ for src in outputs:
519
+
520
+ for j in range(batch_size):
521
+ # get song index for this source
522
+ _idx_song = idx_songs[src][j]
523
+
524
+ multitrack_audio = augm(_idx_song, n_samples)
525
+
526
+ audio_tags = list(data[_idx_song])[0].split("-")
527
+
528
+ for i, tag in enumerate(audio_tags):
529
+ if tag in inputs:
530
+ stems[tag][j, ...] = multitrack_audio[i]
531
+ if tag in outputs:
532
+ mix[tag][j, ...] = multitrack_audio[i]
533
+
534
+ return stems, mix
535
+
mixing_style_transfer/mixing_manipulator/common_miscellaneous.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Common miscellaneous functions.
3
+
4
+ AI Music Technology Group, Sony Group Corporation
5
+ AI Speech and Sound Group, Sony Europe
6
+
7
+ This implementation originally belongs to Sony Group Corporation,
8
+ which has been introduced in the work "Automatic music mixing with deep learning and out-of-domain data".
9
+ Original repo link: https://github.com/sony/FxNorm-automix
10
+ """
11
+ import os
12
+ import psutil
13
+ import sys
14
+ import numpy as np
15
+ import librosa
16
+ import torch
17
+ import math
18
+
19
+
20
+ def uprint(s):
21
+ """
22
+ Unbuffered print to stdout.
23
+
24
+ We also flush stderr to have the log-file in sync.
25
+
26
+ Args:
27
+ s: string to print
28
+ """
29
+ print(s)
30
+ sys.stdout.flush()
31
+ sys.stderr.flush()
32
+
33
+
34
+ def recursive_getattr(obj, attr):
35
+ """
36
+ Run `getattr` recursively (e.g., for `fc1.weight`).
37
+
38
+ Args:
39
+ obj: object
40
+ attr: attribute to get
41
+
42
+ Returns:
43
+ object
44
+ """
45
+ for a in attr.split('.'):
46
+ obj = getattr(obj, a)
47
+ return obj
48
+
49
+
50
+ def compute_stft(samples, hop_length, fft_size, stft_window):
51
+ """
52
+ Compute the STFT of `samples` applying a Hann window of size `FFT_SIZE`, shifted for each frame by `hop_length`.
53
+
54
+ Args:
55
+ samples: num samples x channels
56
+ hop_length: window shift in samples
57
+ fft_size: FFT size which is also the window size
58
+ stft_window: STFT analysis window
59
+
60
+ Returns:
61
+ stft: frames x channels x freqbins
62
+ """
63
+ n_channels = samples.shape[1]
64
+ n_frames = 1+int((samples.shape[0] - fft_size)/hop_length)
65
+ stft = np.empty((n_frames, n_channels, fft_size//2+1), dtype=np.complex64)
66
+
67
+ # convert into f_contiguous (such that [:,n] slicing is c_contiguous)
68
+ samples = np.asfortranarray(samples)
69
+
70
+ for n in range(n_channels):
71
+ # compute STFT (output has size `n_frames x N_BINS`)
72
+ stft[:, n, :] = librosa.stft(samples[:, n],
73
+ n_fft=fft_size,
74
+ hop_length=hop_length,
75
+ window=stft_window,
76
+ center=False).transpose()
77
+ return stft
78
+
79
+
80
+ def compute_istft(stft, hop_length, stft_window):
81
+ """
82
+ Compute the inverse STFT of `stft`.
83
+
84
+ Args:
85
+ stft: frames x channels x freqbins
86
+ hop_length: window shift in samples
87
+ stft_window: STFT synthesis window
88
+
89
+ Returns:
90
+ samples: num samples x channels
91
+ """
92
+ for n in range(stft.shape[1]):
93
+ s = librosa.istft(stft[:, n, :].transpose(),
94
+ hop_length=hop_length, window=stft_window, center=False)
95
+ if n == 0:
96
+ samples = s
97
+ else:
98
+ samples = np.column_stack((samples, s))
99
+
100
+ # ensure that we have a 2d array (monaural files are just loaded as vectors)
101
+ if samples.ndim == 1:
102
+ samples = samples[:, np.newaxis]
103
+
104
+ return samples
105
+
106
+
107
+ def get_size(obj):
108
+ """
109
+ Recursively find size of objects (in bytes).
110
+
111
+ Args:
112
+ obj: object
113
+
114
+ Returns:
115
+ size of object
116
+ """
117
+ size = sys.getsizeof(obj)
118
+
119
+ import functools
120
+
121
+ if isinstance(obj, dict):
122
+ size += sum([get_size(v) for v in obj.values()])
123
+ size += sum([get_size(k) for k in obj.keys()])
124
+ elif isinstance(obj, functools.partial):
125
+ size += sum([get_size(v) for v in obj.keywords.values()])
126
+ size += sum([get_size(k) for k in obj.keywords.keys()])
127
+ elif isinstance(obj, list):
128
+ size += sum([get_size(i) for i in obj])
129
+ elif isinstance(obj, tuple):
130
+ size += sum([get_size(i) for i in obj])
131
+ return size
132
+
133
+
134
+ def get_process_memory():
135
+ """
136
+ Return memory consumption in GBytes.
137
+
138
+ Returns:
139
+ memory used by the process
140
+ """
141
+ return psutil.Process(os.getpid()).memory_info()[0] / (2 ** 30)
142
+
143
+
144
+ def check_complete_convolution(input_size, kernel_size, stride=1,
145
+ padding=0, dilation=1, note=''):
146
+ """
147
+ Check where the convolution is complete.
148
+
149
+ Returns true if no time steps left over in a Conv1d
150
+
151
+ Args:
152
+ input_size: size of input
153
+ kernel_size: size of kernel
154
+ stride: stride
155
+ padding: padding
156
+ dilation: dilation
157
+ note: string for additional notes
158
+ """
159
+ is_complete = ((input_size + 2*padding - dilation * (kernel_size - 1) - 1)
160
+ / stride + 1).is_integer()
161
+ uprint(f'{note} {is_complete}')
162
+
163
+
164
+ def pad_to_shape(x: torch.Tensor, y: int) -> torch.Tensor:
165
+ """
166
+ Right-pad or right-trim first argument last dimension to have same size as second argument.
167
+
168
+ Args:
169
+ x: Tensor to be padded.
170
+ y: Size to pad/trim x last dimension to
171
+
172
+ Returns:
173
+ `x` padded to match `y`'s dimension.
174
+ """
175
+ inp_len = y
176
+ output_len = x.shape[-1]
177
+ return torch.nn.functional.pad(x, [0, inp_len - output_len])
178
+
179
+
180
+ def valid_length(input_size, kernel_size, stride=1, padding=0, dilation=1):
181
+ """
182
+ Return the nearest valid upper length to use with the model so that there is no time steps left over in a 1DConv.
183
+
184
+ For all layers, size of the (input - kernel_size) % stride = 0.
185
+ Here valid means that there is no left over frame neglected and discarded.
186
+
187
+ Args:
188
+ input_size: size of input
189
+ kernel_size: size of kernel
190
+ stride: stride
191
+ padding: padding
192
+ dilation: dilation
193
+
194
+ Returns:
195
+ valid length for convolution
196
+ """
197
+ length = math.ceil((input_size + 2*padding - dilation * (kernel_size - 1) - 1)/stride) + 1
198
+ length = (length - 1) * stride - 2*padding + dilation * (kernel_size - 1) + 1
199
+
200
+ return int(length)
201
+
202
+
203
+ def td_length_from_fd(fd_length: int, fft_size: int, fft_hop: int) -> int:
204
+ """
205
+ Return the length in time domain, given the length in frequency domain.
206
+
207
+ Return the necessary length in the time domain of a signal to be transformed into
208
+ a signal of length `fd_length` in time-frequency domain with the given STFT
209
+ parameters `fft_size` and `fft_hop`. No padding is assumed.
210
+
211
+ Args:
212
+ fd_length: length in frequency domain
213
+ fft_size: size of FFT
214
+ fft_hop: hop length
215
+
216
+ Returns:
217
+ length in time domain
218
+ """
219
+ return (fd_length - 1) * fft_hop + fft_size
mixing_style_transfer/mixing_manipulator/data_normalization.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of the 'audio effects chain normalization'
3
+ """
4
+ import numpy as np
5
+ import scipy
6
+
7
+ import os
8
+ import sys
9
+ currentdir = os.path.dirname(os.path.realpath(__file__))
10
+ sys.path.append(currentdir)
11
+ from utils_data_normalization import *
12
+ from normalization_imager import *
13
+
14
+
15
+ '''
16
+ Audio Effects Chain Normalization
17
+ process: normalizes input stems according to given precomputed features
18
+ '''
19
+ class Audio_Effects_Normalizer:
20
+ def __init__(self, precomputed_feature_path, \
21
+ STEMS=['drums', 'bass', 'other', 'vocals'], \
22
+ EFFECTS=['eq', 'compression', 'imager', 'loudness']):
23
+ self.STEMS = STEMS # Stems to be normalized
24
+ self.EFFECTS = EFFECTS # Effects to be normalized, order matters
25
+
26
+ # Audio settings
27
+ self.SR = 44100
28
+ self.SUBTYPE = 'PCM_16'
29
+
30
+ # General Settings
31
+ self.FFT_SIZE = 2**16
32
+ self.HOP_LENGTH = self.FFT_SIZE//4
33
+
34
+ # Loudness
35
+ self.NTAPS = 1001
36
+ self.LUFS = -30
37
+ self.MIN_DB = -40 # Min amplitude to apply EQ matching
38
+
39
+ # Compressor
40
+ self.COMP_USE_EXPANDER = False
41
+ self.COMP_PEAK_NORM = -10.0
42
+ self.COMP_TRUE_PEAK = False
43
+ self.COMP_PERCENTILE = 75 # features_mean (v1) was done with 25
44
+ self.COMP_MIN_TH = -40
45
+ self.COMP_MAX_RATIO = 20
46
+ comp_settings = {key:{} for key in self.STEMS}
47
+ for key in comp_settings:
48
+ if key == 'vocals':
49
+ comp_settings[key]['attack'] = 7.5
50
+ comp_settings[key]['release'] = 400.0
51
+ comp_settings[key]['ratio'] = 4
52
+ comp_settings[key]['n_mels'] = 128
53
+ elif key == 'drums':
54
+ comp_settings[key]['attack'] = 10.0
55
+ comp_settings[key]['release'] = 180.0
56
+ comp_settings[key]['ratio'] = 6
57
+ comp_settings[key]['n_mels'] = 128
58
+ elif key == 'bass':
59
+ comp_settings[key]['attack'] = 10.0
60
+ comp_settings[key]['release'] = 500.0
61
+ comp_settings[key]['ratio'] = 5
62
+ comp_settings[key]['n_mels'] = 16
63
+ elif key == 'other':
64
+ comp_settings[key]['attack'] = 15.0
65
+ comp_settings[key]['release'] = 666.0
66
+ comp_settings[key]['ratio'] = 4
67
+ comp_settings[key]['n_mels'] = 128
68
+ self.comp_settings = comp_settings
69
+
70
+ # Load Pre-computed Audio Effects Features
71
+ features_mean = np.load(precomputed_feature_path, allow_pickle='TRUE')[()]
72
+ self.features_mean = self.smooth_feature(features_mean)
73
+
74
+
75
+ # normalize current audio input with the order of designed audio FX
76
+ def normalize_audio(self, audio, src):
77
+ assert src in self.STEMS
78
+
79
+ normalized_audio = audio
80
+ for cur_effect in self.EFFECTS:
81
+ normalized_audio = self.normalize_audio_per_effect(normalized_audio, src=src, effect=cur_effect)
82
+
83
+ return normalized_audio
84
+
85
+
86
+ # normalize current audio input with current targeted audio FX
87
+ def normalize_audio_per_effect(self, audio, src, effect):
88
+ audio = audio.astype(dtype=np.float32)
89
+ audio_track = np.pad(audio, ((self.FFT_SIZE, self.FFT_SIZE), (0, 0)), mode='constant')
90
+
91
+ assert len(audio_track.shape) == 2 # Always expects two dimensions
92
+
93
+ if audio_track.shape[1] == 1: # Converts mono to stereo with repeated channels
94
+ audio_track = np.repeat(audio_track, 2, axis=-1)
95
+
96
+ output_audio = audio_track.copy()
97
+
98
+ max_db = amp_to_db(np.max(np.abs(output_audio)))
99
+ if max_db > self.MIN_DB:
100
+
101
+ if effect == 'eq':
102
+ # normalize each channel
103
+ for ch in range(audio_track.shape[1]):
104
+ audio_eq_matched = get_eq_matching(output_audio[:, ch],
105
+ self.features_mean[effect][src],
106
+ sr=self.SR,
107
+ n_fft=self.FFT_SIZE,
108
+ hop_length=self.HOP_LENGTH,
109
+ min_db=self.MIN_DB,
110
+ ntaps=self.NTAPS,
111
+ lufs=self.LUFS)
112
+
113
+
114
+ np.copyto(output_audio[:,ch], audio_eq_matched)
115
+
116
+ elif effect == 'compression':
117
+ assert(len(self.features_mean[effect][src])==2)
118
+ # normalize each channel
119
+ for ch in range(audio_track.shape[1]):
120
+ try:
121
+ audio_comp_matched = get_comp_matching(output_audio[:, ch],
122
+ self.features_mean[effect][src][0],
123
+ self.features_mean[effect][src][1],
124
+ self.comp_settings[src]['ratio'],
125
+ self.comp_settings[src]['attack'],
126
+ self.comp_settings[src]['release'],
127
+ sr=self.SR,
128
+ min_db=self.MIN_DB,
129
+ min_th=self.COMP_MIN_TH,
130
+ comp_peak_norm=self.COMP_PEAK_NORM,
131
+ max_ratio=self.COMP_MAX_RATIO,
132
+ n_mels=self.comp_settings[src]['n_mels'],
133
+ true_peak=self.COMP_TRUE_PEAK,
134
+ percentile=self.COMP_PERCENTILE,
135
+ expander=self.COMP_USE_EXPANDER)
136
+
137
+ np.copyto(output_audio[:,ch], audio_comp_matched[:, 0])
138
+ except:
139
+ break
140
+
141
+ elif effect == 'loudness':
142
+ output_audio = fx_utils.lufs_normalize(output_audio, self.SR, self.features_mean[effect][src], log=False)
143
+
144
+ elif effect == 'imager':
145
+ # threshold of applying Haas effects
146
+ mono_threshold = 0.99 if src=='bass' else 0.975
147
+ audio_imager_matched = normalize_imager(output_audio, \
148
+ target_side_mid_bal=self.features_mean[effect][src], \
149
+ mono_threshold=mono_threshold, \
150
+ sr=self.SR)
151
+
152
+ np.copyto(output_audio, audio_imager_matched)
153
+
154
+ output_audio = output_audio[self.FFT_SIZE:self.FFT_SIZE+audio.shape[0]]
155
+ return output_audio
156
+
157
+
158
+ def smooth_feature(self, feature_dict_):
159
+
160
+ for effect in self.EFFECTS:
161
+ for key in self.STEMS:
162
+ if effect == 'eq':
163
+ if key in ['other', 'vocals']:
164
+ f = 401
165
+ else:
166
+ f = 151
167
+ feature_dict_[effect][key] = scipy.signal.savgol_filter(feature_dict_[effect][key],
168
+ f, 1, mode='mirror')
169
+ elif effect == 'panning':
170
+ feature_dict_[effect][key] = scipy.signal.savgol_filter(feature_dict_[effect][key],
171
+ 501, 1, mode='mirror')
172
+ return feature_dict_
173
+
mixing_style_transfer/mixing_manipulator/fx_utils.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
3
+
4
+ import numpy as np
5
+ import scipy
6
+ import math
7
+ import librosa
8
+ import librosa.display
9
+ import fnmatch
10
+ import os
11
+ from functools import partial
12
+ import pyloudnorm
13
+ from scipy.signal import lfilter
14
+ from sklearn.metrics import mean_absolute_error, mean_squared_error
15
+ from sklearn.metrics.pairwise import paired_distances
16
+
17
+
18
+ import matplotlib.pyplot as plt
19
+
20
+ def db(x):
21
+ """Computes the decible energy of a signal"""
22
+ return 20*np.log10(np.sqrt(np.mean(np.square(x))))
23
+
24
+ def melspectrogram(y, mirror_pad=False):
25
+ """Compute melspectrogram feature extraction
26
+
27
+ Keyword arguments:
28
+ signal -- input audio as a signal in a numpy object
29
+ inputnorm -- normalization of output
30
+ mirror_pad -- pre and post-pend mirror signals
31
+
32
+ Returns freq x time
33
+
34
+
35
+ Assumes the input sampling rate is 22050Hz
36
+ """
37
+
38
+ # Extract mel.
39
+ fftsize = 1024
40
+ window = 1024
41
+ hop = 512
42
+ melBin = 128
43
+ sr = 22050
44
+
45
+ # mirror pad signal
46
+ # first embedding centered on time 0
47
+ # last embedding centered on end of signal
48
+ if mirror_pad:
49
+ y = np.insert(y, 0, y[0:int(half_frame_length_sec * sr)][::-1])
50
+ y = np.insert(y, len(y), y[-int(half_frame_length_sec * sr):][::-1])
51
+
52
+ S = librosa.core.stft(y,n_fft=fftsize,hop_length=hop,win_length=window)
53
+ X = np.abs(S)
54
+ mel_basis = librosa.filters.mel(sr,n_fft=fftsize,n_mels=melBin)
55
+ mel_S = np.dot(mel_basis,X)
56
+
57
+ # value log compression
58
+ mel_S = np.log10(1+10*mel_S)
59
+ mel_S = mel_S.astype(np.float32)
60
+
61
+
62
+ return mel_S
63
+
64
+
65
+ def getFilesPath(directory, extension):
66
+
67
+ n_path=[]
68
+ for path, subdirs, files in os.walk(directory):
69
+ for name in files:
70
+ if fnmatch.fnmatch(name, extension):
71
+ n_path.append(os.path.join(path,name))
72
+ n_path.sort()
73
+
74
+ return n_path
75
+
76
+
77
+
78
+ def getRandomTrim(x, length, pad=0, start=None):
79
+
80
+ length = length+pad
81
+ if x.shape[0] <= length:
82
+ x_ = x
83
+ while(x.shape[0] <= length):
84
+ x_ = np.concatenate((x_,x_))
85
+ else:
86
+ if start is None:
87
+ start = np.random.randint(0, x.shape[0]-length, size=None)
88
+ end = length+start
89
+ if end > x.shape[0]:
90
+ x_ = x[start:]
91
+ x_ = np.concatenate((x_, x[:length-x.shape[0]]))
92
+ else:
93
+ x_ = x[start:length+start]
94
+
95
+ return x_[:length]
96
+
97
+ def fadeIn(x, length=128):
98
+
99
+ w = scipy.signal.hann(length*2, sym=True)
100
+ w1 = w[0:length]
101
+ ones = np.ones(int(x.shape[0]-length))
102
+ w = np.append(w1, ones)
103
+
104
+ return x*w
105
+
106
+ def fadeOut(x, length=128):
107
+
108
+ w = scipy.signal.hann(length*2, sym=True)
109
+ w2 = w[length:length*2]
110
+ ones = np.ones(int(x.shape[0]-length))
111
+ w = np.append(ones, w2)
112
+
113
+ return x*w
114
+
115
+
116
+ def plotTimeFreq(audio, sr, n_fft=512, hop_length=128, ylabels=None):
117
+
118
+ n = len(audio)
119
+ # plt.figure(figsize=(14, 4*n))
120
+ colors = list(plt.cm.viridis(np.linspace(0,1,n)))
121
+
122
+ X = []
123
+ X_db = []
124
+ maxs = np.zeros((n,))
125
+ mins = np.zeros((n,))
126
+ maxs_t = np.zeros((n,))
127
+ for i, x in enumerate(audio):
128
+
129
+ if x.ndim == 2 and x.shape[-1] == 2:
130
+ x = librosa.core.to_mono(x.T)
131
+ X_ = librosa.stft(x, n_fft=n_fft, hop_length=hop_length)
132
+ X_db_ = librosa.amplitude_to_db(abs(X_))
133
+ X.append(X_)
134
+ X_db.append(X_db_)
135
+ maxs[i] = np.max(X_db_)
136
+ mins[i] = np.min(X_db_)
137
+ maxs_t[i] = np.max(np.abs(x))
138
+ vmax = np.max(maxs)
139
+ vmin = np.min(mins)
140
+ tmax = np.max(maxs_t)
141
+ for i, x in enumerate(audio):
142
+
143
+ if x.ndim == 2 and x.shape[-1] == 2:
144
+ x = librosa.core.to_mono(x.T)
145
+
146
+ plt.subplot(n, 2, 2*i+1)
147
+ librosa.display.waveplot(x, sr=sr, color=colors[i])
148
+ if ylabels:
149
+ plt.ylabel(ylabels[i])
150
+
151
+ plt.ylim(-tmax,tmax)
152
+ plt.subplot(n, 2, 2*i+2)
153
+ librosa.display.specshow(X_db[i], sr=sr, x_axis='time', y_axis='log',
154
+ hop_length=hop_length, cmap='GnBu', vmax=vmax, vmin=vmin)
155
+ # plt.colorbar(format='%+2.0f dB')
156
+
157
+
158
+
159
+
160
+
161
+
162
+
163
+
164
+ def slicing(x, win_length, hop_length, center = True, windowing = False, pad = 0):
165
+ # Pad the time series so that frames are centered
166
+ if center:
167
+ # x = np.pad(x, int((win_length-hop_length+pad) // 2), mode='constant')
168
+ x = np.pad(x, ((int((win_length-hop_length+pad)//2), int((win_length+hop_length+pad)//2)),), mode='constant')
169
+
170
+ # Window the time series.
171
+ y_frames = librosa.util.frame(x, frame_length=win_length, hop_length=hop_length)
172
+ if windowing:
173
+ window = scipy.signal.hann(win_length, sym=False)
174
+ else:
175
+ window = 1.0
176
+ f = []
177
+ for i in range(len(y_frames.T)):
178
+ f.append(y_frames.T[i]*window)
179
+ return np.float32(np.asarray(f))
180
+
181
+
182
+ def overlap(x, x_len, win_length, hop_length, windowing = True, rate = 1):
183
+ x = x.reshape(x.shape[0],x.shape[1]).T
184
+ if windowing:
185
+ window = scipy.signal.hann(win_length, sym=False)
186
+ rate = rate*hop_length/win_length
187
+ else:
188
+ window = 1
189
+ rate = 1
190
+ n_frames = x_len / hop_length
191
+ expected_signal_len = int(win_length + hop_length * (n_frames))
192
+ y = np.zeros(expected_signal_len)
193
+ for i in range(int(n_frames)):
194
+ sample = i * hop_length
195
+ w = x[:, i]
196
+ y[sample:(sample + win_length)] = y[sample:(sample + win_length)] + w*window
197
+ y = y[int(win_length // 2):-int(win_length // 2)]
198
+ return np.float32(y*rate)
199
+
200
+
201
+
202
+
203
+
204
+
205
+
206
+ def highpassFiltering(x_list, f0, sr):
207
+
208
+ b1, a1 = scipy.signal.butter(4, f0/(sr/2),'highpass')
209
+ x_f = []
210
+ for x in x_list:
211
+ x_f_ = scipy.signal.filtfilt(b1, a1, x).copy(order='F')
212
+ x_f.append(x_f_)
213
+ return x_f
214
+
215
+ def lineartodB(x):
216
+ return 20*np.log10(x)
217
+ def dBtoLinear(x):
218
+ return np.power(10,x/20)
219
+
220
+ def lufs_normalize(x, sr, lufs, log=True):
221
+
222
+ # measure the loudness first
223
+ meter = pyloudnorm.Meter(sr) # create BS.1770 meter
224
+ loudness = meter.integrated_loudness(x+1e-10)
225
+ if log:
226
+ print("original loudness: ", loudness," max value: ", np.max(np.abs(x)))
227
+
228
+ loudness_normalized_audio = pyloudnorm.normalize.loudness(x, loudness, lufs)
229
+
230
+ maxabs_amp = np.maximum(1.0, 1e-6 + np.max(np.abs(loudness_normalized_audio)))
231
+ loudness_normalized_audio /= maxabs_amp
232
+
233
+ loudness = meter.integrated_loudness(loudness_normalized_audio)
234
+ if log:
235
+ print("new loudness: ", loudness," max value: ", np.max(np.abs(loudness_normalized_audio)))
236
+
237
+
238
+ return loudness_normalized_audio
239
+
240
+ import soxbindings as sox
241
+
242
+ def lufs_normalize_compand(x, sr, lufs):
243
+
244
+ tfm = sox.Transformer()
245
+ tfm.compand(attack_time = 0.001,
246
+ decay_time = 0.01,
247
+ soft_knee_db = 1.0,
248
+ tf_points = [(-70, -70), (-0.1, -20), (0, 0)])
249
+
250
+ x = tfm.build_array(input_array=x, sample_rate_in=sr).astype(np.float32)
251
+
252
+ # measure the loudness first
253
+ meter = pyloudnorm.Meter(sr) # create BS.1770 meter
254
+ loudness = meter.integrated_loudness(x)
255
+ print("original loudness: ", loudness," max value: ", np.max(np.abs(x)))
256
+
257
+ loudness_normalized_audio = pyloudnorm.normalize.loudness(x, loudness, lufs)
258
+
259
+ maxabs_amp = np.maximum(1.0, 1e-6 + np.max(np.abs(loudness_normalized_audio)))
260
+ loudness_normalized_audio /= maxabs_amp
261
+
262
+ loudness = meter.integrated_loudness(loudness_normalized_audio)
263
+ print("new loudness: ", loudness," max value: ", np.max(np.abs(loudness_normalized_audio)))
264
+
265
+
266
+
267
+
268
+
269
+
270
+ return loudness_normalized_audio
271
+
272
+
273
+
274
+
275
+
276
+ def getDistances(x,y):
277
+
278
+ distances = {}
279
+ distances['mae'] = mean_absolute_error(x, y)
280
+ distances['mse'] = mean_squared_error(x, y)
281
+ distances['euclidean'] = np.mean(paired_distances(x, y, metric='euclidean'))
282
+ distances['manhattan'] = np.mean(paired_distances(x, y, metric='manhattan'))
283
+ distances['cosine'] = np.mean(paired_distances(x, y, metric='cosine'))
284
+
285
+ distances['mae'] = round(distances['mae'], 5)
286
+ distances['mse'] = round(distances['mse'], 5)
287
+ distances['euclidean'] = round(distances['euclidean'], 5)
288
+ distances['manhattan'] = round(distances['manhattan'], 5)
289
+ distances['cosine'] = round(distances['cosine'], 5)
290
+
291
+ return distances
292
+
293
+ def getMFCC(x, sr, mels=128, mfcc=13, mean_norm=False):
294
+
295
+ melspec = librosa.feature.melspectrogram(y=x, sr=sr, S=None,
296
+ n_fft=1024, hop_length=256,
297
+ n_mels=mels, power=2.0)
298
+ melspec_dB = librosa.power_to_db(melspec, ref=np.max)
299
+ mfcc = librosa.feature.mfcc(S=melspec_dB, sr=sr, n_mfcc=mfcc)
300
+ if mean_norm:
301
+ mfcc -= (np.mean(mfcc, axis=0))
302
+ return mfcc
303
+
304
+
305
+ def getMSE_MFCC(y_true, y_pred, sr, mels=128, mfcc=13, mean_norm=False):
306
+
307
+ ratio = np.mean(np.abs(y_true))/np.mean(np.abs(y_pred))
308
+ y_pred = ratio*y_pred
309
+
310
+ y_mfcc = getMFCC(y_true, sr, mels=mels, mfcc=mfcc, mean_norm=mean_norm)
311
+ z_mfcc = getMFCC(y_pred, sr, mels=mels, mfcc=mfcc, mean_norm=mean_norm)
312
+
313
+ return getDistances(y_mfcc[:,:], z_mfcc[:,:])
mixing_style_transfer/mixing_manipulator/normalization_imager.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of the normalization process of stereo-imaging and panning effects
3
+ """
4
+ import numpy as np
5
+ import sys
6
+ import os
7
+
8
+ currentdir = os.path.dirname(os.path.realpath(__file__))
9
+ sys.path.append(currentdir)
10
+ from common_audioeffects import AugmentationChain, Haas
11
+
12
+
13
+ '''
14
+ ### normalization algorithm for stereo imaging and panning effects ###
15
+ process :
16
+ 1. inputs 2-channeled audio
17
+ 2. apply Haas effects if the input audio is almost mono
18
+ 3. normalize mid-side channels according to target precomputed feature value
19
+ 4. normalize left-right channels 50-50
20
+ 5. normalize mid-side channels again
21
+ '''
22
+ def normalize_imager(data, \
23
+ target_side_mid_bal=0.9, \
24
+ mono_threshold=0.95, \
25
+ sr=44100, \
26
+ eps=1e-04, \
27
+ verbose=False):
28
+
29
+ # to mid-side channels
30
+ mid, side = lr_to_ms(data[:,0], data[:,1])
31
+
32
+ if verbose:
33
+ print_balance(data[:,0], data[:,1])
34
+ print_balance(mid, side)
35
+ print()
36
+
37
+ # apply mid-side weights according to energy
38
+ mid_e, side_e = np.sum(mid**2), np.sum(side**2)
39
+ total_e = mid_e + side_e
40
+ # apply haas effect to almost-mono signal
41
+ if mid_e/total_e > mono_threshold:
42
+ aug_chain = AugmentationChain(fxs=[(Haas(sample_rate=sr), 1, True)])
43
+ data = aug_chain([data])[0]
44
+ mid, side = lr_to_ms(data[:,0], data[:,1])
45
+
46
+ if verbose:
47
+ print_balance(data[:,0], data[:,1])
48
+ print_balance(mid, side)
49
+ print()
50
+
51
+ # normalize mid-side channels (stereo imaging)
52
+ new_mid, new_side = process_balance(mid, side, tgt_e1_bal=target_side_mid_bal, eps=eps)
53
+ left, right = ms_to_lr(new_mid, new_side)
54
+ imaged = np.stack([left, right], 1)
55
+
56
+ if verbose:
57
+ print_balance(new_mid, new_side)
58
+ print_balance(left, right)
59
+ print()
60
+
61
+ # normalize panning to have the balance of left-right channels 50-50
62
+ left, right = process_balance(left, right, tgt_e1_bal=0.5, eps=eps)
63
+ mid, side = lr_to_ms(left, right)
64
+
65
+ if verbose:
66
+ print_balance(mid, side)
67
+ print_balance(left, right)
68
+ print()
69
+
70
+ # normalize again mid-side channels (stereo imaging)
71
+ new_mid, new_side = process_balance(mid, side, tgt_e1_bal=target_side_mid_bal, eps=eps)
72
+ left, right = ms_to_lr(new_mid, new_side)
73
+ imaged = np.stack([left, right], 1)
74
+
75
+ if verbose:
76
+ print_balance(new_mid, new_side)
77
+ print_balance(left, right)
78
+ print()
79
+
80
+ return imaged
81
+
82
+
83
+ # balance out 2 input data's energy according to given balance
84
+ # tgt_e1_bal range = [0.0, 1.0]
85
+ # tgt_e2_bal = 1.0 - tgt_e1_bal_range
86
+ def process_balance(data_1, data_2, tgt_e1_bal=0.5, eps=1e-04):
87
+
88
+ e_1, e_2 = np.sum(data_1**2), np.sum(data_2**2)
89
+ total_e = e_1 + e_2
90
+
91
+ tgt_1_gain = np.sqrt(tgt_e1_bal * total_e / (e_1 + eps))
92
+
93
+ new_data_1 = data_1 * tgt_1_gain
94
+ new_e_1 = e_1 * (tgt_1_gain ** 2)
95
+ left_e_1 = total_e - new_e_1
96
+ tgt_2_gain = np.sqrt(left_e_1 / (e_2 + 1e-3))
97
+ new_data_2 = data_2 * tgt_2_gain
98
+
99
+ return new_data_1, new_data_2
100
+
101
+
102
+ # left-right channeled signal to mid-side signal
103
+ def lr_to_ms(left, right):
104
+ mid = left + right
105
+ side = left - right
106
+ return mid, side
107
+
108
+
109
+ # mid-side channeled signal to left-right signal
110
+ def ms_to_lr(mid, side):
111
+ left = (mid + side) / 2
112
+ right = (mid - side) / 2
113
+ return left, right
114
+
115
+
116
+ # print energy balance of 2 inputs
117
+ def print_balance(data_1, data_2):
118
+ e_1, e_2 = np.sum(data_1**2), np.sum(data_2**2)
119
+ total_e = e_1 + e_2
120
+ print(total_e, e_1/total_e, e_2/total_e)
121
+
mixing_style_transfer/mixing_manipulator/utils_data_normalization.py ADDED
@@ -0,0 +1,906 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import sys
4
+ import time
5
+ import numpy as np
6
+ import scipy
7
+ import librosa
8
+ import pyloudnorm as pyln
9
+
10
+ sys.setrecursionlimit(int(1e6))
11
+
12
+ import sklearn
13
+
14
+ currentdir = os.path.dirname(os.path.realpath(__file__))
15
+ sys.path.append(currentdir)
16
+ from common_miscellaneous import compute_stft, compute_istft
17
+ from common_audioeffects import Panner, Compressor, AugmentationChain, ConvolutionalReverb, Equaliser, AlgorithmicReverb
18
+ import fx_utils
19
+
20
+ import soundfile as sf
21
+ import aubio
22
+
23
+ import time
24
+
25
+ import warnings
26
+
27
+ # Functions
28
+
29
+ def print_dict(dict_):
30
+ for i in dict_:
31
+ print(i)
32
+ for j in dict_[i]:
33
+ print('\t', j)
34
+
35
+ def amp_to_db(x):
36
+ return 20*np.log10(x + 1e-30)
37
+
38
+ def db_to_amp(x):
39
+ return 10**(x/20)
40
+
41
+ def get_running_stats(x, features, N=20):
42
+ mean = []
43
+ std = []
44
+ for i in range(len(features)):
45
+ mean_, std_ = running_mean_std(x[:,i], N)
46
+ mean.append(mean_)
47
+ std.append(std_)
48
+ mean = np.asarray(mean)
49
+ std = np.asarray(std)
50
+
51
+ return mean, std
52
+
53
+ def running_mean_std(x, N):
54
+
55
+ with warnings.catch_warnings():
56
+ warnings.simplefilter("ignore", category=RuntimeWarning)
57
+ cumsum = np.cumsum(np.insert(x, 0, 0))
58
+ cumsum2 = np.cumsum(np.insert(x**2, 0, 0))
59
+ mean = (cumsum[N:] - cumsum[:-N]) / float(N)
60
+
61
+ std = np.sqrt(((cumsum2[N:] - cumsum2[:-N]) / N) - (mean * mean))
62
+
63
+ return mean, std
64
+
65
+ def get_eq_matching(audio_t, ref_spec, sr=44100, n_fft=65536, hop_length=16384,
66
+ min_db=-50, ntaps=101, lufs=-30):
67
+
68
+ audio_t = np.copy(audio_t)
69
+ max_db = amp_to_db(np.max(np.abs(audio_t)))
70
+ if max_db > min_db:
71
+
72
+ audio_t = fx_utils.lufs_normalize(audio_t, sr, lufs, log=False)
73
+ audio_D = compute_stft(np.expand_dims(audio_t, 1),
74
+ hop_length,
75
+ n_fft,
76
+ np.sqrt(np.hanning(n_fft+1)[:-1]))
77
+ audio_D = np.abs(audio_D)
78
+ audio_D_avg = np.mean(audio_D, axis=0)[0]
79
+
80
+ m = ref_spec.shape[0]
81
+
82
+ Ts = 1.0/sr # sampling interval
83
+ n = m # length of the signal
84
+ kk = np.arange(n)
85
+ T = n/sr
86
+ frq = kk/T # two sides frequency range
87
+ frq /=2
88
+
89
+ diff_eq = amp_to_db(ref_spec)-amp_to_db(audio_D_avg)
90
+ diff_eq = db_to_amp(diff_eq)
91
+ diff_eq = np.sqrt(diff_eq)
92
+
93
+ diff_filter = scipy.signal.firwin2(ntaps,
94
+ frq/np.max(frq),
95
+ diff_eq,
96
+ nfreqs=None, window='hamming',
97
+ nyq=None, antisymmetric=False)
98
+
99
+
100
+ output = scipy.signal.filtfilt(diff_filter, 1, audio_t,
101
+ axis=-1, padtype='odd', padlen=None,
102
+ method='pad', irlen=None)
103
+
104
+ else:
105
+ output = audio_t
106
+
107
+ return output
108
+
109
+ def get_SPS(x, n_fft=2048, hop_length=1024, smooth=False, frames=False):
110
+
111
+ x = np.copy(x)
112
+ eps = 1e-20
113
+
114
+ audio_D = compute_stft(x,
115
+ hop_length,
116
+ n_fft,
117
+ np.sqrt(np.hanning(n_fft+1)[:-1]))
118
+
119
+ audio_D_l = np.abs(audio_D[:, 0, :] + eps)
120
+ audio_D_r = np.abs(audio_D[:, 1, :] + eps)
121
+
122
+ phi = 2 * (np.abs(audio_D_l*np.conj(audio_D_r)))/(np.abs(audio_D_l)**2+np.abs(audio_D_r)**2)
123
+
124
+ phi_l = np.abs(audio_D_l*np.conj(audio_D_r))/(np.abs(audio_D_l)**2)
125
+ phi_r = np.abs(audio_D_r*np.conj(audio_D_l))/(np.abs(audio_D_r)**2)
126
+ delta = phi_l - phi_r
127
+ delta_ = np.sign(delta)
128
+ SPS = (1-phi)*delta_
129
+
130
+ phi_mean = np.mean(phi, axis=0)
131
+ if smooth:
132
+ phi_mean = scipy.signal.savgol_filter(phi_mean, 501, 1, mode='mirror')
133
+
134
+ SPS_mean = np.mean(SPS, axis=0)
135
+ if smooth:
136
+ SPS_mean = scipy.signal.savgol_filter(SPS_mean, 501, 1, mode='mirror')
137
+
138
+
139
+ return SPS_mean, phi_mean, SPS, phi
140
+
141
+
142
+ def get_mean_side(sps, freqs=[50,2500], sr=44100, n_fft=2048):
143
+
144
+ sign = np.sign(sps+ 1e-10)
145
+
146
+ idx1 = freqs[0]
147
+ idx2 = freqs[1]
148
+
149
+ f1 = int(np.floor(idx1*n_fft/sr))
150
+ f2 = int(np.floor(idx2*n_fft/sr))
151
+
152
+ sign_mean = np.mean(sign[f1:f2])/np.abs(np.mean(sign[f1:f2]))
153
+ sign_mean
154
+
155
+ return sign_mean
156
+
157
+ def get_panning_param_values(phi, side):
158
+
159
+ p = np.zeros_like(phi)
160
+
161
+ g = (np.clip(phi+1e-30, 0, 1))/2
162
+
163
+ for i, g_ in enumerate(g):
164
+
165
+ if side > 0:
166
+ p[i] = 1 - g_
167
+
168
+ elif side < 0:
169
+ p[i] = g_
170
+
171
+ else:
172
+ p[i] = 0.5
173
+
174
+ g_l = 1-p
175
+ g_r = p
176
+
177
+ return p, [g_l, g_r]
178
+
179
+ def get_panning_matching(audio, ref_phi,
180
+ sr=44100, n_fft=2048, hop_length=1024,
181
+ min_db_f=-10, max_freq_pan=16000, frames=True):
182
+
183
+ eps = 1e-20
184
+ window = np.sqrt(np.hanning(n_fft+1)[:-1])
185
+ audio = np.copy(audio)
186
+ audio_t = np.pad(audio, ((n_fft, n_fft), (0, 0)), mode='constant')
187
+
188
+ sps_mean_, phi_mean_, _, _ = get_SPS(audio_t, n_fft=n_fft, hop_length=hop_length, smooth=True)
189
+
190
+ side = get_mean_side(sps_mean_, sr=sr, n_fft=n_fft)
191
+
192
+ if side > 0:
193
+ alpha = 0.7
194
+ else:
195
+ alpha = 0.3
196
+
197
+ processor = Panner()
198
+ processor.parameters.pan.value = alpha
199
+ processor.parameters.pan_law.value = 'linear'
200
+ processor.update()
201
+ audio_t_ = processor.process(audio_t)
202
+
203
+ sps_mean_, phi_mean, sps_frames, phi_frames = get_SPS(audio_t_, n_fft=n_fft,
204
+ hop_length=hop_length,
205
+ smooth=True, frames=frames)
206
+
207
+ if frames:
208
+
209
+ p_i_ = []
210
+ g_i_ = []
211
+ p_ref = []
212
+ g_ref = []
213
+ for i in range(len(sps_frames)):
214
+ sps_ = sps_frames[i]
215
+ phi_ = phi_frames[i]
216
+ p_, g_ = get_panning_param_values(phi_, side)
217
+ p_i_.append(p_)
218
+ g_i_.append(g_)
219
+ p_, g_ = get_panning_param_values(ref_phi, side)
220
+ p_ref.append(p_)
221
+ g_ref.append(g_)
222
+ ratio = (np.asarray(g_ref)/(np.asarray(g_i_)+eps))
223
+ g_l = ratio[:,0,:]
224
+ g_r = ratio[:,1,:]
225
+
226
+
227
+ else:
228
+ p, g = get_panning_param_values(ref_phi, side)
229
+ p_i, g_i = get_panning_param_values(phi_mean, side)
230
+ ratio = (np.asarray(g)/np.asarray(g_i))
231
+ g_l = ratio[0]
232
+ g_r = ratio[1]
233
+
234
+ audio_new_D = compute_stft(audio_t_,
235
+ hop_length,
236
+ n_fft,
237
+ window)
238
+
239
+ audio_new_D_mono = audio_new_D.copy()
240
+ audio_new_D_mono = audio_new_D_mono[:, 0, :] + audio_new_D_mono[:, 1, :]
241
+ audio_new_D_mono = np.abs(audio_new_D_mono)
242
+
243
+ audio_new_D_phase = np.angle(audio_new_D)
244
+ audio_new_D = np.abs(audio_new_D)
245
+
246
+ audio_new_D_l = audio_new_D[:, 0, :]
247
+ audio_new_D_r = audio_new_D[:, 1, :]
248
+
249
+ if frames:
250
+ for i, frame in enumerate(audio_new_D_mono):
251
+ max_db = amp_to_db(np.max(np.abs(frame)))
252
+ if max_db < min_db_f:
253
+ g_r[i] = np.ones_like(frame)
254
+ g_l[i] = np.ones_like(frame)
255
+
256
+ idx1 = max_freq_pan
257
+ f1 = int(np.floor(idx1*n_fft/sr))
258
+ ones = np.ones_like(g_l)
259
+ g_l[f1:] = ones[f1:]
260
+ g_r[f1:] = ones[f1:]
261
+
262
+ audio_new_D_l = audio_new_D_l*g_l
263
+ audio_new_D_r = audio_new_D_r*g_r
264
+
265
+ audio_new_D_l = np.expand_dims(audio_new_D_l, 0)
266
+ audio_new_D_r = np.expand_dims(audio_new_D_r, 0)
267
+
268
+ audio_new_D_ = np.concatenate((audio_new_D_l,audio_new_D_r))
269
+
270
+ audio_new_D_ = np.moveaxis(audio_new_D_, 0, 1)
271
+
272
+ audio_new_D_ = audio_new_D_ * (np.cos(audio_new_D_phase) + np.sin(audio_new_D_phase)*1j)
273
+
274
+ audio_new_t = compute_istft(audio_new_D_,
275
+ hop_length,
276
+ window)
277
+
278
+ audio_new_t = audio_new_t[n_fft:n_fft+audio.shape[0]]
279
+
280
+ return audio_new_t
281
+
282
+
283
+
284
+ def get_mean_peak(audio, sr=44100, true_peak=False, n_mels=128, percentile=75):
285
+
286
+ # Returns mean peak value in dB after the 1Q is removed.
287
+ # Input should be in the shape samples x channel
288
+
289
+ audio_ = audio
290
+ window_size = 2**10 # FFT size
291
+ hop_size = window_size
292
+
293
+ peak = []
294
+ std = []
295
+ for ch in range(audio_.shape[-1]):
296
+ x = np.ascontiguousarray(audio_[:, ch])
297
+
298
+ if true_peak:
299
+ x = librosa.resample(x, sr, 4*sr)
300
+ sr = 4*sr
301
+ window_size = 4*window_size
302
+ hop_size = 4*hop_size
303
+
304
+ onset_func = aubio.onset('hfc', buf_size=window_size, hop_size=hop_size, samplerate=sr)
305
+
306
+ frames = np.float32(librosa.util.frame(x, frame_length=window_size, hop_length=hop_size))
307
+
308
+ onset_times = []
309
+ for frame in frames.T:
310
+
311
+ if onset_func(frame):
312
+
313
+ onset_time = onset_func.get_last()
314
+ onset_times.append(onset_time)
315
+
316
+ samples=[]
317
+ if onset_times:
318
+ for i, p in enumerate(onset_times[:-1]):
319
+ samples.append(onset_times[i]+np.argmax(np.abs(x[onset_times[i]:onset_times[i+1]])))
320
+ samples.append(onset_times[-1]+np.argmax(np.abs(x[onset_times[-1]:])))
321
+
322
+ p_value = []
323
+ for p in samples:
324
+ p_ = amp_to_db(np.abs(x[p]))
325
+ p_value.append(p_)
326
+ p_value_=[]
327
+ for p in p_value:
328
+ if p > np.percentile(p_value, percentile):
329
+ p_value_.append(p)
330
+ if p_value_:
331
+ peak.append(np.mean(p_value_))
332
+ std.append(np.std(p_value_))
333
+ elif p_value:
334
+ peak.append(np.mean(p_value))
335
+ std.append(np.std(p_value))
336
+ else:
337
+ return None
338
+ return [np.mean(peak), np.mean(std)]
339
+
340
+ def compress(processor, audio, sr, th, ratio, attack, release):
341
+
342
+ eps = 1e-20
343
+ x = audio
344
+
345
+ processor.parameters.threshold.value = th
346
+ processor.parameters.ratio.value = ratio
347
+ processor.parameters.attack_time.value = attack
348
+ processor.parameters.release_time.value = release
349
+ processor.update()
350
+ output = processor.process(x)
351
+
352
+ if np.max(np.abs(output)) >= 1.0:
353
+ output = np.clip(output, -1.0, 1.0)
354
+
355
+ return output
356
+
357
+ def get_comp_matching(audio,
358
+ ref_peak, ref_std,
359
+ ratio, attack, release, sr=44100,
360
+ min_db=-50, comp_peak_norm=-10.0,
361
+ min_th=-40, max_ratio=20, n_mels=128,
362
+ true_peak=False, percentile=75, expander=True):
363
+
364
+ x = audio.copy()
365
+
366
+ if x.ndim < 2:
367
+ x = np.expand_dims(x, 1)
368
+
369
+ max_db = amp_to_db(np.max(np.abs(x)))
370
+ if max_db > min_db:
371
+
372
+ x = pyln.normalize.peak(x, comp_peak_norm)
373
+
374
+ peak, std = get_mean_peak(x, sr,
375
+ n_mels=n_mels,
376
+ true_peak=true_peak,
377
+ percentile=percentile)
378
+
379
+ if peak > (ref_peak - ref_std) and peak < (ref_peak + ref_std):
380
+ return x
381
+
382
+ # DownwardCompress
383
+ elif peak > (ref_peak - ref_std):
384
+ processor = Compressor(sample_rate=sr)
385
+ # print('compress')
386
+ ratios = np.linspace(ratio, max_ratio, max_ratio-ratio+1)
387
+ ths = np.linspace(-1-9, min_th, 2*np.abs(min_th)-1-18)
388
+ for rt in ratios:
389
+ for th in ths:
390
+ y = compress(processor, x, sr, th, rt, attack, release)
391
+ peak, std = get_mean_peak(y, sr,
392
+ n_mels=n_mels,
393
+ true_peak=true_peak,
394
+ percentile=percentile)
395
+ if peak < (ref_peak + ref_std):
396
+ break
397
+ else:
398
+ continue
399
+ break
400
+
401
+ return y
402
+
403
+ # Upward Expand
404
+ elif peak < (ref_peak + ref_std):
405
+
406
+ if expander:
407
+ processor = Compressor(sample_rate=sr)
408
+ ratios = np.linspace(ratio, max_ratio, max_ratio-ratio+1)
409
+ ths = np.linspace(-1, min_th, 2*np.abs(min_th)-1)[::-1]
410
+
411
+ for rt in ratios:
412
+ for th in ths:
413
+ y = compress(processor, x, sr, th, 1/rt, attack, release)
414
+ peak, std = get_mean_peak(y, sr,
415
+ n_mels=n_mels,
416
+ true_peak=true_peak,
417
+ percentile=percentile)
418
+ if peak > (ref_peak - ref_std):
419
+ break
420
+ else:
421
+ continue
422
+ break
423
+
424
+ return y
425
+
426
+ else:
427
+ return x
428
+ else:
429
+ return x
430
+
431
+
432
+
433
+ # REVERB
434
+
435
+
436
+ def get_reverb_send(audio, eq_parameters, rv_parameters, impulse_responses=None,
437
+ eq_prob=1.0, rv_prob=1.0, parallel=True, shuffle=False, sr=44100, bands=['low_shelf', 'high_shelf']):
438
+
439
+ x = audio.copy()
440
+
441
+ if x.ndim < 2:
442
+ x = np.expand_dims(x, 1)
443
+
444
+ channels = x.shape[-1]
445
+ eq_gain = eq_parameters.low_shelf_gain.value
446
+
447
+
448
+ eq = Equaliser(n_channels=channels,
449
+ sample_rate=sr,
450
+ gain_range=(eq_gain, eq_gain),
451
+ bands=bands,
452
+ hard_clip=False,
453
+ name='Equaliser', parameters=eq_parameters)
454
+ eq.randomize()
455
+
456
+ if impulse_responses:
457
+
458
+ reverb = ConvolutionalReverb(impulse_responses=impulse_responses,
459
+ sample_rate=sr,
460
+ parameters=rv_parameters)
461
+
462
+ else:
463
+
464
+ reverb = AlgorithmicReverb(sample_rate=sr,
465
+ parameters=rv_parameters)
466
+
467
+ reverb.randomize()
468
+
469
+ fxchain = AugmentationChain([
470
+ (eq, rv_prob, False),
471
+ (reverb, eq_prob, False)
472
+ ],
473
+ shuffle=shuffle, parallel=parallel)
474
+
475
+ output = fxchain(x)
476
+
477
+ return output
478
+
479
+
480
+
481
+ # FUNCTIONS TO COMPUTE FEATURES
482
+
483
+ def compute_loudness_features(args_):
484
+
485
+ audio_out_ = args_[0]
486
+ audio_tar_ = args_[1]
487
+ idx = args_[2]
488
+ sr = args_[3]
489
+
490
+ loudness_ = {key:[] for key in ['d_lufs', 'd_peak',]}
491
+
492
+ peak_tar = np.max(np.abs(audio_tar_))
493
+ peak_tar_db = 20.0 * np.log10(peak_tar)
494
+
495
+ peak_out = np.max(np.abs(audio_out_))
496
+ peak_out_db = 20.0 * np.log10(peak_out)
497
+
498
+ with warnings.catch_warnings():
499
+ warnings.simplefilter("ignore", category=RuntimeWarning)
500
+ meter = pyln.Meter(sr) # create BS.1770 meter
501
+ loudness_tar = meter.integrated_loudness(audio_tar_)
502
+ loudness_out = meter.integrated_loudness(audio_out_)
503
+
504
+ loudness_['d_lufs'].append(sklearn.metrics.mean_absolute_percentage_error([loudness_tar], [loudness_out]))
505
+ loudness_['d_peak'].append(sklearn.metrics.mean_absolute_percentage_error([peak_tar_db], [peak_out_db]))
506
+
507
+ return loudness_
508
+
509
+ def compute_spectral_features(args_):
510
+
511
+ audio_out_ = args_[0]
512
+ audio_tar_ = args_[1]
513
+ idx = args_[2]
514
+ sr = args_[3]
515
+ fft_size = args_[4]
516
+ hop_length = args_[5]
517
+ channels = args_[6]
518
+
519
+ audio_out_ = pyln.normalize.peak(audio_out_, -1.0)
520
+ audio_tar_ = pyln.normalize.peak(audio_tar_, -1.0)
521
+
522
+ spec_out_ = compute_stft(audio_out_,
523
+ hop_length,
524
+ fft_size,
525
+ np.sqrt(np.hanning(fft_size+1)[:-1]))
526
+ spec_out_ = np.transpose(spec_out_, axes=[1, -1, 0])
527
+ spec_out_ = np.abs(spec_out_)
528
+
529
+ spec_tar_ = compute_stft(audio_tar_,
530
+ hop_length,
531
+ fft_size,
532
+ np.sqrt(np.hanning(fft_size+1)[:-1]))
533
+ spec_tar_ = np.transpose(spec_tar_, axes=[1, -1, 0])
534
+ spec_tar_ = np.abs(spec_tar_)
535
+
536
+ spectral_ = {key:[] for key in ['centroid_mean',
537
+ 'bandwidth_mean',
538
+ 'contrast_l_mean',
539
+ 'contrast_m_mean',
540
+ 'contrast_h_mean',
541
+ 'rolloff_mean',
542
+ 'flatness_mean',
543
+ 'mape_mean',
544
+ ]}
545
+
546
+ centroid_mean_ = []
547
+ centroid_std_ = []
548
+ bandwidth_mean_ = []
549
+ bandwidth_std_ = []
550
+ contrast_l_mean_ = []
551
+ contrast_l_std_ = []
552
+ contrast_m_mean_ = []
553
+ contrast_m_std_ = []
554
+ contrast_h_mean_ = []
555
+ contrast_h_std_ = []
556
+ rolloff_mean_ = []
557
+ rolloff_std_ = []
558
+ flatness_mean_ = []
559
+
560
+ for ch in range(channels):
561
+ tar = spec_tar_[ch]
562
+ out = spec_out_[ch]
563
+
564
+ tar_sc = librosa.feature.spectral_centroid(y=None, sr=sr, S=tar,
565
+ n_fft=fft_size, hop_length=hop_length)
566
+
567
+ out_sc = librosa.feature.spectral_centroid(y=None, sr=sr, S=out,
568
+ n_fft=fft_size, hop_length=hop_length)
569
+
570
+ tar_bw = librosa.feature.spectral_bandwidth(y=None, sr=sr, S=tar,
571
+ n_fft=fft_size, hop_length=hop_length,
572
+ centroid=tar_sc, norm=True, p=2)
573
+
574
+ out_bw = librosa.feature.spectral_bandwidth(y=None, sr=sr, S=out,
575
+ n_fft=fft_size, hop_length=hop_length,
576
+ centroid=out_sc, norm=True, p=2)
577
+ # l = 0-250, m = 1-2-3 = 250 - 2000, h = 2000 - SR/2
578
+ tar_ct = librosa.feature.spectral_contrast(y=None, sr=sr, S=tar,
579
+ n_fft=fft_size, hop_length=hop_length,
580
+ fmin=250.0, n_bands=4, quantile=0.02, linear=False)
581
+
582
+ out_ct = librosa.feature.spectral_contrast(y=None, sr=sr, S=out,
583
+ n_fft=fft_size, hop_length=hop_length,
584
+ fmin=250.0, n_bands=4, quantile=0.02, linear=False)
585
+
586
+ tar_ro = librosa.feature.spectral_rolloff(y=None, sr=sr, S=tar,
587
+ n_fft=fft_size, hop_length=hop_length,
588
+ roll_percent=0.85)
589
+
590
+ out_ro = librosa.feature.spectral_rolloff(y=None, sr=sr, S=out,
591
+ n_fft=fft_size, hop_length=hop_length,
592
+ roll_percent=0.85)
593
+
594
+ tar_ft = librosa.feature.spectral_flatness(y=None, S=tar,
595
+ n_fft=fft_size, hop_length=hop_length,
596
+ amin=1e-10, power=2.0)
597
+
598
+ out_ft = librosa.feature.spectral_flatness(y=None, S=out,
599
+ n_fft=fft_size, hop_length=hop_length,
600
+ amin=1e-10, power=2.0)
601
+
602
+
603
+ eps = 1e-0
604
+ N = 40
605
+ mean_sc_tar, std_sc_tar = get_running_stats(tar_sc.T+eps, [0], N=N)
606
+ mean_sc_out, std_sc_out = get_running_stats(out_sc.T+eps, [0], N=N)
607
+
608
+ assert np.isnan(mean_sc_tar).any() == False, f'NAN values mean_sc_tar {idx}'
609
+ assert np.isnan(mean_sc_out).any() == False, f'NAN values mean_sc_out {idx}'
610
+
611
+
612
+ mean_bw_tar, std_bw_tar = get_running_stats(tar_bw.T+eps, [0], N=N)
613
+ mean_bw_out, std_bw_out = get_running_stats(out_bw.T+eps, [0], N=N)
614
+
615
+ assert np.isnan(mean_bw_tar).any() == False, f'NAN values tar mean bw {idx}'
616
+ assert np.isnan(mean_bw_out).any() == False, f'NAN values out mean bw {idx}'
617
+
618
+ mean_ct_tar, std_ct_tar = get_running_stats(tar_ct.T, list(range(tar_ct.shape[0])), N=N)
619
+ mean_ct_out, std_ct_out = get_running_stats(out_ct.T, list(range(out_ct.shape[0])), N=N)
620
+
621
+ assert np.isnan(mean_ct_tar).any() == False, f'NAN values tar mean ct {idx}'
622
+ assert np.isnan(mean_ct_out).any() == False, f'NAN values out mean ct {idx}'
623
+
624
+ mean_ro_tar, std_ro_tar = get_running_stats(tar_ro.T+eps, [0], N=N)
625
+ mean_ro_out, std_ro_out = get_running_stats(out_ro.T+eps, [0], N=N)
626
+
627
+ assert np.isnan(mean_ro_tar).any() == False, f'NAN values tar mean ro {idx}'
628
+ assert np.isnan(mean_ro_out).any() == False, f'NAN values out mean ro {idx}'
629
+
630
+ mean_ft_tar, std_ft_tar = get_running_stats(tar_ft.T, [0], N=800) # gives very high numbers due to N (80) value
631
+ mean_ft_out, std_ft_out = get_running_stats(out_ft.T, [0], N=800)
632
+
633
+ mape_mean_sc = sklearn.metrics.mean_absolute_percentage_error(mean_sc_tar[0], mean_sc_out[0])
634
+
635
+ mape_mean_bw = sklearn.metrics.mean_absolute_percentage_error(mean_bw_tar[0], mean_bw_out[0])
636
+
637
+ mape_mean_ct_l = sklearn.metrics.mean_absolute_percentage_error(mean_ct_tar[0], mean_ct_out[0])
638
+
639
+ mape_mean_ct_m = sklearn.metrics.mean_absolute_percentage_error(np.mean(mean_ct_tar[1:4], axis=0),
640
+ np.mean(mean_ct_out[1:4], axis=0))
641
+
642
+ mape_mean_ct_h = sklearn.metrics.mean_absolute_percentage_error(mean_ct_tar[-1], mean_ct_out[-1])
643
+
644
+ mape_mean_ro = sklearn.metrics.mean_absolute_percentage_error(mean_ro_tar[0], mean_ro_out[0])
645
+
646
+ mape_mean_ft = sklearn.metrics.mean_absolute_percentage_error(mean_ft_tar[0], mean_ft_out[0])
647
+
648
+ centroid_mean_.append(mape_mean_sc)
649
+ bandwidth_mean_.append(mape_mean_bw)
650
+ contrast_l_mean_.append(mape_mean_ct_l)
651
+ contrast_m_mean_.append(mape_mean_ct_m)
652
+ contrast_h_mean_.append(mape_mean_ct_h)
653
+ rolloff_mean_.append(mape_mean_ro)
654
+ flatness_mean_.append(mape_mean_ft)
655
+
656
+ spectral_['centroid_mean'].append(np.mean(centroid_mean_))
657
+
658
+ spectral_['bandwidth_mean'].append(np.mean(bandwidth_mean_))
659
+
660
+ spectral_['contrast_l_mean'].append(np.mean(contrast_l_mean_))
661
+
662
+ spectral_['contrast_m_mean'].append(np.mean(contrast_m_mean_))
663
+
664
+ spectral_['contrast_h_mean'].append(np.mean(contrast_h_mean_))
665
+
666
+ spectral_['rolloff_mean'].append(np.mean(rolloff_mean_))
667
+
668
+ spectral_['flatness_mean'].append(np.mean(flatness_mean_))
669
+
670
+ spectral_['mape_mean'].append(np.mean([np.mean(centroid_mean_),
671
+ np.mean(bandwidth_mean_),
672
+ np.mean(contrast_l_mean_),
673
+ np.mean(contrast_m_mean_),
674
+ np.mean(contrast_h_mean_),
675
+ np.mean(rolloff_mean_),
676
+ np.mean(flatness_mean_),
677
+ ]))
678
+
679
+ return spectral_
680
+
681
+ # PANNING
682
+ def get_panning_rms_frame(sps_frame, freqs=[0,22050], sr=44100, n_fft=2048):
683
+
684
+ idx1 = freqs[0]
685
+ idx2 = freqs[1]
686
+
687
+ f1 = int(np.floor(idx1*n_fft/sr))
688
+ f2 = int(np.floor(idx2*n_fft/sr))
689
+
690
+ p_rms = np.sqrt((1/(f2-f1)) * np.sum(sps_frame[f1:f2]**2))
691
+
692
+ return p_rms
693
+ def get_panning_rms(sps, freqs=[[0, 22050]], sr=44100, n_fft=2048):
694
+
695
+ p_rms = []
696
+ for frame in sps:
697
+ p_rms_ = []
698
+ for f in freqs:
699
+ rms = get_panning_rms_frame(frame, freqs=f, sr=sr, n_fft=n_fft)
700
+ p_rms_.append(rms)
701
+ p_rms.append(p_rms_)
702
+
703
+ return np.asarray(p_rms)
704
+
705
+
706
+
707
+ def compute_panning_features(args_):
708
+
709
+ audio_out_ = args_[0]
710
+ audio_tar_ = args_[1]
711
+ idx = args_[2]
712
+ sr = args_[3]
713
+ fft_size = args_[4]
714
+ hop_length = args_[5]
715
+
716
+ audio_out_ = pyln.normalize.peak(audio_out_, -1.0)
717
+ audio_tar_ = pyln.normalize.peak(audio_tar_, -1.0)
718
+
719
+ panning_ = {}
720
+
721
+ freqs=[[0, sr//2], [0, 250], [250, 2500], [2500, sr//2]]
722
+
723
+ _, _, sps_frames_tar, _ = get_SPS(audio_tar_, n_fft=fft_size,
724
+ hop_length=hop_length,
725
+ smooth=True, frames=True)
726
+
727
+ _, _, sps_frames_out, _ = get_SPS(audio_out_, n_fft=fft_size,
728
+ hop_length=hop_length,
729
+ smooth=True, frames=True)
730
+
731
+
732
+ p_rms_tar = get_panning_rms(sps_frames_tar,
733
+ freqs=freqs,
734
+ sr=sr,
735
+ n_fft=fft_size)
736
+
737
+ p_rms_out = get_panning_rms(sps_frames_out,
738
+ freqs=freqs,
739
+ sr=sr,
740
+ n_fft=fft_size)
741
+
742
+ # to avoid num instability, deletes frames with zero rms from target
743
+ if np.min(p_rms_tar) == 0.0:
744
+ id_zeros = np.where(p_rms_tar.T[0] == 0)
745
+ p_rms_tar_ = []
746
+ p_rms_out_ = []
747
+ for i in range(len(freqs)):
748
+ temp_tar = np.delete(p_rms_tar.T[i], id_zeros)
749
+ temp_out = np.delete(p_rms_out.T[i], id_zeros)
750
+ p_rms_tar_.append(temp_tar)
751
+ p_rms_out_.append(temp_out)
752
+ p_rms_tar_ = np.asarray(p_rms_tar_)
753
+ p_rms_tar = p_rms_tar_.T
754
+ p_rms_out_ = np.asarray(p_rms_out_)
755
+ p_rms_out = p_rms_out_.T
756
+
757
+ N = 40
758
+
759
+ mean_tar, std_tar = get_running_stats(p_rms_tar, freqs, N=N)
760
+ mean_out, std_out = get_running_stats(p_rms_out, freqs, N=N)
761
+
762
+ panning_['P_t_mean'] = [sklearn.metrics.mean_absolute_percentage_error(mean_tar[0], mean_out[0])]
763
+ panning_['P_l_mean'] = [sklearn.metrics.mean_absolute_percentage_error(mean_tar[1], mean_out[1])]
764
+ panning_['P_m_mean'] = [sklearn.metrics.mean_absolute_percentage_error(mean_tar[2], mean_out[2])]
765
+ panning_['P_h_mean'] = [sklearn.metrics.mean_absolute_percentage_error(mean_tar[3], mean_out[3])]
766
+
767
+ panning_['mape_mean'] = [np.mean([panning_['P_t_mean'],
768
+ panning_['P_l_mean'],
769
+ panning_['P_m_mean'],
770
+ panning_['P_h_mean'],
771
+ ])]
772
+
773
+ return panning_
774
+
775
+ # DYNAMIC
776
+
777
+ def get_rms_dynamic_crest(x, frame_length, hop_length):
778
+
779
+ rms = []
780
+ dynamic_spread = []
781
+ crest = []
782
+ for ch in range(x.shape[-1]):
783
+ frames = librosa.util.frame(x[:, ch], frame_length=frame_length, hop_length=hop_length)
784
+ rms_ = []
785
+ dynamic_spread_ = []
786
+ crest_ = []
787
+ for i in frames.T:
788
+ x_rms = amp_to_db(np.sqrt(np.sum(i**2)/frame_length))
789
+ x_d = np.sum(amp_to_db(np.abs(i)) - x_rms)/frame_length
790
+ x_c = amp_to_db(np.max(np.abs(i)))/x_rms
791
+
792
+ rms_.append(x_rms)
793
+ dynamic_spread_.append(x_d)
794
+ crest_.append(x_c)
795
+ rms.append(rms_)
796
+ dynamic_spread.append(dynamic_spread_)
797
+ crest.append(crest_)
798
+
799
+ rms = np.asarray(rms)
800
+ dynamic_spread = np.asarray(dynamic_spread)
801
+ crest = np.asarray(crest)
802
+
803
+ rms = np.mean(rms, axis=0)
804
+ dynamic_spread = np.mean(dynamic_spread, axis=0)
805
+ crest = np.mean(crest, axis=0)
806
+
807
+ rms = np.expand_dims(rms, axis=0)
808
+ dynamic_spread = np.expand_dims(dynamic_spread, axis=0)
809
+ crest = np.expand_dims(crest, axis=0)
810
+
811
+ return rms, dynamic_spread, crest
812
+
813
+ def lowpassFiltering(x, f0, sr):
814
+
815
+ b1, a1 = scipy.signal.butter(4, f0/(sr/2),'lowpass')
816
+ x_f = []
817
+ for ch in range(x.shape[-1]):
818
+ x_f_ = scipy.signal.filtfilt(b1, a1, x[:, ch]).copy(order='F')
819
+ x_f.append(x_f_)
820
+ return np.asarray(x_f).T
821
+
822
+
823
+ def get_low_freq_weighting(x, sr, n_fft, hop_length, f0 = 1000):
824
+
825
+ x_low = lowpassFiltering(x, f0, sr)
826
+
827
+ X_low = compute_stft(x_low,
828
+ hop_length,
829
+ n_fft,
830
+ np.sqrt(np.hanning(n_fft+1)[:-1]))
831
+ X_low = np.transpose(X_low, axes=[1, -1, 0])
832
+ X_low = np.abs(X_low)
833
+
834
+ X = compute_stft(x,
835
+ hop_length,
836
+ n_fft,
837
+ np.sqrt(np.hanning(n_fft+1)[:-1]))
838
+ X = np.transpose(X, axes=[1, -1, 0])
839
+ X = np.abs(X)
840
+
841
+ eps = 1e-5
842
+ ratio = (X_low)/(X+eps)
843
+ ratio = np.sum(ratio, axis = 1)
844
+ ratio = np.mean(ratio, axis = 0)
845
+
846
+ return np.expand_dims(ratio, axis=0)
847
+
848
+ def compute_dynamic_features(args_):
849
+
850
+ audio_out_ = args_[0]
851
+ audio_tar_ = args_[1]
852
+ idx = args_[2]
853
+ sr = args_[3]
854
+ fft_size = args_[4]
855
+ hop_length = args_[5]
856
+
857
+ audio_out_ = pyln.normalize.peak(audio_out_, -1.0)
858
+ audio_tar_ = pyln.normalize.peak(audio_tar_, -1.0)
859
+
860
+ dynamic_ = {}
861
+
862
+ with warnings.catch_warnings():
863
+ warnings.simplefilter("ignore", category=UserWarning)
864
+
865
+ rms_tar, dyn_tar, crest_tar = get_rms_dynamic_crest(audio_tar_, fft_size, hop_length)
866
+ rms_out, dyn_out, crest_out = get_rms_dynamic_crest(audio_out_, fft_size, hop_length)
867
+
868
+ low_ratio_tar = get_low_freq_weighting(audio_tar_, sr, fft_size, hop_length, f0=1000)
869
+
870
+ low_ratio_out = get_low_freq_weighting(audio_out_, sr, fft_size, hop_length, f0=1000)
871
+
872
+ N = 40
873
+
874
+ eps = 1e-10
875
+
876
+ rms_tar = (-1*rms_tar) + 1.0
877
+ rms_out = (-1*rms_out) + 1.0
878
+ dyn_tar = (-1*dyn_tar) + 1.0
879
+ dyn_out = (-1*dyn_out) + 1.0
880
+
881
+ mean_rms_tar, std_rms_tar = get_running_stats(rms_tar.T, [0], N=N)
882
+ mean_rms_out, std_rms_out = get_running_stats(rms_out.T, [0], N=N)
883
+
884
+ mean_dyn_tar, std_dyn_tar = get_running_stats(dyn_tar.T, [0], N=N)
885
+ mean_dyn_out, std_dyn_out = get_running_stats(dyn_out.T, [0], N=N)
886
+
887
+ mean_crest_tar, std_crest_tar = get_running_stats(crest_tar.T, [0], N=N)
888
+ mean_crest_out, std_crest_out = get_running_stats(crest_out.T, [0], N=N)
889
+
890
+ mean_low_ratio_tar, std_low_ratio_tar = get_running_stats(low_ratio_tar.T, [0], N=N)
891
+ mean_low_ratio_out, std_low_ratio_out = get_running_stats(low_ratio_out.T, [0], N=N)
892
+
893
+ dynamic_['rms_mean'] = [sklearn.metrics.mean_absolute_percentage_error(mean_rms_tar, mean_rms_out)]
894
+ dynamic_['dyn_mean'] = [sklearn.metrics.mean_absolute_percentage_error(mean_dyn_tar, mean_dyn_out)]
895
+ dynamic_['crest_mean'] = [sklearn.metrics.mean_absolute_percentage_error(mean_crest_tar, mean_crest_out)]
896
+
897
+ dynamic_['l_ratio_mean_mape'] = [sklearn.metrics.mean_absolute_percentage_error(mean_low_ratio_tar, mean_low_ratio_out)]
898
+ dynamic_['l_ratio_mean_l2'] = [sklearn.metrics.mean_squared_error(mean_low_ratio_tar, mean_low_ratio_out)]
899
+
900
+ dynamic_['mape_mean'] = [np.mean([dynamic_['rms_mean'],
901
+ dynamic_['dyn_mean'],
902
+ dynamic_['crest_mean'],
903
+ ])]
904
+
905
+ return dynamic_
906
+
mixing_style_transfer/modules/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .front_back_end import *
2
+ from .loss import *
3
+ from .training_utils import *
mixing_style_transfer/modules/front_back_end.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Front-end: processing raw data input """
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchaudio.functional as ta_F
5
+ import torchaudio
6
+
7
+
8
+
9
+ class FrontEnd(nn.Module):
10
+ def __init__(self, channel='stereo', \
11
+ n_fft=2048, \
12
+ hop_length=None, \
13
+ win_length=None, \
14
+ window="hann", \
15
+ device=torch.device("cpu")):
16
+ super(FrontEnd, self).__init__()
17
+ self.channel = channel
18
+ self.n_fft = n_fft
19
+ self.hop_length = n_fft//4 if hop_length==None else hop_length
20
+ self.win_length = n_fft if win_length==None else win_length
21
+ if window=="hann":
22
+ self.window = torch.hann_window(window_length=self.win_length, periodic=True).to(device)
23
+ elif window=="hamming":
24
+ self.window = torch.hamming_window(window_length=self.win_length, periodic=True).to(device)
25
+
26
+
27
+ def forward(self, input, mode):
28
+ # front-end function which channel-wise combines all demanded features
29
+ # input shape : batch x channel x raw waveform
30
+ # output shape : batch x channel x frequency x time
31
+
32
+ front_output_list = []
33
+ for cur_mode in mode:
34
+ # Real & Imaginary
35
+ if cur_mode=="cplx":
36
+ if self.channel=="mono":
37
+ output = torch.stft(input, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
38
+ elif self.channel=="stereo":
39
+ output_l = torch.stft(input[:,0], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
40
+ output_r = torch.stft(input[:,1], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
41
+ output = torch.cat((output_l, output_r), axis=-1)
42
+ if input.shape[2] % round(self.n_fft/4) == 0:
43
+ output = output[:, :, :-1]
44
+ if self.n_fft % 2 == 0:
45
+ output = output[:, :-1]
46
+ front_output_list.append(output.permute(0, 3, 1, 2))
47
+ # Magnitude & Phase
48
+ elif cur_mode=="mag":
49
+ if self.channel=="mono":
50
+ cur_cplx = torch.stft(input, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
51
+ output = self.mag(cur_cplx).unsqueeze(-1)[..., 0:1]
52
+ elif self.channel=="stereo":
53
+ cplx_l = torch.stft(input[:,0], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
54
+ cplx_r = torch.stft(input[:,1], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
55
+ mag_l = self.mag(cplx_l).unsqueeze(-1)
56
+ mag_r = self.mag(cplx_r).unsqueeze(-1)
57
+ output = torch.cat((mag_l, mag_r), axis=-1)
58
+
59
+ if input.shape[-1] % round(self.n_fft/4) == 0:
60
+ output = output[:, :, :-1]
61
+ if self.n_fft % 2 == 0: # discard highest frequency
62
+ output = output[:, 1:]
63
+ front_output_list.append(output.permute(0, 3, 1, 2))
64
+
65
+ # combine all demanded features
66
+ if not front_output_list:
67
+ raise NameError("NameError at FrontEnd: check using features for front-end")
68
+ elif len(mode)!=1:
69
+ for i, cur_output in enumerate(front_output_list):
70
+ if i==0:
71
+ front_output = cur_output
72
+ else:
73
+ front_output = torch.cat((front_output, cur_output), axis=1)
74
+ else:
75
+ front_output = front_output_list[0]
76
+
77
+ return front_output
78
+
79
+
80
+ def mag(self, cplx_input, eps=1e-07):
81
+ mag_summed = cplx_input.pow(2.).sum(-1) + eps
82
+ return mag_summed.pow(0.5)
83
+
84
+
85
+
86
+
87
+ class BackEnd(nn.Module):
88
+ def __init__(self, channel='stereo', \
89
+ n_fft=2048, \
90
+ hop_length=None, \
91
+ win_length=None, \
92
+ window="hann", \
93
+ eps=1e-07, \
94
+ orig_freq=44100, \
95
+ new_freq=16000, \
96
+ device=torch.device("cpu")):
97
+ super(BackEnd, self).__init__()
98
+ self.device = device
99
+ self.channel = channel
100
+ self.n_fft = n_fft
101
+ self.hop_length = n_fft//4 if hop_length==None else hop_length
102
+ self.win_length = n_fft if win_length==None else win_length
103
+ self.eps = eps
104
+ if window=="hann":
105
+ self.window = torch.hann_window(window_length=self.win_length, periodic=True).to(self.device)
106
+ elif window=="hamming":
107
+ self.window = torch.hamming_window(window_length=self.win_length, periodic=True).to(self.device)
108
+ self.resample_func_8k = torchaudio.transforms.Resample(orig_freq=orig_freq, new_freq=8000).to(self.device)
109
+ self.resample_func = torchaudio.transforms.Resample(orig_freq=orig_freq, new_freq=new_freq).to(self.device)
110
+
111
+ def magphase_to_cplx(self, magphase_spec):
112
+ real = magphase_spec[..., 0] * torch.cos(magphase_spec[..., 1])
113
+ imaginary = magphase_spec[..., 0] * torch.sin(magphase_spec[..., 1])
114
+ return torch.cat((real.unsqueeze(-1), imaginary.unsqueeze(-1)), dim=-1)
115
+
116
+
117
+ def forward(self, input, phase, mode):
118
+ # back-end function which convert output spectrograms into waveform
119
+ # input shape : batch x channel x frequency x time
120
+ # output shape : batch x channel x raw waveform
121
+
122
+ # convert to shape : batch x frequency x time x channel
123
+ input = input.permute(0, 2, 3, 1)
124
+ # pad highest frequency
125
+ pad = torch.zeros((input.shape[0], 1, input.shape[2], input.shape[3])).to(self.device)
126
+ input = torch.cat((pad, input), dim=1)
127
+
128
+ back_output_list = []
129
+ channel_count = 0
130
+ for i, cur_mode in enumerate(mode):
131
+ # Real & Imaginary
132
+ if cur_mode=="cplx":
133
+ if self.channel=="mono":
134
+ output = ta_F.istft(input[...,channel_count:channel_count+2], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window).unsqueeze(1)
135
+ channel_count += 2
136
+ elif self.channel=="stereo":
137
+ cplx_spec = torch.cat([input[...,channel_count:channel_count+2], input[...,channel_count+2:channel_count+4]], dim=0)
138
+ output_wav = ta_F.istft(cplx_spec, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
139
+ output = torch.cat((output_wav[:output_wav.shape[0]//2].unsqueeze(1), output_wav[output_wav.shape[0]//2:].unsqueeze(1)), dim=1)
140
+ channel_count += 4
141
+ back_output_list.append(output)
142
+ # Magnitude & Phase
143
+ elif cur_mode=="mag_phase" or cur_mode=="mag":
144
+ if self.channel=="mono":
145
+ if cur_mode=="mag":
146
+ input_spec = torch.cat((input[...,channel_count:channel_count+1], phase), axis=-1)
147
+ channel_count += 1
148
+ else:
149
+ input_spec = input[...,channel_count:channel_count+2]
150
+ channel_count += 2
151
+ cplx_spec = self.magphase_to_cplx(input_spec)
152
+ output = ta_F.istft(cplx_spec, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window).unsqueeze(1)
153
+ elif self.channel=="stereo":
154
+ if cur_mode=="mag":
155
+ input_spec_l = torch.cat((input[...,channel_count:channel_count+1], phase[...,0:1]), axis=-1)
156
+ input_spec_r = torch.cat((input[...,channel_count+1:channel_count+2], phase[...,1:2]), axis=-1)
157
+ channel_count += 2
158
+ else:
159
+ input_spec_l = input[...,channel_count:channel_count+2]
160
+ input_spec_r = input[...,channel_count+2:channel_count+4]
161
+ channel_count += 4
162
+ cplx_spec_l = self.magphase_to_cplx(input_spec_l)
163
+ cplx_spec_r = self.magphase_to_cplx(input_spec_r)
164
+ cplx_spec = torch.cat([cplx_spec_l, cplx_spec_r], dim=0)
165
+ output_wav = torch.istft(cplx_spec, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
166
+ output = torch.cat((output_wav[:output_wav.shape[0]//2].unsqueeze(1), output_wav[output_wav.shape[0]//2:].unsqueeze(1)), dim=1)
167
+ channel_count += 4
168
+ back_output_list.append(output)
169
+ elif cur_mode=="griff":
170
+ if self.channel=="mono":
171
+ output = self.griffin_lim(input.squeeze(-1), input.device).unsqueeze(1)
172
+ # output = self.griff(input.permute(0, 3, 1, 2))
173
+ else:
174
+ output_l = self.griffin_lim(input[..., 0], input.device).unsqueeze(1)
175
+ output_r = self.griffin_lim(input[..., 1], input.device).unsqueeze(1)
176
+ output = torch.cat((output_l, output_r), axis=1)
177
+
178
+ back_output_list.append(output)
179
+
180
+ # combine all demanded feature outputs
181
+ if not back_output_list:
182
+ raise NameError("NameError at BackEnd: check using features for back-end")
183
+ elif len(mode)!=1:
184
+ for i, cur_output in enumerate(back_output_list):
185
+ if i==0:
186
+ back_output = cur_output
187
+ else:
188
+ back_output = torch.cat((back_output, cur_output), axis=1)
189
+ else:
190
+ back_output = back_output_list[0]
191
+
192
+ return back_output
193
+
194
+
195
+ def griffin_lim(self, l_est, gpu, n_iter=100):
196
+ l_est = l_est.cpu().detach()
197
+
198
+ l_est = torch.pow(l_est, 1/0.80)
199
+ # l_est [batch, channel, time]
200
+ l_mag = l_est.unsqueeze(-1)
201
+ l_phase = 2 * np.pi * torch.rand_like(l_mag) - np.pi
202
+ real = l_mag * torch.cos(l_phase)
203
+ imag = l_mag * torch.sin(l_phase)
204
+ S = torch.cat((real, imag), axis=-1)
205
+ S_mag = (real**2 + imag**2 + self.eps) ** 1/2
206
+ for i in range(n_iter):
207
+ x = ta_F.istft(S, n_fft=2048, hop_length=512, win_length=2048, window=torch.hann_window(2048))
208
+ S_new = torch.stft(x, n_fft=2048, hop_length=512, win_length=2048, window=torch.hann_window(2048))
209
+ S_new_phase = S_new/mag(S_new)
210
+ S = S_mag * S_new_phase
211
+ return x / torch.max(torch.abs(x))
212
+
213
+
214
+
215
+ if __name__ == '__main__':
216
+
217
+ batch_size = 16
218
+ channel = 2
219
+ segment_length = 512*128*6
220
+ input_wav = torch.rand((batch_size, channel, segment_length))
221
+
222
+ mode = ["cplx", "mag"]
223
+ fe = FrontEnd(channel="stereo")
224
+
225
+ output = fe(input_wav, mode=mode)
226
+ print(f"Input shape : {input_wav.shape}\nOutput shape : {output.shape}")
mixing_style_transfer/modules/loss.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of objective functions used in the task 'End-to-end Remastering System'
3
+ """
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.nn as nn
8
+
9
+ import os
10
+ import sys
11
+ currentdir = os.path.dirname(os.path.realpath(__file__))
12
+ sys.path.append(os.path.dirname(currentdir))
13
+
14
+ from modules.training_utils import *
15
+ from modules.front_back_end import *
16
+
17
+
18
+
19
+ '''
20
+ Normalized Temperature-scaled Cross Entropy (NT-Xent) Loss
21
+ below source code (class NT_Xent) is a replication from the github repository - https://github.com/Spijkervet/SimCLR
22
+ the original implementation can be found here: https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/nt_xent.py
23
+ '''
24
+ class NT_Xent(nn.Module):
25
+ def __init__(self, batch_size, temperature, world_size):
26
+ super(NT_Xent, self).__init__()
27
+ self.batch_size = batch_size
28
+ self.temperature = temperature
29
+ self.world_size = world_size
30
+
31
+ self.mask = self.mask_correlated_samples(batch_size, world_size)
32
+ self.criterion = nn.CrossEntropyLoss(reduction="sum")
33
+ self.similarity_f = nn.CosineSimilarity(dim=2)
34
+
35
+ def mask_correlated_samples(self, batch_size, world_size):
36
+ N = 2 * batch_size * world_size
37
+ mask = torch.ones((N, N), dtype=bool)
38
+ mask = mask.fill_diagonal_(0)
39
+ for i in range(batch_size * world_size):
40
+ mask[i, batch_size + i] = 0
41
+ mask[batch_size + i, i] = 0
42
+ # mask[i, batch_size * world_size + i] = 0
43
+ # mask[batch_size * world_size + i, i] = 0
44
+ return mask
45
+
46
+ def forward(self, z_i, z_j):
47
+ """
48
+ We do not sample negative examples explicitly.
49
+ Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples.
50
+ """
51
+ N = 2 * self.batch_size * self.world_size
52
+
53
+ z = torch.cat((z_i, z_j), dim=0)
54
+ # combine embeddings from all GPUs
55
+ if self.world_size > 1:
56
+ z = torch.cat(GatherLayer.apply(z), dim=0)
57
+
58
+ sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature
59
+
60
+ sim_i_j = torch.diag(sim, self.batch_size * self.world_size)
61
+ sim_j_i = torch.diag(sim, -self.batch_size * self.world_size)
62
+
63
+ # We have 2N samples, but with Distributed training every GPU gets N examples too, resulting in: 2xNxN
64
+ positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
65
+ negative_samples = sim[self.mask].reshape(N, -1)
66
+
67
+ labels = torch.zeros(N).to(positive_samples.device).long()
68
+ logits = torch.cat((positive_samples, negative_samples), dim=1)
69
+ loss = self.criterion(logits, labels)
70
+ loss /= N
71
+ return loss
72
+
73
+
74
+
75
+ # Root Mean Squared Loss
76
+ # penalizes the volume factor with non-linearlity
77
+ class RMSLoss(nn.Module):
78
+ def __init__(self, reduce, loss_type="l2"):
79
+ super(RMSLoss, self).__init__()
80
+ self.weight_factor = 100.
81
+ if loss_type=="l2":
82
+ self.loss = nn.MSELoss(reduce=None)
83
+
84
+
85
+ def forward(self, est_targets, targets):
86
+ est_targets = est_targets.reshape(est_targets.shape[0]*est_targets.shape[1], est_targets.shape[2])
87
+ targets = targets.reshape(targets.shape[0]*targets.shape[1], targets.shape[2])
88
+ normalized_est = torch.sqrt(torch.mean(est_targets**2, dim=-1))
89
+ normalized_tgt = torch.sqrt(torch.mean(targets**2, dim=-1))
90
+
91
+ weight = torch.clamp(torch.abs(normalized_tgt-normalized_est), min=1/self.weight_factor) * self.weight_factor
92
+
93
+ return torch.mean(weight**1.5 * self.loss(normalized_est, normalized_tgt))
94
+
95
+
96
+
97
+ # Multi-Scale Spectral Loss proposed at the paper "DDSP: DIFFERENTIABLE DIGITAL SIGNAL PROCESSING" (https://arxiv.org/abs/2001.04643)
98
+ # we extend this loss by applying it to mid/side channels
99
+ class MultiScale_Spectral_Loss_MidSide_DDSP(nn.Module):
100
+ def __init__(self, mode='midside', \
101
+ reduce=True, \
102
+ n_filters=None, \
103
+ windows_size=None, \
104
+ hops_size=None, \
105
+ window="hann", \
106
+ eps=1e-7, \
107
+ device=torch.device("cpu")):
108
+ super(MultiScale_Spectral_Loss_MidSide_DDSP, self).__init__()
109
+ self.mode = mode
110
+ self.eps = eps
111
+ self.mid_weight = 0.5 # value in the range of 0.0 ~ 1.0
112
+ self.logmag_weight = 0.1
113
+
114
+ if n_filters is None:
115
+ n_filters = [4096, 2048, 1024, 512]
116
+ # n_filters = [4096]
117
+ if windows_size is None:
118
+ windows_size = [4096, 2048, 1024, 512]
119
+ # windows_size = [4096]
120
+ if hops_size is None:
121
+ hops_size = [1024, 512, 256, 128]
122
+ # hops_size = [1024]
123
+
124
+ self.multiscales = []
125
+ for i in range(len(windows_size)):
126
+ cur_scale = {'window_size' : float(windows_size[i])}
127
+ if self.mode=='midside':
128
+ cur_scale['front_end'] = FrontEnd(channel='mono', \
129
+ n_fft=n_filters[i], \
130
+ hop_length=hops_size[i], \
131
+ win_length=windows_size[i], \
132
+ window=window, \
133
+ device=device)
134
+ elif self.mode=='ori':
135
+ cur_scale['front_end'] = FrontEnd(channel='stereo', \
136
+ n_fft=n_filters[i], \
137
+ hop_length=hops_size[i], \
138
+ win_length=windows_size[i], \
139
+ window=window, \
140
+ device=device)
141
+ self.multiscales.append(cur_scale)
142
+
143
+ self.objective_l1 = nn.L1Loss(reduce=reduce)
144
+ self.objective_l2 = nn.MSELoss(reduce=reduce)
145
+
146
+
147
+ def forward(self, est_targets, targets):
148
+ if self.mode=='midside':
149
+ return self.forward_midside(est_targets, targets)
150
+ elif self.mode=='ori':
151
+ return self.forward_ori(est_targets, targets)
152
+
153
+
154
+ def forward_ori(self, est_targets, targets):
155
+ total_loss = 0.0
156
+ total_mag_loss = 0.0
157
+ total_logmag_loss = 0.0
158
+ for cur_scale in self.multiscales:
159
+ est_mag = cur_scale['front_end'](est_targets, mode=["mag"])
160
+ tgt_mag = cur_scale['front_end'](targets, mode=["mag"])
161
+
162
+ mag_loss = self.magnitude_loss(est_mag, tgt_mag)
163
+ logmag_loss = self.log_magnitude_loss(est_mag, tgt_mag)
164
+ # cur_loss = mag_loss + logmag_loss
165
+ # total_loss += cur_loss
166
+ total_mag_loss += mag_loss
167
+ total_logmag_loss += logmag_loss
168
+ # return total_loss
169
+ # print(f"ori - mag : {total_mag_loss}\tlog mag : {total_logmag_loss}")
170
+ return (1-self.logmag_weight)*total_mag_loss + \
171
+ (self.logmag_weight)*total_logmag_loss
172
+
173
+
174
+ def forward_midside(self, est_targets, targets):
175
+ est_mid, est_side = self.to_mid_side(est_targets)
176
+ tgt_mid, tgt_side = self.to_mid_side(targets)
177
+ total_loss = 0.0
178
+ total_mag_loss = 0.0
179
+ total_logmag_loss = 0.0
180
+ for cur_scale in self.multiscales:
181
+ est_mid_mag = cur_scale['front_end'](est_mid, mode=["mag"])
182
+ est_side_mag = cur_scale['front_end'](est_side, mode=["mag"])
183
+ tgt_mid_mag = cur_scale['front_end'](tgt_mid, mode=["mag"])
184
+ tgt_side_mag = cur_scale['front_end'](tgt_side, mode=["mag"])
185
+
186
+ mag_loss = self.mid_weight*self.magnitude_loss(est_mid_mag, tgt_mid_mag) + \
187
+ (1-self.mid_weight)*self.magnitude_loss(est_side_mag, tgt_side_mag)
188
+ logmag_loss = self.mid_weight*self.log_magnitude_loss(est_mid_mag, tgt_mid_mag) + \
189
+ (1-self.mid_weight)*self.log_magnitude_loss(est_side_mag, tgt_side_mag)
190
+ # cur_loss = mag_loss + logmag_loss
191
+ # total_loss += cur_loss
192
+ total_mag_loss += mag_loss
193
+ total_logmag_loss += logmag_loss
194
+ # return total_loss
195
+ # print(f"midside - mag : {total_mag_loss}\tlog mag : {total_logmag_loss}")
196
+ return (1-self.logmag_weight)*total_mag_loss + \
197
+ (self.logmag_weight)*total_logmag_loss
198
+
199
+
200
+ def to_mid_side(self, stereo_in):
201
+ mid = stereo_in[:,0] + stereo_in[:,1]
202
+ side = stereo_in[:,0] - stereo_in[:,1]
203
+ return mid, side
204
+
205
+
206
+ def magnitude_loss(self, est_mag_spec, tgt_mag_spec):
207
+ return torch.norm(self.objective_l1(est_mag_spec, tgt_mag_spec))
208
+
209
+
210
+ def log_magnitude_loss(self, est_mag_spec, tgt_mag_spec):
211
+ est_log_mag_spec = torch.log10(est_mag_spec+self.eps)
212
+ tgt_log_mag_spec = torch.log10(tgt_mag_spec+self.eps)
213
+ return self.objective_l2(est_log_mag_spec, tgt_log_mag_spec)
214
+
215
+
216
+
217
+ # hinge loss for discriminator
218
+ def dis_hinge(dis_fake, dis_real):
219
+ return torch.mean(torch.relu(1. - dis_real)) + torch.mean(torch.relu(1. + dis_fake))
220
+
221
+
222
+ # hinge loss for generator
223
+ def gen_hinge(dis_fake, dis_real=None):
224
+ return -torch.mean(dis_fake)
225
+
226
+
227
+ # DirectCLR's implementation of infoNCE loss
228
+ def infoNCE(nn, p, temperature=0.1):
229
+ nn = torch.nn.functional.normalize(nn, dim=1)
230
+ p = torch.nn.functional.normalize(p, dim=1)
231
+ nn = gather_from_all(nn)
232
+ p = gather_from_all(p)
233
+ logits = nn @ p.T
234
+ logits /= temperature
235
+ n = p.shape[0]
236
+ labels = torch.arange(0, n, dtype=torch.long).cuda()
237
+ loss = torch.nn.functional.cross_entropy(logits, labels)
238
+ return loss
239
+
240
+
241
+
242
+
243
+ # Class of available loss functions
244
+ class Loss:
245
+ def __init__(self, args, reduce=True):
246
+ device = torch.device("cpu")
247
+ if torch.cuda.is_available():
248
+ device = torch.device(f"cuda:{args.gpu}")
249
+ self.l1 = nn.L1Loss(reduce=reduce)
250
+ self.mse = nn.MSELoss(reduce=reduce)
251
+ self.ce = nn.CrossEntropyLoss()
252
+ self.triplet = nn.TripletMarginLoss(margin=1., p=2)
253
+
254
+ # self.ntxent = NT_Xent(args.train_batch*2, args.temperature, world_size=len(args.using_gpu.split(",")))
255
+ self.ntxent = NT_Xent(args.batch_size_total*(args.num_strong_negatives+1), args.temperature, world_size=1)
256
+ self.multi_scale_spectral_midside = MultiScale_Spectral_Loss_MidSide_DDSP(mode='midside', eps=args.eps, device=device)
257
+ self.multi_scale_spectral_ori = MultiScale_Spectral_Loss_MidSide_DDSP(mode='ori', eps=args.eps, device=device)
258
+ self.gain = RMSLoss(reduce=reduce)
259
+ self.infonce = infoNCE
260
+
mixing_style_transfer/modules/training_utils.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Utility file for trainers """
2
+ import os
3
+ import shutil
4
+ from glob import glob
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+
10
+
11
+ ''' checkpoint functions '''
12
+ # saves checkpoint
13
+ def save_checkpoint(model, \
14
+ optimizer, \
15
+ scheduler, \
16
+ epoch, \
17
+ checkpoint_dir, \
18
+ name, \
19
+ model_name):
20
+ os.makedirs(checkpoint_dir, exist_ok=True)
21
+ checkpoint_state = {
22
+ "model": model.state_dict(),
23
+ "optimizer": optimizer.state_dict(),
24
+ "scheduler": scheduler.state_dict(),
25
+ "epoch": epoch
26
+ }
27
+ checkpoint_path = os.path.join(checkpoint_dir,'{}_{}_{}.pt'.format(name, model_name, epoch))
28
+ torch.save(checkpoint_state, checkpoint_path)
29
+ print("Saved checkpoint: {}".format(checkpoint_path))
30
+
31
+
32
+ # reload model weights from checkpoint file
33
+ def reload_ckpt(args, \
34
+ network, \
35
+ optimizer, \
36
+ scheduler, \
37
+ gpu, \
38
+ model_name, \
39
+ manual_reload_name=None, \
40
+ manual_reload=False, \
41
+ manual_reload_dir=None, \
42
+ epoch=None, \
43
+ fit_sefa=False):
44
+ if manual_reload:
45
+ reload_name = manual_reload_name
46
+ else:
47
+ reload_name = args.name
48
+ if manual_reload_dir:
49
+ ckpt_dir = manual_reload_dir + reload_name + "/ckpt/"
50
+ else:
51
+ ckpt_dir = args.output_dir + reload_name + "/ckpt/"
52
+ temp_ckpt_dir = f'{args.output_dir}{reload_name}/ckpt_temp/'
53
+ reload_epoch = epoch
54
+ # find best or latest epoch
55
+ if epoch==None:
56
+ reload_epoch_temp = 0
57
+ reload_epoch_ckpt = 0
58
+ if len(os.listdir(temp_ckpt_dir))!=0:
59
+ reload_epoch_temp = find_best_epoch(temp_ckpt_dir)
60
+ if len(os.listdir(ckpt_dir))!=0:
61
+ reload_epoch_ckpt = find_best_epoch(ckpt_dir)
62
+ if reload_epoch_ckpt >= reload_epoch_temp:
63
+ reload_epoch = reload_epoch_ckpt
64
+ else:
65
+ reload_epoch = reload_epoch_temp
66
+ ckpt_dir = temp_ckpt_dir
67
+ else:
68
+ if os.path.isfile(f"{temp_ckpt_dir}{reload_epoch}/{reload_name}_{model_name}_{reload_epoch}.pt"):
69
+ ckpt_dir = temp_ckpt_dir
70
+ # reloading weight
71
+ if model_name==None:
72
+ resuming_path = f"{ckpt_dir}{reload_epoch}/{reload_name}_{reload_epoch}.pt"
73
+ else:
74
+ resuming_path = f"{ckpt_dir}{reload_epoch}/{reload_name}_{model_name}_{reload_epoch}.pt"
75
+ if gpu==0:
76
+ print("===Resume checkpoint from: {}===".format(resuming_path))
77
+ loc = 'cuda:{}'.format(gpu)
78
+ checkpoint = torch.load(resuming_path, map_location=loc)
79
+ start_epoch = 0 if manual_reload and not fit_sefa else checkpoint["epoch"]
80
+
81
+ if manual_reload_dir is not None and 'parameter_estimation' in manual_reload_dir:
82
+ from collections import OrderedDict
83
+ new_state_dict = OrderedDict()
84
+ for k, v in checkpoint["model"].items():
85
+ name = 'module.' + k # add `module.`
86
+ new_state_dict[name] = v
87
+ network.load_state_dict(new_state_dict)
88
+ else:
89
+ network.load_state_dict(checkpoint["model"])
90
+ if not manual_reload:
91
+ optimizer.load_state_dict(checkpoint["optimizer"])
92
+ scheduler.load_state_dict(checkpoint["scheduler"])
93
+ if gpu==0:
94
+ # print("=> loaded checkpoint '{}' (epoch {})".format(resuming_path, checkpoint['epoch']))
95
+ print("=> loaded checkpoint '{}' (epoch {})".format(resuming_path, epoch))
96
+ return start_epoch
97
+
98
+
99
+ # find best epoch for reloading current model
100
+ def find_best_epoch(input_dir):
101
+ cur_epochs = glob("{}*".format(input_dir))
102
+ return find_by_name(cur_epochs)
103
+
104
+
105
+ # sort string epoch names by integers
106
+ def find_by_name(epochs):
107
+ int_epochs = []
108
+ for e in epochs:
109
+ int_epochs.append(int(os.path.basename(e)))
110
+ int_epochs.sort()
111
+ return (int_epochs[-1])
112
+
113
+
114
+ # remove ckpt files
115
+ def remove_ckpt(cur_ckpt_path_dir, leave=2):
116
+ ckpt_nums = [int(i) for i in os.listdir(cur_ckpt_path_dir)]
117
+ ckpt_nums.sort()
118
+ del_num = len(ckpt_nums) - leave
119
+ cur_del_num = 0
120
+ while del_num > 0:
121
+ shutil.rmtree("{}{}".format(cur_ckpt_path_dir, ckpt_nums[cur_del_num]))
122
+ del_num -= 1
123
+ cur_del_num += 1
124
+
125
+
126
+
127
+ ''' multi-GPU functions '''
128
+
129
+ # gather function implemented from DirectCLR
130
+ class GatherLayer_Direct(torch.autograd.Function):
131
+ """
132
+ Gather tensors from all workers with support for backward propagation:
133
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
134
+ """
135
+
136
+ @staticmethod
137
+ def forward(ctx, x):
138
+ output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
139
+ dist.all_gather(output, x)
140
+ return tuple(output)
141
+
142
+ @staticmethod
143
+ def backward(ctx, *grads):
144
+ all_gradients = torch.stack(grads)
145
+ dist.all_reduce(all_gradients)
146
+ return all_gradients[dist.get_rank()]
147
+
148
+ from classy_vision.generic.distributed_util import (
149
+ convert_to_distributed_tensor,
150
+ convert_to_normal_tensor,
151
+ is_distributed_training_run,
152
+ )
153
+ def gather_from_all(tensor: torch.Tensor) -> torch.Tensor:
154
+ """
155
+ Similar to classy_vision.generic.distributed_util.gather_from_all
156
+ except that it does not cut the gradients
157
+ """
158
+ if tensor.ndim == 0:
159
+ # 0 dim tensors cannot be gathered. so unsqueeze
160
+ tensor = tensor.unsqueeze(0)
161
+
162
+ if is_distributed_training_run():
163
+ tensor, orig_device = convert_to_distributed_tensor(tensor)
164
+ gathered_tensors = GatherLayer_Direct.apply(tensor)
165
+ gathered_tensors = [
166
+ convert_to_normal_tensor(_tensor, orig_device)
167
+ for _tensor in gathered_tensors
168
+ ]
169
+ else:
170
+ gathered_tensors = [tensor]
171
+ gathered_tensor = torch.cat(gathered_tensors, 0)
172
+ return gathered_tensor
173
+
174
+
mixing_style_transfer/networks/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .architectures import *
2
+ from .network_utils import *
mixing_style_transfer/networks/architectures.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ "Music Mixing Style Transfer: A Contrastive Learning Approach to Disentangle Audio Effects"
3
+
4
+ Implementation of neural networks used in the task 'Music Mixing Style Transfer'
5
+ - 'FXencoder'
6
+ - TCN based 'MixFXcloner'
7
+
8
+ We modify the TCN implementation from: https://github.com/csteinmetz1/micro-tcn
9
+ which was introduced in the work "Efficient neural networks for real-time modeling of analog dynamic range compression" by Christian J. Steinmetz, and Joshua D. Reiss
10
+ """
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torch.nn.init as init
15
+
16
+ import os
17
+ import sys
18
+ currentdir = os.path.dirname(os.path.realpath(__file__))
19
+ sys.path.append(os.path.dirname(currentdir))
20
+
21
+ from networks.network_utils import *
22
+
23
+
24
+
25
+ # FXencoder that extracts audio effects from music recordings trained with a contrastive objective
26
+ class FXencoder(nn.Module):
27
+ def __init__(self, config):
28
+ super(FXencoder, self).__init__()
29
+ # input is stereo channeled audio
30
+ config["channels"].insert(0, 2)
31
+
32
+ # encoder layers
33
+ encoder = []
34
+ for i in range(len(config["kernels"])):
35
+ if config["conv_block"]=='res':
36
+ encoder.append(Res_ConvBlock(dimension=1, \
37
+ in_channels=config["channels"][i], \
38
+ out_channels=config["channels"][i+1], \
39
+ kernel_size=config["kernels"][i], \
40
+ stride=config["strides"][i], \
41
+ padding="SAME", \
42
+ dilation=config["dilation"][i], \
43
+ norm=config["norm"], \
44
+ activation=config["activation"], \
45
+ last_activation=config["activation"]))
46
+ elif config["conv_block"]=='conv':
47
+ encoder.append(ConvBlock(dimension=1, \
48
+ layer_num=1, \
49
+ in_channels=config["channels"][i], \
50
+ out_channels=config["channels"][i+1], \
51
+ kernel_size=config["kernels"][i], \
52
+ stride=config["strides"][i], \
53
+ padding="VALID", \
54
+ dilation=config["dilation"][i], \
55
+ norm=config["norm"], \
56
+ activation=config["activation"], \
57
+ last_activation=config["activation"], \
58
+ mode='conv'))
59
+ self.encoder = nn.Sequential(*encoder)
60
+
61
+ # pooling method
62
+ self.glob_pool = nn.AdaptiveAvgPool1d(1)
63
+
64
+ # network forward operation
65
+ def forward(self, input):
66
+ enc_output = self.encoder(input)
67
+ glob_pooled = self.glob_pool(enc_output).squeeze(-1)
68
+
69
+ # outputs c feature
70
+ return glob_pooled
71
+
72
+
73
+ # MixFXcloner which is based on a Temporal Convolutional Network (TCN) module
74
+ # original implementation : https://github.com/csteinmetz1/micro-tcn
75
+ import pytorch_lightning as pl
76
+ class TCNModel(pl.LightningModule):
77
+ """ Temporal convolutional network with conditioning module.
78
+ Args:
79
+ nparams (int): Number of conditioning parameters.
80
+ ninputs (int): Number of input channels (mono = 1, stereo 2). Default: 1
81
+ noutputs (int): Number of output channels (mono = 1, stereo 2). Default: 1
82
+ nblocks (int): Number of total TCN blocks. Default: 10
83
+ kernel_size (int): Width of the convolutional kernels. Default: 3
84
+ dialation_growth (int): Compute the dilation factor at each block as dilation_growth ** (n % stack_size). Default: 1
85
+ channel_growth (int): Compute the output channels at each black as in_ch * channel_growth. Default: 2
86
+ channel_width (int): When channel_growth = 1 all blocks use convolutions with this many channels. Default: 64
87
+ stack_size (int): Number of blocks that constitute a single stack of blocks. Default: 10
88
+ grouped (bool): Use grouped convolutions to reduce the total number of parameters. Default: False
89
+ causal (bool): Causal TCN configuration does not consider future input values. Default: False
90
+ skip_connections (bool): Skip connections from each block to the output. Default: False
91
+ num_examples (int): Number of evaluation audio examples to log after each epochs. Default: 4
92
+ """
93
+ def __init__(self,
94
+ nparams,
95
+ ninputs=1,
96
+ noutputs=1,
97
+ nblocks=10,
98
+ kernel_size=3,
99
+ dilation_growth=1,
100
+ channel_growth=1,
101
+ channel_width=32,
102
+ stack_size=10,
103
+ cond_dim=2048,
104
+ grouped=False,
105
+ causal=False,
106
+ skip_connections=False,
107
+ num_examples=4,
108
+ save_dir=None,
109
+ **kwargs):
110
+ super(TCNModel, self).__init__()
111
+ self.save_hyperparameters()
112
+
113
+ self.blocks = torch.nn.ModuleList()
114
+ for n in range(nblocks):
115
+ in_ch = out_ch if n > 0 else ninputs
116
+
117
+ if self.hparams.channel_growth > 1:
118
+ out_ch = in_ch * self.hparams.channel_growth
119
+ else:
120
+ out_ch = self.hparams.channel_width
121
+
122
+ dilation = self.hparams.dilation_growth ** (n % self.hparams.stack_size)
123
+ self.blocks.append(TCNBlock(in_ch,
124
+ out_ch,
125
+ kernel_size=self.hparams.kernel_size,
126
+ dilation=dilation,
127
+ padding="same" if self.hparams.causal else "valid",
128
+ causal=self.hparams.causal,
129
+ cond_dim=cond_dim,
130
+ grouped=self.hparams.grouped,
131
+ conditional=True if self.hparams.nparams > 0 else False))
132
+
133
+ self.output = torch.nn.Conv1d(out_ch, noutputs, kernel_size=1)
134
+
135
+ def forward(self, x, cond):
136
+ # iterate over blocks passing conditioning
137
+ for idx, block in enumerate(self.blocks):
138
+ # for SeFa
139
+ if isinstance(cond, list):
140
+ x = block(x, cond[idx])
141
+ else:
142
+ x = block(x, cond)
143
+ skips = 0
144
+
145
+ out = torch.clamp(self.output(x + skips), min=-1, max=1)
146
+
147
+ return out
148
+
149
+ def compute_receptive_field(self):
150
+ """ Compute the receptive field in samples."""
151
+ rf = self.hparams.kernel_size
152
+ for n in range(1,self.hparams.nblocks):
153
+ dilation = self.hparams.dilation_growth ** (n % self.hparams.stack_size)
154
+ rf = rf + ((self.hparams.kernel_size-1) * dilation)
155
+ return rf
156
+
157
+ # add any model hyperparameters here
158
+ @staticmethod
159
+ def add_model_specific_args(parent_parser):
160
+ parser = ArgumentParser(parents=[parent_parser], add_help=False)
161
+ # --- model related ---
162
+ parser.add_argument('--ninputs', type=int, default=1)
163
+ parser.add_argument('--noutputs', type=int, default=1)
164
+ parser.add_argument('--nblocks', type=int, default=4)
165
+ parser.add_argument('--kernel_size', type=int, default=5)
166
+ parser.add_argument('--dilation_growth', type=int, default=10)
167
+ parser.add_argument('--channel_growth', type=int, default=1)
168
+ parser.add_argument('--channel_width', type=int, default=32)
169
+ parser.add_argument('--stack_size', type=int, default=10)
170
+ parser.add_argument('--grouped', default=False, action='store_true')
171
+ parser.add_argument('--causal', default=False, action="store_true")
172
+ parser.add_argument('--skip_connections', default=False, action="store_true")
173
+
174
+ return parser
175
+
176
+
177
+ class TCNBlock(torch.nn.Module):
178
+ def __init__(self,
179
+ in_ch,
180
+ out_ch,
181
+ kernel_size=3,
182
+ dilation=1,
183
+ cond_dim=2048,
184
+ grouped=False,
185
+ causal=False,
186
+ conditional=False,
187
+ **kwargs):
188
+ super(TCNBlock, self).__init__()
189
+
190
+ self.in_ch = in_ch
191
+ self.out_ch = out_ch
192
+ self.kernel_size = kernel_size
193
+ self.dilation = dilation
194
+ self.grouped = grouped
195
+ self.causal = causal
196
+ self.conditional = conditional
197
+
198
+ groups = out_ch if grouped and (in_ch % out_ch == 0) else 1
199
+
200
+ self.pad_length = ((kernel_size-1)*dilation) if self.causal else ((kernel_size-1)*dilation)//2
201
+ self.conv1 = torch.nn.Conv1d(in_ch,
202
+ out_ch,
203
+ kernel_size=kernel_size,
204
+ padding=self.pad_length,
205
+ dilation=dilation,
206
+ groups=groups,
207
+ bias=False)
208
+ if grouped:
209
+ self.conv1b = torch.nn.Conv1d(out_ch, out_ch, kernel_size=1)
210
+
211
+ if conditional:
212
+ self.film = FiLM(cond_dim, out_ch)
213
+ self.bn = torch.nn.BatchNorm1d(out_ch)
214
+
215
+ self.relu = torch.nn.LeakyReLU()
216
+ self.res = torch.nn.Conv1d(in_ch,
217
+ out_ch,
218
+ kernel_size=1,
219
+ groups=in_ch,
220
+ bias=False)
221
+
222
+ def forward(self, x, p):
223
+ x_in = x
224
+
225
+ x = self.relu(self.bn(self.conv1(x)))
226
+ x = self.film(x, p)
227
+
228
+ x_res = self.res(x_in)
229
+
230
+ if self.causal:
231
+ x = x[..., :-self.pad_length]
232
+ x += x_res
233
+
234
+ return x
235
+
236
+
237
+
238
+ if __name__ == '__main__':
239
+ ''' check model I/O shape '''
240
+ import yaml
241
+ with open('networks/configs.yaml', 'r') as f:
242
+ configs = yaml.full_load(f)
243
+
244
+ batch_size = 32
245
+ sr = 44100
246
+ input_length = sr*5
247
+
248
+ input = torch.rand(batch_size, 2, input_length)
249
+ print(f"Input Shape : {input.shape}\n")
250
+
251
+
252
+ print('\n========== Audio Effects Encoder (FXencoder) ==========')
253
+ model_arc = "FXencoder"
254
+ model_options = "default"
255
+
256
+ config = configs[model_arc][model_options]
257
+ print(f"configuration: {config}")
258
+
259
+ network = FXencoder(config)
260
+ pytorch_total_params = sum(p.numel() for p in network.parameters() if p.requires_grad)
261
+ print(f"Number of trainable parameters : {pytorch_total_params}")
262
+
263
+ # model inference
264
+ output_c = network(input)
265
+ print(f"Output Shape : {output_c.shape}")
266
+
267
+
268
+ print('\n========== TCN based MixFXcloner ==========')
269
+ model_arc = "TCN"
270
+ model_options = "default"
271
+
272
+ config = configs[model_arc][model_options]
273
+ print(f"configuration: {config}")
274
+
275
+ network = TCNModel(nparams=config["condition_dimension"], ninputs=2, noutputs=2, \
276
+ nblocks=config["nblocks"], \
277
+ dilation_growth=config["dilation_growth"], \
278
+ kernel_size=config["kernel_size"], \
279
+ channel_width=config["channel_width"], \
280
+ stack_size=config["stack_size"], \
281
+ cond_dim=config["condition_dimension"], \
282
+ causal=config["causal"])
283
+ pytorch_total_params = sum(p.numel() for p in network.parameters() if p.requires_grad)
284
+ print(f"Number of trainable parameters : {pytorch_total_params}\tReceptive field duration : {network.compute_receptive_field() / sr:.3f}")
285
+
286
+ ref_embedding = output_c
287
+ # model inference
288
+ output = network(input, output_c)
289
+ print(f"Output Shape : {output.shape}")
290
+
mixing_style_transfer/networks/configs.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model architecture configurations
2
+
3
+
4
+ # Music Effects Encoder
5
+ Effects_Encoder:
6
+
7
+ default:
8
+ channels: [16, 32, 64, 128, 256, 256, 512, 512, 1024, 1024, 2048, 2048]
9
+ kernels: [25, 25, 15, 15, 10, 10, 10, 10, 5, 5, 5, 5]
10
+ strides: [4, 4, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1]
11
+ dilation: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
12
+ bias: True
13
+ norm: 'batch'
14
+ conv_block: 'res'
15
+ activation: "relu"
16
+
17
+
18
+ # TCN
19
+ TCN:
20
+
21
+ # receptive field = 5.2 seconds
22
+ default:
23
+ condition_dimension: 2048
24
+ nblocks: 14
25
+ dilation_growth: 2
26
+ kernel_size: 15
27
+ channel_width: 128
28
+ stack_size: 15
29
+ causal: False
30
+
mixing_style_transfer/networks/network_utils.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility File
3
+ containing functions for neural networks
4
+ """
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.nn.init as init
8
+ import torch
9
+ import torchaudio
10
+
11
+
12
+
13
+ # 1-dimensional convolutional layer
14
+ # in the order of conv -> norm -> activation
15
+ class Conv1d_layer(nn.Module):
16
+ def __init__(self, in_channels, out_channels, kernel_size, \
17
+ stride=1, \
18
+ padding="SAME", dilation=1, bias=True, \
19
+ norm="batch", activation="relu", \
20
+ mode="conv"):
21
+ super(Conv1d_layer, self).__init__()
22
+
23
+ self.conv1d = nn.Sequential()
24
+
25
+ ''' padding '''
26
+ if mode=="deconv":
27
+ padding = int(dilation * (kernel_size-1) / 2)
28
+ out_padding = 0 if stride==1 else 1
29
+ elif mode=="conv" or "alias_free" in mode:
30
+ if padding == "SAME":
31
+ pad = int((kernel_size-1) * dilation)
32
+ l_pad = int(pad//2)
33
+ r_pad = pad - l_pad
34
+ padding_area = (l_pad, r_pad)
35
+ elif padding == "VALID":
36
+ padding_area = (0, 0)
37
+ else:
38
+ pass
39
+
40
+ ''' convolutional layer '''
41
+ if mode=="deconv":
42
+ self.conv1d.add_module("deconv1d", nn.ConvTranspose1d(in_channels, out_channels, kernel_size, \
43
+ stride=stride, padding=padding, output_padding=out_padding, \
44
+ dilation=dilation, \
45
+ bias=bias))
46
+ elif mode=="conv":
47
+ self.conv1d.add_module(f"{mode}1d_pad", nn.ReflectionPad1d(padding_area))
48
+ self.conv1d.add_module(f"{mode}1d", nn.Conv1d(in_channels, out_channels, kernel_size, \
49
+ stride=stride, padding=0, \
50
+ dilation=dilation, \
51
+ bias=bias))
52
+ elif "alias_free" in mode:
53
+ if "up" in mode:
54
+ up_factor = stride * 2
55
+ down_factor = 2
56
+ elif "down" in mode:
57
+ up_factor = 2
58
+ down_factor = stride * 2
59
+ else:
60
+ raise ValueError("choose alias-free method : 'up' or 'down'")
61
+ # procedure : conv -> upsample -> lrelu -> low-pass filter -> downsample
62
+ # the torchaudio.transforms.Resample's default resampling_method is 'sinc_interpolation' which performs low-pass filter during the process
63
+ # details at https://pytorch.org/audio/stable/transforms.html
64
+ self.conv1d.add_module(f"{mode}1d_pad", nn.ReflectionPad1d(padding_area))
65
+ self.conv1d.add_module(f"{mode}1d", nn.Conv1d(in_channels, out_channels, kernel_size, \
66
+ stride=1, padding=0, \
67
+ dilation=dilation, \
68
+ bias=bias))
69
+ self.conv1d.add_module(f"{mode}upsample", torchaudio.transforms.Resample(orig_freq=1, new_freq=up_factor))
70
+ self.conv1d.add_module(f"{mode}lrelu", nn.LeakyReLU())
71
+ self.conv1d.add_module(f"{mode}downsample", torchaudio.transforms.Resample(orig_freq=down_factor, new_freq=1))
72
+
73
+ ''' normalization '''
74
+ if norm=="batch":
75
+ self.conv1d.add_module("batch_norm", nn.BatchNorm1d(out_channels))
76
+ # self.conv1d.add_module("batch_norm", nn.SyncBatchNorm(out_channels))
77
+
78
+ ''' activation '''
79
+ if 'alias_free' not in mode:
80
+ if activation=="relu":
81
+ self.conv1d.add_module("relu", nn.ReLU())
82
+ elif activation=="lrelu":
83
+ self.conv1d.add_module("lrelu", nn.LeakyReLU())
84
+
85
+
86
+ def forward(self, input):
87
+ # input shape should be : batch x channel x height x width
88
+ output = self.conv1d(input)
89
+ return output
90
+
91
+
92
+
93
+ # Residual Block
94
+ # the input is added after the first convolutional layer, retaining its original channel size
95
+ # therefore, the second convolutional layer's output channel may differ
96
+ class Res_ConvBlock(nn.Module):
97
+ def __init__(self, dimension, \
98
+ in_channels, out_channels, \
99
+ kernel_size, \
100
+ stride=1, padding="SAME", \
101
+ dilation=1, \
102
+ bias=True, \
103
+ norm="batch", \
104
+ activation="relu", last_activation="relu", \
105
+ mode="conv"):
106
+ super(Res_ConvBlock, self).__init__()
107
+
108
+ if dimension==1:
109
+ self.conv1 = Conv1d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation)
110
+ self.conv2 = Conv1d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode)
111
+ elif dimension==2:
112
+ self.conv1 = Conv2d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation)
113
+ self.conv2 = Conv2d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode)
114
+
115
+
116
+ def forward(self, input):
117
+ c1_out = self.conv1(input) + input
118
+ c2_out = self.conv2(c1_out)
119
+ return c2_out
120
+
121
+
122
+
123
+ # Convoluaionl Block
124
+ # consists of multiple (number of layer_num) convolutional layers
125
+ # only the final convoluational layer outputs the desired 'out_channels'
126
+ class ConvBlock(nn.Module):
127
+ def __init__(self, dimension, layer_num, \
128
+ in_channels, out_channels, \
129
+ kernel_size, \
130
+ stride=1, padding="SAME", \
131
+ dilation=1, \
132
+ bias=True, \
133
+ norm="batch", \
134
+ activation="relu", last_activation="relu", \
135
+ mode="conv"):
136
+ super(ConvBlock, self).__init__()
137
+
138
+ conv_block = []
139
+ if dimension==1:
140
+ for i in range(layer_num-1):
141
+ conv_block.append(Conv1d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation))
142
+ conv_block.append(Conv1d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode))
143
+ elif dimension==2:
144
+ for i in range(layer_num-1):
145
+ conv_block.append(Conv2d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation))
146
+ conv_block.append(Conv2d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode))
147
+ self.conv_block = nn.Sequential(*conv_block)
148
+
149
+
150
+ def forward(self, input):
151
+ return self.conv_block(input)
152
+
153
+
154
+
155
+ # Feature-wise Linear Modulation
156
+ class FiLM(nn.Module):
157
+ def __init__(self, condition_len=2048, feature_len=1024):
158
+ super(FiLM, self).__init__()
159
+ self.film_fc = nn.Linear(condition_len, feature_len*2)
160
+ self.feat_len = feature_len
161
+
162
+
163
+ def forward(self, feature, condition, sefa=None):
164
+ # SeFA
165
+ if sefa:
166
+ weight = self.film_fc.weight.T
167
+ weight = weight / torch.linalg.norm((weight+1e-07), dim=0, keepdims=True)
168
+ eigen_values, eigen_vectors = torch.eig(torch.matmul(weight, weight.T), eigenvectors=True)
169
+
170
+ ####### custom parameters #######
171
+ chosen_eig_idx = sefa[0]
172
+ alpha = eigen_values[chosen_eig_idx][0] * sefa[1]
173
+ #################################
174
+
175
+ An = eigen_vectors[chosen_eig_idx].repeat(condition.shape[0], 1)
176
+ alpha_An = alpha * An
177
+
178
+ condition += alpha_An
179
+
180
+ film_factor = self.film_fc(condition).unsqueeze(-1)
181
+ r, b = torch.split(film_factor, self.feat_len, dim=1)
182
+ return r*feature + b
183
+
184
+
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aubio==0.4.9
2
+ classy_vision==0.6.0
3
+ config==0.5.1
4
+ librosa==0.9.2
5
+ matplotlib==3.3.3
6
+ numba==0.48.0
7
+ numpy==1.23.0
8
+ psutil==5.7.2
9
+ pyloudnorm==0.1.0
10
+ git+https://github.com/csteinmetz1/pymixconsole
11
+ pypesq==1.2.4
12
+ pytorch_lightning==1.3.2
13
+ PyYAML==5.4
14
+ scikit_learn==1.1.3
15
+ scipy==1.6
16
+ SoundFile==0.10.3.post1
17
+ soxbindings==1.2.3
18
+ torch==1.9.0
19
+ torchaudio==0.9.0
20
+ torchvision==0.10.0
21
+ torchmetrics==0.6.0
22
+ torchtext==0.10.0
23
+ demucs
samples/interpolation/#0/input.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc4a7d1283666051d43d07e9a11d4c5014426b0753b316fb64d1aef30288b0bd
3
+ size 5274396
samples/interpolation/#0/reference.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:207cdf21724c640a1b7006013d2519c3ac4176604019d7738905911990036a6d
3
+ size 3842338
samples/interpolation/#0/reference_B.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b591708b64f54c3f5c2d03c4359bfa6bc7b125041bb2792a34cd0fd18ae4961
3
+ size 3790802
samples/interpolation/#0/separated/mdx_extra/input/bass.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be2299dc8d612104b305f7a80c93c5dcb05da62e815c3c865bdfac91e56cbedf
3
+ size 5274396
samples/interpolation/#0/separated/mdx_extra/input/drums.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b503786e6b8503346ed0f2ceb482bc16440cd63ec3e0b83952c47998a292fb84
3
+ size 5274396
samples/interpolation/#0/separated/mdx_extra/input/other.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e0f4593e841ae116e0e1a71c66e84c032c82c5780c295af3a6cb1968bc7d4dd
3
+ size 5274396
samples/interpolation/#0/separated/mdx_extra/input/vocals.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:690bedc47fca08257eba3ec411a3a9f7afb340b5ad6d8bf909f372f4ec369853
3
+ size 5274396
samples/interpolation/#0/separated/mdx_extra/reference/bass.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20023ad9e5b90ee8c338dcc77e5f1644b5131d3acc2ab17f8469bc8bfe57a353
3
+ size 3529776
samples/interpolation/#0/separated/mdx_extra/reference/drums.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a0c050eb941aa939d0c4f78216ad3a40d5ee43d51a7ba23266d228d661785e6
3
+ size 3529776
samples/interpolation/#0/separated/mdx_extra/reference/other.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c405c3ed345388b809520ea5f43229210d67f73ab4b1c1903ca2c761c7467214
3
+ size 3529776
samples/interpolation/#0/separated/mdx_extra/reference/vocals.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07e7685489ab7fe68f3e3752ba01789a8be350a67b86ecf95fa877bb6965e0b4
3
+ size 3529776
samples/interpolation/#0/separated/mdx_extra/reference_B/bass.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:237b9986f23f90c8c7a92d3f284243eaceb0b053b530e11984cc487915f96d45
3
+ size 3482464
samples/interpolation/#0/separated/mdx_extra/reference_B/drums.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49e9ff173b853fcb22ccc6dee3c63def308c2a4cc5c76bf46f3bbd80f0f4bd41
3
+ size 3482464
samples/interpolation/#0/separated/mdx_extra/reference_B/other.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65f48ae65fd288c052efca0fedace5ad8ad9504d664be040c3af48796689f3c1
3
+ size 3482464
samples/interpolation/#0/separated/mdx_extra/reference_B/vocals.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ac504e751f3c601e8cb5ee5175b28e8d0606e78de66bf606f5097fec11ada4c
3
+ size 3482464
samples/style_transfer/#0/input.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d4365f3972c0a58f01479ea90a532e319e4ebe8773cae5e88b5a13fca3de26c
3
+ size 2646196
samples/style_transfer/#0/reference.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01d5d2560ae9d368aab59f1c5297b34f962126a81b535ecbf9031d272b907823
3
+ size 5421522
samples/style_transfer/#0/separated/mdx_extra/input/bass.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b46e0c2d82c11f8ff407fe89eb662834429c40f5642066eb341ca2e9b56cd264
3
+ size 2646196
samples/style_transfer/#0/separated/mdx_extra/input/drums.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a98c5400d00ba28b7b4644c473a34dd2b027b8e5496546cfbfe76f45c554c8da
3
+ size 2646196
samples/style_transfer/#0/separated/mdx_extra/input/other.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c26161d9fc98898100c9c229e55eab28a41787ba57253872b82f3f8e916e39fb
3
+ size 2646196
samples/style_transfer/#0/separated/mdx_extra/input/vocals.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:337af4058bcdaf404fb110d5e071c24bbf77753a55aaa7e8db47726ce502bcab
3
+ size 2646196
samples/style_transfer/#0/separated/mdx_extra/reference/bass.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b79cc96acf4550bbd9c1bd79ea3ab43a1fb90ea15e65288806d379c7de6e8015
3
+ size 5421404
samples/style_transfer/#0/separated/mdx_extra/reference/drums.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:395bc9f89b9e996a12ec7e57bedf9c84899b67fb0d9beabd7e979fe5ac7e8ebb
3
+ size 5421404
samples/style_transfer/#0/separated/mdx_extra/reference/other.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68bab09fa8dee9591a34400ba54284255f4ab059d18e83351ed2e95a894e40e1
3
+ size 5421404