Spaces:
Runtime error
Runtime error
import gradio as gr | |
import librosa | |
import numpy as np | |
import torch | |
import logging | |
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor | |
# Initialize logging | |
logging.basicConfig(level=logging.INFO) | |
# Load the model and feature extractor | |
model_path = "./" | |
try: | |
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_path) | |
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_path) | |
logging.info("Model and feature extractor loaded successfully.") | |
except Exception as e: | |
logging.error(f"Model loading failed: {e}") | |
raise e | |
def load_audio(audio_path, sr=16000): | |
""" | |
Load an audio file and resample to the target sample rate. | |
""" | |
try: | |
audio, _ = librosa.load(audio_path, sr=sr) | |
logging.info("Audio file loaded and resampled.") | |
return audio | |
except Exception as e: | |
logging.error(f"Failed to load audio: {e}") | |
raise e | |
def preprocess_audio(audio): | |
""" | |
Preprocess the audio file to the format expected by Wav2Vec2 model. | |
""" | |
try: | |
input_values = feature_extractor(audio, return_tensors="pt", padding="longest", sampling_rate=16000).input_values | |
logging.info("Audio file preprocessed.") | |
return input_values | |
except Exception as e: | |
logging.error(f"Audio preprocessing failed: {e}") | |
raise e | |
def predict(input_values): | |
""" | |
Make a prediction with the Wav2Vec2 model. | |
""" | |
try: | |
with torch.no_grad(): | |
logits = model(input_values).logits | |
predicted_id = torch.argmax(logits, dim=-1) | |
logging.info(f"Prediction made with id {predicted_id}") | |
return predicted_id | |
except Exception as e: | |
logging.error(f"Prediction failed: {e}") | |
raise e | |
def get_label(prediction_id): | |
""" | |
Convert the prediction ID to a meaningful label. | |
""" | |
# Example of converting predicted id to label | |
# This should be adapted based on your specific model's labels | |
labels = ["label1", "label2"] # Dummy label list for demonstration | |
try: | |
label = labels[prediction_id] | |
logging.info(f"Label obtained: {label}") | |
return label | |
except Exception as e: | |
logging.error(f"Failed to get label: {e}") | |
raise e | |
def main(audio_file_path): | |
""" | |
Load audio, preprocess, predict, and return the label. | |
""" | |
try: | |
audio = load_audio(audio_file_path) | |
input_values = preprocess_audio(audio) | |
prediction_id = predict(input_values) | |
label = get_label(prediction_id) | |
return label | |
except Exception as e: | |
logging.error(f"Error in processing: {e}") | |
return str(e) | |
# Set up Gradio interface | |
iface = gr.Interface(fn=main, inputs=gr.inputs.Audio(type="filepath"), outputs="text", title="Audio Classification") | |
# Launch the interface | |
iface.launch() | |