Spaces:
Runtime error
Runtime error
import streamlit as st | |
import soundfile as sf | |
import numpy as np | |
import torch | |
import librosa | |
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
from transformers import VitsModel, AutoTokenizer | |
import tempfile | |
import os | |
api_key = os.getenv("hf_token") | |
st.title("Dzongkha Speech-to-Text") | |
# Check if a GPU is available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
st.write(f"Using device: {device.upper()}") | |
# Load the model only once (for performance) | |
def load_asr_model(): | |
model_id = "Norphel/wav2vec2-large-mms-1b-dzo-colab" | |
model = Wav2Vec2ForCTC.from_pretrained(model_id).to(device) # Use CPU or GPU | |
processor = Wav2Vec2Processor.from_pretrained(model_id) | |
return model, processor | |
def load_translation_model(): | |
model = AutoModelForSeq2SeqLM.from_pretrained("Norphel/Dz_en", token=api_key) | |
tokenizer = AutoTokenizer.from_pretrained("Norphel/Dz_en", token=api_key) | |
return model, tokenizer | |
def load_tts_model(): | |
model = VitsModel.from_pretrained("Norphel/MMS-TTS-Dzo-N3", token = api_key) | |
tokenizer = AutoTokenizer.from_pretrained("Norphel/MMS-TTS-Dzo-N3", token = api_key) | |
return model, tokenizer | |
def generate_voice(text): | |
inputs = tts_tokenizer(text, return_tensors="pt") | |
with torch.no_grad(): | |
output = tts_model(**inputs).waveform | |
return output | |
def translate(text): | |
inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True).input_ids.to(device) # Move inputs to GPU | |
translation_model.to(device) # Move model to GPU | |
outputs = translation_model.generate(inputs, max_new_tokens=512) | |
decoded_output = translation_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return decoded_output | |
# Corrected function to load the ASR model | |
asr_model, processor = load_asr_model() | |
translation_model, translation_tokenizer = load_translation_model() | |
tts_model, tts_tokenizer = load_tts_model() | |
# Audio Recording Widget | |
audio_value = st.audio_input("Record a voice message") | |
if audio_value: | |
st.audio(audio_value, format="audio/wav") | |
# Save the uploaded audio to a temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file: | |
temp_file.write(audio_value.getvalue()) | |
temp_filename = temp_file.name | |
# Read audio file using soundfile | |
with sf.SoundFile(temp_filename) as audio_file: | |
sample_rate = audio_file.samplerate | |
dtype = audio_file.subtype # Example: PCM_16 | |
st.write(f"Original Sample Rate: {sample_rate} Hz") | |
st.write(f"Data Type: {dtype}") | |
# Convert to 16kHz Float32 | |
with sf.SoundFile(temp_filename) as audio_file: | |
audio_data = audio_file.read(dtype="float32") | |
if sample_rate != 16000: | |
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000) | |
# Run Speech-to-Text | |
def generate_text(audio): | |
input_dict = processor(audio, sampling_rate=16000, return_tensors="pt", padding=True) | |
logits = asr_model(input_dict.input_values.to(device)).logits | |
pred_ids = torch.argmax(logits, dim=-1)[0] | |
return processor.decode(pred_ids) | |
# Get Transcription | |
transcription = generate_text(audio_data) | |
translation = translate(transcription) | |
audio = generate_voice(transcription) | |
st.write(translation) | |
st.audio(audio) | |