nx_denoise / examples /conv_tasnet /step_1_prepare_data.py
HoneyTian's picture
update
14f8597
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json
import os
from pathlib import Path
import random
import sys
pwd = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(pwd, "../../"))
import librosa
import numpy as np
from tqdm import tqdm
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--file_dir", default="./", type=str)
parser.add_argument(
"--noise_dir",
default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
type=str
)
parser.add_argument(
"--speech_dir",
default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
type=str
)
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
parser.add_argument("--duration", default=4.0, type=float)
parser.add_argument("--min_snr_db", default=-10, type=float)
parser.add_argument("--max_snr_db", default=20, type=float)
parser.add_argument("--target_sample_rate", default=8000, type=int)
parser.add_argument("--max_count", default=10000, type=int)
args = parser.parse_args()
return args
def filename_generator(data_dir: str):
data_dir = Path(data_dir)
for filename in data_dir.glob("**/*.wav"):
yield filename.as_posix()
def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
data_dir = Path(data_dir)
for epoch_idx in range(max_epoch):
for filename in data_dir.glob("**/*.wav"):
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
if raw_duration < duration:
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
continue
if signal.ndim != 1:
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
signal_length = len(signal)
win_size = int(duration * sample_rate)
for begin in range(0, signal_length - win_size, win_size):
if np.sum(signal[begin: begin+win_size]) == 0:
continue
row = {
"epoch_idx": epoch_idx,
"filename": filename.as_posix(),
"raw_duration": round(raw_duration, 4),
"offset": round(begin / sample_rate, 4),
"duration": round(duration, 4),
}
yield row
def main():
args = get_args()
file_dir = Path(args.file_dir)
file_dir.mkdir(exist_ok=True)
noise_dir = Path(args.noise_dir)
speech_dir = Path(args.speech_dir)
noise_generator = target_second_signal_generator(
noise_dir.as_posix(),
duration=args.duration,
sample_rate=args.target_sample_rate,
max_epoch=100000,
)
speech_generator = target_second_signal_generator(
speech_dir.as_posix(),
duration=args.duration,
sample_rate=args.target_sample_rate,
max_epoch=1,
)
dataset = list()
count = 0
process_bar = tqdm(desc="build dataset excel")
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
for noise, speech in zip(noise_generator, speech_generator):
if count >= args.max_count:
break
noise_filename = noise["filename"]
noise_raw_duration = noise["raw_duration"]
noise_offset = noise["offset"]
noise_duration = noise["duration"]
speech_filename = speech["filename"]
speech_raw_duration = speech["raw_duration"]
speech_offset = speech["offset"]
speech_duration = speech["duration"]
random1 = random.random()
random2 = random.random()
row = {
"noise_filename": noise_filename,
"noise_raw_duration": noise_raw_duration,
"noise_offset": noise_offset,
"noise_duration": noise_duration,
"speech_filename": speech_filename,
"speech_raw_duration": speech_raw_duration,
"speech_offset": speech_offset,
"speech_duration": speech_duration,
"snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
"random1": random1,
}
row = json.dumps(row, ensure_ascii=False)
if random2 < (1 / 300 / 1):
fvalid.write(f"{row}\n")
else:
ftrain.write(f"{row}\n")
count += 1
duration_seconds = count * args.duration
duration_hours = duration_seconds / 3600
process_bar.update(n=1)
process_bar.set_postfix({
# "duration_seconds": round(duration_seconds, 4),
"duration_hours": round(duration_hours, 4),
})
return
if __name__ == "__main__":
main()