Kabatubare commited on
Commit
09e98e6
·
verified ·
1 Parent(s): 411539a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -71
app.py CHANGED
@@ -3,96 +3,90 @@ import librosa
3
  import numpy as np
4
  import torch
5
  import logging
6
- from transformers import AutoModelForAudioClassification
7
- from torch.nn.functional import interpolate
8
 
9
- # Set up logging to help diagnose issues and track progress
10
  logging.basicConfig(level=logging.INFO)
11
 
12
- # Load the pretrained model for audio classification
13
  model_path = "./"
14
  try:
15
- model = AutoModelForAudioClassification.from_pretrained(model_path)
16
- logging.info("Model loaded successfully.")
 
17
  except Exception as e:
18
- logging.error(f"Failed to load model: {e}")
 
19
 
20
- # Function to preprocess audio file
21
- def preprocess_audio(audio_path, target_sr=16000):
 
 
22
  try:
23
- y, sr = librosa.load(audio_path, sr=target_sr)
24
  logging.info("Audio file loaded and resampled.")
25
- return y, sr
26
  except Exception as e:
27
- logging.error(f"Error in audio preprocessing: {e}")
28
- return None, None
29
 
30
- # Function to extract features from audio
31
- def extract_features(y, sr, n_mfcc=40, n_fft=2048, hop_length=512):
 
 
32
  try:
33
- mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=n_mfcc, n_fft=n_fft, hop_length=hop_length)
34
- logging.info("MFCC features extracted.")
35
- return mfcc
36
  except Exception as e:
37
- logging.error(f"Error extracting MFCC features: {e}")
38
- return None
39
 
40
- # Function to normalize and pad features
41
- def normalize_and_pad_features(mfcc, target_size=512):
 
 
42
  try:
43
- # Normalize features
44
- mfcc_normalized = (mfcc - np.mean(mfcc, axis=1, keepdims=True)) / np.std(mfcc, axis=1, keepdims=True)
45
- logging.info("Features normalized.")
46
-
47
- # Pad features
48
- if mfcc_normalized.shape[1] < target_size:
49
- padding = target_size - mfcc_normalized.shape[1]
50
- mfcc_padded = np.pad(mfcc_normalized, ((0, 0), (0, padding)), 'constant')
51
- logging.info("Features padded.")
52
- else:
53
- mfcc_padded = mfcc_normalized[:, :target_size]
54
- return mfcc_padded
55
  except Exception as e:
56
- logging.error(f"Error in normalization and padding: {e}")
57
- return None
58
 
59
- # Prediction function
60
- def predict_voice(audio_file_path):
 
 
 
 
 
61
  try:
62
- # Preprocess and extract features
63
- y, sr = preprocess_audio(audio_file_path)
64
- if y is None or sr is None:
65
- return "Error in audio preprocessing."
66
- mfcc = extract_features(y, sr)
67
- if mfcc is None:
68
- return "Error extracting features."
69
- features = normalize_and_pad_features(mfcc)
70
- if features is None:
71
- return "Error in feature normalization and padding."
72
-
73
- # Convert to tensor and add batch dimension
74
- features_tensor = torch.tensor(features).float().unsqueeze(0)
75
-
76
- # Ensure the input tensor matches the model's expected dimensions
77
- if features_tensor.dim() == 2:
78
- features_tensor = features_tensor.unsqueeze(0) # Add a channel dimension
79
 
80
- # Predict
81
- with torch.no_grad():
82
- outputs = model(features_tensor)
83
- logits = outputs.logits
84
- predicted_index = logits.argmax().item()
85
- label = model.config.id2label[predicted_index]
86
- confidence = torch.softmax(logits, dim=1).max().item() * 100
87
- return f"Classified as '{label}' with {confidence:.2f}% confidence."
 
 
88
  except Exception as e:
89
- logging.error(f"Prediction error: {e}")
90
- return "Error during prediction."
91
 
92
- # Gradio interface
93
- iface = gr.Interface(fn=predict_voice, inputs=gr.inputs.Audio(type="filepath"), outputs="text",
94
- title="Audio Classification", description="Classify audio files with a pretrained model.")
95
 
96
- # Launch the Gradio app
97
- if __name__ == "__main__":
98
- iface.launch()
 
3
  import numpy as np
4
  import torch
5
  import logging
6
+ from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
 
7
 
8
+ # Initialize logging
9
  logging.basicConfig(level=logging.INFO)
10
 
11
+ # Load the model and feature extractor
12
  model_path = "./"
13
  try:
14
+ model = Wav2Vec2ForSequenceClassification.from_pretrained(model_path)
15
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_path)
16
+ logging.info("Model and feature extractor loaded successfully.")
17
  except Exception as e:
18
+ logging.error(f"Model loading failed: {e}")
19
+ raise e
20
 
21
+ def load_audio(audio_path, sr=16000):
22
+ """
23
+ Load an audio file and resample to the target sample rate.
24
+ """
25
  try:
26
+ audio, _ = librosa.load(audio_path, sr=sr)
27
  logging.info("Audio file loaded and resampled.")
28
+ return audio
29
  except Exception as e:
30
+ logging.error(f"Failed to load audio: {e}")
31
+ raise e
32
 
33
+ def preprocess_audio(audio):
34
+ """
35
+ Preprocess the audio file to the format expected by Wav2Vec2 model.
36
+ """
37
  try:
38
+ input_values = feature_extractor(audio, return_tensors="pt", padding="longest", sampling_rate=16000).input_values
39
+ logging.info("Audio file preprocessed.")
40
+ return input_values
41
  except Exception as e:
42
+ logging.error(f"Audio preprocessing failed: {e}")
43
+ raise e
44
 
45
+ def predict(input_values):
46
+ """
47
+ Make a prediction with the Wav2Vec2 model.
48
+ """
49
  try:
50
+ with torch.no_grad():
51
+ logits = model(input_values).logits
52
+ predicted_id = torch.argmax(logits, dim=-1)
53
+ logging.info(f"Prediction made with id {predicted_id}")
54
+ return predicted_id
 
 
 
 
 
 
 
55
  except Exception as e:
56
+ logging.error(f"Prediction failed: {e}")
57
+ raise e
58
 
59
+ def get_label(prediction_id):
60
+ """
61
+ Convert the prediction ID to a meaningful label.
62
+ """
63
+ # Example of converting predicted id to label
64
+ # This should be adapted based on your specific model's labels
65
+ labels = ["label1", "label2"] # Dummy label list for demonstration
66
  try:
67
+ label = labels[prediction_id]
68
+ logging.info(f"Label obtained: {label}")
69
+ return label
70
+ except Exception as e:
71
+ logging.error(f"Failed to get label: {e}")
72
+ raise e
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ def main(audio_file_path):
75
+ """
76
+ Load audio, preprocess, predict, and return the label.
77
+ """
78
+ try:
79
+ audio = load_audio(audio_file_path)
80
+ input_values = preprocess_audio(audio)
81
+ prediction_id = predict(input_values)
82
+ label = get_label(prediction_id)
83
+ return label
84
  except Exception as e:
85
+ logging.error(f"Error in processing: {e}")
86
+ return str(e)
87
 
88
+ # Set up Gradio interface
89
+ iface = gr.Interface(fn=main, inputs=gr.inputs.Audio(type="filepath"), outputs="text", title="Audio Classification")
 
90
 
91
+ # Launch the interface
92
+ iface.launch()