Spaces:
Sleeping
Sleeping
# streamlit_app.py | |
import streamlit as st | |
import os | |
import librosa | |
import librosa.display | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import tensorflow as tf | |
from tensorflow.keras.utils import to_categorical | |
from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve, average_precision_score, calibration_curve, ConfusionMatrixDisplay | |
from keras.models import load_model | |
SAMPLE_RATE = 16000 | |
DURATION = 5 | |
N_MELS = 128 | |
MAX_TIME_STEPS = 109 | |
NUM_CLASSES = 2 | |
# Streamlit App | |
st.title("Audio Spoofing Detection App") | |
st.sidebar.header("Model Options") | |
task = st.sidebar.selectbox("Select Task", ["Train Model", "Evaluate Model", "Visualize Spectrogram"]) | |
if task == "Train Model": | |
st.header("Train a New Model") | |
uploaded_files = st.file_uploader("Upload FLAC Training Files", accept_multiple_files=True, type='flac') | |
label_file = st.file_uploader("Upload Labels File (txt)", type="txt") | |
if uploaded_files and label_file: | |
# Parse the label file | |
labels = {} | |
for line in label_file.getvalue().decode("utf-8").splitlines(): | |
parts = line.strip().split() | |
file_name = parts[1] | |
label = 1 if parts[-1] == "bonafide" else 0 | |
labels[file_name] = label | |
X, y = [], [] | |
for file in uploaded_files: | |
file_name = file.name.split(".")[0] | |
label = labels[file_name] | |
# Load audio file | |
audio, _ = librosa.load(file, sr=SAMPLE_RATE, duration=DURATION) | |
# Extract Mel spectrogram | |
mel_spectrogram = librosa.feature.melspectrogram(y=audio, sr=SAMPLE_RATE, n_mels=N_MELS) | |
mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max) | |
# Padding | |
if mel_spectrogram.shape[1] < MAX_TIME_STEPS: | |
mel_spectrogram = np.pad(mel_spectrogram, ((0, 0), (0, MAX_TIME_STEPS - mel_spectrogram.shape[1])), mode='constant') | |
else: | |
mel_spectrogram = mel_spectrogram[:, :MAX_TIME_STEPS] | |
X.append(mel_spectrogram) | |
y.append(label) | |
X = np.array(X) | |
y = np.array(y) | |
y_encoded = to_categorical(y, NUM_CLASSES) | |
# Split into train and validation sets | |
split_index = int(0.8 * len(X)) | |
X_train, X_val = X[:split_index], X[split_index:] | |
y_train, y_val = y_encoded[:split_index], y_encoded[split_index:] | |
input_shape = (N_MELS, X_train.shape[2], 1) | |
# Define CNN model | |
model_input = tf.keras.Input(shape=input_shape) | |
x = tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu')(model_input) | |
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x) | |
x = tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu')(x) | |
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x) | |
x = tf.keras.layers.Flatten()(x) | |
x = tf.keras.layers.Dense(128, activation='relu')(x) | |
x = tf.keras.layers.Dropout(0.5)(x) | |
model_output = tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')(x) | |
model = tf.keras.Model(inputs=model_input, outputs=model_output) | |
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) | |
# Train the model | |
if st.button("Start Training"): | |
st.write("Training in progress...") | |
model.fit(X_train, y_train, batch_size=32, epochs=10, validation_data=(X_val, y_val)) | |
model.save("audio_classifier.h5") | |
st.success("Training Complete. Model Saved!") | |
if task == "Evaluate Model": | |
st.header("Evaluate a Trained Model") | |
model_file = st.file_uploader("Upload Model (h5)", type='h5') | |
test_files = st.file_uploader("Upload Test FLAC Files", accept_multiple_files=True, type='flac') | |
protocol_file = st.file_uploader("Upload Protocol File (txt)", type='txt') | |
if model_file and test_files and protocol_file: | |
# Load Model | |
model = load_model(model_file) | |
# Prepare test data | |
X_test = [] | |
for file in test_files: | |
audio, _ = librosa.load(file, sr=SAMPLE_RATE, duration=DURATION) | |
mel_spectrogram = librosa.feature.melspectrogram(y=audio, sr=SAMPLE_RATE, n_mels=N_MELS) | |
mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max) | |
if mel_spectrogram.shape[1] < MAX_TIME_STEPS: | |
mel_spectrogram = np.pad(mel_spectrogram, ((0, 0), (0, MAX_TIME_STEPS - mel_spectrogram.shape[1])), mode='constant') | |
else: | |
mel_spectrogram = mel_spectrogram[:, :MAX_TIME_STEPS] | |
X_test.append(mel_spectrogram) | |
X_test = np.array(X_test) | |
y_pred = model.predict(X_test) | |
y_pred_classes = np.argmax(y_pred, axis=1) | |
# Parse the true labels | |
true_labels = {} | |
for line in protocol_file.getvalue().decode("utf-8").splitlines(): | |
parts = line.strip().split() | |
if len(parts) > 1: | |
file_name = parts[0] | |
label = parts[-1] | |
true_labels[file_name] = 1 if label == "bonafide" else 0 | |
y_true = np.array([label for label in true_labels.values()]) | |
# Confusion Matrix | |
cm = confusion_matrix(y_true, y_pred_classes) | |
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Spoof", "Bonafide"]).plot(cmap=plt.cm.Blues) | |
st.pyplot(plt) | |
# ROC Curve | |
y_pred_prob = y_pred[:, 1] | |
fpr, tpr, _ = roc_curve(y_true, y_pred_prob) | |
roc_auc = auc(fpr, tpr) | |
plt.figure() | |
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})') | |
plt.legend(loc="lower right") | |
st.pyplot(plt) | |
# Precision-Recall Curve | |
precision, recall, _ = precision_recall_curve(y_true, y_pred_prob) | |
avg_precision = average_precision_score(y_true, y_pred_prob) | |
plt.figure() | |
plt.plot(recall, precision, color='darkorange', lw=2, label=f'Avg. Precision = {avg_precision:.2f}') | |
st.pyplot(plt) | |
if task == "Visualize Spectrogram": | |
st.header("Visualize Mel Spectrogram") | |
test_files = st.file_uploader("Upload Test FLAC Files", accept_multiple_files=True, type='flac') | |
if test_files: | |
for file in test_files: | |
audio, _ = librosa.load(file, sr=SAMPLE_RATE, duration=DURATION) | |
mel_spectrogram = librosa.feature.melspectrogram(y=audio, sr=SAMPLE_RATE, n_mels=N_MELS) | |
mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max) | |
plt.figure(figsize=(10, 6)) | |
librosa.display.specshow(mel_spectrogram, x_axis='time', y_axis='mel', sr=SAMPLE_RATE) | |
plt.colorbar(format='%+2.0f dB') | |
plt.title(f'Mel Spectrogram - {file.name}') | |
st.pyplot(plt) | |