|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import logging |
|
import os |
|
import csv |
|
import tempfile |
|
from collections import defaultdict |
|
from pathlib import Path |
|
|
|
import torchaudio |
|
try: |
|
import webrtcvad |
|
except ImportError: |
|
raise ImportError("Please install py-webrtcvad: pip install webrtcvad") |
|
import pandas as pd |
|
from tqdm import tqdm |
|
|
|
from examples.speech_synthesis.preprocessing.denoiser.pretrained import master64 |
|
import examples.speech_synthesis.preprocessing.denoiser.utils as utils |
|
from examples.speech_synthesis.preprocessing.vad import ( |
|
frame_generator, vad_collector, read_wave, write_wave, FS_MS, THRESHOLD, |
|
SCALE |
|
) |
|
from examples.speech_to_text.data_utils import save_df_to_tsv |
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
|
PATHS = ["after_denoise", "after_vad"] |
|
MIN_T = 0.05 |
|
|
|
|
|
def generate_tmp_filename(extension="txt"): |
|
return tempfile._get_default_tempdir() + "/" + \ |
|
next(tempfile._get_candidate_names()) + "." + extension |
|
|
|
|
|
def convert_sr(inpath, sr, output_path=None): |
|
if not output_path: |
|
output_path = generate_tmp_filename("wav") |
|
cmd = f"sox {inpath} -r {sr} {output_path}" |
|
os.system(cmd) |
|
return output_path |
|
|
|
|
|
def apply_vad(vad, inpath): |
|
audio, sample_rate = read_wave(inpath) |
|
frames = frame_generator(FS_MS, audio, sample_rate) |
|
frames = list(frames) |
|
segments = vad_collector(sample_rate, FS_MS, 300, vad, frames) |
|
merge_segments = list() |
|
timestamp_start = 0.0 |
|
timestamp_end = 0.0 |
|
|
|
for i, segment in enumerate(segments): |
|
merge_segments.append(segment[0]) |
|
if i and timestamp_start: |
|
sil_duration = segment[1] - timestamp_end |
|
if sil_duration > THRESHOLD: |
|
merge_segments.append(int(THRESHOLD / SCALE) * (b'\x00')) |
|
else: |
|
merge_segments.append(int((sil_duration / SCALE)) * (b'\x00')) |
|
timestamp_start = segment[1] |
|
timestamp_end = segment[2] |
|
segment = b''.join(merge_segments) |
|
return segment, sample_rate |
|
|
|
|
|
def write(wav, filename, sr=16_000): |
|
|
|
wav = wav / max(wav.abs().max().item(), 1) |
|
torchaudio.save(filename, wav.cpu(), sr, encoding="PCM_S", |
|
bits_per_sample=16) |
|
|
|
|
|
def process(args): |
|
|
|
if not args.denoise and not args.vad: |
|
log.error("No denoise or vad is requested.") |
|
return |
|
|
|
log.info("Creating out directories...") |
|
if args.denoise: |
|
out_denoise = Path(args.output_dir).absolute().joinpath(PATHS[0]) |
|
out_denoise.mkdir(parents=True, exist_ok=True) |
|
if args.vad: |
|
out_vad = Path(args.output_dir).absolute().joinpath(PATHS[1]) |
|
out_vad.mkdir(parents=True, exist_ok=True) |
|
|
|
log.info("Loading pre-trained speech enhancement model...") |
|
model = master64().to(args.device) |
|
|
|
log.info("Building the VAD model...") |
|
vad = webrtcvad.Vad(int(args.vad_agg_level)) |
|
|
|
|
|
output_dict = defaultdict(list) |
|
|
|
log.info(f"Parsing input manifest: {args.audio_manifest}") |
|
with open(args.audio_manifest, "r") as f: |
|
manifest_dict = csv.DictReader(f, delimiter="\t") |
|
for row in tqdm(manifest_dict): |
|
filename = str(row["audio"]) |
|
|
|
final_output = filename |
|
keep_sample = True |
|
n_frames = row["n_frames"] |
|
snr = -1 |
|
if args.denoise: |
|
output_path_denoise = out_denoise.joinpath(Path(filename).name) |
|
|
|
tmp_path = convert_sr(final_output, 16000) |
|
|
|
|
|
out, sr = torchaudio.load(tmp_path) |
|
out = out.to(args.device) |
|
estimate = model(out) |
|
estimate = (1 - args.dry_wet) * estimate + args.dry_wet * out |
|
write(estimate[0], str(output_path_denoise), sr) |
|
|
|
snr = utils.cal_snr(out, estimate) |
|
snr = snr.cpu().detach().numpy()[0][0] |
|
final_output = str(output_path_denoise) |
|
|
|
if args.vad: |
|
output_path_vad = out_vad.joinpath(Path(filename).name) |
|
sr = torchaudio.info(final_output).sample_rate |
|
if sr in [16000, 32000, 48000]: |
|
tmp_path = final_output |
|
elif sr < 16000: |
|
tmp_path = convert_sr(final_output, 16000) |
|
elif sr < 32000: |
|
tmp_path = convert_sr(final_output, 32000) |
|
else: |
|
tmp_path = convert_sr(final_output, 48000) |
|
|
|
segment, sample_rate = apply_vad(vad, tmp_path) |
|
if len(segment) < sample_rate * MIN_T: |
|
keep_sample = False |
|
print(( |
|
f"WARNING: skip {filename} because it is too short " |
|
f"after VAD ({len(segment) / sample_rate} < {MIN_T})" |
|
)) |
|
else: |
|
if sample_rate != sr: |
|
tmp_path = generate_tmp_filename("wav") |
|
write_wave(tmp_path, segment, sample_rate) |
|
convert_sr(tmp_path, sr, |
|
output_path=str(output_path_vad)) |
|
else: |
|
write_wave(str(output_path_vad), segment, sample_rate) |
|
final_output = str(output_path_vad) |
|
segment, _ = torchaudio.load(final_output) |
|
n_frames = segment.size(1) |
|
|
|
if keep_sample: |
|
output_dict["id"].append(row["id"]) |
|
output_dict["audio"].append(final_output) |
|
output_dict["n_frames"].append(n_frames) |
|
output_dict["tgt_text"].append(row["tgt_text"]) |
|
output_dict["speaker"].append(row["speaker"]) |
|
output_dict["src_text"].append(row["src_text"]) |
|
output_dict["snr"].append(snr) |
|
|
|
out_tsv_path = Path(args.output_dir) / Path(args.audio_manifest).name |
|
log.info(f"Saving manifest to {out_tsv_path.as_posix()}") |
|
save_df_to_tsv(pd.DataFrame.from_dict(output_dict), out_tsv_path) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--audio-manifest", "-i", required=True, |
|
type=str, help="path to the input manifest.") |
|
parser.add_argument( |
|
"--output-dir", "-o", required=True, type=str, |
|
help="path to the output dir. it will contain files after denoising and" |
|
" vad" |
|
) |
|
parser.add_argument("--vad-agg-level", "-a", type=int, default=2, |
|
help="the aggresive level of the vad [0-3].") |
|
parser.add_argument( |
|
"--dry-wet", "-dw", type=float, default=0.01, |
|
help="the level of linear interpolation between noisy and enhanced " |
|
"files." |
|
) |
|
parser.add_argument( |
|
"--device", "-d", type=str, default="cpu", |
|
help="the device to be used for the speech enhancement model: " |
|
"cpu | cuda." |
|
) |
|
parser.add_argument("--denoise", action="store_true", |
|
help="apply a denoising") |
|
parser.add_argument("--vad", action="store_true", help="apply a VAD") |
|
args = parser.parse_args() |
|
|
|
process(args) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|