Spaces:
Sleeping
Sleeping
# app.py | |
# -*- coding: utf-8 -*- | |
""" | |
Vietnamese End-to-End Speech Recognition using Wav2Vec 2.0 with Speaker Diarization. | |
Streamlit Application with merged speaker segments and timestamps. | |
""" | |
import os | |
import zipfile | |
import torch | |
import soundfile as sf | |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
import kenlm | |
from pyctcdecode import Alphabet, BeamSearchDecoderCTC, LanguageModel | |
from huggingface_hub import hf_hub_download | |
import streamlit as st | |
import numpy as np | |
import librosa | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
def load_model_and_tokenizer(cache_dir='./cache/'): | |
st.info("Loading processor and model...") | |
processor = Wav2Vec2Processor.from_pretrained( | |
"nguyenvulebinh/wav2vec2-base-vietnamese-250h", | |
cache_dir=cache_dir | |
) | |
model = Wav2Vec2ForCTC.from_pretrained( | |
"nguyenvulebinh/wav2vec2-base-vietnamese-250h", | |
cache_dir=cache_dir | |
) | |
st.info("Downloading language model...") | |
lm_zip_file = hf_hub_download( | |
repo_id="nguyenvulebinh/wav2vec2-base-vietnamese-250h", | |
filename="vi_lm_4grams.bin.zip", | |
cache_dir=cache_dir | |
) | |
st.info("Extracting language model...") | |
with zipfile.ZipFile(lm_zip_file, 'r') as zip_ref: | |
zip_ref.extractall(cache_dir) | |
lm_file = os.path.join(cache_dir, 'vi_lm_4grams.bin') | |
if not os.path.isfile(lm_file): | |
raise FileNotFoundError(f"Language model file not found: {lm_file}") | |
st.success("Processor, model, and language model loaded successfully.") | |
return processor, model, lm_file | |
def get_decoder_ngram_model(_tokenizer, ngram_lm_path): | |
st.info("Building decoder with n-gram language model...") | |
vocab_dict = _tokenizer.get_vocab() | |
sorted_vocab = sorted((value, key) for (key, value) in vocab_dict.items()) | |
vocab_list = [token for _, token in sorted_vocab][:-2] # Exclude special tokens | |
alphabet = Alphabet.build_alphabet(vocab_list) | |
lm_model = kenlm.Model(ngram_lm_path) | |
decoder = BeamSearchDecoderCTC(alphabet, language_model=LanguageModel(lm_model)) | |
st.success("Decoder built successfully.") | |
return decoder | |
def transcribe_chunk(model, processor, decoder, speech_chunk, sampling_rate): | |
if speech_chunk.ndim > 1: | |
speech_chunk = np.mean(speech_chunk, axis=1) | |
speech_chunk = speech_chunk.astype(np.float32) | |
target_sr = 16000 | |
if sampling_rate != target_sr: | |
speech_chunk = librosa.resample(speech_chunk, orig_sr=sampling_rate, target_sr=target_sr) | |
sampling_rate = target_sr | |
MIN_DURATION = 0.5 # seconds | |
MIN_SAMPLES = int(MIN_DURATION * sampling_rate) | |
if len(speech_chunk) < MIN_SAMPLES: | |
# Pad with zeros | |
padding = MIN_SAMPLES - len(speech_chunk) | |
speech_chunk = np.pad(speech_chunk, (0, padding), 'constant') | |
input_values = processor( | |
speech_chunk, sampling_rate=sampling_rate, return_tensors="pt" | |
).input_values | |
with torch.no_grad(): | |
logits = model(input_values).logits[0] | |
beam_search_output = decoder.decode( | |
logits.cpu().detach().numpy(), | |
beam_width=500 | |
) | |
return beam_search_output | |
def alternative_speaker_diarization(audio_file, num_speakers=2): | |
try: | |
# Use librosa to load the audio file | |
y, sr = librosa.load(audio_file, sr=None) | |
# Rough segmentation based on energy | |
intervals = librosa.effects.split(y, top_db=30) # Adjust top_db as needed | |
# Merge very short intervals | |
MIN_INTERVAL_DURATION = 0.5 # seconds | |
MIN_SAMPLES = int(MIN_INTERVAL_DURATION * sr) | |
merged_intervals = [] | |
for interval in intervals: | |
if merged_intervals and (interval[0] - merged_intervals[-1][1]) < MIN_SAMPLES: | |
merged_intervals[-1][1] = interval[1] | |
else: | |
merged_intervals.append([interval[0], interval[1]]) | |
# Assign speakers cyclically | |
segments = [] | |
for i, (start, end) in enumerate(merged_intervals): | |
speaker_id = i % num_speakers | |
start_time = start / sr | |
end_time = end / sr | |
segments.append((start_time, end_time, speaker_id)) | |
return segments | |
except Exception as e: | |
st.error(f"Speaker diarization failed: {e}") | |
# Fallback to a simple equal-length segmentation | |
audio, sr = sf.read(audio_file) | |
total_duration = len(audio) / sr | |
segment_duration = total_duration / num_speakers | |
segments = [] | |
for i in range(num_speakers): | |
start = i * segment_duration | |
end = (i + 1) * segment_duration | |
segments.append((start, end, i)) | |
return segments | |
def process_segments(audio_file, segments, model, processor, decoder, sampling_rate=16000): | |
speech, sr = sf.read(audio_file) | |
final_transcriptions = [] | |
# Remove duplicate or overlapping segments | |
unique_segments = [] | |
for segment in sorted(segments, key=lambda x: x[0]): | |
if not unique_segments or segment[0] >= unique_segments[-1][1]: | |
unique_segments.append(segment) | |
for start, end, speaker_id in unique_segments: | |
start_sample = int(start * sr) | |
end_sample = int(end * sr) | |
speech_chunk = speech[start_sample:end_sample] | |
transcript = transcribe_chunk(model, processor, decoder, speech_chunk, sr) | |
# Only add non-empty transcripts | |
if transcript.strip(): | |
# Lưu (start, end, speaker_id, transcript) | |
final_transcriptions.append((start, end, speaker_id, transcript)) | |
return final_transcriptions | |
def format_timestamp(seconds): | |
# Định dạng thời gian thành MM:SS | |
total_seconds = int(seconds) | |
mm = total_seconds // 60 | |
ss = total_seconds % 60 | |
return f"{mm:02d}:{ss:02d}" | |
def merge_speaker_segments(final_transcriptions): | |
# Gộp các đoạn cùng speaker liên tiếp | |
if not final_transcriptions: | |
return [] | |
merged_results = [] | |
prev_start, prev_end, prev_speaker_id, prev_text = final_transcriptions[0] | |
for i in range(1, len(final_transcriptions)): | |
start, end, speaker_id, text = final_transcriptions[i] | |
if speaker_id == prev_speaker_id: | |
# Cùng speaker, gộp đoạn | |
prev_end = end | |
prev_text += " " + text | |
else: | |
# Khác speaker | |
merged_results.append((prev_start, prev_end, prev_speaker_id, prev_text)) | |
prev_start, prev_end, prev_speaker_id, prev_text = start, end, speaker_id, text | |
# Thêm đoạn cuối cùng | |
merged_results.append((prev_start, prev_end, prev_speaker_id, prev_text)) | |
return merged_results | |
def main(): | |
st.title("🇻🇳 Vietnamese Speech Recognition with Speaker Diarization (with merging & timestamps)") | |
st.write(""" | |
Upload an audio file, select the number of speakers, and get the transcribed text with timestamps and merged segments for each speaker. | |
""") | |
# Sidebar for inputs | |
st.sidebar.header("Input Parameters") | |
uploaded_file = st.sidebar.file_uploader("Upload Audio File", type=["wav", "mp3", "flac", "m4a"]) | |
num_speakers = st.sidebar.slider("Number of Speakers", min_value=1, max_value=5, value=2, step=1) | |
if uploaded_file is not None: | |
# Save the uploaded file to a temporary location | |
temp_audio_path = "temp_audio_file" | |
with open(temp_audio_path, "wb") as f: | |
f.write(uploaded_file.getbuffer()) | |
# Display audio player | |
st.audio(uploaded_file, format='audio/wav') | |
if st.button("Transcribe"): | |
with st.spinner("Processing..."): | |
try: | |
# Load models | |
processor, model, lm_file = load_model_and_tokenizer() | |
decoder = get_decoder_ngram_model(processor.tokenizer, lm_file) | |
# Speaker diarization | |
segments = alternative_speaker_diarization(temp_audio_path, num_speakers=num_speakers) | |
if not segments: | |
st.warning("No speech segments detected.") | |
return | |
# Process segments | |
final_transcriptions = process_segments(temp_audio_path, segments, model, processor, decoder) | |
# Merge consecutive segments of the same speaker | |
merged_results = merge_speaker_segments(final_transcriptions) | |
# Display results | |
if merged_results: | |
st.success("Transcription Completed!") | |
transcription_text = "" | |
for start_time, end_time, speaker_id, transcript in merged_results: | |
start_str = format_timestamp(start_time) | |
end_str = format_timestamp(end_time) | |
line = f"{start_str} - {end_str} - Speaker {speaker_id + 1}: {transcript}" | |
st.markdown(line) | |
transcription_text += line + "\n" | |
# Provide download link | |
st.download_button( | |
label="Download Transcription", | |
data=transcription_text, | |
file_name="transcription.txt", | |
mime="text/plain" | |
) | |
else: | |
st.warning("No transcriptions available.") | |
except Exception as e: | |
st.error(f"An error occurred during processing: {e}") | |
# Optionally, remove the temporary file after processing | |
if os.path.exists(temp_audio_path): | |
os.remove(temp_audio_path) | |
else: | |
st.info("Please upload an audio file to get started.") | |
if __name__ == '__main__': | |
main() | |