Kabatubare commited on
Commit
637d0ca
·
verified ·
1 Parent(s): 15eca51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -55
app.py CHANGED
@@ -1,77 +1,80 @@
 
1
  import librosa
2
  import numpy as np
3
  import torch
 
4
  import logging
5
- from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
6
- import gradio as gr
7
 
 
8
  logging.basicConfig(level=logging.INFO)
9
 
10
- # Path to your Wav2Vec2 model and processor
11
- model_path = "./wav2vec2-sequence-classification"
12
- try:
13
- model = Wav2Vec2ForSequenceClassification.from_pretrained(model_path)
14
- processor = Wav2Vec2Processor.from_pretrained(model_path)
15
- logging.info("Model and processor loaded successfully.")
16
- except Exception as e:
17
- logging.error(f"Loading model and processor failed: {e}")
18
- raise e
19
 
20
- def preprocess_audio(file_path):
21
  """
22
- Load and preprocess the audio file.
 
 
 
 
 
 
 
 
 
23
  """
24
- # Load the audio file using librosa
25
- audio, sr = librosa.load(file_path, sr=None)
26
- # Resample the audio to 16 kHz (if not already at this sample rate)
27
- if sr != 16000:
28
- audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
29
- sr = 16000
30
- return audio, sr
31
 
32
- def audio_to_features(audio, sr):
33
- """
34
- Convert audio waveform to model features.
35
- """
36
- # Use the processor to prepare the features for the model
37
- return processor(audio, sampling_rate=sr, return_tensors="pt", padding=True, truncation=True).input_values
 
38
 
39
- def classify_audio(file_path):
 
 
 
40
  """
41
- Classify the content of the audio file.
 
 
 
 
42
  """
43
  try:
44
- audio, sr = preprocess_audio(file_path)
45
- input_values = audio_to_features(audio, sr)
46
-
47
- # Inference
48
  with torch.no_grad():
49
- logits = model(input_values).logits
50
-
51
- # Post-processing: Convert logits to softmax to get probabilities
52
- probabilities = torch.softmax(logits, dim=1).detach().numpy()
53
-
54
- # Assuming you have a binary classification model for simplicity
55
- # Modify this part based on your actual number of classes and labels
56
- labels = ['Class 0', 'Class 1'] # Example labels
57
- predictions = dict(zip(labels, probabilities[0]))
58
 
59
- # Format the prediction output
60
- prediction_output = "\n".join([f"{label}: {prob:.4f}" for label, prob in predictions.items()])
61
- return prediction_output
62
  except Exception as e:
63
- logging.error(f"Error during classification: {e}")
64
- return f"Classification error: {e}"
 
 
65
 
66
- # Gradio interface
67
  iface = gr.Interface(
68
- fn=classify_audio,
69
- inputs=gr.inputs.Audio(source="upload", type="filepath"),
70
- outputs="text",
71
- title="Audio Classification with Wav2Vec2",
72
- description="Upload an audio file to classify its content using a Wav2Vec2 model."
73
  )
74
 
75
- # Launch the interface
76
- if __name__ == "__main__":
77
- iface.launch()
 
1
+ import gradio as gr
2
  import librosa
3
  import numpy as np
4
  import torch
5
+ import torch.nn.functional as F
6
  import logging
7
+ from transformers import AutoModelForAudioClassification
 
8
 
9
+ # Configure logging for debugging and information
10
  logging.basicConfig(level=logging.INFO)
11
 
12
+ # Model loading from the specified local path
13
+ local_model_path = "./"
14
+ model = AutoModelForAudioClassification.from_pretrained(local_model_path)
 
 
 
 
 
 
15
 
16
+ def custom_feature_extraction(audio_file_path, sr=16000, n_mels=128, n_fft=2048, hop_length=512, target_length=1024):
17
  """
18
+ Custom feature extraction using Mel spectrogram, tailored for models trained on datasets like AudioSet.
19
+ Args:
20
+ audio_file_path: Path to the audio file for prediction.
21
+ sr: Target sampling rate for the audio file.
22
+ n_mels: Number of Mel bands to generate.
23
+ n_fft: Length of the FFT window.
24
+ hop_length: Number of samples between successive frames.
25
+ target_length: Expected length of the Mel spectrogram in the time dimension.
26
+ Returns:
27
+ A tensor representation of the Mel spectrogram features.
28
  """
29
+ waveform, sample_rate = librosa.load(audio_file_path, sr=sr)
30
+ S = librosa.feature.melspectrogram(y=waveform, sr=sample_rate, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
31
+ S_DB = librosa.power_to_db(S, ref=np.max)
32
+ mel_tensor = torch.tensor(S_DB).float()
 
 
 
33
 
34
+ # Ensure the tensor matches the expected sequence length
35
+ current_length = mel_tensor.shape[1]
36
+ if current_length > target_length:
37
+ mel_tensor = mel_tensor[:, :target_length] # Truncate if longer
38
+ elif current_length < target_length:
39
+ padding = target_length - current_length
40
+ mel_tensor = F.pad(mel_tensor, (0, padding), "constant", 0) # Pad if shorter
41
 
42
+ mel_tensor = mel_tensor.unsqueeze(0) # Add batch dimension for compatibility with model
43
+ return mel_tensor
44
+
45
+ def predict_voice(audio_file_path):
46
  """
47
+ Predicts the audio class using a pre-trained model and custom feature extraction.
48
+ Args:
49
+ audio_file_path: Path to the audio file for prediction.
50
+ Returns:
51
+ A string containing the predicted class and confidence level.
52
  """
53
  try:
54
+ features = custom_feature_extraction(audio_file_path)
 
 
 
55
  with torch.no_grad():
56
+ outputs = model(features)
57
+ logits = outputs.logits
58
+ predicted_index = logits.argmax()
59
+ label = model.config.id2label[predicted_index.item()]
60
+ confidence = torch.softmax(logits, dim=1).max().item() * 100
 
 
 
 
61
 
62
+ result = f"The voice is classified as '{label}' with a confidence of {confidence:.2f}%."
63
+ logging.info("Prediction successful.")
 
64
  except Exception as e:
65
+ result = f"Error during processing: {e}"
66
+ logging.error(result)
67
+
68
+ return result
69
 
70
+ # Setting up the Gradio interface
71
  iface = gr.Interface(
72
+ fn=predict_voice,
73
+ inputs=gr.Audio(label="Upload Audio File", type="filepath"),
74
+ outputs=gr.Textbox(label="Prediction"),
75
+ title="Voice Authenticity Detection",
76
+ description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results."
77
  )
78
 
79
+ # Launching the interface
80
+ iface.launch()