Kabatubare commited on
Commit
14693fc
·
verified ·
1 Parent(s): e3b710b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import librosa
3
  import numpy as np
4
  import torch
 
5
  import logging
6
  from transformers import AutoModelForAudioClassification
7
 
@@ -12,19 +13,21 @@ logging.basicConfig(level=logging.INFO)
12
  local_model_path = "./"
13
  model = AutoModelForAudioClassification.from_pretrained(local_model_path)
14
 
15
- def custom_feature_extraction(audio_file_path, sr=16000, n_mels=128, n_fft=2048, hop_length=512):
16
  waveform, sample_rate = librosa.load(audio_file_path, sr=sr)
17
  S = librosa.feature.melspectrogram(y=waveform, sr=sample_rate, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
18
  S_DB = librosa.power_to_db(S, ref=np.max)
19
  S_DB_tensor = torch.tensor(S_DB).float().unsqueeze(0) # Add batch dimension
20
- return S_DB_tensor
 
 
 
21
 
22
  def predict_voice(audio_file_path):
23
  try:
24
  features = custom_feature_extraction(audio_file_path)
25
 
26
  with torch.no_grad():
27
- # Directly pass the features tensor to the model
28
  outputs = model(features)
29
 
30
  logits = outputs.logits
 
2
  import librosa
3
  import numpy as np
4
  import torch
5
+ import torch.nn.functional as F
6
  import logging
7
  from transformers import AutoModelForAudioClassification
8
 
 
13
  local_model_path = "./"
14
  model = AutoModelForAudioClassification.from_pretrained(local_model_path)
15
 
16
+ def custom_feature_extraction(audio_file_path, sr=16000, n_mels=128, n_fft=2048, hop_length=512, target_length=1024):
17
  waveform, sample_rate = librosa.load(audio_file_path, sr=sr)
18
  S = librosa.feature.melspectrogram(y=waveform, sr=sample_rate, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
19
  S_DB = librosa.power_to_db(S, ref=np.max)
20
  S_DB_tensor = torch.tensor(S_DB).float().unsqueeze(0) # Add batch dimension
21
+
22
+ # Resizing the tensor to match the model's expected input size
23
+ S_DB_tensor_resized = F.interpolate(S_DB_tensor, size=(n_mels, target_length), mode='nearest')
24
+ return S_DB_tensor_resized
25
 
26
  def predict_voice(audio_file_path):
27
  try:
28
  features = custom_feature_extraction(audio_file_path)
29
 
30
  with torch.no_grad():
 
31
  outputs = model(features)
32
 
33
  logits = outputs.logits