Spaces:
Running
Running
# Imports | |
from pathlib import Path | |
import tempfile | |
import os | |
import gradio as gr | |
import librosa | |
import tgt.core | |
import tgt.io3 | |
import soundfile as sf | |
from transformers import pipeline | |
# Constants | |
TEXTGRID_DIR = tempfile.mkdtemp() | |
DEFAULT_MODEL = "ginic/data_seed_bs64_4_wav2vec2-large-xlsr-53-buckeye-ipa" | |
TEXTGRID_DOWNLOAD_TEXT = "Download TextGrid file" | |
TEXTGRID_NAME_INPUT_LABEL = "TextGrid file name" | |
# Selection of models | |
VALID_MODELS = [ | |
"ctaguchi/wav2vec2-large-xlsr-japlmthufielta-ipa1000-ns", | |
"ctaguchi/wav2vec2-large-xlsr-japlmthufielta-ipa-plus-2000", | |
"ginic/data_seed_bs64_1_wav2vec2-large-xlsr-53-buckeye-ipa", | |
"ginic/data_seed_bs64_2_wav2vec2-large-xlsr-53-buckeye-ipa", | |
"ginic/data_seed_bs64_3_wav2vec2-large-xlsr-53-buckeye-ipa", | |
"ginic/data_seed_bs64_4_wav2vec2-large-xlsr-53-buckeye-ipa", | |
"ginic/gender_split_30_female_1_wav2vec2-large-xlsr-53-buckeye-ipa", | |
"ginic/gender_split_30_female_2_wav2vec2-large-xlsr-53-buckeye-ipa", | |
"ginic/gender_split_30_female_3_wav2vec2-large-xlsr-53-buckeye-ipa", | |
"ginic/gender_split_30_female_4_wav2vec2-large-xlsr-53-buckeye-ipa", | |
"ginic/gender_split_30_female_5_wav2vec2-large-xlsr-53-buckeye-ipa", | |
"ginic/gender_split_70_female_1_wav2vec2-large-xlsr-53-buckeye-ipa", | |
"ginic/gender_split_70_female_2_wav2vec2-large-xlsr-53-buckeye-ipa", | |
"ginic/gender_split_70_female_3_wav2vec2-large-xlsr-53-buckeye-ipa", | |
"ginic/gender_split_70_female_4_wav2vec2-large-xlsr-53-buckeye-ipa", | |
"ginic/gender_split_70_female_5_wav2vec2-large-xlsr-53-buckeye-ipa", | |
"ginic/vary_individuals_old_only_1_wav2vec2-large-xlsr-53-buckeye-ipa", | |
"ginic/vary_individuals_old_only_2_wav2vec2-large-xlsr-53-buckeye-ipa", | |
"ginic/vary_individuals_old_only_3_wav2vec2-large-xlsr-53-buckeye-ipa", | |
"ginic/vary_individuals_young_only_1_wav2vec2-large-xlsr-53-buckeye-ipa", | |
"ginic/vary_individuals_young_only_2_wav2vec2-large-xlsr-53-buckeye-ipa", | |
"ginic/vary_individuals_young_only_3_wav2vec2-large-xlsr-53-buckeye-ipa", | |
] | |
def load_model_and_predict( | |
model_name: str, | |
audio_in: str, | |
model_state: dict, | |
): | |
try: | |
if audio_in is None: | |
return ( | |
"", | |
model_state, | |
gr.Textbox(label=TEXTGRID_NAME_INPUT_LABEL, interactive=False), | |
) | |
if model_state["model_name"] != model_name: | |
model_state = { | |
"loaded_model": pipeline(task="automatic-speech-recognition", model=model_name), | |
"model_name": model_name, | |
} | |
prediction = model_state["loaded_model"](audio_in)["text"] | |
return prediction, model_state | |
except Exception as e: | |
raise gr.Error(f"Failed to load model: {str(e)}") | |
def get_textgrid_contents(audio_in, textgrid_tier_name, transcription_prediction): | |
if audio_in is None or transcription_prediction is None: | |
return "" | |
duration = librosa.get_duration(path=audio_in) | |
annotation = tgt.core.Interval(0, duration, transcription_prediction) | |
transcription_tier = tgt.core.IntervalTier( | |
start_time=0, end_time=duration, name=textgrid_tier_name | |
) | |
transcription_tier.add_annotation(annotation) | |
textgrid = tgt.core.TextGrid() | |
textgrid.add_tier(transcription_tier) | |
return tgt.io3.export_to_long_textgrid(textgrid) | |
def write_textgrid(textgrid_contents, textgrid_filename): | |
"""Writes the text grid contents to a named file in the temporary directory. | |
Returns the path for download. | |
""" | |
textgrid_path = Path(TEXTGRID_DIR) / Path(textgrid_filename).name | |
textgrid_path.write_text(textgrid_contents) | |
return textgrid_path | |
def get_interactive_download_button(textgrid_contents, textgrid_filename): | |
return gr.DownloadButton( | |
label=TEXTGRID_DOWNLOAD_TEXT, | |
variant="primary", | |
interactive=True, | |
value=write_textgrid(textgrid_contents, textgrid_filename), | |
) | |
def transcribe_intervals(audio_in, textgrid_path, source_tier, target_tier, model_state): | |
if audio_in is None or textgrid_path is None: | |
return "Missing audio or TextGrid input file." | |
tg=tgt.io.read_textgrid(textgrid_path.name) | |
tier = tg.get_tier_by_name(source_tier) | |
ipa_tier = tgt.core.IntervalTier(name=target_tier) | |
for interval in tier.intervals: | |
if not interval.text.strip(): # Skip empty text intervals | |
continue | |
start, end = interval.start_time, interval.end_time | |
try: | |
y, sr = librosa.load(audio_in, sr=None, offset=start, duration=end-start) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: | |
sf.write(temp_audio.name, y, sr) | |
prediction = model_state["loaded_model"](temp_audio.name)["text"] | |
ipa_tier.add_annotation(tgt.core.Interval(start, end, prediction)) | |
os.remove(temp_audio.name) | |
except Exception as e: | |
ipa_tier.add_annotation(tgt.core.Interval(start, end, f"[Error]: {str(e)}")) | |
tg.add_tier(ipa_tier) | |
tgt_str = tgt.io3.export_to_long_textgrid(tg) | |
return tgt_str | |
def extract_tier_names(textgrid_file): | |
try: | |
tg = tgt.io.read_textgrid(textgrid_file.name) | |
tier_names = [tier.name for tier in tg.tiers] | |
return gr.update(choices=tier_names, value=tier_names[0] if tier_names else None) | |
except Exception as e: | |
return gr.update(choices=[], value=None) | |
def validate_textgrid_for_intervals(audio_path, textgrid_file): | |
try: | |
if not audio_path or not textgrid_file: | |
return gr.update(interactive=False) | |
audio_duration = librosa.get_duration(path=audio_path) | |
tg = tgt.io.read_textgrid(textgrid_file.name) | |
tg_end_time = max(tier.end_time for tier in tg.tiers) | |
if tg_end_time > audio_duration: | |
raise gr.Error( | |
f"TextGrid ends at {tg_end_time:.2f}s but audio is only {audio_duration:.2f}s. " | |
"Please upload matching files." | |
) | |
epsilon = 0.01 | |
if abs(tg_end_time - audio_duration) > epsilon: | |
gr.Warning( | |
f"TextGrid ends at {tg_end_time:.2f}s but audio is {audio_duration:.2f}s. " | |
"Only the annotated portion will be transcribed." | |
) | |
return gr.update(interactive=True) | |
except Exception as e: | |
raise gr.Error(f"Invalid TextGrid or audio file:\n{str(e)}") | |
def launch_demo(): | |
initial_model = { | |
"loaded_model": pipeline( | |
task="automatic-speech-recognition", model=DEFAULT_MODEL | |
), | |
"model_name": DEFAULT_MODEL, | |
} | |
with gr.Blocks() as demo: | |
gr.Markdown("""# Automatic International Phonetic Alphabet Transcription | |
This demo allows you to experiment with producing phonetic transcriptions of uploaded or recorded audio using a selected automatic speech recognition (ASR) model.""") | |
# Dropdown for model selection | |
model_name = gr.Dropdown( | |
VALID_MODELS, | |
value=DEFAULT_MODEL, | |
label="IPA transcription ASR model", | |
info="Select the model to use for prediction.", | |
) | |
# Dropdown for transcription type selection | |
transcription_type = gr.Dropdown( | |
choices=["Full Audio", "TextGrid Interval"], | |
label="Transcription Type", | |
value=None, | |
interactive=True, | |
) | |
model_state = gr.State(value=initial_model) | |
# Full audio transcription section | |
with gr.Column(visible=False) as full_audio_section: | |
full_audio = gr.Audio(type="filepath", show_download_button=True, label="Upload Audio File") | |
full_transcribe_btn = gr.Button("Transcribe Full Audio", interactive=False, variant="primary") | |
full_prediction = gr.Textbox(label="IPA Transcription", show_copy_button=True) | |
full_textgrid_tier = gr.Textbox(label="TextGrid Tier Name", value="transcription", interactive=True) | |
full_textgrid_contents = gr.Textbox(label="TextGrid Contents", show_copy_button=True) | |
full_download_btn = gr.DownloadButton(label=TEXTGRID_DOWNLOAD_TEXT, interactive=False, variant="primary") | |
full_reset_btn = gr.Button("Reset", variant="secondary") | |
# Interval transcription section | |
with gr.Column(visible=False) as interval_section: | |
interval_audio = gr.Audio(type="filepath", show_download_button=True, label="Upload Audio File") | |
interval_textgrid_file = gr.File(file_types=[".TextGrid"], label="Upload TextGrid File") | |
tier_names = gr.Dropdown(label="Source Tier (existing)", choices=[], interactive=True) | |
target_tier = gr.Textbox(label="Target Tier (new)", value="IPATier", placeholder="e.g. IPATier") | |
interval_transcribe_btn = gr.Button("Transcribe Intervals", interactive=False, variant="primary") | |
interval_result = gr.Textbox(label="IPA Interval Transcription", show_copy_button=True, interactive=False) | |
interval_download_btn = gr.DownloadButton(label=TEXTGRID_DOWNLOAD_TEXT, interactive=False, variant="primary") | |
interval_reset_btn = gr.Button("Reset", variant="secondary") | |
# Section visibility toggle | |
transcription_type.change( | |
fn=lambda t: ( | |
gr.update(visible=t == "Full Audio"), | |
gr.update(visible=t == "TextGrid Interval"), | |
), | |
inputs=transcription_type, | |
outputs=[full_audio_section, interval_section], | |
) | |
# Enable full transcribe button after audio uploaded | |
full_audio.change( | |
fn=lambda audio: gr.update(interactive=audio is not None), | |
inputs=full_audio, | |
outputs=full_transcribe_btn, | |
) | |
# Full transcription logic | |
full_transcribe_btn.click( | |
fn=load_model_and_predict, | |
inputs=[model_name, full_audio, model_state], | |
outputs=[full_prediction, model_state], | |
) | |
full_prediction.change( | |
fn=get_textgrid_contents, | |
inputs=[full_audio, full_textgrid_tier, full_prediction], | |
outputs=[full_textgrid_contents], | |
) | |
full_textgrid_contents.change( | |
fn=lambda tg_text, audio_path: get_interactive_download_button( | |
tg_text, | |
Path(audio_path).with_suffix(".TextGrid").name if audio_path else "output.TextGrid" | |
), | |
inputs=[full_textgrid_contents, full_audio], | |
outputs=[full_download_btn], | |
) | |
full_reset_btn.click( | |
fn=lambda: (None, "", "", "", gr.update(interactive=False)), | |
outputs=[full_audio, full_prediction, full_textgrid_contents, full_download_btn], | |
) | |
# Enable interval transcribe button only when both files are uploaded | |
interval_audio.change( | |
fn=validate_textgrid_for_intervals, | |
inputs=[interval_audio, interval_textgrid_file], | |
outputs=[interval_transcribe_btn], | |
) | |
interval_textgrid_file.change( | |
fn=validate_textgrid_for_intervals, | |
inputs=[interval_audio, interval_textgrid_file], | |
outputs=[interval_transcribe_btn], | |
) | |
# Interval logic | |
interval_textgrid_file.change( | |
fn=extract_tier_names, | |
inputs=[interval_textgrid_file], | |
outputs=[tier_names], | |
) | |
interval_transcribe_btn.click( | |
fn=transcribe_intervals, | |
inputs=[interval_audio, interval_textgrid_file, tier_names, target_tier, model_state], | |
outputs=[interval_result], | |
) | |
interval_result.change( | |
fn=lambda tg_text, audio_path: gr.update( | |
value=write_textgrid( | |
tg_text, | |
Path(audio_path).with_suffix("").name+"_IPA.TextGrid" | |
), | |
interactive=True, | |
), | |
inputs=[interval_result, interval_audio], | |
outputs=[interval_download_btn], | |
) | |
interval_reset_btn.click( | |
fn=lambda: (None, None, gr.update(choices=[]), "IPATier", "", gr.update(interactive=False)), | |
outputs=[interval_audio, interval_textgrid_file, tier_names, target_tier, interval_result, interval_download_btn], | |
) | |
demo.launch(max_file_size="100mb") | |
if __name__ == "__main__": | |
launch_demo() |