|
import gradio as gr |
|
import librosa |
|
import soundfile |
|
import tempfile |
|
import os |
|
import uuid |
|
import json |
|
|
|
import jieba |
|
|
|
import nemo.collections.asr as nemo_asr |
|
from nemo.collections.asr.models import ASRModel |
|
from nemo.utils import logging |
|
|
|
from align import main, AlignmentConfig, ASSFileConfig |
|
|
|
|
|
SAMPLE_RATE = 16000 |
|
|
|
|
|
logging.setLevel(logging.ERROR) |
|
for tmp_model_name in [ |
|
"stt_en_fastconformer_hybrid_large_pc", |
|
"stt_de_fastconformer_hybrid_large_pc", |
|
"stt_es_fastconformer_hybrid_large_pc", |
|
"stt_fr_conformer_ctc_large", |
|
"stt_zh_citrinet_1024_gamma_0_25", |
|
]: |
|
tmp_model = ASRModel.from_pretrained(tmp_model_name, map_location='cpu') |
|
del tmp_model |
|
logging.setLevel(logging.INFO) |
|
|
|
|
|
def get_audio_data_and_duration(file): |
|
data, sr = librosa.load(file) |
|
|
|
if sr != SAMPLE_RATE: |
|
data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE) |
|
|
|
|
|
data = librosa.to_mono(data) |
|
|
|
duration = librosa.get_duration(y=data, sr=SAMPLE_RATE) |
|
return data, duration |
|
|
|
|
|
def get_char_tokens(text, model): |
|
tokens = [] |
|
for character in text: |
|
if character in model.decoder.vocabulary: |
|
tokens.append(model.decoder.vocabulary.index(character)) |
|
else: |
|
tokens.append(len(model.decoder.vocabulary)) |
|
|
|
return tokens |
|
|
|
|
|
def get_S_prime_and_T(text, model_name, model, audio_duration): |
|
|
|
|
|
if "citrinet" in model_name or "_fastconformer_" in model_name: |
|
output_timestep_duration = 0.08 |
|
elif "_conformer_" in model_name: |
|
output_timestep_duration = 0.04 |
|
elif "quartznet" in model_name: |
|
output_timestep_duration = 0.02 |
|
else: |
|
raise RuntimeError("unexpected model name") |
|
|
|
T = int(audio_duration / output_timestep_duration) + 1 |
|
|
|
|
|
if hasattr(model, 'tokenizer'): |
|
all_tokens = model.tokenizer.text_to_ids(text) |
|
elif hasattr(model.decoder, "vocabulary"): |
|
all_tokens = get_char_tokens(text, model) |
|
else: |
|
raise RuntimeError("cannot obtain tokens from this model") |
|
|
|
n_token_repetitions = 0 |
|
for i_tok in range(1, len(all_tokens)): |
|
if all_tokens[i_tok] == all_tokens[i_tok - 1]: |
|
n_token_repetitions += 1 |
|
|
|
S_prime = len(all_tokens) + n_token_repetitions |
|
|
|
return S_prime, T |
|
|
|
|
|
def hex_to_rgb_list(hex_string): |
|
hex_string = hex_string.lstrip("#") |
|
r = int(hex_string[:2], 16) |
|
g = int(hex_string[2:4], 16) |
|
b = int(hex_string[4:], 16) |
|
return [r, g, b] |
|
|
|
def delete_mp4s_except_given_filepath(filepath): |
|
files_in_dir = os.listdir() |
|
mp4_files_in_dir = [x for x in files_in_dir if x.endswith(".mp4")] |
|
for mp4_file in mp4_files_in_dir: |
|
if mp4_file != filepath: |
|
os.remove(mp4_file) |
|
|
|
|
|
|
|
|
|
def align(lang, Microphone, File_Upload, text, col1, col2, col3, progress=gr.Progress()): |
|
|
|
|
|
|
|
|
|
|
|
utt_id = uuid.uuid4() |
|
output_video_filepath = f"{utt_id}.mp4" |
|
delete_mp4s_except_given_filepath(output_video_filepath) |
|
|
|
output_info = "" |
|
|
|
progress(0, desc="Validating input") |
|
|
|
|
|
if lang in ["en", "de", "es"]: |
|
model_name = f"stt_{lang}_fastconformer_hybrid_large_pc" |
|
elif lang in ["fr"]: |
|
model_name = f"stt_{lang}_conformer_ctc_large" |
|
elif lang in ["zh"]: |
|
model_name = f"stt_{lang}_citrinet_1024_gamma_0_25" |
|
|
|
|
|
if (Microphone is not None) and (File_Upload is not None): |
|
raise gr.Error("Please use either the microphone or file upload input - not both") |
|
|
|
elif (Microphone is None) and (File_Upload is None): |
|
raise gr.Error("You have to either use the microphone or upload an audio file") |
|
|
|
elif Microphone is not None: |
|
file = Microphone |
|
else: |
|
file = File_Upload |
|
|
|
|
|
audio_data, duration = get_audio_data_and_duration(file) |
|
|
|
if duration > 4 * 60: |
|
raise gr.Error( |
|
f"Detected that uploaded audio has duration {duration/60:.1f} mins - please only upload audio of less than 4 mins duration" |
|
) |
|
|
|
|
|
progress(0.1, desc="Loading speech recognition model") |
|
model = ASRModel.from_pretrained(model_name) |
|
|
|
if text: |
|
S_prime, T = get_S_prime_and_T(text, model_name, model, duration) |
|
|
|
if S_prime > T: |
|
raise gr.Error( |
|
f"The number of tokens in the input text is too long compared to the duration of the audio." |
|
f" This model can handle {T} tokens + token repetitions at most. You have provided {S_prime} tokens + token repetitions. " |
|
f" (Adjacent tokens that are not in the model's vocabulary are also counted as a token repetition.)" |
|
) |
|
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
audio_path = os.path.join(tmpdir, f'{utt_id}.wav') |
|
soundfile.write(audio_path, audio_data, SAMPLE_RATE) |
|
|
|
|
|
if not text: |
|
progress(0.2, desc="Transcribing audio") |
|
text = model.transcribe([audio_path])[0] |
|
if 'hybrid' in model_name: |
|
text = text[0] |
|
|
|
if text == "": |
|
raise gr.Error( |
|
"ERROR: the ASR model did not detect any speech in the input audio. Please upload audio with speech." |
|
) |
|
|
|
output_info += ( |
|
"You did not enter any input text, so the ASR model's transcription will be used:\n" |
|
"--------------------------\n" |
|
f"{text}\n" |
|
"--------------------------\n" |
|
f"You could try pasting the transcription into the text input box, correcting any" |
|
" transcription errors, and clicking 'Submit' again." |
|
) |
|
|
|
if lang == "zh" and " " not in text: |
|
|
|
text = " ".join(jieba.cut(text)) |
|
|
|
data = { |
|
"audio_filepath": audio_path, |
|
"text": text, |
|
} |
|
manifest_path = os.path.join(tmpdir, f"{utt_id}_manifest.json") |
|
with open(manifest_path, 'w') as fout: |
|
fout.write(f"{json.dumps(data)}\n") |
|
|
|
|
|
if "|" in text: |
|
resegment_text_to_fill_space = False |
|
else: |
|
resegment_text_to_fill_space = True |
|
|
|
alignment_config = AlignmentConfig( |
|
pretrained_name=model_name, |
|
manifest_filepath=manifest_path, |
|
output_dir=f"{tmpdir}/nfa_output/", |
|
audio_filepath_parts_in_utt_id=1, |
|
batch_size=1, |
|
use_local_attention=True, |
|
additional_segment_grouping_separator="|", |
|
|
|
|
|
save_output_file_formats=["ass"], |
|
ass_file_config=ASSFileConfig( |
|
fontsize=45, |
|
resegment_text_to_fill_space=resegment_text_to_fill_space, |
|
max_lines_per_segment=4, |
|
text_already_spoken_rgb=hex_to_rgb_list(col1), |
|
text_being_spoken_rgb=hex_to_rgb_list(col2), |
|
text_not_yet_spoken_rgb=hex_to_rgb_list(col3), |
|
), |
|
) |
|
|
|
progress(0.5, desc="Aligning audio") |
|
|
|
main(alignment_config) |
|
|
|
progress(0.95, desc="Saving generated alignments") |
|
|
|
|
|
if lang=="zh": |
|
|
|
ass_file_for_video = f"{tmpdir}/nfa_output/ass/tokens/{utt_id}.ass" |
|
else: |
|
|
|
ass_file_for_video = f"{tmpdir}/nfa_output/ass/words/{utt_id}.ass" |
|
|
|
ffmpeg_command = ( |
|
f"ffmpeg -y -i {audio_path} " |
|
"-f lavfi -i color=c=white:s=1280x720:r=50 " |
|
"-crf 1 -shortest -vcodec libx264 -pix_fmt yuv420p " |
|
f"-vf 'ass={ass_file_for_video}' " |
|
f"{output_video_filepath}" |
|
) |
|
|
|
os.system(ffmpeg_command) |
|
|
|
return output_video_filepath, gr.update(value=output_info, visible=True), output_video_filepath |
|
|
|
|
|
def delete_non_tmp_video(video_path): |
|
if video_path: |
|
if os.path.exists(video_path): |
|
os.remove(video_path) |
|
return None |
|
|
|
|
|
with gr.Blocks(title="NeMo Forced Aligner", theme="huggingface") as demo: |
|
non_tmp_output_video_filepath = gr.State([]) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("# NeMo Forced Aligner") |
|
gr.Markdown( |
|
"Demo for [NeMo Forced Aligner](https://github.com/NVIDIA/NeMo/tree/main/tools/nemo_forced_aligner) (NFA). " |
|
"Upload audio and (optionally) the text spoken in the audio to generate a video where each part of the text will be highlighted as it is spoken. ", |
|
) |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("## Input") |
|
lang_drop = gr.Dropdown(choices=["de", "en", "es", "fr", "zh"], value="en", label="Audio language",) |
|
|
|
mic_in = gr.Audio(source="microphone", type='filepath', label="Microphone input (max 4 mins)") |
|
audio_file_in = gr.Audio(source="upload", type='filepath', label="File upload (max 4 mins)") |
|
ref_text = gr.Textbox( |
|
label="[Optional] The reference text. Use '|' separators to specify which text will appear together. " |
|
"Leave this field blank to use an ASR model's transcription as the reference text instead." |
|
) |
|
|
|
gr.Markdown("[Optional] For fun - adjust the colors of the text in the output video") |
|
with gr.Row(): |
|
col1 = gr.ColorPicker(label="text already spoken", value="#fcba03") |
|
col2 = gr.ColorPicker(label="text being spoken", value="#bf45bf") |
|
col3 = gr.ColorPicker(label="text to be spoken", value="#3e1af0") |
|
|
|
submit_button = gr.Button("Submit") |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("## Output") |
|
video_out = gr.Video(label="output video") |
|
text_out = gr.Textbox(label="output info", visible=False) |
|
|
|
submit_button.click( |
|
fn=align, |
|
inputs=[lang_drop, mic_in, audio_file_in, ref_text, col1, col2, col3,], |
|
outputs=[video_out, text_out, non_tmp_output_video_filepath], |
|
).then( |
|
fn=delete_non_tmp_video, inputs=[non_tmp_output_video_filepath], outputs=None, |
|
) |
|
|
|
demo.queue() |
|
demo.launch() |
|
|
|
|