xiaoyao9184 commited on
Commit
ef7bf13
·
verified ·
1 Parent(s): cb6069f

Synced repo using 'sync_with_huggingface' Github Action

Browse files
Files changed (1) hide show
  1. gradio_app.py +24 -15
gradio_app.py CHANGED
@@ -50,22 +50,27 @@ def generate_msg_pt_by_format_string(format_string, bytes_count):
50
  msg_pt = torch.tensor(binary_list, dtype=torch.int32)
51
  return msg_pt.to(device)
52
 
53
- def embed_watermark(audio, sr, msg):
54
- # We add the batch dimension to the single audio to mimic the batch watermarking
55
- original_audio = audio.unsqueeze(0).to(device)
56
 
57
  # If the audio has more than one channel, average all channels to 1 channel
58
  if original_audio.shape[0] > 1:
59
- original_audio = torch.mean(original_audio, dim=0, keepdim=True)
 
 
 
 
 
60
 
61
- watermark = generator.get_watermark(original_audio, sr, message=msg)
62
 
63
- watermarked_audio = original_audio + watermark
64
 
65
  # Alternatively, you can also call forward() function directly with different tune-down / tune-up rate
66
  # watermarked_audio = generator(audios, sample_rate=sr, alpha=1)
67
 
68
- return watermarked_audio
 
69
 
70
  def generate_format_string_by_msg_pt(msg_pt, bytes_count):
71
  hex_length = bytes_count * 2
@@ -80,20 +85,24 @@ def generate_format_string_by_msg_pt(msg_pt, bytes_count):
80
  return hex_string, format_hex
81
 
82
  def detect_watermark(audio, sr):
83
- # We add the batch dimension to the single audio to mimic the batch watermarking
84
- watermarked_audio = audio.unsqueeze(0).to(device)
85
 
86
  # If the audio has more than one channel, average all channels to 1 channel
87
  if watermarked_audio.shape[0] > 1:
88
- watermarked_audio = torch.mean(watermarked_audio, dim=0, keepdim=True)
 
 
 
 
 
89
 
90
- result, message = detector.detect_watermark(watermarked_audio, sr)
91
 
92
  # pred_prob is a tensor of size batch x 2 x frames, indicating the probability (positive and negative) of watermarking for each frame
93
  # A watermarked audio should have pred_prob[:, 1, :] > 0.5
94
  # message_prob is a tensor of size batch x 16, indicating of the probability of each bit to be 1.
95
  # message will be a random tensor if the detector detects no watermarking from the audio
96
- pred_prob, message_prob = detector(watermarked_audio, sr)
97
 
98
  return result, message, pred_prob, message_prob
99
 
@@ -195,11 +204,11 @@ with gr.Blocks(title="AudioSeal") as demo:
195
  audio_original, rate = load_audio(file)
196
  msg_pt = generate_msg_pt_by_format_string(msg, generator_nbytes)
197
  audio_watermarked = embed_watermark(audio_original, rate, msg_pt)
198
- output = rate, audio_watermarked.squeeze().detach().cpu().numpy().astype(np.float32)
199
 
200
  if show_specgram:
201
- fig_original = get_waveform_and_specgram(audio_original.squeeze(), rate)
202
- fig_watermarked = get_waveform_and_specgram(audio_watermarked.squeeze(), rate)
203
  return [
204
  output,
205
  gr.update(visible=True, value=fig_original),
 
50
  msg_pt = torch.tensor(binary_list, dtype=torch.int32)
51
  return msg_pt.to(device)
52
 
53
+ def embed_watermark(audio, sr, msg_pt):
54
+ original_audio = audio.to(device)
 
55
 
56
  # If the audio has more than one channel, average all channels to 1 channel
57
  if original_audio.shape[0] > 1:
58
+ mono_audio = torch.mean(original_audio, dim=0, keepdim=True)
59
+ else:
60
+ mono_audio = original_audio
61
+
62
+ # We add the batch dimension to the single audio to mimic the batch watermarking
63
+ batched_audio = mono_audio.unsqueeze(0)
64
 
65
+ watermark = generator.get_watermark(batched_audio, sr, message=msg_pt)
66
 
67
+ watermarked_audio = batched_audio + watermark
68
 
69
  # Alternatively, you can also call forward() function directly with different tune-down / tune-up rate
70
  # watermarked_audio = generator(audios, sample_rate=sr, alpha=1)
71
 
72
+ # Need remove batch dimension and to cpu
73
+ return watermarked_audio.squeeze(0).detach().cpu()
74
 
75
  def generate_format_string_by_msg_pt(msg_pt, bytes_count):
76
  hex_length = bytes_count * 2
 
85
  return hex_string, format_hex
86
 
87
  def detect_watermark(audio, sr):
88
+ watermarked_audio = audio.to(device)
 
89
 
90
  # If the audio has more than one channel, average all channels to 1 channel
91
  if watermarked_audio.shape[0] > 1:
92
+ mono_audio = torch.mean(watermarked_audio, dim=0, keepdim=True)
93
+ else:
94
+ mono_audio = watermarked_audio
95
+
96
+ # We add the batch dimension to the single audio to mimic the batch watermarking
97
+ batched_audio = mono_audio.unsqueeze(0)
98
 
99
+ result, message = detector.detect_watermark(batched_audio, sr)
100
 
101
  # pred_prob is a tensor of size batch x 2 x frames, indicating the probability (positive and negative) of watermarking for each frame
102
  # A watermarked audio should have pred_prob[:, 1, :] > 0.5
103
  # message_prob is a tensor of size batch x 16, indicating of the probability of each bit to be 1.
104
  # message will be a random tensor if the detector detects no watermarking from the audio
105
+ pred_prob, message_prob = detector(batched_audio, sr)
106
 
107
  return result, message, pred_prob, message_prob
108
 
 
204
  audio_original, rate = load_audio(file)
205
  msg_pt = generate_msg_pt_by_format_string(msg, generator_nbytes)
206
  audio_watermarked = embed_watermark(audio_original, rate, msg_pt)
207
+ output = rate, audio_watermarked.squeeze().numpy().astype(np.float32)
208
 
209
  if show_specgram:
210
+ fig_original = get_waveform_and_specgram(audio_original, rate)
211
+ fig_watermarked = get_waveform_and_specgram(audio_watermarked, rate)
212
  return [
213
  output,
214
  gr.update(visible=True, value=fig_original),