Wav2Txt / app.py
Merlintxu's picture
Update app.py
b7fce90 verified
raw
history blame
10.2 kB
import gradio as gr
from transformers import pipeline, WhisperProcessor, WhisperForConditionalGeneration
import torch
import librosa
import subprocess
from langdetect import detect_langs
import os
import warnings
from transformers import logging
import math
import json
from pyannote.audio import Pipeline
import numpy as np # Asegúrate de importar numpy
# Suppress warnings
warnings.filterwarnings("ignore")
logging.set_verbosity_error()
# Inicializar numpy correctamente
np._import_array()
# Read the Hugging Face token from the environment variable
HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN")
# Updated models by language
MODELS = {
"es": [
"openai/whisper-large-v3",
"facebook/wav2vec2-large-xlsr-53-spanish",
"jonatasgrosman/wav2vec2-xls-r-1b-spanish"
],
"en": [
"openai/whisper-large-v3",
"facebook/wav2vec2-large-960h",
"microsoft/wav2vec2-base-960h"
],
"pt": [
"facebook/wav2vec2-large-xlsr-53-portuguese",
"openai/whisper-medium",
"jonatasgrosman/wav2vec2-large-xlsr-53-portuguese"
]
}
def convert_audio_to_wav(audio_path):
try:
print("Converting audio to WAV format...")
wav_path = "converted_audio.wav"
command = ["ffmpeg", "-i", audio_path, "-ac", "1", "-ar", "16000", wav_path]
subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
print(f"Audio converted to {wav_path}")
return wav_path
except Exception as e:
print(f"Error converting audio to WAV: {e}")
raise RuntimeError(f"Error converting audio to WAV: {e}")
def detect_language(audio_path):
try:
print("Detecting language...")
speech, _ = librosa.load(audio_path, sr=16000, duration=30)
processor = WhisperProcessor.from_pretrained("openai/whisper-base")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
input_features = processor(speech, sampling_rate=16000, return_tensors="pt").input_features
predicted_ids = model.generate(input_features)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
langs = detect_langs(transcription)
es_confidence = next((lang.prob for lang in langs if lang.lang == 'es'), 0)
pt_confidence = next((lang.prob for lang in langs if lang.lang == 'pt'), 0)
if abs(es_confidence - pt_confidence) < 0.2:
print("Detected language: Spanish")
return 'es'
detected_language = max(langs, key=lambda x: x.prob).lang
print(f"Detected language: {detected_language}")
return detected_language
except Exception as e:
print(f"Error detecting language: {e}")
raise RuntimeError(f"Error detecting language: {e}")
def diarize_audio(wav_audio):
try:
print("Performing diarization...")
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=HUGGINGFACE_TOKEN)
diarization = pipeline(wav_audio)
print("Diarization complete.")
return diarization
except Exception as e:
print(f"Error in diarization: {e}")
raise RuntimeError(f"Error in diarization: {e}")
def transcribe_audio_stream(audio, model_name):
try:
wav_audio = convert_audio_to_wav(audio)
speech, rate = librosa.load(wav_audio, sr=16000)
duration = len(speech) / rate
transcriptions = []
if "whisper" in model_name:
processor = WhisperProcessor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name)
chunk_duration = 30 # seconds
for i in range(0, int(duration), chunk_duration):
end = min(i + chunk_duration, duration)
chunk = speech[int(i * rate):int(end * rate)]
input_features = processor(chunk, sampling_rate=16000, return_tensors="pt").input_features
predicted_ids = model.generate(input_features)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
progress = min(100, (end / duration) * 100)
timestamp = i
transcriptions.append((timestamp, transcription, progress))
yield transcriptions, progress
else:
transcriber = pipeline("automatic-speech-recognition", model=model_name)
chunk_duration = 10 # seconds
for i in range(0, int(duration), chunk_duration):
end = min(i + chunk_duration, duration)
chunk = speech[int(i * rate):int(end * rate)]
result = transcriber(chunk)
progress = min(100, (end / duration) * 100)
timestamp = i
transcriptions.append((timestamp, result["text"], progress))
yield transcriptions, progress
except Exception as e:
print(f"Error in transcription: {e}")
raise RuntimeError(f"Error in transcription: {e}")
def merge_diarization_with_transcription(transcriptions, diarization, rate):
try:
print("Merging diarization with transcription...")
speaker_transcriptions = []
for segment in diarization.itertracks(yield_label=True):
start, end, speaker = segment
start_time = start / rate
end_time = end / rate
text_segment = ""
for ts, text, _ in transcriptions:
if start_time <= ts <= end_time:
text_segment += text + " "
speaker_transcriptions.append((start_time, end_time, speaker, text_segment.strip()))
print("Merge complete.")
return speaker_transcriptions
except Exception as e:
print(f"Error merging diarization with transcription: {e}")
raise RuntimeError(f"Error merging diarization with transcription: {e}")
def detect_and_select_model(audio):
try:
print("Detecting and selecting model...")
wav_audio = convert_audio_to_wav(audio)
language = detect_language(wav_audio)
model_options = MODELS.get(language, MODELS["en"])
print(f"Selected model: {model_options[0]}")
return language, model_options
except Exception as e:
print(f"Error detecting and selecting model: {e}")
raise RuntimeError(f"Error detecting and selecting model: {e}")
def save_transcription(transcriptions, file_format):
try:
print(f"Saving transcription to {file_format} format...")
if file_format == "txt":
file_path = "/tmp/transcription.txt"
with open(file_path, "w") as f:
for start, end, speaker, text in transcriptions:
f.write(f"[{start:.2f}-{end:.2f}] {speaker}: {text}\n")
print(f"Transcription saved to {file_path}")
return file_path
elif file_format == "json":
file_path = "/tmp/transcription.json"
with open(file_path, "w") as f:
json.dump(transcriptions, f)
print(f"Transcription saved to {file_path}")
return file_path
except Exception as e:
print(f"Error saving transcription: {e}")
raise RuntimeError(f"Error saving transcription: {e}")
def combined_interface(audio):
try:
print("Starting combined interface...")
language, model_options = detect_and_select_model(audio)
selected_model = model_options[0]
yield language, model_options, selected_model, "", 0, "Initializing...", None, None
wav_audio = convert_audio_to_wav(audio)
diarization = diarize_audio(wav_audio)
transcriptions = []
for partial_transcriptions, progress in transcribe_audio_stream(audio, selected_model):
transcriptions = partial_transcriptions
transcriptions_text = "\n".join([f"[{start}-{end}] {text}" for start, end, text in transcriptions])
progress_int = math.floor(progress)
status = f"Transcribing... {progress_int}% complete"
yield language, model_options, selected_model, transcriptions_text, progress_int, status, None, None
rate = librosa.get_samplerate(wav_audio)
speaker_transcriptions = merge_diarization_with_transcription(transcriptions, diarization, rate)
transcriptions_text = "\n".join([f"[{start:.2f}-{end:.2f}] {speaker}: {text}" for start, end, speaker, text in speaker_transcriptions])
txt_file_path = save_transcription(speaker_transcriptions, "txt")
json_file_path = save_transcription(speaker_transcriptions, "json")
os.remove(wav_audio)
yield language, model_options, selected_model, transcriptions_text, 100, "Transcription complete!", txt_file_path, json_file_path
except Exception as e:
print(f"Error in combined interface: {e}")
yield str(e), [], "", "An error occurred during processing.", 0, "Error", None, None
iface = gr.Interface(
fn=combined_interface,
inputs=gr.Audio(type="filepath"),
outputs=[
gr.Textbox(label="Detected Language"),
gr.Dropdown(label="Available Models", choices=[]),
gr.Textbox(label="Selected Model"),
gr.Textbox(label="Transcription", lines=10),
gr.Slider(minimum=0, maximum=100, label="Progress", interactive=False),
gr.Textbox(label="Status"),
gr.File(label="Download Transcription (TXT)", type="filepath"),
gr.File(label="Download Transcription (JSON)", type="filepath")
],
title="Multilingual Audio Transcriber with Real-time Display, Timestamps, and Speaker Diarization",
description="Upload an audio file to detect the language, select the transcription model, and get the transcription with timestamps and speaker labels in real-time. Download the transcription as TXT or JSON. Optimized for Spanish, English, and Portuguese.",
live=True
)
if __name__ == "__main__":
iface.queue().launch()