Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
|
|
2 |
import librosa
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
-
import
|
6 |
from transformers import AutoModelForAudioClassification, ASTFeatureExtractor
|
7 |
import random
|
8 |
|
@@ -10,8 +10,16 @@ import random
|
|
10 |
model = AutoModelForAudioClassification.from_pretrained("./")
|
11 |
feature_extractor = ASTFeatureExtractor.from_pretrained("./")
|
12 |
|
13 |
-
def
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
|
16 |
return features.input_values
|
17 |
|
@@ -35,14 +43,22 @@ def predict_voice(audio_file_path):
|
|
35 |
predicted_index = logits.argmax()
|
36 |
label = model.config.id2label[predicted_index.item()]
|
37 |
confidence = torch.softmax(logits, dim=1).max().item() * 100
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
39 |
except Exception as e:
|
40 |
-
return f"Error during processing: {e}"
|
41 |
|
42 |
iface = gr.Interface(
|
43 |
fn=predict_voice,
|
44 |
inputs=gr.Audio(label="Upload Audio File", type="filepath"),
|
45 |
-
outputs=
|
|
|
|
|
|
|
46 |
title="Voice Authenticity Detection",
|
47 |
description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results."
|
48 |
)
|
|
|
2 |
import librosa
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
from transformers import AutoModelForAudioClassification, ASTFeatureExtractor
|
7 |
import random
|
8 |
|
|
|
10 |
model = AutoModelForAudioClassification.from_pretrained("./")
|
11 |
feature_extractor = ASTFeatureExtractor.from_pretrained("./")
|
12 |
|
13 |
+
def plot_waveform(waveform, sr):
|
14 |
+
plt.figure(figsize=(10, 3))
|
15 |
+
plt.title('Waveform')
|
16 |
+
plt.ylabel('Amplitude')
|
17 |
+
plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform)
|
18 |
+
plt.xlabel('Time (s)')
|
19 |
+
# Instead of plt.show(), we'll return the figure
|
20 |
+
return plt.gcf()
|
21 |
+
|
22 |
+
def custom_feature_extraction(audio, sr=16000, target_length=1024):
|
23 |
features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
|
24 |
return features.input_values
|
25 |
|
|
|
43 |
predicted_index = logits.argmax()
|
44 |
label = model.config.id2label[predicted_index.item()]
|
45 |
confidence = torch.softmax(logits, dim=1).max().item() * 100
|
46 |
+
|
47 |
+
# Plot the waveform using the modified function
|
48 |
+
waveform_plot = plot_waveform(waveform, sample_rate)
|
49 |
+
|
50 |
+
# Return both the label and the plot
|
51 |
+
return f"The voice is classified as '{label}' with a confidence of {confidence:.2f}%.", waveform_plot
|
52 |
except Exception as e:
|
53 |
+
return f"Error during processing: {e}", None
|
54 |
|
55 |
iface = gr.Interface(
|
56 |
fn=predict_voice,
|
57 |
inputs=gr.Audio(label="Upload Audio File", type="filepath"),
|
58 |
+
outputs=[
|
59 |
+
gr.Textbox(label="Prediction"),
|
60 |
+
gr.Plot(label="Waveform") # Gradio will handle the rendering of the matplotlib figure
|
61 |
+
],
|
62 |
title="Voice Authenticity Detection",
|
63 |
description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results."
|
64 |
)
|