Irpan commited on
Commit
f6cde70
1 Parent(s): 3c9ecf2
Files changed (1) hide show
  1. asr.py +17 -2
asr.py CHANGED
@@ -1,5 +1,7 @@
 
1
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
2
  import torch
 
3
  from umsc import UgMultiScriptConverter
4
  import util
5
 
@@ -13,9 +15,22 @@ asr_processor.tokenizer.set_target_lang("uig-script_latin")
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  asr_model = asr_model.to(device)
15
 
16
- def asr(user_audio):
17
  # Load and resample user audio
18
- audio_input, sampling_rate = util.load_and_resample_audio(audio_data = user_audio, target_rate=16000)
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # Process audio through ASR model
21
  inputs = asr_processor(audio_input.squeeze(), sampling_rate=sampling_rate, return_tensors="pt", padding=True)
 
1
+ import numpy as np
2
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
3
  import torch
4
+ import torchaudio
5
  from umsc import UgMultiScriptConverter
6
  import util
7
 
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  asr_model = asr_model.to(device)
17
 
18
+ def asr(audio_data, target_rate = 16000):
19
  # Load and resample user audio
20
+ if isinstance(audio_data, tuple):
21
+ # microphone
22
+ sampling_rate, audio_input = audio_data
23
+ audio_input = (audio_input / 32768.0).astype(np.float32)
24
+ elif isinstance(audio_data, str):
25
+ # file upload
26
+ audio_input, sampling_rate = torchaudio.load(audio_data)
27
+ else:
28
+ return "<<ERROR: Invalid Audio Input Instance: {}>>".format(type(audio_data))
29
+
30
+ # Resample if needed
31
+ if sampling_rate != target_rate:
32
+ resampler = torchaudio.transforms.Resample(sampling_rate, target_rate)
33
+ audio_input = resampler(audio_input)
34
 
35
  # Process audio through ASR model
36
  inputs = asr_processor(audio_input.squeeze(), sampling_rate=sampling_rate, return_tensors="pt", padding=True)