Kabatubare commited on
Commit
e02dec8
·
verified ·
1 Parent(s): 7045c5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -43
app.py CHANGED
@@ -2,58 +2,51 @@ import gradio as gr
2
  import librosa
3
  import numpy as np
4
  import torch
5
- import torch.nn.functional as F
6
  import logging
7
  from transformers import AutoModelForAudioClassification
 
8
 
9
  # Configure logging for debugging and information
10
  logging.basicConfig(level=logging.INFO)
11
 
12
- # Model loading from the specified local path
13
- local_model_path = "./"
14
- model = AutoModelForAudioClassification.from_pretrained(local_model_path)
15
 
16
- def custom_feature_extraction(audio_file_path, sr=16000, n_mels=128, n_fft=2048, hop_length=512, target_length=1024):
17
- """
18
- Custom feature extraction using Mel spectrogram, tailored for models trained on datasets like AudioSet.
19
- Args:
20
- audio_file_path: Path to the audio file for prediction.
21
- sr: Target sampling rate for the audio file.
22
- n_mels: Number of Mel bands to generate.
23
- n_fft: Length of the FFT window.
24
- hop_length: Number of samples between successive frames.
25
- target_length: Expected length of the Mel spectrogram in the time dimension.
26
- Returns:
27
- A tensor representation of the Mel spectrogram features.
28
- """
29
- waveform, sample_rate = librosa.load(audio_file_path, sr=sr)
30
- S = librosa.feature.melspectrogram(y=waveform, sr=sample_rate, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
31
- S_DB = librosa.power_to_db(S, ref=np.max)
32
- mel_tensor = torch.tensor(S_DB).float()
33
-
34
- # Ensure the tensor matches the expected sequence length
35
- current_length = mel_tensor.shape[1]
36
- if current_length > target_length:
37
- mel_tensor = mel_tensor[:, :target_length] # Truncate if longer
38
- elif current_length < target_length:
39
- padding = target_length - current_length
40
- mel_tensor = F.pad(mel_tensor, (0, padding), "constant", 0) # Pad if shorter
41
-
42
- mel_tensor = mel_tensor.unsqueeze(0) # Add batch dimension for compatibility with model
43
- return mel_tensor
44
 
45
  def predict_voice(audio_file_path):
46
- """
47
- Predicts the audio class using a pre-trained model and custom feature extraction.
48
- Args:
49
- audio_file_path: Path to the audio file for prediction.
50
- Returns:
51
- A string containing the predicted class and confidence level.
52
- """
53
  try:
54
- features = custom_feature_extraction(audio_file_path)
 
55
  with torch.no_grad():
56
- outputs = model(features)
 
57
  logits = outputs.logits
58
  predicted_index = logits.argmax()
59
  label = model.config.id2label[predicted_index.item()]
@@ -67,7 +60,6 @@ def predict_voice(audio_file_path):
67
 
68
  return result
69
 
70
- # Setting up the Gradio interface
71
  iface = gr.Interface(
72
  fn=predict_voice,
73
  inputs=gr.Audio(label="Upload Audio File", type="filepath"),
@@ -76,5 +68,4 @@ iface = gr.Interface(
76
  description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results."
77
  )
78
 
79
- # Launching the interface
80
  iface.launch()
 
2
  import librosa
3
  import numpy as np
4
  import torch
 
5
  import logging
6
  from transformers import AutoModelForAudioClassification
7
+ import soundfile as sf
8
 
9
  # Configure logging for debugging and information
10
  logging.basicConfig(level=logging.INFO)
11
 
12
+ # Load the model
13
+ model_path = "./"
14
+ model = AutoModelForAudioClassification.from_pretrained(model_path)
15
 
16
+ def augment_and_extract_features(audio_path, output_path=None, sr=16000, n_mfcc=40, n_fft=2048, hop_length=512):
17
+ # Load and augment the audio file
18
+ y, sr = librosa.load(audio_path, sr=sr)
19
+ y_augmented = librosa.effects.pitch_shift(y, sr, n_steps=4) # Pitch shifting
20
+ y_augmented = librosa.effects.time_stretch(y_augmented, rate=1.2) # Time stretching
21
+
22
+ # Save the augmented audio if an output path is provided
23
+ if output_path is not None:
24
+ sf.write(output_path, y_augmented, sr)
25
+
26
+ # Extract features
27
+ mfcc = librosa.feature.mfcc(y=y_augmented, sr=sr, n_mfcc=n_mfcc, n_fft=n_fft, hop_length=hop_length)
28
+ chroma = librosa.feature.chroma_stft(y=y_augmented, sr=sr, n_fft=n_fft, hop_length=hop_length)
29
+ mel = librosa.feature.melspectrogram(y=y_augmented, sr=sr, n_fft=n_fft, hop_length=hop_length)
30
+ contrast = librosa.feature.spectral_contrast(y=y_augmented, sr=sr, n_fft=n_fft, hop_length=hop_length)
31
+ tonnetz = librosa.feature.tonnetz(y=librosa.effects.harmonic(y_augmented), sr=sr)
32
+
33
+ # Combine all features
34
+ features = np.concatenate((mfcc, chroma, mel, contrast, tonnetz), axis=0)
35
+
36
+ # Normalize features
37
+ features = (features - np.mean(features, axis=1, keepdims=True)) / np.std(features, axis=1, keepdims=True)
38
+
39
+ # Convert to tensor
40
+ features_tensor = torch.tensor(features).float().unsqueeze(0) # Add batch dimension
41
+ return features_tensor
 
 
42
 
43
  def predict_voice(audio_file_path):
 
 
 
 
 
 
 
44
  try:
45
+ features_tensor = augment_and_extract_features(audio_file_path)
46
+
47
  with torch.no_grad():
48
+ outputs = model(features_tensor)
49
+
50
  logits = outputs.logits
51
  predicted_index = logits.argmax()
52
  label = model.config.id2label[predicted_index.item()]
 
60
 
61
  return result
62
 
 
63
  iface = gr.Interface(
64
  fn=predict_voice,
65
  inputs=gr.Audio(label="Upload Audio File", type="filepath"),
 
68
  description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results."
69
  )
70
 
 
71
  iface.launch()