Kabatubare commited on
Commit
411539a
·
verified ·
1 Parent(s): 38963c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -63
app.py CHANGED
@@ -3,80 +3,96 @@ import librosa
3
  import numpy as np
4
  import torch
5
  import logging
6
- from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
7
- from pydub import AudioSegment
8
- import os
9
- import tempfile
10
- import soundfile as sf
11
 
12
- # Setup logging
13
  logging.basicConfig(level=logging.INFO)
14
 
15
- # Load model and processor
16
  model_path = "./"
17
- model = Wav2Vec2ForSequenceClassification.from_pretrained(model_path)
18
- processor = Wav2Vec2Processor.from_pretrained(model_path)
 
 
 
19
 
20
- def preprocess_audio(audio_file_path, target_sampling_rate=16000):
21
- """
22
- Preprocess the input audio file to the target sampling rate and format.
23
- """
24
- # Convert audio to target sampling rate using librosa
25
- y, sr = librosa.load(audio_file_path, sr=target_sampling_rate)
26
- return y, sr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- def predict_audio_class(audio_file_path):
29
- """
30
- Predict the class of the input audio file using Wav2Vec 2.0 model.
31
- """
32
  try:
33
- # Preprocess audio
34
- audio, sr = preprocess_audio(audio_file_path, target_sampling_rate=16000)
35
-
36
- # Prepare the audio for the model
37
- inputs = processor(audio, sampling_rate=sr, return_tensors="pt", padding=True, truncation=True)
38
-
 
 
 
 
 
 
 
 
 
 
 
 
39
  # Predict
40
  with torch.no_grad():
41
- outputs = model(**inputs)
42
-
43
  logits = outputs.logits
44
- predicted_index = logits.argmax(dim=1).item()
45
- confidence = torch.softmax(logits, dim=1).max().item() * 100
46
  label = model.config.id2label[predicted_index]
47
-
48
- return f"Predicted class: {label} with confidence: {confidence:.2f}%"
49
  except Exception as e:
50
- logging.error(f"Error during processing: {e}")
51
- return "Prediction failed due to an error."
52
-
53
- def save_temp_audio(file):
54
- """
55
- Saves a temporary audio file, returns the path.
56
- """
57
- temp_dir = tempfile.gettempdir()
58
- temp_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir, suffix=".wav")
59
- temp_file_path = temp_file.name
60
- # Convert to WAV for consistency
61
- AudioSegment.from_file(file).export(temp_file_path, format="wav")
62
- return temp_file_path
63
-
64
- def handle_audio_input(file_info):
65
- """
66
- Handles the input audio file for prediction.
67
- """
68
- audio_file_path = save_temp_audio(file_info)
69
- prediction = predict_audio_class(audio_file_path)
70
- os.unlink(audio_file_path) # Clean up temp file
71
- return prediction
72
 
73
- # Setup Gradio interface
74
- iface = gr.Interface(
75
- fn=handle_audio_input,
76
- inputs=gr.inputs.Audio(source="upload", type="file", label="Upload Audio"),
77
- outputs="text",
78
- title="Audio Class Prediction",
79
- description="Predicts the class of uploaded audio files using a fine-tuned Wav2Vec 2.0 model."
80
- )
81
 
82
- iface.launch()
 
 
 
3
  import numpy as np
4
  import torch
5
  import logging
6
+ from transformers import AutoModelForAudioClassification
7
+ from torch.nn.functional import interpolate
 
 
 
8
 
9
+ # Set up logging to help diagnose issues and track progress
10
  logging.basicConfig(level=logging.INFO)
11
 
12
+ # Load the pretrained model for audio classification
13
  model_path = "./"
14
+ try:
15
+ model = AutoModelForAudioClassification.from_pretrained(model_path)
16
+ logging.info("Model loaded successfully.")
17
+ except Exception as e:
18
+ logging.error(f"Failed to load model: {e}")
19
 
20
+ # Function to preprocess audio file
21
+ def preprocess_audio(audio_path, target_sr=16000):
22
+ try:
23
+ y, sr = librosa.load(audio_path, sr=target_sr)
24
+ logging.info("Audio file loaded and resampled.")
25
+ return y, sr
26
+ except Exception as e:
27
+ logging.error(f"Error in audio preprocessing: {e}")
28
+ return None, None
29
+
30
+ # Function to extract features from audio
31
+ def extract_features(y, sr, n_mfcc=40, n_fft=2048, hop_length=512):
32
+ try:
33
+ mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=n_mfcc, n_fft=n_fft, hop_length=hop_length)
34
+ logging.info("MFCC features extracted.")
35
+ return mfcc
36
+ except Exception as e:
37
+ logging.error(f"Error extracting MFCC features: {e}")
38
+ return None
39
+
40
+ # Function to normalize and pad features
41
+ def normalize_and_pad_features(mfcc, target_size=512):
42
+ try:
43
+ # Normalize features
44
+ mfcc_normalized = (mfcc - np.mean(mfcc, axis=1, keepdims=True)) / np.std(mfcc, axis=1, keepdims=True)
45
+ logging.info("Features normalized.")
46
+
47
+ # Pad features
48
+ if mfcc_normalized.shape[1] < target_size:
49
+ padding = target_size - mfcc_normalized.shape[1]
50
+ mfcc_padded = np.pad(mfcc_normalized, ((0, 0), (0, padding)), 'constant')
51
+ logging.info("Features padded.")
52
+ else:
53
+ mfcc_padded = mfcc_normalized[:, :target_size]
54
+ return mfcc_padded
55
+ except Exception as e:
56
+ logging.error(f"Error in normalization and padding: {e}")
57
+ return None
58
 
59
+ # Prediction function
60
+ def predict_voice(audio_file_path):
 
 
61
  try:
62
+ # Preprocess and extract features
63
+ y, sr = preprocess_audio(audio_file_path)
64
+ if y is None or sr is None:
65
+ return "Error in audio preprocessing."
66
+ mfcc = extract_features(y, sr)
67
+ if mfcc is None:
68
+ return "Error extracting features."
69
+ features = normalize_and_pad_features(mfcc)
70
+ if features is None:
71
+ return "Error in feature normalization and padding."
72
+
73
+ # Convert to tensor and add batch dimension
74
+ features_tensor = torch.tensor(features).float().unsqueeze(0)
75
+
76
+ # Ensure the input tensor matches the model's expected dimensions
77
+ if features_tensor.dim() == 2:
78
+ features_tensor = features_tensor.unsqueeze(0) # Add a channel dimension
79
+
80
  # Predict
81
  with torch.no_grad():
82
+ outputs = model(features_tensor)
 
83
  logits = outputs.logits
84
+ predicted_index = logits.argmax().item()
 
85
  label = model.config.id2label[predicted_index]
86
+ confidence = torch.softmax(logits, dim=1).max().item() * 100
87
+ return f"Classified as '{label}' with {confidence:.2f}% confidence."
88
  except Exception as e:
89
+ logging.error(f"Prediction error: {e}")
90
+ return "Error during prediction."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ # Gradio interface
93
+ iface = gr.Interface(fn=predict_voice, inputs=gr.inputs.Audio(type="filepath"), outputs="text",
94
+ title="Audio Classification", description="Classify audio files with a pretrained model.")
 
 
 
 
 
95
 
96
+ # Launch the Gradio app
97
+ if __name__ == "__main__":
98
+ iface.launch()