Zw07 commited on
Commit
430043a
·
verified ·
1 Parent(s): 19cea60

Delete src

Browse files
src/audioseal/__init__.py DELETED
@@ -1,21 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- Watermarking and detection for speech audios
9
-
10
- A Pytorch-based localized algorithm for proactive detection
11
- of the watermarkings in AI-generated audios, with very fast
12
- detector.
13
-
14
- """
15
-
16
- __version__ = "0.1.4"
17
-
18
-
19
- from audioseal import builder
20
- from audioseal.loader import AudioSeal
21
- from audioseal.models import AudioSealDetector, AudioSealWM, MsgProcessor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/audioseal/builder.py DELETED
@@ -1,118 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from dataclasses import asdict, dataclass, field, is_dataclass
8
- from typing import Any, Dict, List, Optional
9
-
10
- from omegaconf import DictConfig, OmegaConf
11
- from torch import device, dtype
12
- from typing_extensions import TypeAlias
13
-
14
- from audioseal.libs import audiocraft
15
- from audioseal.models import AudioSealDetector, AudioSealWM, MsgProcessor
16
-
17
- Device: TypeAlias = device
18
-
19
- DataType: TypeAlias = dtype
20
-
21
-
22
- @dataclass
23
- class SEANetConfig:
24
- """
25
- Map common hparams of SEANet encoder and decoder.
26
- """
27
-
28
- channels: int
29
- dimension: int
30
- n_filters: int
31
- n_residual_layers: int
32
- ratios: List[int]
33
- activation: str
34
- activation_params: Dict[str, float]
35
- norm: str
36
- norm_params: Dict[str, Any]
37
- kernel_size: int
38
- last_kernel_size: int
39
- residual_kernel_size: int
40
- dilation_base: int
41
- causal: bool
42
- pad_mode: str
43
- true_skip: bool
44
- compress: int
45
- lstm: int
46
- disable_norm_outer_blocks: int
47
-
48
-
49
- @dataclass
50
- class DecoderConfig:
51
- final_activation: Optional[str]
52
- final_activation_params: Optional[dict]
53
- trim_right_ratio: float
54
-
55
-
56
- @dataclass
57
- class DetectorConfig:
58
- output_dim: int = 32
59
-
60
-
61
- @dataclass
62
- class AudioSealWMConfig:
63
- nbits: int
64
- seanet: SEANetConfig
65
- decoder: DecoderConfig
66
-
67
-
68
- @dataclass
69
- class AudioSealDetectorConfig:
70
- nbits: int
71
- seanet: SEANetConfig
72
- detector: DetectorConfig = field(default_factory=lambda: DetectorConfig())
73
-
74
-
75
- def as_dict(obj: Any) -> Dict[str, Any]:
76
- if isinstance(obj, dict):
77
- return obj
78
- if is_dataclass(obj) and not isinstance(obj, type):
79
- return asdict(obj)
80
- elif isinstance(obj, DictConfig):
81
- return OmegaConf.to_container(obj) # type: ignore
82
- else:
83
- raise NotImplementedError(f"Unsupported type for config: {type(obj)}")
84
-
85
-
86
- def create_generator(
87
- config: AudioSealWMConfig,
88
- *,
89
- device: Optional[Device] = None,
90
- dtype: Optional[DataType] = None,
91
- ) -> AudioSealWM:
92
- """Create a generator from hparams"""
93
-
94
- # Currently the encoder hparams are the same as
95
- # SEANet, but this can be changed in the future.
96
- encoder = audiocraft.modules.SEANetEncoder(**as_dict(config.seanet))
97
- encoder = encoder.to(device=device, dtype=dtype)
98
-
99
- decoder_config = {**as_dict(config.seanet), **as_dict(config.decoder)}
100
- decoder = audiocraft.modules.SEANetDecoder(**as_dict(decoder_config))
101
- decoder = decoder.to(device=device, dtype=dtype)
102
-
103
- msgprocessor = MsgProcessor(nbits=config.nbits, hidden_size=config.seanet.dimension)
104
- msgprocessor = msgprocessor.to(device=device, dtype=dtype)
105
-
106
- return AudioSealWM(encoder=encoder, decoder=decoder, msg_processor=msgprocessor)
107
-
108
-
109
- def create_detector(
110
- config: AudioSealDetectorConfig,
111
- *,
112
- device: Optional[Device] = None,
113
- dtype: Optional[DataType] = None,
114
- ) -> AudioSealDetector:
115
- detector_config = {**as_dict(config.seanet), **as_dict(config.detector)}
116
- detector = AudioSealDetector(nbits=config.nbits, **detector_config)
117
- detector = detector.to(device=device, dtype=dtype)
118
- return detector
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/audioseal/cards/audioseal_detector_16bits.yaml DELETED
@@ -1,33 +0,0 @@
1
- # @package __global__
2
-
3
- name: audioseal_detector_16bits
4
- model_type: seanet
5
- checkpoint: "https://huggingface.co/facebook/audioseal/resolve/main/detector_base.pth"
6
- nbits: 16
7
- seanet:
8
- activation: ELU
9
- activation_params:
10
- alpha: 1.0
11
- causal: false
12
- channels: 1
13
- compress: 2
14
- dilation_base: 2
15
- dimension: 128
16
- disable_norm_outer_blocks: 0
17
- kernel_size: 7
18
- last_kernel_size: 7
19
- lstm: 2
20
- n_filters: 32
21
- n_residual_layers: 1
22
- norm: weight_norm
23
- norm_params: {}
24
- pad_mode: constant
25
- ratios:
26
- - 8
27
- - 5
28
- - 4
29
- - 2
30
- residual_kernel_size: 3
31
- true_skip: true
32
- detector:
33
- output_dim: 32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/audioseal/cards/audioseal_wm_16bits.yaml DELETED
@@ -1,39 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- name: audioseal_wm_16bits
8
- model_type: seanet
9
- checkpoint: "https://huggingface.co/facebook/audioseal/resolve/main/generator_base.pth"
10
- nbits: 16
11
- seanet:
12
- activation: ELU
13
- activation_params:
14
- alpha: 1.0
15
- causal: false
16
- channels: 1
17
- compress: 2
18
- dilation_base: 2
19
- dimension: 128
20
- disable_norm_outer_blocks: 0
21
- kernel_size: 7
22
- last_kernel_size: 7
23
- lstm: 2
24
- n_filters: 32
25
- n_residual_layers: 1
26
- norm: weight_norm
27
- norm_params: {}
28
- pad_mode: constant
29
- ratios:
30
- - 8
31
- - 5
32
- - 4
33
- - 2
34
- residual_kernel_size: 3
35
- true_skip: true
36
- decoder:
37
- final_activation: null
38
- final_activation_params: null
39
- trim_right_ratio: 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/audioseal/libs/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
 
 
 
 
 
 
src/audioseal/libs/audiocraft/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
 
 
 
 
 
 
src/audioseal/libs/audiocraft/modules/__init__.py DELETED
@@ -1,8 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
-
8
- from .seanet import SEANetDecoder, SEANetEncoder, SEANetEncoderKeepDimension
 
 
 
 
 
 
 
 
 
src/audioseal/libs/audiocraft/modules/conv.py DELETED
@@ -1,337 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- #
7
- # Vendor from https://github.com/facebookresearch/audiocraft
8
-
9
- import math
10
- import typing as tp
11
- import warnings
12
-
13
- import torch
14
- from torch import nn
15
- from torch.nn import functional as F
16
- from torch.nn.utils import spectral_norm
17
-
18
- try:
19
- from torch.nn.utils.parametrizations import weight_norm
20
- except ImportError:
21
- # Old Pytorch
22
- from torch.nn.utils import weight_norm
23
-
24
-
25
- CONV_NORMALIZATIONS = frozenset(
26
- ["none", "weight_norm", "spectral_norm", "time_group_norm"]
27
- )
28
-
29
-
30
- def apply_parametrization_norm(module: nn.Module, norm: str = "none"):
31
- assert norm in CONV_NORMALIZATIONS
32
- if norm == "weight_norm":
33
- return weight_norm(module)
34
- elif norm == "spectral_norm":
35
- return spectral_norm(module)
36
- else:
37
- # We already check was in CONV_NORMALIZATION, so any other choice
38
- # doesn't need reparametrization.
39
- return module
40
-
41
-
42
- def get_norm_module(
43
- module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
44
- ):
45
- """Return the proper normalization module. If causal is True, this will ensure the returned
46
- module is causal, or return an error if the normalization doesn't support causal evaluation.
47
- """
48
- assert norm in CONV_NORMALIZATIONS
49
- if norm == "time_group_norm":
50
- if causal:
51
- raise ValueError("GroupNorm doesn't support causal evaluation.")
52
- assert isinstance(module, nn.modules.conv._ConvNd)
53
- return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
54
- else:
55
- return nn.Identity()
56
-
57
-
58
- def get_extra_padding_for_conv1d(
59
- x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
60
- ) -> int:
61
- """See `pad_for_conv1d`."""
62
- length = x.shape[-1]
63
- n_frames = (length - kernel_size + padding_total) / stride + 1
64
- ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
65
- return ideal_length - length
66
-
67
-
68
- def pad_for_conv1d(
69
- x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
70
- ):
71
- """Pad for a convolution to make sure that the last window is full.
72
- Extra padding is added at the end. This is required to ensure that we can rebuild
73
- an output of the same length, as otherwise, even with padding, some time steps
74
- might get removed.
75
- For instance, with total padding = 4, kernel size = 4, stride = 2:
76
- 0 0 1 2 3 4 5 0 0 # (0s are padding)
77
- 1 2 3 # (output frames of a convolution, last 0 is never used)
78
- 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
79
- 1 2 3 4 # once you removed padding, we are missing one time step !
80
- """
81
- extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
82
- return F.pad(x, (0, extra_padding))
83
-
84
-
85
- def pad1d(
86
- x: torch.Tensor,
87
- paddings: tp.Tuple[int, int],
88
- mode: str = "constant",
89
- value: float = 0.0,
90
- ):
91
- """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
92
- If this is the case, we insert extra 0 padding to the right before the reflection happen.
93
- """
94
- length = x.shape[-1]
95
- padding_left, padding_right = paddings
96
- assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
97
- if mode == "reflect":
98
- max_pad = max(padding_left, padding_right)
99
- extra_pad = 0
100
- if length <= max_pad:
101
- extra_pad = max_pad - length + 1
102
- x = F.pad(x, (0, extra_pad))
103
- padded = F.pad(x, paddings, mode, value)
104
- end = padded.shape[-1] - extra_pad
105
- return padded[..., :end]
106
- else:
107
- return F.pad(x, paddings, mode, value)
108
-
109
-
110
- def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
111
- """Remove padding from x, handling properly zero padding. Only for 1d!"""
112
- padding_left, padding_right = paddings
113
- assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
114
- assert (padding_left + padding_right) <= x.shape[-1]
115
- end = x.shape[-1] - padding_right
116
- return x[..., padding_left:end]
117
-
118
-
119
- class NormConv1d(nn.Module):
120
- """Wrapper around Conv1d and normalization applied to this conv
121
- to provide a uniform interface across normalization approaches.
122
- """
123
-
124
- def __init__(
125
- self,
126
- *args,
127
- causal: bool = False,
128
- norm: str = "none",
129
- norm_kwargs: tp.Dict[str, tp.Any] = {},
130
- **kwargs,
131
- ):
132
- super().__init__()
133
- self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
134
- self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
135
- self.norm_type = norm
136
-
137
- def forward(self, x):
138
- x = self.conv(x)
139
- x = self.norm(x)
140
- return x
141
-
142
-
143
- class NormConv2d(nn.Module):
144
- """Wrapper around Conv2d and normalization applied to this conv
145
- to provide a uniform interface across normalization approaches.
146
- """
147
-
148
- def __init__(
149
- self,
150
- *args,
151
- norm: str = "none",
152
- norm_kwargs: tp.Dict[str, tp.Any] = {},
153
- **kwargs,
154
- ):
155
- super().__init__()
156
- self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
157
- self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
158
- self.norm_type = norm
159
-
160
- def forward(self, x):
161
- x = self.conv(x)
162
- x = self.norm(x)
163
- return x
164
-
165
-
166
- class NormConvTranspose1d(nn.Module):
167
- """Wrapper around ConvTranspose1d and normalization applied to this conv
168
- to provide a uniform interface across normalization approaches.
169
- """
170
-
171
- def __init__(
172
- self,
173
- *args,
174
- causal: bool = False,
175
- norm: str = "none",
176
- norm_kwargs: tp.Dict[str, tp.Any] = {},
177
- **kwargs,
178
- ):
179
- super().__init__()
180
- self.convtr = apply_parametrization_norm(
181
- nn.ConvTranspose1d(*args, **kwargs), norm
182
- )
183
- self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
184
- self.norm_type = norm
185
-
186
- def forward(self, x):
187
- x = self.convtr(x)
188
- x = self.norm(x)
189
- return x
190
-
191
-
192
- class NormConvTranspose2d(nn.Module):
193
- """Wrapper around ConvTranspose2d and normalization applied to this conv
194
- to provide a uniform interface across normalization approaches.
195
- """
196
-
197
- def __init__(
198
- self,
199
- *args,
200
- norm: str = "none",
201
- norm_kwargs: tp.Dict[str, tp.Any] = {},
202
- **kwargs,
203
- ):
204
- super().__init__()
205
- self.convtr = apply_parametrization_norm(
206
- nn.ConvTranspose2d(*args, **kwargs), norm
207
- )
208
- self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
209
-
210
- def forward(self, x):
211
- x = self.convtr(x)
212
- x = self.norm(x)
213
- return x
214
-
215
-
216
- class StreamableConv1d(nn.Module):
217
- """Conv1d with some builtin handling of asymmetric or causal padding
218
- and normalization.
219
- """
220
-
221
- def __init__(
222
- self,
223
- in_channels: int,
224
- out_channels: int,
225
- kernel_size: int,
226
- stride: int = 1,
227
- dilation: int = 1,
228
- groups: int = 1,
229
- bias: bool = True,
230
- causal: bool = False,
231
- norm: str = "none",
232
- norm_kwargs: tp.Dict[str, tp.Any] = {},
233
- pad_mode: str = "reflect",
234
- ):
235
- super().__init__()
236
- # warn user on unusual setup between dilation and stride
237
- if stride > 1 and dilation > 1:
238
- warnings.warn(
239
- "StreamableConv1d has been initialized with stride > 1 and dilation > 1"
240
- f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
241
- )
242
- self.conv = NormConv1d(
243
- in_channels,
244
- out_channels,
245
- kernel_size,
246
- stride,
247
- dilation=dilation,
248
- groups=groups,
249
- bias=bias,
250
- causal=causal,
251
- norm=norm,
252
- norm_kwargs=norm_kwargs,
253
- )
254
- self.causal = causal
255
- self.pad_mode = pad_mode
256
-
257
- def forward(self, x):
258
- B, C, T = x.shape
259
- kernel_size = self.conv.conv.kernel_size[0]
260
- stride = self.conv.conv.stride[0]
261
- dilation = self.conv.conv.dilation[0]
262
- kernel_size = (
263
- kernel_size - 1
264
- ) * dilation + 1 # effective kernel size with dilations
265
- padding_total = kernel_size - stride
266
- extra_padding = get_extra_padding_for_conv1d(
267
- x, kernel_size, stride, padding_total
268
- )
269
- if self.causal:
270
- # Left padding for causal
271
- x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
272
- else:
273
- # Asymmetric padding required for odd strides
274
- padding_right = padding_total // 2
275
- padding_left = padding_total - padding_right
276
- x = pad1d(
277
- x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
278
- )
279
- return self.conv(x)
280
-
281
-
282
- class StreamableConvTranspose1d(nn.Module):
283
- """ConvTranspose1d with some builtin handling of asymmetric or causal padding
284
- and normalization.
285
- """
286
-
287
- def __init__(
288
- self,
289
- in_channels: int,
290
- out_channels: int,
291
- kernel_size: int,
292
- stride: int = 1,
293
- causal: bool = False,
294
- norm: str = "none",
295
- trim_right_ratio: float = 1.0,
296
- norm_kwargs: tp.Dict[str, tp.Any] = {},
297
- ):
298
- super().__init__()
299
- self.convtr = NormConvTranspose1d(
300
- in_channels,
301
- out_channels,
302
- kernel_size,
303
- stride,
304
- causal=causal,
305
- norm=norm,
306
- norm_kwargs=norm_kwargs,
307
- )
308
- self.causal = causal
309
- self.trim_right_ratio = trim_right_ratio
310
- assert (
311
- self.causal or self.trim_right_ratio == 1.0
312
- ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
313
- assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
314
-
315
- def forward(self, x):
316
- kernel_size = self.convtr.convtr.kernel_size[0]
317
- stride = self.convtr.convtr.stride[0]
318
- padding_total = kernel_size - stride
319
-
320
- y = self.convtr(x)
321
-
322
- # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
323
- # removed at the very end, when keeping only the right length for the output,
324
- # as removing it here would require also passing the length at the matching layer
325
- # in the encoder.
326
- if self.causal:
327
- # Trim the padding on the right according to the specified ratio
328
- # if trim_right_ratio = 1.0, trim everything from right
329
- padding_right = math.ceil(padding_total * self.trim_right_ratio)
330
- padding_left = padding_total - padding_right
331
- y = unpad1d(y, (padding_left, padding_right))
332
- else:
333
- # Asymmetric padding required for odd strides
334
- padding_right = padding_total // 2
335
- padding_left = padding_total - padding_right
336
- y = unpad1d(y, (padding_left, padding_right))
337
- return y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/audioseal/libs/audiocraft/modules/lstm.py DELETED
@@ -1,28 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- #
7
- # Vendor from https://github.com/facebookresearch/audiocraft
8
-
9
- from torch import nn
10
-
11
-
12
- class StreamableLSTM(nn.Module):
13
- """LSTM without worrying about the hidden state, nor the layout of the data.
14
- Expects input as convolutional layout.
15
- """
16
-
17
- def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
18
- super().__init__()
19
- self.skip = skip
20
- self.lstm = nn.LSTM(dimension, dimension, num_layers)
21
-
22
- def forward(self, x):
23
- x = x.permute(2, 0, 1)
24
- y, _ = self.lstm(x)
25
- if self.skip:
26
- y = y + x
27
- y = y.permute(1, 2, 0)
28
- return y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/audioseal/libs/audiocraft/modules/seanet.py DELETED
@@ -1,426 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- #
7
- # Vendor from https://github.com/facebookresearch/audiocraft
8
-
9
- import math
10
- import typing as tp
11
-
12
- import numpy as np
13
- import torch.nn as nn
14
-
15
- from audioseal.libs.audiocraft.modules.conv import (
16
- StreamableConv1d,
17
- StreamableConvTranspose1d,
18
- )
19
- from audioseal.libs.audiocraft.modules.lstm import StreamableLSTM
20
-
21
-
22
- class SEANetResnetBlock(nn.Module):
23
- """Residual block from SEANet model.
24
-
25
- Args:
26
- dim (int): Dimension of the input/output.
27
- kernel_sizes (list): List of kernel sizes for the convolutions.
28
- dilations (list): List of dilations for the convolutions.
29
- activation (str): Activation function.
30
- activation_params (dict): Parameters to provide to the activation function.
31
- norm (str): Normalization method.
32
- norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
33
- causal (bool): Whether to use fully causal convolution.
34
- pad_mode (str): Padding mode for the convolutions.
35
- compress (int): Reduced dimensionality in residual branches (from Demucs v3).
36
- true_skip (bool): Whether to use true skip connection or a simple
37
- (streamable) convolution as the skip connection.
38
- """
39
-
40
- def __init__(
41
- self,
42
- dim: int,
43
- kernel_sizes: tp.List[int] = [3, 1],
44
- dilations: tp.List[int] = [1, 1],
45
- activation: str = "ELU",
46
- activation_params: dict = {"alpha": 1.0},
47
- norm: str = "none",
48
- norm_params: tp.Dict[str, tp.Any] = {},
49
- causal: bool = False,
50
- pad_mode: str = "reflect",
51
- compress: int = 2,
52
- true_skip: bool = True,
53
- ):
54
- super().__init__()
55
- assert len(kernel_sizes) == len(
56
- dilations
57
- ), "Number of kernel sizes should match number of dilations"
58
- act = getattr(nn, activation)
59
- hidden = dim // compress
60
- block = []
61
- for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
62
- in_chs = dim if i == 0 else hidden
63
- out_chs = dim if i == len(kernel_sizes) - 1 else hidden
64
- block += [
65
- act(**activation_params),
66
- StreamableConv1d(
67
- in_chs,
68
- out_chs,
69
- kernel_size=kernel_size,
70
- dilation=dilation,
71
- norm=norm,
72
- norm_kwargs=norm_params,
73
- causal=causal,
74
- pad_mode=pad_mode,
75
- ),
76
- ]
77
- self.block = nn.Sequential(*block)
78
- self.shortcut: nn.Module
79
- if true_skip:
80
- self.shortcut = nn.Identity()
81
- else:
82
- self.shortcut = StreamableConv1d(
83
- dim,
84
- dim,
85
- kernel_size=1,
86
- norm=norm,
87
- norm_kwargs=norm_params,
88
- causal=causal,
89
- pad_mode=pad_mode,
90
- )
91
-
92
- def forward(self, x):
93
- return self.shortcut(x) + self.block(x)
94
-
95
-
96
- class SEANetEncoder(nn.Module):
97
- """SEANet encoder.
98
-
99
- Args:
100
- channels (int): Audio channels.
101
- dimension (int): Intermediate representation dimension.
102
- n_filters (int): Base width for the model.
103
- n_residual_layers (int): nb of residual layers.
104
- ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
105
- upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
106
- that must match the decoder order. We use the decoder order as some models may only employ the decoder.
107
- activation (str): Activation function.
108
- activation_params (dict): Parameters to provide to the activation function.
109
- norm (str): Normalization method.
110
- norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
111
- kernel_size (int): Kernel size for the initial convolution.
112
- last_kernel_size (int): Kernel size for the initial convolution.
113
- residual_kernel_size (int): Kernel size for the residual layers.
114
- dilation_base (int): How much to increase the dilation with each layer.
115
- causal (bool): Whether to use fully causal convolution.
116
- pad_mode (str): Padding mode for the convolutions.
117
- true_skip (bool): Whether to use true skip connection or a simple
118
- (streamable) convolution as the skip connection in the residual network blocks.
119
- compress (int): Reduced dimensionality in residual branches (from Demucs v3).
120
- lstm (int): Number of LSTM layers at the end of the encoder.
121
- disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
122
- For the encoder, it corresponds to the N first blocks.
123
- """
124
-
125
- def __init__(
126
- self,
127
- channels: int = 1,
128
- dimension: int = 128,
129
- n_filters: int = 32,
130
- n_residual_layers: int = 3,
131
- ratios: tp.List[int] = [8, 5, 4, 2],
132
- activation: str = "ELU",
133
- activation_params: dict = {"alpha": 1.0},
134
- norm: str = "none",
135
- norm_params: tp.Dict[str, tp.Any] = {},
136
- kernel_size: int = 7,
137
- last_kernel_size: int = 7,
138
- residual_kernel_size: int = 3,
139
- dilation_base: int = 2,
140
- causal: bool = False,
141
- pad_mode: str = "reflect",
142
- true_skip: bool = True,
143
- compress: int = 2,
144
- lstm: int = 0,
145
- disable_norm_outer_blocks: int = 0,
146
- ):
147
- super().__init__()
148
- self.channels = channels
149
- self.dimension = dimension
150
- self.n_filters = n_filters
151
- self.ratios = list(reversed(ratios))
152
- del ratios
153
- self.n_residual_layers = n_residual_layers
154
- self.hop_length = np.prod(self.ratios)
155
- self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
156
- self.disable_norm_outer_blocks = disable_norm_outer_blocks
157
- assert (
158
- self.disable_norm_outer_blocks >= 0
159
- and self.disable_norm_outer_blocks <= self.n_blocks
160
- ), (
161
- "Number of blocks for which to disable norm is invalid."
162
- "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
163
- )
164
-
165
- act = getattr(nn, activation)
166
- mult = 1
167
- model: tp.List[nn.Module] = [
168
- StreamableConv1d(
169
- channels,
170
- mult * n_filters,
171
- kernel_size,
172
- norm="none" if self.disable_norm_outer_blocks >= 1 else norm,
173
- norm_kwargs=norm_params,
174
- causal=causal,
175
- pad_mode=pad_mode,
176
- )
177
- ]
178
- # Downsample to raw audio scale
179
- for i, ratio in enumerate(self.ratios):
180
- block_norm = "none" if self.disable_norm_outer_blocks >= i + 2 else norm
181
- # Add residual layers
182
- for j in range(n_residual_layers):
183
- model += [
184
- SEANetResnetBlock(
185
- mult * n_filters,
186
- kernel_sizes=[residual_kernel_size, 1],
187
- dilations=[dilation_base**j, 1],
188
- norm=block_norm,
189
- norm_params=norm_params,
190
- activation=activation,
191
- activation_params=activation_params,
192
- causal=causal,
193
- pad_mode=pad_mode,
194
- compress=compress,
195
- true_skip=true_skip,
196
- )
197
- ]
198
-
199
- # Add downsampling layers
200
- model += [
201
- act(**activation_params),
202
- StreamableConv1d(
203
- mult * n_filters,
204
- mult * n_filters * 2,
205
- kernel_size=ratio * 2,
206
- stride=ratio,
207
- norm=block_norm,
208
- norm_kwargs=norm_params,
209
- causal=causal,
210
- pad_mode=pad_mode,
211
- ),
212
- ]
213
- mult *= 2
214
-
215
- if lstm:
216
- model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
217
-
218
- model += [
219
- act(**activation_params),
220
- StreamableConv1d(
221
- mult * n_filters,
222
- dimension,
223
- last_kernel_size,
224
- norm=(
225
- "none" if self.disable_norm_outer_blocks == self.n_blocks else norm
226
- ),
227
- norm_kwargs=norm_params,
228
- causal=causal,
229
- pad_mode=pad_mode,
230
- ),
231
- ]
232
-
233
- self.model = nn.Sequential(*model)
234
-
235
- def forward(self, x):
236
- return self.model(x)
237
-
238
-
239
- class SEANetEncoderKeepDimension(SEANetEncoder):
240
- """
241
- similar architecture to the SEANet encoder but with an extra step that
242
- projects the output dimension to the same input dimension by repeating
243
- the sequential
244
-
245
- Args:
246
- SEANetEncoder (_type_): _description_
247
- """
248
-
249
- def __init__(self, *args, **kwargs):
250
-
251
- self.output_dim = kwargs.pop("output_dim")
252
- super().__init__(*args, **kwargs)
253
- # Adding a reverse convolution layer
254
- self.reverse_convolution = nn.ConvTranspose1d(
255
- in_channels=self.dimension,
256
- out_channels=self.output_dim,
257
- kernel_size=math.prod(self.ratios),
258
- stride=math.prod(self.ratios),
259
- padding=0,
260
- )
261
-
262
- def forward(self, x):
263
- orig_nframes = x.shape[-1]
264
- x = self.model(x)
265
- x = self.reverse_convolution(x)
266
- # make sure dim didn't change
267
- return x[:, :, :orig_nframes]
268
-
269
-
270
- class SEANetDecoder(nn.Module):
271
- """SEANet decoder.
272
-
273
- Args:
274
- channels (int): Audio channels.
275
- dimension (int): Intermediate representation dimension.
276
- n_filters (int): Base width for the model.
277
- n_residual_layers (int): nb of residual layers.
278
- ratios (Sequence[int]): kernel size and stride ratios.
279
- activation (str): Activation function.
280
- activation_params (dict): Parameters to provide to the activation function.
281
- final_activation (str): Final activation function after all convolutions.
282
- final_activation_params (dict): Parameters to provide to the activation function.
283
- norm (str): Normalization method.
284
- norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
285
- kernel_size (int): Kernel size for the initial convolution.
286
- last_kernel_size (int): Kernel size for the initial convolution.
287
- residual_kernel_size (int): Kernel size for the residual layers.
288
- dilation_base (int): How much to increase the dilation with each layer.
289
- causal (bool): Whether to use fully causal convolution.
290
- pad_mode (str): Padding mode for the convolutions.
291
- true_skip (bool): Whether to use true skip connection or a simple.
292
- (streamable) convolution as the skip connection in the residual network blocks.
293
- compress (int): Reduced dimensionality in residual branches (from Demucs v3).
294
- lstm (int): Number of LSTM layers at the end of the encoder.
295
- disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
296
- For the decoder, it corresponds to the N last blocks.
297
- trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
298
- If equal to 1.0, it means that all the trimming is done at the right.
299
- """
300
-
301
- def __init__(
302
- self,
303
- channels: int = 1,
304
- dimension: int = 128,
305
- n_filters: int = 32,
306
- n_residual_layers: int = 3,
307
- ratios: tp.List[int] = [8, 5, 4, 2],
308
- activation: str = "ELU",
309
- activation_params: dict = {"alpha": 1.0},
310
- final_activation: tp.Optional[str] = None,
311
- final_activation_params: tp.Optional[dict] = None,
312
- norm: str = "none",
313
- norm_params: tp.Dict[str, tp.Any] = {},
314
- kernel_size: int = 7,
315
- last_kernel_size: int = 7,
316
- residual_kernel_size: int = 3,
317
- dilation_base: int = 2,
318
- causal: bool = False,
319
- pad_mode: str = "reflect",
320
- true_skip: bool = True,
321
- compress: int = 2,
322
- lstm: int = 0,
323
- disable_norm_outer_blocks: int = 0,
324
- trim_right_ratio: float = 1.0,
325
- ):
326
- super().__init__()
327
- self.dimension = dimension
328
- self.channels = channels
329
- self.n_filters = n_filters
330
- self.ratios = ratios
331
- del ratios
332
- self.n_residual_layers = n_residual_layers
333
- self.hop_length = np.prod(self.ratios)
334
- self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
335
- self.disable_norm_outer_blocks = disable_norm_outer_blocks
336
- assert (
337
- self.disable_norm_outer_blocks >= 0
338
- and self.disable_norm_outer_blocks <= self.n_blocks
339
- ), (
340
- "Number of blocks for which to disable norm is invalid."
341
- "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
342
- )
343
-
344
- act = getattr(nn, activation)
345
- mult = int(2 ** len(self.ratios))
346
- model: tp.List[nn.Module] = [
347
- StreamableConv1d(
348
- dimension,
349
- mult * n_filters,
350
- kernel_size,
351
- norm=(
352
- "none" if self.disable_norm_outer_blocks == self.n_blocks else norm
353
- ),
354
- norm_kwargs=norm_params,
355
- causal=causal,
356
- pad_mode=pad_mode,
357
- )
358
- ]
359
-
360
- if lstm:
361
- model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
362
-
363
- # Upsample to raw audio scale
364
- for i, ratio in enumerate(self.ratios):
365
- block_norm = (
366
- "none"
367
- if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1)
368
- else norm
369
- )
370
- # Add upsampling layers
371
- model += [
372
- act(**activation_params),
373
- StreamableConvTranspose1d(
374
- mult * n_filters,
375
- mult * n_filters // 2,
376
- kernel_size=ratio * 2,
377
- stride=ratio,
378
- norm=block_norm,
379
- norm_kwargs=norm_params,
380
- causal=causal,
381
- trim_right_ratio=trim_right_ratio,
382
- ),
383
- ]
384
- # Add residual layers
385
- for j in range(n_residual_layers):
386
- model += [
387
- SEANetResnetBlock(
388
- mult * n_filters // 2,
389
- kernel_sizes=[residual_kernel_size, 1],
390
- dilations=[dilation_base**j, 1],
391
- activation=activation,
392
- activation_params=activation_params,
393
- norm=block_norm,
394
- norm_params=norm_params,
395
- causal=causal,
396
- pad_mode=pad_mode,
397
- compress=compress,
398
- true_skip=true_skip,
399
- )
400
- ]
401
-
402
- mult //= 2
403
-
404
- # Add final layers
405
- model += [
406
- act(**activation_params),
407
- StreamableConv1d(
408
- n_filters,
409
- channels,
410
- last_kernel_size,
411
- norm="none" if self.disable_norm_outer_blocks >= 1 else norm,
412
- norm_kwargs=norm_params,
413
- causal=causal,
414
- pad_mode=pad_mode,
415
- ),
416
- ]
417
- # Add optional final activation to decoder (eg. tanh)
418
- if final_activation is not None:
419
- final_act = getattr(nn, final_activation)
420
- final_activation_params = final_activation_params or {}
421
- model += [final_act(**final_activation_params)]
422
- self.model = nn.Sequential(*model)
423
-
424
- def forward(self, z):
425
- y = self.model(z)
426
- return y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/audioseal/loader.py DELETED
@@ -1,227 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
-
8
- import os
9
- from dataclasses import fields
10
- from hashlib import sha1
11
- from pathlib import Path
12
- from typing import ( # type: ignore[attr-defined]
13
- Any,
14
- Dict,
15
- List,
16
- Optional,
17
- Tuple,
18
- Type,
19
- TypeVar,
20
- Union,
21
- cast,
22
- )
23
- from urllib.parse import urlparse # noqa: F401
24
-
25
- import torch
26
- from omegaconf import DictConfig, OmegaConf
27
-
28
- import audioseal
29
- from audioseal.builder import (
30
- AudioSealDetectorConfig,
31
- AudioSealWMConfig,
32
- create_detector,
33
- create_generator,
34
- )
35
- from audioseal.models import AudioSealDetector, AudioSealWM
36
-
37
- AudioSealT = TypeVar("AudioSealT", AudioSealWMConfig, AudioSealDetectorConfig)
38
-
39
-
40
- class ModelLoadError(RuntimeError):
41
- """Raised when the model loading fails"""
42
-
43
-
44
- def _get_path_from_env(var_name: str) -> Optional[Path]:
45
- pathname = os.getenv(var_name)
46
- if not pathname:
47
- return None
48
-
49
- try:
50
- return Path(pathname)
51
- except ValueError as ex:
52
- raise RuntimeError(f"Expect valid pathname, get '{pathname}'.") from ex
53
-
54
-
55
- def _get_cache_dir(env_names: List[str]):
56
- """Re-use cache dir from a list of existing caches"""
57
- for env in env_names:
58
- cache_dir = _get_path_from_env(env)
59
- if cache_dir:
60
- break
61
- else:
62
- cache_dir = Path("~/.cache").expanduser().resolve()
63
-
64
- # Create a sub-dir to not mess up with existing caches
65
- cache_dir = cache_dir / "audioseal"
66
- cache_dir.mkdir(exist_ok=True, parents=True)
67
-
68
- return cache_dir
69
-
70
-
71
- def load_model_checkpoint(
72
- model_path: Union[Path, str],
73
- device: Union[str, torch.device] = "cpu",
74
- ):
75
- if Path(model_path).is_file():
76
- return torch.load(model_path, map_location=device)
77
-
78
- cache_dir = _get_cache_dir(
79
- ["AUDIOSEAL_CACHE_DIR", "AUDIOCRAFT_CACHE_DIR", "XDG_CACHE_HOME"]
80
- )
81
- parts = urlparse(str(model_path))
82
- if parts.scheme == "https":
83
-
84
- hash_ = sha1(parts.path.encode()).hexdigest()[:24]
85
- return torch.hub.load_state_dict_from_url(
86
- str(model_path), model_dir=cache_dir, map_location=device, file_name=hash_
87
- )
88
- elif str(model_path).startswith("facebook/audioseal/"):
89
- hf_filename = str(model_path)[len("facebook/audioseal/") :]
90
-
91
- try:
92
- from huggingface_hub import hf_hub_download
93
- except ModuleNotFoundError:
94
- print(
95
- f"The model path {model_path} seems to be a direct HF path, "
96
- "but you do not install Huggingface_hub. Install with for example "
97
- "`pip install huggingface_hub` to use this feature."
98
- )
99
- file = hf_hub_download(
100
- repo_id="facebook/audioseal",
101
- filename=hf_filename,
102
- cache_dir=cache_dir,
103
- library_name="audioseal",
104
- library_version=audioseal.__version__,
105
- )
106
- return torch.load(file, map_location=device)
107
- else:
108
- raise ModelLoadError(f"Path or uri {model_path} is unknown or does not exist")
109
-
110
-
111
- def load_local_model_config(model_card: str) -> Optional[DictConfig]:
112
- config_file = Path(__file__).parent / "cards" / (model_card + ".yaml")
113
- if Path(config_file).is_file():
114
- return cast(DictConfig, OmegaConf.load(config_file.resolve()))
115
- else:
116
- return None
117
-
118
-
119
- class AudioSeal:
120
-
121
- @staticmethod
122
- def parse_model(
123
- model_card_or_path: str,
124
- model_type: Type[AudioSealT],
125
- nbits: Optional[int] = None,
126
- ) -> Tuple[Dict[str, Any], AudioSealT]:
127
- """
128
- Parse the information from the model card or checkpoint path using
129
- the schema `model_type` that defines the model type
130
- """
131
- # Get the raw checkpoint and config from the local model cards
132
- config = load_local_model_config(model_card_or_path)
133
-
134
- if config:
135
- assert "checkpoint" in config, f"Checkpoint missing in {model_card_or_path}"
136
- config_dict = OmegaConf.to_container(config)
137
- assert isinstance(
138
- config_dict, dict
139
- ), f"Cannot parse config from {model_card_or_path}"
140
- checkpoint = config_dict.pop("checkpoint")
141
- checkpoint = load_model_checkpoint(checkpoint)
142
-
143
- # Get the raw checkpoint and config from the checkpoint path
144
- else:
145
- config_dict = {}
146
- checkpoint = load_model_checkpoint(model_card_or_path)
147
-
148
- if "xp.cfg" in checkpoint:
149
- config_dict = {**checkpoint["xp.cfg"], **config_dict} # type: ignore
150
-
151
- model_config = AudioSeal.parse_config(config_dict, config_type=model_type, nbits=nbits) # type: ignore
152
-
153
- if "model" in checkpoint:
154
- checkpoint = checkpoint["model"]
155
-
156
- return checkpoint, model_config
157
-
158
- @staticmethod
159
- def parse_config(
160
- config: Dict[str, Any],
161
- config_type: Type[AudioSealT],
162
- nbits: Optional[int] = None,
163
- ) -> AudioSealT:
164
-
165
- assert "seanet" in config, f"missing seanet backbone config in {config}"
166
-
167
- # Patch 1: Resolve the variables in the checkpoint
168
- config = OmegaConf.create(config) # type: ignore
169
- OmegaConf.resolve(config) # type: ignore
170
- config = OmegaConf.to_container(config) # type: ignore
171
-
172
- # Patch 2: Put decoder, encoder and detector outside seanet
173
- seanet_config = config["seanet"]
174
- for key_to_patch in ["encoder", "decoder", "detector"]:
175
- if key_to_patch in seanet_config:
176
- config_to_patch = config.get(key_to_patch) or {}
177
- config[key_to_patch] = {
178
- **config_to_patch,
179
- **seanet_config.pop(key_to_patch),
180
- }
181
-
182
- config["seanet"] = seanet_config
183
-
184
- # Patch 3: Put nbits into config if specified
185
- if nbits and "nbits" not in config:
186
- config["nbits"] = nbits
187
-
188
- # remove attributes not related to the model_type
189
- result_config = {}
190
- assert config, f"Empty config"
191
- for field in fields(config_type):
192
- if field.name in config:
193
- result_config[field.name] = config[field.name]
194
-
195
- schema = OmegaConf.structured(config_type)
196
- schema.merge_with(result_config)
197
- return schema
198
-
199
- @staticmethod
200
- def load_generator(
201
- model_card_or_path: str,
202
- nbits: Optional[int] = None,
203
- ) -> AudioSealWM:
204
- """Load the AudioSeal generator from the model card"""
205
- checkpoint, config = AudioSeal.parse_model(
206
- model_card_or_path,
207
- AudioSealWMConfig,
208
- nbits=nbits,
209
- )
210
-
211
- model = create_generator(config)
212
- model.load_state_dict(checkpoint)
213
- return model
214
-
215
- @staticmethod
216
- def load_detector(
217
- model_card_or_path: str,
218
- nbits: Optional[int] = None,
219
- ) -> AudioSealDetector:
220
- checkpoint, config = AudioSeal.parse_model(
221
- model_card_or_path,
222
- AudioSealDetectorConfig,
223
- nbits=nbits,
224
- )
225
- model = create_detector(config)
226
- model.load_state_dict(checkpoint)
227
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/audioseal/models.py DELETED
@@ -1,175 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import logging
8
- from typing import Optional, Tuple
9
- import librosa
10
- import numpy as np
11
- import torch
12
-
13
- from audioseal.libs.audiocraft.modules.seanet import SEANetEncoderKeepDimension
14
-
15
- logger = logging.getLogger("Audioseal")
16
-
17
- COMPATIBLE_WARNING = """
18
- AudioSeal is designed to work at a sample rate 16khz.
19
- Implicit sampling rate usage is deprecated and will be removed in future version.
20
- To remove this warning please add this argument to the function call:
21
- sample_rate = your_sample_rate
22
- """
23
-
24
- class MsgProcessor(torch.nn.Module):
25
- def __init__(self, nbits: int, hidden_size: int):
26
- super().__init__()
27
- assert nbits > 0, "MsgProcessor should not be built in 0bit watermarking"
28
- self.nbits = nbits
29
- self.hidden_size = hidden_size
30
- self.msg_processor = torch.nn.Embedding(2 * nbits, hidden_size)
31
-
32
- def forward(self, hidden: torch.Tensor, msg: torch.Tensor) -> torch.Tensor:
33
- indices = 2 * torch.arange(msg.shape[-1]).to(msg.device)
34
- indices = indices.repeat(msg.shape[0], 1)
35
- indices = (indices + msg).long()
36
- msg_aux = self.msg_processor(indices)
37
- msg_aux = msg_aux.sum(dim=-2)
38
- msg_aux = msg_aux.unsqueeze(-1).repeat(1, 1, hidden.shape[2])
39
- hidden = hidden + msg_aux
40
- return hidden
41
-
42
- def compute_stft_energy(audio: torch.Tensor, sr: int, n_fft: int = 2048, hop_length: int = 512) -> torch.Tensor:
43
- batch_size = audio.size(0)
44
- energy_values = []
45
-
46
- for i in range(batch_size):
47
- y = audio[i].cpu().numpy()
48
- stft = np.abs(librosa.stft(y, n_fft=n_fft, hop_length=hop_length))
49
- frame_energy = torch.tensor(np.sum(stft ** 2, axis=0), device=audio.device)
50
- energy_values.append(frame_energy)
51
-
52
- energy_values = torch.stack(energy_values, dim=0)
53
- return energy_values
54
-
55
- def compute_adaptive_alpha_librosa(energy_values: torch.Tensor, min_alpha: float = 0.5, max_alpha: float = 1.5) -> torch.Tensor:
56
- normalized_energy = (energy_values - energy_values.min(dim=1, keepdim=True)[0]) / (
57
- energy_values.max(dim=1, keepdim=True)[0] - energy_values.min(dim=1, keepdim=True)[0] + 1e-6
58
- )
59
- alpha_values = min_alpha + normalized_energy * (max_alpha - min_alpha)
60
- return alpha_values
61
-
62
- class AudioSealWM(torch.nn.Module):
63
- def __init__(self, encoder: torch.nn.Module, decoder: torch.nn.Module, msg_processor: Optional[torch.nn.Module] = None):
64
- super().__init__()
65
- self.encoder = encoder
66
- self.decoder = decoder
67
- self.msg_processor = msg_processor
68
- self._message: Optional[torch.Tensor] = None
69
- self._original_payload: Optional[torch.Tensor] = None
70
-
71
- @property
72
- def message(self) -> Optional[torch.Tensor]:
73
- return self._message
74
-
75
- @message.setter
76
- def message(self, message: torch.Tensor) -> None:
77
- self._message = message
78
-
79
- def get_original_payload(self) -> Optional[torch.Tensor]:
80
- return self._original_payload
81
-
82
- def get_watermark(self, x: torch.Tensor, sample_rate: Optional[int] = None, message: Optional[torch.Tensor] = None) -> torch.Tensor:
83
- # Call the forward method manually here
84
- return self.forward(x, sample_rate, message)
85
-
86
- def forward(self, x: torch.Tensor, sample_rate: Optional[int] = None, message: Optional[torch.Tensor] = None,
87
- n_fft: int = 2048, hop_length: int = 512, min_alpha: float = 0.5, max_alpha: float = 1.5) -> torch.Tensor:
88
- print("Forward method called!") # This should always print if forward is being executed
89
- if sample_rate is None:
90
- logger.warning(COMPATIBLE_WARNING)
91
- sample_rate = 16_000
92
-
93
- if sample_rate != 16000:
94
- x_np = x.detach().cpu().numpy() # Ensure detached tensor is converted to NumPy array
95
- resampled_x = librosa.resample(x_np, orig_sr=sample_rate, target_sr=16000)
96
- x = torch.tensor(resampled_x, device=x.device)
97
-
98
- hidden = self.encoder(x)
99
-
100
- if self.msg_processor is not None:
101
- if message is None:
102
- if self.message is None:
103
- message = torch.randint(0, 2, (x.shape[0], self.msg_processor.nbits), device=x.device)
104
- else:
105
- message = self.message.to(device=x.device)
106
- else:
107
- message = message.to(device=x.device)
108
-
109
- hidden = self.msg_processor(hidden, message)
110
- self._original_payload = message
111
-
112
- watermark = self.decoder(hidden)
113
-
114
- if sample_rate != 16000:
115
- watermark_np = watermark.detach().cpu().numpy()
116
- resampled_watermark = librosa.resample(watermark_np, orig_sr=16000, target_sr=sample_rate)
117
- watermark = torch.tensor(resampled_watermark, device=watermark.device)
118
-
119
- energy_values = compute_stft_energy(x, sr=sample_rate, n_fft=n_fft, hop_length=hop_length)
120
- adaptive_alpha = compute_adaptive_alpha_librosa(energy_values, min_alpha=min_alpha, max_alpha=max_alpha)
121
-
122
- # Adjust stretched_alpha to match the dimensions of watermark
123
- num_frames = adaptive_alpha.size(1)
124
- stretched_alpha = torch.repeat_interleave(adaptive_alpha, hop_length, dim=1)
125
- stretched_alpha = stretched_alpha[:, :x.size(1)]
126
-
127
- # Make sure dimensions align
128
- if stretched_alpha.dim() < watermark.dim():
129
- stretched_alpha = stretched_alpha.unsqueeze(-1) # Add extra dimension
130
-
131
- stretched_alpha = stretched_alpha.expand_as(watermark) # Match dimensions
132
- print(f"stretched_alpha shape: {stretched_alpha.shape} for debugging")
133
-
134
- watermarked_audio = x + stretched_alpha * watermark
135
-
136
- return watermarked_audio
137
-
138
- class AudioSealDetector(torch.nn.Module):
139
- def __init__(self, *args, nbits: int = 0, **kwargs):
140
- super().__init__()
141
- encoder = SEANetEncoderKeepDimension(*args, **kwargs)
142
- last_layer = torch.nn.Conv1d(encoder.output_dim, 2 + nbits, 1)
143
- self.detector = torch.nn.Sequential(encoder, last_layer)
144
- self.nbits = nbits
145
-
146
- def detect_watermark(self, x: torch.Tensor, sample_rate: Optional[int] = None, message_threshold: float = 0.5) -> Tuple[float, torch.Tensor]:
147
- result, message = self.forward(x, sample_rate=sample_rate)
148
- print("Forward method in detector called!")
149
- detected = (torch.count_nonzero(torch.gt(result[:, 1, :], 0.5)) / result.shape[-1])
150
- detect_prob = detected.cpu().item()
151
- message = torch.gt(message, message_threshold).int()
152
- return detect_prob, message
153
-
154
- def decode_message(self, result: torch.Tensor) -> torch.Tensor:
155
- assert (result.dim() > 2 and result.shape[1] == self.nbits) or (
156
- result.dim() == 2 and result.shape[0] == self.nbits
157
- ), f"Expect message of size [,{self.nbits}, frames] (get {result.size()})"
158
- decoded_message = result.mean(dim=-1)
159
- return torch.sigmoid(decoded_message)
160
-
161
- def forward(self, x: torch.Tensor, sample_rate: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
162
- if sample_rate is None:
163
- logger.warning(COMPATIBLE_WARNING)
164
- sample_rate = 16_000
165
-
166
- if sample_rate != 16000:
167
- x_np = x.detach().cpu().numpy()
168
- resampled_x = librosa.resample(x_np, orig_sr=sample_rate, target_sr=16000)
169
- x = torch.tensor(resampled_x, device=x.device)
170
-
171
- result = self.detector(x)
172
- result[:, :2, :] = torch.softmax(result[:, :2, :], dim=1)
173
- message = self.decode_message(result[:, 2:, :])
174
- return result[:, :2, :], message
175
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/audioseal/py.typed DELETED
File without changes
src/scripts/checkpoints.py DELETED
@@ -1,51 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
-
8
- from pathlib import Path
9
-
10
- import torch
11
-
12
-
13
- def convert(checkpoint: str, outdir: str, suffix: str = "base"):
14
- """Convert the checkpoint to generator and detector"""
15
- outdir_path = Path(outdir)
16
- ckpt = torch.load(checkpoint)
17
-
18
- # keep inference-related params only
19
- infer_cfg = {
20
- "seanet": ckpt["xp.cfg"]["seanet"],
21
- "channels": ckpt["xp.cfg"]["channels"],
22
- "dtype": ckpt["xp.cfg"]["dtype"],
23
- "sample_rate": ckpt["xp.cfg"]["sample_rate"],
24
- }
25
-
26
- generator_ckpt = {"xp.cfg": infer_cfg, "model": {}}
27
- detector_ckpt = {"xp.cfg": infer_cfg, "model": {}}
28
-
29
- for layer in ckpt["model"].keys():
30
- if layer.startswith("detector"):
31
- new_layer = layer[9:]
32
- detector_ckpt["model"][new_layer] = ckpt["model"][layer] # type: ignore
33
- elif layer == "msg_processor.msg_processor.0.weight":
34
- generator_ckpt["model"]["msg_processor.msg_processor.weight"] = ckpt[ # type: ignore
35
- "model"
36
- ][
37
- layer
38
- ]
39
- else:
40
- assert layer.startswith("generator"), f"Invalid layer: {layer}"
41
- new_layer = layer[10:]
42
- generator_ckpt["model"][new_layer] = ckpt["model"][layer] # type: ignore
43
-
44
- torch.save(generator_ckpt, outdir_path / (f"checkpoint_generator_{suffix}.pth"))
45
- torch.save(detector_ckpt, outdir_path / (f"checkpoint_detector_{suffix}.pth"))
46
-
47
-
48
- if __name__ == "__main__":
49
- import fire
50
-
51
- fire.Fire(convert)