import gradio as gr import librosa import numpy as np import torch import matplotlib.pyplot as plt from transformers import AutoModelForAudioClassification, ASTFeatureExtractor import random import tempfile # Model and feature extractor loading model = AutoModelForAudioClassification.from_pretrained("./") feature_extractor = ASTFeatureExtractor.from_pretrained("./") def plot_waveform(waveform, sr): plt.figure(figsize=(12, 4)) plt.title('Waveform') plt.ylabel('Amplitude') plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform) plt.xlabel('Time (s)') # Save plot to a temporary file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./') plt.savefig(temp_file.name) plt.close() # Close the figure to free memory return temp_file.name def plot_spectrogram(waveform, sr): S = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=128) S_DB = librosa.power_to_db(S, ref=np.max) plt.figure(figsize=(12, 6)) librosa.display.specshow(S_DB, sr=sr, x_axis='time', y_axis='mel') plt.title('Mel Spectrogram') plt.colorbar(format='%+2.0f dB') plt.tight_layout() # Save plot to a temporary file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./') plt.savefig(temp_file.name) plt.close() # Close the figure to free memory return temp_file.name def custom_feature_extraction(audio, sr=16000, target_length=1024): features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length) return features.input_values def apply_time_shift(waveform, max_shift_fraction=0.1): shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform))) return np.roll(waveform, shift) def predict_voice(audio_file_path): try: waveform, sample_rate = librosa.load(audio_file_path, sr=feature_extractor.sampling_rate, mono=True) augmented_waveform = apply_time_shift(waveform) original_features = custom_feature_extraction(waveform, sr=sample_rate) augmented_features = custom_feature_extraction(augmented_waveform, sr=sample_rate) with torch.no_grad(): outputs_original = model(original_features) outputs_augmented = model(augmented_features) logits = (outputs_original.logits + outputs_augmented.logits) / 2 predicted_index = logits.argmax() original_label = model.config.id2label[predicted_index.item()] confidence = torch.softmax(logits, dim=1).max().item() * 100 # Map original labels to new labels label_mapping = { "Spoof": "AI-generated Clone", "Bonafide": "Real Human Voice" } # Use the original label to get the new label new_label = label_mapping.get(original_label, "Unknown") # Default to "Unknown" if label not found waveform_plot = plot_waveform(waveform, sample_rate) spectrogram_plot = plot_spectrogram(waveform, sample_rate) return ( f"The voice is classified as '{new_label}' with a confidence of {confidence:.2f}%.", waveform_plot, spectrogram_plot ) except Exception as e: return f"Error during processing: {e}", None, None iface = gr.Interface( fn=predict_voice, inputs=gr.Audio(label="Upload Audio File", type="filepath"), outputs=[ gr.Textbox(label="Prediction"), gr.Image(label="Waveform"), # Adjusted to remove unsupported 'tool' argument gr.Image(label="Spectrogram") # Adjusted to remove unsupported 'tool' argument ], title="Voice Clone Detection", description="Detects whether a voice is real or an AI-generated clone. Upload an audio file to see the results." ) iface.launch()