Zw07 commited on
Commit
0209786
·
verified ·
1 Parent(s): 430043a

Upload 14 files

Browse files
src/audioseal/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Watermarking and detection for speech audios
9
+
10
+ A Pytorch-based localized algorithm for proactive detection
11
+ of the watermarkings in AI-generated audios, with very fast
12
+ detector.
13
+
14
+ """
15
+
16
+ __version__ = "0.1.4"
17
+
18
+
19
+ from audioseal import builder
20
+ from audioseal.loader import AudioSeal
21
+ from audioseal.models import AudioSealDetector, AudioSealWM, MsgProcessor
src/audioseal/builder.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import asdict, dataclass, field, is_dataclass
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ from omegaconf import DictConfig, OmegaConf
11
+ from torch import device, dtype
12
+ from typing_extensions import TypeAlias
13
+
14
+ from audioseal.libs import audiocraft
15
+ from audioseal.models import AudioSealDetector, AudioSealWM, MsgProcessor
16
+
17
+ Device: TypeAlias = device
18
+
19
+ DataType: TypeAlias = dtype
20
+
21
+
22
+ @dataclass
23
+ class SEANetConfig:
24
+ """
25
+ Map common hparams of SEANet encoder and decoder.
26
+ """
27
+
28
+ channels: int
29
+ dimension: int
30
+ n_filters: int
31
+ n_residual_layers: int
32
+ ratios: List[int]
33
+ activation: str
34
+ activation_params: Dict[str, float]
35
+ norm: str
36
+ norm_params: Dict[str, Any]
37
+ kernel_size: int
38
+ last_kernel_size: int
39
+ residual_kernel_size: int
40
+ dilation_base: int
41
+ causal: bool
42
+ pad_mode: str
43
+ true_skip: bool
44
+ compress: int
45
+ lstm: int
46
+ disable_norm_outer_blocks: int
47
+
48
+
49
+ @dataclass
50
+ class DecoderConfig:
51
+ final_activation: Optional[str]
52
+ final_activation_params: Optional[dict]
53
+ trim_right_ratio: float
54
+
55
+
56
+ @dataclass
57
+ class DetectorConfig:
58
+ output_dim: int = 32
59
+
60
+
61
+ @dataclass
62
+ class AudioSealWMConfig:
63
+ nbits: int
64
+ seanet: SEANetConfig
65
+ decoder: DecoderConfig
66
+
67
+
68
+ @dataclass
69
+ class AudioSealDetectorConfig:
70
+ nbits: int
71
+ seanet: SEANetConfig
72
+ detector: DetectorConfig = field(default_factory=lambda: DetectorConfig())
73
+
74
+
75
+ def as_dict(obj: Any) -> Dict[str, Any]:
76
+ if isinstance(obj, dict):
77
+ return obj
78
+ if is_dataclass(obj) and not isinstance(obj, type):
79
+ return asdict(obj)
80
+ elif isinstance(obj, DictConfig):
81
+ return OmegaConf.to_container(obj) # type: ignore
82
+ else:
83
+ raise NotImplementedError(f"Unsupported type for config: {type(obj)}")
84
+
85
+
86
+ def create_generator(
87
+ config: AudioSealWMConfig,
88
+ *,
89
+ device: Optional[Device] = None,
90
+ dtype: Optional[DataType] = None,
91
+ ) -> AudioSealWM:
92
+ """Create a generator from hparams"""
93
+
94
+ # Currently the encoder hparams are the same as
95
+ # SEANet, but this can be changed in the future.
96
+ encoder = audiocraft.modules.SEANetEncoder(**as_dict(config.seanet))
97
+ encoder = encoder.to(device=device, dtype=dtype)
98
+
99
+ decoder_config = {**as_dict(config.seanet), **as_dict(config.decoder)}
100
+ decoder = audiocraft.modules.SEANetDecoder(**as_dict(decoder_config))
101
+ decoder = decoder.to(device=device, dtype=dtype)
102
+
103
+ msgprocessor = MsgProcessor(nbits=config.nbits, hidden_size=config.seanet.dimension)
104
+ msgprocessor = msgprocessor.to(device=device, dtype=dtype)
105
+
106
+ return AudioSealWM(encoder=encoder, decoder=decoder, msg_processor=msgprocessor)
107
+
108
+
109
+ def create_detector(
110
+ config: AudioSealDetectorConfig,
111
+ *,
112
+ device: Optional[Device] = None,
113
+ dtype: Optional[DataType] = None,
114
+ ) -> AudioSealDetector:
115
+ detector_config = {**as_dict(config.seanet), **as_dict(config.detector)}
116
+ detector = AudioSealDetector(nbits=config.nbits, **detector_config)
117
+ detector = detector.to(device=device, dtype=dtype)
118
+ return detector
src/audioseal/cards/audioseal_detector_16bits.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package __global__
2
+
3
+ name: audioseal_detector_16bits
4
+ model_type: seanet
5
+ checkpoint: "https://huggingface.co/facebook/audioseal/resolve/main/detector_base.pth"
6
+ nbits: 16
7
+ seanet:
8
+ activation: ELU
9
+ activation_params:
10
+ alpha: 1.0
11
+ causal: false
12
+ channels: 1
13
+ compress: 2
14
+ dilation_base: 2
15
+ dimension: 128
16
+ disable_norm_outer_blocks: 0
17
+ kernel_size: 7
18
+ last_kernel_size: 7
19
+ lstm: 2
20
+ n_filters: 32
21
+ n_residual_layers: 1
22
+ norm: weight_norm
23
+ norm_params: {}
24
+ pad_mode: constant
25
+ ratios:
26
+ - 8
27
+ - 5
28
+ - 4
29
+ - 2
30
+ residual_kernel_size: 3
31
+ true_skip: true
32
+ detector:
33
+ output_dim: 32
src/audioseal/cards/audioseal_wm_16bits.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ name: audioseal_wm_16bits
8
+ model_type: seanet
9
+ checkpoint: "https://huggingface.co/facebook/audioseal/resolve/main/generator_base.pth"
10
+ nbits: 16
11
+ seanet:
12
+ activation: ELU
13
+ activation_params:
14
+ alpha: 1.0
15
+ causal: false
16
+ channels: 1
17
+ compress: 2
18
+ dilation_base: 2
19
+ dimension: 128
20
+ disable_norm_outer_blocks: 0
21
+ kernel_size: 7
22
+ last_kernel_size: 7
23
+ lstm: 2
24
+ n_filters: 32
25
+ n_residual_layers: 1
26
+ norm: weight_norm
27
+ norm_params: {}
28
+ pad_mode: constant
29
+ ratios:
30
+ - 8
31
+ - 5
32
+ - 4
33
+ - 2
34
+ residual_kernel_size: 3
35
+ true_skip: true
36
+ decoder:
37
+ final_activation: null
38
+ final_activation_params: null
39
+ trim_right_ratio: 1.0
src/audioseal/libs/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
src/audioseal/libs/audiocraft/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
src/audioseal/libs/audiocraft/modules/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ from .seanet import SEANetDecoder, SEANetEncoder, SEANetEncoderKeepDimension
src/audioseal/libs/audiocraft/modules/conv.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Vendor from https://github.com/facebookresearch/audiocraft
8
+
9
+ import math
10
+ import typing as tp
11
+ import warnings
12
+
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from torch.nn.utils import spectral_norm
17
+
18
+ try:
19
+ from torch.nn.utils.parametrizations import weight_norm
20
+ except ImportError:
21
+ # Old Pytorch
22
+ from torch.nn.utils import weight_norm
23
+
24
+
25
+ CONV_NORMALIZATIONS = frozenset(
26
+ ["none", "weight_norm", "spectral_norm", "time_group_norm"]
27
+ )
28
+
29
+
30
+ def apply_parametrization_norm(module: nn.Module, norm: str = "none"):
31
+ assert norm in CONV_NORMALIZATIONS
32
+ if norm == "weight_norm":
33
+ return weight_norm(module)
34
+ elif norm == "spectral_norm":
35
+ return spectral_norm(module)
36
+ else:
37
+ # We already check was in CONV_NORMALIZATION, so any other choice
38
+ # doesn't need reparametrization.
39
+ return module
40
+
41
+
42
+ def get_norm_module(
43
+ module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
44
+ ):
45
+ """Return the proper normalization module. If causal is True, this will ensure the returned
46
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
47
+ """
48
+ assert norm in CONV_NORMALIZATIONS
49
+ if norm == "time_group_norm":
50
+ if causal:
51
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
52
+ assert isinstance(module, nn.modules.conv._ConvNd)
53
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
54
+ else:
55
+ return nn.Identity()
56
+
57
+
58
+ def get_extra_padding_for_conv1d(
59
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
60
+ ) -> int:
61
+ """See `pad_for_conv1d`."""
62
+ length = x.shape[-1]
63
+ n_frames = (length - kernel_size + padding_total) / stride + 1
64
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
65
+ return ideal_length - length
66
+
67
+
68
+ def pad_for_conv1d(
69
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
70
+ ):
71
+ """Pad for a convolution to make sure that the last window is full.
72
+ Extra padding is added at the end. This is required to ensure that we can rebuild
73
+ an output of the same length, as otherwise, even with padding, some time steps
74
+ might get removed.
75
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
76
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
77
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
78
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
79
+ 1 2 3 4 # once you removed padding, we are missing one time step !
80
+ """
81
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
82
+ return F.pad(x, (0, extra_padding))
83
+
84
+
85
+ def pad1d(
86
+ x: torch.Tensor,
87
+ paddings: tp.Tuple[int, int],
88
+ mode: str = "constant",
89
+ value: float = 0.0,
90
+ ):
91
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
92
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
93
+ """
94
+ length = x.shape[-1]
95
+ padding_left, padding_right = paddings
96
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
97
+ if mode == "reflect":
98
+ max_pad = max(padding_left, padding_right)
99
+ extra_pad = 0
100
+ if length <= max_pad:
101
+ extra_pad = max_pad - length + 1
102
+ x = F.pad(x, (0, extra_pad))
103
+ padded = F.pad(x, paddings, mode, value)
104
+ end = padded.shape[-1] - extra_pad
105
+ return padded[..., :end]
106
+ else:
107
+ return F.pad(x, paddings, mode, value)
108
+
109
+
110
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
111
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
112
+ padding_left, padding_right = paddings
113
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
114
+ assert (padding_left + padding_right) <= x.shape[-1]
115
+ end = x.shape[-1] - padding_right
116
+ return x[..., padding_left:end]
117
+
118
+
119
+ class NormConv1d(nn.Module):
120
+ """Wrapper around Conv1d and normalization applied to this conv
121
+ to provide a uniform interface across normalization approaches.
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ *args,
127
+ causal: bool = False,
128
+ norm: str = "none",
129
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
130
+ **kwargs,
131
+ ):
132
+ super().__init__()
133
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
134
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
135
+ self.norm_type = norm
136
+
137
+ def forward(self, x):
138
+ x = self.conv(x)
139
+ x = self.norm(x)
140
+ return x
141
+
142
+
143
+ class NormConv2d(nn.Module):
144
+ """Wrapper around Conv2d and normalization applied to this conv
145
+ to provide a uniform interface across normalization approaches.
146
+ """
147
+
148
+ def __init__(
149
+ self,
150
+ *args,
151
+ norm: str = "none",
152
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
153
+ **kwargs,
154
+ ):
155
+ super().__init__()
156
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
157
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
158
+ self.norm_type = norm
159
+
160
+ def forward(self, x):
161
+ x = self.conv(x)
162
+ x = self.norm(x)
163
+ return x
164
+
165
+
166
+ class NormConvTranspose1d(nn.Module):
167
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
168
+ to provide a uniform interface across normalization approaches.
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ *args,
174
+ causal: bool = False,
175
+ norm: str = "none",
176
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
177
+ **kwargs,
178
+ ):
179
+ super().__init__()
180
+ self.convtr = apply_parametrization_norm(
181
+ nn.ConvTranspose1d(*args, **kwargs), norm
182
+ )
183
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
184
+ self.norm_type = norm
185
+
186
+ def forward(self, x):
187
+ x = self.convtr(x)
188
+ x = self.norm(x)
189
+ return x
190
+
191
+
192
+ class NormConvTranspose2d(nn.Module):
193
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
194
+ to provide a uniform interface across normalization approaches.
195
+ """
196
+
197
+ def __init__(
198
+ self,
199
+ *args,
200
+ norm: str = "none",
201
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
202
+ **kwargs,
203
+ ):
204
+ super().__init__()
205
+ self.convtr = apply_parametrization_norm(
206
+ nn.ConvTranspose2d(*args, **kwargs), norm
207
+ )
208
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
209
+
210
+ def forward(self, x):
211
+ x = self.convtr(x)
212
+ x = self.norm(x)
213
+ return x
214
+
215
+
216
+ class StreamableConv1d(nn.Module):
217
+ """Conv1d with some builtin handling of asymmetric or causal padding
218
+ and normalization.
219
+ """
220
+
221
+ def __init__(
222
+ self,
223
+ in_channels: int,
224
+ out_channels: int,
225
+ kernel_size: int,
226
+ stride: int = 1,
227
+ dilation: int = 1,
228
+ groups: int = 1,
229
+ bias: bool = True,
230
+ causal: bool = False,
231
+ norm: str = "none",
232
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
233
+ pad_mode: str = "reflect",
234
+ ):
235
+ super().__init__()
236
+ # warn user on unusual setup between dilation and stride
237
+ if stride > 1 and dilation > 1:
238
+ warnings.warn(
239
+ "StreamableConv1d has been initialized with stride > 1 and dilation > 1"
240
+ f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
241
+ )
242
+ self.conv = NormConv1d(
243
+ in_channels,
244
+ out_channels,
245
+ kernel_size,
246
+ stride,
247
+ dilation=dilation,
248
+ groups=groups,
249
+ bias=bias,
250
+ causal=causal,
251
+ norm=norm,
252
+ norm_kwargs=norm_kwargs,
253
+ )
254
+ self.causal = causal
255
+ self.pad_mode = pad_mode
256
+
257
+ def forward(self, x):
258
+ B, C, T = x.shape
259
+ kernel_size = self.conv.conv.kernel_size[0]
260
+ stride = self.conv.conv.stride[0]
261
+ dilation = self.conv.conv.dilation[0]
262
+ kernel_size = (
263
+ kernel_size - 1
264
+ ) * dilation + 1 # effective kernel size with dilations
265
+ padding_total = kernel_size - stride
266
+ extra_padding = get_extra_padding_for_conv1d(
267
+ x, kernel_size, stride, padding_total
268
+ )
269
+ if self.causal:
270
+ # Left padding for causal
271
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
272
+ else:
273
+ # Asymmetric padding required for odd strides
274
+ padding_right = padding_total // 2
275
+ padding_left = padding_total - padding_right
276
+ x = pad1d(
277
+ x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
278
+ )
279
+ return self.conv(x)
280
+
281
+
282
+ class StreamableConvTranspose1d(nn.Module):
283
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
284
+ and normalization.
285
+ """
286
+
287
+ def __init__(
288
+ self,
289
+ in_channels: int,
290
+ out_channels: int,
291
+ kernel_size: int,
292
+ stride: int = 1,
293
+ causal: bool = False,
294
+ norm: str = "none",
295
+ trim_right_ratio: float = 1.0,
296
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
297
+ ):
298
+ super().__init__()
299
+ self.convtr = NormConvTranspose1d(
300
+ in_channels,
301
+ out_channels,
302
+ kernel_size,
303
+ stride,
304
+ causal=causal,
305
+ norm=norm,
306
+ norm_kwargs=norm_kwargs,
307
+ )
308
+ self.causal = causal
309
+ self.trim_right_ratio = trim_right_ratio
310
+ assert (
311
+ self.causal or self.trim_right_ratio == 1.0
312
+ ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
313
+ assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
314
+
315
+ def forward(self, x):
316
+ kernel_size = self.convtr.convtr.kernel_size[0]
317
+ stride = self.convtr.convtr.stride[0]
318
+ padding_total = kernel_size - stride
319
+
320
+ y = self.convtr(x)
321
+
322
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
323
+ # removed at the very end, when keeping only the right length for the output,
324
+ # as removing it here would require also passing the length at the matching layer
325
+ # in the encoder.
326
+ if self.causal:
327
+ # Trim the padding on the right according to the specified ratio
328
+ # if trim_right_ratio = 1.0, trim everything from right
329
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
330
+ padding_left = padding_total - padding_right
331
+ y = unpad1d(y, (padding_left, padding_right))
332
+ else:
333
+ # Asymmetric padding required for odd strides
334
+ padding_right = padding_total // 2
335
+ padding_left = padding_total - padding_right
336
+ y = unpad1d(y, (padding_left, padding_right))
337
+ return y
src/audioseal/libs/audiocraft/modules/lstm.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Vendor from https://github.com/facebookresearch/audiocraft
8
+
9
+ from torch import nn
10
+
11
+
12
+ class StreamableLSTM(nn.Module):
13
+ """LSTM without worrying about the hidden state, nor the layout of the data.
14
+ Expects input as convolutional layout.
15
+ """
16
+
17
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
18
+ super().__init__()
19
+ self.skip = skip
20
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
21
+
22
+ def forward(self, x):
23
+ x = x.permute(2, 0, 1)
24
+ y, _ = self.lstm(x)
25
+ if self.skip:
26
+ y = y + x
27
+ y = y.permute(1, 2, 0)
28
+ return y
src/audioseal/libs/audiocraft/modules/seanet.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Vendor from https://github.com/facebookresearch/audiocraft
8
+
9
+ import math
10
+ import typing as tp
11
+
12
+ import numpy as np
13
+ import torch.nn as nn
14
+
15
+ from audioseal.libs.audiocraft.modules.conv import (
16
+ StreamableConv1d,
17
+ StreamableConvTranspose1d,
18
+ )
19
+ from audioseal.libs.audiocraft.modules.lstm import StreamableLSTM
20
+
21
+
22
+ class SEANetResnetBlock(nn.Module):
23
+ """Residual block from SEANet model.
24
+
25
+ Args:
26
+ dim (int): Dimension of the input/output.
27
+ kernel_sizes (list): List of kernel sizes for the convolutions.
28
+ dilations (list): List of dilations for the convolutions.
29
+ activation (str): Activation function.
30
+ activation_params (dict): Parameters to provide to the activation function.
31
+ norm (str): Normalization method.
32
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
33
+ causal (bool): Whether to use fully causal convolution.
34
+ pad_mode (str): Padding mode for the convolutions.
35
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
36
+ true_skip (bool): Whether to use true skip connection or a simple
37
+ (streamable) convolution as the skip connection.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ dim: int,
43
+ kernel_sizes: tp.List[int] = [3, 1],
44
+ dilations: tp.List[int] = [1, 1],
45
+ activation: str = "ELU",
46
+ activation_params: dict = {"alpha": 1.0},
47
+ norm: str = "none",
48
+ norm_params: tp.Dict[str, tp.Any] = {},
49
+ causal: bool = False,
50
+ pad_mode: str = "reflect",
51
+ compress: int = 2,
52
+ true_skip: bool = True,
53
+ ):
54
+ super().__init__()
55
+ assert len(kernel_sizes) == len(
56
+ dilations
57
+ ), "Number of kernel sizes should match number of dilations"
58
+ act = getattr(nn, activation)
59
+ hidden = dim // compress
60
+ block = []
61
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
62
+ in_chs = dim if i == 0 else hidden
63
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
64
+ block += [
65
+ act(**activation_params),
66
+ StreamableConv1d(
67
+ in_chs,
68
+ out_chs,
69
+ kernel_size=kernel_size,
70
+ dilation=dilation,
71
+ norm=norm,
72
+ norm_kwargs=norm_params,
73
+ causal=causal,
74
+ pad_mode=pad_mode,
75
+ ),
76
+ ]
77
+ self.block = nn.Sequential(*block)
78
+ self.shortcut: nn.Module
79
+ if true_skip:
80
+ self.shortcut = nn.Identity()
81
+ else:
82
+ self.shortcut = StreamableConv1d(
83
+ dim,
84
+ dim,
85
+ kernel_size=1,
86
+ norm=norm,
87
+ norm_kwargs=norm_params,
88
+ causal=causal,
89
+ pad_mode=pad_mode,
90
+ )
91
+
92
+ def forward(self, x):
93
+ return self.shortcut(x) + self.block(x)
94
+
95
+
96
+ class SEANetEncoder(nn.Module):
97
+ """SEANet encoder.
98
+
99
+ Args:
100
+ channels (int): Audio channels.
101
+ dimension (int): Intermediate representation dimension.
102
+ n_filters (int): Base width for the model.
103
+ n_residual_layers (int): nb of residual layers.
104
+ ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
105
+ upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
106
+ that must match the decoder order. We use the decoder order as some models may only employ the decoder.
107
+ activation (str): Activation function.
108
+ activation_params (dict): Parameters to provide to the activation function.
109
+ norm (str): Normalization method.
110
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
111
+ kernel_size (int): Kernel size for the initial convolution.
112
+ last_kernel_size (int): Kernel size for the initial convolution.
113
+ residual_kernel_size (int): Kernel size for the residual layers.
114
+ dilation_base (int): How much to increase the dilation with each layer.
115
+ causal (bool): Whether to use fully causal convolution.
116
+ pad_mode (str): Padding mode for the convolutions.
117
+ true_skip (bool): Whether to use true skip connection or a simple
118
+ (streamable) convolution as the skip connection in the residual network blocks.
119
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
120
+ lstm (int): Number of LSTM layers at the end of the encoder.
121
+ disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
122
+ For the encoder, it corresponds to the N first blocks.
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ channels: int = 1,
128
+ dimension: int = 128,
129
+ n_filters: int = 32,
130
+ n_residual_layers: int = 3,
131
+ ratios: tp.List[int] = [8, 5, 4, 2],
132
+ activation: str = "ELU",
133
+ activation_params: dict = {"alpha": 1.0},
134
+ norm: str = "none",
135
+ norm_params: tp.Dict[str, tp.Any] = {},
136
+ kernel_size: int = 7,
137
+ last_kernel_size: int = 7,
138
+ residual_kernel_size: int = 3,
139
+ dilation_base: int = 2,
140
+ causal: bool = False,
141
+ pad_mode: str = "reflect",
142
+ true_skip: bool = True,
143
+ compress: int = 2,
144
+ lstm: int = 0,
145
+ disable_norm_outer_blocks: int = 0,
146
+ ):
147
+ super().__init__()
148
+ self.channels = channels
149
+ self.dimension = dimension
150
+ self.n_filters = n_filters
151
+ self.ratios = list(reversed(ratios))
152
+ del ratios
153
+ self.n_residual_layers = n_residual_layers
154
+ self.hop_length = np.prod(self.ratios)
155
+ self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
156
+ self.disable_norm_outer_blocks = disable_norm_outer_blocks
157
+ assert (
158
+ self.disable_norm_outer_blocks >= 0
159
+ and self.disable_norm_outer_blocks <= self.n_blocks
160
+ ), (
161
+ "Number of blocks for which to disable norm is invalid."
162
+ "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
163
+ )
164
+
165
+ act = getattr(nn, activation)
166
+ mult = 1
167
+ model: tp.List[nn.Module] = [
168
+ StreamableConv1d(
169
+ channels,
170
+ mult * n_filters,
171
+ kernel_size,
172
+ norm="none" if self.disable_norm_outer_blocks >= 1 else norm,
173
+ norm_kwargs=norm_params,
174
+ causal=causal,
175
+ pad_mode=pad_mode,
176
+ )
177
+ ]
178
+ # Downsample to raw audio scale
179
+ for i, ratio in enumerate(self.ratios):
180
+ block_norm = "none" if self.disable_norm_outer_blocks >= i + 2 else norm
181
+ # Add residual layers
182
+ for j in range(n_residual_layers):
183
+ model += [
184
+ SEANetResnetBlock(
185
+ mult * n_filters,
186
+ kernel_sizes=[residual_kernel_size, 1],
187
+ dilations=[dilation_base**j, 1],
188
+ norm=block_norm,
189
+ norm_params=norm_params,
190
+ activation=activation,
191
+ activation_params=activation_params,
192
+ causal=causal,
193
+ pad_mode=pad_mode,
194
+ compress=compress,
195
+ true_skip=true_skip,
196
+ )
197
+ ]
198
+
199
+ # Add downsampling layers
200
+ model += [
201
+ act(**activation_params),
202
+ StreamableConv1d(
203
+ mult * n_filters,
204
+ mult * n_filters * 2,
205
+ kernel_size=ratio * 2,
206
+ stride=ratio,
207
+ norm=block_norm,
208
+ norm_kwargs=norm_params,
209
+ causal=causal,
210
+ pad_mode=pad_mode,
211
+ ),
212
+ ]
213
+ mult *= 2
214
+
215
+ if lstm:
216
+ model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
217
+
218
+ model += [
219
+ act(**activation_params),
220
+ StreamableConv1d(
221
+ mult * n_filters,
222
+ dimension,
223
+ last_kernel_size,
224
+ norm=(
225
+ "none" if self.disable_norm_outer_blocks == self.n_blocks else norm
226
+ ),
227
+ norm_kwargs=norm_params,
228
+ causal=causal,
229
+ pad_mode=pad_mode,
230
+ ),
231
+ ]
232
+
233
+ self.model = nn.Sequential(*model)
234
+
235
+ def forward(self, x):
236
+ return self.model(x)
237
+
238
+
239
+ class SEANetEncoderKeepDimension(SEANetEncoder):
240
+ """
241
+ similar architecture to the SEANet encoder but with an extra step that
242
+ projects the output dimension to the same input dimension by repeating
243
+ the sequential
244
+
245
+ Args:
246
+ SEANetEncoder (_type_): _description_
247
+ """
248
+
249
+ def __init__(self, *args, **kwargs):
250
+
251
+ self.output_dim = kwargs.pop("output_dim")
252
+ super().__init__(*args, **kwargs)
253
+ # Adding a reverse convolution layer
254
+ self.reverse_convolution = nn.ConvTranspose1d(
255
+ in_channels=self.dimension,
256
+ out_channels=self.output_dim,
257
+ kernel_size=math.prod(self.ratios),
258
+ stride=math.prod(self.ratios),
259
+ padding=0,
260
+ )
261
+
262
+ def forward(self, x):
263
+ orig_nframes = x.shape[-1]
264
+ x = self.model(x)
265
+ x = self.reverse_convolution(x)
266
+ # make sure dim didn't change
267
+ return x[:, :, :orig_nframes]
268
+
269
+
270
+ class SEANetDecoder(nn.Module):
271
+ """SEANet decoder.
272
+
273
+ Args:
274
+ channels (int): Audio channels.
275
+ dimension (int): Intermediate representation dimension.
276
+ n_filters (int): Base width for the model.
277
+ n_residual_layers (int): nb of residual layers.
278
+ ratios (Sequence[int]): kernel size and stride ratios.
279
+ activation (str): Activation function.
280
+ activation_params (dict): Parameters to provide to the activation function.
281
+ final_activation (str): Final activation function after all convolutions.
282
+ final_activation_params (dict): Parameters to provide to the activation function.
283
+ norm (str): Normalization method.
284
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
285
+ kernel_size (int): Kernel size for the initial convolution.
286
+ last_kernel_size (int): Kernel size for the initial convolution.
287
+ residual_kernel_size (int): Kernel size for the residual layers.
288
+ dilation_base (int): How much to increase the dilation with each layer.
289
+ causal (bool): Whether to use fully causal convolution.
290
+ pad_mode (str): Padding mode for the convolutions.
291
+ true_skip (bool): Whether to use true skip connection or a simple.
292
+ (streamable) convolution as the skip connection in the residual network blocks.
293
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
294
+ lstm (int): Number of LSTM layers at the end of the encoder.
295
+ disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
296
+ For the decoder, it corresponds to the N last blocks.
297
+ trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
298
+ If equal to 1.0, it means that all the trimming is done at the right.
299
+ """
300
+
301
+ def __init__(
302
+ self,
303
+ channels: int = 1,
304
+ dimension: int = 128,
305
+ n_filters: int = 32,
306
+ n_residual_layers: int = 3,
307
+ ratios: tp.List[int] = [8, 5, 4, 2],
308
+ activation: str = "ELU",
309
+ activation_params: dict = {"alpha": 1.0},
310
+ final_activation: tp.Optional[str] = None,
311
+ final_activation_params: tp.Optional[dict] = None,
312
+ norm: str = "none",
313
+ norm_params: tp.Dict[str, tp.Any] = {},
314
+ kernel_size: int = 7,
315
+ last_kernel_size: int = 7,
316
+ residual_kernel_size: int = 3,
317
+ dilation_base: int = 2,
318
+ causal: bool = False,
319
+ pad_mode: str = "reflect",
320
+ true_skip: bool = True,
321
+ compress: int = 2,
322
+ lstm: int = 0,
323
+ disable_norm_outer_blocks: int = 0,
324
+ trim_right_ratio: float = 1.0,
325
+ ):
326
+ super().__init__()
327
+ self.dimension = dimension
328
+ self.channels = channels
329
+ self.n_filters = n_filters
330
+ self.ratios = ratios
331
+ del ratios
332
+ self.n_residual_layers = n_residual_layers
333
+ self.hop_length = np.prod(self.ratios)
334
+ self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
335
+ self.disable_norm_outer_blocks = disable_norm_outer_blocks
336
+ assert (
337
+ self.disable_norm_outer_blocks >= 0
338
+ and self.disable_norm_outer_blocks <= self.n_blocks
339
+ ), (
340
+ "Number of blocks for which to disable norm is invalid."
341
+ "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
342
+ )
343
+
344
+ act = getattr(nn, activation)
345
+ mult = int(2 ** len(self.ratios))
346
+ model: tp.List[nn.Module] = [
347
+ StreamableConv1d(
348
+ dimension,
349
+ mult * n_filters,
350
+ kernel_size,
351
+ norm=(
352
+ "none" if self.disable_norm_outer_blocks == self.n_blocks else norm
353
+ ),
354
+ norm_kwargs=norm_params,
355
+ causal=causal,
356
+ pad_mode=pad_mode,
357
+ )
358
+ ]
359
+
360
+ if lstm:
361
+ model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
362
+
363
+ # Upsample to raw audio scale
364
+ for i, ratio in enumerate(self.ratios):
365
+ block_norm = (
366
+ "none"
367
+ if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1)
368
+ else norm
369
+ )
370
+ # Add upsampling layers
371
+ model += [
372
+ act(**activation_params),
373
+ StreamableConvTranspose1d(
374
+ mult * n_filters,
375
+ mult * n_filters // 2,
376
+ kernel_size=ratio * 2,
377
+ stride=ratio,
378
+ norm=block_norm,
379
+ norm_kwargs=norm_params,
380
+ causal=causal,
381
+ trim_right_ratio=trim_right_ratio,
382
+ ),
383
+ ]
384
+ # Add residual layers
385
+ for j in range(n_residual_layers):
386
+ model += [
387
+ SEANetResnetBlock(
388
+ mult * n_filters // 2,
389
+ kernel_sizes=[residual_kernel_size, 1],
390
+ dilations=[dilation_base**j, 1],
391
+ activation=activation,
392
+ activation_params=activation_params,
393
+ norm=block_norm,
394
+ norm_params=norm_params,
395
+ causal=causal,
396
+ pad_mode=pad_mode,
397
+ compress=compress,
398
+ true_skip=true_skip,
399
+ )
400
+ ]
401
+
402
+ mult //= 2
403
+
404
+ # Add final layers
405
+ model += [
406
+ act(**activation_params),
407
+ StreamableConv1d(
408
+ n_filters,
409
+ channels,
410
+ last_kernel_size,
411
+ norm="none" if self.disable_norm_outer_blocks >= 1 else norm,
412
+ norm_kwargs=norm_params,
413
+ causal=causal,
414
+ pad_mode=pad_mode,
415
+ ),
416
+ ]
417
+ # Add optional final activation to decoder (eg. tanh)
418
+ if final_activation is not None:
419
+ final_act = getattr(nn, final_activation)
420
+ final_activation_params = final_activation_params or {}
421
+ model += [final_act(**final_activation_params)]
422
+ self.model = nn.Sequential(*model)
423
+
424
+ def forward(self, z):
425
+ y = self.model(z)
426
+ return y
src/audioseal/loader.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import os
9
+ from dataclasses import fields
10
+ from hashlib import sha1
11
+ from pathlib import Path
12
+ from typing import ( # type: ignore[attr-defined]
13
+ Any,
14
+ Dict,
15
+ List,
16
+ Optional,
17
+ Tuple,
18
+ Type,
19
+ TypeVar,
20
+ Union,
21
+ cast,
22
+ )
23
+ from urllib.parse import urlparse # noqa: F401
24
+
25
+ import torch
26
+ from omegaconf import DictConfig, OmegaConf
27
+
28
+ import audioseal
29
+ from audioseal.builder import (
30
+ AudioSealDetectorConfig,
31
+ AudioSealWMConfig,
32
+ create_detector,
33
+ create_generator,
34
+ )
35
+ from audioseal.models import AudioSealDetector, AudioSealWM
36
+
37
+ AudioSealT = TypeVar("AudioSealT", AudioSealWMConfig, AudioSealDetectorConfig)
38
+
39
+
40
+ class ModelLoadError(RuntimeError):
41
+ """Raised when the model loading fails"""
42
+
43
+
44
+ def _get_path_from_env(var_name: str) -> Optional[Path]:
45
+ pathname = os.getenv(var_name)
46
+ if not pathname:
47
+ return None
48
+
49
+ try:
50
+ return Path(pathname)
51
+ except ValueError as ex:
52
+ raise RuntimeError(f"Expect valid pathname, get '{pathname}'.") from ex
53
+
54
+
55
+ def _get_cache_dir(env_names: List[str]):
56
+ """Re-use cache dir from a list of existing caches"""
57
+ for env in env_names:
58
+ cache_dir = _get_path_from_env(env)
59
+ if cache_dir:
60
+ break
61
+ else:
62
+ cache_dir = Path("~/.cache").expanduser().resolve()
63
+
64
+ # Create a sub-dir to not mess up with existing caches
65
+ cache_dir = cache_dir / "audioseal"
66
+ cache_dir.mkdir(exist_ok=True, parents=True)
67
+
68
+ return cache_dir
69
+
70
+
71
+ def load_model_checkpoint(
72
+ model_path: Union[Path, str],
73
+ device: Union[str, torch.device] = "cpu",
74
+ ):
75
+ if Path(model_path).is_file():
76
+ return torch.load(model_path, map_location=device)
77
+
78
+ cache_dir = _get_cache_dir(
79
+ ["AUDIOSEAL_CACHE_DIR", "AUDIOCRAFT_CACHE_DIR", "XDG_CACHE_HOME"]
80
+ )
81
+ parts = urlparse(str(model_path))
82
+ if parts.scheme == "https":
83
+
84
+ hash_ = sha1(parts.path.encode()).hexdigest()[:24]
85
+ return torch.hub.load_state_dict_from_url(
86
+ str(model_path), model_dir=cache_dir, map_location=device, file_name=hash_
87
+ )
88
+ elif str(model_path).startswith("facebook/audioseal/"):
89
+ hf_filename = str(model_path)[len("facebook/audioseal/") :]
90
+
91
+ try:
92
+ from huggingface_hub import hf_hub_download
93
+ except ModuleNotFoundError:
94
+ print(
95
+ f"The model path {model_path} seems to be a direct HF path, "
96
+ "but you do not install Huggingface_hub. Install with for example "
97
+ "`pip install huggingface_hub` to use this feature."
98
+ )
99
+ file = hf_hub_download(
100
+ repo_id="facebook/audioseal",
101
+ filename=hf_filename,
102
+ cache_dir=cache_dir,
103
+ library_name="audioseal",
104
+ library_version=audioseal.__version__,
105
+ )
106
+ return torch.load(file, map_location=device)
107
+ else:
108
+ raise ModelLoadError(f"Path or uri {model_path} is unknown or does not exist")
109
+
110
+
111
+ def load_local_model_config(model_card: str) -> Optional[DictConfig]:
112
+ config_file = Path(__file__).parent / "cards" / (model_card + ".yaml")
113
+ if Path(config_file).is_file():
114
+ return cast(DictConfig, OmegaConf.load(config_file.resolve()))
115
+ else:
116
+ return None
117
+
118
+
119
+ class AudioSeal:
120
+
121
+ @staticmethod
122
+ def parse_model(
123
+ model_card_or_path: str,
124
+ model_type: Type[AudioSealT],
125
+ nbits: Optional[int] = None,
126
+ ) -> Tuple[Dict[str, Any], AudioSealT]:
127
+ """
128
+ Parse the information from the model card or checkpoint path using
129
+ the schema `model_type` that defines the model type
130
+ """
131
+ # Get the raw checkpoint and config from the local model cards
132
+ config = load_local_model_config(model_card_or_path)
133
+
134
+ if config:
135
+ assert "checkpoint" in config, f"Checkpoint missing in {model_card_or_path}"
136
+ config_dict = OmegaConf.to_container(config)
137
+ assert isinstance(
138
+ config_dict, dict
139
+ ), f"Cannot parse config from {model_card_or_path}"
140
+ checkpoint = config_dict.pop("checkpoint")
141
+ checkpoint = load_model_checkpoint(checkpoint)
142
+
143
+ # Get the raw checkpoint and config from the checkpoint path
144
+ else:
145
+ config_dict = {}
146
+ checkpoint = load_model_checkpoint(model_card_or_path)
147
+
148
+ if "xp.cfg" in checkpoint:
149
+ config_dict = {**checkpoint["xp.cfg"], **config_dict} # type: ignore
150
+
151
+ model_config = AudioSeal.parse_config(config_dict, config_type=model_type, nbits=nbits) # type: ignore
152
+
153
+ if "model" in checkpoint:
154
+ checkpoint = checkpoint["model"]
155
+
156
+ return checkpoint, model_config
157
+
158
+ @staticmethod
159
+ def parse_config(
160
+ config: Dict[str, Any],
161
+ config_type: Type[AudioSealT],
162
+ nbits: Optional[int] = None,
163
+ ) -> AudioSealT:
164
+
165
+ assert "seanet" in config, f"missing seanet backbone config in {config}"
166
+
167
+ # Patch 1: Resolve the variables in the checkpoint
168
+ config = OmegaConf.create(config) # type: ignore
169
+ OmegaConf.resolve(config) # type: ignore
170
+ config = OmegaConf.to_container(config) # type: ignore
171
+
172
+ # Patch 2: Put decoder, encoder and detector outside seanet
173
+ seanet_config = config["seanet"]
174
+ for key_to_patch in ["encoder", "decoder", "detector"]:
175
+ if key_to_patch in seanet_config:
176
+ config_to_patch = config.get(key_to_patch) or {}
177
+ config[key_to_patch] = {
178
+ **config_to_patch,
179
+ **seanet_config.pop(key_to_patch),
180
+ }
181
+
182
+ config["seanet"] = seanet_config
183
+
184
+ # Patch 3: Put nbits into config if specified
185
+ if nbits and "nbits" not in config:
186
+ config["nbits"] = nbits
187
+
188
+ # remove attributes not related to the model_type
189
+ result_config = {}
190
+ assert config, f"Empty config"
191
+ for field in fields(config_type):
192
+ if field.name in config:
193
+ result_config[field.name] = config[field.name]
194
+
195
+ schema = OmegaConf.structured(config_type)
196
+ schema.merge_with(result_config)
197
+ return schema
198
+
199
+ @staticmethod
200
+ def load_generator(
201
+ model_card_or_path: str,
202
+ nbits: Optional[int] = None,
203
+ ) -> AudioSealWM:
204
+ """Load the AudioSeal generator from the model card"""
205
+ checkpoint, config = AudioSeal.parse_model(
206
+ model_card_or_path,
207
+ AudioSealWMConfig,
208
+ nbits=nbits,
209
+ )
210
+
211
+ model = create_generator(config)
212
+ model.load_state_dict(checkpoint)
213
+ return model
214
+
215
+ @staticmethod
216
+ def load_detector(
217
+ model_card_or_path: str,
218
+ nbits: Optional[int] = None,
219
+ ) -> AudioSealDetector:
220
+ checkpoint, config = AudioSeal.parse_model(
221
+ model_card_or_path,
222
+ AudioSealDetectorConfig,
223
+ nbits=nbits,
224
+ )
225
+ model = create_detector(config)
226
+ model.load_state_dict(checkpoint)
227
+ return model
src/audioseal/models.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ from typing import Optional, Tuple
9
+
10
+ import julius
11
+ import torch
12
+
13
+ from audioseal.libs.audiocraft.modules.seanet import SEANetEncoderKeepDimension
14
+
15
+ logger = logging.getLogger("Audioseal")
16
+
17
+ COMPATIBLE_WARNING = """
18
+ AudioSeal is designed to work at a sample rate 16khz.
19
+ Implicit sampling rate usage is deprecated and will be removed in future version.
20
+ To remove this warning please add this argument to the function call:
21
+ sample_rate = your_sample_rate
22
+ """
23
+
24
+
25
+ class MsgProcessor(torch.nn.Module):
26
+ """
27
+ Apply the secret message to the encoder output.
28
+ Args:
29
+ nbits: Number of bits used to generate the message. Must be non-zero
30
+ hidden_size: Dimension of the encoder output
31
+ """
32
+
33
+ def __init__(self, nbits: int, hidden_size: int):
34
+ super().__init__()
35
+ assert nbits > 0, "MsgProcessor should not be built in 0bit watermarking"
36
+ self.nbits = nbits
37
+ self.hidden_size = hidden_size
38
+ self.msg_processor = torch.nn.Embedding(2 * nbits, hidden_size)
39
+
40
+ def forward(self, hidden: torch.Tensor, msg: torch.Tensor) -> torch.Tensor:
41
+ """
42
+ Build the embedding map: 2 x k -> k x h, then sum on the first dim
43
+ Args:
44
+ hidden: The encoder output, size: batch x hidden x frames
45
+ msg: The secret message, size: batch x k
46
+ """
47
+ # create indices to take from embedding layer
48
+ indices = 2 * torch.arange(msg.shape[-1]).to(msg.device) # k: 0 2 4 ... 2k
49
+ indices = indices.repeat(msg.shape[0], 1) # b x k
50
+ indices = (indices + msg).long()
51
+ msg_aux = self.msg_processor(indices) # b x k -> b x k x h
52
+ msg_aux = msg_aux.sum(dim=-2) # b x k x h -> b x h
53
+ msg_aux = msg_aux.unsqueeze(-1).repeat(
54
+ 1, 1, hidden.shape[2]
55
+ ) # b x h -> b x h x t/f
56
+ hidden = hidden + msg_aux # -> b x h x t/f
57
+ return hidden
58
+
59
+
60
+ class AudioSealWM(torch.nn.Module):
61
+ """
62
+ Generate watermarking for a given audio signal
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ encoder: torch.nn.Module,
68
+ decoder: torch.nn.Module,
69
+ msg_processor: Optional[torch.nn.Module] = None,
70
+ ):
71
+ super().__init__()
72
+ self.encoder = encoder
73
+ self.decoder = decoder
74
+ # The build should take care of validating the dimensions between component
75
+ self.msg_processor = msg_processor
76
+ self._message: Optional[torch.Tensor] = None
77
+
78
+ @property
79
+ def message(self) -> Optional[torch.Tensor]:
80
+ return self._message
81
+
82
+ @message.setter
83
+ def message(self, message: torch.Tensor) -> None:
84
+ self._message = message
85
+
86
+ def get_watermark(
87
+ self,
88
+ x: torch.Tensor,
89
+ sample_rate: Optional[int] = None,
90
+ message: Optional[torch.Tensor] = None,
91
+ ) -> torch.Tensor:
92
+ """
93
+ Get the watermark from an audio tensor and a message.
94
+ If the input message is None, a random message of
95
+ n bits {0,1} will be generated.
96
+ Args:
97
+ x: Audio signal, size: batch x frames
98
+ sample_rate: The sample rate of the input audio (default 16khz as
99
+ currently supported by the main AudioSeal model)
100
+ message: An optional binary message, size: batch x k
101
+ """
102
+ length = x.size(-1)
103
+ if sample_rate is None:
104
+ logger.warning(COMPATIBLE_WARNING)
105
+ sample_rate = 16_000
106
+ assert sample_rate
107
+ if sample_rate != 16000:
108
+ x = julius.resample_frac(x, old_sr=sample_rate, new_sr=16000)
109
+ hidden = self.encoder(x)
110
+
111
+ if self.msg_processor is not None:
112
+ if message is None:
113
+ if self.message is None:
114
+ message = torch.randint(
115
+ 0, 2, (x.shape[0], self.msg_processor.nbits), device=x.device
116
+ )
117
+ else:
118
+ message = self.message.to(device=x.device)
119
+ else:
120
+ message = message.to(device=x.device)
121
+
122
+ hidden = self.msg_processor(hidden, message)
123
+
124
+ watermark = self.decoder(hidden)
125
+
126
+ if sample_rate != 16000:
127
+ watermark = julius.resample_frac(
128
+ watermark, old_sr=16000, new_sr=sample_rate
129
+ )
130
+
131
+ return watermark[..., :length] # trim output cf encodec codebase
132
+
133
+ def forward(
134
+ self,
135
+ x: torch.Tensor,
136
+ sample_rate: Optional[int] = None,
137
+ message: Optional[torch.Tensor] = None,
138
+ alpha: float = 1.0,
139
+ ) -> torch.Tensor:
140
+ """Apply the watermarking to the audio signal x with a tune-down ratio (default 1.0)"""
141
+ if sample_rate is None:
142
+ logger.warning(COMPATIBLE_WARNING)
143
+ sample_rate = 16_000
144
+ wm = self.get_watermark(x, sample_rate=sample_rate, message=message)
145
+ return x + alpha * wm
146
+
147
+
148
+ class AudioSealDetector(torch.nn.Module):
149
+ """
150
+ Detect the watermarking from an audio signal
151
+ Args:
152
+ SEANetEncoderKeepDimension (_type_): _description_
153
+ nbits (int): The number of bits in the secret message. The result will have size
154
+ of 2 + nbits, where the first two items indicate the possibilities of the
155
+ audio being watermarked (positive / negative scores), he rest is used to decode
156
+ the secret message. In 0bit watermarking (no secret message), the detector just
157
+ returns 2 values.
158
+ """
159
+
160
+ def __init__(self, *args, nbits: int = 0, **kwargs):
161
+ super().__init__()
162
+ encoder = SEANetEncoderKeepDimension(*args, **kwargs)
163
+ last_layer = torch.nn.Conv1d(encoder.output_dim, 2 + nbits, 1)
164
+ self.detector = torch.nn.Sequential(encoder, last_layer)
165
+ self.nbits = nbits
166
+
167
+ def detect_watermark(
168
+ self,
169
+ x: torch.Tensor,
170
+ sample_rate: Optional[int] = None,
171
+ message_threshold: float = 0.5,
172
+ ) -> Tuple[float, torch.Tensor]:
173
+ """
174
+ A convenience function that returns a probability of an audio being watermarked,
175
+ together with its message in n-bits (binary) format. If the audio is not watermarked,
176
+ the message will be random.
177
+ Args:
178
+ x: Audio signal, size: batch x frames
179
+ sample_rate: The sample rate of the input audio
180
+ message_threshold: threshold used to convert the watermark output (probability
181
+ of each bits being 0 or 1) into the binary n-bit message.
182
+ """
183
+ if sample_rate is None:
184
+ logger.warning(COMPATIBLE_WARNING)
185
+ sample_rate = 16_000
186
+ result, message = self.forward(x, sample_rate=sample_rate) # b x 2+nbits
187
+ detected = (
188
+ torch.count_nonzero(torch.gt(result[:, 1, :], 0.5)) / result.shape[-1]
189
+ )
190
+ detect_prob = detected.cpu().item() # type: ignore
191
+ message = torch.gt(message, message_threshold).int()
192
+ return detect_prob, message
193
+
194
+ def decode_message(self, result: torch.Tensor) -> torch.Tensor:
195
+ """
196
+ Decode the message from the watermark result (batch x nbits x frames)
197
+ Args:
198
+ result: watermark result (batch x nbits x frames)
199
+ Returns:
200
+ The message of size batch x nbits, indicating probability of 1 for each bit
201
+ """
202
+ assert (result.dim() > 2 and result.shape[1] == self.nbits) or (
203
+ self.dim() == 2 and result.shape[0] == self.nbits
204
+ ), f"Expect message of size [,{self.nbits}, frames] (get {result.size()})"
205
+ decoded_message = result.mean(dim=-1)
206
+ return torch.sigmoid(decoded_message)
207
+
208
+ def forward(
209
+ self,
210
+ x: torch.Tensor,
211
+ sample_rate: Optional[int] = None,
212
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
213
+ """
214
+ Detect the watermarks from the audio signal
215
+ Args:
216
+ x: Audio signal, size batch x frames
217
+ sample_rate: The sample rate of the input audio
218
+ """
219
+ if sample_rate is None:
220
+ logger.warning(COMPATIBLE_WARNING)
221
+ sample_rate = 16_000
222
+ assert sample_rate
223
+ if sample_rate != 16000:
224
+ x = julius.resample_frac(x, old_sr=sample_rate, new_sr=16000)
225
+ result = self.detector(x) # b x 2+nbits
226
+ # hardcode softmax on 2 first units used for detection
227
+ result[:, :2, :] = torch.softmax(result[:, :2, :], dim=1)
228
+ message = self.decode_message(result[:, 2:, :])
229
+ return result[:, :2, :], message
src/audioseal/py.typed ADDED
File without changes
src/scripts/checkpoints.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ from pathlib import Path
9
+
10
+ import torch
11
+
12
+
13
+ def convert(checkpoint: str, outdir: str, suffix: str = "base"):
14
+ """Convert the checkpoint to generator and detector"""
15
+ outdir_path = Path(outdir)
16
+ ckpt = torch.load(checkpoint)
17
+
18
+ # keep inference-related params only
19
+ infer_cfg = {
20
+ "seanet": ckpt["xp.cfg"]["seanet"],
21
+ "channels": ckpt["xp.cfg"]["channels"],
22
+ "dtype": ckpt["xp.cfg"]["dtype"],
23
+ "sample_rate": ckpt["xp.cfg"]["sample_rate"],
24
+ }
25
+
26
+ generator_ckpt = {"xp.cfg": infer_cfg, "model": {}}
27
+ detector_ckpt = {"xp.cfg": infer_cfg, "model": {}}
28
+
29
+ for layer in ckpt["model"].keys():
30
+ if layer.startswith("detector"):
31
+ new_layer = layer[9:]
32
+ detector_ckpt["model"][new_layer] = ckpt["model"][layer] # type: ignore
33
+ elif layer == "msg_processor.msg_processor.0.weight":
34
+ generator_ckpt["model"]["msg_processor.msg_processor.weight"] = ckpt[ # type: ignore
35
+ "model"
36
+ ][
37
+ layer
38
+ ]
39
+ else:
40
+ assert layer.startswith("generator"), f"Invalid layer: {layer}"
41
+ new_layer = layer[10:]
42
+ generator_ckpt["model"][new_layer] = ckpt["model"][layer] # type: ignore
43
+
44
+ torch.save(generator_ckpt, outdir_path / (f"checkpoint_generator_{suffix}.pth"))
45
+ torch.save(detector_ckpt, outdir_path / (f"checkpoint_detector_{suffix}.pth"))
46
+
47
+
48
+ if __name__ == "__main__":
49
+ import fire
50
+
51
+ fire.Fire(convert)