File size: 3,470 Bytes
44c8041
 
cbf2f18
83dc08b
60d0c64
44c8041
 
 
 
7f18afd
 
 
27ddc0e
44c8041
fcd28b3
44c8041
7603c38
44c8041
 
 
 
 
 
 
 
 
 
 
 
7f18afd
 
44c8041
 
 
 
7f18afd
 
44c8041
 
 
 
23f9512
44c8041
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import streamlit as st
import soundfile as sf
import numpy as np
import torch
import librosa
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from transformers import VitsModel, AutoTokenizer
import tempfile
import os

api_key = os.getenv("hf_token")

st.title("Dzongkha Speech-to-Text")

# Check if a GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"
st.write(f"Using device: {device.upper()}")

# Load the model only once (for performance)
@st.cache_resource
def load_asr_model():
    model_id = "Norphel/wav2vec2-large-mms-1b-dzo-colab"
    model = Wav2Vec2ForCTC.from_pretrained(model_id).to(device)  # Use CPU or GPU
    processor = Wav2Vec2Processor.from_pretrained(model_id)
    return model, processor

@st.cache_resource
def load_translation_model():
    model = AutoModelForSeq2SeqLM.from_pretrained("Norphel/Dz_en", token=api_key)
    tokenizer = AutoTokenizer.from_pretrained("Norphel/Dz_en", token=api_key)
    return model, tokenizer

@st.cache_resource
def load_tts_model():
    model = VitsModel.from_pretrained("Norphel/MMS-TTS-Dzo-N3", token = api_key)
    tokenizer = AutoTokenizer.from_pretrained("Norphel/MMS-TTS-Dzo-N3", token = api_key)
    return model, tokenizer

def generate_voice(text):
    inputs = tts_tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        output = tts_model(**inputs).waveform
    return output

def translate(text):
    inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)  # Move inputs to GPU
    translation_model.to(device)  # Move model to GPU
    outputs = translation_model.generate(inputs, max_new_tokens=512)
    decoded_output = translation_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return decoded_output

# Corrected function to load the ASR model
asr_model, processor = load_asr_model()
translation_model, translation_tokenizer = load_translation_model()
tts_model, tts_tokenizer = load_tts_model()

# Audio Recording Widget
audio_value = st.audio_input("Record a voice message")

if audio_value:
    st.audio(audio_value, format="audio/wav")

    # Save the uploaded audio to a temporary file
    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
        temp_file.write(audio_value.getvalue())
        temp_filename = temp_file.name

    # Read audio file using soundfile
    with sf.SoundFile(temp_filename) as audio_file:
        sample_rate = audio_file.samplerate
        dtype = audio_file.subtype  # Example: PCM_16

    st.write(f"Original Sample Rate: {sample_rate} Hz")
    st.write(f"Data Type: {dtype}")

    # Convert to 16kHz Float32
    with sf.SoundFile(temp_filename) as audio_file:
        audio_data = audio_file.read(dtype="float32")

    if sample_rate != 16000:
        audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)

    # Run Speech-to-Text
    def generate_text(audio):
        input_dict = processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
        logits = asr_model(input_dict.input_values.to(device)).logits
        pred_ids = torch.argmax(logits, dim=-1)[0]
        return processor.decode(pred_ids)

    # Get Transcription
    transcription = generate_text(audio_data)
    translation = translate(transcription)
    audio = generate_voice(transcription)
    st.write(translation)
    st.audio(audio)