Spaces:
Runtime error
Runtime error
import librosa | |
import numpy as np | |
import torch | |
import logging | |
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor | |
import gradio as gr | |
logging.basicConfig(level=logging.INFO) | |
# Path to your Wav2Vec2 model and processor | |
model_path = "./wav2vec2-sequence-classification" | |
try: | |
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_path) | |
processor = Wav2Vec2Processor.from_pretrained(model_path) | |
logging.info("Model and processor loaded successfully.") | |
except Exception as e: | |
logging.error(f"Loading model and processor failed: {e}") | |
raise e | |
def preprocess_audio(file_path): | |
""" | |
Load and preprocess the audio file. | |
""" | |
# Load the audio file using librosa | |
audio, sr = librosa.load(file_path, sr=None) | |
# Resample the audio to 16 kHz (if not already at this sample rate) | |
if sr != 16000: | |
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) | |
sr = 16000 | |
return audio, sr | |
def audio_to_features(audio, sr): | |
""" | |
Convert audio waveform to model features. | |
""" | |
# Use the processor to prepare the features for the model | |
return processor(audio, sampling_rate=sr, return_tensors="pt", padding=True, truncation=True).input_values | |
def classify_audio(file_path): | |
""" | |
Classify the content of the audio file. | |
""" | |
try: | |
audio, sr = preprocess_audio(file_path) | |
input_values = audio_to_features(audio, sr) | |
# Inference | |
with torch.no_grad(): | |
logits = model(input_values).logits | |
# Post-processing: Convert logits to softmax to get probabilities | |
probabilities = torch.softmax(logits, dim=1).detach().numpy() | |
# Assuming you have a binary classification model for simplicity | |
# Modify this part based on your actual number of classes and labels | |
labels = ['Class 0', 'Class 1'] # Example labels | |
predictions = dict(zip(labels, probabilities[0])) | |
# Format the prediction output | |
prediction_output = "\n".join([f"{label}: {prob:.4f}" for label, prob in predictions.items()]) | |
return prediction_output | |
except Exception as e: | |
logging.error(f"Error during classification: {e}") | |
return f"Classification error: {e}" | |
# Gradio interface | |
iface = gr.Interface( | |
fn=classify_audio, | |
inputs=gr.inputs.Audio(source="upload", type="filepath"), | |
outputs="text", | |
title="Audio Classification with Wav2Vec2", | |
description="Upload an audio file to classify its content using a Wav2Vec2 model." | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
iface.launch() | |