Kabatubare commited on
Commit
8b7f20a
·
verified ·
1 Parent(s): d99c4fb
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -11,22 +11,29 @@ from PIL import Image
11
 
12
  def plot_spectrogram(waveform, sample_rate):
13
  """Plot and return a spectrogram."""
 
 
 
 
14
  spectrogram_transform = T.Spectrogram()
15
  spectrogram = spectrogram_transform(waveform)
16
  spectrogram_db = torchaudio.transforms.AmplitudeToDB()(spectrogram)
17
 
18
  plt.figure(figsize=(10, 4))
19
- plt.imshow(spectrogram_db[0].numpy(), cmap='hot', aspect='auto', origin='lower')
 
 
 
20
  plt.colorbar(format='%+2.0f dB')
21
  plt.title('Spectrogram')
22
  plt.xlabel('Time Frame')
23
  plt.ylabel('Frequency')
24
-
25
  buf = io.BytesIO()
26
  plt.savefig(buf, format='png')
27
  plt.close()
28
  buf.seek(0)
29
-
30
  return Image.open(buf)
31
 
32
  def detect_watermark(audio_file_path, threshold=0.99):
 
11
 
12
  def plot_spectrogram(waveform, sample_rate):
13
  """Plot and return a spectrogram."""
14
+ # Ensure waveform is 2D (channels, time) after squeeze
15
+ if waveform.ndim == 1:
16
+ waveform = waveform.unsqueeze(0) # Add a channel dimension if it's missing
17
+
18
  spectrogram_transform = T.Spectrogram()
19
  spectrogram = spectrogram_transform(waveform)
20
  spectrogram_db = torchaudio.transforms.AmplitudeToDB()(spectrogram)
21
 
22
  plt.figure(figsize=(10, 4))
23
+ # Ensure we're plotting the first channel for 2D data
24
+ if spectrogram_db.ndim == 3:
25
+ spectrogram_db = spectrogram_db[0]
26
+ plt.imshow(spectrogram_db.numpy(), cmap='hot', aspect='auto', origin='lower')
27
  plt.colorbar(format='%+2.0f dB')
28
  plt.title('Spectrogram')
29
  plt.xlabel('Time Frame')
30
  plt.ylabel('Frequency')
31
+
32
  buf = io.BytesIO()
33
  plt.savefig(buf, format='png')
34
  plt.close()
35
  buf.seek(0)
36
+
37
  return Image.open(buf)
38
 
39
  def detect_watermark(audio_file_path, threshold=0.99):