|
import gradio as gr |
|
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor |
|
import torch |
|
import torchaudio |
|
|
|
model_name = "Mrkomiljon/voiceGUARD/wav2vec2_finetuned_model" |
|
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name) |
|
processor = Wav2Vec2Processor.from_pretrained(model_name) |
|
model.eval() |
|
|
|
def classify_audio(audio_file): |
|
waveform, sample_rate = torchaudio.load(audio_file) |
|
if sample_rate != 16000: |
|
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) |
|
waveform = resampler(waveform) |
|
if waveform.size(1) > 16000 * 10: |
|
waveform = waveform[:, :16000 * 10] |
|
elif waveform.size(1) < 16000 * 10: |
|
waveform = torch.nn.functional.pad(waveform, (0, 16000 * 10 - waveform.size(1))) |
|
if waveform.ndim > 1: |
|
waveform = waveform[0] |
|
inputs = processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt") |
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
predicted_label = logits.argmax(dim=-1).item() |
|
return predicted_label |
|
|
|
interface = gr.Interface( |
|
fn=classify_audio, |
|
inputs=gr.Audio(source="upload", type="filepath"), |
|
outputs="label", |
|
title="Audio Classifier", |
|
description="Upload an audio file to classify its label as AI-generated or Real." |
|
) |
|
interface.launch() |
|
|