Spaces:
Sleeping
Sleeping
import os | |
import operator | |
import glob | |
import librosa | |
import argparse | |
import hashlib | |
import gradio as gr | |
import numpy as np | |
import pickle | |
from tqdm import tqdm | |
from modelscope.pipelines import pipeline | |
from download_audios import download_audios | |
""" | |
Audio Speaker needle in haystack | |
cronrpc | |
https://github.com/cronrpc | |
""" | |
MAX_DISPLAY_AUDIO_NUMBER = 10 | |
g_gr_audio_list = [] | |
class Speaker_Needle_In_Haystack(): | |
SAMPLE_RATE = 16000 | |
def __init__(self, pickle_support = False) -> None: | |
self._load_model() | |
self.all_embs = {} | |
self.cosine_score = {} | |
self.pickle_support = pickle_support | |
pass | |
def set_audio_list_dir(self, dir_path): | |
self.audio_list_dir = dir_path | |
def _load_model(self) -> None: | |
# could switch model here | |
self.model_name = 'damo/speech_eres2netv2_sv_zh-cn_16k-common' | |
self.sv_pipline = pipeline( | |
task='speaker-verification', | |
model=self.model_name, | |
model_revision='v1.0.1' | |
) | |
# self.model_name = 'iic/speech_campplus_sv_zh-cn_3dspeaker_16k' | |
# self.sv_pipline = pipeline( | |
# task='speaker-verification', | |
# model=self.model_name | |
# ) | |
def _get_emb(self, audio) -> None: | |
if isinstance(audio, str): | |
audio, sr = librosa.load(audio, sr=self.SAMPLE_RATE, mono=True) | |
return self.sv_pipline([audio], output_emb=True)['embs'] # (1,196) np array | |
elif isinstance(audio, list): | |
return self.sv_pipline(audio, output_emb=True)['embs'] # (n,196) np array | |
else: | |
return self.sv_pipline([audio], output_emb=True)['embs'] # (1,196) np array | |
def _cosine_similarity_compute(self, emb1, emb2): | |
emb1 = np.squeeze(emb1) | |
emb2 = np.squeeze(emb2) | |
dot_product = np.dot(emb1, emb2) | |
norm_vector1 = np.linalg.norm(emb1) | |
norm_vector2 = np.linalg.norm(emb2) | |
cosine_similarity = dot_product / (norm_vector1 * norm_vector2) | |
return cosine_similarity | |
def compute_all_embs(self, batch_size=1): | |
wav_files = sorted(glob.glob(os.path.join(self.audio_list_dir, '*.wav'))) | |
# hash to skip | |
file_string = self.model_name + ''.join(wav_files) | |
hash_file = hashlib.sha256(file_string.encode()).hexdigest()[:15] + ".pkl" | |
if self.pickle_support: | |
cache_dir = os.path.join('cache','embs_cache') | |
os.makedirs(cache_dir, exist_ok=True) | |
hash_file = os.path.join(cache_dir, hash_file) | |
if os.path.exists(hash_file): | |
print("load pickle embs") | |
self.load_all_embs(hash_file) | |
return | |
self.all_embs = {} | |
num_files = len(wav_files) | |
num_batches = (num_files + batch_size - 1) // batch_size | |
for batch_idx in tqdm(range(num_batches)): | |
start_idx = batch_idx * batch_size | |
end_idx = min((batch_idx + 1) * batch_size, num_files) | |
batch_files = wav_files[start_idx:end_idx] | |
batch_audio = [] | |
for file_path in batch_files: | |
audio, sr = librosa.load(file_path, sr=self.SAMPLE_RATE, mono=True) | |
batch_audio.append(audio) | |
embs = self._get_emb(batch_audio) | |
for i, file_path in enumerate(batch_files): | |
self.all_embs[file_path] = embs[i] | |
# save the self.all_embs in hash_value named file | |
if self.pickle_support: | |
self.save_all_embs(hash_file) | |
def compute_target_aduio_cosine_score(self, target_audio): | |
self.cosine_score = {} | |
target_emb = self._get_emb(target_audio) | |
for file_path, emb in self.all_embs.items(): | |
self.cosine_score[file_path] = self._cosine_similarity_compute(target_emb, emb) | |
def get_cosine_next_top_k(self, k, start = 0): | |
top_subset = sorted(self.cosine_score.items(), key=operator.itemgetter(1), reverse=True)[start: start + k] | |
return top_subset | |
def save_all_embs(self, hash_file): | |
file_path = hash_file | |
with open(file_path, 'wb') as file: | |
pickle.dump(self.all_embs, file) | |
def load_all_embs(self, hash_file): | |
file_path = hash_file | |
with open(file_path, 'rb') as file: | |
self.all_embs = pickle.load(file) | |
def get_similar_score_audio(audio, start_index): | |
output = [] | |
top_subset = [] | |
if audio != None: | |
sr, y = audio | |
if len(y.shape) == 2: | |
y = np.mean(y, axis=-1) | |
audio_16k = librosa.resample(y.astype(np.float32), orig_sr=sr, target_sr=snih.SAMPLE_RATE) | |
snih.compute_target_aduio_cosine_score(audio_16k) | |
top_subset = snih.get_cosine_next_top_k(MAX_DISPLAY_AUDIO_NUMBER, start=start_index) | |
for i in range(0, len(top_subset)): | |
path, score = top_subset[i] | |
file_name = os.path.basename(path) | |
output.append( | |
{ | |
"__type__":"update", | |
"value":path, | |
"label":f"{start_index+i}:{file_name} score={score:.4f}" | |
} | |
) | |
for _ in range(0, MAX_DISPLAY_AUDIO_NUMBER - len(top_subset)): | |
output.append( | |
{ | |
"__type__":"update", | |
"value":None, | |
"label":"None" | |
} | |
) | |
return *output, start_index | |
def get_next_index_zero(audio): | |
return get_similar_score_audio(audio, 0) | |
def get_next_index(audio, start_index): | |
return get_similar_score_audio(audio, start_index + 10) | |
def get_previous_index(audio, start_index): | |
return get_similar_score_audio(audio, max(start_index - 10, 0)) | |
if __name__ == '__main__': | |
download_audios() | |
parser = argparse.ArgumentParser(description='Speaker_Needle_In_Haystack demo Launch') | |
parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name') | |
parser.add_argument('--server_port', type=int, default=8080, help='Server port') | |
parser.add_argument('--batch_size', type=int, default=4, help='the batch_size about embedding generate') | |
parser.add_argument('--audio_dir', type=str, default="audios", help='the audio dir which will be compared to target audio') | |
parser.add_argument('--disable_pickle_support', action='store_true', help="save emb by pickle") | |
args = parser.parse_args() | |
pickle_support = not args.disable_pickle_support | |
print("pickle support : ", pickle_support) | |
snih = Speaker_Needle_In_Haystack(pickle_support=pickle_support) | |
snih.set_audio_list_dir(args.audio_dir) | |
snih.compute_all_embs(batch_size = args.batch_size) | |
with gr.Blocks() as demo: | |
gr.Markdown("# 大海捞针 Audio Needle In Haystack") | |
with gr.Row(): | |
audio_input = gr.Audio( | |
label= "Input Audio / 输入音频", | |
visible = True, | |
scale=5, | |
type="numpy", | |
format='wav' | |
) | |
with gr.Column(): | |
wav_files = sorted(glob.glob(os.path.join("examples", '*.wav'))) | |
gr.Examples( | |
examples=[ | |
*wav_files | |
], | |
inputs=[ | |
audio_input | |
] | |
) | |
input_index = gr.Number(value=0, label="Index") | |
btn_get_similar = gr.Button("获取相似音频 Get Similar Score Audio") | |
btn_get_previous_index = gr.Button("上一页 Previous Index") | |
btn_get_next_index = gr.Button("下一页 Next Index") | |
gr.Markdown("# 相似音频 similar audio") | |
with gr.Column(): | |
for _ in range(0,MAX_DISPLAY_AUDIO_NUMBER): | |
audio_output = gr.Audio( | |
label= "Output Audio", | |
visible = True, | |
scale=5, | |
editable=False | |
) | |
g_gr_audio_list.append(audio_output) | |
btn_get_similar.click( | |
get_next_index_zero, | |
inputs=[ | |
audio_input | |
], | |
outputs=[ | |
*g_gr_audio_list, | |
input_index | |
] | |
) | |
btn_get_previous_index.click( | |
get_previous_index, | |
inputs=[ | |
audio_input, | |
input_index | |
], | |
outputs=[ | |
*g_gr_audio_list, | |
input_index | |
] | |
) | |
btn_get_next_index.click( | |
get_next_index, | |
inputs=[ | |
audio_input, | |
input_index | |
], | |
outputs=[ | |
*g_gr_audio_list, | |
input_index | |
] | |
) | |
#demo.launch(server_name=args.server_name, server_port=args.server_port) | |
demo.launch() | |