File size: 4,933 Bytes
59b5f81
 
 
 
aa9e812
 
59b5f81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa9e812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59b5f81
aa9e812
59b5f81
 
 
 
aa9e812
59b5f81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1ffa5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59b5f81
 
 
 
aa9e812
 
59b5f81
aa9e812
 
 
 
d1ffa5e
 
59b5f81
 
 
 
 
d1ffa5e
aa9e812
59b5f81
 
 
d1ffa5e
59b5f81
 
 
aa9e812
d1ffa5e
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import streamlit as st
import librosa
import numpy as np
import onnxruntime as ort
import os
import requests

# Audio padding function
def pad(x, max_len=64600):
    """
    Pad or trim an audio segment to a fixed length by repeating or slicing.
    """
    x_len = x.shape[0]
    if x_len >= max_len:
        return x[:max_len]  # Trim if longer
    # Repeat to fill max_len
    num_repeats = (max_len // x_len) + 1
    padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0]
    return padded_x

# Preprocess audio for a single segment
def preprocess_audio_segment(segment, cut=64600):
    """
    Preprocess a single audio segment: pad or trim as required.
    """
    segment = pad(segment, max_len=cut)
    return np.expand_dims(np.array(segment, dtype=np.float32), axis=0)  # Add batch dimension

# Download ONNX model from Hugging Face
def download_model(url, local_path="RawNet_model.onnx"):
    """
    Download the ONNX model from a URL if it doesn't already exist locally.
    """
    if not os.path.exists(local_path):
        with st.spinner("Downloading ONNX model..."):
            response = requests.get(url)
            if response.status_code == 200:
                with open(local_path, "wb") as f:
                    f.write(response.content)
                st.success("Model downloaded successfully!")
            else:
                raise Exception("Failed to download ONNX model")
    return local_path

# Sliding window prediction function
def predict_with_sliding_window(audio_path, onnx_model_path, window_size=64600, step_size=64600, sample_rate=16000):
    """
    Use a sliding window to predict if the audio is real or fake over the entire audio.
    """
    # Load ONNX runtime session
    ort_session = ort.InferenceSession(onnx_model_path)

    # Load audio file
    waveform, _ = librosa.load(audio_path, sr=sample_rate)
    total_segments = []
    total_probabilities = []

    # Sliding window processing
    for start in range(0, len(waveform), step_size):
        end = start + window_size
        segment = waveform[start:end]

        # Preprocess the segment
        audio_tensor = preprocess_audio_segment(segment)

        # Perform inference
        inputs = {ort_session.get_inputs()[0].name: audio_tensor}
        outputs = ort_session.run(None, inputs)
        probabilities = np.exp(outputs[0])  # Softmax probabilities
        prediction = np.argmax(probabilities)

        # Store the results
        predicted_class = "Real" if prediction == 1 else "Fake"
        total_segments.append(predicted_class)
        total_probabilities.append(probabilities[0][prediction])

    # Final aggregation
    majority_class = max(set(total_segments), key=total_segments.count)  # Majority voting
    avg_probability = np.mean(total_probabilities) * 100  # Average probability in percentage

    return majority_class, avg_probability

# Streamlit app
st.set_page_config(page_title="Audio Spoof Detection", page_icon="🎵", layout="centered")

# Header Section
st.markdown("<h1 style='text-align: center; color: blue;'>Audio Spoof Detection</h1>", unsafe_allow_html=True)
st.markdown(
    """
    <p style='text-align: center;'>
    Detect whether an uploaded audio file is <strong>Real</strong> or <strong>Fake</strong> using an ONNX model.
    </p>
    """,
    unsafe_allow_html=True,
)

# Sidebar
st.sidebar.header("Instructions")
st.sidebar.write(
    """
    - Upload an audio file in WAV or MP3 format.
    - Wait for the model to process the file.
    - View the prediction result and confidence score.
    """
)
st.sidebar.markdown("### About the Model")
st.sidebar.info(
    """
    The model is trained to classify audio as Real or Fake using a RawNet-based architecture.
    """
)

# File uploader
uploaded_file = st.file_uploader("Upload your audio file (WAV or MP3)", type=["wav", "mp3"])

# ONNX model URL (replace with your actual Hugging Face model URL)
onnx_model_url = "https://huggingface.co/Mrkomiljon/DeepVoiceGuard/resolve/main/RawNet_model.onnx"

# Ensure ONNX model is downloaded locally
onnx_model_path = download_model(onnx_model_url)

if uploaded_file is not None:
    st.markdown("<h3 style='text-align: center;'>Processing Your File...</h3>", unsafe_allow_html=True)

    # Save uploaded file temporarily
    with open("temp_audio_file.wav", "wb") as f:
        f.write(uploaded_file.read())

    # Perform prediction
    with st.spinner("Running the model..."):
        result, avg_probability = predict_with_sliding_window("temp_audio_file.wav", onnx_model_path)

    # Display results
    st.success(f"Prediction: {result}")
    st.metric(label="Confidence", value=f"{avg_probability:.2f}%", delta=None)

    # Clean up temporary file
    os.remove("temp_audio_file.wav")

# Footer
st.markdown(
    """
    <hr>
    <p style='text-align: center; font-size: small;'>
    Created with ❤️ using Streamlit.
    </p>
    """,
    unsafe_allow_html=True,
)