Kabatubare's picture
Update app.py
09e98e6 verified
raw
history blame
2.89 kB
import gradio as gr
import librosa
import numpy as np
import torch
import logging
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
# Initialize logging
logging.basicConfig(level=logging.INFO)
# Load the model and feature extractor
model_path = "./"
try:
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_path)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_path)
logging.info("Model and feature extractor loaded successfully.")
except Exception as e:
logging.error(f"Model loading failed: {e}")
raise e
def load_audio(audio_path, sr=16000):
"""
Load an audio file and resample to the target sample rate.
"""
try:
audio, _ = librosa.load(audio_path, sr=sr)
logging.info("Audio file loaded and resampled.")
return audio
except Exception as e:
logging.error(f"Failed to load audio: {e}")
raise e
def preprocess_audio(audio):
"""
Preprocess the audio file to the format expected by Wav2Vec2 model.
"""
try:
input_values = feature_extractor(audio, return_tensors="pt", padding="longest", sampling_rate=16000).input_values
logging.info("Audio file preprocessed.")
return input_values
except Exception as e:
logging.error(f"Audio preprocessing failed: {e}")
raise e
def predict(input_values):
"""
Make a prediction with the Wav2Vec2 model.
"""
try:
with torch.no_grad():
logits = model(input_values).logits
predicted_id = torch.argmax(logits, dim=-1)
logging.info(f"Prediction made with id {predicted_id}")
return predicted_id
except Exception as e:
logging.error(f"Prediction failed: {e}")
raise e
def get_label(prediction_id):
"""
Convert the prediction ID to a meaningful label.
"""
# Example of converting predicted id to label
# This should be adapted based on your specific model's labels
labels = ["label1", "label2"] # Dummy label list for demonstration
try:
label = labels[prediction_id]
logging.info(f"Label obtained: {label}")
return label
except Exception as e:
logging.error(f"Failed to get label: {e}")
raise e
def main(audio_file_path):
"""
Load audio, preprocess, predict, and return the label.
"""
try:
audio = load_audio(audio_file_path)
input_values = preprocess_audio(audio)
prediction_id = predict(input_values)
label = get_label(prediction_id)
return label
except Exception as e:
logging.error(f"Error in processing: {e}")
return str(e)
# Set up Gradio interface
iface = gr.Interface(fn=main, inputs=gr.inputs.Audio(type="filepath"), outputs="text", title="Audio Classification")
# Launch the interface
iface.launch()