Norphel's picture
Update app.py
7f18afd verified
import streamlit as st
import soundfile as sf
import numpy as np
import torch
import librosa
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from transformers import VitsModel, AutoTokenizer
import tempfile
import os
api_key = os.getenv("hf_token")
st.title("Dzongkha Speech-to-Text")
# Check if a GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"
st.write(f"Using device: {device.upper()}")
# Load the model only once (for performance)
@st.cache_resource
def load_asr_model():
model_id = "Norphel/wav2vec2-large-mms-1b-dzo-colab"
model = Wav2Vec2ForCTC.from_pretrained(model_id).to(device) # Use CPU or GPU
processor = Wav2Vec2Processor.from_pretrained(model_id)
return model, processor
@st.cache_resource
def load_translation_model():
model = AutoModelForSeq2SeqLM.from_pretrained("Norphel/Dz_en", token=api_key)
tokenizer = AutoTokenizer.from_pretrained("Norphel/Dz_en", token=api_key)
return model, tokenizer
@st.cache_resource
def load_tts_model():
model = VitsModel.from_pretrained("Norphel/MMS-TTS-Dzo-N3", token = api_key)
tokenizer = AutoTokenizer.from_pretrained("Norphel/MMS-TTS-Dzo-N3", token = api_key)
return model, tokenizer
def generate_voice(text):
inputs = tts_tokenizer(text, return_tensors="pt")
with torch.no_grad():
output = tts_model(**inputs).waveform
return output
def translate(text):
inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True).input_ids.to(device) # Move inputs to GPU
translation_model.to(device) # Move model to GPU
outputs = translation_model.generate(inputs, max_new_tokens=512)
decoded_output = translation_tokenizer.decode(outputs[0], skip_special_tokens=True)
return decoded_output
# Corrected function to load the ASR model
asr_model, processor = load_asr_model()
translation_model, translation_tokenizer = load_translation_model()
tts_model, tts_tokenizer = load_tts_model()
# Audio Recording Widget
audio_value = st.audio_input("Record a voice message")
if audio_value:
st.audio(audio_value, format="audio/wav")
# Save the uploaded audio to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
temp_file.write(audio_value.getvalue())
temp_filename = temp_file.name
# Read audio file using soundfile
with sf.SoundFile(temp_filename) as audio_file:
sample_rate = audio_file.samplerate
dtype = audio_file.subtype # Example: PCM_16
st.write(f"Original Sample Rate: {sample_rate} Hz")
st.write(f"Data Type: {dtype}")
# Convert to 16kHz Float32
with sf.SoundFile(temp_filename) as audio_file:
audio_data = audio_file.read(dtype="float32")
if sample_rate != 16000:
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
# Run Speech-to-Text
def generate_text(audio):
input_dict = processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
logits = asr_model(input_dict.input_values.to(device)).logits
pred_ids = torch.argmax(logits, dim=-1)[0]
return processor.decode(pred_ids)
# Get Transcription
transcription = generate_text(audio_data)
translation = translate(transcription)
audio = generate_voice(transcription)
st.write(translation)
st.audio(audio)