import shutil import os import tempfile from collections import OrderedDict from glob import glob import numpy import torch import torchaudio import torchaudio.functional as F from pydub import AudioSegment from tqdm import tqdm from speechbrain.pretrained import VAD from speechbrain.pretrained import EncoderASR import gradio as gr tempdir = tempfile.mkdtemp() def read_and_resample(filename, outdir): # load the file AudioSegment.from_file(filename).export(f"{filename}.wav", format='wav', parameters=["-ar", "16000", "-ac", '1']) filename = f"{filename}.wav" signal, sr = torchaudio.load(filename) if sr != 16_000: # downsample to 16khz and mono resampled = F.resample(signal, sr, 16_000, lowpass_filter_width=128).mean(dim=0).view(1, -1).cpu() else: resampled = signal.mean(dim=0).view(1, -1).cpu() # get tmp dir: filename = os.path.basename(filename).split(".")[0] # yield segments of 90 minutes. c_size = 60 * 60 * 16_000 for i, c in enumerate(range(0, resampled.shape[1], c_size)): tempaudio = os.path.join(outdir, f"{filename}-{i}.wav") # save to tmp dir: torchaudio.save(tempaudio, resampled[:, c:c+c_size], 16_000) yield (tempaudio, resampled[:, c:c+c_size]) def segment_file(VAD, id, prefix, filename, resampled, output_dir): min_chunk_size = 4 # seconds max_allowed_length = 12 # seconds margin = 0.15 with torch.no_grad(): audio_info = VAD.get_speech_segments(filename, apply_energy_VAD=True, len_th=0.5, deactivation_th=0.4, double_check=False, close_th=0.25) # save segments: s = -1 for _s, _e in audio_info: _s, _e = _s.item(), _e.item() _s = max(0, _s - margin) e = min(resampled.size(1) / 16_000, _e + margin) if s == -1: s = _s chunk_length = e - s if chunk_length > min_chunk_size: no_chunks = int(numpy.ceil(chunk_length / max_allowed_length)) starts = numpy.linspace(s, e, no_chunks + 1).tolist() if chunk_length > max_allowed_length: print("WARNING: segment too long:", chunk_length) print(no_chunks, starts) for x in range(no_chunks): start = starts[x] end = starts[x + 1] local_chunk_length = end - start print(f"Saving segment: {start:08.2f}-{end:08.2f}, with length: {local_chunk_length:05.2f} secs") fname = f"{id}-{prefix}-{start:08.2f}-{end:08.2f}.wav" # convert from seconds to samples: start = int(start * 16_000) end = int(end * 16_000) # save segment: torchaudio.save(os.path.join(output_dir, fname), resampled[:, start:end], 16_000) s = -1 def format_time(secs: float): m, s = divmod(secs, 60) h, m = divmod(m, 60) return "%d:%02d:%02d,%03d" % (h, m, s, int(secs * 1000 % 1000)) asr_model = EncoderASR.from_hparams(source="asafaya/hubert-large-arabic-transcribe") vad_model = VAD.from_hparams(source="speechbrain/vad-crdnn-libriparty") def main(filename, generate_srt=False): try: AudioSegment.from_file(filename) except: return "Please upload a valid audio file" outdir = os.path.join(tempdir, filename.split("/")[-1].split(".")[0]) if not os.path.exists(outdir): os.mkdir(outdir) print("Applying VAD to", filename) # directory to save segments_dir = os.path.join(outdir, "segments") if os.path.exists(segments_dir): raise Exception(f"Segments directory already exists: {segments_dir}") os.mkdir(segments_dir) print("Saving segments to", segments_dir) for c, (tempaudio, resampled) in enumerate(read_and_resample(filename, outdir)): print(f"Segmenting file: {filename}, with length: {resampled.shape[1] / 16_000:05.2f} secs: {tempaudio}") segment_file(vad_model, os.path.basename(tempaudio), c, tempaudio, resampled, segments_dir) # os.remove(tempaudio) transcriptions = OrderedDict() files = glob(os.path.join(segments_dir, "*.wav")) print("Start transcribing") for f in tqdm(sorted(files)): try: transcriptions[os.path.basename(f).replace(".wav", "")] = asr_model.transcribe_file(f) # os.remove(os.path.basename(f)) except Exception as e: print(e) print("Error transcribing file {}".format(f)) print("Skipping...") # shutil.rmtree(outdir) fo = "" for i, key in enumerate(transcriptions): line = key # segment-0-00148.72-00156.97 start_sec = float(line.split("-")[-2]) end_sec = float(line.split("-")[-1]) if len(line) < 2: continue if generate_srt: fo += ("{}\n".format(i+1)) fo += ("{} --> ".format(format_time(start_sec))) fo += ("{}\n".format(format_time(end_sec))) fo += ("{}\n".format(transcriptions[key])) fo += ("\n") if generate_srt else "" return fo outputs = gr.outputs.Textbox(label="Transcription") title = "Arabic Speech Transcription" description = "Simply upload your audio." gr.Interface(main, [gr.inputs.Audio(label="Arabic Audio File", type="filepath"), "checkbox"], outputs, title=title, description=description, enable_queue=True).launch()