Kabatubare commited on
Commit
d75aa1b
·
verified ·
1 Parent(s): 14693fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -11
app.py CHANGED
@@ -2,33 +2,48 @@ import gradio as gr
2
  import librosa
3
  import numpy as np
4
  import torch
5
- import torch.nn.functional as F
6
  import logging
7
  from transformers import AutoModelForAudioClassification
8
 
9
- # Configure logging
10
  logging.basicConfig(level=logging.INFO)
11
 
12
- # Model loading
13
  local_model_path = "./"
14
  model = AutoModelForAudioClassification.from_pretrained(local_model_path)
15
 
16
- def custom_feature_extraction(audio_file_path, sr=16000, n_mels=128, n_fft=2048, hop_length=512, target_length=1024):
 
 
 
 
 
 
 
 
 
 
 
17
  waveform, sample_rate = librosa.load(audio_file_path, sr=sr)
18
  S = librosa.feature.melspectrogram(y=waveform, sr=sample_rate, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
19
  S_DB = librosa.power_to_db(S, ref=np.max)
20
  S_DB_tensor = torch.tensor(S_DB).float().unsqueeze(0) # Add batch dimension
21
-
22
- # Resizing the tensor to match the model's expected input size
23
- S_DB_tensor_resized = F.interpolate(S_DB_tensor, size=(n_mels, target_length), mode='nearest')
24
- return S_DB_tensor_resized
25
 
26
  def predict_voice(audio_file_path):
 
 
 
 
 
 
 
27
  try:
28
  features = custom_feature_extraction(audio_file_path)
29
 
30
  with torch.no_grad():
31
- outputs = model(features)
 
32
 
33
  logits = outputs.logits
34
  predicted_index = logits.argmax()
@@ -36,13 +51,14 @@ def predict_voice(audio_file_path):
36
  confidence = torch.softmax(logits, dim=1).max().item() * 100
37
 
38
  result = f"The voice is classified as '{label}' with a confidence of {confidence:.2f}%."
39
- logging.info(f"Prediction: {result}")
40
  except Exception as e:
41
  result = f"Error during processing: {e}"
42
  logging.error(result)
43
 
44
  return result
45
 
 
46
  iface = gr.Interface(
47
  fn=predict_voice,
48
  inputs=gr.Audio(label="Upload Audio File", type="filepath"),
@@ -51,4 +67,5 @@ iface = gr.Interface(
51
  description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results."
52
  )
53
 
54
- iface.launch()
 
 
2
  import librosa
3
  import numpy as np
4
  import torch
 
5
  import logging
6
  from transformers import AutoModelForAudioClassification
7
 
8
+ # Configure logging for debugging and information
9
  logging.basicConfig(level=logging.INFO)
10
 
11
+ # Model loading from the specified local path
12
  local_model_path = "./"
13
  model = AutoModelForAudioClassification.from_pretrained(local_model_path)
14
 
15
+ def custom_feature_extraction(audio_file_path, sr=16000, n_mels=128, n_fft=2048, hop_length=512):
16
+ """
17
+ Custom feature extraction using Mel spectrogram, tailored for models trained on datasets like AudioSet.
18
+ Args:
19
+ audio_file_path: Path to the audio file for prediction.
20
+ sr: Target sampling rate for the audio file.
21
+ n_mels: Number of Mel bands to generate.
22
+ n_fft: Length of the FFT window.
23
+ hop_length: Number of samples between successive frames.
24
+ Returns:
25
+ A tensor representation of the Mel spectrogram features.
26
+ """
27
  waveform, sample_rate = librosa.load(audio_file_path, sr=sr)
28
  S = librosa.feature.melspectrogram(y=waveform, sr=sample_rate, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
29
  S_DB = librosa.power_to_db(S, ref=np.max)
30
  S_DB_tensor = torch.tensor(S_DB).float().unsqueeze(0) # Add batch dimension
31
+ return S_DB_tensor
 
 
 
32
 
33
  def predict_voice(audio_file_path):
34
+ """
35
+ Predicts the audio class using a pre-trained model and custom feature extraction.
36
+ Args:
37
+ audio_file_path: Path to the audio file for prediction.
38
+ Returns:
39
+ A string containing the predicted class and confidence level.
40
+ """
41
  try:
42
  features = custom_feature_extraction(audio_file_path)
43
 
44
  with torch.no_grad():
45
+ # Adjust the model prediction line if necessary to match your model's expected input
46
+ outputs = model(inputs=features)
47
 
48
  logits = outputs.logits
49
  predicted_index = logits.argmax()
 
51
  confidence = torch.softmax(logits, dim=1).max().item() * 100
52
 
53
  result = f"The voice is classified as '{label}' with a confidence of {confidence:.2f}%."
54
+ logging.info("Prediction successful.")
55
  except Exception as e:
56
  result = f"Error during processing: {e}"
57
  logging.error(result)
58
 
59
  return result
60
 
61
+ # Setting up the Gradio interface
62
  iface = gr.Interface(
63
  fn=predict_voice,
64
  inputs=gr.Audio(label="Upload Audio File", type="filepath"),
 
67
  description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results."
68
  )
69
 
70
+ # Launching the interface
71
+ iface.launch()