File size: 2,889 Bytes
ee91d94
0a26e54
1364a7f
6530ee3
f0dd070
09e98e6
f0dd070
09e98e6
f0dd070
ee91d94
09e98e6
e02dec8
411539a
09e98e6
 
 
411539a
09e98e6
 
0c35856
09e98e6
 
 
 
411539a
09e98e6
411539a
09e98e6
411539a
09e98e6
 
411539a
09e98e6
 
 
 
411539a
09e98e6
 
 
411539a
09e98e6
 
411539a
09e98e6
 
 
 
411539a
09e98e6
 
 
 
 
411539a
09e98e6
 
8a834c6
09e98e6
 
 
 
 
 
 
df3ef47
09e98e6
 
 
 
 
 
411539a
09e98e6
 
 
 
 
 
 
 
 
 
e8e81bf
09e98e6
 
ee91d94
09e98e6
 
ee91d94
09e98e6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
import librosa
import numpy as np
import torch
import logging
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor

# Initialize logging
logging.basicConfig(level=logging.INFO)

# Load the model and feature extractor
model_path = "./"
try:
    model = Wav2Vec2ForSequenceClassification.from_pretrained(model_path)
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_path)
    logging.info("Model and feature extractor loaded successfully.")
except Exception as e:
    logging.error(f"Model loading failed: {e}")
    raise e

def load_audio(audio_path, sr=16000):
    """
    Load an audio file and resample to the target sample rate.
    """
    try:
        audio, _ = librosa.load(audio_path, sr=sr)
        logging.info("Audio file loaded and resampled.")
        return audio
    except Exception as e:
        logging.error(f"Failed to load audio: {e}")
        raise e

def preprocess_audio(audio):
    """
    Preprocess the audio file to the format expected by Wav2Vec2 model.
    """
    try:
        input_values = feature_extractor(audio, return_tensors="pt", padding="longest", sampling_rate=16000).input_values
        logging.info("Audio file preprocessed.")
        return input_values
    except Exception as e:
        logging.error(f"Audio preprocessing failed: {e}")
        raise e

def predict(input_values):
    """
    Make a prediction with the Wav2Vec2 model.
    """
    try:
        with torch.no_grad():
            logits = model(input_values).logits
            predicted_id = torch.argmax(logits, dim=-1)
            logging.info(f"Prediction made with id {predicted_id}")
        return predicted_id
    except Exception as e:
        logging.error(f"Prediction failed: {e}")
        raise e

def get_label(prediction_id):
    """
    Convert the prediction ID to a meaningful label.
    """
    # Example of converting predicted id to label
    # This should be adapted based on your specific model's labels
    labels = ["label1", "label2"]  # Dummy label list for demonstration
    try:
        label = labels[prediction_id]
        logging.info(f"Label obtained: {label}")
        return label
    except Exception as e:
        logging.error(f"Failed to get label: {e}")
        raise e

def main(audio_file_path):
    """
    Load audio, preprocess, predict, and return the label.
    """
    try:
        audio = load_audio(audio_file_path)
        input_values = preprocess_audio(audio)
        prediction_id = predict(input_values)
        label = get_label(prediction_id)
        return label
    except Exception as e:
        logging.error(f"Error in processing: {e}")
        return str(e)

# Set up Gradio interface
iface = gr.Interface(fn=main, inputs=gr.inputs.Audio(type="filepath"), outputs="text", title="Audio Classification")

# Launch the interface
iface.launch()