mrfakename's picture
Super-squash branch 'main' using huggingface_hub
0102e16 verified
import time
import logging
from tqdm import tqdm
from funasr_detach.register import tables
from funasr_detach.download.download_from_hub import download_model
from funasr_detach.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr_detach.auto.auto_model import prepare_data_iterator
from funasr_detach.auto.auto_model import prepare_data_iterator
class AutoFrontend:
def __init__(self, **kwargs):
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info(
"download models from model hub: {}".format(
kwargs.get("model_hub", "ms")
)
)
kwargs = download_model(**kwargs)
# build frontend
frontend = kwargs.get("frontend", None)
if frontend is not None:
frontend_class = tables.frontend_classes.get(frontend)
frontend = frontend_class(**kwargs["frontend_conf"])
self.frontend = frontend
if "frontend" in kwargs:
del kwargs["frontend"]
self.kwargs = kwargs
def __call__(self, input, input_len=None, kwargs=None, **cfg):
kwargs = self.kwargs if kwargs is None else kwargs
kwargs.update(cfg)
key_list, data_list = prepare_data_iterator(input, input_len=input_len)
batch_size = kwargs.get("batch_size", 1)
device = kwargs.get("device", "cpu")
if device == "cpu":
batch_size = 1
meta_data = {}
result_list = []
num_samples = len(data_list)
pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
time0 = time.perf_counter()
for beg_idx in range(0, num_samples, batch_size):
end_idx = min(num_samples, beg_idx + batch_size)
data_batch = data_list[beg_idx:end_idx]
key_batch = key_list[beg_idx:end_idx]
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(
data_batch, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000)
)
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(
audio_sample_list,
data_type=kwargs.get("data_type", "sound"),
frontend=self.frontend,
**kwargs,
)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = (
speech_lengths.sum().item()
* self.frontend.frame_shift
* self.frontend.lfr_n
/ 1000
)
speech.to(device=device), speech_lengths.to(device=device)
batch = {"input": speech, "input_len": speech_lengths, "key": key_batch}
result_list.append(batch)
pbar.update(1)
description = f"{meta_data}, "
pbar.set_description(description)
time_end = time.perf_counter()
pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
return result_list