Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
import time | |
import traceback | |
from concurrent.futures import ThreadPoolExecutor | |
from pathlib import Path | |
import librosa | |
import numpy as np | |
import tqdm | |
import whisper | |
from soundstorm.s2.exps.hubert.feature_utils import get_shard_range | |
from soundstorm.utils import check_txt_file | |
def process_sentence(args, | |
fp: Path, | |
train_dump_dir: Path, | |
dev_dump_dir: Path, | |
test_dump_dir: Path, | |
VAD_dict): | |
asr_model = whisper.load_model("tiny.en") | |
utt_id = fp.stem | |
sr = args.sr | |
record = [] | |
train_txt_dir = train_dump_dir / "txt" | |
train_txt_dir.mkdir(parents=True, exist_ok=True) | |
dev_txt_dir = dev_dump_dir / "txt" | |
dev_txt_dir.mkdir(parents=True, exist_ok=True) | |
test_txt_dir = test_dump_dir / "txt" | |
test_txt_dir.mkdir(parents=True, exist_ok=True) | |
try: | |
# get info for path | |
wav_path_list = str(fp).strip().split('/') | |
sub_dataset, spk_id, book_name = wav_path_list[-4], wav_path_list[ | |
-3], wav_path_list[-2] | |
wav_name = wav_path_list[-1][:-5] | |
assert wav_name == utt_id | |
# key_name for big wav | |
key_name = f'{wav_name}#{sub_dataset}#{spk_id}#{book_name}' | |
# 判断 VAD 字典中不存在该条音频信息的情况 | |
if key_name not in VAD_dict.keys(): | |
print(key_name, 'not in VAD_dict !') | |
return record | |
wav = None | |
sorted_split_VAD_dict = sorted(VAD_dict[key_name].items()) | |
len_dict = len(sorted_split_VAD_dict) | |
for index, item in enumerate(sorted_split_VAD_dict): | |
split_name, value = item | |
start, end = value | |
# train | dev | test | |
if index == len_dict - 1: | |
subset = 'test' | |
txt_path = test_txt_dir / (split_name + ".txt") | |
elif index == len_dict - 2: | |
subset = 'dev' | |
txt_path = dev_txt_dir / (split_name + ".txt") | |
else: | |
subset = 'train' | |
txt_path = train_txt_dir / (split_name + ".txt") | |
if os.path.exists(txt_path) and check_txt_file(txt_path): | |
# print(txt_path, 'exits!') | |
pass | |
else: | |
# 这里加判断保证在 sub wav 的循环中只 load 一次 | |
if wav is None: | |
# load big wav | |
# 在最外层 load 如果 sub wav 的特征都存在了就会白白消耗 load 的时间 | |
wav, _ = librosa.load(str(fp), sr=sr) | |
sub_wav = wav[int(start * sr):int(end * sr)] | |
asr_result = asr_model.transcribe(sub_wav)["text"] | |
with open(txt_path, 'w') as f: | |
f.write(asr_result) | |
sub_record = { | |
"utt_id": split_name, | |
"txt_path": txt_path, | |
"subset": subset | |
} | |
# recodrd 变成 List of Dict | |
record.append(sub_record) | |
except Exception: | |
print("occur Exception") | |
traceback.print_exc() | |
# record 有可能是一个不完整的 List | |
return record | |
return record | |
def process_sentences(args, | |
fps: Path, | |
train_dump_dir: Path, | |
dev_dump_dir: Path, | |
test_dump_dir: Path, | |
VAD_dict, | |
nprocs: int=1): | |
print("nprocs:", nprocs) | |
if nprocs == 1: | |
results = [] | |
for fp in tqdm.tqdm(fps, total=len(fps)): | |
record = process_sentence( | |
args=args, | |
fp=fp, | |
train_dump_dir=train_dump_dir, | |
dev_dump_dir=dev_dump_dir, | |
test_dump_dir=test_dump_dir, | |
VAD_dict=VAD_dict) | |
if record: | |
results.append(record) | |
else: | |
with ThreadPoolExecutor(nprocs) as pool: | |
futures = [] | |
with tqdm.tqdm(total=len(fps)) as progress: | |
for fp in fps: | |
future = pool.submit(process_sentence, args, fp, | |
train_dump_dir, dev_dump_dir, | |
test_dump_dir, VAD_dict) | |
future.add_done_callback(lambda p: progress.update()) | |
futures.append(future) | |
results = [] | |
for ft in futures: | |
record = ft.result() | |
if record: | |
results.append(record) | |
# torch.save() to a large `.pth` file | |
txt_dict = dict() | |
txt_dict['train'] = {} | |
txt_dict['dev'] = {} | |
txt_dict['test'] = {} | |
# record 是 List of Dict, 一条大 wav 一个 record,一条小 wav 一个 sub_recored | |
print(f"start to save {args.rank}_{args.nshard}.npy ...") | |
save_start_time = time.time() | |
for record in tqdm.tqdm(results, total=len(results), colour='green'): | |
for sub_record in record: | |
# 这里加 try, 因为 txt 文件可能损坏 | |
try: | |
utt_id = sub_record["utt_id"] | |
subset = sub_record["subset"] | |
asr_result = check_txt_file(sub_record["txt_path"]) | |
if asr_result is not False: | |
txt_dict[subset][utt_id] = asr_result | |
else: | |
print(f'asr result of {utt_id} is False') | |
except Exception: | |
print(f"{utt_id} occur Exception") | |
traceback.print_exc() | |
continue | |
train_filename = train_dump_dir / f'txt_{args.rank}_{args.nshard}.npy' | |
dev_filename = dev_dump_dir / f'txt_{args.rank}_{args.nshard}.npy' | |
test_filename = test_dump_dir / f'txt_{args.rank}_{args.nshard}.npy' | |
np.save(train_filename, txt_dict['train']) | |
print(f"npy file '{train_filename}' write down") | |
np.save(dev_filename, txt_dict['dev']) | |
print(f"npy file '{dev_filename}' write down") | |
np.save(test_filename, txt_dict['test']) | |
print(f"npy file '{test_filename}' write down") | |
print('time of save stage:', time.time() - save_start_time) | |
def main(): | |
# parse config and args | |
parser = argparse.ArgumentParser( | |
description="Preprocess audio and then extract features for LibriLight.") | |
parser.add_argument( | |
"--data_dir", default=None, type=str, help="directory to dataset.") | |
parser.add_argument( | |
"--dump_dir", | |
type=str, | |
required=True, | |
help="directory to dump feature files.") | |
parser.add_argument( | |
"--num-cpu", type=int, default=1, help="number of process.") | |
parser.add_argument( | |
'--sr', type=int, default=16000, help='sample rate of model') | |
# For LibriLight dataset | |
parser.add_argument( | |
"--sub_dataset", | |
default="small", | |
type=str, | |
help="name of sub dataset of LibriLight", | |
choices=['small', 'medium', 'large', 'duplicate'], ) | |
parser.add_argument( | |
"--VAD_path", type=str, default='./VAD/librilight_segment_dict.npy') | |
parser.add_argument("--nshard", type=int, default=3) | |
parser.add_argument("--rank", type=int, default=0) | |
args = parser.parse_args() | |
data_dir = Path(args.data_dir).expanduser() | |
dump_dir = Path(args.dump_dir).expanduser() | |
# use absolute path | |
dump_dir = dump_dir.resolve() | |
dump_dir.mkdir(parents=True, exist_ok=True) | |
assert data_dir.is_dir() | |
# sub_dataset here | |
sub_dataset_dir = data_dir / args.sub_dataset | |
# olny spk_id in list, sort by lexicographical order | |
speaker_list = sorted(os.listdir(sub_dataset_dir)) | |
start, end = get_shard_range(len(speaker_list), args.nshard, args.rank) | |
# speaker_list for this rank | |
speaker_list = speaker_list[start:end] | |
all_wav_files = [] | |
for speaker in speaker_list: | |
wav_files = sorted(list((sub_dataset_dir / speaker).rglob("*/*.flac"))) | |
# filter out ._*.flac | |
wav_files = [ | |
file for file in wav_files if not file.name.startswith('._') | |
] | |
all_wav_files += wav_files | |
print(f"num of wav files in rank {args.rank}:", len(all_wav_files)) | |
# get VAD info | |
VAD_dict = np.load(args.VAD_path, allow_pickle=True).item() | |
sub_dataset_dump_dir = dump_dir / args.sub_dataset | |
sub_dataset_dump_dir.mkdir(parents=True, exist_ok=True) | |
train_dump_dir = sub_dataset_dump_dir / "train" | |
train_dump_dir.mkdir(parents=True, exist_ok=True) | |
dev_dump_dir = sub_dataset_dump_dir / "dev" | |
dev_dump_dir.mkdir(parents=True, exist_ok=True) | |
test_dump_dir = sub_dataset_dump_dir / "test" | |
test_dump_dir.mkdir(parents=True, exist_ok=True) | |
# 每条大 wav 分出一个 dev 一个 test,比例大概是 96:2:2 | |
if all_wav_files: | |
process_sentences( | |
args=args, | |
fps=all_wav_files, | |
train_dump_dir=train_dump_dir, | |
dev_dump_dir=dev_dump_dir, | |
test_dump_dir=test_dump_dir, | |
VAD_dict=VAD_dict, | |
nprocs=args.num_cpu) | |
if __name__ == "__main__": | |
main() | |