File size: 7,728 Bytes
83940d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Optional, Tuple
import librosa
import numpy as np
import torch

from audioseal.libs.audiocraft.modules.seanet import SEANetEncoderKeepDimension

logger = logging.getLogger("Audioseal")

COMPATIBLE_WARNING = """
AudioSeal is designed to work at a sample rate 16khz.
Implicit sampling rate usage is deprecated and will be removed in future version.
To remove this warning please add this argument to the function call:
sample_rate = your_sample_rate
"""

class MsgProcessor(torch.nn.Module):
    def __init__(self, nbits: int, hidden_size: int):
        super().__init__()
        assert nbits > 0, "MsgProcessor should not be built in 0bit watermarking"
        self.nbits = nbits
        self.hidden_size = hidden_size
        self.msg_processor = torch.nn.Embedding(2 * nbits, hidden_size)

    def forward(self, hidden: torch.Tensor, msg: torch.Tensor) -> torch.Tensor:
        indices = 2 * torch.arange(msg.shape[-1]).to(msg.device)
        indices = indices.repeat(msg.shape[0], 1)
        indices = (indices + msg).long()
        msg_aux = self.msg_processor(indices)
        msg_aux = msg_aux.sum(dim=-2)
        msg_aux = msg_aux.unsqueeze(-1).repeat(1, 1, hidden.shape[2])
        hidden = hidden + msg_aux
        return hidden

def compute_stft_energy(audio: torch.Tensor, sr: int, n_fft: int = 2048, hop_length: int = 512) -> torch.Tensor:
    batch_size = audio.size(0)
    energy_values = []

    for i in range(batch_size):
        y = audio[i].cpu().numpy()
        stft = np.abs(librosa.stft(y, n_fft=n_fft, hop_length=hop_length))
        frame_energy = torch.tensor(np.sum(stft ** 2, axis=0), device=audio.device)
        energy_values.append(frame_energy)
    
    energy_values = torch.stack(energy_values, dim=0)
    return energy_values

def compute_adaptive_alpha_librosa(energy_values: torch.Tensor, min_alpha: float = 0.5, max_alpha: float = 1.5) -> torch.Tensor:
    normalized_energy = (energy_values - energy_values.min(dim=1, keepdim=True)[0]) / (
        energy_values.max(dim=1, keepdim=True)[0] - energy_values.min(dim=1, keepdim=True)[0] + 1e-6
    )
    alpha_values = min_alpha + normalized_energy * (max_alpha - min_alpha)
    return alpha_values

class AudioSealWM(torch.nn.Module):
    def __init__(self, encoder: torch.nn.Module, decoder: torch.nn.Module, msg_processor: Optional[torch.nn.Module] = None):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.msg_processor = msg_processor
        self._message: Optional[torch.Tensor] = None
        self._original_payload: Optional[torch.Tensor] = None

    @property
    def message(self) -> Optional[torch.Tensor]:
        return self._message

    @message.setter
    def message(self, message: torch.Tensor) -> None:
        self._message = message

    def get_original_payload(self) -> Optional[torch.Tensor]:
        return self._original_payload
        
    def get_watermark(self, x: torch.Tensor, sample_rate: Optional[int] = None, message: Optional[torch.Tensor] = None) -> torch.Tensor:
        # Call the forward method manually here
        return self.forward(x, sample_rate, message)

    def forward(self, x: torch.Tensor, sample_rate: Optional[int] = None, message: Optional[torch.Tensor] = None,
                n_fft: int = 2048, hop_length: int = 512, min_alpha: float = 0.5, max_alpha: float = 1.5) -> torch.Tensor:
        print("Forward method called!")  # This should always print if forward is being executed
        if sample_rate is None:
            logger.warning(COMPATIBLE_WARNING)
            sample_rate = 16_000

        if sample_rate != 16000:
            x_np = x.detach().cpu().numpy()  # Ensure detached tensor is converted to NumPy array
            resampled_x = librosa.resample(x_np, orig_sr=sample_rate, target_sr=16000)
            x = torch.tensor(resampled_x, device=x.device)

        hidden = self.encoder(x)

        if self.msg_processor is not None:
            if message is None:
                if self.message is None:
                    message = torch.randint(0, 2, (x.shape[0], self.msg_processor.nbits), device=x.device)
                else:
                    message = self.message.to(device=x.device)
            else:
                message = message.to(device=x.device)

            hidden = self.msg_processor(hidden, message)
            self._original_payload = message

        watermark = self.decoder(hidden)

        if sample_rate != 16000:
            watermark_np = watermark.detach().cpu().numpy()
            resampled_watermark = librosa.resample(watermark_np, orig_sr=16000, target_sr=sample_rate)
            watermark = torch.tensor(resampled_watermark, device=watermark.device)

        energy_values = compute_stft_energy(x, sr=sample_rate, n_fft=n_fft, hop_length=hop_length)
        adaptive_alpha = compute_adaptive_alpha_librosa(energy_values, min_alpha=min_alpha, max_alpha=max_alpha)

        # Adjust stretched_alpha to match the dimensions of watermark
        num_frames = adaptive_alpha.size(1)
        stretched_alpha = torch.repeat_interleave(adaptive_alpha, hop_length, dim=1)
        stretched_alpha = stretched_alpha[:, :x.size(1)]
        
        # Make sure dimensions align
        if stretched_alpha.dim() < watermark.dim():
            stretched_alpha = stretched_alpha.unsqueeze(-1)  # Add extra dimension

        stretched_alpha = stretched_alpha.expand_as(watermark)  # Match dimensions
        print(f"stretched_alpha shape: {stretched_alpha.shape} for debugging")

        watermarked_audio = x + stretched_alpha * watermark

        return watermarked_audio

class AudioSealDetector(torch.nn.Module):
    def __init__(self, *args, nbits: int = 0, **kwargs):
        super().__init__()
        encoder = SEANetEncoderKeepDimension(*args, **kwargs)
        last_layer = torch.nn.Conv1d(encoder.output_dim, 2 + nbits, 1)
        self.detector = torch.nn.Sequential(encoder, last_layer)
        self.nbits = nbits

    def detect_watermark(self, x: torch.Tensor, sample_rate: Optional[int] = None, message_threshold: float = 0.5) -> Tuple[float, torch.Tensor]:
        result, message = self.forward(x, sample_rate=sample_rate)
        print("Forward method in detector called!")
        detected = (torch.count_nonzero(torch.gt(result[:, 1, :], 0.5)) / result.shape[-1])
        detect_prob = detected.cpu().item()
        message = torch.gt(message, message_threshold).int()
        return detect_prob, message

    def decode_message(self, result: torch.Tensor) -> torch.Tensor:
        assert (result.dim() > 2 and result.shape[1] == self.nbits) or (
            result.dim() == 2 and result.shape[0] == self.nbits
        ), f"Expect message of size [,{self.nbits}, frames] (get {result.size()})"
        decoded_message = result.mean(dim=-1)
        return torch.sigmoid(decoded_message)

    def forward(self, x: torch.Tensor, sample_rate: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        if sample_rate is None:
            logger.warning(COMPATIBLE_WARNING)
            sample_rate = 16_000

        if sample_rate != 16000:
            x_np = x.detach().cpu().numpy()
            resampled_x = librosa.resample(x_np, orig_sr=sample_rate, target_sr=16000)
            x = torch.tensor(resampled_x, device=x.device)

        result = self.detector(x)
        result[:, :2, :] = torch.softmax(result[:, :2, :], dim=1)
        message = self.decode_message(result[:, 2:, :])
        return result[:, :2, :], message