Kabatubare commited on
Commit
6aa52fc
·
verified ·
1 Parent(s): 05e6aba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -7
app.py CHANGED
@@ -6,7 +6,7 @@ import matplotlib.pyplot as plt
6
  from transformers import AutoModelForAudioClassification, ASTFeatureExtractor
7
  import random
8
 
9
- # Model and feature extractor loading from the specified local path
10
  model = AutoModelForAudioClassification.from_pretrained("./")
11
  feature_extractor = ASTFeatureExtractor.from_pretrained("./")
12
 
@@ -16,7 +16,16 @@ def plot_waveform(waveform, sr):
16
  plt.ylabel('Amplitude')
17
  plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform)
18
  plt.xlabel('Time (s)')
19
- # Instead of plt.show(), we'll return the figure
 
 
 
 
 
 
 
 
 
20
  return plt.gcf()
21
 
22
  def custom_feature_extraction(audio, sr=16000, target_length=1024):
@@ -44,20 +53,24 @@ def predict_voice(audio_file_path):
44
  label = model.config.id2label[predicted_index.item()]
45
  confidence = torch.softmax(logits, dim=1).max().item() * 100
46
 
47
- # Plot the waveform using the modified function
48
  waveform_plot = plot_waveform(waveform, sample_rate)
 
49
 
50
- # Return both the label and the plot
51
- return f"The voice is classified as '{label}' with a confidence of {confidence:.2f}%.", waveform_plot
 
 
 
52
  except Exception as e:
53
- return f"Error during processing: {e}", None
54
 
55
  iface = gr.Interface(
56
  fn=predict_voice,
57
  inputs=gr.Audio(label="Upload Audio File", type="filepath"),
58
  outputs=[
59
  gr.Textbox(label="Prediction"),
60
- gr.Plot(label="Waveform") # Gradio will handle the rendering of the matplotlib figure
 
61
  ],
62
  title="Voice Authenticity Detection",
63
  description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results."
 
6
  from transformers import AutoModelForAudioClassification, ASTFeatureExtractor
7
  import random
8
 
9
+ # Model and feature extractor loading
10
  model = AutoModelForAudioClassification.from_pretrained("./")
11
  feature_extractor = ASTFeatureExtractor.from_pretrained("./")
12
 
 
16
  plt.ylabel('Amplitude')
17
  plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform)
18
  plt.xlabel('Time (s)')
19
+ return plt.gcf()
20
+
21
+ def plot_spectrogram(waveform, sr):
22
+ S = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=128)
23
+ S_DB = librosa.power_to_db(S, ref=np.max)
24
+ plt.figure(figsize=(10, 4))
25
+ librosa.display.specshow(S_DB, sr=sr, x_axis='time', y_axis='mel')
26
+ plt.title('Mel Spectrogram')
27
+ plt.colorbar(format='%+2.0f dB')
28
+ plt.tight_layout()
29
  return plt.gcf()
30
 
31
  def custom_feature_extraction(audio, sr=16000, target_length=1024):
 
53
  label = model.config.id2label[predicted_index.item()]
54
  confidence = torch.softmax(logits, dim=1).max().item() * 100
55
 
 
56
  waveform_plot = plot_waveform(waveform, sample_rate)
57
+ spectrogram_plot = plot_spectrogram(waveform, sample_rate)
58
 
59
+ return (
60
+ f"The voice is classified as '{label}' with a confidence of {confidence:.2f}%.",
61
+ waveform_plot,
62
+ spectrogram_plot
63
+ )
64
  except Exception as e:
65
+ return f"Error during processing: {e}", None, None
66
 
67
  iface = gr.Interface(
68
  fn=predict_voice,
69
  inputs=gr.Audio(label="Upload Audio File", type="filepath"),
70
  outputs=[
71
  gr.Textbox(label="Prediction"),
72
+ gr.Plot(label="Waveform"),
73
+ gr.Plot(label="Spectrogram")
74
  ],
75
  title="Voice Authenticity Detection",
76
  description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results."