Spaces:
Paused
Paused
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/1B. Voice activity detection.ipynb. | |
# %% auto 0 | |
__all__ = [] | |
# %% ../nbs/1B. Voice activity detection.ipynb 3 | |
import os | |
import torch | |
import torchaudio | |
from pathlib import Path | |
from fastprogress import progress_bar | |
from fastcore.script import call_parse | |
import whisperx | |
import random | |
import numpy as np | |
import webdataset as wds | |
# %% ../nbs/1B. Voice activity detection.ipynb 5 | |
# some of the original file names have a dot in their name | |
# webdataset does not like it so let's patch it | |
def fix_dots_in_names(name): | |
name, ext = name.rsplit('.', 1) | |
return ".".join((name.replace('.', '_'), ext)) | |
def load_dataset(url, decode=True, rename_files=None): | |
ds = wds.WebDataset(url, rename_files=rename_files) | |
if not decode: return ds | |
return ds.decode(wds.torch_audio) | |
# %% ../nbs/1B. Voice activity detection.ipynb 7 | |
def extract_segments(vad_result, max_duration): | |
binarize = whisperx.vad.Binarize(max_duration=max_duration) | |
segments = binarize(vad_result) | |
return [(x.start, x.end) for x in segments.get_timeline()] | |
def segment_audio(vad_model, audio, sr=16000): | |
vad_result = vad_model({"waveform": audio, "sample_rate": sr}) | |
return extract_segments(vad_result, 30) | |
# %% ../nbs/1B. Voice activity detection.ipynb 13 | |
def flac_to_vad_name(input): | |
if '-flac-' in input: | |
return input.rsplit("/", 1)[1].replace('flac', 'vad') + ".gz" | |
else: | |
return input.rsplit("/", 1)[1].replace('raw', 'vad') + ".gz" | |
def process_shard( | |
input:str, # input shard URL/path | |
output:str=None, # output shard URL/path | |
fix_dots:bool=False, # fix dots in LibriLight filenames | |
): | |
if output is None: output = flac_to_vad_name(input) | |
ds = torch.utils.data.DataLoader(load_dataset(input, rename_files=fix_dots_in_names if fix_dots else None), num_workers=2, batch_size=None) | |
vad_model = whisperx.vad.load_vad_model('cuda') | |
tmp = output+".tmp" | |
with wds.TarWriter(tmp) as sink: | |
for s in progress_bar(ds, total='noinfer'): | |
audio, sr = s.get('flac', s.get('wav', (None, None))) | |
if audio is None: | |
print(f"warning: '{s['__key__']}' does not contain an audio file") | |
continue | |
sink.write({ | |
"__key__": s['__key__'], | |
"vad.npy": np.array(segment_audio(vad_model, audio, sr=sr), dtype=np.float16) | |
}) | |
os.rename(tmp, output) | |