File size: 3,017 Bytes
fe0bcff
 
dfabd2f
30a5efb
05e6aba
84de51b
fe0bcff
f0dd070
6aa52fc
fe0bcff
84de51b
ee91d94
05e6aba
 
 
 
 
 
6aa52fc
 
 
 
 
 
 
 
 
 
05e6aba
 
 
84de51b
 
0c35856
fe0bcff
 
 
411539a
50facbf
411539a
84de51b
fe0bcff
84de51b
a29043b
fe0bcff
84de51b
53b1abc
fe0bcff
 
84de51b
fe0bcff
 
 
 
05e6aba
 
6aa52fc
05e6aba
6aa52fc
 
 
 
 
e8e81bf
6aa52fc
ee91d94
15eca51
637d0ca
30c595f
05e6aba
 
6aa52fc
 
05e6aba
637d0ca
fe0bcff
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
import librosa
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForAudioClassification, ASTFeatureExtractor
import random

# Model and feature extractor loading
model = AutoModelForAudioClassification.from_pretrained("./")
feature_extractor = ASTFeatureExtractor.from_pretrained("./")

def plot_waveform(waveform, sr):
    plt.figure(figsize=(10, 3))
    plt.title('Waveform')
    plt.ylabel('Amplitude')
    plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform)
    plt.xlabel('Time (s)')
    return plt.gcf()

def plot_spectrogram(waveform, sr):
    S = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=128)
    S_DB = librosa.power_to_db(S, ref=np.max)
    plt.figure(figsize=(10, 4))
    librosa.display.specshow(S_DB, sr=sr, x_axis='time', y_axis='mel')
    plt.title('Mel Spectrogram')
    plt.colorbar(format='%+2.0f dB')
    plt.tight_layout()
    return plt.gcf()

def custom_feature_extraction(audio, sr=16000, target_length=1024):
    features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
    return features.input_values

def apply_time_shift(waveform, max_shift_fraction=0.1):
    shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform)))
    return np.roll(waveform, shift)

def predict_voice(audio_file_path):
    try:
        waveform, sample_rate = librosa.load(audio_file_path, sr=feature_extractor.sampling_rate, mono=True)
        augmented_waveform = apply_time_shift(waveform)
        
        original_features = custom_feature_extraction(waveform, sr=sample_rate)
        augmented_features = custom_feature_extraction(augmented_waveform, sr=sample_rate)
        
        with torch.no_grad():
            outputs_original = model(original_features)
            outputs_augmented = model(augmented_features)
        
        logits = (outputs_original.logits + outputs_augmented.logits) / 2
        predicted_index = logits.argmax()
        label = model.config.id2label[predicted_index.item()]
        confidence = torch.softmax(logits, dim=1).max().item() * 100
        
        waveform_plot = plot_waveform(waveform, sample_rate)
        spectrogram_plot = plot_spectrogram(waveform, sample_rate)
        
        return (
            f"The voice is classified as '{label}' with a confidence of {confidence:.2f}%.",
            waveform_plot,
            spectrogram_plot
        )
    except Exception as e:
        return f"Error during processing: {e}", None, None

iface = gr.Interface(
    fn=predict_voice,
    inputs=gr.Audio(label="Upload Audio File", type="filepath"),
    outputs=[
        gr.Textbox(label="Prediction"),
        gr.Plot(label="Waveform"),
        gr.Plot(label="Spectrogram")
    ],
    title="Voice Authenticity Detection",
    description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results."
)

iface.launch()