Irpan commited on
Commit
3a18b3b
·
1 Parent(s): 1959ce1
Files changed (2) hide show
  1. app.py +8 -2
  2. asr.py +49 -11
app.py CHANGED
@@ -1,11 +1,17 @@
1
  import gradio as gr
2
- from asr import transcribe
3
  # from tts import synthesize
4
 
5
 
6
  mms_transcribe = gr.Interface(
7
- fn=transcribe,
8
  inputs=[
 
 
 
 
 
 
9
  gr.Audio()
10
  ],
11
  outputs="text",
 
1
  import gradio as gr
2
+ import asr
3
  # from tts import synthesize
4
 
5
 
6
  mms_transcribe = gr.Interface(
7
+ fn=asr.transcribe,
8
  inputs=[
9
+ gr.Dropdown(
10
+ choices=[m["id"] for m in asr.models_info],
11
+ label="Select Model for ASR",
12
+ value="ixxan/wav2vec2-large-mms-1b-uyghur-latin",
13
+ interactive=True
14
+ ),
15
  gr.Audio()
16
  ],
17
  outputs="text",
asr.py CHANGED
@@ -1,15 +1,45 @@
1
  import torchaudio
2
  import torch
3
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
 
 
 
 
 
 
 
4
  import numpy as np
5
 
6
  # Load processor and model
7
- processor = AutoProcessor.from_pretrained("ixxan/whisper-small-common-voice-ug")
8
- model = AutoModelForSpeechSeq2Seq.from_pretrained("ixxan/whisper-small-common-voice-ug")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- target_sr = processor.feature_extractor.sampling_rate
11
-
12
- def transcribe(audio_data) -> str:
13
  """
14
  Transcribes audio to text using the Whisper model for Uyghur.
15
  Args:
@@ -35,13 +65,18 @@ def transcribe(audio_data) -> str:
35
  return "<<ERROR: Invalid Audio Input Instance: {}>>".format(type(audio_data))
36
 
37
 
 
 
 
 
 
38
  # Resample if needed
39
  if sampling_rate != target_sr:
40
  resampler = torchaudio.transforms.Resample(sampling_rate, target_sr)
41
  audio_input = resampler(audio_input)
42
 
43
  # Preprocess the audio input
44
- inputs = processor(audio_input.squeeze(), sampling_rate=target_sr, return_tensors="pt")
45
 
46
  # Move model to GPU if available
47
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -50,9 +85,12 @@ def transcribe(audio_data) -> str:
50
 
51
  # Generate transcription
52
  with torch.no_grad():
53
- generated_ids = model.generate(inputs["input_features"], max_length=225)
54
-
55
- # Decode the output to get the transcription text
56
- transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
 
57
 
58
  return transcription
 
1
  import torchaudio
2
  import torch
3
+ from transformers import (
4
+ WhisperProcessor,
5
+ AutoProcessor,
6
+ AutoModelForSpeechSeq2Seq,
7
+ AutoModelForCTC,
8
+ Wav2Vec2Processor,
9
+ Wav2Vec2ForCTC
10
+ )
11
  import numpy as np
12
 
13
  # Load processor and model
14
+ models_info = {
15
+ "openai/whisper-small-uzbek": {
16
+ "processor": WhisperProcessor.from_pretrained("openai/whisper-small", language="uzbek", task="transcribe"),
17
+ "model": AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small"),
18
+ "ctc_model": False
19
+ },
20
+ "ixxan/whisper-small-thugy20": {
21
+ "processor": AutoProcessor.from_pretrained("ixxan/whisper-small-thugy20"),
22
+ "model": AutoModelForSpeechSeq2Seq.from_pretrained("ixxan/whisper-small-thugy20"),
23
+ "ctc_model": False
24
+ },
25
+ "ixxan/whisper-small-uyghur-common-voice": {
26
+ "processor": AutoProcessor.from_pretrained("ixxan/whisper-small-uyghur-common-voice"),
27
+ "model": AutoModelForSpeechSeq2Seq.from_pretrained("ixxan/whisper-small-uyghur-common-voice"),
28
+ "ctc_model": False
29
+ },
30
+ "facebook/mms-1b-all": {
31
+ "processor": AutoProcessor.from_pretrained("facebook/mms-1b-all", target_lang='uig-script_arabic'),
32
+ "model": AutoModelForCTC.from_pretrained("facebook/mms-1b-all", target_lang='uig-script_arabic', ignore_mismatched_sizes=True),
33
+ "ctc_model": True
34
+ },
35
+ # "ixxan/wav2vec2-large-mms-1b-uyghur-latin": {
36
+ # "processor": Wav2Vec2Processor.from_pretrained("ixxan/wav2vec2-large-mms-1b-uyghur-latin", target_lang='uig-script_latin'),
37
+ # "model": Wav2Vec2ForCTC.from_pretrained("ixxan/wav2vec2-large-mms-1b-uyghur-latin"),
38
+ # "ctc_model": True
39
+ # },
40
+ }
41
 
42
+ def transcribe(audio_data, model_id) -> str:
 
 
43
  """
44
  Transcribes audio to text using the Whisper model for Uyghur.
45
  Args:
 
65
  return "<<ERROR: Invalid Audio Input Instance: {}>>".format(type(audio_data))
66
 
67
 
68
+ model = models_info[model_id]["model"]
69
+ processor = models_info[model_id]["processor"]
70
+ target_sr = processor.feature_extractor.sampling_rate
71
+ ctc_model = models_info[model_id]["ctc_model"]
72
+
73
  # Resample if needed
74
  if sampling_rate != target_sr:
75
  resampler = torchaudio.transforms.Resample(sampling_rate, target_sr)
76
  audio_input = resampler(audio_input)
77
 
78
  # Preprocess the audio input
79
+ inputs = processor(audio_input.squeeze(), sampling_rate=target_sr, return_tensors="pt", padding=True)
80
 
81
  # Move model to GPU if available
82
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
85
 
86
  # Generate transcription
87
  with torch.no_grad():
88
+ if ctc_model:
89
+ logits = model(**inputs).logits
90
+ predicted_ids = torch.argmax(logits, dim=-1)
91
+ transcription = processor.batch_decode(predicted_ids)[0]
92
+ else:
93
+ generated_ids = model.generate(inputs["input_features"], max_length=225)
94
+ transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
95
 
96
  return transcription