Mudrock commited on
Commit
530a7d1
·
1 Parent(s): b05a8d5

Upload 18 files

Browse files
Files changed (18) hide show
  1. LICENSE +21 -0
  2. README.md +122 -12
  3. cog.yaml +33 -0
  4. config.py +62 -0
  5. create_balanced_list.py +24 -0
  6. create_index.sh +12 -0
  7. create_indexes.py +126 -0
  8. data_processor.py +179 -0
  9. htsat_config.py +122 -0
  10. htsat_utils.py +226 -0
  11. losses.py +23 -0
  12. main.py +502 -0
  13. opt_thres.pkl +3 -0
  14. predict.py +111 -0
  15. requirements.txt +19 -0
  16. sed_model.py +358 -0
  17. utils.py +580 -0
  18. zero_shot_create_vector.py +158 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Knut(Ke) Chen
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,122 @@
1
- ---
2
- title: Sebas
3
- emoji: 🔥
4
- colorFrom: pink
5
- colorTo: pink
6
- sdk: streamlit
7
- sdk_version: 1.15.2
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Zero Shot Audio Source Separation
2
+
3
+ ## Introduction
4
+
5
+ The Code Repository for "[Zero-shot Audio Source Separation through Query-based Learning from Weakly-labeled Data](https://arxiv.org/abs/2112.07891)", in AAAI 2022.
6
+
7
+ In this paper, we propose a three-component pipline that allows you to train a audio source separator to separate *any source* from the track. All you need is a mixture audio to separate, and a given source sample as a query. Then the model will separate your specified source from the track. Our model lies in a zero-shot setting because we never use the seapration dataset but a general audio dataset **AudioSet**. However, we achieve a very competible separation performance (SDR) in MUSDB18 Dataset compared with those supervised models. Our model has a generalization ability to unseen sources out of the training set. Indeed, we do not even require the separation dataset for training but solely **AudioSet**.
8
+
9
+ The demos and introduction are presented in our [short instroduction video](https://youtu.be/8XQ5ZyYRLQM) and [full presentation video](https://youtu.be/RgNwB_pJ7Cw).
10
+
11
+ More demos will be presented in [my personal website](https://www.knutchen.com) (now under construction)
12
+
13
+ Chckout this interactive demo at Replicate <a href="https://replicate.com/retrocirce/zero_shot_audio_source_separation"><img src="https://replicate.com/retrocirce/zero_shot_audio_source_separation/badge"></a> Thanks @[ariel415el](https://github.com/ariel415el) for creating this!
14
+
15
+ ![Model Arch](fig/arch.png)
16
+
17
+
18
+
19
+ ## Main Separation Performance on MUSDB18 Dataset
20
+ We achieve a very competible separation performance (SDR) in MUSDB18 Dataset **with neither seeing the MUSDB18 training data nor speficying source targets**, compared with those supervised models.
21
+
22
+ Additionally, our model can easily separate many other sources, such as violin, harmonica, guitar, etc. (demos shown in the above video link)
23
+
24
+ <p align="center">
25
+ <img src="fig/results.png" align="center" alt="MUSDB results" width="50%"/>
26
+ </p>
27
+
28
+ ## Getting Started
29
+
30
+ ### Install Requirments
31
+ ```
32
+ pip install -r requirements.txt
33
+ ```
34
+
35
+ ### Download and Processing Datasets
36
+
37
+ * config.py
38
+ ```
39
+ change the varible "dataset_path" to your audioset address
40
+ change the classes_num to 527
41
+ ```
42
+
43
+ * [AudioSet](https://research.google.com/audioset/download.html)
44
+ ```
45
+ ./create_index.sh #
46
+ // remember to change the pathes in the script
47
+ // more information about this script is in https://github.com/qiuqiangkong/audioset_tagging_cnn
48
+
49
+ python main.py save_idc
50
+ // count the number of samples in each class and save the npy files
51
+ ```
52
+
53
+ * [MUSDB18](https://sigsep.github.io/datasets/musdb.html) - You can directly use [our processed musdb audio files](https://drive.google.com/drive/folders/1VwRnCxp3t2bXUS_MbXiFiggwkkJQEmha?usp=sharing) in 32000Hz sample rate. Or you set the "musdb_path" in the download path, and:
54
+
55
+ ```
56
+ python main.py musdb_process
57
+ // Notice that the training set is a highlight version, while the testing set is the full version
58
+ ```
59
+
60
+
61
+ ### Set the Configuration File: config.py
62
+
63
+ The script *config.py* contains all configurations you need to assign to run your code.
64
+
65
+ Please read the introduction comments in the file and change your settings.
66
+
67
+ For the most important part:
68
+
69
+ If you want to train/test your model on AudioSet, you need to set:
70
+ ```
71
+ dataset_path = "your processed audioset folder"
72
+ balanced_data = True
73
+ sample_rate = 32000
74
+ hop_size = 320
75
+ classes_num = 527
76
+ ```
77
+
78
+ ### Train and Evaluation
79
+
80
+ #### Train the sound event detection system ST-SED/HTS-AT
81
+ We further integrated this system ST-SED into an independent repository, and evaluteed it on more datasets, improved it a lot and achieved better performance.
82
+
83
+ You can follow [this repo](https://github.com/RetroCirce/HTS-Audio-Transformer) to train and evalute the sound event detection system ST-SED (or a more relevant name HTS-AT), the configuation file for training the model for this separation task should be [htsat_config.py](htsat_config.py).
84
+
85
+ For this separation task, if you want to save time, you can also download [the checkpoint](https://drive.google.com/drive/folders/1RouwHsGsMs8n3l_jF8XifWtbPzur_YQS?usp=sharing) directly.
86
+
87
+ #### Train, Evaluate and Inference the Seapration Model
88
+
89
+ All scripts is run by main.py:
90
+ ```
91
+ Train: CUDA_VISIBLE_DEVICES=1,2,3,4 python main.py train
92
+
93
+ Test: CUDA_VISIBLE_DEVICES=1,2,3,4 python main.py test
94
+
95
+ ```
96
+ We recommend using at least 4 GPU cards with above 20GB memories per card. In our training phrase, we use 8 Nvidia V-100 (32GB) GPUs.
97
+
98
+ We provide a quick **inference** interface by:
99
+ ```
100
+ CUDA_VISIBLE_DEVICES=1 python main.py inference
101
+ ```
102
+ Where you can separate any given source from the track. You need to set the value of "inference_file" and "inference_query" in *config.py*. Just check the comment and get it started. And for the inference, we recommend to use only one card (because it is already enough).
103
+
104
+
105
+ #### Model Checkpoints:
106
+
107
+ We provide the model checkpoints in this [link](https://drive.google.com/drive/folders/1RouwHsGsMs8n3l_jF8XifWtbPzur_YQS?usp=sharing). Feel free to download and test it.
108
+
109
+ ## Citing
110
+ ```
111
+ @inproceedings{zsasp-ke2022,
112
+ author = {Ke Chen* and Xingjian Du* and Bilei Zhu and Zejun Ma and Taylor Berg-Kirkpatrick and Shlomo Dubnov},
113
+ title = {Zero-shot Audio Source Separation via Query-based Learning from Weakly-labeled Data},
114
+ booktitle = {{AAAI} 2022}
115
+ }
116
+
117
+ @inproceedings{htsat-ke2022,
118
+ author = {Ke Chen and Xingjian Du and Bilei Zhu and Zejun Ma and Taylor Berg-Kirkpatrick and Shlomo Dubnov},
119
+ title = {HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection},
120
+ booktitle = {{ICASSP} 2022}
121
+ }
122
+ ```
cog.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build:
2
+ gpu: true
3
+ python_version: "3.8"
4
+ system_packages:
5
+ - "libgl1-mesa-glx"
6
+ - "libglib2.0-0"
7
+ - "libsndfile1-dev"
8
+ - "ffmpeg"
9
+ python_packages:
10
+ - torch==1.9.0
11
+ - torchmetrics==0.6.0
12
+ - torchaudio==0.9.0
13
+ - torchcontrib==0.0.2
14
+ - torchlibrosa==0.0.9
15
+ - librosa==0.8.0
16
+ - pytorch_lightning==1.4.1
17
+ - museval==0.4.0
18
+ - noisereduce==2.0.0
19
+ - numba==0.55.1
20
+ - numpy==1.19.4
21
+ - scikit_learn==0.24.0
22
+ - scipy==1.6.0
23
+ - soundfile==0.10.3.post1
24
+ - tensorboard==2.2.0
25
+ - tqdm==4.55.0
26
+ - h5py==3.1.0
27
+ - musdb==0.4.0
28
+
29
+ # run:
30
+ # - pip install open3d
31
+ # # - "gdown --id 16VnMcF1KJYxN9QId6TClMsZRahHNMW5g"
32
+
33
+ predict: "predict.py:Predictor"
config.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
3
+ # Zero-shot Audio Source Separation via Query-based Learning from Weakly-labeled Data
4
+ # The configuration file
5
+
6
+ # for model training
7
+ exp_name = "exp_zs_asp_full" # the saved ckpt prefix name of the model
8
+ workspace = "/home/Research/ZS_ASP/" # the folder of your code
9
+ dataset_path = "/home/Research/ZS_ASP/data/audioset" # the dataset path
10
+ index_type = "full_train"
11
+ idc_path = "/home/Research/ZS_ASP/" # the folder of audioset class count files
12
+ balanced_data = True
13
+
14
+ # trained from a checkpoint, or evaluate a single model
15
+ resume_checkpoint = None
16
+ # "/home/Research/ZS_ASP/model_backup/zeroshot_asp_full.ckpt"
17
+
18
+ loss_type = "mae"
19
+
20
+ gather_mode = False
21
+ debug = False
22
+
23
+ classes_num = 527
24
+ eval_list = [] # left blank to preserve all classes, otherwise will filter the specified classes
25
+ # [15, 63, 81, 184, 335, 449, 474, 348, 486, 4] # randomly generated from the 527-classes for held-out evaludation
26
+
27
+
28
+ batch_size = 16 * 8 # batch size per GPU x GPU number , default is 16 x 8 = 128
29
+ learning_rate = 1e-3 # 3e-4 is also workable
30
+ max_epoch = 100
31
+ num_workers = 3
32
+ lr_scheduler_epoch = [90, 110]
33
+ latent_dim = 2048
34
+
35
+ # for signal processing
36
+ sample_rate = 32000
37
+ clip_samples = sample_rate * 10 # audio_set 10-sec clip
38
+ segment_frames = 200
39
+ hop_samples = 320
40
+ random_seed = 12412 # 444612 1536123 12412
41
+ random_mode = "one_class" # "no_random, one_class, random, order", one class is the best
42
+
43
+ # for evaluation
44
+ musdb_path = "/home/Research/ZS_ASP/data/musdb-wav/" # musdb download folder
45
+ testavg_path = "/home/Research/ZS_ASP/data/musdb30-train-32000fs.npy" # the processed training set (to get the latent query)
46
+ testset_path = "/home/Research/ZS_ASP/data/musdb-test-32000fs.npy" # the processed testing set (to calculate the performance)
47
+ test_key = ["vocals", "drums", "bass", "other"] # four tracks for musdb, and your named track for other inference
48
+ test_type = "mix"
49
+ infer_type = "mean"
50
+ energy_thres = 0.1
51
+ wave_output_path = "/home/Research/ZS_ASP/wavoutput" # output folder
52
+ using_wiener = True # use wiener filter or not (default: True)
53
+ using_whiting = False # use whiting or not (default: False)
54
+
55
+ # weight average
56
+ wa_model_folder = "/home/Research/ZS_ASP/version_3/checkpoints/"
57
+ wa_model_path = "zs_wa.ckpt"
58
+
59
+ # for inference
60
+ inference_file = "/home/Research/ZS_ASP/data/pagenini.wav" # an audio file to separate
61
+ inference_query = "/home/Research/ZS_ASP/data/query" # a folder containing all samples for obtaining the query
62
+ overlap_rate = 0.0 # [0.0, 1.0), 0 to disabled, recommand 0.5 for 50% overlap. Overlap will increase computation time and improve result quality
create_balanced_list.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
3
+ import os
4
+ import sys
5
+ import config
6
+ import logging
7
+ import numpy as np
8
+
9
+ from utils import get_balanced_class_list
10
+
11
+ def main():
12
+ train_indexes_hdf5_path = os.path.join(config.dataset_path, "hdf5s", "indexes",
13
+ "{}.h5".format(config.data_type))
14
+
15
+ eval_indexes_hdf5_path = os.path.join(config.dataset_path, "hdf5s", "indexes", "eval.h5")
16
+ logging.info("Process training data")
17
+ indexes_per_class = get_balanced_class_list(train_indexes_hdf5_path, random_seed = config.random_seed)
18
+ np.save("idc_train.npy", indexes_per_class)
19
+ logging.info("Process testing data")
20
+ indexes_per_class = get_balanced_class_list(eval_indexes_hdf5_path, random_seed = config.random_seed)
21
+ np.save("idc_eval.npy", indexes_per_class)
22
+ if __name__ == '__main__':
23
+ logging.basicConfig(level=logging.INFO)
24
+ main()
create_index.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ python3 create_indexes.py create_indexes --waveforms_hdf5_path="/home/Research/ZS_ASP/data/audioset/hdf5s/indexes/balanced_train.h5"
4
+
5
+ # Unbalanced training indexes
6
+ for IDX in {00..40}; do
7
+ echo $IDX
8
+ python3 create_indexes.py create_indexes --waveforms_hdf5_path="/home/Research/ZS_ASP/data/audioset/hdf5s/waveforms/unbalanced_train/unbalanced_train_part$IDX.h5" --indexes_hdf5_path="/home/Research/ZS_ASP/data/audioset/hdf5s/indexes/unbalanced_train/unbalanced_train_part$IDX.h5"
9
+ done
10
+
11
+ # Combine balanced and unbalanced training indexes to a full training indexes hdf5
12
+ python3 create_indexes.py combine_full_indexes --indexes_hdf5s_dir="/home/Research/ZS_ASP/data/audioset/hdf5s/indexes" --full_indexes_hdf5_path="/home/Research/ZS_ASP/data/audioset/hdf5s/indexes/full_train.h5"
create_indexes.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import argparse
3
+ import csv
4
+ import os
5
+ import glob
6
+ import datetime
7
+ import time
8
+ import logging
9
+ import h5py
10
+ import librosa
11
+
12
+ from utils import create_folder, get_sub_filepaths
13
+ import config
14
+
15
+
16
+ def create_indexes(args):
17
+ """Create indexes a for dataloader to read for training. When users have
18
+ a new task and their own data, they need to create similar indexes. The
19
+ indexes contain meta information of "where to find the data for training".
20
+ """
21
+
22
+ # Arguments & parameters
23
+ waveforms_hdf5_path = args.waveforms_hdf5_path
24
+ indexes_hdf5_path = args.indexes_hdf5_path
25
+
26
+ # Paths
27
+ create_folder(os.path.dirname(indexes_hdf5_path))
28
+
29
+ with h5py.File(waveforms_hdf5_path, 'r') as hr:
30
+ with h5py.File(indexes_hdf5_path, 'w') as hw:
31
+ audios_num = len(hr['audio_name'])
32
+ hw.create_dataset('audio_name', data=hr['audio_name'][:], dtype='S20')
33
+ hw.create_dataset('target', data=hr['target'][:], dtype=np.bool)
34
+ hw.create_dataset('hdf5_path', data=[waveforms_hdf5_path.encode()] * audios_num, dtype='S200')
35
+ hw.create_dataset('index_in_hdf5', data=np.arange(audios_num), dtype=np.int32)
36
+
37
+ print('Write to {}'.format(indexes_hdf5_path))
38
+
39
+
40
+ def combine_full_indexes(args):
41
+ """Combine all balanced and unbalanced indexes hdf5s to a single hdf5. This
42
+ combined indexes hdf5 is used for training with full data (~20k balanced
43
+ audio clips + ~1.9m unbalanced audio clips).
44
+ """
45
+
46
+ # Arguments & parameters
47
+ indexes_hdf5s_dir = args.indexes_hdf5s_dir
48
+ full_indexes_hdf5_path = args.full_indexes_hdf5_path
49
+
50
+ classes_num = config.classes_num
51
+
52
+ # Paths
53
+ paths = get_sub_filepaths(indexes_hdf5s_dir)
54
+ paths = [path for path in paths if (
55
+ 'train' in path and 'full_train' not in path and 'mini' not in path)]
56
+
57
+ print('Total {} hdf5 to combine.'.format(len(paths)))
58
+
59
+ with h5py.File(full_indexes_hdf5_path, 'w') as full_hf:
60
+ full_hf.create_dataset(
61
+ name='audio_name',
62
+ shape=(0,),
63
+ maxshape=(None,),
64
+ dtype='S20')
65
+
66
+ full_hf.create_dataset(
67
+ name='target',
68
+ shape=(0, classes_num),
69
+ maxshape=(None, classes_num),
70
+ dtype=np.bool)
71
+
72
+ full_hf.create_dataset(
73
+ name='hdf5_path',
74
+ shape=(0,),
75
+ maxshape=(None,),
76
+ dtype='S200')
77
+
78
+ full_hf.create_dataset(
79
+ name='index_in_hdf5',
80
+ shape=(0,),
81
+ maxshape=(None,),
82
+ dtype=np.int32)
83
+
84
+ for path in paths:
85
+ with h5py.File(path, 'r') as part_hf:
86
+ print(path)
87
+ n = len(full_hf['audio_name'][:])
88
+ new_n = n + len(part_hf['audio_name'][:])
89
+
90
+ full_hf['audio_name'].resize((new_n,))
91
+ full_hf['audio_name'][n : new_n] = part_hf['audio_name'][:]
92
+
93
+ full_hf['target'].resize((new_n, classes_num))
94
+ full_hf['target'][n : new_n] = part_hf['target'][:]
95
+
96
+ full_hf['hdf5_path'].resize((new_n,))
97
+ full_hf['hdf5_path'][n : new_n] = part_hf['hdf5_path'][:]
98
+
99
+ full_hf['index_in_hdf5'].resize((new_n,))
100
+ full_hf['index_in_hdf5'][n : new_n] = part_hf['index_in_hdf5'][:]
101
+
102
+ print('Write combined full hdf5 to {}'.format(full_indexes_hdf5_path))
103
+
104
+
105
+ if __name__ == '__main__':
106
+ parser = argparse.ArgumentParser()
107
+ subparsers = parser.add_subparsers(dest='mode')
108
+
109
+ parser_create_indexes = subparsers.add_parser('create_indexes')
110
+ parser_create_indexes.add_argument('--waveforms_hdf5_path', type=str, required=True, help='Path of packed waveforms hdf5.')
111
+ parser_create_indexes.add_argument('--indexes_hdf5_path', type=str, required=True, help='Path to write out indexes hdf5.')
112
+
113
+ parser_combine_full_indexes = subparsers.add_parser('combine_full_indexes')
114
+ parser_combine_full_indexes.add_argument('--indexes_hdf5s_dir', type=str, required=True, help='Directory containing indexes hdf5s to be combined.')
115
+ parser_combine_full_indexes.add_argument('--full_indexes_hdf5_path', type=str, required=True, help='Path to write out full indexes hdf5 file.')
116
+
117
+ args = parser.parse_args()
118
+
119
+ if args.mode == 'create_indexes':
120
+ create_indexes(args)
121
+
122
+ elif args.mode == 'combine_full_indexes':
123
+ combine_full_indexes(args)
124
+
125
+ else:
126
+ raise Exception('Incorrect arguments!')
data_processor.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
3
+ # Zero-shot Audio Source Separation via Query-based Learning from Weakly-labeled Data
4
+ # The dataset classes
5
+
6
+ import numpy as np
7
+ import torch
8
+ import logging
9
+ import os
10
+ import sys
11
+ import h5py
12
+ import csv
13
+ import time
14
+ import random
15
+ import json
16
+ from datetime import datetime
17
+ from utils import int16_to_float32
18
+
19
+ from torch.utils.data import Dataset, Sampler
20
+
21
+ # output the dict["index"].key form to save the memory in multi-GPU training
22
+ def reverse_dict(data_path, sed_path, output_dir):
23
+ # filename
24
+ waveform_dir = os.path.join(output_dir, "audioset_eval_waveform_balanced.h5")
25
+ sed_dir = os.path.join(output_dir, "audioset_eval_sed_balanced.h5")
26
+ # load data
27
+ logging.info("Write Data...............")
28
+ h_data = h5py.File(data_path, "r")
29
+ h_sed = h5py.File(sed_path, "r")
30
+ audio_num = len(h_data["waveform"])
31
+ assert len(h_data["waveform"]) == len(h_sed["sed_vector"]), "waveform and sed should be in the same length"
32
+ with h5py.File(waveform_dir, 'w') as hw:
33
+ for i in range(audio_num):
34
+ hw.create_dataset(str(i), data=int16_to_float32(h_data['waveform'][i]), dtype=np.float32)
35
+ logging.info("Write Data Succeed...............")
36
+ logging.info("Write Sed...............")
37
+ with h5py.File(sed_dir, 'w') as hw:
38
+ for i in range(audio_num):
39
+ hw.create_dataset(str(i), data=h_sed['sed_vector'][i], dtype=np.float32)
40
+ logging.info("Write Sed Succeed...............")
41
+
42
+ # A dataset for handling musdb
43
+ class MusdbDataset(Dataset):
44
+ def __init__(self, tracks):
45
+ self.tracks = tracks
46
+ self.dataset_len = len(tracks)
47
+ def __getitem__(self, index):
48
+ """Load waveform and target of an audio clip.
49
+ Args:
50
+ index: the index number
51
+ Return:
52
+ track: [mixture + n_sources, n_samples]
53
+ """
54
+ return self.tracks[index]
55
+ def __len__(self):
56
+ return self.dataset_len
57
+
58
+ class InferDataset(Dataset):
59
+ def __init__(self, tracks):
60
+ self.tracks = tracks
61
+ self.dataset_len = len(tracks)
62
+ def __getitem__(self, index):
63
+ """Load waveform and target of an audio clip.
64
+ Args:
65
+ index: the index number
66
+ Return:
67
+ track: [mixture + n_sources, n_samples]
68
+ """
69
+ return self.tracks[index]
70
+ def __len__(self):
71
+ return self.dataset_len
72
+
73
+ # polished LGSPDataset, the main dataset for procssing the audioset files
74
+ class LGSPDataset(Dataset):
75
+ def __init__(self, index_path, idc, config, factor = 3, eval_mode = False):
76
+ self.index_path = index_path
77
+ self.fp = h5py.File(index_path, "r")
78
+ self.config = config
79
+ self.idc = idc
80
+ self.factor = factor
81
+ self.classes_num = self.config.classes_num
82
+ self.eval_mode = eval_mode
83
+ self.total_size = int(len(self.fp["audio_name"]) * self.factor)
84
+ self.generate_queue()
85
+ logging.info("total dataset size: %d" %(self.total_size))
86
+ logging.info("class num: %d" %(self.classes_num))
87
+
88
+ def generate_queue(self):
89
+ self.queue = []
90
+ self.class_queue = []
91
+ if self.config.debug:
92
+ self.total_size = 1000
93
+ if self.config.balanced_data:
94
+ while len(self.queue) < self.total_size * 2:
95
+ if self.eval_mode:
96
+ if len(self.config.eval_list) == 0:
97
+ class_set = [*range(self.classes_num)]
98
+ else:
99
+ class_set = self.config.eval_list[:]
100
+ else:
101
+ class_set = [*range(self.classes_num)]
102
+ class_set = list(set(class_set) - set(self.config.eval_list))
103
+ random.shuffle(class_set)
104
+ self.queue += [self.idc[d][random.randint(0, len(self.idc[d]) - 1)] for d in class_set]
105
+ self.class_queue += class_set[:]
106
+ self.queue = self.queue[:self.total_size * 2]
107
+ self.class_queue = self.class_queue[:self.total_size * 2]
108
+ self.queue = [[self.queue[i],self.queue[i+1]] for i in range(0, self.total_size * 2, 2)]
109
+ self.class_queue = [[self.class_queue[i],self.class_queue[i+1]] for i in range(0, self.total_size * 2, 2)]
110
+ assert len(self.queue) == self.total_size, "generate data error!!"
111
+ else:
112
+ if self.eval_mode:
113
+ if len(self.config.eval_list) == 0:
114
+ class_set = [*range(self.classes_num)]
115
+ else:
116
+ class_set = self.config.eval_list[:]
117
+ else:
118
+ class_set = [*range(self.classes_num)]
119
+ class_set = list(set(class_set) - set(self.config.eval_list))
120
+ self.class_queue = random.choices(class_set, k = self.total_size * 2)
121
+ self.queue = [self.idc[d][random.randint(0, len(self.idc[d]) - 1)] for d in self.class_queue]
122
+ self.queue = [[self.queue[i],self.queue[i+1]] for i in range(0, self.total_size * 2, 2)]
123
+ self.class_queue = [[self.class_queue[i],self.class_queue[i+1]] for i in range(0, self.total_size * 2, 2)]
124
+ assert len(self.queue) == self.total_size, "generate data error!!"
125
+ logging.info("queue regenerated:%s" %(self.queue[-5:]))
126
+
127
+ def __getitem__(self, index):
128
+ """Load waveform and target of an audio clip.
129
+ Args:
130
+ index: the index number
131
+ Return: {
132
+ "audio_name_1": str,
133
+ "waveform_1": (clip_samples,),
134
+ "class_id_1": int,
135
+ "audio_name_2": str,
136
+ "waveform_2": (clip_samples,),
137
+ "class_id_2": int,
138
+ ...
139
+ "check_num": int
140
+ }
141
+ """
142
+ # put the right index here!!!
143
+ data_dict = {}
144
+ for k in range(2):
145
+ s_index = self.queue[index][k]
146
+ target = self.class_queue[index][k]
147
+ audio_name = self.fp["audio_name"][s_index].decode()
148
+ hdf5_path = self.fp["hdf5_path"][s_index].decode().replace("/home/tiger/DB/knut/data/audioset", self.config.dataset_path)
149
+ r_idx = self.fp["index_in_hdf5"][s_index]
150
+ with h5py.File(hdf5_path, "r") as f:
151
+ waveform = int16_to_float32(f["waveform"][r_idx])
152
+ data_dict["audio_name_" + str(k+1)] = audio_name
153
+ data_dict["waveform_" + str(k+1)] = waveform
154
+ data_dict["class_id_" + str(k+1)] = target
155
+ data_dict["check_num"] = str(self.queue[-5:])
156
+ return data_dict
157
+
158
+ def __len__(self):
159
+ return self.total_size
160
+
161
+ # only for test
162
+ class TestDataset(Dataset):
163
+ def __init__(self, dataset_size):
164
+ print("init")
165
+ self.dataset_size = dataset_size
166
+ self.base_num = 100
167
+ self.dicts = [(self.base_num + 2 * i, self.base_num + 2 * i + 1) for i in range(self.dataset_size)]
168
+
169
+ def get_new_list(self):
170
+ self.base_num = random.randint(0,10)
171
+ print("base num changed:", self.base_num)
172
+ self.dicts = [(self.base_num + 2 * i, self.base_num + 2 * i + 1) for i in range(self.dataset_size)]
173
+
174
+ def __getitem__(self, index):
175
+ return self.dicts[index]
176
+
177
+ def __len__(self):
178
+ return self.dataset_size
179
+
htsat_config.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
3
+ # Zero-shot Audio Source Separation via Query-based Learning from Weakly-labeled Data
4
+ # The configuration file of ST-SED model or HTS-AT model
5
+
6
+ exp_name = "exp_htsat_2048d" # the saved ckpt prefix name of the model
7
+ workspace = "/home/kechen/Research/HTSAT" # the folder of your code
8
+ dataset_path = "/home/Research/audioset" # the dataset path
9
+ desed_folder = "/home/Research/DESED" # the desed file
10
+
11
+ dataset_type = "audioset"
12
+
13
+ loss_type = "clip_bce"
14
+ balanced_data = True
15
+
16
+ resume_checkpoint = "/home/kechen/Research/Latent_ASP/model_backup/htsat_audioset_2048d.ckpt"
17
+
18
+ esc_fold = 0 # just for esc dataset, select the fold you need for evaluation and (+1) validation
19
+
20
+ debug = False
21
+
22
+ random_seed = 970131 # 19970318 970131 12412 127777 1009 34047
23
+ batch_size = 32 * 4 # batch size per GPU x GPU number , default is 32 x 4 = 128
24
+ learning_rate = 1e-3 # 1e-4 also workable
25
+ max_epoch = 100
26
+ num_workers = 3
27
+
28
+ lr_scheduler_epoch = [10,20,30]
29
+ lr_rate = [0.02, 0.05, 0.1]
30
+
31
+ # these data preparation optimizations do not bring many improvements, so deprecated
32
+ enable_token_label = False # token label
33
+ class_map_path = "class_hier_map.npy"
34
+ class_filter = None
35
+ retrieval_index = [15382, 9202, 130, 17618, 17157, 17516, 16356, 6165, 13992, 9238, 5550, 5733, 1914, 1600, 3450, 13735, 11108, 3762,
36
+ 9840, 11318, 8131, 4429, 16748, 4992, 16783, 12691, 4945, 8779, 2805, 9418, 2797, 14357, 5603, 212, 3852, 12666, 1338, 10269, 2388, 8260, 4293, 14454, 7677, 11253, 5060, 14938, 8840, 4542, 2627, 16336, 8992, 15496, 11140, 446, 6126, 10691, 8624, 10127, 9068, 16710, 10155, 14358, 7567, 5695, 2354, 8057, 17635, 133, 16183, 14535, 7248, 4560, 14429, 2463, 10773, 113, 2462, 9223, 4929, 14274, 4716, 17307, 4617, 2132, 11083, 1039, 1403, 9621, 13936, 2229, 2875, 17840, 9359, 13311, 9790, 13288, 4750, 17052, 8260, 14900]
37
+ token_label_range = [0.2,0.6]
38
+ enable_time_shift = False # shift time
39
+ enable_label_enhance = False # enhance hierarchical label
40
+ enable_repeat_mode = False # repeat the spectrogram / reshape the spectrogram
41
+
42
+
43
+
44
+ # for model's design
45
+ enable_tscam = True # enbale the token-semantic layer
46
+
47
+ # for signal processing
48
+ sample_rate = 32000 # 16000 for scv2, 32000 for audioset and esc-50
49
+ clip_samples = sample_rate * 10 # audio_set 10-sec clip
50
+ window_size = 1024
51
+ hop_size = 320 # 160 for scv2, 320 for audioset and esc-50
52
+ mel_bins = 64
53
+ fmin = 50
54
+ fmax = 14000
55
+ shift_max = int(clip_samples * 0.5)
56
+
57
+ # for data collection
58
+ classes_num = 527 # esc: 50 | audioset: 527 | scv2: 35
59
+ patch_size = (25, 4) # deprecated
60
+ crop_size = None # int(clip_samples * 0.5) deprecated
61
+
62
+ # for htsat hyperparamater
63
+ htsat_window_size = 8
64
+ htsat_spec_size = 256
65
+ htsat_patch_size = 4
66
+ htsat_stride = (4, 4)
67
+ htsat_num_head = [4,8,16,32]
68
+ htsat_dim = 256 # for 2048-d model
69
+ htsat_depth = [2,2,6,2]
70
+
71
+ swin_pretrain_path = None
72
+ # "/home/Research/model_backup/pretrain/swin_tiny_c24_patch4_window8_256.pth"
73
+
74
+ # Some Deprecated Optimization in the model design, check the model code for details
75
+ htsat_attn_heatmap = False
76
+ htsat_hier_output = False
77
+ htsat_use_max = False
78
+
79
+
80
+ # no use here
81
+ ensemble_checkpoints = []
82
+ ensemble_strides = []
83
+
84
+
85
+ # weight average folder
86
+ wa_folder = "/home/version_0/checkpoints/"
87
+ # weight average output filename
88
+ wa_model_path = "HTSAT_AudioSet_Saved_x.ckpt"
89
+
90
+ esm_model_pathes = [
91
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt",
92
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_2.ckpt",
93
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_3.ckpt",
94
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_4.ckpt",
95
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_5.ckpt",
96
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_6.ckpt"
97
+ ]
98
+
99
+ # for framewise localization
100
+ heatmap_dir = "/home/Research/heatmap_output"
101
+ test_file = "htsat-test-ensemble"
102
+ fl_local = False # indicate if we need to use this dataset for the framewise detection
103
+ fl_dataset = "/home/Research/desed/desed_eval.npy"
104
+ fl_class_num = [
105
+ "Speech", "Frying", "Dishes", "Running_water",
106
+ "Blender", "Electric_shaver_toothbrush", "Alarm_bell_ringing",
107
+ "Cat", "Dog", "Vacuum_cleaner"
108
+ ]
109
+
110
+ # map 527 classes into 10 classes
111
+ fl_audioset_mapping = [
112
+ [0,1,2,3,4,5,6,7],
113
+ [366, 367, 368],
114
+ [364],
115
+ [288, 289, 290, 291, 292, 293, 294, 295, 296, 297],
116
+ [369],
117
+ [382],
118
+ [310, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402],
119
+ [81, 82, 83, 84, 85],
120
+ [74, 75, 76, 77, 78, 79],
121
+ [377]
122
+ ]
htsat_utils.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
3
+ # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
4
+ # Some Useful Common Methods
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch import Tensor
10
+ from typing import Optional
11
+ import logging
12
+ import os
13
+ import sys
14
+ import h5py
15
+ import csv
16
+ import time
17
+ import json
18
+ import museval
19
+ import librosa
20
+ from datetime import datetime
21
+ from tqdm import tqdm
22
+ from scipy import stats
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+
26
+
27
+ # import from https://github.com/Alibaba-MIIL/ASL/blob/main/src/loss_functions/losses.py
28
+ class AsymmetricLoss(nn.Module):
29
+ def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True):
30
+ super(AsymmetricLoss, self).__init__()
31
+
32
+ self.gamma_neg = gamma_neg
33
+ self.gamma_pos = gamma_pos
34
+ self.clip = clip
35
+ self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
36
+ self.eps = eps
37
+
38
+ def forward(self, x, y):
39
+ """"
40
+ Parameters
41
+ ----------
42
+ x: input logits
43
+ y: targets (multi-label binarized vector)
44
+ """
45
+
46
+ # Calculating Probabilities
47
+ # x_sigmoid = torch.sigmoid(x)
48
+ x_sigmoid = x # without sigmoid since it has been computed
49
+ xs_pos = x_sigmoid
50
+ xs_neg = 1 - x_sigmoid
51
+
52
+ # Asymmetric Clipping
53
+ if self.clip is not None and self.clip > 0:
54
+ xs_neg = (xs_neg + self.clip).clamp(max=1)
55
+
56
+ # Basic CE calculation
57
+ los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
58
+ los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
59
+ loss = los_pos + los_neg
60
+
61
+ # Asymmetric Focusing
62
+ if self.gamma_neg > 0 or self.gamma_pos > 0:
63
+ if self.disable_torch_grad_focal_loss:
64
+ torch.set_grad_enabled(False)
65
+ pt0 = xs_pos * y
66
+ pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
67
+ pt = pt0 + pt1
68
+ one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
69
+ one_sided_w = torch.pow(1 - pt, one_sided_gamma)
70
+ if self.disable_torch_grad_focal_loss:
71
+ torch.set_grad_enabled(True)
72
+ loss *= one_sided_w
73
+
74
+ return -loss.mean()
75
+
76
+
77
+ def get_mix_lambda(mixup_alpha, batch_size):
78
+ mixup_lambdas = [np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size)]
79
+ return np.array(mixup_lambdas).astype(np.float32)
80
+
81
+ def create_folder(fd):
82
+ if not os.path.exists(fd):
83
+ os.makedirs(fd)
84
+
85
+ def dump_config(config, filename, include_time = False):
86
+ save_time = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
87
+ config_json = {}
88
+ for key in dir(config):
89
+ if not key.startswith("_"):
90
+ config_json[key] = eval("config." + key)
91
+ if include_time:
92
+ filename = filename + "_" + save_time
93
+ with open(filename + ".json", "w") as f:
94
+ json.dump(config_json, f ,indent=4)
95
+
96
+ def int16_to_float32(x):
97
+ return (x / 32767.).astype(np.float32)
98
+
99
+ def float32_to_int16(x):
100
+ x = np.clip(x, a_min = -1., a_max = 1.)
101
+ return (x * 32767.).astype(np.int16)
102
+
103
+
104
+ # index for each class
105
+ def process_idc(index_path, classes_num, filename):
106
+ # load data
107
+ logging.info("Load Data...............")
108
+ idc = [[] for _ in range(classes_num)]
109
+ with h5py.File(index_path, "r") as f:
110
+ for i in tqdm(range(len(f["target"]))):
111
+ t_class = np.where(f["target"][i])[0]
112
+ for t in t_class:
113
+ idc[t].append(i)
114
+ print(idc)
115
+ np.save(filename, idc)
116
+ logging.info("Load Data Succeed...............")
117
+
118
+ def clip_bce(pred, target):
119
+ """Binary crossentropy loss.
120
+ """
121
+ return F.binary_cross_entropy(pred, target)
122
+ # return F.binary_cross_entropy(pred, target)
123
+
124
+
125
+ def clip_ce(pred, target):
126
+ return F.cross_entropy(pred, target)
127
+
128
+ def d_prime(auc):
129
+ d_prime = stats.norm().ppf(auc) * np.sqrt(2.0)
130
+ return d_prime
131
+
132
+
133
+ def get_loss_func(loss_type):
134
+ if loss_type == 'clip_bce':
135
+ return clip_bce
136
+ if loss_type == 'clip_ce':
137
+ return clip_ce
138
+ if loss_type == 'asl_loss':
139
+ loss_func = AsymmetricLoss(gamma_neg=4, gamma_pos=0,clip=0.05)
140
+ return loss_func
141
+
142
+ def do_mixup_label(x):
143
+ out = torch.logical_or(x, torch.flip(x, dims = [0])).float()
144
+ return out
145
+
146
+ def do_mixup(x, mixup_lambda):
147
+ """
148
+ Args:
149
+ x: (batch_size , ...)
150
+ mixup_lambda: (batch_size,)
151
+
152
+ Returns:
153
+ out: (batch_size, ...)
154
+ """
155
+ out = (x.transpose(0,-1) * mixup_lambda + torch.flip(x, dims = [0]).transpose(0,-1) * (1 - mixup_lambda)).transpose(0,-1)
156
+ return out
157
+
158
+ def interpolate(x, ratio):
159
+ """Interpolate data in time domain. This is used to compensate the
160
+ resolution reduction in downsampling of a CNN.
161
+
162
+ Args:
163
+ x: (batch_size, time_steps, classes_num)
164
+ ratio: int, ratio to interpolate
165
+
166
+ Returns:
167
+ upsampled: (batch_size, time_steps * ratio, classes_num)
168
+ """
169
+ (batch_size, time_steps, classes_num) = x.shape
170
+ upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
171
+ upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
172
+ return upsampled
173
+
174
+
175
+ def pad_framewise_output(framewise_output, frames_num):
176
+ """Pad framewise_output to the same length as input frames. The pad value
177
+ is the same as the value of the last frame.
178
+
179
+ Args:
180
+ framewise_output: (batch_size, frames_num, classes_num)
181
+ frames_num: int, number of frames to pad
182
+
183
+ Outputs:
184
+ output: (batch_size, frames_num, classes_num)
185
+ """
186
+ pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1)
187
+ """tensor for padding"""
188
+
189
+ output = torch.cat((framewise_output, pad), dim=1)
190
+ """(batch_size, frames_num, classes_num)"""
191
+
192
+ return output
193
+
194
+ # set the audio into the format that can be fed into the model
195
+ # resample -> convert to mono -> output the audio
196
+ # track [n_sample, n_channel]
197
+ def prepprocess_audio(track, ofs, rfs, mono_type = "mix"):
198
+ if track.shape[-1] > 1:
199
+ # stereo
200
+ if mono_type == "mix":
201
+ track = np.transpose(track, (1,0))
202
+ track = librosa.to_mono(track)
203
+ elif mono_type == "left":
204
+ track = track[:, 0]
205
+ elif mono_type == "right":
206
+ track = track[:, 1]
207
+ else:
208
+ track = track[:, 0]
209
+ # track [n_sample]
210
+ if ofs != rfs:
211
+ track = librosa.resample(track, ofs, rfs)
212
+ return track
213
+
214
+ def init_hier_head(class_map, num_class):
215
+ class_map = np.load(class_map, allow_pickle = True)
216
+
217
+ head_weight = torch.zeros(num_class,num_class).float()
218
+ head_bias = torch.zeros(num_class).float()
219
+
220
+ for i in range(len(class_map)):
221
+ for d in class_map[i][1]:
222
+ head_weight[d][i] = 1.0
223
+ for d in class_map[i][2]:
224
+ head_weight[d][i] = 1.0 / len(class_map[i][2])
225
+ head_weight[i][i] = 1.0
226
+ return head_weight, head_bias
losses.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ import torch
3
+ import numpy as np
4
+
5
+
6
+ def mae(input, target):
7
+ return torch.mean(torch.abs(input - target))
8
+
9
+
10
+ def logmae_wav(model, output_dict, target):
11
+ loss = torch.log10(torch.clamp(mae(output_dict['wav'], target), 1e-8, np.inf))
12
+ return loss
13
+
14
+
15
+ def get_loss_func(loss_type):
16
+ if loss_type == 'logmae_wav':
17
+ return logmae_wav
18
+
19
+ elif loss_type == 'mae':
20
+ return mae
21
+
22
+ else:
23
+ raise Exception('Incorrect loss_type!')
main.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
3
+ # Zero-shot Audio Source Separation via Query-based Learning from Weakly-labeled Data
4
+ # The Main Script
5
+
6
+ import os
7
+ # this is to avoid the sdr calculation from occupying all cpus
8
+ os.environ["OMP_NUM_THREADS"] = "4"
9
+ os.environ["OPENBLAS_NUM_THREADS"] = "4"
10
+ os.environ["MKL_NUM_THREADS"] = "6"
11
+ os.environ["VECLIB_MAXIMUM_THREADS"] = "4"
12
+ os.environ["NUMEXPR_NUM_THREADS"] = "6"
13
+
14
+ import sys
15
+ import librosa
16
+ import numpy as np
17
+ import argparse
18
+ import logging
19
+
20
+ import torch
21
+ from torch.utils.data import DataLoader
22
+ from torch.utils.data.distributed import DistributedSampler
23
+
24
+ from utils import collect_fn, dump_config, create_folder, prepprocess_audio
25
+ import musdb
26
+
27
+ from models.asp_model import ZeroShotASP, SeparatorModel, AutoTaggingWarpper, WhitingWarpper
28
+ from data_processor import LGSPDataset, MusdbDataset
29
+ import config
30
+ import htsat_config
31
+ from models.htsat import HTSAT_Swin_Transformer
32
+ from sed_model import SEDWrapper
33
+
34
+ import pytorch_lightning as pl
35
+ from pytorch_lightning.callbacks import ModelCheckpoint
36
+
37
+ from htsat_utils import process_idc
38
+
39
+ import warnings
40
+ warnings.filterwarnings("ignore")
41
+
42
+
43
+
44
+ class data_prep(pl.LightningDataModule):
45
+ def __init__(self, train_dataset, eval_dataset, device_num, config):
46
+ super().__init__()
47
+ self.train_dataset = train_dataset
48
+ self.eval_dataset = eval_dataset
49
+ self.device_num = device_num
50
+ self.config = config
51
+
52
+ def train_dataloader(self):
53
+ train_sampler = DistributedSampler(self.train_dataset, shuffle = False) if self.device_num > 1 else None
54
+ train_loader = DataLoader(
55
+ dataset = self.train_dataset,
56
+ num_workers = config.num_workers,
57
+ batch_size = config.batch_size // self.device_num,
58
+ shuffle = False,
59
+ sampler = train_sampler,
60
+ collate_fn = collect_fn
61
+ )
62
+ return train_loader
63
+ def val_dataloader(self):
64
+ eval_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None
65
+ eval_loader = DataLoader(
66
+ dataset = self.eval_dataset,
67
+ num_workers = config.num_workers,
68
+ batch_size = config.batch_size // self.device_num,
69
+ shuffle = False,
70
+ sampler = eval_sampler,
71
+ collate_fn = collect_fn
72
+ )
73
+ return eval_loader
74
+ def test_dataloader(self):
75
+ test_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None
76
+ test_loader = DataLoader(
77
+ dataset = self.eval_dataset,
78
+ num_workers = config.num_workers,
79
+ batch_size = config.batch_size // self.device_num,
80
+ shuffle = False,
81
+ sampler = test_sampler,
82
+ collate_fn = collect_fn
83
+ )
84
+ return test_loader
85
+
86
+ def save_idc():
87
+ train_index_path = os.path.join(config.dataset_path, "hdf5s", "indexes", config.index_type + ".h5")
88
+ eval_index_path = os.path.join(config.dataset_path,"hdf5s", "indexes", "eval.h5")
89
+ process_idc(train_index_path, config.classes_num, config.index_type + "_idc.npy")
90
+ process_idc(eval_index_path, config.classes_num, "eval_idc.npy")
91
+
92
+ # Process the musdb tracks into the sample rate of 32000 Hz sample rate, the original is 44100 Hz
93
+ def process_musdb():
94
+ # use musdb as testset
95
+ test_data = musdb.DB(
96
+ root = config.musdb_path,
97
+ download = False,
98
+ subsets = "test",
99
+ is_wav = True
100
+ )
101
+ print(len(test_data.tracks))
102
+ mus_tracks = []
103
+ # in musdb, all fs is the same (44100)
104
+ orig_fs = test_data.tracks[0].rate
105
+ print(orig_fs)
106
+ for track in test_data.tracks:
107
+ temp = {}
108
+ mixture = prepprocess_audio(
109
+ track.audio,
110
+ orig_fs, config.sample_rate,
111
+ config.test_type
112
+ )
113
+ temp["mixture" ]= mixture
114
+ for dickey in config.test_key:
115
+ source = prepprocess_audio(
116
+ track.targets[dickey].audio,
117
+ orig_fs, config.sample_rate,
118
+ config.test_type
119
+ )
120
+ temp[dickey] = source
121
+ print(track.audio.shape, len(temp.keys()), temp["mixture"].shape)
122
+ mus_tracks.append(temp)
123
+ print(len(mus_tracks))
124
+ # save the file to npy
125
+ np.save("musdb-32000fs.npy", mus_tracks)
126
+
127
+ # weight average will perform in the given folder
128
+ # It will output one model checkpoint, which avergas the weight of all models in the folder
129
+ def weight_average():
130
+ model_ckpt = []
131
+ model_files = os.listdir(config.wa_model_folder)
132
+ wa_ckpt = {
133
+ "state_dict": {}
134
+ }
135
+
136
+ for model_file in model_files:
137
+ model_file = os.path.join(config.esm_model_folder, model_file)
138
+ model_ckpt.append(torch.load(model_file, map_location="cpu")["state_dict"])
139
+ keys = model_ckpt[0].keys()
140
+ for key in keys:
141
+ model_ckpt_key = torch.cat([d[key].float().unsqueeze(0) for d in model_ckpt])
142
+ model_ckpt_key = torch.mean(model_ckpt_key, dim = 0)
143
+ assert model_ckpt_key.shape == model_ckpt[0][key].shape, "the shape is unmatched " + model_ckpt_key.shape + " " + model_ckpt[0][key].shape
144
+ wa_ckpt["state_dict"][key] = model_ckpt_key
145
+ torch.save(wa_ckpt, config.wa_model_path)
146
+
147
+
148
+ # use the model to quickly separate a track given a query
149
+ # it requires four variables in config.py:
150
+ # inference_file: the track you want to separate
151
+ # inference_query: a **folder** containing all samples from the same source
152
+ # test_key: ["name"] indicate the source name (just a name for final output, no other functions)
153
+ # wave_output_path: the output folder
154
+
155
+ # make sure the query folder contain the samples from the same source
156
+ # each time, the model is able to separate one source from the track
157
+ # if you want to separate multiple sources, you need to change the query folder or write a script to help you do that
158
+ def inference():
159
+ # set exp settings
160
+ device_name = "cuda" if torch.cuda.is_available() else "cpu"
161
+ device = torch.device("cuda")
162
+ assert config.test_key is not None, "there should be a separate key"
163
+ create_folder(config.wave_output_path)
164
+ test_track, fs = librosa.load(config.inference_file, sr = None)
165
+ test_track = test_track[:,None]
166
+ print(test_track.shape)
167
+ print(fs)
168
+ # convert the track into 32000 Hz sample rate
169
+ test_track = prepprocess_audio(
170
+ test_track,
171
+ fs, config.sample_rate,
172
+ config.test_type
173
+ )
174
+ test_tracks = []
175
+ temp = [test_track]
176
+ for dickey in config.test_key:
177
+ temp.append(test_track)
178
+ temp = np.array(temp)
179
+ test_tracks.append(temp)
180
+ dataset = MusdbDataset(tracks = test_tracks) # the action is similar to musdbdataset, reuse it
181
+ loader = DataLoader(
182
+ dataset = dataset,
183
+ num_workers = 1,
184
+ batch_size = 1,
185
+ shuffle = False
186
+ )
187
+ # obtain the samples for query
188
+ queries = []
189
+ for query_file in os.listdir(config.inference_query):
190
+ f_path = os.path.join(config.inference_query, query_file)
191
+ if query_file.endswith(".wav"):
192
+ temp_q, fs = librosa.load(f_path, sr = None)
193
+ temp_q = temp_q[:, None]
194
+ temp_q = prepprocess_audio(
195
+ temp_q,
196
+ fs, config.sample_rate,
197
+ config.test_type
198
+ )
199
+ temp = [temp_q]
200
+ for dickey in config.test_key:
201
+ temp.append(temp_q)
202
+ temp = np.array(temp)
203
+ queries.append(temp)
204
+
205
+ assert config.resume_checkpoint is not None, "there should be a saved model when inferring"
206
+
207
+ sed_model = HTSAT_Swin_Transformer(
208
+ spec_size=htsat_config.htsat_spec_size,
209
+ patch_size=htsat_config.htsat_patch_size,
210
+ in_chans=1,
211
+ num_classes=htsat_config.classes_num,
212
+ window_size=htsat_config.htsat_window_size,
213
+ config = htsat_config,
214
+ depths = htsat_config.htsat_depth,
215
+ embed_dim = htsat_config.htsat_dim,
216
+ patch_stride=htsat_config.htsat_stride,
217
+ num_heads=htsat_config.htsat_num_head
218
+ )
219
+ at_model = SEDWrapper(
220
+ sed_model = sed_model,
221
+ config = htsat_config,
222
+ dataset = None
223
+ )
224
+ ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu")
225
+ at_model.load_state_dict(ckpt["state_dict"])
226
+
227
+ trainer = pl.Trainer(
228
+ gpus = 1
229
+ )
230
+ avg_at = None
231
+ # obtain the latent embedding as query
232
+ if config.infer_type == "mean":
233
+ avg_dataset = MusdbDataset(tracks = queries)
234
+ avg_loader = DataLoader(
235
+ dataset = avg_dataset,
236
+ num_workers = 1,
237
+ batch_size = 1,
238
+ shuffle = False
239
+ )
240
+ at_wrapper = AutoTaggingWarpper(
241
+ at_model = at_model,
242
+ config = config,
243
+ target_keys = config.test_key
244
+ )
245
+ trainer.test(at_wrapper, test_dataloaders = avg_loader)
246
+ avg_at = at_wrapper.avg_at
247
+
248
+ # import seapration model
249
+ model = ZeroShotASP(
250
+ channels = 1, config = config,
251
+ at_model = at_model,
252
+ dataset = dataset
253
+ )
254
+ # resume checkpoint
255
+ ckpt = torch.load(config.resume_checkpoint, map_location="cpu")
256
+ model.load_state_dict(ckpt["state_dict"], strict= False)
257
+ exp_model = SeparatorModel(
258
+ model = model,
259
+ config = config,
260
+ target_keys = config.test_key,
261
+ avg_at = avg_at,
262
+ using_wiener = False,
263
+ calc_sdr = False,
264
+ output_wav = True
265
+ )
266
+ trainer.test(exp_model, test_dataloaders = loader)
267
+
268
+ # test the separation model, mainly in musdb
269
+ def test():
270
+ # set exp settings
271
+ device_name = "cuda" if torch.cuda.is_available() else "cpu"
272
+ device = torch.device("cuda")
273
+ assert config.test_key is not None, "there should be a separate key"
274
+ create_folder(config.wave_output_path)
275
+ # use musdb as testset
276
+ test_data = np.load(config.testset_path, allow_pickle = True)
277
+ print(len(test_data))
278
+ mus_tracks = []
279
+ # in musdb, all fs is the same (44100)
280
+ # load the dataset
281
+ for track in test_data:
282
+ temp = []
283
+ mixture = track["mixture"]
284
+ temp.append(mixture)
285
+ for dickey in config.test_key:
286
+ source = track[dickey]
287
+ temp.append(source)
288
+ temp = np.array(temp)
289
+ print(temp.shape)
290
+ mus_tracks.append(temp)
291
+ print(len(mus_tracks))
292
+ dataset = MusdbDataset(tracks = mus_tracks)
293
+ loader = DataLoader(
294
+ dataset = dataset,
295
+ num_workers = 1,
296
+ batch_size = 1,
297
+ shuffle = False
298
+ )
299
+ assert config.resume_checkpoint is not None, "there should be a saved model when inferring"
300
+
301
+ sed_model = HTSAT_Swin_Transformer(
302
+ spec_size=htsat_config.htsat_spec_size,
303
+ patch_size=htsat_config.htsat_patch_size,
304
+ in_chans=1,
305
+ num_classes=htsat_config.classes_num,
306
+ window_size=htsat_config.htsat_window_size,
307
+ config = htsat_config,
308
+ depths = htsat_config.htsat_depth,
309
+ embed_dim = htsat_config.htsat_dim,
310
+ patch_stride=htsat_config.htsat_stride,
311
+ num_heads=htsat_config.htsat_num_head
312
+ )
313
+ at_model = SEDWrapper(
314
+ sed_model = sed_model,
315
+ config = htsat_config,
316
+ dataset = None
317
+ )
318
+ ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu")
319
+ at_model.load_state_dict(ckpt["state_dict"])
320
+ trainer = pl.Trainer(
321
+ gpus = 1
322
+ )
323
+ avg_at = None
324
+ # obtain the query of four stems from the training set
325
+ if config.infer_type == "mean":
326
+ avg_data = np.load(config.testavg_path, allow_pickle = True)[:90]
327
+ print(len(avg_data))
328
+ avgmus_tracks = []
329
+ # in musdb, all fs is the same (44100)
330
+ # load the dataset
331
+ for track in avg_data:
332
+ temp = []
333
+ mixture = track["mixture"]
334
+ temp.append(mixture)
335
+ for dickey in config.test_key:
336
+ source = track[dickey]
337
+ temp.append(source)
338
+ temp = np.array(temp)
339
+ print(temp.shape)
340
+ avgmus_tracks.append(temp)
341
+ print(len(avgmus_tracks))
342
+ avg_dataset = MusdbDataset(tracks = avgmus_tracks)
343
+ avg_loader = DataLoader(
344
+ dataset = avg_dataset,
345
+ num_workers = 1,
346
+ batch_size = 1,
347
+ shuffle = False
348
+ )
349
+ at_wrapper = AutoTaggingWarpper(
350
+ at_model = at_model,
351
+ config = config,
352
+ target_keys = config.test_key
353
+ )
354
+ trainer.test(at_wrapper, test_dataloaders = avg_loader)
355
+ avg_at = at_wrapper.avg_at
356
+
357
+ model = ZeroShotASP(
358
+ channels = 1, config = config,
359
+ at_model = at_model,
360
+ dataset = dataset
361
+ )
362
+ ckpt = torch.load(config.resume_checkpoint, map_location="cpu")
363
+ model.load_state_dict(ckpt["state_dict"], strict= False)
364
+ exp_model = SeparatorModel(
365
+ model = model,
366
+ config = config,
367
+ target_keys = config.test_key,
368
+ avg_at = avg_at,
369
+ using_wiener = config.using_wiener
370
+ )
371
+ trainer.test(exp_model, test_dataloaders = loader)
372
+
373
+ def train():
374
+ # set exp settings
375
+ # device_name = "cuda" if torch.cuda.is_available() else "cpu"
376
+ # device = torch.device("cuda")
377
+
378
+ device_num = torch.cuda.device_count()
379
+ print("each batch size:", config.batch_size // device_num)
380
+
381
+ train_index_path = os.path.join(config.dataset_path, "hdf5s","indexes", config.index_type + ".h5")
382
+ train_idc = np.load(os.path.join(config.idc_path, config.index_type + "_idc.npy"), allow_pickle = True)
383
+
384
+ eval_index_path = os.path.join(config.dataset_path,"hdf5s", "indexes", "eval.h5")
385
+ eval_idc = np.load(os.path.join(config.idc_path, "eval_idc.npy"), allow_pickle = True)
386
+
387
+ # set exp folder
388
+ exp_dir = os.path.join(config.workspace, "results", config.exp_name)
389
+ checkpoint_dir = os.path.join(config.workspace, "results", config.exp_name, "checkpoint")
390
+
391
+ if not config.debug:
392
+ create_folder(os.path.join(config.workspace, "results"))
393
+ create_folder(exp_dir)
394
+ create_folder(checkpoint_dir)
395
+ dump_config(config, os.path.join(exp_dir, config.exp_name), False)
396
+
397
+ # load data
398
+ # import dataset LGSPDataset (latent general source separation) and sampler
399
+ dataset = LGSPDataset(
400
+ index_path = train_index_path,
401
+ idc = train_idc,
402
+ config = config,
403
+ factor = 0.05,
404
+ eval_mode = False
405
+ )
406
+ eval_dataset = LGSPDataset(
407
+ index_path = eval_index_path,
408
+ idc = eval_idc,
409
+ config = config,
410
+ factor = 0.05,
411
+ eval_mode = True
412
+ )
413
+
414
+ audioset_data = data_prep(train_dataset=dataset,eval_dataset=eval_dataset,device_num=device_num, config=config)
415
+ checkpoint_callback = ModelCheckpoint(
416
+ monitor = "mixture_sdr",
417
+ filename='l-{epoch:d}-{mixture_sdr:.3f}-{clean_sdr:.3f}-{silence_sdr:.3f}',
418
+ save_top_k = 10,
419
+ mode = "max"
420
+ )
421
+ # infer at model
422
+ sed_model = HTSAT_Swin_Transformer(
423
+ spec_size=htsat_config.htsat_spec_size,
424
+ patch_size=htsat_config.htsat_patch_size,
425
+ in_chans=1,
426
+ num_classes=htsat_config.classes_num,
427
+ window_size=htsat_config.htsat_window_size,
428
+ config = htsat_config,
429
+ depths = htsat_config.htsat_depth,
430
+ embed_dim = htsat_config.htsat_dim,
431
+ patch_stride=htsat_config.htsat_stride,
432
+ num_heads=htsat_config.htsat_num_head
433
+ )
434
+ at_model = SEDWrapper(
435
+ sed_model = sed_model,
436
+ config = htsat_config,
437
+ dataset = None
438
+ )
439
+ # load the checkpoint
440
+ ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu")
441
+ at_model.load_state_dict(ckpt["state_dict"])
442
+
443
+ trainer = pl.Trainer(
444
+ deterministic=True,
445
+ default_root_dir = checkpoint_dir,
446
+ gpus = device_num,
447
+ val_check_interval = 0.2,
448
+ # check_val_every_n_epoch = 1,
449
+ max_epochs = config.max_epoch,
450
+ auto_lr_find = True,
451
+ sync_batchnorm = True,
452
+ callbacks = [checkpoint_callback],
453
+ accelerator = "ddp" if device_num > 1 else None,
454
+ resume_from_checkpoint = None, #config.resume_checkpoint,
455
+ replace_sampler_ddp = False,
456
+ gradient_clip_val=1.0,
457
+ num_sanity_val_steps = 0,
458
+ )
459
+ model = ZeroShotASP(
460
+ channels = 1, config = config,
461
+ at_model = at_model,
462
+ dataset = dataset
463
+ )
464
+ if config.resume_checkpoint is not None:
465
+ ckpt = torch.load(config.resume_checkpoint, map_location="cpu")
466
+ model.load_state_dict(ckpt["state_dict"])
467
+ # trainer.test(model, datamodule = audioset_data)
468
+ trainer.fit(model, audioset_data)
469
+
470
+ def main():
471
+ parser = argparse.ArgumentParser(description="latent genreal source separation parser")
472
+ subparsers = parser.add_subparsers(dest = "mode")
473
+ parser_train = subparsers.add_parser("train")
474
+ parser_test = subparsers.add_parser("test")
475
+ parser_musdb = subparsers.add_parser("musdb_process")
476
+ parser_saveidc = subparsers.add_parser("save_idc")
477
+ parser_wa = subparsers.add_parser("weight_average")
478
+ parser_infer = subparsers.add_parser("inference")
479
+ args = parser.parse_args()
480
+ # default settings
481
+ logging.basicConfig(level=logging.INFO)
482
+ pl.utilities.seed.seed_everything(seed = config.random_seed)
483
+
484
+ if args.mode == "train":
485
+ train()
486
+ elif args.mode == "test":
487
+ test()
488
+ elif args.mode == "musdb_process":
489
+ process_musdb()
490
+ elif args.mode == "weight_average":
491
+ weight_average()
492
+ elif args.mode == "save_idc":
493
+ save_idc()
494
+ elif args.mode == "inference":
495
+ inference()
496
+ else:
497
+ raise Exception("Error Mode!")
498
+
499
+
500
+ if __name__ == '__main__':
501
+ main()
502
+
opt_thres.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f40e97a4946e70392576d4f4c171596bcb6243883a54f48aaa9ae5b86c0976c
3
+ size 13585
predict.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import types
3
+
4
+ import librosa
5
+ import numpy as np
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ from torch.utils.data import DataLoader
9
+
10
+ import htsat_config
11
+ from cog import BasePredictor, Input, Path
12
+ from data_processor import MusdbDataset
13
+ from models.asp_model import AutoTaggingWarpper, SeparatorModel, ZeroShotASP
14
+ from models.htsat import HTSAT_Swin_Transformer
15
+ from sed_model import SEDWrapper
16
+ from utils import prepprocess_audio
17
+
18
+ def get_inference_configs():
19
+ config = types.SimpleNamespace()
20
+ config.ckpt_path = "pretrained/zeroshot_asp_full.ckpt"
21
+ config.sed_ckpt_path = "pretrained/htsat_audioset_2048d.ckpt"
22
+ config.wave_output_path = "predict_outputs"
23
+ config.test_key = "query_name"
24
+ config.test_type = "mix"
25
+ config.loss_type = "mae"
26
+ config.infer_type = "mean"
27
+ config.sample_rate = 32000
28
+ config.segment_frames = 200
29
+ config.hop_samples = 320
30
+ config.energy_thres = 0.1
31
+ config.using_whiting = False
32
+ config.latent_dim = 2048
33
+ config.classes_num = 527
34
+ config.overlap_rate = 0.5
35
+ config.num_workers = 1
36
+
37
+ return config
38
+
39
+ def load_models(config):
40
+ sed_model = HTSAT_Swin_Transformer(
41
+ spec_size=htsat_config.htsat_spec_size,
42
+ patch_size=htsat_config.htsat_patch_size,
43
+ in_chans=1,
44
+ num_classes=htsat_config.classes_num,
45
+ window_size=htsat_config.htsat_window_size,
46
+ config=htsat_config,
47
+ depths=htsat_config.htsat_depth,
48
+ embed_dim=htsat_config.htsat_dim,
49
+ patch_stride=htsat_config.htsat_stride,
50
+ num_heads=htsat_config.htsat_num_head,
51
+ )
52
+ at_model = SEDWrapper(sed_model=sed_model, config=htsat_config, dataset=None)
53
+
54
+ ckpt = torch.load(config.sed_ckpt_path, map_location="cpu")
55
+ at_model.load_state_dict(ckpt["state_dict"])
56
+
57
+ at_wrapper = AutoTaggingWarpper(
58
+ at_model=at_model, config=config, target_keys=[config.test_key]
59
+ )
60
+
61
+ asp_model = ZeroShotASP(channels=1, config=config, at_model=at_model, dataset=None)
62
+ ckpt = torch.load(config.ckpt_path, map_location="cpu")
63
+ asp_model.load_state_dict(ckpt["state_dict"], strict=False)
64
+
65
+ return at_wrapper, asp_model
66
+
67
+ def get_dataloader_from_sound_file(sound_file_path, config):
68
+ signal, sampling_rate = librosa.load(str(sound_file_path), sr=None)
69
+ signal = prepprocess_audio(
70
+ signal[:, None], sampling_rate, config.sample_rate, config.test_type
71
+ )
72
+ signal = np.array([signal, signal]) # Duplicate signal for later use
73
+ dataset = MusdbDataset(tracks=[signal])
74
+ data_loader = DataLoader(dataset, num_workers=config.num_workers, batch_size=1, shuffle=False)
75
+ return data_loader
76
+
77
+
78
+ class Predictor(BasePredictor):
79
+ def setup(self):
80
+ self.config = get_inference_configs()
81
+ os.makedirs(self.config.wave_output_path, exist_ok=True)
82
+ self.at_wrapper, self.asp_model = load_models(self.config)
83
+
84
+ def predict(
85
+ self,
86
+ mix_file: Path = Input(description="Reference sound to extract source from"),
87
+ query_file: Path = Input(description="Query sound to be searched and extracted from mix"),
88
+ ) -> Path:
89
+ ref_loader = get_dataloader_from_sound_file(str(mix_file), self.config)
90
+
91
+ query_loader = get_dataloader_from_sound_file(str(query_file), self.config)
92
+
93
+ trainer = pl.Trainer(gpus=1)
94
+ trainer.test(self.at_wrapper, test_dataloaders=query_loader)
95
+ avg_at = self.at_wrapper.avg_at
96
+
97
+ exp_model = SeparatorModel(
98
+ model=self.asp_model,
99
+ config=self.config,
100
+ target_keys=[self.config.test_key],
101
+ avg_at=avg_at,
102
+ using_wiener=False,
103
+ calc_sdr=False,
104
+ output_wav=True,
105
+ )
106
+ trainer.test(exp_model, test_dataloaders=ref_loader)
107
+
108
+ prediction_path = os.path.join(
109
+ self.config.wave_output_path, f"0_{self.config.test_key}_pred_(0.0).wav"
110
+ )
111
+ return prediction_path
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h5py==3.6.0
2
+ hydra-core>=1.0
3
+ librosa==0.8.1
4
+ musdb==0.4.0
5
+ museval==0.4.0
6
+ noisereduce==2.0.0
7
+ numba==0.55.1
8
+ numpy==1.21.5
9
+ omegaconf>=2.0.0
10
+ pytorch_lightning==1.5.9
11
+ scikit_learn==1.0.2
12
+ scipy==1.7.3
13
+ soundfile==0.10.3.post1
14
+ tensorboard==2.8.0
15
+ torch==1.10.2
16
+ torchaudio==0.10.2
17
+ torchcontrib==0.0.2
18
+ torchlibrosa==0.0.9
19
+ tqdm==4.62.3
sed_model.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
3
+ # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
4
+ # The Model Training Wrapper
5
+ import numpy as np
6
+ import librosa
7
+ import os
8
+ import sys
9
+ import math
10
+ import bisect
11
+ import pickle
12
+ from numpy.lib.function_base import average
13
+ from sklearn import metrics
14
+ import soundfile as sf
15
+ from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
16
+
17
+ import tensorboard
18
+ import torch
19
+ import torchaudio
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ import torch.utils.checkpoint as cp
23
+ import torch.optim as optim
24
+ from torch.nn.parameter import Parameter
25
+ import torch.distributed as dist
26
+ from torchlibrosa.stft import STFT, ISTFT, magphase
27
+ import pytorch_lightning as pl
28
+ from htsat_utils import do_mixup, get_mix_lambda, do_mixup_label, get_loss_func, d_prime
29
+ import random
30
+
31
+ from torchcontrib.optim import SWA
32
+
33
+
34
+ class SEDWrapper(pl.LightningModule):
35
+ def __init__(self, sed_model, config, dataset):
36
+ super().__init__()
37
+ self.sed_model = sed_model
38
+ self.config = config
39
+ self.dataset = dataset
40
+ self.loss_func = get_loss_func(config.loss_type)
41
+
42
+ def evaluate_metric(self, pred, ans):
43
+ ap = []
44
+ if self.config.dataset_type == "audioset":
45
+ mAP = np.mean(average_precision_score(ans, pred, average = None))
46
+ mAUC = np.mean(roc_auc_score(ans, pred, average = None))
47
+ dprime = d_prime(mAUC)
48
+ return {"mAP": mAP, "mAUC": mAUC, "dprime": dprime}
49
+ else:
50
+ acc = accuracy_score(ans, np.argmax(pred, 1))
51
+ return {"acc": acc}
52
+ def forward(self, x, mix_lambda = None):
53
+ output_dict = self.sed_model(x, mix_lambda)
54
+ return output_dict["clipwise_output"], output_dict["framewise_output"]
55
+
56
+ def inference(self, x):
57
+ self.device_type = next(self.parameters()).device
58
+ self.eval()
59
+ x = torch.from_numpy(x).float().to(self.device_type)
60
+ output_dict = self.sed_model(x, None, True)
61
+ for key in output_dict.keys():
62
+ output_dict[key] = output_dict[key].detach().cpu().numpy()
63
+ return output_dict
64
+
65
+ def training_step(self, batch, batch_idx):
66
+ self.device_type = next(self.parameters()).device
67
+ mix_lambda = torch.from_numpy(get_mix_lambda(0.5, len(batch["waveform"]))).to(self.device_type)
68
+ # Another Choice: also mixup the target, but AudioSet is not a perfect data
69
+ # so "adding noise" might be better than purly "mix"
70
+ # batch["target"] = do_mixup_label(batch["target"])
71
+ # batch["target"] = do_mixup(batch["target"], mix_lambda)
72
+
73
+ pred, _ = self(batch["waveform"], mix_lambda)
74
+ loss = self.loss_func(pred, batch["target"])
75
+ self.log("loss", loss, on_epoch= True, prog_bar=True)
76
+ return loss
77
+ def training_epoch_end(self, outputs):
78
+ # Change: SWA, deprecated
79
+ # for opt in self.trainer.optimizers:
80
+ # if not type(opt) is SWA:
81
+ # continue
82
+ # opt.swap_swa_sgd()
83
+ self.dataset.generate_queue()
84
+
85
+
86
+ def validation_step(self, batch, batch_idx):
87
+ pred, _ = self(batch["waveform"])
88
+ return [pred.detach(), batch["target"].detach()]
89
+
90
+ def validation_epoch_end(self, validation_step_outputs):
91
+ self.device_type = next(self.parameters()).device
92
+ pred = torch.cat([d[0] for d in validation_step_outputs], dim = 0)
93
+ target = torch.cat([d[1] for d in validation_step_outputs], dim = 0)
94
+ gather_pred = [torch.zeros_like(pred) for _ in range(dist.get_world_size())]
95
+ gather_target = [torch.zeros_like(target) for _ in range(dist.get_world_size())]
96
+ dist.barrier()
97
+ if self.config.dataset_type == "audioset":
98
+ metric_dict = {
99
+ "mAP": 0.,
100
+ "mAUC": 0.,
101
+ "dprime": 0.
102
+ }
103
+ else:
104
+ metric_dict = {
105
+ "acc":0.
106
+ }
107
+ dist.all_gather(gather_pred, pred)
108
+ dist.all_gather(gather_target, target)
109
+ if dist.get_rank() == 0:
110
+ gather_pred = torch.cat(gather_pred, dim = 0).cpu().numpy()
111
+ gather_target = torch.cat(gather_target, dim = 0).cpu().numpy()
112
+ if self.config.dataset_type == "scv2":
113
+ gather_target = np.argmax(gather_target, 1)
114
+ metric_dict = self.evaluate_metric(gather_pred, gather_target)
115
+ print(self.device_type, dist.get_world_size(), metric_dict, flush = True)
116
+
117
+ if self.config.dataset_type == "audioset":
118
+ self.log("mAP", metric_dict["mAP"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
119
+ self.log("mAUC", metric_dict["mAUC"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
120
+ self.log("dprime", metric_dict["dprime"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
121
+ else:
122
+ self.log("acc", metric_dict["acc"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
123
+ dist.barrier()
124
+
125
+ def time_shifting(self, x, shift_len):
126
+ shift_len = int(shift_len)
127
+ new_sample = torch.cat([x[:, shift_len:], x[:, :shift_len]], axis = 1)
128
+ return new_sample
129
+
130
+ def test_step(self, batch, batch_idx):
131
+ self.device_type = next(self.parameters()).device
132
+ preds = []
133
+ # cancel the time shifting optimization because to speed up
134
+ shift_num = 1
135
+ for i in range(shift_num):
136
+ pred, pred_map = self(batch["waveform"])
137
+ preds.append(pred.unsqueeze(0))
138
+ batch["waveform"] = self.time_shifting(batch["waveform"], shift_len = 100 * (i + 1))
139
+ preds = torch.cat(preds, dim=0)
140
+ pred = preds.mean(dim = 0)
141
+ if self.config.fl_local:
142
+ return [
143
+ pred.detach().cpu().numpy(),
144
+ pred_map.detach().cpu().numpy(),
145
+ batch["audio_name"],
146
+ batch["real_len"].cpu().numpy()
147
+ ]
148
+ else:
149
+ return [pred.detach(), batch["target"].detach()]
150
+
151
+ def test_epoch_end(self, test_step_outputs):
152
+ self.device_type = next(self.parameters()).device
153
+ if self.config.fl_local:
154
+ pred = np.concatenate([d[0] for d in test_step_outputs], axis = 0)
155
+ pred_map = np.concatenate([d[1] for d in test_step_outputs], axis = 0)
156
+ audio_name = np.concatenate([d[2] for d in test_step_outputs], axis = 0)
157
+ real_len = np.concatenate([d[3] for d in test_step_outputs], axis = 0)
158
+ heatmap_file = os.path.join(self.config.heatmap_dir, self.config.test_file + "_" + str(self.device_type) + ".npy")
159
+ save_npy = [
160
+ {
161
+ "audio_name": audio_name[i],
162
+ "heatmap": pred_map[i],
163
+ "pred": pred[i],
164
+ "real_len":real_len[i]
165
+ }
166
+ for i in range(len(pred))
167
+ ]
168
+ np.save(heatmap_file, save_npy)
169
+ else:
170
+ self.device_type = next(self.parameters()).device
171
+ pred = torch.cat([d[0] for d in test_step_outputs], dim = 0)
172
+ target = torch.cat([d[1] for d in test_step_outputs], dim = 0)
173
+ gather_pred = [torch.zeros_like(pred) for _ in range(dist.get_world_size())]
174
+ gather_target = [torch.zeros_like(target) for _ in range(dist.get_world_size())]
175
+ dist.barrier()
176
+ if self.config.dataset_type == "audioset":
177
+ metric_dict = {
178
+ "mAP": 0.,
179
+ "mAUC": 0.,
180
+ "dprime": 0.
181
+ }
182
+ else:
183
+ metric_dict = {
184
+ "acc":0.
185
+ }
186
+ dist.all_gather(gather_pred, pred)
187
+ dist.all_gather(gather_target, target)
188
+ if dist.get_rank() == 0:
189
+ gather_pred = torch.cat(gather_pred, dim = 0).cpu().numpy()
190
+ gather_target = torch.cat(gather_target, dim = 0).cpu().numpy()
191
+ if self.config.dataset_type == "scv2":
192
+ gather_target = np.argmax(gather_target, 1)
193
+ metric_dict = self.evaluate_metric(gather_pred, gather_target)
194
+ print(self.device_type, dist.get_world_size(), metric_dict, flush = True)
195
+ if self.config.dataset_type == "audioset":
196
+ self.log("mAP", metric_dict["mAP"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
197
+ self.log("mAUC", metric_dict["mAUC"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
198
+ self.log("dprime", metric_dict["dprime"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
199
+ else:
200
+ self.log("acc", metric_dict["acc"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
201
+ dist.barrier()
202
+
203
+
204
+ def configure_optimizers(self):
205
+ optimizer = optim.AdamW(
206
+ filter(lambda p: p.requires_grad, self.parameters()),
207
+ lr = self.config.learning_rate,
208
+ betas = (0.9, 0.999), eps = 1e-08, weight_decay = 0.05,
209
+ )
210
+ # Change: SWA, deprecated
211
+ # optimizer = SWA(optimizer, swa_start=10, swa_freq=5)
212
+ def lr_foo(epoch):
213
+ if epoch < 3:
214
+ # warm up lr
215
+ lr_scale = self.config.lr_rate[epoch]
216
+ else:
217
+ # warmup schedule
218
+ lr_pos = int(-1 - bisect.bisect_left(self.config.lr_scheduler_epoch, epoch))
219
+ if lr_pos < -3:
220
+ lr_scale = max(self.config.lr_rate[0] * (0.98 ** epoch), 0.03 )
221
+ else:
222
+ lr_scale = self.config.lr_rate[lr_pos]
223
+ return lr_scale
224
+ scheduler = optim.lr_scheduler.LambdaLR(
225
+ optimizer,
226
+ lr_lambda=lr_foo
227
+ )
228
+
229
+ return [optimizer], [scheduler]
230
+
231
+
232
+
233
+ class Ensemble_SEDWrapper(pl.LightningModule):
234
+ def __init__(self, sed_models, config, dataset):
235
+ super().__init__()
236
+
237
+ self.sed_models = nn.ModuleList(sed_models)
238
+ self.config = config
239
+ self.dataset = dataset
240
+
241
+ def evaluate_metric(self, pred, ans):
242
+ if self.config.dataset_type == "audioset":
243
+ mAP = np.mean(average_precision_score(ans, pred, average = None))
244
+ mAUC = np.mean(roc_auc_score(ans, pred, average = None))
245
+ dprime = d_prime(mAUC)
246
+ return {"mAP": mAP, "mAUC": mAUC, "dprime": dprime}
247
+ else:
248
+ acc = accuracy_score(ans, np.argmax(pred, 1))
249
+ return {"acc": acc}
250
+
251
+ def forward(self, x, sed_index, mix_lambda = None):
252
+ self.sed_models[sed_index].eval()
253
+ preds = []
254
+ pred_maps = []
255
+ # cancel the time shifting optimization because to speed up
256
+ shift_num = 1
257
+ for i in range(shift_num):
258
+ pred, pred_map = self.sed_models[sed_index](x)
259
+ pred_maps.append(pred_map.unsqueeze(0))
260
+ preds.append(pred.unsqueeze(0))
261
+ x = self.time_shifting(x, shift_len = 100 * (i + 1))
262
+ preds = torch.cat(preds, dim=0)
263
+ pred_maps = torch.cat(pred_maps, dim = 0)
264
+ pred = preds.mean(dim = 0)
265
+ pred_map = pred_maps.mean(dim = 0)
266
+ return pred, pred_map
267
+
268
+
269
+ def time_shifting(self, x, shift_len):
270
+ shift_len = int(shift_len)
271
+ new_sample = torch.cat([x[:, shift_len:], x[:, :shift_len]], axis = 1)
272
+ return new_sample
273
+
274
+ def test_step(self, batch, batch_idx):
275
+ self.device_type = next(self.parameters()).device
276
+ if self.config.fl_local:
277
+ pred = torch.zeros(len(batch["waveform"]), self.config.classes_num).float().to(self.device_type)
278
+ pred_map = torch.zeros(len(batch["waveform"]), 1024, self.config.classes_num).float().to(self.device_type)
279
+ for j in range(len(self.sed_models)):
280
+ temp_pred, temp_pred_map = self(batch["waveform"], j)
281
+ pred = pred + temp_pred
282
+ pred_map = pred_map + temp_pred_map
283
+ pred = pred / len(self.sed_models)
284
+ pred_map = pred_map / len(self.sed_models)
285
+ return [
286
+ pred.detach().cpu().numpy(),
287
+ pred_map.detach().cpu().numpy(),
288
+ batch["audio_name"],
289
+ batch["real_len"].cpu().numpy()
290
+ ]
291
+ else:
292
+ pred = torch.zeros(len(batch["waveform"]), self.config.classes_num).float().to(self.device_type)
293
+ for j in range(len(self.sed_models)):
294
+ temp_pred, _ = self(batch["waveform"], j)
295
+ pred = pred + temp_pred
296
+ pred = pred / len(self.sed_models)
297
+ return [
298
+ pred.detach(),
299
+ batch["target"].detach(),
300
+ ]
301
+
302
+ def test_epoch_end(self, test_step_outputs):
303
+ self.device_type = next(self.parameters()).device
304
+ if self.config.fl_local:
305
+ pred = np.concatenate([d[0] for d in test_step_outputs], axis = 0)
306
+ pred_map = np.concatenate([d[1] for d in test_step_outputs], axis = 0)
307
+ audio_name = np.concatenate([d[2] for d in test_step_outputs], axis = 0)
308
+ real_len = np.concatenate([d[3] for d in test_step_outputs], axis = 0)
309
+ heatmap_file = os.path.join(self.config.heatmap_dir, self.config.test_file + "_" + str(self.device_type) + ".npy")
310
+ print(pred.shape)
311
+ print(pred_map.shape)
312
+ print(real_len.shape)
313
+ save_npy = [
314
+ {
315
+ "audio_name": audio_name[i],
316
+ "heatmap": pred_map[i],
317
+ "pred": pred[i],
318
+ "real_len":real_len[i]
319
+ }
320
+ for i in range(len(pred))
321
+ ]
322
+ np.save(heatmap_file, save_npy)
323
+ else:
324
+ pred = torch.cat([d[0] for d in test_step_outputs], dim = 0)
325
+ target = torch.cat([d[1] for d in test_step_outputs], dim = 0)
326
+ gather_pred = [torch.zeros_like(pred) for _ in range(dist.get_world_size())]
327
+ gather_target = [torch.zeros_like(target) for _ in range(dist.get_world_size())]
328
+
329
+ dist.barrier()
330
+ if self.config.dataset_type == "audioset":
331
+ metric_dict = {
332
+ "mAP": 0.,
333
+ "mAUC": 0.,
334
+ "dprime": 0.
335
+ }
336
+ else:
337
+ metric_dict = {
338
+ "acc":0.
339
+ }
340
+ dist.all_gather(gather_pred, pred)
341
+ dist.all_gather(gather_target, target)
342
+ if dist.get_rank() == 0:
343
+ gather_pred = torch.cat(gather_pred, dim = 0).cpu().numpy()
344
+ gather_target = torch.cat(gather_target, dim = 0).cpu().numpy()
345
+ if self.config.dataset_type == "scv2":
346
+ gather_target = np.argmax(gather_target, 1)
347
+ metric_dict = self.evaluate_metric(gather_pred, gather_target)
348
+ print(self.device_type, dist.get_world_size(), metric_dict, flush = True)
349
+ if self.config.dataset_type == "audioset":
350
+ self.log("mAP", metric_dict["mAP"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
351
+ self.log("mAUC", metric_dict["mAUC"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
352
+ self.log("dprime", metric_dict["dprime"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
353
+ else:
354
+ self.log("acc", metric_dict["acc"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
355
+ dist.barrier()
356
+
357
+
358
+
utils.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
3
+ # Zero-shot Audio Source Separation via Query-based Learning from Weakly-labeled Data
4
+ # Some Common Methods
5
+
6
+ import numpy as np
7
+ from scipy.signal import butter, filtfilt
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch import Tensor
11
+ from typing import Optional
12
+ import logging
13
+ import os
14
+ import sys
15
+ import h5py
16
+ import csv
17
+ import time
18
+ import json
19
+ import museval
20
+ import librosa
21
+ from datetime import datetime
22
+
23
+ def create_folder(fd):
24
+ if not os.path.exists(fd):
25
+ os.makedirs(fd)
26
+
27
+ def get_filename(path):
28
+ path = os.path.realpath(path)
29
+ na_ext = path.split('/')[-1]
30
+ na = os.path.splitext(na_ext)[0]
31
+ return na
32
+
33
+ def get_sub_filepaths(folder):
34
+ paths = []
35
+ for root, dirs, files in os.walk(folder):
36
+ for name in files:
37
+ path = os.path.join(root, name)
38
+ paths.append(path)
39
+ return paths
40
+
41
+ def np_to_pytorch(x, device = None):
42
+ if 'float' in str(x.dtype):
43
+ x = torch.Tensor(x)
44
+ elif 'int' in str(x.dtype):
45
+ x = torch.LongTensor(x)
46
+ else:
47
+ return x
48
+ return x.to(device)
49
+
50
+ def count_parameters(model):
51
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
52
+
53
+ def calculate_average_energy(x):
54
+ return np.mean(np.square(x))
55
+
56
+ def id_to_one_hot(id, classes_num):
57
+ one_hot = np.zeros(classes_num)
58
+ one_hot[id] = 1
59
+ return one_hot
60
+
61
+ def ids_to_hots(ids, classes_num):
62
+ hots = np.zeros(classes_num)
63
+ for id in ids:
64
+ hots[id] = 1
65
+ return hots
66
+
67
+ def float32_to_int16(x):
68
+ assert np.max(np.abs(x)) <= 1.
69
+ return (x * 32767.).astype(np.int16)
70
+
71
+ def int16_to_float32(x):
72
+ return (x / 32767.).astype(np.float32)
73
+
74
+ def collect_fn(list_data_dict):
75
+ np_data_dict = {}
76
+ for key in list_data_dict[0].keys():
77
+ np_data_dict[key] = np.array([data_dict[key] for data_dict in list_data_dict])
78
+ return np_data_dict
79
+
80
+ def dump_config(config, filename, include_time = False):
81
+ save_time = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
82
+ config_json = {}
83
+ for key in dir(config):
84
+ if not key.startswith("_"):
85
+ config_json[key] = eval("config." + key)
86
+ if include_time:
87
+ filename = filename + "_" + save_time
88
+ with open(filename + ".json", "w") as f:
89
+ json.dump(config_json, f ,indent=4)
90
+
91
+
92
+ def get_segment_bgn_end_samples(anchor_index, segment_frames, hop_samples, clip_samples):
93
+ bgn_frame = anchor_index - segment_frames // 2
94
+ end_frame = anchor_index + segment_frames // 2
95
+ bgn_sample = bgn_frame * hop_samples
96
+ end_sample = end_frame * hop_samples
97
+
98
+ segment_samples = segment_frames * hop_samples
99
+
100
+ if bgn_sample < 0:
101
+ bgn_sample = 0
102
+ end_sample = segment_samples
103
+
104
+ if end_sample > clip_samples:
105
+ bgn_sample = clip_samples - segment_samples
106
+ end_sample = clip_samples
107
+
108
+ return bgn_sample, end_sample
109
+
110
+ def get_mix_data(waveforms, con_vectors, class_ids, indexes, mix_type = "mixture"):
111
+ # define return data
112
+ mixtures = []
113
+ sources = []
114
+ conditions = []
115
+ gds = []
116
+ for i in range(0, len(indexes), 2):
117
+ n1 = indexes[i]
118
+ n2 = indexes[i + 1]
119
+ # energy normalization
120
+ e1 = np.mean(np.square(waveforms[n1]))
121
+ e2 = np.mean(np.square(waveforms[n2]))
122
+ ratio = (e1 / max(1e-8, e2)) ** 0.5
123
+ ratio = np.clip(ratio, 0.02, 50)
124
+ waveforms[n2] *= ratio
125
+ mixture = waveforms[n1] + waveforms[n2]
126
+ # form data
127
+ if mix_type == "clean":
128
+ mixtures.append(waveforms[n1])
129
+ mixtures.append(waveforms[n2])
130
+ sources.append(waveforms[n1])
131
+ sources.append(waveforms[n2])
132
+ elif mix_type == "silence":
133
+ mixtures.append(waveforms[n2])
134
+ mixtures.append(waveforms[n1])
135
+ sources.append(np.zeros_like(waveforms[n1]))
136
+ sources.append(np.zeros_like(waveforms[n2]))
137
+ else:
138
+ mixtures.append(mixture)
139
+ mixtures.append(mixture)
140
+ sources.append(waveforms[n1])
141
+ sources.append(waveforms[n2])
142
+
143
+ conditions.append(con_vectors[n1])
144
+ conditions.append(con_vectors[n2])
145
+ gds.append(class_ids[n1])
146
+ gds.append(class_ids[n2])
147
+ return mixtures, sources, conditions, gds
148
+
149
+ # generate a list
150
+ def get_balanced_class_list(index_path, factor = 3, black_list = None, random_seed = 0):
151
+ # initialization
152
+ random_state = np.random.RandomState(random_seed)
153
+ logging.info("Load Indexes...............")
154
+ with h5py.File(index_path, "r") as hf:
155
+ indexes = hf["index_in_hdf5"][:]
156
+ targets = hf["target"][:].astype(np.float32)
157
+ (audios_num, classes_num) = targets.shape
158
+ # set the indexes per class for balanced list
159
+ indexes_per_class = []
160
+ for k in range(classes_num):
161
+ indexes_per_class.append(
162
+ np.where(targets[:, k] == 1)[0]
163
+ )
164
+
165
+ logging.info("Load Indexes Succeed...............")
166
+
167
+ return indexes_per_class
168
+
169
+ def dataset_worker_init_fn_seed(worker_id):
170
+ seed = np.random.randint(0, 224141) + worker_id * np.random.randint(100,1000)
171
+ print(seed)
172
+ np.random.seed(seed)
173
+
174
+ def calculate_sdr(ref, est, scaling=False):
175
+ s = museval.evaluate(ref[None,:,None], est[None,:,None], win = len(ref), hop = len(ref))
176
+ return s[0][0]
177
+
178
+ def butter_lowpass_filter(data, cuton, cutoff, fs, order):
179
+ normal_cutoff = cutoff / (0.5 * fs)
180
+ normal_cuton = cuton / (0.5 * fs)
181
+ b, a = butter(order, [normal_cuton, normal_cutoff], btype="band", analog=False)
182
+ y = filtfilt(b,a, data)
183
+ return y
184
+
185
+ def calculate_silence_sdr(mixture, est):
186
+ sdr = 10. * (
187
+ np.log10(np.clip(np.mean(mixture ** 2), 1e-8, np.inf)) \
188
+ - np.log10(np.clip(np.mean(est ** 2), 1e-8, np.inf)))
189
+ return sdr
190
+
191
+
192
+ def evaluate_sdr(ref, est, class_ids, mix_type = "mixture"):
193
+ sdr_results = []
194
+ if mix_type == "silence":
195
+ for i in range(len(ref)):
196
+ sdr = calculate_silence_sdr(ref[i,:,0], est[i,:,0])
197
+ sdr_results.append([sdr, class_ids[i]])
198
+ else:
199
+ for i in range(len(ref)):
200
+ if np.sum(ref[i,:,0]) == 0 or np.sum(est[i,:,0]) == 0:
201
+ continue
202
+ else:
203
+ sdr_c = calculate_sdr(ref[i,:,0], est[i,:,0], scaling = True)
204
+ sdr_results.append([sdr_c, class_ids[i]])
205
+ return sdr_results
206
+
207
+ # set the audio into the format that can be fed into the model
208
+ # resample -> convert to mono -> output the audio
209
+ # track [n_sample, n_channel]
210
+ def prepprocess_audio(track, ofs, rfs, mono_type = "mix"):
211
+ if track.shape[-1] > 1:
212
+ # stereo
213
+ if mono_type == "mix":
214
+ track = np.transpose(track, (1,0))
215
+ track = librosa.to_mono(track)
216
+ elif mono_type == "left":
217
+ track = track[:, 0]
218
+ elif mono_type == "right":
219
+ track = track[:, 1]
220
+ else:
221
+ track = track[:, 0]
222
+ # track [n_sample]
223
+ if ofs != rfs:
224
+ track = librosa.resample(track, ofs, rfs)
225
+ return track
226
+
227
+ # *************************************************
228
+ # all below is referred from the wiener filter code
229
+
230
+ def atan2(y, x):
231
+ r"""Element-wise arctangent function of y/x.
232
+ Returns a new tensor with signed angles in radians.
233
+ It is an alternative implementation of torch.atan2
234
+ Args:
235
+ y (Tensor): First input tensor
236
+ x (Tensor): Second input tensor [shape=y.shape]
237
+ Returns:
238
+ Tensor: [shape=y.shape].
239
+ """
240
+ pi = 2 * torch.asin(torch.tensor(1.0))
241
+ x += ((x == 0) & (y == 0)) * 1.0
242
+ out = torch.atan(y / x)
243
+ out += ((y >= 0) & (x < 0)) * pi
244
+ out -= ((y < 0) & (x < 0)) * pi
245
+ out *= 1 - ((y > 0) & (x == 0)) * 1.0
246
+ out += ((y > 0) & (x == 0)) * (pi / 2)
247
+ out *= 1 - ((y < 0) & (x == 0)) * 1.0
248
+ out += ((y < 0) & (x == 0)) * (-pi / 2)
249
+ return out
250
+
251
+
252
+ # Define basic complex operations on torch.Tensor objects whose last dimension
253
+ # consists in the concatenation of the real and imaginary parts.
254
+ def _norm(x: torch.Tensor) -> torch.Tensor:
255
+ r"""Computes the norm value of a torch Tensor, assuming that it
256
+ comes as real and imaginary part in its last dimension.
257
+ Args:
258
+ x (Tensor): Input Tensor of shape [shape=(..., 2)]
259
+ Returns:
260
+ Tensor: shape as x excluding the last dimension.
261
+ """
262
+ return torch.abs(x[..., 0]) ** 2 + torch.abs(x[..., 1]) ** 2
263
+
264
+
265
+ def _mul_add(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
266
+ """Element-wise multiplication of two complex Tensors described
267
+ through their real and imaginary parts.
268
+ The result is added to the `out` tensor"""
269
+
270
+ # check `out` and allocate it if needed
271
+ target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
272
+ if out is None or out.shape != target_shape:
273
+ out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
274
+ if out is a:
275
+ real_a = a[..., 0]
276
+ out[..., 0] = out[..., 0] + (real_a * b[..., 0] - a[..., 1] * b[..., 1])
277
+ out[..., 1] = out[..., 1] + (real_a * b[..., 1] + a[..., 1] * b[..., 0])
278
+ else:
279
+ out[..., 0] = out[..., 0] + (a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1])
280
+ out[..., 1] = out[..., 1] + (a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0])
281
+ return out
282
+
283
+
284
+ def _mul(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
285
+ """Element-wise multiplication of two complex Tensors described
286
+ through their real and imaginary parts
287
+ can work in place in case out is a only"""
288
+ target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
289
+ if out is None or out.shape != target_shape:
290
+ out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
291
+ if out is a:
292
+ real_a = a[..., 0]
293
+ out[..., 0] = real_a * b[..., 0] - a[..., 1] * b[..., 1]
294
+ out[..., 1] = real_a * b[..., 1] + a[..., 1] * b[..., 0]
295
+ else:
296
+ out[..., 0] = a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1]
297
+ out[..., 1] = a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0]
298
+ return out
299
+
300
+
301
+ def _inv(z: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
302
+ """Element-wise multiplicative inverse of a Tensor with complex
303
+ entries described through their real and imaginary parts.
304
+ can work in place in case out is z"""
305
+ ez = _norm(z)
306
+ if out is None or out.shape != z.shape:
307
+ out = torch.zeros_like(z)
308
+ out[..., 0] = z[..., 0] / ez
309
+ out[..., 1] = -z[..., 1] / ez
310
+ return out
311
+
312
+
313
+ def _conj(z, out: Optional[torch.Tensor] = None) -> torch.Tensor:
314
+ """Element-wise complex conjugate of a Tensor with complex entries
315
+ described through their real and imaginary parts.
316
+ can work in place in case out is z"""
317
+ if out is None or out.shape != z.shape:
318
+ out = torch.zeros_like(z)
319
+ out[..., 0] = z[..., 0]
320
+ out[..., 1] = -z[..., 1]
321
+ return out
322
+
323
+
324
+ def _invert(M: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
325
+ """
326
+ Invert 1x1 or 2x2 matrices
327
+ Will generate errors if the matrices are singular: user must handle this
328
+ through his own regularization schemes.
329
+ Args:
330
+ M (Tensor): [shape=(..., nb_channels, nb_channels, 2)]
331
+ matrices to invert: must be square along dimensions -3 and -2
332
+ Returns:
333
+ invM (Tensor): [shape=M.shape]
334
+ inverses of M
335
+ """
336
+ nb_channels = M.shape[-2]
337
+
338
+ if out is None or out.shape != M.shape:
339
+ out = torch.empty_like(M)
340
+
341
+ if nb_channels == 1:
342
+ # scalar case
343
+ out = _inv(M, out)
344
+ elif nb_channels == 2:
345
+ # two channels case: analytical expression
346
+
347
+ # first compute the determinent
348
+ det = _mul(M[..., 0, 0, :], M[..., 1, 1, :])
349
+ det = det - _mul(M[..., 0, 1, :], M[..., 1, 0, :])
350
+ # invert it
351
+ invDet = _inv(det)
352
+
353
+ # then fill out the matrix with the inverse
354
+ out[..., 0, 0, :] = _mul(invDet, M[..., 1, 1, :], out[..., 0, 0, :])
355
+ out[..., 1, 0, :] = _mul(-invDet, M[..., 1, 0, :], out[..., 1, 0, :])
356
+ out[..., 0, 1, :] = _mul(-invDet, M[..., 0, 1, :], out[..., 0, 1, :])
357
+ out[..., 1, 1, :] = _mul(invDet, M[..., 0, 0, :], out[..., 1, 1, :])
358
+ else:
359
+ raise Exception("Only 2 channels are supported for the torch version.")
360
+ return out
361
+
362
+
363
+
364
+ def expectation_maximization(
365
+ y: torch.Tensor,
366
+ x: torch.Tensor,
367
+ iterations: int = 2,
368
+ eps: float = 1e-10,
369
+ batch_size: int = 200,
370
+ ):
371
+ r"""Expectation maximization algorithm, for refining source separation
372
+ estimates.
373
+ Args:
374
+ y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)]
375
+ initial estimates for the sources
376
+ x (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2)]
377
+ complex STFT of the mixture signal
378
+ iterations (int): [scalar]
379
+ number of iterations for the EM algorithm.
380
+ eps (float or None): [scalar]
381
+ The epsilon value to use for regularization and filters.
382
+ Returns:
383
+ y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)]
384
+ estimated sources after iterations
385
+ v (Tensor): [shape=(nb_frames, nb_bins, nb_sources)]
386
+ estimated power spectral densities
387
+ R (Tensor): [shape=(nb_bins, nb_channels, nb_channels, 2, nb_sources)]
388
+ estimated spatial covariance matrices
389
+ """
390
+ # dimensions
391
+ (nb_frames, nb_bins, nb_channels) = x.shape[:-1]
392
+ nb_sources = y.shape[-1]
393
+
394
+ regularization = torch.cat(
395
+ (
396
+ torch.eye(nb_channels, dtype=x.dtype, device=x.device)[..., None],
397
+ torch.zeros((nb_channels, nb_channels, 1), dtype=x.dtype, device=x.device),
398
+ ),
399
+ dim=2,
400
+ )
401
+ regularization = torch.sqrt(torch.as_tensor(eps)) * (
402
+ regularization[None, None, ...].expand((-1, nb_bins, -1, -1, -1))
403
+ )
404
+
405
+ # allocate the spatial covariance matrices
406
+ R = [
407
+ torch.zeros((nb_bins, nb_channels, nb_channels, 2), dtype=x.dtype, device=x.device)
408
+ for j in range(nb_sources)
409
+ ]
410
+ weight: torch.Tensor = torch.zeros((nb_bins,), dtype=x.dtype, device=x.device)
411
+
412
+ v: torch.Tensor = torch.zeros((nb_frames, nb_bins, nb_sources), dtype=x.dtype, device=x.device)
413
+ for it in range(iterations):
414
+ # constructing the mixture covariance matrix. Doing it with a loop
415
+ # to avoid storing anytime in RAM the whole 6D tensor
416
+
417
+ # update the PSD as the average spectrogram over channels
418
+ v = torch.mean(torch.abs(y[..., 0, :]) ** 2 + torch.abs(y[..., 1, :]) ** 2, dim=-2)
419
+
420
+ # update spatial covariance matrices (weighted update)
421
+ for j in range(nb_sources):
422
+ R[j] = torch.tensor(0.0, device=x.device)
423
+ weight = torch.tensor(eps, device=x.device)
424
+ pos: int = 0
425
+ batch_size = batch_size if batch_size else nb_frames
426
+ while pos < nb_frames:
427
+ t = torch.arange(pos, min(nb_frames, pos + batch_size))
428
+ pos = int(t[-1]) + 1
429
+
430
+ R[j] = R[j] + torch.sum(_covariance(y[t, ..., j]), dim=0)
431
+ weight = weight + torch.sum(v[t, ..., j], dim=0)
432
+ R[j] = R[j] / weight[..., None, None, None]
433
+ weight = torch.zeros_like(weight)
434
+
435
+ # cloning y if we track gradient, because we're going to update it
436
+ if y.requires_grad:
437
+ y = y.clone()
438
+
439
+ pos = 0
440
+ while pos < nb_frames:
441
+ t = torch.arange(pos, min(nb_frames, pos + batch_size))
442
+ pos = int(t[-1]) + 1
443
+
444
+ y[t, ...] = torch.tensor(0.0, device=x.device)
445
+
446
+ # compute mix covariance matrix
447
+ Cxx = regularization
448
+ for j in range(nb_sources):
449
+ Cxx = Cxx + (v[t, ..., j, None, None, None] * R[j][None, ...].clone())
450
+
451
+ # invert it
452
+ inv_Cxx = _invert(Cxx)
453
+
454
+ # separate the sources
455
+ for j in range(nb_sources):
456
+
457
+ # create a wiener gain for this source
458
+ gain = torch.zeros_like(inv_Cxx)
459
+
460
+ # computes multichannel Wiener gain as v_j R_j inv_Cxx
461
+ indices = torch.cartesian_prod(
462
+ torch.arange(nb_channels),
463
+ torch.arange(nb_channels),
464
+ torch.arange(nb_channels),
465
+ )
466
+ for index in indices:
467
+ gain[:, :, index[0], index[1], :] = _mul_add(
468
+ R[j][None, :, index[0], index[2], :].clone(),
469
+ inv_Cxx[:, :, index[2], index[1], :],
470
+ gain[:, :, index[0], index[1], :],
471
+ )
472
+ gain = gain * v[t, ..., None, None, None, j]
473
+
474
+ # apply it to the mixture
475
+ for i in range(nb_channels):
476
+ y[t, ..., j] = _mul_add(gain[..., i, :], x[t, ..., i, None, :], y[t, ..., j])
477
+
478
+ return y, v, R
479
+
480
+ def _covariance(y_j):
481
+ """
482
+ Compute the empirical covariance for a source.
483
+ Args:
484
+ y_j (Tensor): complex stft of the source.
485
+ [shape=(nb_frames, nb_bins, nb_channels, 2)].
486
+ Returns:
487
+ Cj (Tensor): [shape=(nb_frames, nb_bins, nb_channels, nb_channels, 2)]
488
+ just y_j * conj(y_j.T): empirical covariance for each TF bin.
489
+ """
490
+ (nb_frames, nb_bins, nb_channels) = y_j.shape[:-1]
491
+ Cj = torch.zeros(
492
+ (nb_frames, nb_bins, nb_channels, nb_channels, 2),
493
+ dtype=y_j.dtype,
494
+ device=y_j.device,
495
+ )
496
+ indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels))
497
+ for index in indices:
498
+ Cj[:, :, index[0], index[1], :] = _mul_add(
499
+ y_j[:, :, index[0], :],
500
+ _conj(y_j[:, :, index[1], :]),
501
+ Cj[:, :, index[0], index[1], :],
502
+ )
503
+ return Cj
504
+
505
+ def wiener(
506
+ targets_spectrograms: torch.Tensor,
507
+ mix_stft: torch.Tensor,
508
+ iterations: int = 1,
509
+ softmask: bool = False,
510
+ residual: bool = False,
511
+ scale_factor: float = 10.0,
512
+ eps: float = 1e-10,
513
+ ):
514
+ """Wiener-based separation for multichannel audio.
515
+ Returns:
516
+ Tensor: shape=(nb_frames, nb_bins, nb_channels, complex=2, nb_sources)
517
+ STFT of estimated sources
518
+ """
519
+ if softmask:
520
+ # if we use softmask, we compute the ratio mask for all targets and
521
+ # multiply by the mix stft
522
+ y = (
523
+ mix_stft[..., None]
524
+ * (
525
+ targets_spectrograms
526
+ / (eps + torch.sum(targets_spectrograms, dim=-1, keepdim=True).to(mix_stft.dtype))
527
+ )[..., None, :]
528
+ )
529
+ else:
530
+ # otherwise, we just multiply the targets spectrograms with mix phase
531
+ # we tacitly assume that we have magnitude estimates.
532
+ angle = atan2(mix_stft[..., 1], mix_stft[..., 0])[..., None]
533
+ nb_sources = targets_spectrograms.shape[-1]
534
+ y = torch.zeros(
535
+ mix_stft.shape + (nb_sources,), dtype=mix_stft.dtype, device=mix_stft.device
536
+ )
537
+ y[..., 0, :] = targets_spectrograms * torch.cos(angle)
538
+ y[..., 1, :] = targets_spectrograms * torch.sin(angle)
539
+
540
+ if residual:
541
+ # if required, adding an additional target as the mix minus
542
+ # available targets
543
+ y = torch.cat([y, mix_stft[..., None] - y.sum(dim=-1, keepdim=True)], dim=-1)
544
+
545
+ if iterations == 0:
546
+ return y
547
+
548
+ # we need to refine the estimates. Scales down the estimates for
549
+ # numerical stability
550
+ max_abs = torch.max(
551
+ torch.as_tensor(1.0, dtype=mix_stft.dtype, device=mix_stft.device),
552
+ torch.sqrt(_norm(mix_stft)).max() / scale_factor,
553
+ )
554
+
555
+ mix_stft = mix_stft / max_abs
556
+ y = y / max_abs
557
+
558
+ # call expectation maximization
559
+ y = expectation_maximization(y, mix_stft, iterations, eps=eps)[0]
560
+
561
+ # scale estimates up again
562
+ y = y * max_abs
563
+ return y
564
+
565
+ def split_nparray_with_overlap(array, array_size, overlap_size):
566
+ result = []
567
+ element_size = int(len(array) / array_size)
568
+ for i in range(array_size):
569
+ offset = int(i * element_size)
570
+ last_loop = i == array_size
571
+ chunk = array[offset : offset + element_size + (0 if last_loop else overlap_size)]
572
+ chunk = chunk.copy()
573
+ chunk.resize(element_size + overlap_size, refcheck = False)
574
+ result.append(chunk)
575
+
576
+ return np.array(result)
577
+
578
+
579
+
580
+
zero_shot_create_vector.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
3
+ # Zero-shot Audio Source Separation via Query-based Learning from Weakly-labeled Data
4
+ # The Main Script
5
+
6
+ import os
7
+ gpu_use = 0
8
+ # this is to avoid the sdr calculation from occupying all cpus
9
+ os.environ["OMP_NUM_THREADS"] = "4"
10
+ os.environ["OPENBLAS_NUM_THREADS"] = "4"
11
+ os.environ["MKL_NUM_THREADS"] = "6"
12
+ os.environ["VECLIB_MAXIMUM_THREADS"] = "4"
13
+ os.environ["NUMEXPR_NUM_THREADS"] = "6"
14
+ os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(gpu_use)
15
+
16
+ import librosa
17
+ import numpy as np
18
+ import soundfile as sf
19
+ from hashlib import md5
20
+
21
+ import torch
22
+ from torch.utils.data import DataLoader
23
+ from utils import collect_fn, dump_config, create_folder, prepprocess_audio
24
+ from models.asp_model import ZeroShotASP, SeparatorModel, AutoTaggingWarpper, WhitingWarpper
25
+ from data_processor import LGSPDataset, MusdbDataset
26
+ import config
27
+ import htsat_config
28
+ from models.htsat import HTSAT_Swin_Transformer
29
+ from sed_model import SEDWrapper
30
+
31
+ import pytorch_lightning as pl
32
+
33
+ import time
34
+ import tqdm
35
+ import warnings
36
+ import shutil
37
+ import pickle
38
+ warnings.filterwarnings("ignore")
39
+
40
+ # use the model to quickly separate a track given a query
41
+ # it requires four variables in config.py:
42
+ # inference_file: the track you want to separate
43
+ # inference_query: a **folder** containing all samples from the same source
44
+ # test_key: ["name"] indicate the source name (just a name for final output, no other functions)
45
+ # wave_output_path: the output folder
46
+
47
+ # make sure the query folder contain the samples from the same source
48
+ # each time, the model is able to separate one source from the track
49
+ # if you want to separate multiple sources, you need to change the query folder or write a script to help you do that
50
+
51
+
52
+ def save_in_file_fast(arr, file_name):
53
+ pickle.dump(arr, open(file_name, 'wb'), protocol=4)
54
+
55
+
56
+ def load_from_file_fast(file_name):
57
+ return pickle.load(open(file_name, 'rb'))
58
+
59
+
60
+ def create_vector():
61
+ test_type = 'mix'
62
+ inference_file = config.inference_file
63
+ inference_query = config.inference_query
64
+ test_key = config.test_key
65
+ wave_output_path = config.wave_output_path
66
+ sample_rate = config.sample_rate
67
+ resume_checkpoint_zeroshot = config.resume_checkpoint
68
+ resume_checkpoint_htsat = htsat_config.resume_checkpoint
69
+ print('Inference query folder: {}'.format(inference_query))
70
+ print('Test key: {}'.format(test_key))
71
+ print('Vector out folder: {}'.format(wave_output_path))
72
+ print('Sample rate: {}'.format(sample_rate))
73
+ print('Model 1 (zeroshot): {}'.format(resume_checkpoint_zeroshot))
74
+
75
+ # set exp settings
76
+ device_name = "cuda" if torch.cuda.is_available() else "cpu"
77
+ device = torch.device("cuda")
78
+ create_folder(wave_output_path)
79
+
80
+ # obtain the samples for query
81
+ queries = []
82
+ query_names = []
83
+ for query_file in tqdm.tqdm(os.listdir(inference_query)):
84
+ f_path = os.path.join(inference_query, query_file)
85
+ if query_file.endswith(".wav"):
86
+ temp_q, fs = librosa.load(f_path, sr=None)
87
+ temp_q = temp_q[:, None]
88
+ temp_q = prepprocess_audio(
89
+ temp_q,
90
+ fs,
91
+ sample_rate,
92
+ test_type
93
+ )
94
+ temp = [temp_q]
95
+ for dickey in test_key:
96
+ temp.append(temp_q)
97
+ temp = np.array(temp)
98
+ queries.append(temp)
99
+ query_names.append(os.path.basename(query_file))
100
+
101
+ sed_model = HTSAT_Swin_Transformer(
102
+ spec_size=htsat_config.htsat_spec_size,
103
+ patch_size=htsat_config.htsat_patch_size,
104
+ in_chans=1,
105
+ num_classes=htsat_config.classes_num,
106
+ window_size=htsat_config.htsat_window_size,
107
+ config=htsat_config,
108
+ depths=htsat_config.htsat_depth,
109
+ embed_dim=htsat_config.htsat_dim,
110
+ patch_stride=htsat_config.htsat_stride,
111
+ num_heads=htsat_config.htsat_num_head
112
+ )
113
+ at_model = SEDWrapper(
114
+ sed_model=sed_model,
115
+ config=htsat_config,
116
+ dataset=None
117
+ )
118
+ ckpt = torch.load(resume_checkpoint_htsat, map_location="cpu")
119
+ at_model.load_state_dict(ckpt["state_dict"])
120
+
121
+ if device_name == 'cpu':
122
+ trainer = pl.Trainer(
123
+ accelerator="cpu", gpus=None
124
+ )
125
+ else:
126
+ trainer = pl.Trainer(
127
+ gpus=1
128
+ )
129
+
130
+ print('Process: {}'.format(len(queries)))
131
+ avg_dataset = MusdbDataset(
132
+ tracks=queries
133
+ )
134
+ avg_loader = DataLoader(
135
+ dataset=avg_dataset,
136
+ num_workers=1,
137
+ batch_size=1,
138
+ shuffle=False
139
+ )
140
+ at_wrapper = AutoTaggingWarpper(
141
+ at_model=at_model,
142
+ config=config,
143
+ target_keys=test_key
144
+ )
145
+ trainer.test(
146
+ at_wrapper,
147
+ test_dataloaders=avg_loader
148
+ )
149
+ avg_at = at_wrapper.avg_at
150
+
151
+ md5_str = str(md5(str(queries).encode('utf-8')).hexdigest())
152
+ out_vector_path = wave_output_path + '/{}_vector_{}.pkl'.format(test_key[0], md5_str)
153
+ save_in_file_fast(avg_at, out_vector_path)
154
+ print('Vector saved in: {}'.format(out_vector_path))
155
+
156
+
157
+ if __name__ == '__main__':
158
+ create_vector()