File size: 4,099 Bytes
879c4b9
 
 
 
 
b2b5493
d2fc68f
7f6fdf0
879c4b9
 
 
d2fc68f
 
879c4b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2fc68f
879c4b9
d2fc68f
afdbee4
879c4b9
5053033
d2fc68f
 
879c4b9
 
 
 
 
 
 
 
 
 
711f553
879c4b9
 
 
 
 
 
 
d2fc68f
 
 
 
 
 
879c4b9
 
d2fc68f
 
879c4b9
 
 
 
d2fc68f
879c4b9
d2fc68f
 
 
 
 
 
879c4b9
0436843
d2fc68f
 
 
b2b5493
c3a5025
b2b5493
879c4b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import joblib
from transformers import AutoFeatureExtractor, Wav2Vec2Model
import torch
import librosa
from scipy.special import expit
import json
import os
import gradio as gr

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



class CustomWav2Vec2Model(Wav2Vec2Model):
    def __init__(self, config):
        super().__init__(config)
        self.encoder.layers = self.encoder.layers[:9]


class HuggingFaceFeatureExtractor:
    def __init__(self, model, feature_extractor_name):
        self.device = device
        self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_name)
        self.model = model
        self.model.eval()
        self.model.to(self.device)

    def __call__(self, audio, sr):
        inputs = self.feature_extractor(
            audio,
            sampling_rate=sr,
            return_tensors="pt",
            padding=True,
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
        return outputs.hidden_states[9] 


truncated_model = CustomWav2Vec2Model.from_pretrained("facebook/wav2vec2-xls-r-2b")
FEATURE_EXTRACTOR = HuggingFaceFeatureExtractor(truncated_model, "facebook/wav2vec2-xls-r-2b")
classifier, scaler, thresh = joblib.load('logreg_margin_pruning_ALL_with_scaler+threshold.joblib')



def segment_audio(audio, sr, segment_duration):
    segment_samples = int(segment_duration * sr)
    total_samples = len(audio)
    segments = [audio[i:i + segment_samples] for i in range(0, total_samples, segment_samples)]
    segments_check = []
    for seg in segments:
        # if the segment is shorter than 0.7s, skip it to avoid complications inside wav2vec2
        if len(seg) > 0.7 * sr:
           segments_check.append(seg)
    return segments_check

def process_audio(input_data, segment_duration=10):
    audio, sr = librosa.load(input_data, sr=16000)
    if len(audio.shape) > 1:
        audio = audio[0]
    segments = segment_audio(audio, sr, segment_duration)
    segment_predictions = []
    confidence_scores_fake_sum = 0
    fake_segments = 0
    confidence_scores_real_sum = 0
    real_segments = 0
    eer_threshold = thresh - 5e-3 # small margin error due to feature extractor space differences
    #print(eer_threshold)
    for idx, segment in enumerate(segments):
        features = FEATURE_EXTRACTOR(segment, sr)
        features_avg = torch.mean(features, dim=1).cpu().numpy()
        features_avg = features_avg.reshape(1, -1)
        decision_score = classifier.decision_function(features_avg)
        decision_score_scaled = scaler.transform(decision_score.reshape(-1, 1)).flatten()
        decision_value = decision_score_scaled[0]
        pred = 1 if decision_value >= eer_threshold else 0
        if pred == 0:
            confidence_percentage = 1 - expit(decision_score).item()
            confidence_scores_fake_sum +=confidence_percentage
            fake_segments +=1
        else:
            confidence_percentage = expit(decision_score).item()
            confidence_scores_real_sum +=confidence_percentage
            real_segments +=1
        segment_predictions.append(pred)
    output_dict = {
    "label": "real" if sum(segment_predictions) > (len(segment_predictions) / 2) else "fake",
    "confidence score:": f'{confidence_scores_real_sum/real_segments:.2f}' if sum(segment_predictions) > (len(segment_predictions) / 2) else f'{confidence_scores_fake_sum/fake_segments:.2f}'
        }
    json_output = json.dumps(output_dict, indent=4)
    print(json_output)
    return json_output

def gradio_interface(audio):
    if audio:
        return process_audio(audio)
    else:
        return "please upload an audio file"

interface = gr.Interface(
    fn=gradio_interface,
    inputs=[gr.Audio(type="filepath", label="Upload Audio")],
    outputs="text",
    title="SOL2 Audio Deepfake Detection Demo",
    description="Upload an audio file to check if it's AI-generated",
)

interface.launch(share=True)
#
#print(process_audio('SSL_scripts/1.wav'))