File size: 3,179 Bytes
ee91d94
0a26e54
1364a7f
6530ee3
7045c5c
f0dd070
6530ee3
f0dd070
d75aa1b
f0dd070
ee91d94
d75aa1b
24baf79
 
0c35856
7045c5c
d75aa1b
 
 
 
 
 
 
 
7045c5c
d75aa1b
 
 
6530ee3
 
 
7045c5c
 
 
 
 
 
 
 
 
 
 
 
8a834c6
3b392fa
d75aa1b
 
 
 
 
 
 
df3ef47
6530ee3
8a834c6
780b961
e8e81bf
 
 
1364a7f
e8e81bf
 
d75aa1b
e8e81bf
1364a7f
f0dd070
df3ef47
24baf79
ee91d94
d75aa1b
ee91d94
cea8753
e8e81bf
cea8753
0c35856
9ff14b4
ee91d94
 
d75aa1b
7045c5c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
import librosa
import numpy as np
import torch
import torch.nn.functional as F
import logging
from transformers import AutoModelForAudioClassification

# Configure logging for debugging and information
logging.basicConfig(level=logging.INFO)

# Model loading from the specified local path
local_model_path = "./"
model = AutoModelForAudioClassification.from_pretrained(local_model_path)

def custom_feature_extraction(audio_file_path, sr=16000, n_mels=128, n_fft=2048, hop_length=512, target_length=1024):
    """
    Custom feature extraction using Mel spectrogram, tailored for models trained on datasets like AudioSet.
    Args:
        audio_file_path: Path to the audio file for prediction.
        sr: Target sampling rate for the audio file.
        n_mels: Number of Mel bands to generate.
        n_fft: Length of the FFT window.
        hop_length: Number of samples between successive frames.
        target_length: Expected length of the Mel spectrogram in the time dimension.
    Returns:
        A tensor representation of the Mel spectrogram features.
    """
    waveform, sample_rate = librosa.load(audio_file_path, sr=sr)
    S = librosa.feature.melspectrogram(y=waveform, sr=sample_rate, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
    S_DB = librosa.power_to_db(S, ref=np.max)
    mel_tensor = torch.tensor(S_DB).float()

    # Ensure the tensor matches the expected sequence length
    current_length = mel_tensor.shape[1]
    if current_length > target_length:
        mel_tensor = mel_tensor[:, :target_length]  # Truncate if longer
    elif current_length < target_length:
        padding = target_length - current_length
        mel_tensor = F.pad(mel_tensor, (0, padding), "constant", 0)  # Pad if shorter

    mel_tensor = mel_tensor.unsqueeze(0)  # Add batch dimension for compatibility with model
    return mel_tensor

def predict_voice(audio_file_path):
    """
    Predicts the audio class using a pre-trained model and custom feature extraction.
    Args:
        audio_file_path: Path to the audio file for prediction.
    Returns:
        A string containing the predicted class and confidence level.
    """
    try:
        features = custom_feature_extraction(audio_file_path)
        with torch.no_grad():
            outputs = model(features)
        logits = outputs.logits
        predicted_index = logits.argmax()
        label = model.config.id2label[predicted_index.item()]
        confidence = torch.softmax(logits, dim=1).max().item() * 100
        
        result = f"The voice is classified as '{label}' with a confidence of {confidence:.2f}%."
        logging.info("Prediction successful.")
    except Exception as e:
        result = f"Error during processing: {e}"
        logging.error(result)
    
    return result

# Setting up the Gradio interface
iface = gr.Interface(
    fn=predict_voice,
    inputs=gr.Audio(label="Upload Audio File", type="filepath"),
    outputs=gr.Textbox(label="Prediction"),
    title="Voice Authenticity Detection",
    description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results."
)

# Launching the interface
iface.launch()