Norphel commited on
Commit
44c8041
·
verified ·
1 Parent(s): bb5cc1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -52
app.py CHANGED
@@ -1,59 +1,93 @@
 
 
1
  import numpy as np
2
  import torch
3
- import gradio as gr
4
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
  import librosa
 
 
 
 
6
 
7
- # Load ASR model & processor
8
- asr_model_id = "Norphel/wav2vec2-large-mms-1b-dzo-colab"
9
- asr_model = Wav2Vec2ForCTC.from_pretrained(asr_model_id, target_lang="dzo")
10
- asr_processor = Wav2Vec2Processor.from_pretrained(asr_model_id)
11
- asr_processor.tokenizer.set_target_lang("dzo")
12
 
13
- # Use CPU if no GPU is available
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
- asr_model.to(device)
16
-
17
- # Function to process audio & generate text
18
- def generate_text(audio):
19
- if audio is None:
20
- return "No audio received"
21
-
22
- sr, data = audio # Unpack the tuple (sample rate, numpy array)
23
- print(f"Original sample rate: {sr}, dtype: {data.dtype}")
24
-
25
- # Convert to float32
26
- data = data.astype(np.float32)
27
-
28
- # Resample to 16kHz if necessary
29
- target_sr = 16000
30
- if sr != target_sr:
31
- data = librosa.resample(data, orig_sr=sr, target_sr=target_sr)
32
- sr = target_sr
33
-
34
- print(f"Processed sample rate: {sr}, dtype: {data.dtype}")
35
-
36
- # Tokenize and run inference
37
- inputs = asr_processor(data, sampling_rate=sr, return_tensors="pt", padding=True)
38
-
39
  with torch.no_grad():
40
- outputs = asr_model(**inputs).logits
41
- pred_ids = torch.argmax(outputs, dim=-1)[0]
42
-
43
- # Decode the prediction
44
- return asr_processor.decode(pred_ids)
45
-
46
- # Ensure we get a NumPy array from Gradio
47
- input_audio = gr.Audio(
48
- sources=["microphone"],
49
- type="numpy", # Ensures function gets (sr, np.ndarray)
50
- )
51
-
52
- demo = gr.Interface(
53
- fn=generate_text,
54
- inputs=input_audio,
55
- outputs="text"
56
- )
57
-
58
- if __name__ == "__main__":
59
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import soundfile as sf
3
  import numpy as np
4
  import torch
 
 
5
  import librosa
6
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
7
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
8
+ from transformers import VitsModel, AutoTokenizer
9
+ import tempfile
10
 
11
+ st.title("Dzongkha Speech-to-Text")
 
 
 
 
12
 
13
+ # Check if a GPU is available
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ st.write(f"Using device: {device.upper()}")
16
+
17
+ # Load the model only once (for performance)
18
+ @st.cache_resource
19
+ def load_asr_model():
20
+ model_id = "Norphel/wav2vec2-large-mms-1b-dzo-colab"
21
+ model = Wav2Vec2ForCTC.from_pretrained(model_id).to(device) # Use CPU or GPU
22
+ processor = Wav2Vec2Processor.from_pretrained(model_id)
23
+ return model, processor
24
+
25
+ @st.cache_resource
26
+ def load_translation_model():
27
+ model = AutoModelForSeq2SeqLM.from_pretrained("Norphel/Dz_en", token="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")
28
+ tokenizer = AutoTokenizer.from_pretrained("Norphel/Dz_en", token="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")
29
+ return model, tokenizer
30
+
31
+ @st.cache_resource
32
+ def load_tts_model():
33
+ model = VitsModel.from_pretrained("Norphel/MMS-TTS-Dzo-N3")
34
+ tokenizer = AutoTokenizer.from_pretrained("Norphel/MMS-TTS-Dzo-N3")
35
+ return model, tokenizer
36
+
37
+ def generate_voice(text):
38
+ inputs = tts_tokenizer(text, return_tensors="pt")
39
  with torch.no_grad():
40
+ output = tts_model(**inputs).waveform
41
+ return output
42
+
43
+ def translate(text):
44
+ inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True).input_ids.to(device) # Move inputs to GPU
45
+ translation_model.to(device) # Move model to GPU
46
+ outputs = translation_model.generate(inputs, max_new_tokens=512)
47
+ decoded_output = translation_tokenizer.decode(outputs[0], skip_special_tokens=True)
48
+ return decoded_output
49
+
50
+ # Corrected function to load the ASR model
51
+ asr_model, processor = load_asr_model()
52
+ translation_model, translation_tokenizer = load_translation_model()
53
+ tts_model, tts_tokenizer = load_tts_model()
54
+
55
+ # Audio Recording Widget
56
+ audio_value = st.audio_input("Record a voice message")
57
+
58
+ if audio_value:
59
+ st.audio(audio_value, format="audio/wav")
60
+
61
+ # Save the uploaded audio to a temporary file
62
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
63
+ temp_file.write(audio_value.getvalue())
64
+ temp_filename = temp_file.name
65
+
66
+ # Read audio file using soundfile
67
+ with sf.SoundFile(temp_filename) as audio_file:
68
+ sample_rate = audio_file.samplerate
69
+ dtype = audio_file.subtype # Example: PCM_16
70
+
71
+ st.write(f"Original Sample Rate: {sample_rate} Hz")
72
+ st.write(f"Data Type: {dtype}")
73
+
74
+ # Convert to 16kHz Float32
75
+ with sf.SoundFile(temp_filename) as audio_file:
76
+ audio_data = audio_file.read(dtype="float32")
77
+
78
+ if sample_rate != 16000:
79
+ audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
80
+
81
+ # Run Speech-to-Text
82
+ def generate_text(audio):
83
+ input_dict = processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
84
+ logits = asr_model(input_dict.input_values.to(device)).logits
85
+ pred_ids = torch.argmax(logits, dim=-1)[0]
86
+ return processor.decode(pred_ids)
87
+
88
+ # Get Transcription
89
+ transcription = generate_text(audio_data)
90
+ translation = translate(transcription)
91
+ audio = generate_voice(transcription)
92
+ st.write(translation)
93
+ st.audio(audio)