Kabatubare's picture
Update app.py
6aa52fc verified
raw
history blame
3.02 kB
import gradio as gr
import librosa
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForAudioClassification, ASTFeatureExtractor
import random
# Model and feature extractor loading
model = AutoModelForAudioClassification.from_pretrained("./")
feature_extractor = ASTFeatureExtractor.from_pretrained("./")
def plot_waveform(waveform, sr):
plt.figure(figsize=(10, 3))
plt.title('Waveform')
plt.ylabel('Amplitude')
plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform)
plt.xlabel('Time (s)')
return plt.gcf()
def plot_spectrogram(waveform, sr):
S = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=128)
S_DB = librosa.power_to_db(S, ref=np.max)
plt.figure(figsize=(10, 4))
librosa.display.specshow(S_DB, sr=sr, x_axis='time', y_axis='mel')
plt.title('Mel Spectrogram')
plt.colorbar(format='%+2.0f dB')
plt.tight_layout()
return plt.gcf()
def custom_feature_extraction(audio, sr=16000, target_length=1024):
features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
return features.input_values
def apply_time_shift(waveform, max_shift_fraction=0.1):
shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform)))
return np.roll(waveform, shift)
def predict_voice(audio_file_path):
try:
waveform, sample_rate = librosa.load(audio_file_path, sr=feature_extractor.sampling_rate, mono=True)
augmented_waveform = apply_time_shift(waveform)
original_features = custom_feature_extraction(waveform, sr=sample_rate)
augmented_features = custom_feature_extraction(augmented_waveform, sr=sample_rate)
with torch.no_grad():
outputs_original = model(original_features)
outputs_augmented = model(augmented_features)
logits = (outputs_original.logits + outputs_augmented.logits) / 2
predicted_index = logits.argmax()
label = model.config.id2label[predicted_index.item()]
confidence = torch.softmax(logits, dim=1).max().item() * 100
waveform_plot = plot_waveform(waveform, sample_rate)
spectrogram_plot = plot_spectrogram(waveform, sample_rate)
return (
f"The voice is classified as '{label}' with a confidence of {confidence:.2f}%.",
waveform_plot,
spectrogram_plot
)
except Exception as e:
return f"Error during processing: {e}", None, None
iface = gr.Interface(
fn=predict_voice,
inputs=gr.Audio(label="Upload Audio File", type="filepath"),
outputs=[
gr.Textbox(label="Prediction"),
gr.Plot(label="Waveform"),
gr.Plot(label="Spectrogram")
],
title="Voice Authenticity Detection",
description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results."
)
iface.launch()