Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification | |
import numpy as np | |
import torch | |
from torch.nn.functional import softmax | |
import soundfile as sf | |
# Path to the local directory where the model files are stored within the Space | |
local_model_path = "./" | |
# Initialize the feature extractor and model from the local files | |
extractor = AutoFeatureExtractor.from_pretrained(local_model_path) | |
model = AutoModelForAudioClassification.from_pretrained(local_model_path) | |
def predict_voice(audio_file): | |
""" | |
Predicts whether a voice is real or spoofed from an audio file. | |
Args: | |
audio_file: The input audio file to be classified. | |
Returns: | |
A string with the prediction and confidence level. | |
""" | |
# Gradio passes the audio file as a tuple (file_name, file_path). We only need the file_path. | |
audio_file_path = audio_file[1] | |
# Load the audio file. Adjust the loading mechanism based on your audio file format. | |
waveform, sample_rate = sf.read(audio_file_path) | |
# Convert the input audio file to model's expected format. | |
inputs = extractor(waveform, return_tensors="pt", sampling_rate=sample_rate) | |
# Generate predictions from the model. | |
with torch.no_grad(): # Ensure no gradients are calculated | |
outputs = model(**inputs) | |
# Extract logits and compute the class with the highest score. | |
logits = outputs.logits | |
predicted_index = logits.argmax() | |
# Translate index to label | |
label = model.config.id2label[predicted_index.item()] | |
# Calculate the confidence of the prediction using softmax. | |
confidence = softmax(logits, dim=1).max().item() * 100 | |
# Prepare the output string. | |
result = f"The voice is classified as '{label}' with a confidence of {confidence:.2f}%." | |
return result | |
# Setting up the Gradio interface | |
iface = gr.Interface( | |
fn=predict_voice, | |
inputs=gr.Audio(type="filepath", label="Upload Audio File"), # Corrected usage | |
outputs=gr.Textbox(label="Prediction"), | |
title="Voice Authenticity Detection", | |
description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results.", | |
theme="huggingface" | |
) | |
# Run the Gradio interface | |
iface.launch() | |