Spaces:
Build error
Build error
import streamlit as st | |
import torch | |
import librosa | |
import soundfile | |
import nemo.collections.asr as nemo_asr | |
import tempfile | |
import os | |
import uuid | |
from pydub import AudioSegment | |
import numpy as np | |
import io | |
SAMPLE_RATE = 16000 | |
# Load pre-trained model | |
model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained("stt_en_conformer_transducer_large") | |
model.change_decoding_strategy(None) | |
model.eval() | |
def process_audio_data(audio_data): | |
# Convert stereo to mono | |
if audio_data.channels == 2: | |
audio_data = audio_data.set_channels(1) | |
# Convert pydub audio segment to numpy array | |
audio_np = np.array(audio_data.get_array_of_samples()) | |
# Resample if necessary | |
if audio_data.frame_rate != SAMPLE_RATE: | |
audio_np = librosa.resample(audio_np, audio_data.frame_rate, SAMPLE_RATE) | |
return audio_np | |
def transcribe(audio_np): | |
with tempfile.TemporaryDirectory() as tmpdir: | |
# Save audio data to a temporary WAV file | |
audio_path = os.path.join(tmpdir, f'audio_{uuid.uuid4()}.wav') | |
soundfile.write(audio_path, audio_np, SAMPLE_RATE) | |
# Transcribe audio | |
transcriptions = model.transcribe([audio_path]) | |
# Extract best hypothesis if transcriptions form a tuple (from RNNT) | |
if isinstance(transcriptions, tuple) and len(transcriptions) == 2: | |
transcriptions = transcriptions[0] | |
return transcriptions[0] | |
st.title("Speech Recognition with NeMo Conformer Transducer Large - English") | |
# Record audio | |
st.write("Click the button below to start recording.") | |
record_state = st.checkbox("Recording") | |
if record_state: | |
# Start recording audio | |
recording = st.audio("", format="audio/wav") | |
# Stop recording when checkbox is unchecked | |
recording_file = tempfile.NamedTemporaryFile(delete=False) | |
with recording_file as f: | |
while record_state: | |
audio_data = st.audio_recorder( | |
sample_rate=SAMPLE_RATE, | |
format="wav", | |
data_format="audio/wav" | |
) | |
f.write(audio_data.getvalue()) | |
# Update recording display | |
audio_data = AudioSegment.from_wav(io.BytesIO(audio_data.getvalue())) | |
recording.audio(audio_data, format="audio/wav") | |
record_state = st.checkbox("Recording") | |
# Process and transcribe recorded audio | |
recording_file.seek(0) | |
audio_np = process_audio_data(AudioSegment.from_file(recording_file.name)) | |
transcript = transcribe(audio_np) | |
st.write("Transcription:") | |
st.write(transcript) | |