Norphel commited on
Commit
83dc08b
·
verified ·
1 Parent(s): 93df753

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -18
app.py CHANGED
@@ -1,35 +1,45 @@
1
  import numpy as np
 
2
  import gradio as gr
3
- from transformers import Wav2Vec2ForCTC,Wav2Vec2Processor
 
4
 
 
5
  asr_model_id = "Norphel/wav2vec2-large-mms-1b-dzo-colab"
6
- asr_model = Wav2Vec2ForCTC.from_pretrained(asr_model_id, target_lang="dzo")
7
  asr_processor = Wav2Vec2Processor.from_pretrained(asr_model_id)
8
  asr_processor.tokenizer.set_target_lang("dzo")
9
 
 
10
  def generate_text(audio):
11
- sr, data = audio
12
- print(data)
13
- input_dict = asr_processor(aud_arr, sampling_rate=16_000, return_tensors="pt", padding=True)
14
- logits = asr_model(input_dict.input_values.to("cuda")).logits
 
 
 
 
 
 
 
 
 
 
 
15
  pred_ids = torch.argmax(logits, dim=-1)[0]
16
 
 
17
  return asr_processor.decode(pred_ids)
18
 
19
- input_audio = gr.Audio(
20
- sources=["microphone"],
21
- waveform_options=gr.WaveformOptions(
22
- waveform_color="#01C6FF",
23
- waveform_progress_color="#0066B4",
24
- skip_length=2,
25
- show_controls=False,
26
- ),
27
- )
28
  demo = gr.Interface(
29
  fn=generate_text,
30
- inputs=input_audio,
31
- outputs="text"
 
 
32
  )
33
 
34
  if __name__ == "__main__":
35
- demo.launch()
 
1
  import numpy as np
2
+ import torch
3
  import gradio as gr
4
+ import torchaudio
5
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
6
 
7
+ # Load ASR model & processor
8
  asr_model_id = "Norphel/wav2vec2-large-mms-1b-dzo-colab"
9
+ asr_model = Wav2Vec2ForCTC.from_pretrained(asr_model_id, target_lang="dzo").to("cuda")
10
  asr_processor = Wav2Vec2Processor.from_pretrained(asr_model_id)
11
  asr_processor.tokenizer.set_target_lang("dzo")
12
 
13
+ # Function to process audio & generate text
14
  def generate_text(audio):
15
+ if audio is None:
16
+ return "No audio recorded."
17
+
18
+ sr, data = audio # Gradio provides (sample_rate, waveform)
19
+
20
+ # Resample to 16kHz if needed
21
+ if sr != 16000:
22
+ data = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(torch.tensor(data))
23
+
24
+ # Convert to model input format
25
+ input_dict = asr_processor(data.numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
26
+
27
+ # Run model inference
28
+ with torch.no_grad():
29
+ logits = asr_model(input_dict.input_values.to("cuda")).logits
30
  pred_ids = torch.argmax(logits, dim=-1)[0]
31
 
32
+ # Decode prediction
33
  return asr_processor.decode(pred_ids)
34
 
35
+ # Gradio interface
 
 
 
 
 
 
 
 
36
  demo = gr.Interface(
37
  fn=generate_text,
38
+ inputs=gr.Audio(type="numpy"), # Automatically enables recording
39
+ outputs="text",
40
+ title="Dzongkha Speech-to-Text",
41
+ description="Record your voice and get transcriptions in Dzongkha."
42
  )
43
 
44
  if __name__ == "__main__":
45
+ demo.launch()