Kabatubare's picture
Update app.py
5bde6bc verified
raw
history blame
2.28 kB
import gradio as gr
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
import numpy as np
import torch
from torch.nn.functional import softmax
import soundfile as sf
# Path to the local directory where the model files are stored within the Space
local_model_path = "./"
# Initialize the feature extractor and model from the local files
extractor = AutoFeatureExtractor.from_pretrained(local_model_path)
model = AutoModelForAudioClassification.from_pretrained(local_model_path)
def predict_voice(audio_file):
"""
Predicts whether a voice is real or spoofed from an audio file.
Args:
audio_file: The input audio file to be classified.
Returns:
A string with the prediction and confidence level.
"""
# Gradio passes the audio file as a tuple (file_name, file_path). We only need the file_path.
audio_file_path = audio_file[1]
# Load the audio file. Adjust the loading mechanism based on your audio file format.
waveform, sample_rate = sf.read(audio_file_path)
# Convert the input audio file to model's expected format.
inputs = extractor(waveform, return_tensors="pt", sampling_rate=sample_rate)
# Generate predictions from the model.
with torch.no_grad(): # Ensure no gradients are calculated
outputs = model(**inputs)
# Extract logits and compute the class with the highest score.
logits = outputs.logits
predicted_index = logits.argmax()
# Translate index to label
label = model.config.id2label[predicted_index.item()]
# Calculate the confidence of the prediction using softmax.
confidence = softmax(logits, dim=1).max().item() * 100
# Prepare the output string.
result = f"The voice is classified as '{label}' with a confidence of {confidence:.2f}%."
return result
# Setting up the Gradio interface
iface = gr.Interface(
fn=predict_voice,
inputs=gr.Audio(type="filepath", label="Upload Audio File"), # Corrected usage
outputs=gr.Textbox(label="Prediction"),
title="Voice Authenticity Detection",
description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results.",
theme="huggingface"
)
# Run the Gradio interface
iface.launch()