Kabatubare commited on
Commit
15eca51
·
verified ·
1 Parent(s): 09e98e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -65
app.py CHANGED
@@ -1,92 +1,77 @@
1
- import gradio as gr
2
  import librosa
3
  import numpy as np
4
  import torch
5
  import logging
6
- from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
 
7
 
8
- # Initialize logging
9
  logging.basicConfig(level=logging.INFO)
10
 
11
- # Load the model and feature extractor
12
- model_path = "./"
13
  try:
14
  model = Wav2Vec2ForSequenceClassification.from_pretrained(model_path)
15
- feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_path)
16
- logging.info("Model and feature extractor loaded successfully.")
17
  except Exception as e:
18
- logging.error(f"Model loading failed: {e}")
19
  raise e
20
 
21
- def load_audio(audio_path, sr=16000):
22
  """
23
- Load an audio file and resample to the target sample rate.
24
  """
25
- try:
26
- audio, _ = librosa.load(audio_path, sr=sr)
27
- logging.info("Audio file loaded and resampled.")
28
- return audio
29
- except Exception as e:
30
- logging.error(f"Failed to load audio: {e}")
31
- raise e
32
 
33
- def preprocess_audio(audio):
34
  """
35
- Preprocess the audio file to the format expected by Wav2Vec2 model.
36
  """
37
- try:
38
- input_values = feature_extractor(audio, return_tensors="pt", padding="longest", sampling_rate=16000).input_values
39
- logging.info("Audio file preprocessed.")
40
- return input_values
41
- except Exception as e:
42
- logging.error(f"Audio preprocessing failed: {e}")
43
- raise e
44
 
45
- def predict(input_values):
46
  """
47
- Make a prediction with the Wav2Vec2 model.
48
  """
49
  try:
 
 
 
 
50
  with torch.no_grad():
51
  logits = model(input_values).logits
52
- predicted_id = torch.argmax(logits, dim=-1)
53
- logging.info(f"Prediction made with id {predicted_id}")
54
- return predicted_id
55
- except Exception as e:
56
- logging.error(f"Prediction failed: {e}")
57
- raise e
58
-
59
- def get_label(prediction_id):
60
- """
61
- Convert the prediction ID to a meaningful label.
62
- """
63
- # Example of converting predicted id to label
64
- # This should be adapted based on your specific model's labels
65
- labels = ["label1", "label2"] # Dummy label list for demonstration
66
- try:
67
- label = labels[prediction_id]
68
- logging.info(f"Label obtained: {label}")
69
- return label
70
- except Exception as e:
71
- logging.error(f"Failed to get label: {e}")
72
- raise e
73
-
74
- def main(audio_file_path):
75
- """
76
- Load audio, preprocess, predict, and return the label.
77
- """
78
- try:
79
- audio = load_audio(audio_file_path)
80
- input_values = preprocess_audio(audio)
81
- prediction_id = predict(input_values)
82
- label = get_label(prediction_id)
83
- return label
84
  except Exception as e:
85
- logging.error(f"Error in processing: {e}")
86
- return str(e)
87
 
88
- # Set up Gradio interface
89
- iface = gr.Interface(fn=main, inputs=gr.inputs.Audio(type="filepath"), outputs="text", title="Audio Classification")
 
 
 
 
 
 
90
 
91
  # Launch the interface
92
- iface.launch()
 
 
 
1
  import librosa
2
  import numpy as np
3
  import torch
4
  import logging
5
+ from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
6
+ import gradio as gr
7
 
 
8
  logging.basicConfig(level=logging.INFO)
9
 
10
+ # Path to your Wav2Vec2 model and processor
11
+ model_path = "./wav2vec2-sequence-classification"
12
  try:
13
  model = Wav2Vec2ForSequenceClassification.from_pretrained(model_path)
14
+ processor = Wav2Vec2Processor.from_pretrained(model_path)
15
+ logging.info("Model and processor loaded successfully.")
16
  except Exception as e:
17
+ logging.error(f"Loading model and processor failed: {e}")
18
  raise e
19
 
20
+ def preprocess_audio(file_path):
21
  """
22
+ Load and preprocess the audio file.
23
  """
24
+ # Load the audio file using librosa
25
+ audio, sr = librosa.load(file_path, sr=None)
26
+ # Resample the audio to 16 kHz (if not already at this sample rate)
27
+ if sr != 16000:
28
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
29
+ sr = 16000
30
+ return audio, sr
31
 
32
+ def audio_to_features(audio, sr):
33
  """
34
+ Convert audio waveform to model features.
35
  """
36
+ # Use the processor to prepare the features for the model
37
+ return processor(audio, sampling_rate=sr, return_tensors="pt", padding=True, truncation=True).input_values
 
 
 
 
 
38
 
39
+ def classify_audio(file_path):
40
  """
41
+ Classify the content of the audio file.
42
  """
43
  try:
44
+ audio, sr = preprocess_audio(file_path)
45
+ input_values = audio_to_features(audio, sr)
46
+
47
+ # Inference
48
  with torch.no_grad():
49
  logits = model(input_values).logits
50
+
51
+ # Post-processing: Convert logits to softmax to get probabilities
52
+ probabilities = torch.softmax(logits, dim=1).detach().numpy()
53
+
54
+ # Assuming you have a binary classification model for simplicity
55
+ # Modify this part based on your actual number of classes and labels
56
+ labels = ['Class 0', 'Class 1'] # Example labels
57
+ predictions = dict(zip(labels, probabilities[0]))
58
+
59
+ # Format the prediction output
60
+ prediction_output = "\n".join([f"{label}: {prob:.4f}" for label, prob in predictions.items()])
61
+ return prediction_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  except Exception as e:
63
+ logging.error(f"Error during classification: {e}")
64
+ return f"Classification error: {e}"
65
 
66
+ # Gradio interface
67
+ iface = gr.Interface(
68
+ fn=classify_audio,
69
+ inputs=gr.inputs.Audio(source="upload", type="filepath"),
70
+ outputs="text",
71
+ title="Audio Classification with Wav2Vec2",
72
+ description="Upload an audio file to classify its content using a Wav2Vec2 model."
73
+ )
74
 
75
  # Launch the interface
76
+ if __name__ == "__main__":
77
+ iface.launch()