cronrpc commited on
Commit
5586f24
·
1 Parent(s): 1313332

upload all

Browse files
.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
cache/embs_cache/e609b16c50d879c.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18708e1f64ab9e90daf8a13edbd24c4ff419db1d239668be6dc5dd334837bc68
3
+ size 6791535
download_audios.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tarfile
2
+ import glob
3
+ import os
4
+ from huggingface_hub import hf_hub_download
5
+
6
+
7
+ def download_audios():
8
+ wav_files = sorted(glob.glob(os.path.join("audios", '*.wav')))
9
+ if len(wav_files) == 0:
10
+ audios_targz_path = hf_hub_download(repo_id="omniway/Audio_speaker_needle_in_haystack", filename="audios.tar.gz", repo_type="dataset")
11
+ tar = tarfile.open(audios_targz_path, 'r:gz')
12
+ tar.extractall(path='.')
13
+ tar.close()
14
+
15
+ if __name__ == '__main__':
16
+ download_audios()
examples/seed1037_index0.wav ADDED
Binary file (164 kB). View file
 
examples/seed2_index0.wav ADDED
Binary file (290 kB). View file
 
examples/seed452_index1.wav ADDED
Binary file (270 kB). View file
 
examples/seed5_index6.wav ADDED
Binary file (486 kB). View file
 
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ffmpeg
2
+ libsox-dev
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ modelscope==1.15.0
2
+ torch==2.3.1
3
+ torchaudio==2.3.1
4
+ torchvision==0.18.1
5
+ tqdm
6
+ librosa
7
+ soundfile
webui_speaker_needle_in_haystack.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import operator
3
+ import glob
4
+ import librosa
5
+ import argparse
6
+ import hashlib
7
+ import gradio as gr
8
+ import numpy as np
9
+ import pickle
10
+
11
+ from tqdm import tqdm
12
+ from modelscope.pipelines import pipeline
13
+ from download_audios import download_audios
14
+
15
+ """
16
+
17
+ Audio Speaker needle in haystack
18
+ cronrpc
19
+ https://github.com/cronrpc
20
+
21
+ """
22
+
23
+ MAX_DISPLAY_AUDIO_NUMBER = 10
24
+ g_gr_audio_list = []
25
+
26
+
27
+ class Speaker_Needle_In_Haystack():
28
+ SAMPLE_RATE = 16000
29
+
30
+ def __init__(self, pickle_support = False) -> None:
31
+ self._load_model()
32
+ self.all_embs = {}
33
+ self.cosine_score = {}
34
+ self.pickle_support = pickle_support
35
+ pass
36
+
37
+ def set_audio_list_dir(self, dir_path):
38
+ self.audio_list_dir = dir_path
39
+
40
+ def _load_model(self) -> None:
41
+ # could switch model here
42
+
43
+ self.model_name = 'damo/speech_eres2netv2_sv_zh-cn_16k-common'
44
+ self.sv_pipline = pipeline(
45
+ task='speaker-verification',
46
+ model=self.model_name,
47
+ model_revision='v1.0.1'
48
+ )
49
+
50
+ # self.model_name = 'iic/speech_campplus_sv_zh-cn_3dspeaker_16k'
51
+ # self.sv_pipline = pipeline(
52
+ # task='speaker-verification',
53
+ # model=self.model_name
54
+ # )
55
+
56
+ def _get_emb(self, audio) -> None:
57
+ if isinstance(audio, str):
58
+ audio, sr = librosa.load(audio, sr=self.SAMPLE_RATE, mono=True)
59
+ return self.sv_pipline([audio], output_emb=True)['embs'] # (1,196) np array
60
+ elif isinstance(audio, list):
61
+ return self.sv_pipline(audio, output_emb=True)['embs'] # (n,196) np array
62
+ else:
63
+ return self.sv_pipline([audio], output_emb=True)['embs'] # (1,196) np array
64
+
65
+ def _cosine_similarity_compute(self, emb1, emb2):
66
+ emb1 = np.squeeze(emb1)
67
+ emb2 = np.squeeze(emb2)
68
+ dot_product = np.dot(emb1, emb2)
69
+ norm_vector1 = np.linalg.norm(emb1)
70
+ norm_vector2 = np.linalg.norm(emb2)
71
+ cosine_similarity = dot_product / (norm_vector1 * norm_vector2)
72
+ return cosine_similarity
73
+
74
+ def compute_all_embs(self, batch_size=1):
75
+ wav_files = sorted(glob.glob(os.path.join(self.audio_list_dir, '*.wav')))
76
+
77
+ # hash to skip
78
+ file_string = self.model_name + ''.join(wav_files)
79
+ hash_file = hashlib.sha256(file_string.encode()).hexdigest()[:15] + ".pkl"
80
+ if self.pickle_support:
81
+ cache_dir = os.path.join('cache','embs_cache')
82
+ os.makedirs(cache_dir, exist_ok=True)
83
+ hash_file = os.path.join(cache_dir, hash_file)
84
+ if os.path.exists(hash_file):
85
+ print("load pickle embs")
86
+ self.load_all_embs(hash_file)
87
+ return
88
+
89
+ self.all_embs = {}
90
+ num_files = len(wav_files)
91
+ num_batches = (num_files + batch_size - 1) // batch_size
92
+
93
+ for batch_idx in tqdm(range(num_batches)):
94
+ start_idx = batch_idx * batch_size
95
+ end_idx = min((batch_idx + 1) * batch_size, num_files)
96
+ batch_files = wav_files[start_idx:end_idx]
97
+ batch_audio = []
98
+
99
+ for file_path in batch_files:
100
+ audio, sr = librosa.load(file_path, sr=self.SAMPLE_RATE, mono=True)
101
+ batch_audio.append(audio)
102
+
103
+ embs = self._get_emb(batch_audio)
104
+
105
+ for i, file_path in enumerate(batch_files):
106
+ self.all_embs[file_path] = embs[i]
107
+
108
+ # save the self.all_embs in hash_value named file
109
+ if self.pickle_support:
110
+ self.save_all_embs(hash_file)
111
+
112
+ def compute_target_aduio_cosine_score(self, target_audio):
113
+ self.cosine_score = {}
114
+ target_emb = self._get_emb(target_audio)
115
+ for file_path, emb in self.all_embs.items():
116
+ self.cosine_score[file_path] = self._cosine_similarity_compute(target_emb, emb)
117
+
118
+ def get_cosine_next_top_k(self, k, start = 0):
119
+ top_subset = sorted(self.cosine_score.items(), key=operator.itemgetter(1), reverse=True)[start: start + k]
120
+ return top_subset
121
+
122
+ def save_all_embs(self, hash_file):
123
+ file_path = hash_file
124
+ with open(file_path, 'wb') as file:
125
+ pickle.dump(self.all_embs, file)
126
+
127
+ def load_all_embs(self, hash_file):
128
+ file_path = hash_file
129
+ with open(file_path, 'rb') as file:
130
+ self.all_embs = pickle.load(file)
131
+
132
+
133
+ def get_similar_score_audio(audio, start_index):
134
+ output = []
135
+ top_subset = []
136
+
137
+ if audio != None:
138
+ sr, y = audio
139
+ if len(y.shape) == 2:
140
+ y = np.mean(y, axis=-1)
141
+ audio_16k = librosa.resample(y.astype(np.float32), orig_sr=sr, target_sr=snih.SAMPLE_RATE)
142
+ snih.compute_target_aduio_cosine_score(audio_16k)
143
+ top_subset = snih.get_cosine_next_top_k(MAX_DISPLAY_AUDIO_NUMBER, start=start_index)
144
+
145
+ for i in range(0, len(top_subset)):
146
+ path, score = top_subset[i]
147
+ file_name = os.path.basename(path)
148
+ output.append(
149
+ {
150
+ "__type__":"update",
151
+ "value":path,
152
+ "label":f"{start_index+i}:{file_name} score={score:.4f}"
153
+ }
154
+ )
155
+
156
+ for _ in range(0, MAX_DISPLAY_AUDIO_NUMBER - len(top_subset)):
157
+ output.append(
158
+ {
159
+ "__type__":"update",
160
+ "value":None,
161
+ "label":"None"
162
+ }
163
+ )
164
+
165
+ return *output, start_index
166
+
167
+
168
+ def get_next_index_zero(audio):
169
+ return get_similar_score_audio(audio, 0)
170
+
171
+
172
+ def get_next_index(audio, start_index):
173
+ return get_similar_score_audio(audio, start_index + 10)
174
+
175
+
176
+ def get_previous_index(audio, start_index):
177
+ return get_similar_score_audio(audio, max(start_index - 10, 0))
178
+
179
+
180
+ if __name__ == '__main__':
181
+
182
+ download_audios()
183
+
184
+ parser = argparse.ArgumentParser(description='Speaker_Needle_In_Haystack demo Launch')
185
+ parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name')
186
+ parser.add_argument('--server_port', type=int, default=8080, help='Server port')
187
+ parser.add_argument('--batch_size', type=int, default=4, help='the batch_size about embedding generate')
188
+ parser.add_argument('--audio_dir', type=str, default="audios", help='the audio dir which will be compared to target audio')
189
+ parser.add_argument('--disable_pickle_support', action='store_true', help="save emb by pickle")
190
+ args = parser.parse_args()
191
+
192
+ pickle_support = not args.disable_pickle_support
193
+ print("pickle support : ", pickle_support)
194
+ snih = Speaker_Needle_In_Haystack(pickle_support=pickle_support)
195
+
196
+ snih.set_audio_list_dir(args.audio_dir)
197
+ snih.compute_all_embs(batch_size = args.batch_size)
198
+
199
+ with gr.Blocks() as demo:
200
+ gr.Markdown("# 大海捞针 Audio Needle In Haystack")
201
+ with gr.Row():
202
+ audio_input = gr.Audio(
203
+ label= "Input Audio / 输入音频",
204
+ visible = True,
205
+ scale=5,
206
+ type="numpy",
207
+ format='wav'
208
+ )
209
+
210
+ with gr.Column():
211
+ wav_files = sorted(glob.glob(os.path.join("examples", '*.wav')))
212
+ gr.Examples(
213
+ examples=[
214
+ *wav_files
215
+ ],
216
+ inputs=[
217
+ audio_input
218
+ ]
219
+ )
220
+ input_index = gr.Number(value=0, label="Index")
221
+
222
+ btn_get_similar = gr.Button("获取相似音频 Get Similar Score Audio")
223
+ btn_get_previous_index = gr.Button("上一页 Previous Index")
224
+ btn_get_next_index = gr.Button("下一页 Next Index")
225
+
226
+
227
+ gr.Markdown("# 相似音频 similar audio")
228
+
229
+ with gr.Column():
230
+ for _ in range(0,MAX_DISPLAY_AUDIO_NUMBER):
231
+ audio_output = gr.Audio(
232
+ label= "Output Audio",
233
+ visible = True,
234
+ scale=5,
235
+ editable=False
236
+ )
237
+ g_gr_audio_list.append(audio_output)
238
+
239
+ btn_get_similar.click(
240
+ get_next_index_zero,
241
+ inputs=[
242
+ audio_input
243
+ ],
244
+ outputs=[
245
+ *g_gr_audio_list,
246
+ input_index
247
+ ]
248
+ )
249
+
250
+ btn_get_previous_index.click(
251
+ get_previous_index,
252
+ inputs=[
253
+ audio_input,
254
+ input_index
255
+ ],
256
+ outputs=[
257
+ *g_gr_audio_list,
258
+ input_index
259
+ ]
260
+ )
261
+
262
+ btn_get_next_index.click(
263
+ get_next_index,
264
+ inputs=[
265
+ audio_input,
266
+ input_index
267
+ ],
268
+ outputs=[
269
+ *g_gr_audio_list,
270
+ input_index
271
+ ]
272
+ )
273
+
274
+ #demo.launch(server_name=args.server_name, server_port=args.server_port)
275
+ demo.launch()