Kabatubare commited on
Commit
7045c5c
·
verified ·
1 Parent(s): 780b961

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -7
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import librosa
3
  import numpy as np
4
  import torch
 
5
  import logging
6
  from transformers import AutoModelForAudioClassification
7
 
@@ -12,7 +13,7 @@ logging.basicConfig(level=logging.INFO)
12
  local_model_path = "./"
13
  model = AutoModelForAudioClassification.from_pretrained(local_model_path)
14
 
15
- def custom_feature_extraction(audio_file_path, sr=16000, n_mels=128, n_fft=2048, hop_length=512):
16
  """
17
  Custom feature extraction using Mel spectrogram, tailored for models trained on datasets like AudioSet.
18
  Args:
@@ -21,14 +22,25 @@ def custom_feature_extraction(audio_file_path, sr=16000, n_mels=128, n_fft=2048,
21
  n_mels: Number of Mel bands to generate.
22
  n_fft: Length of the FFT window.
23
  hop_length: Number of samples between successive frames.
 
24
  Returns:
25
  A tensor representation of the Mel spectrogram features.
26
  """
27
  waveform, sample_rate = librosa.load(audio_file_path, sr=sr)
28
  S = librosa.feature.melspectrogram(y=waveform, sr=sample_rate, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
29
  S_DB = librosa.power_to_db(S, ref=np.max)
30
- S_DB_tensor = torch.tensor(S_DB).float().unsqueeze(0) # Add batch dimension
31
- return S_DB_tensor
 
 
 
 
 
 
 
 
 
 
32
 
33
  def predict_voice(audio_file_path):
34
  """
@@ -40,11 +52,8 @@ def predict_voice(audio_file_path):
40
  """
41
  try:
42
  features = custom_feature_extraction(audio_file_path)
43
-
44
  with torch.no_grad():
45
- # Corrected: Directly pass the features tensor to the model
46
  outputs = model(features)
47
-
48
  logits = outputs.logits
49
  predicted_index = logits.argmax()
50
  label = model.config.id2label[predicted_index.item()]
@@ -68,4 +77,4 @@ iface = gr.Interface(
68
  )
69
 
70
  # Launching the interface
71
- iface.launch()
 
2
  import librosa
3
  import numpy as np
4
  import torch
5
+ import torch.nn.functional as F
6
  import logging
7
  from transformers import AutoModelForAudioClassification
8
 
 
13
  local_model_path = "./"
14
  model = AutoModelForAudioClassification.from_pretrained(local_model_path)
15
 
16
+ def custom_feature_extraction(audio_file_path, sr=16000, n_mels=128, n_fft=2048, hop_length=512, target_length=1024):
17
  """
18
  Custom feature extraction using Mel spectrogram, tailored for models trained on datasets like AudioSet.
19
  Args:
 
22
  n_mels: Number of Mel bands to generate.
23
  n_fft: Length of the FFT window.
24
  hop_length: Number of samples between successive frames.
25
+ target_length: Expected length of the Mel spectrogram in the time dimension.
26
  Returns:
27
  A tensor representation of the Mel spectrogram features.
28
  """
29
  waveform, sample_rate = librosa.load(audio_file_path, sr=sr)
30
  S = librosa.feature.melspectrogram(y=waveform, sr=sample_rate, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
31
  S_DB = librosa.power_to_db(S, ref=np.max)
32
+ mel_tensor = torch.tensor(S_DB).float()
33
+
34
+ # Ensure the tensor matches the expected sequence length
35
+ current_length = mel_tensor.shape[1]
36
+ if current_length > target_length:
37
+ mel_tensor = mel_tensor[:, :target_length] # Truncate if longer
38
+ elif current_length < target_length:
39
+ padding = target_length - current_length
40
+ mel_tensor = F.pad(mel_tensor, (0, padding), "constant", 0) # Pad if shorter
41
+
42
+ mel_tensor = mel_tensor.unsqueeze(0) # Add batch dimension for compatibility with model
43
+ return mel_tensor
44
 
45
  def predict_voice(audio_file_path):
46
  """
 
52
  """
53
  try:
54
  features = custom_feature_extraction(audio_file_path)
 
55
  with torch.no_grad():
 
56
  outputs = model(features)
 
57
  logits = outputs.logits
58
  predicted_index = logits.argmax()
59
  label = model.config.id2label[predicted_index.item()]
 
77
  )
78
 
79
  # Launching the interface
80
+ iface.launch()