File size: 4,029 Bytes
6a662e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import typing as tp
import random

import torch


def pad(
    x_wm: torch.Tensor, central: bool = False
) -> tp.Tuple[torch.Tensor, torch.Tensor]:
    """Pad a watermarked signal at the begining and the end

    Args:
        x_wm (torch.Tensor) : watermarked audio
        central (bool): Whether to mask the middle of the wave (around 34%) or the two tails
            (beginning and ending frames)

    Returns:
        padded (torch.Tensor): padded signal
        true_predictions(torch.Tensor): A binary mask where 1 represents
        watermarked and 0 represents non-watermarked."""
    # keep at leat 34% of watermarked signal
    max_start = int(0.33 * x_wm.size(-1))
    min_end = int(0.66 * x_wm.size(-1))
    starts = torch.randint(0, max_start, size=(x_wm.size(0),))
    ends = torch.randint(min_end, x_wm.size(-1), size=(x_wm.size(0),))
    mask = torch.zeros_like(x_wm)
    for i in range(x_wm.size(0)):
        mask[i, :, starts[i]: ends[i]] = 1
    if central:
        mask = 1 - mask
    padded = x_wm * mask
    true_predictions = torch.cat([1 - mask, mask], dim=1)
    return padded, true_predictions


def mix(
    x: torch.Tensor, x_wm: torch.Tensor, window_size: float = 0.5, shuffle: bool = False
) -> tp.Tuple[torch.Tensor, torch.Tensor]:
    """
    Mixes a window of the non-watermarked audio signal 'x' into the watermarked audio signal 'x_wm'.

    This function takes two tensors of shape [batch, channels, frames], copies a window of 'x' with the specified
    'window_size' into 'x_wm', and returns a new tensor that is a mix between the watermarked (1 - mix_percent %)
    and non-watermarked audio (mix_percent %).

    Args:
        x (torch.Tensor): The non-watermarked audio signal tensor.
        x_wm (torch.Tensor): The watermarked audio signal tensor.
        window_size (float, optional): The percentage of 'x' to copy into 'x_wm' (between 0 and 1).
        shuffle (bool): whether or no keep the mix from the same batch element

    Returns:
        tuple: A tuple containing two tensors:
            - mixed_tensor (torch.Tensor): The resulting mixed audio signal tensor.
            - mask (torch.Tensor): A binary mask where 1 represents watermarked and 0 represents non-watermarked.

    Raises:
        AssertionError: If 'window_size' is not between 0 and 1.
    """
    assert 0 < window_size <= 1, "window_size should be between 0 and 1"

    # Calculate the maximum starting point for the window
    max_start_point = x.shape[-1] - int(window_size * x.shape[-1])

    # Generate a random starting point within the adjusted valid range
    start_point = random.randint(0, max_start_point)

    # Calculate the window size in frames
    total_frames = x.shape[-1]
    window_frames = int(window_size * total_frames)

    # Create a mask tensor to identify watermarked and non-watermarked portions
    # it outputs two classes to match the detector output shape of [bsz, 2, frames]
    # Copy the random window from 'x' to 'x_wm'
    mixed = x_wm.detach().clone()

    true_predictions = torch.cat(
        [torch.zeros_like(mixed), torch.ones_like(mixed)], dim=1
    )
    # non-watermark class correct labels.
    true_predictions[:, 0, start_point: start_point + window_frames] = 1.0
    # watermarked class correct labels
    true_predictions[:, 1, start_point: start_point + window_frames] = 0.0

    if shuffle:
        # Take the middle part from a random element of the batch
        shuffle_idx = torch.randint(0, x.size(0), (x.size(0),))
        mixed[:, :, start_point: start_point + window_frames] = x[shuffle_idx][
            :, :, start_point: start_point + window_frames
        ]
    else:
        mixed[:, :, start_point: start_point + window_frames] = x[
            :, :, start_point: start_point + window_frames
        ]

    return mixed, true_predictions