Spaces:
Runtime error
Runtime error
File size: 2,889 Bytes
ee91d94 0a26e54 1364a7f 6530ee3 f0dd070 09e98e6 f0dd070 09e98e6 f0dd070 ee91d94 09e98e6 e02dec8 411539a 09e98e6 411539a 09e98e6 0c35856 09e98e6 411539a 09e98e6 411539a 09e98e6 411539a 09e98e6 411539a 09e98e6 411539a 09e98e6 411539a 09e98e6 411539a 09e98e6 411539a 09e98e6 411539a 09e98e6 8a834c6 09e98e6 df3ef47 09e98e6 411539a 09e98e6 e8e81bf 09e98e6 ee91d94 09e98e6 ee91d94 09e98e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import gradio as gr
import librosa
import numpy as np
import torch
import logging
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
# Initialize logging
logging.basicConfig(level=logging.INFO)
# Load the model and feature extractor
model_path = "./"
try:
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_path)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_path)
logging.info("Model and feature extractor loaded successfully.")
except Exception as e:
logging.error(f"Model loading failed: {e}")
raise e
def load_audio(audio_path, sr=16000):
"""
Load an audio file and resample to the target sample rate.
"""
try:
audio, _ = librosa.load(audio_path, sr=sr)
logging.info("Audio file loaded and resampled.")
return audio
except Exception as e:
logging.error(f"Failed to load audio: {e}")
raise e
def preprocess_audio(audio):
"""
Preprocess the audio file to the format expected by Wav2Vec2 model.
"""
try:
input_values = feature_extractor(audio, return_tensors="pt", padding="longest", sampling_rate=16000).input_values
logging.info("Audio file preprocessed.")
return input_values
except Exception as e:
logging.error(f"Audio preprocessing failed: {e}")
raise e
def predict(input_values):
"""
Make a prediction with the Wav2Vec2 model.
"""
try:
with torch.no_grad():
logits = model(input_values).logits
predicted_id = torch.argmax(logits, dim=-1)
logging.info(f"Prediction made with id {predicted_id}")
return predicted_id
except Exception as e:
logging.error(f"Prediction failed: {e}")
raise e
def get_label(prediction_id):
"""
Convert the prediction ID to a meaningful label.
"""
# Example of converting predicted id to label
# This should be adapted based on your specific model's labels
labels = ["label1", "label2"] # Dummy label list for demonstration
try:
label = labels[prediction_id]
logging.info(f"Label obtained: {label}")
return label
except Exception as e:
logging.error(f"Failed to get label: {e}")
raise e
def main(audio_file_path):
"""
Load audio, preprocess, predict, and return the label.
"""
try:
audio = load_audio(audio_file_path)
input_values = preprocess_audio(audio)
prediction_id = predict(input_values)
label = get_label(prediction_id)
return label
except Exception as e:
logging.error(f"Error in processing: {e}")
return str(e)
# Set up Gradio interface
iface = gr.Interface(fn=main, inputs=gr.inputs.Audio(type="filepath"), outputs="text", title="Audio Classification")
# Launch the interface
iface.launch()
|