Spaces:
Runtime error
Runtime error
import time | |
import logging | |
from tqdm import tqdm | |
from funasr_detach.register import tables | |
from funasr_detach.download.download_from_hub import download_model | |
from funasr_detach.utils.load_utils import load_audio_text_image_video, extract_fbank | |
from funasr_detach.auto.auto_model import prepare_data_iterator | |
from funasr_detach.auto.auto_model import prepare_data_iterator | |
class AutoFrontend: | |
def __init__(self, **kwargs): | |
assert "model" in kwargs | |
if "model_conf" not in kwargs: | |
logging.info( | |
"download models from model hub: {}".format( | |
kwargs.get("model_hub", "ms") | |
) | |
) | |
kwargs = download_model(**kwargs) | |
# build frontend | |
frontend = kwargs.get("frontend", None) | |
if frontend is not None: | |
frontend_class = tables.frontend_classes.get(frontend) | |
frontend = frontend_class(**kwargs["frontend_conf"]) | |
self.frontend = frontend | |
if "frontend" in kwargs: | |
del kwargs["frontend"] | |
self.kwargs = kwargs | |
def __call__(self, input, input_len=None, kwargs=None, **cfg): | |
kwargs = self.kwargs if kwargs is None else kwargs | |
kwargs.update(cfg) | |
key_list, data_list = prepare_data_iterator(input, input_len=input_len) | |
batch_size = kwargs.get("batch_size", 1) | |
device = kwargs.get("device", "cpu") | |
if device == "cpu": | |
batch_size = 1 | |
meta_data = {} | |
result_list = [] | |
num_samples = len(data_list) | |
pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True) | |
time0 = time.perf_counter() | |
for beg_idx in range(0, num_samples, batch_size): | |
end_idx = min(num_samples, beg_idx + batch_size) | |
data_batch = data_list[beg_idx:end_idx] | |
key_batch = key_list[beg_idx:end_idx] | |
# extract fbank feats | |
time1 = time.perf_counter() | |
audio_sample_list = load_audio_text_image_video( | |
data_batch, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000) | |
) | |
time2 = time.perf_counter() | |
meta_data["load_data"] = f"{time2 - time1:0.3f}" | |
speech, speech_lengths = extract_fbank( | |
audio_sample_list, | |
data_type=kwargs.get("data_type", "sound"), | |
frontend=self.frontend, | |
**kwargs, | |
) | |
time3 = time.perf_counter() | |
meta_data["extract_feat"] = f"{time3 - time2:0.3f}" | |
meta_data["batch_data_time"] = ( | |
speech_lengths.sum().item() | |
* self.frontend.frame_shift | |
* self.frontend.lfr_n | |
/ 1000 | |
) | |
speech.to(device=device), speech_lengths.to(device=device) | |
batch = {"input": speech, "input_len": speech_lengths, "key": key_batch} | |
result_list.append(batch) | |
pbar.update(1) | |
description = f"{meta_data}, " | |
pbar.set_description(description) | |
time_end = time.perf_counter() | |
pbar.set_description(f"time escaped total: {time_end - time0:0.3f}") | |
return result_list | |