File size: 3,900 Bytes
fe0bcff
 
dfabd2f
30a5efb
05e6aba
84de51b
fe0bcff
ccce4a0
f0dd070
6aa52fc
fe0bcff
84de51b
ee91d94
05e6aba
ccce4a0
05e6aba
 
 
 
ccce4a0
 
 
 
 
6aa52fc
 
 
 
ccce4a0
6aa52fc
 
 
 
ccce4a0
 
 
 
 
05e6aba
 
84de51b
 
0c35856
fe0bcff
 
 
411539a
50facbf
411539a
84de51b
fe0bcff
84de51b
a29043b
fe0bcff
84de51b
53b1abc
fe0bcff
 
84de51b
fe0bcff
 
16802ac
fe0bcff
05e6aba
16802ac
 
 
 
 
 
 
 
05e6aba
6aa52fc
05e6aba
accc14f
16802ac
accc14f
 
 
e8e81bf
6aa52fc
ee91d94
15eca51
637d0ca
30c595f
05e6aba
accc14f
9f2fc99
 
05e6aba
b860c29
accc14f
fe0bcff
 
90d10cd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import librosa
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForAudioClassification, ASTFeatureExtractor
import random
import tempfile

# Model and feature extractor loading
model = AutoModelForAudioClassification.from_pretrained("./")
feature_extractor = ASTFeatureExtractor.from_pretrained("./")

def plot_waveform(waveform, sr):
    plt.figure(figsize=(12, 4))
    plt.title('Waveform')
    plt.ylabel('Amplitude')
    plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform)
    plt.xlabel('Time (s)')
    # Save plot to a temporary file
    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
    plt.savefig(temp_file.name)
    plt.close()  # Close the figure to free memory
    return temp_file.name

def plot_spectrogram(waveform, sr):
    S = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=128)
    S_DB = librosa.power_to_db(S, ref=np.max)
    plt.figure(figsize=(12, 6))
    librosa.display.specshow(S_DB, sr=sr, x_axis='time', y_axis='mel')
    plt.title('Mel Spectrogram')
    plt.colorbar(format='%+2.0f dB')
    plt.tight_layout()
    # Save plot to a temporary file
    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
    plt.savefig(temp_file.name)
    plt.close()  # Close the figure to free memory
    return temp_file.name

def custom_feature_extraction(audio, sr=16000, target_length=1024):
    features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
    return features.input_values

def apply_time_shift(waveform, max_shift_fraction=0.1):
    shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform)))
    return np.roll(waveform, shift)

def predict_voice(audio_file_path):
    try:
        waveform, sample_rate = librosa.load(audio_file_path, sr=feature_extractor.sampling_rate, mono=True)
        augmented_waveform = apply_time_shift(waveform)
        
        original_features = custom_feature_extraction(waveform, sr=sample_rate)
        augmented_features = custom_feature_extraction(augmented_waveform, sr=sample_rate)
        
        with torch.no_grad():
            outputs_original = model(original_features)
            outputs_augmented = model(augmented_features)
        
        logits = (outputs_original.logits + outputs_augmented.logits) / 2
        predicted_index = logits.argmax()
        original_label = model.config.id2label[predicted_index.item()]
        confidence = torch.softmax(logits, dim=1).max().item() * 100
        
        # Map original labels to new labels
        label_mapping = {
            "Spoof": "AI-generated Clone",
            "Bonafide": "Real Human Voice"
        }
        # Use the original label to get the new label
        new_label = label_mapping.get(original_label, "Unknown")  # Default to "Unknown" if label not found
        
        waveform_plot = plot_waveform(waveform, sample_rate)
        spectrogram_plot = plot_spectrogram(waveform, sample_rate)
        
        return (
            f"The voice is classified as '{new_label}' with a confidence of {confidence:.2f}%.",
            waveform_plot,
            spectrogram_plot
        )
    except Exception as e:
        return f"Error during processing: {e}", None, None

iface = gr.Interface(
    fn=predict_voice,
    inputs=gr.Audio(label="Upload Audio File", type="filepath"),
    outputs=[
        gr.Textbox(label="Prediction"),
        gr.Image(label="Waveform"),  # Adjusted to remove unsupported 'tool' argument
        gr.Image(label="Spectrogram")  # Adjusted to remove unsupported 'tool' argument
    ],
    title="Voice Clone Detection",
    description="Detects whether a voice is real or an AI-generated clone. Upload an audio file to see the results."
)

iface.launch()