Kabatubare's picture
Update app.py
90d10cd verified
raw
history blame
3.91 kB
import gradio as gr
import librosa
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForAudioClassification, ASTFeatureExtractor
import random
import tempfile
# Model and feature extractor loading
model = AutoModelForAudioClassification.from_pretrained("./")
feature_extractor = ASTFeatureExtractor.from_pretrained("./")
def plot_waveform(waveform, sr):
plt.figure(figsize=(12, 4))
plt.title('Waveform')
plt.ylabel('Amplitude')
plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform)
plt.xlabel('Time (s)')
# Save plot to a temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
plt.savefig(temp_file.name)
plt.close() # Close the figure to free memory
return temp_file.name
def plot_spectrogram(waveform, sr):
S = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=128)
S_DB = librosa.power_to_db(S, ref=np.max)
plt.figure(figsize=(12, 6))
librosa.display.specshow(S_DB, sr=sr, x_axis='time', y_axis='mel')
plt.title('Mel Spectrogram')
plt.colorbar(format='%+2.0f dB')
plt.tight_layout()
# Save plot to a temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
plt.savefig(temp_file.name)
plt.close() # Close the figure to free memory
return temp_file.name
def custom_feature_extraction(audio, sr=16000, target_length=1024):
features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
return features.input_values
def apply_time_shift(waveform, max_shift_fraction=0.1):
shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform)))
return np.roll(waveform, shift)
def predict_voice(audio_file_path):
try:
waveform, sample_rate = librosa.load(audio_file_path, sr=feature_extractor.sampling_rate, mono=True)
augmented_waveform = apply_time_shift(waveform)
original_features = custom_feature_extraction(waveform, sr=sample_rate)
augmented_features = custom_feature_extraction(augmented_waveform, sr=sample_rate)
with torch.no_grad():
outputs_original = model(original_features)
outputs_augmented = model(augmented_features)
logits = (outputs_original.logits + outputs_augmented.logits) / 2
predicted_index = logits.argmax()
original_label = model.config.id2label[predicted_index.item()]
confidence = torch.softmax(logits, dim=1).max().item() * 100
# Map original labels to new labels
label_mapping = {
"Spoof": "AI-generated Clone",
"Bonafide": "Real Human Voice"
}
# Use the original label to get the new label
new_label = label_mapping.get(original_label, "Unknown") # Default to "Unknown" if label not found
waveform_plot = plot_waveform(waveform, sample_rate)
spectrogram_plot = plot_spectrogram(waveform, sample_rate)
return (
f"The voice is classified as '{new_label}' with a confidence of {confidence:.2f}%.",
waveform_plot,
spectrogram_plot
)
except Exception as e:
return f"Error during processing: {e}", None, None
iface = gr.Interface(
fn=predict_voice,
inputs=gr.Audio(label="Upload Audio File", type="filepath"),
outputs=[
gr.Textbox(label="Prediction"),
gr.Image(label="Waveform", tool="editor"), # Use Image component for waveform
gr.Image(label="Spectrogram", tool="editor") # Use Image component for spectrogram
],
title="Voice Clone Detection",
description="Detects whether a voice is real or an AI-generated clone. Upload an audio file to see the results."
)
iface.launch()