Kabatubare's picture
Update app.py
9f2fc99 verified
import gradio as gr
import librosa
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForAudioClassification, ASTFeatureExtractor
import random
import tempfile
# Model and feature extractor loading
model = AutoModelForAudioClassification.from_pretrained("./")
feature_extractor = ASTFeatureExtractor.from_pretrained("./")
def plot_waveform(waveform, sr):
plt.figure(figsize=(12, 4))
plt.title('Waveform')
plt.ylabel('Amplitude')
plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform)
plt.xlabel('Time (s)')
# Save plot to a temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
plt.savefig(temp_file.name)
plt.close() # Close the figure to free memory
return temp_file.name
def plot_spectrogram(waveform, sr):
S = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=128)
S_DB = librosa.power_to_db(S, ref=np.max)
plt.figure(figsize=(12, 6))
librosa.display.specshow(S_DB, sr=sr, x_axis='time', y_axis='mel')
plt.title('Mel Spectrogram')
plt.colorbar(format='%+2.0f dB')
plt.tight_layout()
# Save plot to a temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
plt.savefig(temp_file.name)
plt.close() # Close the figure to free memory
return temp_file.name
def custom_feature_extraction(audio, sr=16000, target_length=1024):
features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
return features.input_values
def apply_time_shift(waveform, max_shift_fraction=0.1):
shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform)))
return np.roll(waveform, shift)
def predict_voice(audio_file_path):
try:
waveform, sample_rate = librosa.load(audio_file_path, sr=feature_extractor.sampling_rate, mono=True)
augmented_waveform = apply_time_shift(waveform)
original_features = custom_feature_extraction(waveform, sr=sample_rate)
augmented_features = custom_feature_extraction(augmented_waveform, sr=sample_rate)
with torch.no_grad():
outputs_original = model(original_features)
outputs_augmented = model(augmented_features)
logits = (outputs_original.logits + outputs_augmented.logits) / 2
predicted_index = logits.argmax()
original_label = model.config.id2label[predicted_index.item()]
confidence = torch.softmax(logits, dim=1).max().item() * 100
# Map original labels to new labels
label_mapping = {
"Spoof": "AI-generated Clone",
"Bonafide": "Real Human Voice"
}
# Use the original label to get the new label
new_label = label_mapping.get(original_label, "Unknown") # Default to "Unknown" if label not found
waveform_plot = plot_waveform(waveform, sample_rate)
spectrogram_plot = plot_spectrogram(waveform, sample_rate)
return (
f"The voice is classified as '{new_label}' with a confidence of {confidence:.2f}%.",
waveform_plot,
spectrogram_plot
)
except Exception as e:
return f"Error during processing: {e}", None, None
iface = gr.Interface(
fn=predict_voice,
inputs=gr.Audio(label="Upload Audio File", type="filepath"),
outputs=[
gr.Textbox(label="Prediction"),
gr.Image(label="Waveform"), # Adjusted to remove unsupported 'tool' argument
gr.Image(label="Spectrogram") # Adjusted to remove unsupported 'tool' argument
],
title="Voice Clone Detection",
description="Detects whether a voice is real or an AI-generated clone. Upload an audio file to see the results."
)
iface.launch()