andromeda01111 commited on
Commit
c6d010d
·
verified ·
1 Parent(s): e1edbee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -5
app.py CHANGED
@@ -12,11 +12,34 @@ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name_or_path)
12
  sampling_rate = feature_extractor.sampling_rate
13
  model = Wav2Vec2ForSpeechClassification.from_pretrained(model_name_or_path)
14
 
15
- def speech_file_to_array_fn(path, sampling_rate):
16
- speech_array, _sampling_rate = torchaudio.load(path)
17
- resampler = torchaudio.transforms.Resample(_sampling_rate, sampling_rate)
18
- speech = resampler(speech_array).squeeze().numpy()
19
- return speech
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def predict(audio_path):
22
  speech = speech_file_to_array_fn(audio_path, sampling_rate)
 
12
  sampling_rate = feature_extractor.sampling_rate
13
  model = Wav2Vec2ForSpeechClassification.from_pretrained(model_name_or_path)
14
 
15
+ # def speech_file_to_array_fn(path, sampling_rate):
16
+ # speech_array, _sampling_rate = torchaudio.load(path)
17
+ # resampler = torchaudio.transforms.Resample(_sampling_rate, sampling_rate)
18
+ # speech = resampler(speech_array).squeeze().numpy()
19
+ # return speech
20
+
21
+ def speech_file_to_array_fn(audio_path):
22
+ if audio_path is None:
23
+ return None # Handle cases where no file is provided
24
+
25
+ try:
26
+ # Check if the input is a file path (upload) or direct audio data (recording)
27
+ if isinstance(audio_path, str):
28
+ speech_array, _sampling_rate = torchaudio.load(audio_path)
29
+ else:
30
+ # If it's recorded audio, Gradio provides it as a NumPy array
31
+ speech_array = torch.tensor(audio_path)
32
+ _sampling_rate = sampling_rate # Use default sampling rate
33
+
34
+ # Resample to match model requirements
35
+ resampler = torchaudio.transforms.Resample(orig_freq=_sampling_rate, new_freq=sampling_rate)
36
+ speech = resampler(speech_array).squeeze().numpy()
37
+ return speech
38
+
39
+ except Exception as e:
40
+ print(f"Error processing audio: {e}")
41
+ return None
42
+
43
 
44
  def predict(audio_path):
45
  speech = speech_file_to_array_fn(audio_path, sampling_rate)