MusicLM / audiocraft /modules /watermark.py
SunilGopal's picture
Upload 214 files
6a662e6 verified
raw
history blame
4.03 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
import random
import torch
def pad(
x_wm: torch.Tensor, central: bool = False
) -> tp.Tuple[torch.Tensor, torch.Tensor]:
"""Pad a watermarked signal at the begining and the end
Args:
x_wm (torch.Tensor) : watermarked audio
central (bool): Whether to mask the middle of the wave (around 34%) or the two tails
(beginning and ending frames)
Returns:
padded (torch.Tensor): padded signal
true_predictions(torch.Tensor): A binary mask where 1 represents
watermarked and 0 represents non-watermarked."""
# keep at leat 34% of watermarked signal
max_start = int(0.33 * x_wm.size(-1))
min_end = int(0.66 * x_wm.size(-1))
starts = torch.randint(0, max_start, size=(x_wm.size(0),))
ends = torch.randint(min_end, x_wm.size(-1), size=(x_wm.size(0),))
mask = torch.zeros_like(x_wm)
for i in range(x_wm.size(0)):
mask[i, :, starts[i]: ends[i]] = 1
if central:
mask = 1 - mask
padded = x_wm * mask
true_predictions = torch.cat([1 - mask, mask], dim=1)
return padded, true_predictions
def mix(
x: torch.Tensor, x_wm: torch.Tensor, window_size: float = 0.5, shuffle: bool = False
) -> tp.Tuple[torch.Tensor, torch.Tensor]:
"""
Mixes a window of the non-watermarked audio signal 'x' into the watermarked audio signal 'x_wm'.
This function takes two tensors of shape [batch, channels, frames], copies a window of 'x' with the specified
'window_size' into 'x_wm', and returns a new tensor that is a mix between the watermarked (1 - mix_percent %)
and non-watermarked audio (mix_percent %).
Args:
x (torch.Tensor): The non-watermarked audio signal tensor.
x_wm (torch.Tensor): The watermarked audio signal tensor.
window_size (float, optional): The percentage of 'x' to copy into 'x_wm' (between 0 and 1).
shuffle (bool): whether or no keep the mix from the same batch element
Returns:
tuple: A tuple containing two tensors:
- mixed_tensor (torch.Tensor): The resulting mixed audio signal tensor.
- mask (torch.Tensor): A binary mask where 1 represents watermarked and 0 represents non-watermarked.
Raises:
AssertionError: If 'window_size' is not between 0 and 1.
"""
assert 0 < window_size <= 1, "window_size should be between 0 and 1"
# Calculate the maximum starting point for the window
max_start_point = x.shape[-1] - int(window_size * x.shape[-1])
# Generate a random starting point within the adjusted valid range
start_point = random.randint(0, max_start_point)
# Calculate the window size in frames
total_frames = x.shape[-1]
window_frames = int(window_size * total_frames)
# Create a mask tensor to identify watermarked and non-watermarked portions
# it outputs two classes to match the detector output shape of [bsz, 2, frames]
# Copy the random window from 'x' to 'x_wm'
mixed = x_wm.detach().clone()
true_predictions = torch.cat(
[torch.zeros_like(mixed), torch.ones_like(mixed)], dim=1
)
# non-watermark class correct labels.
true_predictions[:, 0, start_point: start_point + window_frames] = 1.0
# watermarked class correct labels
true_predictions[:, 1, start_point: start_point + window_frames] = 0.0
if shuffle:
# Take the middle part from a random element of the batch
shuffle_idx = torch.randint(0, x.size(0), (x.size(0),))
mixed[:, :, start_point: start_point + window_frames] = x[shuffle_idx][
:, :, start_point: start_point + window_frames
]
else:
mixed[:, :, start_point: start_point + window_frames] = x[
:, :, start_point: start_point + window_frames
]
return mixed, true_predictions