Spaces:
Sleeping
Sleeping
import streamlit as st | |
import librosa | |
import numpy as np | |
import onnxruntime as ort | |
import os | |
import requests | |
# Audio padding function | |
def pad(x, max_len=64600): | |
""" | |
Pad or trim an audio segment to a fixed length by repeating or slicing. | |
""" | |
x_len = x.shape[0] | |
if x_len >= max_len: | |
return x[:max_len] # Trim if longer | |
# Repeat to fill max_len | |
num_repeats = (max_len // x_len) + 1 | |
padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0] | |
return padded_x | |
# Preprocess audio for a single segment | |
def preprocess_audio_segment(segment, cut=64600): | |
""" | |
Preprocess a single audio segment: pad or trim as required. | |
""" | |
segment = pad(segment, max_len=cut) | |
return np.expand_dims(np.array(segment, dtype=np.float32), axis=0) # Add batch dimension | |
# Download ONNX model from Hugging Face | |
def download_model(url, local_path="RawNet_model.onnx"): | |
""" | |
Download the ONNX model from a URL if it doesn't already exist locally. | |
""" | |
if not os.path.exists(local_path): | |
with st.spinner("Downloading ONNX model..."): | |
response = requests.get(url) | |
if response.status_code == 200: | |
with open(local_path, "wb") as f: | |
f.write(response.content) | |
st.success("Model downloaded successfully!") | |
else: | |
raise Exception("Failed to download ONNX model") | |
return local_path | |
# Sliding window prediction function | |
def predict_with_sliding_window(audio_path, onnx_model_path, window_size=64600, step_size=64600, sample_rate=16000): | |
""" | |
Use a sliding window to predict if the audio is real or fake over the entire audio. | |
""" | |
# Load ONNX runtime session | |
ort_session = ort.InferenceSession(onnx_model_path) | |
# Load audio file | |
waveform, _ = librosa.load(audio_path, sr=sample_rate) | |
total_segments = [] | |
total_probabilities = [] | |
# Sliding window processing | |
for start in range(0, len(waveform), step_size): | |
end = start + window_size | |
segment = waveform[start:end] | |
# Preprocess the segment | |
audio_tensor = preprocess_audio_segment(segment) | |
# Perform inference | |
inputs = {ort_session.get_inputs()[0].name: audio_tensor} | |
outputs = ort_session.run(None, inputs) | |
probabilities = np.exp(outputs[0]) # Softmax probabilities | |
prediction = np.argmax(probabilities) | |
# Store the results | |
predicted_class = "Real" if prediction == 1 else "Fake" | |
total_segments.append(predicted_class) | |
total_probabilities.append(probabilities[0][prediction]) | |
# Final aggregation | |
majority_class = max(set(total_segments), key=total_segments.count) # Majority voting | |
avg_probability = np.mean(total_probabilities) * 100 # Average probability in percentage | |
return majority_class, avg_probability | |
# Streamlit app | |
st.set_page_config(page_title="Audio Spoof Detection", page_icon="🎵", layout="centered") | |
# Header Section | |
st.markdown("<h1 style='text-align: center; color: blue;'>Audio Spoof Detection</h1>", unsafe_allow_html=True) | |
st.markdown( | |
""" | |
<p style='text-align: center;'> | |
Detect whether an uploaded audio file is <strong>Real</strong> or <strong>Fake</strong> using an ONNX model. | |
</p> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Sidebar | |
st.sidebar.header("Instructions") | |
st.sidebar.write( | |
""" | |
- Upload an audio file in WAV or MP3 format. | |
- Wait for the model to process the file. | |
- View the prediction result and confidence score. | |
""" | |
) | |
st.sidebar.markdown("### About the Model") | |
st.sidebar.info( | |
""" | |
The model is trained to classify audio as Real or Fake using a RawNet-based architecture. | |
""" | |
) | |
# File uploader | |
uploaded_file = st.file_uploader("Upload your audio file (WAV or MP3)", type=["wav", "mp3"]) | |
# ONNX model URL (replace with your actual Hugging Face model URL) | |
onnx_model_url = "https://huggingface.co/Mrkomiljon/DeepVoiceGuard/resolve/main/RawNet_model.onnx" | |
# Ensure ONNX model is downloaded locally | |
onnx_model_path = download_model(onnx_model_url) | |
if uploaded_file is not None: | |
st.markdown("<h3 style='text-align: center;'>Processing Your File...</h3>", unsafe_allow_html=True) | |
# Save uploaded file temporarily | |
with open("temp_audio_file.wav", "wb") as f: | |
f.write(uploaded_file.read()) | |
# Perform prediction | |
with st.spinner("Running the model..."): | |
result, avg_probability = predict_with_sliding_window("temp_audio_file.wav", onnx_model_path) | |
# Display results | |
st.success(f"Prediction: {result}") | |
st.metric(label="Confidence", value=f"{avg_probability:.2f}%", delta=None) | |
# Clean up temporary file | |
os.remove("temp_audio_file.wav") | |
# Footer | |
st.markdown( | |
""" | |
<hr> | |
<p style='text-align: center; font-size: small;'> | |
Created with ❤️ using Streamlit. | |
</p> | |
""", | |
unsafe_allow_html=True, | |
) | |