DavidCombei's picture
Update app.py
7f6fdf0 verified
import joblib
from transformers import AutoFeatureExtractor, Wav2Vec2Model
import torch
import librosa
from scipy.special import expit
import json
import os
import gradio as gr
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class CustomWav2Vec2Model(Wav2Vec2Model):
def __init__(self, config):
super().__init__(config)
self.encoder.layers = self.encoder.layers[:9]
class HuggingFaceFeatureExtractor:
def __init__(self, model, feature_extractor_name):
self.device = device
self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_name)
self.model = model
self.model.eval()
self.model.to(self.device)
def __call__(self, audio, sr):
inputs = self.feature_extractor(
audio,
sampling_rate=sr,
return_tensors="pt",
padding=True,
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
return outputs.hidden_states[9]
truncated_model = CustomWav2Vec2Model.from_pretrained("facebook/wav2vec2-xls-r-2b")
FEATURE_EXTRACTOR = HuggingFaceFeatureExtractor(truncated_model, "facebook/wav2vec2-xls-r-2b")
classifier, scaler, thresh = joblib.load('logreg_margin_pruning_ALL_with_scaler+threshold.joblib')
def segment_audio(audio, sr, segment_duration):
segment_samples = int(segment_duration * sr)
total_samples = len(audio)
segments = [audio[i:i + segment_samples] for i in range(0, total_samples, segment_samples)]
segments_check = []
for seg in segments:
# if the segment is shorter than 0.7s, skip it to avoid complications inside wav2vec2
if len(seg) > 0.7 * sr:
segments_check.append(seg)
return segments_check
def process_audio(input_data, segment_duration=10):
audio, sr = librosa.load(input_data, sr=16000)
if len(audio.shape) > 1:
audio = audio[0]
segments = segment_audio(audio, sr, segment_duration)
segment_predictions = []
confidence_scores_fake_sum = 0
fake_segments = 0
confidence_scores_real_sum = 0
real_segments = 0
eer_threshold = thresh - 5e-3 # small margin error due to feature extractor space differences
#print(eer_threshold)
for idx, segment in enumerate(segments):
features = FEATURE_EXTRACTOR(segment, sr)
features_avg = torch.mean(features, dim=1).cpu().numpy()
features_avg = features_avg.reshape(1, -1)
decision_score = classifier.decision_function(features_avg)
decision_score_scaled = scaler.transform(decision_score.reshape(-1, 1)).flatten()
decision_value = decision_score_scaled[0]
pred = 1 if decision_value >= eer_threshold else 0
if pred == 0:
confidence_percentage = 1 - expit(decision_score).item()
confidence_scores_fake_sum +=confidence_percentage
fake_segments +=1
else:
confidence_percentage = expit(decision_score).item()
confidence_scores_real_sum +=confidence_percentage
real_segments +=1
segment_predictions.append(pred)
output_dict = {
"label": "real" if sum(segment_predictions) > (len(segment_predictions) / 2) else "fake",
"confidence score:": f'{confidence_scores_real_sum/real_segments:.2f}' if sum(segment_predictions) > (len(segment_predictions) / 2) else f'{confidence_scores_fake_sum/fake_segments:.2f}'
}
json_output = json.dumps(output_dict, indent=4)
print(json_output)
return json_output
def gradio_interface(audio):
if audio:
return process_audio(audio)
else:
return "please upload an audio file"
interface = gr.Interface(
fn=gradio_interface,
inputs=[gr.Audio(type="filepath", label="Upload Audio")],
outputs="text",
title="SOL2 Audio Deepfake Detection Demo",
description="Upload an audio file to check if it's AI-generated",
)
interface.launch(share=True)
#
#print(process_audio('SSL_scripts/1.wav'))