Spaces:
Runtime error
Runtime error
File size: 3,900 Bytes
fe0bcff dfabd2f 30a5efb 05e6aba 84de51b fe0bcff ccce4a0 f0dd070 6aa52fc fe0bcff 84de51b ee91d94 05e6aba ccce4a0 05e6aba ccce4a0 6aa52fc ccce4a0 6aa52fc ccce4a0 05e6aba 84de51b 0c35856 fe0bcff 411539a 50facbf 411539a 84de51b fe0bcff 84de51b a29043b fe0bcff 84de51b 53b1abc fe0bcff 84de51b fe0bcff 16802ac fe0bcff 05e6aba 16802ac 05e6aba 6aa52fc 05e6aba accc14f 16802ac accc14f e8e81bf 6aa52fc ee91d94 15eca51 637d0ca 30c595f 05e6aba accc14f 9f2fc99 05e6aba b860c29 accc14f fe0bcff 90d10cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import gradio as gr
import librosa
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForAudioClassification, ASTFeatureExtractor
import random
import tempfile
# Model and feature extractor loading
model = AutoModelForAudioClassification.from_pretrained("./")
feature_extractor = ASTFeatureExtractor.from_pretrained("./")
def plot_waveform(waveform, sr):
plt.figure(figsize=(12, 4))
plt.title('Waveform')
plt.ylabel('Amplitude')
plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform)
plt.xlabel('Time (s)')
# Save plot to a temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
plt.savefig(temp_file.name)
plt.close() # Close the figure to free memory
return temp_file.name
def plot_spectrogram(waveform, sr):
S = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=128)
S_DB = librosa.power_to_db(S, ref=np.max)
plt.figure(figsize=(12, 6))
librosa.display.specshow(S_DB, sr=sr, x_axis='time', y_axis='mel')
plt.title('Mel Spectrogram')
plt.colorbar(format='%+2.0f dB')
plt.tight_layout()
# Save plot to a temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
plt.savefig(temp_file.name)
plt.close() # Close the figure to free memory
return temp_file.name
def custom_feature_extraction(audio, sr=16000, target_length=1024):
features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
return features.input_values
def apply_time_shift(waveform, max_shift_fraction=0.1):
shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform)))
return np.roll(waveform, shift)
def predict_voice(audio_file_path):
try:
waveform, sample_rate = librosa.load(audio_file_path, sr=feature_extractor.sampling_rate, mono=True)
augmented_waveform = apply_time_shift(waveform)
original_features = custom_feature_extraction(waveform, sr=sample_rate)
augmented_features = custom_feature_extraction(augmented_waveform, sr=sample_rate)
with torch.no_grad():
outputs_original = model(original_features)
outputs_augmented = model(augmented_features)
logits = (outputs_original.logits + outputs_augmented.logits) / 2
predicted_index = logits.argmax()
original_label = model.config.id2label[predicted_index.item()]
confidence = torch.softmax(logits, dim=1).max().item() * 100
# Map original labels to new labels
label_mapping = {
"Spoof": "AI-generated Clone",
"Bonafide": "Real Human Voice"
}
# Use the original label to get the new label
new_label = label_mapping.get(original_label, "Unknown") # Default to "Unknown" if label not found
waveform_plot = plot_waveform(waveform, sample_rate)
spectrogram_plot = plot_spectrogram(waveform, sample_rate)
return (
f"The voice is classified as '{new_label}' with a confidence of {confidence:.2f}%.",
waveform_plot,
spectrogram_plot
)
except Exception as e:
return f"Error during processing: {e}", None, None
iface = gr.Interface(
fn=predict_voice,
inputs=gr.Audio(label="Upload Audio File", type="filepath"),
outputs=[
gr.Textbox(label="Prediction"),
gr.Image(label="Waveform"), # Adjusted to remove unsupported 'tool' argument
gr.Image(label="Spectrogram") # Adjusted to remove unsupported 'tool' argument
],
title="Voice Clone Detection",
description="Detects whether a voice is real or an AI-generated clone. Upload an audio file to see the results."
)
iface.launch() |