Upload 18 files
Browse files- LICENSE +21 -0
- README.md +122 -12
- cog.yaml +33 -0
- config.py +62 -0
- create_balanced_list.py +24 -0
- create_index.sh +12 -0
- create_indexes.py +126 -0
- data_processor.py +179 -0
- htsat_config.py +122 -0
- htsat_utils.py +226 -0
- losses.py +23 -0
- main.py +502 -0
- opt_thres.pkl +3 -0
- predict.py +111 -0
- requirements.txt +19 -0
- sed_model.py +358 -0
- utils.py +580 -0
- zero_shot_create_vector.py +158 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2021 Knut(Ke) Chen
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,122 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Zero Shot Audio Source Separation
|
2 |
+
|
3 |
+
## Introduction
|
4 |
+
|
5 |
+
The Code Repository for "[Zero-shot Audio Source Separation through Query-based Learning from Weakly-labeled Data](https://arxiv.org/abs/2112.07891)", in AAAI 2022.
|
6 |
+
|
7 |
+
In this paper, we propose a three-component pipline that allows you to train a audio source separator to separate *any source* from the track. All you need is a mixture audio to separate, and a given source sample as a query. Then the model will separate your specified source from the track. Our model lies in a zero-shot setting because we never use the seapration dataset but a general audio dataset **AudioSet**. However, we achieve a very competible separation performance (SDR) in MUSDB18 Dataset compared with those supervised models. Our model has a generalization ability to unseen sources out of the training set. Indeed, we do not even require the separation dataset for training but solely **AudioSet**.
|
8 |
+
|
9 |
+
The demos and introduction are presented in our [short instroduction video](https://youtu.be/8XQ5ZyYRLQM) and [full presentation video](https://youtu.be/RgNwB_pJ7Cw).
|
10 |
+
|
11 |
+
More demos will be presented in [my personal website](https://www.knutchen.com) (now under construction)
|
12 |
+
|
13 |
+
Chckout this interactive demo at Replicate <a href="https://replicate.com/retrocirce/zero_shot_audio_source_separation"><img src="https://replicate.com/retrocirce/zero_shot_audio_source_separation/badge"></a> Thanks @[ariel415el](https://github.com/ariel415el) for creating this!
|
14 |
+
|
15 |
+

|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
## Main Separation Performance on MUSDB18 Dataset
|
20 |
+
We achieve a very competible separation performance (SDR) in MUSDB18 Dataset **with neither seeing the MUSDB18 training data nor speficying source targets**, compared with those supervised models.
|
21 |
+
|
22 |
+
Additionally, our model can easily separate many other sources, such as violin, harmonica, guitar, etc. (demos shown in the above video link)
|
23 |
+
|
24 |
+
<p align="center">
|
25 |
+
<img src="fig/results.png" align="center" alt="MUSDB results" width="50%"/>
|
26 |
+
</p>
|
27 |
+
|
28 |
+
## Getting Started
|
29 |
+
|
30 |
+
### Install Requirments
|
31 |
+
```
|
32 |
+
pip install -r requirements.txt
|
33 |
+
```
|
34 |
+
|
35 |
+
### Download and Processing Datasets
|
36 |
+
|
37 |
+
* config.py
|
38 |
+
```
|
39 |
+
change the varible "dataset_path" to your audioset address
|
40 |
+
change the classes_num to 527
|
41 |
+
```
|
42 |
+
|
43 |
+
* [AudioSet](https://research.google.com/audioset/download.html)
|
44 |
+
```
|
45 |
+
./create_index.sh #
|
46 |
+
// remember to change the pathes in the script
|
47 |
+
// more information about this script is in https://github.com/qiuqiangkong/audioset_tagging_cnn
|
48 |
+
|
49 |
+
python main.py save_idc
|
50 |
+
// count the number of samples in each class and save the npy files
|
51 |
+
```
|
52 |
+
|
53 |
+
* [MUSDB18](https://sigsep.github.io/datasets/musdb.html) - You can directly use [our processed musdb audio files](https://drive.google.com/drive/folders/1VwRnCxp3t2bXUS_MbXiFiggwkkJQEmha?usp=sharing) in 32000Hz sample rate. Or you set the "musdb_path" in the download path, and:
|
54 |
+
|
55 |
+
```
|
56 |
+
python main.py musdb_process
|
57 |
+
// Notice that the training set is a highlight version, while the testing set is the full version
|
58 |
+
```
|
59 |
+
|
60 |
+
|
61 |
+
### Set the Configuration File: config.py
|
62 |
+
|
63 |
+
The script *config.py* contains all configurations you need to assign to run your code.
|
64 |
+
|
65 |
+
Please read the introduction comments in the file and change your settings.
|
66 |
+
|
67 |
+
For the most important part:
|
68 |
+
|
69 |
+
If you want to train/test your model on AudioSet, you need to set:
|
70 |
+
```
|
71 |
+
dataset_path = "your processed audioset folder"
|
72 |
+
balanced_data = True
|
73 |
+
sample_rate = 32000
|
74 |
+
hop_size = 320
|
75 |
+
classes_num = 527
|
76 |
+
```
|
77 |
+
|
78 |
+
### Train and Evaluation
|
79 |
+
|
80 |
+
#### Train the sound event detection system ST-SED/HTS-AT
|
81 |
+
We further integrated this system ST-SED into an independent repository, and evaluteed it on more datasets, improved it a lot and achieved better performance.
|
82 |
+
|
83 |
+
You can follow [this repo](https://github.com/RetroCirce/HTS-Audio-Transformer) to train and evalute the sound event detection system ST-SED (or a more relevant name HTS-AT), the configuation file for training the model for this separation task should be [htsat_config.py](htsat_config.py).
|
84 |
+
|
85 |
+
For this separation task, if you want to save time, you can also download [the checkpoint](https://drive.google.com/drive/folders/1RouwHsGsMs8n3l_jF8XifWtbPzur_YQS?usp=sharing) directly.
|
86 |
+
|
87 |
+
#### Train, Evaluate and Inference the Seapration Model
|
88 |
+
|
89 |
+
All scripts is run by main.py:
|
90 |
+
```
|
91 |
+
Train: CUDA_VISIBLE_DEVICES=1,2,3,4 python main.py train
|
92 |
+
|
93 |
+
Test: CUDA_VISIBLE_DEVICES=1,2,3,4 python main.py test
|
94 |
+
|
95 |
+
```
|
96 |
+
We recommend using at least 4 GPU cards with above 20GB memories per card. In our training phrase, we use 8 Nvidia V-100 (32GB) GPUs.
|
97 |
+
|
98 |
+
We provide a quick **inference** interface by:
|
99 |
+
```
|
100 |
+
CUDA_VISIBLE_DEVICES=1 python main.py inference
|
101 |
+
```
|
102 |
+
Where you can separate any given source from the track. You need to set the value of "inference_file" and "inference_query" in *config.py*. Just check the comment and get it started. And for the inference, we recommend to use only one card (because it is already enough).
|
103 |
+
|
104 |
+
|
105 |
+
#### Model Checkpoints:
|
106 |
+
|
107 |
+
We provide the model checkpoints in this [link](https://drive.google.com/drive/folders/1RouwHsGsMs8n3l_jF8XifWtbPzur_YQS?usp=sharing). Feel free to download and test it.
|
108 |
+
|
109 |
+
## Citing
|
110 |
+
```
|
111 |
+
@inproceedings{zsasp-ke2022,
|
112 |
+
author = {Ke Chen* and Xingjian Du* and Bilei Zhu and Zejun Ma and Taylor Berg-Kirkpatrick and Shlomo Dubnov},
|
113 |
+
title = {Zero-shot Audio Source Separation via Query-based Learning from Weakly-labeled Data},
|
114 |
+
booktitle = {{AAAI} 2022}
|
115 |
+
}
|
116 |
+
|
117 |
+
@inproceedings{htsat-ke2022,
|
118 |
+
author = {Ke Chen and Xingjian Du and Bilei Zhu and Zejun Ma and Taylor Berg-Kirkpatrick and Shlomo Dubnov},
|
119 |
+
title = {HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection},
|
120 |
+
booktitle = {{ICASSP} 2022}
|
121 |
+
}
|
122 |
+
```
|
cog.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
build:
|
2 |
+
gpu: true
|
3 |
+
python_version: "3.8"
|
4 |
+
system_packages:
|
5 |
+
- "libgl1-mesa-glx"
|
6 |
+
- "libglib2.0-0"
|
7 |
+
- "libsndfile1-dev"
|
8 |
+
- "ffmpeg"
|
9 |
+
python_packages:
|
10 |
+
- torch==1.9.0
|
11 |
+
- torchmetrics==0.6.0
|
12 |
+
- torchaudio==0.9.0
|
13 |
+
- torchcontrib==0.0.2
|
14 |
+
- torchlibrosa==0.0.9
|
15 |
+
- librosa==0.8.0
|
16 |
+
- pytorch_lightning==1.4.1
|
17 |
+
- museval==0.4.0
|
18 |
+
- noisereduce==2.0.0
|
19 |
+
- numba==0.55.1
|
20 |
+
- numpy==1.19.4
|
21 |
+
- scikit_learn==0.24.0
|
22 |
+
- scipy==1.6.0
|
23 |
+
- soundfile==0.10.3.post1
|
24 |
+
- tensorboard==2.2.0
|
25 |
+
- tqdm==4.55.0
|
26 |
+
- h5py==3.1.0
|
27 |
+
- musdb==0.4.0
|
28 |
+
|
29 |
+
# run:
|
30 |
+
# - pip install open3d
|
31 |
+
# # - "gdown --id 16VnMcF1KJYxN9QId6TClMsZRahHNMW5g"
|
32 |
+
|
33 |
+
predict: "predict.py:Predictor"
|
config.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ke Chen
|
2 | |
3 |
+
# Zero-shot Audio Source Separation via Query-based Learning from Weakly-labeled Data
|
4 |
+
# The configuration file
|
5 |
+
|
6 |
+
# for model training
|
7 |
+
exp_name = "exp_zs_asp_full" # the saved ckpt prefix name of the model
|
8 |
+
workspace = "/home/Research/ZS_ASP/" # the folder of your code
|
9 |
+
dataset_path = "/home/Research/ZS_ASP/data/audioset" # the dataset path
|
10 |
+
index_type = "full_train"
|
11 |
+
idc_path = "/home/Research/ZS_ASP/" # the folder of audioset class count files
|
12 |
+
balanced_data = True
|
13 |
+
|
14 |
+
# trained from a checkpoint, or evaluate a single model
|
15 |
+
resume_checkpoint = None
|
16 |
+
# "/home/Research/ZS_ASP/model_backup/zeroshot_asp_full.ckpt"
|
17 |
+
|
18 |
+
loss_type = "mae"
|
19 |
+
|
20 |
+
gather_mode = False
|
21 |
+
debug = False
|
22 |
+
|
23 |
+
classes_num = 527
|
24 |
+
eval_list = [] # left blank to preserve all classes, otherwise will filter the specified classes
|
25 |
+
# [15, 63, 81, 184, 335, 449, 474, 348, 486, 4] # randomly generated from the 527-classes for held-out evaludation
|
26 |
+
|
27 |
+
|
28 |
+
batch_size = 16 * 8 # batch size per GPU x GPU number , default is 16 x 8 = 128
|
29 |
+
learning_rate = 1e-3 # 3e-4 is also workable
|
30 |
+
max_epoch = 100
|
31 |
+
num_workers = 3
|
32 |
+
lr_scheduler_epoch = [90, 110]
|
33 |
+
latent_dim = 2048
|
34 |
+
|
35 |
+
# for signal processing
|
36 |
+
sample_rate = 32000
|
37 |
+
clip_samples = sample_rate * 10 # audio_set 10-sec clip
|
38 |
+
segment_frames = 200
|
39 |
+
hop_samples = 320
|
40 |
+
random_seed = 12412 # 444612 1536123 12412
|
41 |
+
random_mode = "one_class" # "no_random, one_class, random, order", one class is the best
|
42 |
+
|
43 |
+
# for evaluation
|
44 |
+
musdb_path = "/home/Research/ZS_ASP/data/musdb-wav/" # musdb download folder
|
45 |
+
testavg_path = "/home/Research/ZS_ASP/data/musdb30-train-32000fs.npy" # the processed training set (to get the latent query)
|
46 |
+
testset_path = "/home/Research/ZS_ASP/data/musdb-test-32000fs.npy" # the processed testing set (to calculate the performance)
|
47 |
+
test_key = ["vocals", "drums", "bass", "other"] # four tracks for musdb, and your named track for other inference
|
48 |
+
test_type = "mix"
|
49 |
+
infer_type = "mean"
|
50 |
+
energy_thres = 0.1
|
51 |
+
wave_output_path = "/home/Research/ZS_ASP/wavoutput" # output folder
|
52 |
+
using_wiener = True # use wiener filter or not (default: True)
|
53 |
+
using_whiting = False # use whiting or not (default: False)
|
54 |
+
|
55 |
+
# weight average
|
56 |
+
wa_model_folder = "/home/Research/ZS_ASP/version_3/checkpoints/"
|
57 |
+
wa_model_path = "zs_wa.ckpt"
|
58 |
+
|
59 |
+
# for inference
|
60 |
+
inference_file = "/home/Research/ZS_ASP/data/pagenini.wav" # an audio file to separate
|
61 |
+
inference_query = "/home/Research/ZS_ASP/data/query" # a folder containing all samples for obtaining the query
|
62 |
+
overlap_rate = 0.0 # [0.0, 1.0), 0 to disabled, recommand 0.5 for 50% overlap. Overlap will increase computation time and improve result quality
|
create_balanced_list.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ke Chen
|
2 | |
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import config
|
6 |
+
import logging
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from utils import get_balanced_class_list
|
10 |
+
|
11 |
+
def main():
|
12 |
+
train_indexes_hdf5_path = os.path.join(config.dataset_path, "hdf5s", "indexes",
|
13 |
+
"{}.h5".format(config.data_type))
|
14 |
+
|
15 |
+
eval_indexes_hdf5_path = os.path.join(config.dataset_path, "hdf5s", "indexes", "eval.h5")
|
16 |
+
logging.info("Process training data")
|
17 |
+
indexes_per_class = get_balanced_class_list(train_indexes_hdf5_path, random_seed = config.random_seed)
|
18 |
+
np.save("idc_train.npy", indexes_per_class)
|
19 |
+
logging.info("Process testing data")
|
20 |
+
indexes_per_class = get_balanced_class_list(eval_indexes_hdf5_path, random_seed = config.random_seed)
|
21 |
+
np.save("idc_eval.npy", indexes_per_class)
|
22 |
+
if __name__ == '__main__':
|
23 |
+
logging.basicConfig(level=logging.INFO)
|
24 |
+
main()
|
create_index.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
python3 create_indexes.py create_indexes --waveforms_hdf5_path="/home/Research/ZS_ASP/data/audioset/hdf5s/indexes/balanced_train.h5"
|
4 |
+
|
5 |
+
# Unbalanced training indexes
|
6 |
+
for IDX in {00..40}; do
|
7 |
+
echo $IDX
|
8 |
+
python3 create_indexes.py create_indexes --waveforms_hdf5_path="/home/Research/ZS_ASP/data/audioset/hdf5s/waveforms/unbalanced_train/unbalanced_train_part$IDX.h5" --indexes_hdf5_path="/home/Research/ZS_ASP/data/audioset/hdf5s/indexes/unbalanced_train/unbalanced_train_part$IDX.h5"
|
9 |
+
done
|
10 |
+
|
11 |
+
# Combine balanced and unbalanced training indexes to a full training indexes hdf5
|
12 |
+
python3 create_indexes.py combine_full_indexes --indexes_hdf5s_dir="/home/Research/ZS_ASP/data/audioset/hdf5s/indexes" --full_indexes_hdf5_path="/home/Research/ZS_ASP/data/audioset/hdf5s/indexes/full_train.h5"
|
create_indexes.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import argparse
|
3 |
+
import csv
|
4 |
+
import os
|
5 |
+
import glob
|
6 |
+
import datetime
|
7 |
+
import time
|
8 |
+
import logging
|
9 |
+
import h5py
|
10 |
+
import librosa
|
11 |
+
|
12 |
+
from utils import create_folder, get_sub_filepaths
|
13 |
+
import config
|
14 |
+
|
15 |
+
|
16 |
+
def create_indexes(args):
|
17 |
+
"""Create indexes a for dataloader to read for training. When users have
|
18 |
+
a new task and their own data, they need to create similar indexes. The
|
19 |
+
indexes contain meta information of "where to find the data for training".
|
20 |
+
"""
|
21 |
+
|
22 |
+
# Arguments & parameters
|
23 |
+
waveforms_hdf5_path = args.waveforms_hdf5_path
|
24 |
+
indexes_hdf5_path = args.indexes_hdf5_path
|
25 |
+
|
26 |
+
# Paths
|
27 |
+
create_folder(os.path.dirname(indexes_hdf5_path))
|
28 |
+
|
29 |
+
with h5py.File(waveforms_hdf5_path, 'r') as hr:
|
30 |
+
with h5py.File(indexes_hdf5_path, 'w') as hw:
|
31 |
+
audios_num = len(hr['audio_name'])
|
32 |
+
hw.create_dataset('audio_name', data=hr['audio_name'][:], dtype='S20')
|
33 |
+
hw.create_dataset('target', data=hr['target'][:], dtype=np.bool)
|
34 |
+
hw.create_dataset('hdf5_path', data=[waveforms_hdf5_path.encode()] * audios_num, dtype='S200')
|
35 |
+
hw.create_dataset('index_in_hdf5', data=np.arange(audios_num), dtype=np.int32)
|
36 |
+
|
37 |
+
print('Write to {}'.format(indexes_hdf5_path))
|
38 |
+
|
39 |
+
|
40 |
+
def combine_full_indexes(args):
|
41 |
+
"""Combine all balanced and unbalanced indexes hdf5s to a single hdf5. This
|
42 |
+
combined indexes hdf5 is used for training with full data (~20k balanced
|
43 |
+
audio clips + ~1.9m unbalanced audio clips).
|
44 |
+
"""
|
45 |
+
|
46 |
+
# Arguments & parameters
|
47 |
+
indexes_hdf5s_dir = args.indexes_hdf5s_dir
|
48 |
+
full_indexes_hdf5_path = args.full_indexes_hdf5_path
|
49 |
+
|
50 |
+
classes_num = config.classes_num
|
51 |
+
|
52 |
+
# Paths
|
53 |
+
paths = get_sub_filepaths(indexes_hdf5s_dir)
|
54 |
+
paths = [path for path in paths if (
|
55 |
+
'train' in path and 'full_train' not in path and 'mini' not in path)]
|
56 |
+
|
57 |
+
print('Total {} hdf5 to combine.'.format(len(paths)))
|
58 |
+
|
59 |
+
with h5py.File(full_indexes_hdf5_path, 'w') as full_hf:
|
60 |
+
full_hf.create_dataset(
|
61 |
+
name='audio_name',
|
62 |
+
shape=(0,),
|
63 |
+
maxshape=(None,),
|
64 |
+
dtype='S20')
|
65 |
+
|
66 |
+
full_hf.create_dataset(
|
67 |
+
name='target',
|
68 |
+
shape=(0, classes_num),
|
69 |
+
maxshape=(None, classes_num),
|
70 |
+
dtype=np.bool)
|
71 |
+
|
72 |
+
full_hf.create_dataset(
|
73 |
+
name='hdf5_path',
|
74 |
+
shape=(0,),
|
75 |
+
maxshape=(None,),
|
76 |
+
dtype='S200')
|
77 |
+
|
78 |
+
full_hf.create_dataset(
|
79 |
+
name='index_in_hdf5',
|
80 |
+
shape=(0,),
|
81 |
+
maxshape=(None,),
|
82 |
+
dtype=np.int32)
|
83 |
+
|
84 |
+
for path in paths:
|
85 |
+
with h5py.File(path, 'r') as part_hf:
|
86 |
+
print(path)
|
87 |
+
n = len(full_hf['audio_name'][:])
|
88 |
+
new_n = n + len(part_hf['audio_name'][:])
|
89 |
+
|
90 |
+
full_hf['audio_name'].resize((new_n,))
|
91 |
+
full_hf['audio_name'][n : new_n] = part_hf['audio_name'][:]
|
92 |
+
|
93 |
+
full_hf['target'].resize((new_n, classes_num))
|
94 |
+
full_hf['target'][n : new_n] = part_hf['target'][:]
|
95 |
+
|
96 |
+
full_hf['hdf5_path'].resize((new_n,))
|
97 |
+
full_hf['hdf5_path'][n : new_n] = part_hf['hdf5_path'][:]
|
98 |
+
|
99 |
+
full_hf['index_in_hdf5'].resize((new_n,))
|
100 |
+
full_hf['index_in_hdf5'][n : new_n] = part_hf['index_in_hdf5'][:]
|
101 |
+
|
102 |
+
print('Write combined full hdf5 to {}'.format(full_indexes_hdf5_path))
|
103 |
+
|
104 |
+
|
105 |
+
if __name__ == '__main__':
|
106 |
+
parser = argparse.ArgumentParser()
|
107 |
+
subparsers = parser.add_subparsers(dest='mode')
|
108 |
+
|
109 |
+
parser_create_indexes = subparsers.add_parser('create_indexes')
|
110 |
+
parser_create_indexes.add_argument('--waveforms_hdf5_path', type=str, required=True, help='Path of packed waveforms hdf5.')
|
111 |
+
parser_create_indexes.add_argument('--indexes_hdf5_path', type=str, required=True, help='Path to write out indexes hdf5.')
|
112 |
+
|
113 |
+
parser_combine_full_indexes = subparsers.add_parser('combine_full_indexes')
|
114 |
+
parser_combine_full_indexes.add_argument('--indexes_hdf5s_dir', type=str, required=True, help='Directory containing indexes hdf5s to be combined.')
|
115 |
+
parser_combine_full_indexes.add_argument('--full_indexes_hdf5_path', type=str, required=True, help='Path to write out full indexes hdf5 file.')
|
116 |
+
|
117 |
+
args = parser.parse_args()
|
118 |
+
|
119 |
+
if args.mode == 'create_indexes':
|
120 |
+
create_indexes(args)
|
121 |
+
|
122 |
+
elif args.mode == 'combine_full_indexes':
|
123 |
+
combine_full_indexes(args)
|
124 |
+
|
125 |
+
else:
|
126 |
+
raise Exception('Incorrect arguments!')
|
data_processor.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ke Chen
|
2 | |
3 |
+
# Zero-shot Audio Source Separation via Query-based Learning from Weakly-labeled Data
|
4 |
+
# The dataset classes
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
import h5py
|
12 |
+
import csv
|
13 |
+
import time
|
14 |
+
import random
|
15 |
+
import json
|
16 |
+
from datetime import datetime
|
17 |
+
from utils import int16_to_float32
|
18 |
+
|
19 |
+
from torch.utils.data import Dataset, Sampler
|
20 |
+
|
21 |
+
# output the dict["index"].key form to save the memory in multi-GPU training
|
22 |
+
def reverse_dict(data_path, sed_path, output_dir):
|
23 |
+
# filename
|
24 |
+
waveform_dir = os.path.join(output_dir, "audioset_eval_waveform_balanced.h5")
|
25 |
+
sed_dir = os.path.join(output_dir, "audioset_eval_sed_balanced.h5")
|
26 |
+
# load data
|
27 |
+
logging.info("Write Data...............")
|
28 |
+
h_data = h5py.File(data_path, "r")
|
29 |
+
h_sed = h5py.File(sed_path, "r")
|
30 |
+
audio_num = len(h_data["waveform"])
|
31 |
+
assert len(h_data["waveform"]) == len(h_sed["sed_vector"]), "waveform and sed should be in the same length"
|
32 |
+
with h5py.File(waveform_dir, 'w') as hw:
|
33 |
+
for i in range(audio_num):
|
34 |
+
hw.create_dataset(str(i), data=int16_to_float32(h_data['waveform'][i]), dtype=np.float32)
|
35 |
+
logging.info("Write Data Succeed...............")
|
36 |
+
logging.info("Write Sed...............")
|
37 |
+
with h5py.File(sed_dir, 'w') as hw:
|
38 |
+
for i in range(audio_num):
|
39 |
+
hw.create_dataset(str(i), data=h_sed['sed_vector'][i], dtype=np.float32)
|
40 |
+
logging.info("Write Sed Succeed...............")
|
41 |
+
|
42 |
+
# A dataset for handling musdb
|
43 |
+
class MusdbDataset(Dataset):
|
44 |
+
def __init__(self, tracks):
|
45 |
+
self.tracks = tracks
|
46 |
+
self.dataset_len = len(tracks)
|
47 |
+
def __getitem__(self, index):
|
48 |
+
"""Load waveform and target of an audio clip.
|
49 |
+
Args:
|
50 |
+
index: the index number
|
51 |
+
Return:
|
52 |
+
track: [mixture + n_sources, n_samples]
|
53 |
+
"""
|
54 |
+
return self.tracks[index]
|
55 |
+
def __len__(self):
|
56 |
+
return self.dataset_len
|
57 |
+
|
58 |
+
class InferDataset(Dataset):
|
59 |
+
def __init__(self, tracks):
|
60 |
+
self.tracks = tracks
|
61 |
+
self.dataset_len = len(tracks)
|
62 |
+
def __getitem__(self, index):
|
63 |
+
"""Load waveform and target of an audio clip.
|
64 |
+
Args:
|
65 |
+
index: the index number
|
66 |
+
Return:
|
67 |
+
track: [mixture + n_sources, n_samples]
|
68 |
+
"""
|
69 |
+
return self.tracks[index]
|
70 |
+
def __len__(self):
|
71 |
+
return self.dataset_len
|
72 |
+
|
73 |
+
# polished LGSPDataset, the main dataset for procssing the audioset files
|
74 |
+
class LGSPDataset(Dataset):
|
75 |
+
def __init__(self, index_path, idc, config, factor = 3, eval_mode = False):
|
76 |
+
self.index_path = index_path
|
77 |
+
self.fp = h5py.File(index_path, "r")
|
78 |
+
self.config = config
|
79 |
+
self.idc = idc
|
80 |
+
self.factor = factor
|
81 |
+
self.classes_num = self.config.classes_num
|
82 |
+
self.eval_mode = eval_mode
|
83 |
+
self.total_size = int(len(self.fp["audio_name"]) * self.factor)
|
84 |
+
self.generate_queue()
|
85 |
+
logging.info("total dataset size: %d" %(self.total_size))
|
86 |
+
logging.info("class num: %d" %(self.classes_num))
|
87 |
+
|
88 |
+
def generate_queue(self):
|
89 |
+
self.queue = []
|
90 |
+
self.class_queue = []
|
91 |
+
if self.config.debug:
|
92 |
+
self.total_size = 1000
|
93 |
+
if self.config.balanced_data:
|
94 |
+
while len(self.queue) < self.total_size * 2:
|
95 |
+
if self.eval_mode:
|
96 |
+
if len(self.config.eval_list) == 0:
|
97 |
+
class_set = [*range(self.classes_num)]
|
98 |
+
else:
|
99 |
+
class_set = self.config.eval_list[:]
|
100 |
+
else:
|
101 |
+
class_set = [*range(self.classes_num)]
|
102 |
+
class_set = list(set(class_set) - set(self.config.eval_list))
|
103 |
+
random.shuffle(class_set)
|
104 |
+
self.queue += [self.idc[d][random.randint(0, len(self.idc[d]) - 1)] for d in class_set]
|
105 |
+
self.class_queue += class_set[:]
|
106 |
+
self.queue = self.queue[:self.total_size * 2]
|
107 |
+
self.class_queue = self.class_queue[:self.total_size * 2]
|
108 |
+
self.queue = [[self.queue[i],self.queue[i+1]] for i in range(0, self.total_size * 2, 2)]
|
109 |
+
self.class_queue = [[self.class_queue[i],self.class_queue[i+1]] for i in range(0, self.total_size * 2, 2)]
|
110 |
+
assert len(self.queue) == self.total_size, "generate data error!!"
|
111 |
+
else:
|
112 |
+
if self.eval_mode:
|
113 |
+
if len(self.config.eval_list) == 0:
|
114 |
+
class_set = [*range(self.classes_num)]
|
115 |
+
else:
|
116 |
+
class_set = self.config.eval_list[:]
|
117 |
+
else:
|
118 |
+
class_set = [*range(self.classes_num)]
|
119 |
+
class_set = list(set(class_set) - set(self.config.eval_list))
|
120 |
+
self.class_queue = random.choices(class_set, k = self.total_size * 2)
|
121 |
+
self.queue = [self.idc[d][random.randint(0, len(self.idc[d]) - 1)] for d in self.class_queue]
|
122 |
+
self.queue = [[self.queue[i],self.queue[i+1]] for i in range(0, self.total_size * 2, 2)]
|
123 |
+
self.class_queue = [[self.class_queue[i],self.class_queue[i+1]] for i in range(0, self.total_size * 2, 2)]
|
124 |
+
assert len(self.queue) == self.total_size, "generate data error!!"
|
125 |
+
logging.info("queue regenerated:%s" %(self.queue[-5:]))
|
126 |
+
|
127 |
+
def __getitem__(self, index):
|
128 |
+
"""Load waveform and target of an audio clip.
|
129 |
+
Args:
|
130 |
+
index: the index number
|
131 |
+
Return: {
|
132 |
+
"audio_name_1": str,
|
133 |
+
"waveform_1": (clip_samples,),
|
134 |
+
"class_id_1": int,
|
135 |
+
"audio_name_2": str,
|
136 |
+
"waveform_2": (clip_samples,),
|
137 |
+
"class_id_2": int,
|
138 |
+
...
|
139 |
+
"check_num": int
|
140 |
+
}
|
141 |
+
"""
|
142 |
+
# put the right index here!!!
|
143 |
+
data_dict = {}
|
144 |
+
for k in range(2):
|
145 |
+
s_index = self.queue[index][k]
|
146 |
+
target = self.class_queue[index][k]
|
147 |
+
audio_name = self.fp["audio_name"][s_index].decode()
|
148 |
+
hdf5_path = self.fp["hdf5_path"][s_index].decode().replace("/home/tiger/DB/knut/data/audioset", self.config.dataset_path)
|
149 |
+
r_idx = self.fp["index_in_hdf5"][s_index]
|
150 |
+
with h5py.File(hdf5_path, "r") as f:
|
151 |
+
waveform = int16_to_float32(f["waveform"][r_idx])
|
152 |
+
data_dict["audio_name_" + str(k+1)] = audio_name
|
153 |
+
data_dict["waveform_" + str(k+1)] = waveform
|
154 |
+
data_dict["class_id_" + str(k+1)] = target
|
155 |
+
data_dict["check_num"] = str(self.queue[-5:])
|
156 |
+
return data_dict
|
157 |
+
|
158 |
+
def __len__(self):
|
159 |
+
return self.total_size
|
160 |
+
|
161 |
+
# only for test
|
162 |
+
class TestDataset(Dataset):
|
163 |
+
def __init__(self, dataset_size):
|
164 |
+
print("init")
|
165 |
+
self.dataset_size = dataset_size
|
166 |
+
self.base_num = 100
|
167 |
+
self.dicts = [(self.base_num + 2 * i, self.base_num + 2 * i + 1) for i in range(self.dataset_size)]
|
168 |
+
|
169 |
+
def get_new_list(self):
|
170 |
+
self.base_num = random.randint(0,10)
|
171 |
+
print("base num changed:", self.base_num)
|
172 |
+
self.dicts = [(self.base_num + 2 * i, self.base_num + 2 * i + 1) for i in range(self.dataset_size)]
|
173 |
+
|
174 |
+
def __getitem__(self, index):
|
175 |
+
return self.dicts[index]
|
176 |
+
|
177 |
+
def __len__(self):
|
178 |
+
return self.dataset_size
|
179 |
+
|
htsat_config.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ke Chen
|
2 | |
3 |
+
# Zero-shot Audio Source Separation via Query-based Learning from Weakly-labeled Data
|
4 |
+
# The configuration file of ST-SED model or HTS-AT model
|
5 |
+
|
6 |
+
exp_name = "exp_htsat_2048d" # the saved ckpt prefix name of the model
|
7 |
+
workspace = "/home/kechen/Research/HTSAT" # the folder of your code
|
8 |
+
dataset_path = "/home/Research/audioset" # the dataset path
|
9 |
+
desed_folder = "/home/Research/DESED" # the desed file
|
10 |
+
|
11 |
+
dataset_type = "audioset"
|
12 |
+
|
13 |
+
loss_type = "clip_bce"
|
14 |
+
balanced_data = True
|
15 |
+
|
16 |
+
resume_checkpoint = "/home/kechen/Research/Latent_ASP/model_backup/htsat_audioset_2048d.ckpt"
|
17 |
+
|
18 |
+
esc_fold = 0 # just for esc dataset, select the fold you need for evaluation and (+1) validation
|
19 |
+
|
20 |
+
debug = False
|
21 |
+
|
22 |
+
random_seed = 970131 # 19970318 970131 12412 127777 1009 34047
|
23 |
+
batch_size = 32 * 4 # batch size per GPU x GPU number , default is 32 x 4 = 128
|
24 |
+
learning_rate = 1e-3 # 1e-4 also workable
|
25 |
+
max_epoch = 100
|
26 |
+
num_workers = 3
|
27 |
+
|
28 |
+
lr_scheduler_epoch = [10,20,30]
|
29 |
+
lr_rate = [0.02, 0.05, 0.1]
|
30 |
+
|
31 |
+
# these data preparation optimizations do not bring many improvements, so deprecated
|
32 |
+
enable_token_label = False # token label
|
33 |
+
class_map_path = "class_hier_map.npy"
|
34 |
+
class_filter = None
|
35 |
+
retrieval_index = [15382, 9202, 130, 17618, 17157, 17516, 16356, 6165, 13992, 9238, 5550, 5733, 1914, 1600, 3450, 13735, 11108, 3762,
|
36 |
+
9840, 11318, 8131, 4429, 16748, 4992, 16783, 12691, 4945, 8779, 2805, 9418, 2797, 14357, 5603, 212, 3852, 12666, 1338, 10269, 2388, 8260, 4293, 14454, 7677, 11253, 5060, 14938, 8840, 4542, 2627, 16336, 8992, 15496, 11140, 446, 6126, 10691, 8624, 10127, 9068, 16710, 10155, 14358, 7567, 5695, 2354, 8057, 17635, 133, 16183, 14535, 7248, 4560, 14429, 2463, 10773, 113, 2462, 9223, 4929, 14274, 4716, 17307, 4617, 2132, 11083, 1039, 1403, 9621, 13936, 2229, 2875, 17840, 9359, 13311, 9790, 13288, 4750, 17052, 8260, 14900]
|
37 |
+
token_label_range = [0.2,0.6]
|
38 |
+
enable_time_shift = False # shift time
|
39 |
+
enable_label_enhance = False # enhance hierarchical label
|
40 |
+
enable_repeat_mode = False # repeat the spectrogram / reshape the spectrogram
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
# for model's design
|
45 |
+
enable_tscam = True # enbale the token-semantic layer
|
46 |
+
|
47 |
+
# for signal processing
|
48 |
+
sample_rate = 32000 # 16000 for scv2, 32000 for audioset and esc-50
|
49 |
+
clip_samples = sample_rate * 10 # audio_set 10-sec clip
|
50 |
+
window_size = 1024
|
51 |
+
hop_size = 320 # 160 for scv2, 320 for audioset and esc-50
|
52 |
+
mel_bins = 64
|
53 |
+
fmin = 50
|
54 |
+
fmax = 14000
|
55 |
+
shift_max = int(clip_samples * 0.5)
|
56 |
+
|
57 |
+
# for data collection
|
58 |
+
classes_num = 527 # esc: 50 | audioset: 527 | scv2: 35
|
59 |
+
patch_size = (25, 4) # deprecated
|
60 |
+
crop_size = None # int(clip_samples * 0.5) deprecated
|
61 |
+
|
62 |
+
# for htsat hyperparamater
|
63 |
+
htsat_window_size = 8
|
64 |
+
htsat_spec_size = 256
|
65 |
+
htsat_patch_size = 4
|
66 |
+
htsat_stride = (4, 4)
|
67 |
+
htsat_num_head = [4,8,16,32]
|
68 |
+
htsat_dim = 256 # for 2048-d model
|
69 |
+
htsat_depth = [2,2,6,2]
|
70 |
+
|
71 |
+
swin_pretrain_path = None
|
72 |
+
# "/home/Research/model_backup/pretrain/swin_tiny_c24_patch4_window8_256.pth"
|
73 |
+
|
74 |
+
# Some Deprecated Optimization in the model design, check the model code for details
|
75 |
+
htsat_attn_heatmap = False
|
76 |
+
htsat_hier_output = False
|
77 |
+
htsat_use_max = False
|
78 |
+
|
79 |
+
|
80 |
+
# no use here
|
81 |
+
ensemble_checkpoints = []
|
82 |
+
ensemble_strides = []
|
83 |
+
|
84 |
+
|
85 |
+
# weight average folder
|
86 |
+
wa_folder = "/home/version_0/checkpoints/"
|
87 |
+
# weight average output filename
|
88 |
+
wa_model_path = "HTSAT_AudioSet_Saved_x.ckpt"
|
89 |
+
|
90 |
+
esm_model_pathes = [
|
91 |
+
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt",
|
92 |
+
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_2.ckpt",
|
93 |
+
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_3.ckpt",
|
94 |
+
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_4.ckpt",
|
95 |
+
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_5.ckpt",
|
96 |
+
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_6.ckpt"
|
97 |
+
]
|
98 |
+
|
99 |
+
# for framewise localization
|
100 |
+
heatmap_dir = "/home/Research/heatmap_output"
|
101 |
+
test_file = "htsat-test-ensemble"
|
102 |
+
fl_local = False # indicate if we need to use this dataset for the framewise detection
|
103 |
+
fl_dataset = "/home/Research/desed/desed_eval.npy"
|
104 |
+
fl_class_num = [
|
105 |
+
"Speech", "Frying", "Dishes", "Running_water",
|
106 |
+
"Blender", "Electric_shaver_toothbrush", "Alarm_bell_ringing",
|
107 |
+
"Cat", "Dog", "Vacuum_cleaner"
|
108 |
+
]
|
109 |
+
|
110 |
+
# map 527 classes into 10 classes
|
111 |
+
fl_audioset_mapping = [
|
112 |
+
[0,1,2,3,4,5,6,7],
|
113 |
+
[366, 367, 368],
|
114 |
+
[364],
|
115 |
+
[288, 289, 290, 291, 292, 293, 294, 295, 296, 297],
|
116 |
+
[369],
|
117 |
+
[382],
|
118 |
+
[310, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402],
|
119 |
+
[81, 82, 83, 84, 85],
|
120 |
+
[74, 75, 76, 77, 78, 79],
|
121 |
+
[377]
|
122 |
+
]
|
htsat_utils.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ke Chen
|
2 | |
3 |
+
# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
|
4 |
+
# Some Useful Common Methods
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch import Tensor
|
10 |
+
from typing import Optional
|
11 |
+
import logging
|
12 |
+
import os
|
13 |
+
import sys
|
14 |
+
import h5py
|
15 |
+
import csv
|
16 |
+
import time
|
17 |
+
import json
|
18 |
+
import museval
|
19 |
+
import librosa
|
20 |
+
from datetime import datetime
|
21 |
+
from tqdm import tqdm
|
22 |
+
from scipy import stats
|
23 |
+
import torch.nn as nn
|
24 |
+
import torch.nn.functional as F
|
25 |
+
|
26 |
+
|
27 |
+
# import from https://github.com/Alibaba-MIIL/ASL/blob/main/src/loss_functions/losses.py
|
28 |
+
class AsymmetricLoss(nn.Module):
|
29 |
+
def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True):
|
30 |
+
super(AsymmetricLoss, self).__init__()
|
31 |
+
|
32 |
+
self.gamma_neg = gamma_neg
|
33 |
+
self.gamma_pos = gamma_pos
|
34 |
+
self.clip = clip
|
35 |
+
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
|
36 |
+
self.eps = eps
|
37 |
+
|
38 |
+
def forward(self, x, y):
|
39 |
+
""""
|
40 |
+
Parameters
|
41 |
+
----------
|
42 |
+
x: input logits
|
43 |
+
y: targets (multi-label binarized vector)
|
44 |
+
"""
|
45 |
+
|
46 |
+
# Calculating Probabilities
|
47 |
+
# x_sigmoid = torch.sigmoid(x)
|
48 |
+
x_sigmoid = x # without sigmoid since it has been computed
|
49 |
+
xs_pos = x_sigmoid
|
50 |
+
xs_neg = 1 - x_sigmoid
|
51 |
+
|
52 |
+
# Asymmetric Clipping
|
53 |
+
if self.clip is not None and self.clip > 0:
|
54 |
+
xs_neg = (xs_neg + self.clip).clamp(max=1)
|
55 |
+
|
56 |
+
# Basic CE calculation
|
57 |
+
los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
|
58 |
+
los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
|
59 |
+
loss = los_pos + los_neg
|
60 |
+
|
61 |
+
# Asymmetric Focusing
|
62 |
+
if self.gamma_neg > 0 or self.gamma_pos > 0:
|
63 |
+
if self.disable_torch_grad_focal_loss:
|
64 |
+
torch.set_grad_enabled(False)
|
65 |
+
pt0 = xs_pos * y
|
66 |
+
pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
|
67 |
+
pt = pt0 + pt1
|
68 |
+
one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
|
69 |
+
one_sided_w = torch.pow(1 - pt, one_sided_gamma)
|
70 |
+
if self.disable_torch_grad_focal_loss:
|
71 |
+
torch.set_grad_enabled(True)
|
72 |
+
loss *= one_sided_w
|
73 |
+
|
74 |
+
return -loss.mean()
|
75 |
+
|
76 |
+
|
77 |
+
def get_mix_lambda(mixup_alpha, batch_size):
|
78 |
+
mixup_lambdas = [np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size)]
|
79 |
+
return np.array(mixup_lambdas).astype(np.float32)
|
80 |
+
|
81 |
+
def create_folder(fd):
|
82 |
+
if not os.path.exists(fd):
|
83 |
+
os.makedirs(fd)
|
84 |
+
|
85 |
+
def dump_config(config, filename, include_time = False):
|
86 |
+
save_time = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
|
87 |
+
config_json = {}
|
88 |
+
for key in dir(config):
|
89 |
+
if not key.startswith("_"):
|
90 |
+
config_json[key] = eval("config." + key)
|
91 |
+
if include_time:
|
92 |
+
filename = filename + "_" + save_time
|
93 |
+
with open(filename + ".json", "w") as f:
|
94 |
+
json.dump(config_json, f ,indent=4)
|
95 |
+
|
96 |
+
def int16_to_float32(x):
|
97 |
+
return (x / 32767.).astype(np.float32)
|
98 |
+
|
99 |
+
def float32_to_int16(x):
|
100 |
+
x = np.clip(x, a_min = -1., a_max = 1.)
|
101 |
+
return (x * 32767.).astype(np.int16)
|
102 |
+
|
103 |
+
|
104 |
+
# index for each class
|
105 |
+
def process_idc(index_path, classes_num, filename):
|
106 |
+
# load data
|
107 |
+
logging.info("Load Data...............")
|
108 |
+
idc = [[] for _ in range(classes_num)]
|
109 |
+
with h5py.File(index_path, "r") as f:
|
110 |
+
for i in tqdm(range(len(f["target"]))):
|
111 |
+
t_class = np.where(f["target"][i])[0]
|
112 |
+
for t in t_class:
|
113 |
+
idc[t].append(i)
|
114 |
+
print(idc)
|
115 |
+
np.save(filename, idc)
|
116 |
+
logging.info("Load Data Succeed...............")
|
117 |
+
|
118 |
+
def clip_bce(pred, target):
|
119 |
+
"""Binary crossentropy loss.
|
120 |
+
"""
|
121 |
+
return F.binary_cross_entropy(pred, target)
|
122 |
+
# return F.binary_cross_entropy(pred, target)
|
123 |
+
|
124 |
+
|
125 |
+
def clip_ce(pred, target):
|
126 |
+
return F.cross_entropy(pred, target)
|
127 |
+
|
128 |
+
def d_prime(auc):
|
129 |
+
d_prime = stats.norm().ppf(auc) * np.sqrt(2.0)
|
130 |
+
return d_prime
|
131 |
+
|
132 |
+
|
133 |
+
def get_loss_func(loss_type):
|
134 |
+
if loss_type == 'clip_bce':
|
135 |
+
return clip_bce
|
136 |
+
if loss_type == 'clip_ce':
|
137 |
+
return clip_ce
|
138 |
+
if loss_type == 'asl_loss':
|
139 |
+
loss_func = AsymmetricLoss(gamma_neg=4, gamma_pos=0,clip=0.05)
|
140 |
+
return loss_func
|
141 |
+
|
142 |
+
def do_mixup_label(x):
|
143 |
+
out = torch.logical_or(x, torch.flip(x, dims = [0])).float()
|
144 |
+
return out
|
145 |
+
|
146 |
+
def do_mixup(x, mixup_lambda):
|
147 |
+
"""
|
148 |
+
Args:
|
149 |
+
x: (batch_size , ...)
|
150 |
+
mixup_lambda: (batch_size,)
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
out: (batch_size, ...)
|
154 |
+
"""
|
155 |
+
out = (x.transpose(0,-1) * mixup_lambda + torch.flip(x, dims = [0]).transpose(0,-1) * (1 - mixup_lambda)).transpose(0,-1)
|
156 |
+
return out
|
157 |
+
|
158 |
+
def interpolate(x, ratio):
|
159 |
+
"""Interpolate data in time domain. This is used to compensate the
|
160 |
+
resolution reduction in downsampling of a CNN.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
x: (batch_size, time_steps, classes_num)
|
164 |
+
ratio: int, ratio to interpolate
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
upsampled: (batch_size, time_steps * ratio, classes_num)
|
168 |
+
"""
|
169 |
+
(batch_size, time_steps, classes_num) = x.shape
|
170 |
+
upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
|
171 |
+
upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
|
172 |
+
return upsampled
|
173 |
+
|
174 |
+
|
175 |
+
def pad_framewise_output(framewise_output, frames_num):
|
176 |
+
"""Pad framewise_output to the same length as input frames. The pad value
|
177 |
+
is the same as the value of the last frame.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
framewise_output: (batch_size, frames_num, classes_num)
|
181 |
+
frames_num: int, number of frames to pad
|
182 |
+
|
183 |
+
Outputs:
|
184 |
+
output: (batch_size, frames_num, classes_num)
|
185 |
+
"""
|
186 |
+
pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1)
|
187 |
+
"""tensor for padding"""
|
188 |
+
|
189 |
+
output = torch.cat((framewise_output, pad), dim=1)
|
190 |
+
"""(batch_size, frames_num, classes_num)"""
|
191 |
+
|
192 |
+
return output
|
193 |
+
|
194 |
+
# set the audio into the format that can be fed into the model
|
195 |
+
# resample -> convert to mono -> output the audio
|
196 |
+
# track [n_sample, n_channel]
|
197 |
+
def prepprocess_audio(track, ofs, rfs, mono_type = "mix"):
|
198 |
+
if track.shape[-1] > 1:
|
199 |
+
# stereo
|
200 |
+
if mono_type == "mix":
|
201 |
+
track = np.transpose(track, (1,0))
|
202 |
+
track = librosa.to_mono(track)
|
203 |
+
elif mono_type == "left":
|
204 |
+
track = track[:, 0]
|
205 |
+
elif mono_type == "right":
|
206 |
+
track = track[:, 1]
|
207 |
+
else:
|
208 |
+
track = track[:, 0]
|
209 |
+
# track [n_sample]
|
210 |
+
if ofs != rfs:
|
211 |
+
track = librosa.resample(track, ofs, rfs)
|
212 |
+
return track
|
213 |
+
|
214 |
+
def init_hier_head(class_map, num_class):
|
215 |
+
class_map = np.load(class_map, allow_pickle = True)
|
216 |
+
|
217 |
+
head_weight = torch.zeros(num_class,num_class).float()
|
218 |
+
head_bias = torch.zeros(num_class).float()
|
219 |
+
|
220 |
+
for i in range(len(class_map)):
|
221 |
+
for d in class_map[i][1]:
|
222 |
+
head_weight[d][i] = 1.0
|
223 |
+
for d in class_map[i][2]:
|
224 |
+
head_weight[d][i] = 1.0 / len(class_map[i][2])
|
225 |
+
head_weight[i][i] = 1.0
|
226 |
+
return head_weight, head_bias
|
losses.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn.functional as F
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def mae(input, target):
|
7 |
+
return torch.mean(torch.abs(input - target))
|
8 |
+
|
9 |
+
|
10 |
+
def logmae_wav(model, output_dict, target):
|
11 |
+
loss = torch.log10(torch.clamp(mae(output_dict['wav'], target), 1e-8, np.inf))
|
12 |
+
return loss
|
13 |
+
|
14 |
+
|
15 |
+
def get_loss_func(loss_type):
|
16 |
+
if loss_type == 'logmae_wav':
|
17 |
+
return logmae_wav
|
18 |
+
|
19 |
+
elif loss_type == 'mae':
|
20 |
+
return mae
|
21 |
+
|
22 |
+
else:
|
23 |
+
raise Exception('Incorrect loss_type!')
|
main.py
ADDED
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ke Chen
|
2 | |
3 |
+
# Zero-shot Audio Source Separation via Query-based Learning from Weakly-labeled Data
|
4 |
+
# The Main Script
|
5 |
+
|
6 |
+
import os
|
7 |
+
# this is to avoid the sdr calculation from occupying all cpus
|
8 |
+
os.environ["OMP_NUM_THREADS"] = "4"
|
9 |
+
os.environ["OPENBLAS_NUM_THREADS"] = "4"
|
10 |
+
os.environ["MKL_NUM_THREADS"] = "6"
|
11 |
+
os.environ["VECLIB_MAXIMUM_THREADS"] = "4"
|
12 |
+
os.environ["NUMEXPR_NUM_THREADS"] = "6"
|
13 |
+
|
14 |
+
import sys
|
15 |
+
import librosa
|
16 |
+
import numpy as np
|
17 |
+
import argparse
|
18 |
+
import logging
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from torch.utils.data import DataLoader
|
22 |
+
from torch.utils.data.distributed import DistributedSampler
|
23 |
+
|
24 |
+
from utils import collect_fn, dump_config, create_folder, prepprocess_audio
|
25 |
+
import musdb
|
26 |
+
|
27 |
+
from models.asp_model import ZeroShotASP, SeparatorModel, AutoTaggingWarpper, WhitingWarpper
|
28 |
+
from data_processor import LGSPDataset, MusdbDataset
|
29 |
+
import config
|
30 |
+
import htsat_config
|
31 |
+
from models.htsat import HTSAT_Swin_Transformer
|
32 |
+
from sed_model import SEDWrapper
|
33 |
+
|
34 |
+
import pytorch_lightning as pl
|
35 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
36 |
+
|
37 |
+
from htsat_utils import process_idc
|
38 |
+
|
39 |
+
import warnings
|
40 |
+
warnings.filterwarnings("ignore")
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
class data_prep(pl.LightningDataModule):
|
45 |
+
def __init__(self, train_dataset, eval_dataset, device_num, config):
|
46 |
+
super().__init__()
|
47 |
+
self.train_dataset = train_dataset
|
48 |
+
self.eval_dataset = eval_dataset
|
49 |
+
self.device_num = device_num
|
50 |
+
self.config = config
|
51 |
+
|
52 |
+
def train_dataloader(self):
|
53 |
+
train_sampler = DistributedSampler(self.train_dataset, shuffle = False) if self.device_num > 1 else None
|
54 |
+
train_loader = DataLoader(
|
55 |
+
dataset = self.train_dataset,
|
56 |
+
num_workers = config.num_workers,
|
57 |
+
batch_size = config.batch_size // self.device_num,
|
58 |
+
shuffle = False,
|
59 |
+
sampler = train_sampler,
|
60 |
+
collate_fn = collect_fn
|
61 |
+
)
|
62 |
+
return train_loader
|
63 |
+
def val_dataloader(self):
|
64 |
+
eval_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None
|
65 |
+
eval_loader = DataLoader(
|
66 |
+
dataset = self.eval_dataset,
|
67 |
+
num_workers = config.num_workers,
|
68 |
+
batch_size = config.batch_size // self.device_num,
|
69 |
+
shuffle = False,
|
70 |
+
sampler = eval_sampler,
|
71 |
+
collate_fn = collect_fn
|
72 |
+
)
|
73 |
+
return eval_loader
|
74 |
+
def test_dataloader(self):
|
75 |
+
test_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None
|
76 |
+
test_loader = DataLoader(
|
77 |
+
dataset = self.eval_dataset,
|
78 |
+
num_workers = config.num_workers,
|
79 |
+
batch_size = config.batch_size // self.device_num,
|
80 |
+
shuffle = False,
|
81 |
+
sampler = test_sampler,
|
82 |
+
collate_fn = collect_fn
|
83 |
+
)
|
84 |
+
return test_loader
|
85 |
+
|
86 |
+
def save_idc():
|
87 |
+
train_index_path = os.path.join(config.dataset_path, "hdf5s", "indexes", config.index_type + ".h5")
|
88 |
+
eval_index_path = os.path.join(config.dataset_path,"hdf5s", "indexes", "eval.h5")
|
89 |
+
process_idc(train_index_path, config.classes_num, config.index_type + "_idc.npy")
|
90 |
+
process_idc(eval_index_path, config.classes_num, "eval_idc.npy")
|
91 |
+
|
92 |
+
# Process the musdb tracks into the sample rate of 32000 Hz sample rate, the original is 44100 Hz
|
93 |
+
def process_musdb():
|
94 |
+
# use musdb as testset
|
95 |
+
test_data = musdb.DB(
|
96 |
+
root = config.musdb_path,
|
97 |
+
download = False,
|
98 |
+
subsets = "test",
|
99 |
+
is_wav = True
|
100 |
+
)
|
101 |
+
print(len(test_data.tracks))
|
102 |
+
mus_tracks = []
|
103 |
+
# in musdb, all fs is the same (44100)
|
104 |
+
orig_fs = test_data.tracks[0].rate
|
105 |
+
print(orig_fs)
|
106 |
+
for track in test_data.tracks:
|
107 |
+
temp = {}
|
108 |
+
mixture = prepprocess_audio(
|
109 |
+
track.audio,
|
110 |
+
orig_fs, config.sample_rate,
|
111 |
+
config.test_type
|
112 |
+
)
|
113 |
+
temp["mixture" ]= mixture
|
114 |
+
for dickey in config.test_key:
|
115 |
+
source = prepprocess_audio(
|
116 |
+
track.targets[dickey].audio,
|
117 |
+
orig_fs, config.sample_rate,
|
118 |
+
config.test_type
|
119 |
+
)
|
120 |
+
temp[dickey] = source
|
121 |
+
print(track.audio.shape, len(temp.keys()), temp["mixture"].shape)
|
122 |
+
mus_tracks.append(temp)
|
123 |
+
print(len(mus_tracks))
|
124 |
+
# save the file to npy
|
125 |
+
np.save("musdb-32000fs.npy", mus_tracks)
|
126 |
+
|
127 |
+
# weight average will perform in the given folder
|
128 |
+
# It will output one model checkpoint, which avergas the weight of all models in the folder
|
129 |
+
def weight_average():
|
130 |
+
model_ckpt = []
|
131 |
+
model_files = os.listdir(config.wa_model_folder)
|
132 |
+
wa_ckpt = {
|
133 |
+
"state_dict": {}
|
134 |
+
}
|
135 |
+
|
136 |
+
for model_file in model_files:
|
137 |
+
model_file = os.path.join(config.esm_model_folder, model_file)
|
138 |
+
model_ckpt.append(torch.load(model_file, map_location="cpu")["state_dict"])
|
139 |
+
keys = model_ckpt[0].keys()
|
140 |
+
for key in keys:
|
141 |
+
model_ckpt_key = torch.cat([d[key].float().unsqueeze(0) for d in model_ckpt])
|
142 |
+
model_ckpt_key = torch.mean(model_ckpt_key, dim = 0)
|
143 |
+
assert model_ckpt_key.shape == model_ckpt[0][key].shape, "the shape is unmatched " + model_ckpt_key.shape + " " + model_ckpt[0][key].shape
|
144 |
+
wa_ckpt["state_dict"][key] = model_ckpt_key
|
145 |
+
torch.save(wa_ckpt, config.wa_model_path)
|
146 |
+
|
147 |
+
|
148 |
+
# use the model to quickly separate a track given a query
|
149 |
+
# it requires four variables in config.py:
|
150 |
+
# inference_file: the track you want to separate
|
151 |
+
# inference_query: a **folder** containing all samples from the same source
|
152 |
+
# test_key: ["name"] indicate the source name (just a name for final output, no other functions)
|
153 |
+
# wave_output_path: the output folder
|
154 |
+
|
155 |
+
# make sure the query folder contain the samples from the same source
|
156 |
+
# each time, the model is able to separate one source from the track
|
157 |
+
# if you want to separate multiple sources, you need to change the query folder or write a script to help you do that
|
158 |
+
def inference():
|
159 |
+
# set exp settings
|
160 |
+
device_name = "cuda" if torch.cuda.is_available() else "cpu"
|
161 |
+
device = torch.device("cuda")
|
162 |
+
assert config.test_key is not None, "there should be a separate key"
|
163 |
+
create_folder(config.wave_output_path)
|
164 |
+
test_track, fs = librosa.load(config.inference_file, sr = None)
|
165 |
+
test_track = test_track[:,None]
|
166 |
+
print(test_track.shape)
|
167 |
+
print(fs)
|
168 |
+
# convert the track into 32000 Hz sample rate
|
169 |
+
test_track = prepprocess_audio(
|
170 |
+
test_track,
|
171 |
+
fs, config.sample_rate,
|
172 |
+
config.test_type
|
173 |
+
)
|
174 |
+
test_tracks = []
|
175 |
+
temp = [test_track]
|
176 |
+
for dickey in config.test_key:
|
177 |
+
temp.append(test_track)
|
178 |
+
temp = np.array(temp)
|
179 |
+
test_tracks.append(temp)
|
180 |
+
dataset = MusdbDataset(tracks = test_tracks) # the action is similar to musdbdataset, reuse it
|
181 |
+
loader = DataLoader(
|
182 |
+
dataset = dataset,
|
183 |
+
num_workers = 1,
|
184 |
+
batch_size = 1,
|
185 |
+
shuffle = False
|
186 |
+
)
|
187 |
+
# obtain the samples for query
|
188 |
+
queries = []
|
189 |
+
for query_file in os.listdir(config.inference_query):
|
190 |
+
f_path = os.path.join(config.inference_query, query_file)
|
191 |
+
if query_file.endswith(".wav"):
|
192 |
+
temp_q, fs = librosa.load(f_path, sr = None)
|
193 |
+
temp_q = temp_q[:, None]
|
194 |
+
temp_q = prepprocess_audio(
|
195 |
+
temp_q,
|
196 |
+
fs, config.sample_rate,
|
197 |
+
config.test_type
|
198 |
+
)
|
199 |
+
temp = [temp_q]
|
200 |
+
for dickey in config.test_key:
|
201 |
+
temp.append(temp_q)
|
202 |
+
temp = np.array(temp)
|
203 |
+
queries.append(temp)
|
204 |
+
|
205 |
+
assert config.resume_checkpoint is not None, "there should be a saved model when inferring"
|
206 |
+
|
207 |
+
sed_model = HTSAT_Swin_Transformer(
|
208 |
+
spec_size=htsat_config.htsat_spec_size,
|
209 |
+
patch_size=htsat_config.htsat_patch_size,
|
210 |
+
in_chans=1,
|
211 |
+
num_classes=htsat_config.classes_num,
|
212 |
+
window_size=htsat_config.htsat_window_size,
|
213 |
+
config = htsat_config,
|
214 |
+
depths = htsat_config.htsat_depth,
|
215 |
+
embed_dim = htsat_config.htsat_dim,
|
216 |
+
patch_stride=htsat_config.htsat_stride,
|
217 |
+
num_heads=htsat_config.htsat_num_head
|
218 |
+
)
|
219 |
+
at_model = SEDWrapper(
|
220 |
+
sed_model = sed_model,
|
221 |
+
config = htsat_config,
|
222 |
+
dataset = None
|
223 |
+
)
|
224 |
+
ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu")
|
225 |
+
at_model.load_state_dict(ckpt["state_dict"])
|
226 |
+
|
227 |
+
trainer = pl.Trainer(
|
228 |
+
gpus = 1
|
229 |
+
)
|
230 |
+
avg_at = None
|
231 |
+
# obtain the latent embedding as query
|
232 |
+
if config.infer_type == "mean":
|
233 |
+
avg_dataset = MusdbDataset(tracks = queries)
|
234 |
+
avg_loader = DataLoader(
|
235 |
+
dataset = avg_dataset,
|
236 |
+
num_workers = 1,
|
237 |
+
batch_size = 1,
|
238 |
+
shuffle = False
|
239 |
+
)
|
240 |
+
at_wrapper = AutoTaggingWarpper(
|
241 |
+
at_model = at_model,
|
242 |
+
config = config,
|
243 |
+
target_keys = config.test_key
|
244 |
+
)
|
245 |
+
trainer.test(at_wrapper, test_dataloaders = avg_loader)
|
246 |
+
avg_at = at_wrapper.avg_at
|
247 |
+
|
248 |
+
# import seapration model
|
249 |
+
model = ZeroShotASP(
|
250 |
+
channels = 1, config = config,
|
251 |
+
at_model = at_model,
|
252 |
+
dataset = dataset
|
253 |
+
)
|
254 |
+
# resume checkpoint
|
255 |
+
ckpt = torch.load(config.resume_checkpoint, map_location="cpu")
|
256 |
+
model.load_state_dict(ckpt["state_dict"], strict= False)
|
257 |
+
exp_model = SeparatorModel(
|
258 |
+
model = model,
|
259 |
+
config = config,
|
260 |
+
target_keys = config.test_key,
|
261 |
+
avg_at = avg_at,
|
262 |
+
using_wiener = False,
|
263 |
+
calc_sdr = False,
|
264 |
+
output_wav = True
|
265 |
+
)
|
266 |
+
trainer.test(exp_model, test_dataloaders = loader)
|
267 |
+
|
268 |
+
# test the separation model, mainly in musdb
|
269 |
+
def test():
|
270 |
+
# set exp settings
|
271 |
+
device_name = "cuda" if torch.cuda.is_available() else "cpu"
|
272 |
+
device = torch.device("cuda")
|
273 |
+
assert config.test_key is not None, "there should be a separate key"
|
274 |
+
create_folder(config.wave_output_path)
|
275 |
+
# use musdb as testset
|
276 |
+
test_data = np.load(config.testset_path, allow_pickle = True)
|
277 |
+
print(len(test_data))
|
278 |
+
mus_tracks = []
|
279 |
+
# in musdb, all fs is the same (44100)
|
280 |
+
# load the dataset
|
281 |
+
for track in test_data:
|
282 |
+
temp = []
|
283 |
+
mixture = track["mixture"]
|
284 |
+
temp.append(mixture)
|
285 |
+
for dickey in config.test_key:
|
286 |
+
source = track[dickey]
|
287 |
+
temp.append(source)
|
288 |
+
temp = np.array(temp)
|
289 |
+
print(temp.shape)
|
290 |
+
mus_tracks.append(temp)
|
291 |
+
print(len(mus_tracks))
|
292 |
+
dataset = MusdbDataset(tracks = mus_tracks)
|
293 |
+
loader = DataLoader(
|
294 |
+
dataset = dataset,
|
295 |
+
num_workers = 1,
|
296 |
+
batch_size = 1,
|
297 |
+
shuffle = False
|
298 |
+
)
|
299 |
+
assert config.resume_checkpoint is not None, "there should be a saved model when inferring"
|
300 |
+
|
301 |
+
sed_model = HTSAT_Swin_Transformer(
|
302 |
+
spec_size=htsat_config.htsat_spec_size,
|
303 |
+
patch_size=htsat_config.htsat_patch_size,
|
304 |
+
in_chans=1,
|
305 |
+
num_classes=htsat_config.classes_num,
|
306 |
+
window_size=htsat_config.htsat_window_size,
|
307 |
+
config = htsat_config,
|
308 |
+
depths = htsat_config.htsat_depth,
|
309 |
+
embed_dim = htsat_config.htsat_dim,
|
310 |
+
patch_stride=htsat_config.htsat_stride,
|
311 |
+
num_heads=htsat_config.htsat_num_head
|
312 |
+
)
|
313 |
+
at_model = SEDWrapper(
|
314 |
+
sed_model = sed_model,
|
315 |
+
config = htsat_config,
|
316 |
+
dataset = None
|
317 |
+
)
|
318 |
+
ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu")
|
319 |
+
at_model.load_state_dict(ckpt["state_dict"])
|
320 |
+
trainer = pl.Trainer(
|
321 |
+
gpus = 1
|
322 |
+
)
|
323 |
+
avg_at = None
|
324 |
+
# obtain the query of four stems from the training set
|
325 |
+
if config.infer_type == "mean":
|
326 |
+
avg_data = np.load(config.testavg_path, allow_pickle = True)[:90]
|
327 |
+
print(len(avg_data))
|
328 |
+
avgmus_tracks = []
|
329 |
+
# in musdb, all fs is the same (44100)
|
330 |
+
# load the dataset
|
331 |
+
for track in avg_data:
|
332 |
+
temp = []
|
333 |
+
mixture = track["mixture"]
|
334 |
+
temp.append(mixture)
|
335 |
+
for dickey in config.test_key:
|
336 |
+
source = track[dickey]
|
337 |
+
temp.append(source)
|
338 |
+
temp = np.array(temp)
|
339 |
+
print(temp.shape)
|
340 |
+
avgmus_tracks.append(temp)
|
341 |
+
print(len(avgmus_tracks))
|
342 |
+
avg_dataset = MusdbDataset(tracks = avgmus_tracks)
|
343 |
+
avg_loader = DataLoader(
|
344 |
+
dataset = avg_dataset,
|
345 |
+
num_workers = 1,
|
346 |
+
batch_size = 1,
|
347 |
+
shuffle = False
|
348 |
+
)
|
349 |
+
at_wrapper = AutoTaggingWarpper(
|
350 |
+
at_model = at_model,
|
351 |
+
config = config,
|
352 |
+
target_keys = config.test_key
|
353 |
+
)
|
354 |
+
trainer.test(at_wrapper, test_dataloaders = avg_loader)
|
355 |
+
avg_at = at_wrapper.avg_at
|
356 |
+
|
357 |
+
model = ZeroShotASP(
|
358 |
+
channels = 1, config = config,
|
359 |
+
at_model = at_model,
|
360 |
+
dataset = dataset
|
361 |
+
)
|
362 |
+
ckpt = torch.load(config.resume_checkpoint, map_location="cpu")
|
363 |
+
model.load_state_dict(ckpt["state_dict"], strict= False)
|
364 |
+
exp_model = SeparatorModel(
|
365 |
+
model = model,
|
366 |
+
config = config,
|
367 |
+
target_keys = config.test_key,
|
368 |
+
avg_at = avg_at,
|
369 |
+
using_wiener = config.using_wiener
|
370 |
+
)
|
371 |
+
trainer.test(exp_model, test_dataloaders = loader)
|
372 |
+
|
373 |
+
def train():
|
374 |
+
# set exp settings
|
375 |
+
# device_name = "cuda" if torch.cuda.is_available() else "cpu"
|
376 |
+
# device = torch.device("cuda")
|
377 |
+
|
378 |
+
device_num = torch.cuda.device_count()
|
379 |
+
print("each batch size:", config.batch_size // device_num)
|
380 |
+
|
381 |
+
train_index_path = os.path.join(config.dataset_path, "hdf5s","indexes", config.index_type + ".h5")
|
382 |
+
train_idc = np.load(os.path.join(config.idc_path, config.index_type + "_idc.npy"), allow_pickle = True)
|
383 |
+
|
384 |
+
eval_index_path = os.path.join(config.dataset_path,"hdf5s", "indexes", "eval.h5")
|
385 |
+
eval_idc = np.load(os.path.join(config.idc_path, "eval_idc.npy"), allow_pickle = True)
|
386 |
+
|
387 |
+
# set exp folder
|
388 |
+
exp_dir = os.path.join(config.workspace, "results", config.exp_name)
|
389 |
+
checkpoint_dir = os.path.join(config.workspace, "results", config.exp_name, "checkpoint")
|
390 |
+
|
391 |
+
if not config.debug:
|
392 |
+
create_folder(os.path.join(config.workspace, "results"))
|
393 |
+
create_folder(exp_dir)
|
394 |
+
create_folder(checkpoint_dir)
|
395 |
+
dump_config(config, os.path.join(exp_dir, config.exp_name), False)
|
396 |
+
|
397 |
+
# load data
|
398 |
+
# import dataset LGSPDataset (latent general source separation) and sampler
|
399 |
+
dataset = LGSPDataset(
|
400 |
+
index_path = train_index_path,
|
401 |
+
idc = train_idc,
|
402 |
+
config = config,
|
403 |
+
factor = 0.05,
|
404 |
+
eval_mode = False
|
405 |
+
)
|
406 |
+
eval_dataset = LGSPDataset(
|
407 |
+
index_path = eval_index_path,
|
408 |
+
idc = eval_idc,
|
409 |
+
config = config,
|
410 |
+
factor = 0.05,
|
411 |
+
eval_mode = True
|
412 |
+
)
|
413 |
+
|
414 |
+
audioset_data = data_prep(train_dataset=dataset,eval_dataset=eval_dataset,device_num=device_num, config=config)
|
415 |
+
checkpoint_callback = ModelCheckpoint(
|
416 |
+
monitor = "mixture_sdr",
|
417 |
+
filename='l-{epoch:d}-{mixture_sdr:.3f}-{clean_sdr:.3f}-{silence_sdr:.3f}',
|
418 |
+
save_top_k = 10,
|
419 |
+
mode = "max"
|
420 |
+
)
|
421 |
+
# infer at model
|
422 |
+
sed_model = HTSAT_Swin_Transformer(
|
423 |
+
spec_size=htsat_config.htsat_spec_size,
|
424 |
+
patch_size=htsat_config.htsat_patch_size,
|
425 |
+
in_chans=1,
|
426 |
+
num_classes=htsat_config.classes_num,
|
427 |
+
window_size=htsat_config.htsat_window_size,
|
428 |
+
config = htsat_config,
|
429 |
+
depths = htsat_config.htsat_depth,
|
430 |
+
embed_dim = htsat_config.htsat_dim,
|
431 |
+
patch_stride=htsat_config.htsat_stride,
|
432 |
+
num_heads=htsat_config.htsat_num_head
|
433 |
+
)
|
434 |
+
at_model = SEDWrapper(
|
435 |
+
sed_model = sed_model,
|
436 |
+
config = htsat_config,
|
437 |
+
dataset = None
|
438 |
+
)
|
439 |
+
# load the checkpoint
|
440 |
+
ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu")
|
441 |
+
at_model.load_state_dict(ckpt["state_dict"])
|
442 |
+
|
443 |
+
trainer = pl.Trainer(
|
444 |
+
deterministic=True,
|
445 |
+
default_root_dir = checkpoint_dir,
|
446 |
+
gpus = device_num,
|
447 |
+
val_check_interval = 0.2,
|
448 |
+
# check_val_every_n_epoch = 1,
|
449 |
+
max_epochs = config.max_epoch,
|
450 |
+
auto_lr_find = True,
|
451 |
+
sync_batchnorm = True,
|
452 |
+
callbacks = [checkpoint_callback],
|
453 |
+
accelerator = "ddp" if device_num > 1 else None,
|
454 |
+
resume_from_checkpoint = None, #config.resume_checkpoint,
|
455 |
+
replace_sampler_ddp = False,
|
456 |
+
gradient_clip_val=1.0,
|
457 |
+
num_sanity_val_steps = 0,
|
458 |
+
)
|
459 |
+
model = ZeroShotASP(
|
460 |
+
channels = 1, config = config,
|
461 |
+
at_model = at_model,
|
462 |
+
dataset = dataset
|
463 |
+
)
|
464 |
+
if config.resume_checkpoint is not None:
|
465 |
+
ckpt = torch.load(config.resume_checkpoint, map_location="cpu")
|
466 |
+
model.load_state_dict(ckpt["state_dict"])
|
467 |
+
# trainer.test(model, datamodule = audioset_data)
|
468 |
+
trainer.fit(model, audioset_data)
|
469 |
+
|
470 |
+
def main():
|
471 |
+
parser = argparse.ArgumentParser(description="latent genreal source separation parser")
|
472 |
+
subparsers = parser.add_subparsers(dest = "mode")
|
473 |
+
parser_train = subparsers.add_parser("train")
|
474 |
+
parser_test = subparsers.add_parser("test")
|
475 |
+
parser_musdb = subparsers.add_parser("musdb_process")
|
476 |
+
parser_saveidc = subparsers.add_parser("save_idc")
|
477 |
+
parser_wa = subparsers.add_parser("weight_average")
|
478 |
+
parser_infer = subparsers.add_parser("inference")
|
479 |
+
args = parser.parse_args()
|
480 |
+
# default settings
|
481 |
+
logging.basicConfig(level=logging.INFO)
|
482 |
+
pl.utilities.seed.seed_everything(seed = config.random_seed)
|
483 |
+
|
484 |
+
if args.mode == "train":
|
485 |
+
train()
|
486 |
+
elif args.mode == "test":
|
487 |
+
test()
|
488 |
+
elif args.mode == "musdb_process":
|
489 |
+
process_musdb()
|
490 |
+
elif args.mode == "weight_average":
|
491 |
+
weight_average()
|
492 |
+
elif args.mode == "save_idc":
|
493 |
+
save_idc()
|
494 |
+
elif args.mode == "inference":
|
495 |
+
inference()
|
496 |
+
else:
|
497 |
+
raise Exception("Error Mode!")
|
498 |
+
|
499 |
+
|
500 |
+
if __name__ == '__main__':
|
501 |
+
main()
|
502 |
+
|
opt_thres.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6f40e97a4946e70392576d4f4c171596bcb6243883a54f48aaa9ae5b86c0976c
|
3 |
+
size 13585
|
predict.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import types
|
3 |
+
|
4 |
+
import librosa
|
5 |
+
import numpy as np
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
|
10 |
+
import htsat_config
|
11 |
+
from cog import BasePredictor, Input, Path
|
12 |
+
from data_processor import MusdbDataset
|
13 |
+
from models.asp_model import AutoTaggingWarpper, SeparatorModel, ZeroShotASP
|
14 |
+
from models.htsat import HTSAT_Swin_Transformer
|
15 |
+
from sed_model import SEDWrapper
|
16 |
+
from utils import prepprocess_audio
|
17 |
+
|
18 |
+
def get_inference_configs():
|
19 |
+
config = types.SimpleNamespace()
|
20 |
+
config.ckpt_path = "pretrained/zeroshot_asp_full.ckpt"
|
21 |
+
config.sed_ckpt_path = "pretrained/htsat_audioset_2048d.ckpt"
|
22 |
+
config.wave_output_path = "predict_outputs"
|
23 |
+
config.test_key = "query_name"
|
24 |
+
config.test_type = "mix"
|
25 |
+
config.loss_type = "mae"
|
26 |
+
config.infer_type = "mean"
|
27 |
+
config.sample_rate = 32000
|
28 |
+
config.segment_frames = 200
|
29 |
+
config.hop_samples = 320
|
30 |
+
config.energy_thres = 0.1
|
31 |
+
config.using_whiting = False
|
32 |
+
config.latent_dim = 2048
|
33 |
+
config.classes_num = 527
|
34 |
+
config.overlap_rate = 0.5
|
35 |
+
config.num_workers = 1
|
36 |
+
|
37 |
+
return config
|
38 |
+
|
39 |
+
def load_models(config):
|
40 |
+
sed_model = HTSAT_Swin_Transformer(
|
41 |
+
spec_size=htsat_config.htsat_spec_size,
|
42 |
+
patch_size=htsat_config.htsat_patch_size,
|
43 |
+
in_chans=1,
|
44 |
+
num_classes=htsat_config.classes_num,
|
45 |
+
window_size=htsat_config.htsat_window_size,
|
46 |
+
config=htsat_config,
|
47 |
+
depths=htsat_config.htsat_depth,
|
48 |
+
embed_dim=htsat_config.htsat_dim,
|
49 |
+
patch_stride=htsat_config.htsat_stride,
|
50 |
+
num_heads=htsat_config.htsat_num_head,
|
51 |
+
)
|
52 |
+
at_model = SEDWrapper(sed_model=sed_model, config=htsat_config, dataset=None)
|
53 |
+
|
54 |
+
ckpt = torch.load(config.sed_ckpt_path, map_location="cpu")
|
55 |
+
at_model.load_state_dict(ckpt["state_dict"])
|
56 |
+
|
57 |
+
at_wrapper = AutoTaggingWarpper(
|
58 |
+
at_model=at_model, config=config, target_keys=[config.test_key]
|
59 |
+
)
|
60 |
+
|
61 |
+
asp_model = ZeroShotASP(channels=1, config=config, at_model=at_model, dataset=None)
|
62 |
+
ckpt = torch.load(config.ckpt_path, map_location="cpu")
|
63 |
+
asp_model.load_state_dict(ckpt["state_dict"], strict=False)
|
64 |
+
|
65 |
+
return at_wrapper, asp_model
|
66 |
+
|
67 |
+
def get_dataloader_from_sound_file(sound_file_path, config):
|
68 |
+
signal, sampling_rate = librosa.load(str(sound_file_path), sr=None)
|
69 |
+
signal = prepprocess_audio(
|
70 |
+
signal[:, None], sampling_rate, config.sample_rate, config.test_type
|
71 |
+
)
|
72 |
+
signal = np.array([signal, signal]) # Duplicate signal for later use
|
73 |
+
dataset = MusdbDataset(tracks=[signal])
|
74 |
+
data_loader = DataLoader(dataset, num_workers=config.num_workers, batch_size=1, shuffle=False)
|
75 |
+
return data_loader
|
76 |
+
|
77 |
+
|
78 |
+
class Predictor(BasePredictor):
|
79 |
+
def setup(self):
|
80 |
+
self.config = get_inference_configs()
|
81 |
+
os.makedirs(self.config.wave_output_path, exist_ok=True)
|
82 |
+
self.at_wrapper, self.asp_model = load_models(self.config)
|
83 |
+
|
84 |
+
def predict(
|
85 |
+
self,
|
86 |
+
mix_file: Path = Input(description="Reference sound to extract source from"),
|
87 |
+
query_file: Path = Input(description="Query sound to be searched and extracted from mix"),
|
88 |
+
) -> Path:
|
89 |
+
ref_loader = get_dataloader_from_sound_file(str(mix_file), self.config)
|
90 |
+
|
91 |
+
query_loader = get_dataloader_from_sound_file(str(query_file), self.config)
|
92 |
+
|
93 |
+
trainer = pl.Trainer(gpus=1)
|
94 |
+
trainer.test(self.at_wrapper, test_dataloaders=query_loader)
|
95 |
+
avg_at = self.at_wrapper.avg_at
|
96 |
+
|
97 |
+
exp_model = SeparatorModel(
|
98 |
+
model=self.asp_model,
|
99 |
+
config=self.config,
|
100 |
+
target_keys=[self.config.test_key],
|
101 |
+
avg_at=avg_at,
|
102 |
+
using_wiener=False,
|
103 |
+
calc_sdr=False,
|
104 |
+
output_wav=True,
|
105 |
+
)
|
106 |
+
trainer.test(exp_model, test_dataloaders=ref_loader)
|
107 |
+
|
108 |
+
prediction_path = os.path.join(
|
109 |
+
self.config.wave_output_path, f"0_{self.config.test_key}_pred_(0.0).wav"
|
110 |
+
)
|
111 |
+
return prediction_path
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
h5py==3.6.0
|
2 |
+
hydra-core>=1.0
|
3 |
+
librosa==0.8.1
|
4 |
+
musdb==0.4.0
|
5 |
+
museval==0.4.0
|
6 |
+
noisereduce==2.0.0
|
7 |
+
numba==0.55.1
|
8 |
+
numpy==1.21.5
|
9 |
+
omegaconf>=2.0.0
|
10 |
+
pytorch_lightning==1.5.9
|
11 |
+
scikit_learn==1.0.2
|
12 |
+
scipy==1.7.3
|
13 |
+
soundfile==0.10.3.post1
|
14 |
+
tensorboard==2.8.0
|
15 |
+
torch==1.10.2
|
16 |
+
torchaudio==0.10.2
|
17 |
+
torchcontrib==0.0.2
|
18 |
+
torchlibrosa==0.0.9
|
19 |
+
tqdm==4.62.3
|
sed_model.py
ADDED
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ke Chen
|
2 | |
3 |
+
# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
|
4 |
+
# The Model Training Wrapper
|
5 |
+
import numpy as np
|
6 |
+
import librosa
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import math
|
10 |
+
import bisect
|
11 |
+
import pickle
|
12 |
+
from numpy.lib.function_base import average
|
13 |
+
from sklearn import metrics
|
14 |
+
import soundfile as sf
|
15 |
+
from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
|
16 |
+
|
17 |
+
import tensorboard
|
18 |
+
import torch
|
19 |
+
import torchaudio
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.nn.functional as F
|
22 |
+
import torch.utils.checkpoint as cp
|
23 |
+
import torch.optim as optim
|
24 |
+
from torch.nn.parameter import Parameter
|
25 |
+
import torch.distributed as dist
|
26 |
+
from torchlibrosa.stft import STFT, ISTFT, magphase
|
27 |
+
import pytorch_lightning as pl
|
28 |
+
from htsat_utils import do_mixup, get_mix_lambda, do_mixup_label, get_loss_func, d_prime
|
29 |
+
import random
|
30 |
+
|
31 |
+
from torchcontrib.optim import SWA
|
32 |
+
|
33 |
+
|
34 |
+
class SEDWrapper(pl.LightningModule):
|
35 |
+
def __init__(self, sed_model, config, dataset):
|
36 |
+
super().__init__()
|
37 |
+
self.sed_model = sed_model
|
38 |
+
self.config = config
|
39 |
+
self.dataset = dataset
|
40 |
+
self.loss_func = get_loss_func(config.loss_type)
|
41 |
+
|
42 |
+
def evaluate_metric(self, pred, ans):
|
43 |
+
ap = []
|
44 |
+
if self.config.dataset_type == "audioset":
|
45 |
+
mAP = np.mean(average_precision_score(ans, pred, average = None))
|
46 |
+
mAUC = np.mean(roc_auc_score(ans, pred, average = None))
|
47 |
+
dprime = d_prime(mAUC)
|
48 |
+
return {"mAP": mAP, "mAUC": mAUC, "dprime": dprime}
|
49 |
+
else:
|
50 |
+
acc = accuracy_score(ans, np.argmax(pred, 1))
|
51 |
+
return {"acc": acc}
|
52 |
+
def forward(self, x, mix_lambda = None):
|
53 |
+
output_dict = self.sed_model(x, mix_lambda)
|
54 |
+
return output_dict["clipwise_output"], output_dict["framewise_output"]
|
55 |
+
|
56 |
+
def inference(self, x):
|
57 |
+
self.device_type = next(self.parameters()).device
|
58 |
+
self.eval()
|
59 |
+
x = torch.from_numpy(x).float().to(self.device_type)
|
60 |
+
output_dict = self.sed_model(x, None, True)
|
61 |
+
for key in output_dict.keys():
|
62 |
+
output_dict[key] = output_dict[key].detach().cpu().numpy()
|
63 |
+
return output_dict
|
64 |
+
|
65 |
+
def training_step(self, batch, batch_idx):
|
66 |
+
self.device_type = next(self.parameters()).device
|
67 |
+
mix_lambda = torch.from_numpy(get_mix_lambda(0.5, len(batch["waveform"]))).to(self.device_type)
|
68 |
+
# Another Choice: also mixup the target, but AudioSet is not a perfect data
|
69 |
+
# so "adding noise" might be better than purly "mix"
|
70 |
+
# batch["target"] = do_mixup_label(batch["target"])
|
71 |
+
# batch["target"] = do_mixup(batch["target"], mix_lambda)
|
72 |
+
|
73 |
+
pred, _ = self(batch["waveform"], mix_lambda)
|
74 |
+
loss = self.loss_func(pred, batch["target"])
|
75 |
+
self.log("loss", loss, on_epoch= True, prog_bar=True)
|
76 |
+
return loss
|
77 |
+
def training_epoch_end(self, outputs):
|
78 |
+
# Change: SWA, deprecated
|
79 |
+
# for opt in self.trainer.optimizers:
|
80 |
+
# if not type(opt) is SWA:
|
81 |
+
# continue
|
82 |
+
# opt.swap_swa_sgd()
|
83 |
+
self.dataset.generate_queue()
|
84 |
+
|
85 |
+
|
86 |
+
def validation_step(self, batch, batch_idx):
|
87 |
+
pred, _ = self(batch["waveform"])
|
88 |
+
return [pred.detach(), batch["target"].detach()]
|
89 |
+
|
90 |
+
def validation_epoch_end(self, validation_step_outputs):
|
91 |
+
self.device_type = next(self.parameters()).device
|
92 |
+
pred = torch.cat([d[0] for d in validation_step_outputs], dim = 0)
|
93 |
+
target = torch.cat([d[1] for d in validation_step_outputs], dim = 0)
|
94 |
+
gather_pred = [torch.zeros_like(pred) for _ in range(dist.get_world_size())]
|
95 |
+
gather_target = [torch.zeros_like(target) for _ in range(dist.get_world_size())]
|
96 |
+
dist.barrier()
|
97 |
+
if self.config.dataset_type == "audioset":
|
98 |
+
metric_dict = {
|
99 |
+
"mAP": 0.,
|
100 |
+
"mAUC": 0.,
|
101 |
+
"dprime": 0.
|
102 |
+
}
|
103 |
+
else:
|
104 |
+
metric_dict = {
|
105 |
+
"acc":0.
|
106 |
+
}
|
107 |
+
dist.all_gather(gather_pred, pred)
|
108 |
+
dist.all_gather(gather_target, target)
|
109 |
+
if dist.get_rank() == 0:
|
110 |
+
gather_pred = torch.cat(gather_pred, dim = 0).cpu().numpy()
|
111 |
+
gather_target = torch.cat(gather_target, dim = 0).cpu().numpy()
|
112 |
+
if self.config.dataset_type == "scv2":
|
113 |
+
gather_target = np.argmax(gather_target, 1)
|
114 |
+
metric_dict = self.evaluate_metric(gather_pred, gather_target)
|
115 |
+
print(self.device_type, dist.get_world_size(), metric_dict, flush = True)
|
116 |
+
|
117 |
+
if self.config.dataset_type == "audioset":
|
118 |
+
self.log("mAP", metric_dict["mAP"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
|
119 |
+
self.log("mAUC", metric_dict["mAUC"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
|
120 |
+
self.log("dprime", metric_dict["dprime"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
|
121 |
+
else:
|
122 |
+
self.log("acc", metric_dict["acc"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
|
123 |
+
dist.barrier()
|
124 |
+
|
125 |
+
def time_shifting(self, x, shift_len):
|
126 |
+
shift_len = int(shift_len)
|
127 |
+
new_sample = torch.cat([x[:, shift_len:], x[:, :shift_len]], axis = 1)
|
128 |
+
return new_sample
|
129 |
+
|
130 |
+
def test_step(self, batch, batch_idx):
|
131 |
+
self.device_type = next(self.parameters()).device
|
132 |
+
preds = []
|
133 |
+
# cancel the time shifting optimization because to speed up
|
134 |
+
shift_num = 1
|
135 |
+
for i in range(shift_num):
|
136 |
+
pred, pred_map = self(batch["waveform"])
|
137 |
+
preds.append(pred.unsqueeze(0))
|
138 |
+
batch["waveform"] = self.time_shifting(batch["waveform"], shift_len = 100 * (i + 1))
|
139 |
+
preds = torch.cat(preds, dim=0)
|
140 |
+
pred = preds.mean(dim = 0)
|
141 |
+
if self.config.fl_local:
|
142 |
+
return [
|
143 |
+
pred.detach().cpu().numpy(),
|
144 |
+
pred_map.detach().cpu().numpy(),
|
145 |
+
batch["audio_name"],
|
146 |
+
batch["real_len"].cpu().numpy()
|
147 |
+
]
|
148 |
+
else:
|
149 |
+
return [pred.detach(), batch["target"].detach()]
|
150 |
+
|
151 |
+
def test_epoch_end(self, test_step_outputs):
|
152 |
+
self.device_type = next(self.parameters()).device
|
153 |
+
if self.config.fl_local:
|
154 |
+
pred = np.concatenate([d[0] for d in test_step_outputs], axis = 0)
|
155 |
+
pred_map = np.concatenate([d[1] for d in test_step_outputs], axis = 0)
|
156 |
+
audio_name = np.concatenate([d[2] for d in test_step_outputs], axis = 0)
|
157 |
+
real_len = np.concatenate([d[3] for d in test_step_outputs], axis = 0)
|
158 |
+
heatmap_file = os.path.join(self.config.heatmap_dir, self.config.test_file + "_" + str(self.device_type) + ".npy")
|
159 |
+
save_npy = [
|
160 |
+
{
|
161 |
+
"audio_name": audio_name[i],
|
162 |
+
"heatmap": pred_map[i],
|
163 |
+
"pred": pred[i],
|
164 |
+
"real_len":real_len[i]
|
165 |
+
}
|
166 |
+
for i in range(len(pred))
|
167 |
+
]
|
168 |
+
np.save(heatmap_file, save_npy)
|
169 |
+
else:
|
170 |
+
self.device_type = next(self.parameters()).device
|
171 |
+
pred = torch.cat([d[0] for d in test_step_outputs], dim = 0)
|
172 |
+
target = torch.cat([d[1] for d in test_step_outputs], dim = 0)
|
173 |
+
gather_pred = [torch.zeros_like(pred) for _ in range(dist.get_world_size())]
|
174 |
+
gather_target = [torch.zeros_like(target) for _ in range(dist.get_world_size())]
|
175 |
+
dist.barrier()
|
176 |
+
if self.config.dataset_type == "audioset":
|
177 |
+
metric_dict = {
|
178 |
+
"mAP": 0.,
|
179 |
+
"mAUC": 0.,
|
180 |
+
"dprime": 0.
|
181 |
+
}
|
182 |
+
else:
|
183 |
+
metric_dict = {
|
184 |
+
"acc":0.
|
185 |
+
}
|
186 |
+
dist.all_gather(gather_pred, pred)
|
187 |
+
dist.all_gather(gather_target, target)
|
188 |
+
if dist.get_rank() == 0:
|
189 |
+
gather_pred = torch.cat(gather_pred, dim = 0).cpu().numpy()
|
190 |
+
gather_target = torch.cat(gather_target, dim = 0).cpu().numpy()
|
191 |
+
if self.config.dataset_type == "scv2":
|
192 |
+
gather_target = np.argmax(gather_target, 1)
|
193 |
+
metric_dict = self.evaluate_metric(gather_pred, gather_target)
|
194 |
+
print(self.device_type, dist.get_world_size(), metric_dict, flush = True)
|
195 |
+
if self.config.dataset_type == "audioset":
|
196 |
+
self.log("mAP", metric_dict["mAP"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
|
197 |
+
self.log("mAUC", metric_dict["mAUC"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
|
198 |
+
self.log("dprime", metric_dict["dprime"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
|
199 |
+
else:
|
200 |
+
self.log("acc", metric_dict["acc"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
|
201 |
+
dist.barrier()
|
202 |
+
|
203 |
+
|
204 |
+
def configure_optimizers(self):
|
205 |
+
optimizer = optim.AdamW(
|
206 |
+
filter(lambda p: p.requires_grad, self.parameters()),
|
207 |
+
lr = self.config.learning_rate,
|
208 |
+
betas = (0.9, 0.999), eps = 1e-08, weight_decay = 0.05,
|
209 |
+
)
|
210 |
+
# Change: SWA, deprecated
|
211 |
+
# optimizer = SWA(optimizer, swa_start=10, swa_freq=5)
|
212 |
+
def lr_foo(epoch):
|
213 |
+
if epoch < 3:
|
214 |
+
# warm up lr
|
215 |
+
lr_scale = self.config.lr_rate[epoch]
|
216 |
+
else:
|
217 |
+
# warmup schedule
|
218 |
+
lr_pos = int(-1 - bisect.bisect_left(self.config.lr_scheduler_epoch, epoch))
|
219 |
+
if lr_pos < -3:
|
220 |
+
lr_scale = max(self.config.lr_rate[0] * (0.98 ** epoch), 0.03 )
|
221 |
+
else:
|
222 |
+
lr_scale = self.config.lr_rate[lr_pos]
|
223 |
+
return lr_scale
|
224 |
+
scheduler = optim.lr_scheduler.LambdaLR(
|
225 |
+
optimizer,
|
226 |
+
lr_lambda=lr_foo
|
227 |
+
)
|
228 |
+
|
229 |
+
return [optimizer], [scheduler]
|
230 |
+
|
231 |
+
|
232 |
+
|
233 |
+
class Ensemble_SEDWrapper(pl.LightningModule):
|
234 |
+
def __init__(self, sed_models, config, dataset):
|
235 |
+
super().__init__()
|
236 |
+
|
237 |
+
self.sed_models = nn.ModuleList(sed_models)
|
238 |
+
self.config = config
|
239 |
+
self.dataset = dataset
|
240 |
+
|
241 |
+
def evaluate_metric(self, pred, ans):
|
242 |
+
if self.config.dataset_type == "audioset":
|
243 |
+
mAP = np.mean(average_precision_score(ans, pred, average = None))
|
244 |
+
mAUC = np.mean(roc_auc_score(ans, pred, average = None))
|
245 |
+
dprime = d_prime(mAUC)
|
246 |
+
return {"mAP": mAP, "mAUC": mAUC, "dprime": dprime}
|
247 |
+
else:
|
248 |
+
acc = accuracy_score(ans, np.argmax(pred, 1))
|
249 |
+
return {"acc": acc}
|
250 |
+
|
251 |
+
def forward(self, x, sed_index, mix_lambda = None):
|
252 |
+
self.sed_models[sed_index].eval()
|
253 |
+
preds = []
|
254 |
+
pred_maps = []
|
255 |
+
# cancel the time shifting optimization because to speed up
|
256 |
+
shift_num = 1
|
257 |
+
for i in range(shift_num):
|
258 |
+
pred, pred_map = self.sed_models[sed_index](x)
|
259 |
+
pred_maps.append(pred_map.unsqueeze(0))
|
260 |
+
preds.append(pred.unsqueeze(0))
|
261 |
+
x = self.time_shifting(x, shift_len = 100 * (i + 1))
|
262 |
+
preds = torch.cat(preds, dim=0)
|
263 |
+
pred_maps = torch.cat(pred_maps, dim = 0)
|
264 |
+
pred = preds.mean(dim = 0)
|
265 |
+
pred_map = pred_maps.mean(dim = 0)
|
266 |
+
return pred, pred_map
|
267 |
+
|
268 |
+
|
269 |
+
def time_shifting(self, x, shift_len):
|
270 |
+
shift_len = int(shift_len)
|
271 |
+
new_sample = torch.cat([x[:, shift_len:], x[:, :shift_len]], axis = 1)
|
272 |
+
return new_sample
|
273 |
+
|
274 |
+
def test_step(self, batch, batch_idx):
|
275 |
+
self.device_type = next(self.parameters()).device
|
276 |
+
if self.config.fl_local:
|
277 |
+
pred = torch.zeros(len(batch["waveform"]), self.config.classes_num).float().to(self.device_type)
|
278 |
+
pred_map = torch.zeros(len(batch["waveform"]), 1024, self.config.classes_num).float().to(self.device_type)
|
279 |
+
for j in range(len(self.sed_models)):
|
280 |
+
temp_pred, temp_pred_map = self(batch["waveform"], j)
|
281 |
+
pred = pred + temp_pred
|
282 |
+
pred_map = pred_map + temp_pred_map
|
283 |
+
pred = pred / len(self.sed_models)
|
284 |
+
pred_map = pred_map / len(self.sed_models)
|
285 |
+
return [
|
286 |
+
pred.detach().cpu().numpy(),
|
287 |
+
pred_map.detach().cpu().numpy(),
|
288 |
+
batch["audio_name"],
|
289 |
+
batch["real_len"].cpu().numpy()
|
290 |
+
]
|
291 |
+
else:
|
292 |
+
pred = torch.zeros(len(batch["waveform"]), self.config.classes_num).float().to(self.device_type)
|
293 |
+
for j in range(len(self.sed_models)):
|
294 |
+
temp_pred, _ = self(batch["waveform"], j)
|
295 |
+
pred = pred + temp_pred
|
296 |
+
pred = pred / len(self.sed_models)
|
297 |
+
return [
|
298 |
+
pred.detach(),
|
299 |
+
batch["target"].detach(),
|
300 |
+
]
|
301 |
+
|
302 |
+
def test_epoch_end(self, test_step_outputs):
|
303 |
+
self.device_type = next(self.parameters()).device
|
304 |
+
if self.config.fl_local:
|
305 |
+
pred = np.concatenate([d[0] for d in test_step_outputs], axis = 0)
|
306 |
+
pred_map = np.concatenate([d[1] for d in test_step_outputs], axis = 0)
|
307 |
+
audio_name = np.concatenate([d[2] for d in test_step_outputs], axis = 0)
|
308 |
+
real_len = np.concatenate([d[3] for d in test_step_outputs], axis = 0)
|
309 |
+
heatmap_file = os.path.join(self.config.heatmap_dir, self.config.test_file + "_" + str(self.device_type) + ".npy")
|
310 |
+
print(pred.shape)
|
311 |
+
print(pred_map.shape)
|
312 |
+
print(real_len.shape)
|
313 |
+
save_npy = [
|
314 |
+
{
|
315 |
+
"audio_name": audio_name[i],
|
316 |
+
"heatmap": pred_map[i],
|
317 |
+
"pred": pred[i],
|
318 |
+
"real_len":real_len[i]
|
319 |
+
}
|
320 |
+
for i in range(len(pred))
|
321 |
+
]
|
322 |
+
np.save(heatmap_file, save_npy)
|
323 |
+
else:
|
324 |
+
pred = torch.cat([d[0] for d in test_step_outputs], dim = 0)
|
325 |
+
target = torch.cat([d[1] for d in test_step_outputs], dim = 0)
|
326 |
+
gather_pred = [torch.zeros_like(pred) for _ in range(dist.get_world_size())]
|
327 |
+
gather_target = [torch.zeros_like(target) for _ in range(dist.get_world_size())]
|
328 |
+
|
329 |
+
dist.barrier()
|
330 |
+
if self.config.dataset_type == "audioset":
|
331 |
+
metric_dict = {
|
332 |
+
"mAP": 0.,
|
333 |
+
"mAUC": 0.,
|
334 |
+
"dprime": 0.
|
335 |
+
}
|
336 |
+
else:
|
337 |
+
metric_dict = {
|
338 |
+
"acc":0.
|
339 |
+
}
|
340 |
+
dist.all_gather(gather_pred, pred)
|
341 |
+
dist.all_gather(gather_target, target)
|
342 |
+
if dist.get_rank() == 0:
|
343 |
+
gather_pred = torch.cat(gather_pred, dim = 0).cpu().numpy()
|
344 |
+
gather_target = torch.cat(gather_target, dim = 0).cpu().numpy()
|
345 |
+
if self.config.dataset_type == "scv2":
|
346 |
+
gather_target = np.argmax(gather_target, 1)
|
347 |
+
metric_dict = self.evaluate_metric(gather_pred, gather_target)
|
348 |
+
print(self.device_type, dist.get_world_size(), metric_dict, flush = True)
|
349 |
+
if self.config.dataset_type == "audioset":
|
350 |
+
self.log("mAP", metric_dict["mAP"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
|
351 |
+
self.log("mAUC", metric_dict["mAUC"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
|
352 |
+
self.log("dprime", metric_dict["dprime"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
|
353 |
+
else:
|
354 |
+
self.log("acc", metric_dict["acc"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
|
355 |
+
dist.barrier()
|
356 |
+
|
357 |
+
|
358 |
+
|
utils.py
ADDED
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ke Chen
|
2 | |
3 |
+
# Zero-shot Audio Source Separation via Query-based Learning from Weakly-labeled Data
|
4 |
+
# Some Common Methods
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from scipy.signal import butter, filtfilt
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch import Tensor
|
11 |
+
from typing import Optional
|
12 |
+
import logging
|
13 |
+
import os
|
14 |
+
import sys
|
15 |
+
import h5py
|
16 |
+
import csv
|
17 |
+
import time
|
18 |
+
import json
|
19 |
+
import museval
|
20 |
+
import librosa
|
21 |
+
from datetime import datetime
|
22 |
+
|
23 |
+
def create_folder(fd):
|
24 |
+
if not os.path.exists(fd):
|
25 |
+
os.makedirs(fd)
|
26 |
+
|
27 |
+
def get_filename(path):
|
28 |
+
path = os.path.realpath(path)
|
29 |
+
na_ext = path.split('/')[-1]
|
30 |
+
na = os.path.splitext(na_ext)[0]
|
31 |
+
return na
|
32 |
+
|
33 |
+
def get_sub_filepaths(folder):
|
34 |
+
paths = []
|
35 |
+
for root, dirs, files in os.walk(folder):
|
36 |
+
for name in files:
|
37 |
+
path = os.path.join(root, name)
|
38 |
+
paths.append(path)
|
39 |
+
return paths
|
40 |
+
|
41 |
+
def np_to_pytorch(x, device = None):
|
42 |
+
if 'float' in str(x.dtype):
|
43 |
+
x = torch.Tensor(x)
|
44 |
+
elif 'int' in str(x.dtype):
|
45 |
+
x = torch.LongTensor(x)
|
46 |
+
else:
|
47 |
+
return x
|
48 |
+
return x.to(device)
|
49 |
+
|
50 |
+
def count_parameters(model):
|
51 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
52 |
+
|
53 |
+
def calculate_average_energy(x):
|
54 |
+
return np.mean(np.square(x))
|
55 |
+
|
56 |
+
def id_to_one_hot(id, classes_num):
|
57 |
+
one_hot = np.zeros(classes_num)
|
58 |
+
one_hot[id] = 1
|
59 |
+
return one_hot
|
60 |
+
|
61 |
+
def ids_to_hots(ids, classes_num):
|
62 |
+
hots = np.zeros(classes_num)
|
63 |
+
for id in ids:
|
64 |
+
hots[id] = 1
|
65 |
+
return hots
|
66 |
+
|
67 |
+
def float32_to_int16(x):
|
68 |
+
assert np.max(np.abs(x)) <= 1.
|
69 |
+
return (x * 32767.).astype(np.int16)
|
70 |
+
|
71 |
+
def int16_to_float32(x):
|
72 |
+
return (x / 32767.).astype(np.float32)
|
73 |
+
|
74 |
+
def collect_fn(list_data_dict):
|
75 |
+
np_data_dict = {}
|
76 |
+
for key in list_data_dict[0].keys():
|
77 |
+
np_data_dict[key] = np.array([data_dict[key] for data_dict in list_data_dict])
|
78 |
+
return np_data_dict
|
79 |
+
|
80 |
+
def dump_config(config, filename, include_time = False):
|
81 |
+
save_time = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
|
82 |
+
config_json = {}
|
83 |
+
for key in dir(config):
|
84 |
+
if not key.startswith("_"):
|
85 |
+
config_json[key] = eval("config." + key)
|
86 |
+
if include_time:
|
87 |
+
filename = filename + "_" + save_time
|
88 |
+
with open(filename + ".json", "w") as f:
|
89 |
+
json.dump(config_json, f ,indent=4)
|
90 |
+
|
91 |
+
|
92 |
+
def get_segment_bgn_end_samples(anchor_index, segment_frames, hop_samples, clip_samples):
|
93 |
+
bgn_frame = anchor_index - segment_frames // 2
|
94 |
+
end_frame = anchor_index + segment_frames // 2
|
95 |
+
bgn_sample = bgn_frame * hop_samples
|
96 |
+
end_sample = end_frame * hop_samples
|
97 |
+
|
98 |
+
segment_samples = segment_frames * hop_samples
|
99 |
+
|
100 |
+
if bgn_sample < 0:
|
101 |
+
bgn_sample = 0
|
102 |
+
end_sample = segment_samples
|
103 |
+
|
104 |
+
if end_sample > clip_samples:
|
105 |
+
bgn_sample = clip_samples - segment_samples
|
106 |
+
end_sample = clip_samples
|
107 |
+
|
108 |
+
return bgn_sample, end_sample
|
109 |
+
|
110 |
+
def get_mix_data(waveforms, con_vectors, class_ids, indexes, mix_type = "mixture"):
|
111 |
+
# define return data
|
112 |
+
mixtures = []
|
113 |
+
sources = []
|
114 |
+
conditions = []
|
115 |
+
gds = []
|
116 |
+
for i in range(0, len(indexes), 2):
|
117 |
+
n1 = indexes[i]
|
118 |
+
n2 = indexes[i + 1]
|
119 |
+
# energy normalization
|
120 |
+
e1 = np.mean(np.square(waveforms[n1]))
|
121 |
+
e2 = np.mean(np.square(waveforms[n2]))
|
122 |
+
ratio = (e1 / max(1e-8, e2)) ** 0.5
|
123 |
+
ratio = np.clip(ratio, 0.02, 50)
|
124 |
+
waveforms[n2] *= ratio
|
125 |
+
mixture = waveforms[n1] + waveforms[n2]
|
126 |
+
# form data
|
127 |
+
if mix_type == "clean":
|
128 |
+
mixtures.append(waveforms[n1])
|
129 |
+
mixtures.append(waveforms[n2])
|
130 |
+
sources.append(waveforms[n1])
|
131 |
+
sources.append(waveforms[n2])
|
132 |
+
elif mix_type == "silence":
|
133 |
+
mixtures.append(waveforms[n2])
|
134 |
+
mixtures.append(waveforms[n1])
|
135 |
+
sources.append(np.zeros_like(waveforms[n1]))
|
136 |
+
sources.append(np.zeros_like(waveforms[n2]))
|
137 |
+
else:
|
138 |
+
mixtures.append(mixture)
|
139 |
+
mixtures.append(mixture)
|
140 |
+
sources.append(waveforms[n1])
|
141 |
+
sources.append(waveforms[n2])
|
142 |
+
|
143 |
+
conditions.append(con_vectors[n1])
|
144 |
+
conditions.append(con_vectors[n2])
|
145 |
+
gds.append(class_ids[n1])
|
146 |
+
gds.append(class_ids[n2])
|
147 |
+
return mixtures, sources, conditions, gds
|
148 |
+
|
149 |
+
# generate a list
|
150 |
+
def get_balanced_class_list(index_path, factor = 3, black_list = None, random_seed = 0):
|
151 |
+
# initialization
|
152 |
+
random_state = np.random.RandomState(random_seed)
|
153 |
+
logging.info("Load Indexes...............")
|
154 |
+
with h5py.File(index_path, "r") as hf:
|
155 |
+
indexes = hf["index_in_hdf5"][:]
|
156 |
+
targets = hf["target"][:].astype(np.float32)
|
157 |
+
(audios_num, classes_num) = targets.shape
|
158 |
+
# set the indexes per class for balanced list
|
159 |
+
indexes_per_class = []
|
160 |
+
for k in range(classes_num):
|
161 |
+
indexes_per_class.append(
|
162 |
+
np.where(targets[:, k] == 1)[0]
|
163 |
+
)
|
164 |
+
|
165 |
+
logging.info("Load Indexes Succeed...............")
|
166 |
+
|
167 |
+
return indexes_per_class
|
168 |
+
|
169 |
+
def dataset_worker_init_fn_seed(worker_id):
|
170 |
+
seed = np.random.randint(0, 224141) + worker_id * np.random.randint(100,1000)
|
171 |
+
print(seed)
|
172 |
+
np.random.seed(seed)
|
173 |
+
|
174 |
+
def calculate_sdr(ref, est, scaling=False):
|
175 |
+
s = museval.evaluate(ref[None,:,None], est[None,:,None], win = len(ref), hop = len(ref))
|
176 |
+
return s[0][0]
|
177 |
+
|
178 |
+
def butter_lowpass_filter(data, cuton, cutoff, fs, order):
|
179 |
+
normal_cutoff = cutoff / (0.5 * fs)
|
180 |
+
normal_cuton = cuton / (0.5 * fs)
|
181 |
+
b, a = butter(order, [normal_cuton, normal_cutoff], btype="band", analog=False)
|
182 |
+
y = filtfilt(b,a, data)
|
183 |
+
return y
|
184 |
+
|
185 |
+
def calculate_silence_sdr(mixture, est):
|
186 |
+
sdr = 10. * (
|
187 |
+
np.log10(np.clip(np.mean(mixture ** 2), 1e-8, np.inf)) \
|
188 |
+
- np.log10(np.clip(np.mean(est ** 2), 1e-8, np.inf)))
|
189 |
+
return sdr
|
190 |
+
|
191 |
+
|
192 |
+
def evaluate_sdr(ref, est, class_ids, mix_type = "mixture"):
|
193 |
+
sdr_results = []
|
194 |
+
if mix_type == "silence":
|
195 |
+
for i in range(len(ref)):
|
196 |
+
sdr = calculate_silence_sdr(ref[i,:,0], est[i,:,0])
|
197 |
+
sdr_results.append([sdr, class_ids[i]])
|
198 |
+
else:
|
199 |
+
for i in range(len(ref)):
|
200 |
+
if np.sum(ref[i,:,0]) == 0 or np.sum(est[i,:,0]) == 0:
|
201 |
+
continue
|
202 |
+
else:
|
203 |
+
sdr_c = calculate_sdr(ref[i,:,0], est[i,:,0], scaling = True)
|
204 |
+
sdr_results.append([sdr_c, class_ids[i]])
|
205 |
+
return sdr_results
|
206 |
+
|
207 |
+
# set the audio into the format that can be fed into the model
|
208 |
+
# resample -> convert to mono -> output the audio
|
209 |
+
# track [n_sample, n_channel]
|
210 |
+
def prepprocess_audio(track, ofs, rfs, mono_type = "mix"):
|
211 |
+
if track.shape[-1] > 1:
|
212 |
+
# stereo
|
213 |
+
if mono_type == "mix":
|
214 |
+
track = np.transpose(track, (1,0))
|
215 |
+
track = librosa.to_mono(track)
|
216 |
+
elif mono_type == "left":
|
217 |
+
track = track[:, 0]
|
218 |
+
elif mono_type == "right":
|
219 |
+
track = track[:, 1]
|
220 |
+
else:
|
221 |
+
track = track[:, 0]
|
222 |
+
# track [n_sample]
|
223 |
+
if ofs != rfs:
|
224 |
+
track = librosa.resample(track, ofs, rfs)
|
225 |
+
return track
|
226 |
+
|
227 |
+
# *************************************************
|
228 |
+
# all below is referred from the wiener filter code
|
229 |
+
|
230 |
+
def atan2(y, x):
|
231 |
+
r"""Element-wise arctangent function of y/x.
|
232 |
+
Returns a new tensor with signed angles in radians.
|
233 |
+
It is an alternative implementation of torch.atan2
|
234 |
+
Args:
|
235 |
+
y (Tensor): First input tensor
|
236 |
+
x (Tensor): Second input tensor [shape=y.shape]
|
237 |
+
Returns:
|
238 |
+
Tensor: [shape=y.shape].
|
239 |
+
"""
|
240 |
+
pi = 2 * torch.asin(torch.tensor(1.0))
|
241 |
+
x += ((x == 0) & (y == 0)) * 1.0
|
242 |
+
out = torch.atan(y / x)
|
243 |
+
out += ((y >= 0) & (x < 0)) * pi
|
244 |
+
out -= ((y < 0) & (x < 0)) * pi
|
245 |
+
out *= 1 - ((y > 0) & (x == 0)) * 1.0
|
246 |
+
out += ((y > 0) & (x == 0)) * (pi / 2)
|
247 |
+
out *= 1 - ((y < 0) & (x == 0)) * 1.0
|
248 |
+
out += ((y < 0) & (x == 0)) * (-pi / 2)
|
249 |
+
return out
|
250 |
+
|
251 |
+
|
252 |
+
# Define basic complex operations on torch.Tensor objects whose last dimension
|
253 |
+
# consists in the concatenation of the real and imaginary parts.
|
254 |
+
def _norm(x: torch.Tensor) -> torch.Tensor:
|
255 |
+
r"""Computes the norm value of a torch Tensor, assuming that it
|
256 |
+
comes as real and imaginary part in its last dimension.
|
257 |
+
Args:
|
258 |
+
x (Tensor): Input Tensor of shape [shape=(..., 2)]
|
259 |
+
Returns:
|
260 |
+
Tensor: shape as x excluding the last dimension.
|
261 |
+
"""
|
262 |
+
return torch.abs(x[..., 0]) ** 2 + torch.abs(x[..., 1]) ** 2
|
263 |
+
|
264 |
+
|
265 |
+
def _mul_add(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
266 |
+
"""Element-wise multiplication of two complex Tensors described
|
267 |
+
through their real and imaginary parts.
|
268 |
+
The result is added to the `out` tensor"""
|
269 |
+
|
270 |
+
# check `out` and allocate it if needed
|
271 |
+
target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
|
272 |
+
if out is None or out.shape != target_shape:
|
273 |
+
out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
|
274 |
+
if out is a:
|
275 |
+
real_a = a[..., 0]
|
276 |
+
out[..., 0] = out[..., 0] + (real_a * b[..., 0] - a[..., 1] * b[..., 1])
|
277 |
+
out[..., 1] = out[..., 1] + (real_a * b[..., 1] + a[..., 1] * b[..., 0])
|
278 |
+
else:
|
279 |
+
out[..., 0] = out[..., 0] + (a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1])
|
280 |
+
out[..., 1] = out[..., 1] + (a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0])
|
281 |
+
return out
|
282 |
+
|
283 |
+
|
284 |
+
def _mul(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
285 |
+
"""Element-wise multiplication of two complex Tensors described
|
286 |
+
through their real and imaginary parts
|
287 |
+
can work in place in case out is a only"""
|
288 |
+
target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
|
289 |
+
if out is None or out.shape != target_shape:
|
290 |
+
out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
|
291 |
+
if out is a:
|
292 |
+
real_a = a[..., 0]
|
293 |
+
out[..., 0] = real_a * b[..., 0] - a[..., 1] * b[..., 1]
|
294 |
+
out[..., 1] = real_a * b[..., 1] + a[..., 1] * b[..., 0]
|
295 |
+
else:
|
296 |
+
out[..., 0] = a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1]
|
297 |
+
out[..., 1] = a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0]
|
298 |
+
return out
|
299 |
+
|
300 |
+
|
301 |
+
def _inv(z: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
302 |
+
"""Element-wise multiplicative inverse of a Tensor with complex
|
303 |
+
entries described through their real and imaginary parts.
|
304 |
+
can work in place in case out is z"""
|
305 |
+
ez = _norm(z)
|
306 |
+
if out is None or out.shape != z.shape:
|
307 |
+
out = torch.zeros_like(z)
|
308 |
+
out[..., 0] = z[..., 0] / ez
|
309 |
+
out[..., 1] = -z[..., 1] / ez
|
310 |
+
return out
|
311 |
+
|
312 |
+
|
313 |
+
def _conj(z, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
314 |
+
"""Element-wise complex conjugate of a Tensor with complex entries
|
315 |
+
described through their real and imaginary parts.
|
316 |
+
can work in place in case out is z"""
|
317 |
+
if out is None or out.shape != z.shape:
|
318 |
+
out = torch.zeros_like(z)
|
319 |
+
out[..., 0] = z[..., 0]
|
320 |
+
out[..., 1] = -z[..., 1]
|
321 |
+
return out
|
322 |
+
|
323 |
+
|
324 |
+
def _invert(M: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
325 |
+
"""
|
326 |
+
Invert 1x1 or 2x2 matrices
|
327 |
+
Will generate errors if the matrices are singular: user must handle this
|
328 |
+
through his own regularization schemes.
|
329 |
+
Args:
|
330 |
+
M (Tensor): [shape=(..., nb_channels, nb_channels, 2)]
|
331 |
+
matrices to invert: must be square along dimensions -3 and -2
|
332 |
+
Returns:
|
333 |
+
invM (Tensor): [shape=M.shape]
|
334 |
+
inverses of M
|
335 |
+
"""
|
336 |
+
nb_channels = M.shape[-2]
|
337 |
+
|
338 |
+
if out is None or out.shape != M.shape:
|
339 |
+
out = torch.empty_like(M)
|
340 |
+
|
341 |
+
if nb_channels == 1:
|
342 |
+
# scalar case
|
343 |
+
out = _inv(M, out)
|
344 |
+
elif nb_channels == 2:
|
345 |
+
# two channels case: analytical expression
|
346 |
+
|
347 |
+
# first compute the determinent
|
348 |
+
det = _mul(M[..., 0, 0, :], M[..., 1, 1, :])
|
349 |
+
det = det - _mul(M[..., 0, 1, :], M[..., 1, 0, :])
|
350 |
+
# invert it
|
351 |
+
invDet = _inv(det)
|
352 |
+
|
353 |
+
# then fill out the matrix with the inverse
|
354 |
+
out[..., 0, 0, :] = _mul(invDet, M[..., 1, 1, :], out[..., 0, 0, :])
|
355 |
+
out[..., 1, 0, :] = _mul(-invDet, M[..., 1, 0, :], out[..., 1, 0, :])
|
356 |
+
out[..., 0, 1, :] = _mul(-invDet, M[..., 0, 1, :], out[..., 0, 1, :])
|
357 |
+
out[..., 1, 1, :] = _mul(invDet, M[..., 0, 0, :], out[..., 1, 1, :])
|
358 |
+
else:
|
359 |
+
raise Exception("Only 2 channels are supported for the torch version.")
|
360 |
+
return out
|
361 |
+
|
362 |
+
|
363 |
+
|
364 |
+
def expectation_maximization(
|
365 |
+
y: torch.Tensor,
|
366 |
+
x: torch.Tensor,
|
367 |
+
iterations: int = 2,
|
368 |
+
eps: float = 1e-10,
|
369 |
+
batch_size: int = 200,
|
370 |
+
):
|
371 |
+
r"""Expectation maximization algorithm, for refining source separation
|
372 |
+
estimates.
|
373 |
+
Args:
|
374 |
+
y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)]
|
375 |
+
initial estimates for the sources
|
376 |
+
x (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2)]
|
377 |
+
complex STFT of the mixture signal
|
378 |
+
iterations (int): [scalar]
|
379 |
+
number of iterations for the EM algorithm.
|
380 |
+
eps (float or None): [scalar]
|
381 |
+
The epsilon value to use for regularization and filters.
|
382 |
+
Returns:
|
383 |
+
y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)]
|
384 |
+
estimated sources after iterations
|
385 |
+
v (Tensor): [shape=(nb_frames, nb_bins, nb_sources)]
|
386 |
+
estimated power spectral densities
|
387 |
+
R (Tensor): [shape=(nb_bins, nb_channels, nb_channels, 2, nb_sources)]
|
388 |
+
estimated spatial covariance matrices
|
389 |
+
"""
|
390 |
+
# dimensions
|
391 |
+
(nb_frames, nb_bins, nb_channels) = x.shape[:-1]
|
392 |
+
nb_sources = y.shape[-1]
|
393 |
+
|
394 |
+
regularization = torch.cat(
|
395 |
+
(
|
396 |
+
torch.eye(nb_channels, dtype=x.dtype, device=x.device)[..., None],
|
397 |
+
torch.zeros((nb_channels, nb_channels, 1), dtype=x.dtype, device=x.device),
|
398 |
+
),
|
399 |
+
dim=2,
|
400 |
+
)
|
401 |
+
regularization = torch.sqrt(torch.as_tensor(eps)) * (
|
402 |
+
regularization[None, None, ...].expand((-1, nb_bins, -1, -1, -1))
|
403 |
+
)
|
404 |
+
|
405 |
+
# allocate the spatial covariance matrices
|
406 |
+
R = [
|
407 |
+
torch.zeros((nb_bins, nb_channels, nb_channels, 2), dtype=x.dtype, device=x.device)
|
408 |
+
for j in range(nb_sources)
|
409 |
+
]
|
410 |
+
weight: torch.Tensor = torch.zeros((nb_bins,), dtype=x.dtype, device=x.device)
|
411 |
+
|
412 |
+
v: torch.Tensor = torch.zeros((nb_frames, nb_bins, nb_sources), dtype=x.dtype, device=x.device)
|
413 |
+
for it in range(iterations):
|
414 |
+
# constructing the mixture covariance matrix. Doing it with a loop
|
415 |
+
# to avoid storing anytime in RAM the whole 6D tensor
|
416 |
+
|
417 |
+
# update the PSD as the average spectrogram over channels
|
418 |
+
v = torch.mean(torch.abs(y[..., 0, :]) ** 2 + torch.abs(y[..., 1, :]) ** 2, dim=-2)
|
419 |
+
|
420 |
+
# update spatial covariance matrices (weighted update)
|
421 |
+
for j in range(nb_sources):
|
422 |
+
R[j] = torch.tensor(0.0, device=x.device)
|
423 |
+
weight = torch.tensor(eps, device=x.device)
|
424 |
+
pos: int = 0
|
425 |
+
batch_size = batch_size if batch_size else nb_frames
|
426 |
+
while pos < nb_frames:
|
427 |
+
t = torch.arange(pos, min(nb_frames, pos + batch_size))
|
428 |
+
pos = int(t[-1]) + 1
|
429 |
+
|
430 |
+
R[j] = R[j] + torch.sum(_covariance(y[t, ..., j]), dim=0)
|
431 |
+
weight = weight + torch.sum(v[t, ..., j], dim=0)
|
432 |
+
R[j] = R[j] / weight[..., None, None, None]
|
433 |
+
weight = torch.zeros_like(weight)
|
434 |
+
|
435 |
+
# cloning y if we track gradient, because we're going to update it
|
436 |
+
if y.requires_grad:
|
437 |
+
y = y.clone()
|
438 |
+
|
439 |
+
pos = 0
|
440 |
+
while pos < nb_frames:
|
441 |
+
t = torch.arange(pos, min(nb_frames, pos + batch_size))
|
442 |
+
pos = int(t[-1]) + 1
|
443 |
+
|
444 |
+
y[t, ...] = torch.tensor(0.0, device=x.device)
|
445 |
+
|
446 |
+
# compute mix covariance matrix
|
447 |
+
Cxx = regularization
|
448 |
+
for j in range(nb_sources):
|
449 |
+
Cxx = Cxx + (v[t, ..., j, None, None, None] * R[j][None, ...].clone())
|
450 |
+
|
451 |
+
# invert it
|
452 |
+
inv_Cxx = _invert(Cxx)
|
453 |
+
|
454 |
+
# separate the sources
|
455 |
+
for j in range(nb_sources):
|
456 |
+
|
457 |
+
# create a wiener gain for this source
|
458 |
+
gain = torch.zeros_like(inv_Cxx)
|
459 |
+
|
460 |
+
# computes multichannel Wiener gain as v_j R_j inv_Cxx
|
461 |
+
indices = torch.cartesian_prod(
|
462 |
+
torch.arange(nb_channels),
|
463 |
+
torch.arange(nb_channels),
|
464 |
+
torch.arange(nb_channels),
|
465 |
+
)
|
466 |
+
for index in indices:
|
467 |
+
gain[:, :, index[0], index[1], :] = _mul_add(
|
468 |
+
R[j][None, :, index[0], index[2], :].clone(),
|
469 |
+
inv_Cxx[:, :, index[2], index[1], :],
|
470 |
+
gain[:, :, index[0], index[1], :],
|
471 |
+
)
|
472 |
+
gain = gain * v[t, ..., None, None, None, j]
|
473 |
+
|
474 |
+
# apply it to the mixture
|
475 |
+
for i in range(nb_channels):
|
476 |
+
y[t, ..., j] = _mul_add(gain[..., i, :], x[t, ..., i, None, :], y[t, ..., j])
|
477 |
+
|
478 |
+
return y, v, R
|
479 |
+
|
480 |
+
def _covariance(y_j):
|
481 |
+
"""
|
482 |
+
Compute the empirical covariance for a source.
|
483 |
+
Args:
|
484 |
+
y_j (Tensor): complex stft of the source.
|
485 |
+
[shape=(nb_frames, nb_bins, nb_channels, 2)].
|
486 |
+
Returns:
|
487 |
+
Cj (Tensor): [shape=(nb_frames, nb_bins, nb_channels, nb_channels, 2)]
|
488 |
+
just y_j * conj(y_j.T): empirical covariance for each TF bin.
|
489 |
+
"""
|
490 |
+
(nb_frames, nb_bins, nb_channels) = y_j.shape[:-1]
|
491 |
+
Cj = torch.zeros(
|
492 |
+
(nb_frames, nb_bins, nb_channels, nb_channels, 2),
|
493 |
+
dtype=y_j.dtype,
|
494 |
+
device=y_j.device,
|
495 |
+
)
|
496 |
+
indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels))
|
497 |
+
for index in indices:
|
498 |
+
Cj[:, :, index[0], index[1], :] = _mul_add(
|
499 |
+
y_j[:, :, index[0], :],
|
500 |
+
_conj(y_j[:, :, index[1], :]),
|
501 |
+
Cj[:, :, index[0], index[1], :],
|
502 |
+
)
|
503 |
+
return Cj
|
504 |
+
|
505 |
+
def wiener(
|
506 |
+
targets_spectrograms: torch.Tensor,
|
507 |
+
mix_stft: torch.Tensor,
|
508 |
+
iterations: int = 1,
|
509 |
+
softmask: bool = False,
|
510 |
+
residual: bool = False,
|
511 |
+
scale_factor: float = 10.0,
|
512 |
+
eps: float = 1e-10,
|
513 |
+
):
|
514 |
+
"""Wiener-based separation for multichannel audio.
|
515 |
+
Returns:
|
516 |
+
Tensor: shape=(nb_frames, nb_bins, nb_channels, complex=2, nb_sources)
|
517 |
+
STFT of estimated sources
|
518 |
+
"""
|
519 |
+
if softmask:
|
520 |
+
# if we use softmask, we compute the ratio mask for all targets and
|
521 |
+
# multiply by the mix stft
|
522 |
+
y = (
|
523 |
+
mix_stft[..., None]
|
524 |
+
* (
|
525 |
+
targets_spectrograms
|
526 |
+
/ (eps + torch.sum(targets_spectrograms, dim=-1, keepdim=True).to(mix_stft.dtype))
|
527 |
+
)[..., None, :]
|
528 |
+
)
|
529 |
+
else:
|
530 |
+
# otherwise, we just multiply the targets spectrograms with mix phase
|
531 |
+
# we tacitly assume that we have magnitude estimates.
|
532 |
+
angle = atan2(mix_stft[..., 1], mix_stft[..., 0])[..., None]
|
533 |
+
nb_sources = targets_spectrograms.shape[-1]
|
534 |
+
y = torch.zeros(
|
535 |
+
mix_stft.shape + (nb_sources,), dtype=mix_stft.dtype, device=mix_stft.device
|
536 |
+
)
|
537 |
+
y[..., 0, :] = targets_spectrograms * torch.cos(angle)
|
538 |
+
y[..., 1, :] = targets_spectrograms * torch.sin(angle)
|
539 |
+
|
540 |
+
if residual:
|
541 |
+
# if required, adding an additional target as the mix minus
|
542 |
+
# available targets
|
543 |
+
y = torch.cat([y, mix_stft[..., None] - y.sum(dim=-1, keepdim=True)], dim=-1)
|
544 |
+
|
545 |
+
if iterations == 0:
|
546 |
+
return y
|
547 |
+
|
548 |
+
# we need to refine the estimates. Scales down the estimates for
|
549 |
+
# numerical stability
|
550 |
+
max_abs = torch.max(
|
551 |
+
torch.as_tensor(1.0, dtype=mix_stft.dtype, device=mix_stft.device),
|
552 |
+
torch.sqrt(_norm(mix_stft)).max() / scale_factor,
|
553 |
+
)
|
554 |
+
|
555 |
+
mix_stft = mix_stft / max_abs
|
556 |
+
y = y / max_abs
|
557 |
+
|
558 |
+
# call expectation maximization
|
559 |
+
y = expectation_maximization(y, mix_stft, iterations, eps=eps)[0]
|
560 |
+
|
561 |
+
# scale estimates up again
|
562 |
+
y = y * max_abs
|
563 |
+
return y
|
564 |
+
|
565 |
+
def split_nparray_with_overlap(array, array_size, overlap_size):
|
566 |
+
result = []
|
567 |
+
element_size = int(len(array) / array_size)
|
568 |
+
for i in range(array_size):
|
569 |
+
offset = int(i * element_size)
|
570 |
+
last_loop = i == array_size
|
571 |
+
chunk = array[offset : offset + element_size + (0 if last_loop else overlap_size)]
|
572 |
+
chunk = chunk.copy()
|
573 |
+
chunk.resize(element_size + overlap_size, refcheck = False)
|
574 |
+
result.append(chunk)
|
575 |
+
|
576 |
+
return np.array(result)
|
577 |
+
|
578 |
+
|
579 |
+
|
580 |
+
|
zero_shot_create_vector.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ke Chen
|
2 | |
3 |
+
# Zero-shot Audio Source Separation via Query-based Learning from Weakly-labeled Data
|
4 |
+
# The Main Script
|
5 |
+
|
6 |
+
import os
|
7 |
+
gpu_use = 0
|
8 |
+
# this is to avoid the sdr calculation from occupying all cpus
|
9 |
+
os.environ["OMP_NUM_THREADS"] = "4"
|
10 |
+
os.environ["OPENBLAS_NUM_THREADS"] = "4"
|
11 |
+
os.environ["MKL_NUM_THREADS"] = "6"
|
12 |
+
os.environ["VECLIB_MAXIMUM_THREADS"] = "4"
|
13 |
+
os.environ["NUMEXPR_NUM_THREADS"] = "6"
|
14 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(gpu_use)
|
15 |
+
|
16 |
+
import librosa
|
17 |
+
import numpy as np
|
18 |
+
import soundfile as sf
|
19 |
+
from hashlib import md5
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from torch.utils.data import DataLoader
|
23 |
+
from utils import collect_fn, dump_config, create_folder, prepprocess_audio
|
24 |
+
from models.asp_model import ZeroShotASP, SeparatorModel, AutoTaggingWarpper, WhitingWarpper
|
25 |
+
from data_processor import LGSPDataset, MusdbDataset
|
26 |
+
import config
|
27 |
+
import htsat_config
|
28 |
+
from models.htsat import HTSAT_Swin_Transformer
|
29 |
+
from sed_model import SEDWrapper
|
30 |
+
|
31 |
+
import pytorch_lightning as pl
|
32 |
+
|
33 |
+
import time
|
34 |
+
import tqdm
|
35 |
+
import warnings
|
36 |
+
import shutil
|
37 |
+
import pickle
|
38 |
+
warnings.filterwarnings("ignore")
|
39 |
+
|
40 |
+
# use the model to quickly separate a track given a query
|
41 |
+
# it requires four variables in config.py:
|
42 |
+
# inference_file: the track you want to separate
|
43 |
+
# inference_query: a **folder** containing all samples from the same source
|
44 |
+
# test_key: ["name"] indicate the source name (just a name for final output, no other functions)
|
45 |
+
# wave_output_path: the output folder
|
46 |
+
|
47 |
+
# make sure the query folder contain the samples from the same source
|
48 |
+
# each time, the model is able to separate one source from the track
|
49 |
+
# if you want to separate multiple sources, you need to change the query folder or write a script to help you do that
|
50 |
+
|
51 |
+
|
52 |
+
def save_in_file_fast(arr, file_name):
|
53 |
+
pickle.dump(arr, open(file_name, 'wb'), protocol=4)
|
54 |
+
|
55 |
+
|
56 |
+
def load_from_file_fast(file_name):
|
57 |
+
return pickle.load(open(file_name, 'rb'))
|
58 |
+
|
59 |
+
|
60 |
+
def create_vector():
|
61 |
+
test_type = 'mix'
|
62 |
+
inference_file = config.inference_file
|
63 |
+
inference_query = config.inference_query
|
64 |
+
test_key = config.test_key
|
65 |
+
wave_output_path = config.wave_output_path
|
66 |
+
sample_rate = config.sample_rate
|
67 |
+
resume_checkpoint_zeroshot = config.resume_checkpoint
|
68 |
+
resume_checkpoint_htsat = htsat_config.resume_checkpoint
|
69 |
+
print('Inference query folder: {}'.format(inference_query))
|
70 |
+
print('Test key: {}'.format(test_key))
|
71 |
+
print('Vector out folder: {}'.format(wave_output_path))
|
72 |
+
print('Sample rate: {}'.format(sample_rate))
|
73 |
+
print('Model 1 (zeroshot): {}'.format(resume_checkpoint_zeroshot))
|
74 |
+
|
75 |
+
# set exp settings
|
76 |
+
device_name = "cuda" if torch.cuda.is_available() else "cpu"
|
77 |
+
device = torch.device("cuda")
|
78 |
+
create_folder(wave_output_path)
|
79 |
+
|
80 |
+
# obtain the samples for query
|
81 |
+
queries = []
|
82 |
+
query_names = []
|
83 |
+
for query_file in tqdm.tqdm(os.listdir(inference_query)):
|
84 |
+
f_path = os.path.join(inference_query, query_file)
|
85 |
+
if query_file.endswith(".wav"):
|
86 |
+
temp_q, fs = librosa.load(f_path, sr=None)
|
87 |
+
temp_q = temp_q[:, None]
|
88 |
+
temp_q = prepprocess_audio(
|
89 |
+
temp_q,
|
90 |
+
fs,
|
91 |
+
sample_rate,
|
92 |
+
test_type
|
93 |
+
)
|
94 |
+
temp = [temp_q]
|
95 |
+
for dickey in test_key:
|
96 |
+
temp.append(temp_q)
|
97 |
+
temp = np.array(temp)
|
98 |
+
queries.append(temp)
|
99 |
+
query_names.append(os.path.basename(query_file))
|
100 |
+
|
101 |
+
sed_model = HTSAT_Swin_Transformer(
|
102 |
+
spec_size=htsat_config.htsat_spec_size,
|
103 |
+
patch_size=htsat_config.htsat_patch_size,
|
104 |
+
in_chans=1,
|
105 |
+
num_classes=htsat_config.classes_num,
|
106 |
+
window_size=htsat_config.htsat_window_size,
|
107 |
+
config=htsat_config,
|
108 |
+
depths=htsat_config.htsat_depth,
|
109 |
+
embed_dim=htsat_config.htsat_dim,
|
110 |
+
patch_stride=htsat_config.htsat_stride,
|
111 |
+
num_heads=htsat_config.htsat_num_head
|
112 |
+
)
|
113 |
+
at_model = SEDWrapper(
|
114 |
+
sed_model=sed_model,
|
115 |
+
config=htsat_config,
|
116 |
+
dataset=None
|
117 |
+
)
|
118 |
+
ckpt = torch.load(resume_checkpoint_htsat, map_location="cpu")
|
119 |
+
at_model.load_state_dict(ckpt["state_dict"])
|
120 |
+
|
121 |
+
if device_name == 'cpu':
|
122 |
+
trainer = pl.Trainer(
|
123 |
+
accelerator="cpu", gpus=None
|
124 |
+
)
|
125 |
+
else:
|
126 |
+
trainer = pl.Trainer(
|
127 |
+
gpus=1
|
128 |
+
)
|
129 |
+
|
130 |
+
print('Process: {}'.format(len(queries)))
|
131 |
+
avg_dataset = MusdbDataset(
|
132 |
+
tracks=queries
|
133 |
+
)
|
134 |
+
avg_loader = DataLoader(
|
135 |
+
dataset=avg_dataset,
|
136 |
+
num_workers=1,
|
137 |
+
batch_size=1,
|
138 |
+
shuffle=False
|
139 |
+
)
|
140 |
+
at_wrapper = AutoTaggingWarpper(
|
141 |
+
at_model=at_model,
|
142 |
+
config=config,
|
143 |
+
target_keys=test_key
|
144 |
+
)
|
145 |
+
trainer.test(
|
146 |
+
at_wrapper,
|
147 |
+
test_dataloaders=avg_loader
|
148 |
+
)
|
149 |
+
avg_at = at_wrapper.avg_at
|
150 |
+
|
151 |
+
md5_str = str(md5(str(queries).encode('utf-8')).hexdigest())
|
152 |
+
out_vector_path = wave_output_path + '/{}_vector_{}.pkl'.format(test_key[0], md5_str)
|
153 |
+
save_in_file_fast(avg_at, out_vector_path)
|
154 |
+
print('Vector saved in: {}'.format(out_vector_path))
|
155 |
+
|
156 |
+
|
157 |
+
if __name__ == '__main__':
|
158 |
+
create_vector()
|