Kabatubare commited on
Commit
05e6aba
·
verified ·
1 Parent(s): 84de51b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -6
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import librosa
3
  import numpy as np
4
  import torch
5
- import torch.nn.functional as F
6
  from transformers import AutoModelForAudioClassification, ASTFeatureExtractor
7
  import random
8
 
@@ -10,8 +10,16 @@ import random
10
  model = AutoModelForAudioClassification.from_pretrained("./")
11
  feature_extractor = ASTFeatureExtractor.from_pretrained("./")
12
 
13
- def custom_feature_extraction(audio, sr=16000, n_mels=128, target_length=1024):
14
- # Using the loaded feature extractor
 
 
 
 
 
 
 
 
15
  features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
16
  return features.input_values
17
 
@@ -35,14 +43,22 @@ def predict_voice(audio_file_path):
35
  predicted_index = logits.argmax()
36
  label = model.config.id2label[predicted_index.item()]
37
  confidence = torch.softmax(logits, dim=1).max().item() * 100
38
- return f"The voice is classified as '{label}' with a confidence of {confidence:.2f}%."
 
 
 
 
 
39
  except Exception as e:
40
- return f"Error during processing: {e}"
41
 
42
  iface = gr.Interface(
43
  fn=predict_voice,
44
  inputs=gr.Audio(label="Upload Audio File", type="filepath"),
45
- outputs=gr.Textbox(label="Prediction"),
 
 
 
46
  title="Voice Authenticity Detection",
47
  description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results."
48
  )
 
2
  import librosa
3
  import numpy as np
4
  import torch
5
+ import matplotlib.pyplot as plt
6
  from transformers import AutoModelForAudioClassification, ASTFeatureExtractor
7
  import random
8
 
 
10
  model = AutoModelForAudioClassification.from_pretrained("./")
11
  feature_extractor = ASTFeatureExtractor.from_pretrained("./")
12
 
13
+ def plot_waveform(waveform, sr):
14
+ plt.figure(figsize=(10, 3))
15
+ plt.title('Waveform')
16
+ plt.ylabel('Amplitude')
17
+ plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform)
18
+ plt.xlabel('Time (s)')
19
+ # Instead of plt.show(), we'll return the figure
20
+ return plt.gcf()
21
+
22
+ def custom_feature_extraction(audio, sr=16000, target_length=1024):
23
  features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
24
  return features.input_values
25
 
 
43
  predicted_index = logits.argmax()
44
  label = model.config.id2label[predicted_index.item()]
45
  confidence = torch.softmax(logits, dim=1).max().item() * 100
46
+
47
+ # Plot the waveform using the modified function
48
+ waveform_plot = plot_waveform(waveform, sample_rate)
49
+
50
+ # Return both the label and the plot
51
+ return f"The voice is classified as '{label}' with a confidence of {confidence:.2f}%.", waveform_plot
52
  except Exception as e:
53
+ return f"Error during processing: {e}", None
54
 
55
  iface = gr.Interface(
56
  fn=predict_voice,
57
  inputs=gr.Audio(label="Upload Audio File", type="filepath"),
58
+ outputs=[
59
+ gr.Textbox(label="Prediction"),
60
+ gr.Plot(label="Waveform") # Gradio will handle the rendering of the matplotlib figure
61
+ ],
62
  title="Voice Authenticity Detection",
63
  description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results."
64
  )