Kabatubare's picture
Update app.py
15eca51 verified
raw
history blame
2.65 kB
import librosa
import numpy as np
import torch
import logging
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
import gradio as gr
logging.basicConfig(level=logging.INFO)
# Path to your Wav2Vec2 model and processor
model_path = "./wav2vec2-sequence-classification"
try:
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_path)
processor = Wav2Vec2Processor.from_pretrained(model_path)
logging.info("Model and processor loaded successfully.")
except Exception as e:
logging.error(f"Loading model and processor failed: {e}")
raise e
def preprocess_audio(file_path):
"""
Load and preprocess the audio file.
"""
# Load the audio file using librosa
audio, sr = librosa.load(file_path, sr=None)
# Resample the audio to 16 kHz (if not already at this sample rate)
if sr != 16000:
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
sr = 16000
return audio, sr
def audio_to_features(audio, sr):
"""
Convert audio waveform to model features.
"""
# Use the processor to prepare the features for the model
return processor(audio, sampling_rate=sr, return_tensors="pt", padding=True, truncation=True).input_values
def classify_audio(file_path):
"""
Classify the content of the audio file.
"""
try:
audio, sr = preprocess_audio(file_path)
input_values = audio_to_features(audio, sr)
# Inference
with torch.no_grad():
logits = model(input_values).logits
# Post-processing: Convert logits to softmax to get probabilities
probabilities = torch.softmax(logits, dim=1).detach().numpy()
# Assuming you have a binary classification model for simplicity
# Modify this part based on your actual number of classes and labels
labels = ['Class 0', 'Class 1'] # Example labels
predictions = dict(zip(labels, probabilities[0]))
# Format the prediction output
prediction_output = "\n".join([f"{label}: {prob:.4f}" for label, prob in predictions.items()])
return prediction_output
except Exception as e:
logging.error(f"Error during classification: {e}")
return f"Classification error: {e}"
# Gradio interface
iface = gr.Interface(
fn=classify_audio,
inputs=gr.inputs.Audio(source="upload", type="filepath"),
outputs="text",
title="Audio Classification with Wav2Vec2",
description="Upload an audio file to classify its content using a Wav2Vec2 model."
)
# Launch the interface
if __name__ == "__main__":
iface.launch()