Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import argparse | |
import json | |
import os | |
from pathlib import Path | |
import random | |
import sys | |
pwd = os.path.abspath(os.path.dirname(__file__)) | |
sys.path.append(os.path.join(pwd, "../../")) | |
import librosa | |
import numpy as np | |
from tqdm import tqdm | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--file_dir", default="./", type=str) | |
parser.add_argument( | |
"--noise_dir", | |
default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise", | |
type=str | |
) | |
parser.add_argument( | |
"--speech_dir", | |
default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train", | |
type=str | |
) | |
parser.add_argument("--train_dataset", default="train.jsonl", type=str) | |
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str) | |
parser.add_argument("--duration", default=4.0, type=float) | |
parser.add_argument("--min_snr_db", default=-10, type=float) | |
parser.add_argument("--max_snr_db", default=20, type=float) | |
parser.add_argument("--target_sample_rate", default=8000, type=int) | |
parser.add_argument("--max_count", default=10000, type=int) | |
args = parser.parse_args() | |
return args | |
def filename_generator(data_dir: str): | |
data_dir = Path(data_dir) | |
for filename in data_dir.glob("**/*.wav"): | |
yield filename.as_posix() | |
def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000): | |
data_dir = Path(data_dir) | |
for epoch_idx in range(max_epoch): | |
for filename in data_dir.glob("**/*.wav"): | |
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate) | |
raw_duration = librosa.get_duration(y=signal, sr=sample_rate) | |
if raw_duration < duration: | |
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}") | |
continue | |
if signal.ndim != 1: | |
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}") | |
signal_length = len(signal) | |
win_size = int(duration * sample_rate) | |
for begin in range(0, signal_length - win_size, win_size): | |
if np.sum(signal[begin: begin+win_size]) == 0: | |
continue | |
row = { | |
"epoch_idx": epoch_idx, | |
"filename": filename.as_posix(), | |
"raw_duration": round(raw_duration, 4), | |
"offset": round(begin / sample_rate, 4), | |
"duration": round(duration, 4), | |
} | |
yield row | |
def main(): | |
args = get_args() | |
file_dir = Path(args.file_dir) | |
file_dir.mkdir(exist_ok=True) | |
noise_dir = Path(args.noise_dir) | |
speech_dir = Path(args.speech_dir) | |
noise_generator = target_second_signal_generator( | |
noise_dir.as_posix(), | |
duration=args.duration, | |
sample_rate=args.target_sample_rate, | |
max_epoch=100000, | |
) | |
speech_generator = target_second_signal_generator( | |
speech_dir.as_posix(), | |
duration=args.duration, | |
sample_rate=args.target_sample_rate, | |
max_epoch=1, | |
) | |
dataset = list() | |
count = 0 | |
process_bar = tqdm(desc="build dataset excel") | |
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid: | |
for noise, speech in zip(noise_generator, speech_generator): | |
if count >= args.max_count: | |
break | |
noise_filename = noise["filename"] | |
noise_raw_duration = noise["raw_duration"] | |
noise_offset = noise["offset"] | |
noise_duration = noise["duration"] | |
speech_filename = speech["filename"] | |
speech_raw_duration = speech["raw_duration"] | |
speech_offset = speech["offset"] | |
speech_duration = speech["duration"] | |
random1 = random.random() | |
random2 = random.random() | |
row = { | |
"noise_filename": noise_filename, | |
"noise_raw_duration": noise_raw_duration, | |
"noise_offset": noise_offset, | |
"noise_duration": noise_duration, | |
"speech_filename": speech_filename, | |
"speech_raw_duration": speech_raw_duration, | |
"speech_offset": speech_offset, | |
"speech_duration": speech_duration, | |
"snr_db": random.uniform(args.min_snr_db, args.max_snr_db), | |
"random1": random1, | |
} | |
row = json.dumps(row, ensure_ascii=False) | |
if random2 < (1 / 300 / 1): | |
fvalid.write(f"{row}\n") | |
else: | |
ftrain.write(f"{row}\n") | |
count += 1 | |
duration_seconds = count * args.duration | |
duration_hours = duration_seconds / 3600 | |
process_bar.update(n=1) | |
process_bar.set_postfix({ | |
# "duration_seconds": round(duration_seconds, 4), | |
"duration_hours": round(duration_hours, 4), | |
}) | |
return | |
if __name__ == "__main__": | |
main() | |