File size: 3,967 Bytes
879c4b9
 
 
 
 
 
 
 
 
 
b2b5493
879c4b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711f553
879c4b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2b5493
995694d
b2b5493
 
 
 
 
 
 
 
 
 
c3a5025
b2b5493
879c4b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import joblib
from transformers import AutoFeatureExtractor, Wav2Vec2Model
import torch
import librosa
import numpy as np
from sklearn.linear_model import LogisticRegression
import gradio as gr
import os
import torch.nn.functional as F
from scipy.special import expit
import json


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class CustomWav2Vec2Model(Wav2Vec2Model):
    def __init__(self, config):
        super().__init__(config)
        self.encoder.layers = self.encoder.layers[:9]

truncated_model = CustomWav2Vec2Model.from_pretrained("facebook/wav2vec2-xls-r-2b")

class HuggingFaceFeatureExtractor:
    def __init__(self, model, feature_extractor_name):
        self.device = device
        self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_name)
        self.model = model
        self.model.eval()
        self.model.to(self.device)

    def __call__(self, audio, sr):
        inputs = self.feature_extractor(
            audio,
            sampling_rate=sr,
            return_tensors="pt",
            padding=True,
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
        return outputs.hidden_states[9]

FEATURE_EXTRACTOR = HuggingFaceFeatureExtractor(truncated_model, "facebook/wav2vec2-xls-r-2b")
classifier,scaler, thresh = joblib.load('logreg_margin_pruning_ALL_with_scaler+threshold.joblib')

def segment_audio(audio, sr, segment_duration):
    segment_samples = int(segment_duration * sr)
    total_samples = len(audio)
    segments = [audio[i:i + segment_samples] for i in range(0, total_samples, segment_samples)]
    segments_check = []
    for seg in segments:
        # if the segment is shorter than 0.7s, skip it to avoid complications inside wav2vec2
        if len(seg) > 0.7 * sr:
           segments_check.append(seg)
    return segments_check

def process_audio(input_data, segment_duration=10):
    audio, sr = librosa.load(input_data, sr=16000)
    if len(audio.shape) > 1:
        audio = audio[0]
    segments = segment_audio(audio, sr, segment_duration)
    segment_predictions = []
    output_lines = []
    eer_threshold = thresh - 5e-3 # small margin error due to feature extractor space differences
    for idx, segment in enumerate(segments):
        features = FEATURE_EXTRACTOR(segment, sr)
        features_avg = torch.mean(features, dim=1).cpu().numpy()
        features_avg = features_avg.reshape(1, -1)
        decision_score = classifier.decision_function(features_avg)
        decision_score_scaled = scaler.transform(decision_score.reshape(-1, 1)).flatten()
        decision_value = decision_score_scaled[0]
        pred = 1 if decision_value >= eer_threshold else 0
        if pred == 1:
            confidence_percentage = expit(decision_score).item()
        else:
            confidence_percentage = 1 - expit(decision_score).item()
        segment_predictions.append(pred)
        output_dict = {
    "label": "real" if sum(segment_predictions) > (len(segment_predictions) / 2) else "fake",
    "segments": [
        {
            "segment": idx + 1,
            "prediction": "real" if pred == 1 else "fake",
            "confidence": round(conf * 100, 2)
        }
        for idx, (pred, conf) in enumerate(zip(segment_predictions, confidence_scores))
    ]
}
    json_output = json.dumps(output_dict, indent=4)
    print(json_output)
    return json_output

def gradio_interface(audio):
    if audio:
        return process_audio(audio)
    else:
        return "please upload an audio file"

interface = gr.Interface(
    fn=gradio_interface,
    inputs=[gr.Audio(type="filepath", label="Upload Audio")],
    outputs="text",
    title="SOL2 Audio Deepfake Detection Demo",
    description="Upload an audio file to check if it's AI-generated",
)

interface.launch(share=True)
#
#print(process_audio('SSL_scripts/1.wav'))