File size: 4,099 Bytes
879c4b9 b2b5493 d2fc68f 7f6fdf0 879c4b9 d2fc68f 879c4b9 d2fc68f 879c4b9 d2fc68f afdbee4 879c4b9 5053033 d2fc68f 879c4b9 711f553 879c4b9 d2fc68f 879c4b9 d2fc68f 879c4b9 d2fc68f 879c4b9 d2fc68f 879c4b9 0436843 d2fc68f b2b5493 c3a5025 b2b5493 879c4b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import joblib
from transformers import AutoFeatureExtractor, Wav2Vec2Model
import torch
import librosa
from scipy.special import expit
import json
import os
import gradio as gr
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class CustomWav2Vec2Model(Wav2Vec2Model):
def __init__(self, config):
super().__init__(config)
self.encoder.layers = self.encoder.layers[:9]
class HuggingFaceFeatureExtractor:
def __init__(self, model, feature_extractor_name):
self.device = device
self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_name)
self.model = model
self.model.eval()
self.model.to(self.device)
def __call__(self, audio, sr):
inputs = self.feature_extractor(
audio,
sampling_rate=sr,
return_tensors="pt",
padding=True,
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
return outputs.hidden_states[9]
truncated_model = CustomWav2Vec2Model.from_pretrained("facebook/wav2vec2-xls-r-2b")
FEATURE_EXTRACTOR = HuggingFaceFeatureExtractor(truncated_model, "facebook/wav2vec2-xls-r-2b")
classifier, scaler, thresh = joblib.load('logreg_margin_pruning_ALL_with_scaler+threshold.joblib')
def segment_audio(audio, sr, segment_duration):
segment_samples = int(segment_duration * sr)
total_samples = len(audio)
segments = [audio[i:i + segment_samples] for i in range(0, total_samples, segment_samples)]
segments_check = []
for seg in segments:
# if the segment is shorter than 0.7s, skip it to avoid complications inside wav2vec2
if len(seg) > 0.7 * sr:
segments_check.append(seg)
return segments_check
def process_audio(input_data, segment_duration=10):
audio, sr = librosa.load(input_data, sr=16000)
if len(audio.shape) > 1:
audio = audio[0]
segments = segment_audio(audio, sr, segment_duration)
segment_predictions = []
confidence_scores_fake_sum = 0
fake_segments = 0
confidence_scores_real_sum = 0
real_segments = 0
eer_threshold = thresh - 5e-3 # small margin error due to feature extractor space differences
#print(eer_threshold)
for idx, segment in enumerate(segments):
features = FEATURE_EXTRACTOR(segment, sr)
features_avg = torch.mean(features, dim=1).cpu().numpy()
features_avg = features_avg.reshape(1, -1)
decision_score = classifier.decision_function(features_avg)
decision_score_scaled = scaler.transform(decision_score.reshape(-1, 1)).flatten()
decision_value = decision_score_scaled[0]
pred = 1 if decision_value >= eer_threshold else 0
if pred == 0:
confidence_percentage = 1 - expit(decision_score).item()
confidence_scores_fake_sum +=confidence_percentage
fake_segments +=1
else:
confidence_percentage = expit(decision_score).item()
confidence_scores_real_sum +=confidence_percentage
real_segments +=1
segment_predictions.append(pred)
output_dict = {
"label": "real" if sum(segment_predictions) > (len(segment_predictions) / 2) else "fake",
"confidence score:": f'{confidence_scores_real_sum/real_segments:.2f}' if sum(segment_predictions) > (len(segment_predictions) / 2) else f'{confidence_scores_fake_sum/fake_segments:.2f}'
}
json_output = json.dumps(output_dict, indent=4)
print(json_output)
return json_output
def gradio_interface(audio):
if audio:
return process_audio(audio)
else:
return "please upload an audio file"
interface = gr.Interface(
fn=gradio_interface,
inputs=[gr.Audio(type="filepath", label="Upload Audio")],
outputs="text",
title="SOL2 Audio Deepfake Detection Demo",
description="Upload an audio file to check if it's AI-generated",
)
interface.launch(share=True)
#
#print(process_audio('SSL_scripts/1.wav'))
|