Upload 14 files
Browse files- src/audioseal/__init__.py +21 -0
- src/audioseal/builder.py +118 -0
- src/audioseal/cards/audioseal_detector_16bits.yaml +33 -0
- src/audioseal/cards/audioseal_wm_16bits.yaml +39 -0
- src/audioseal/libs/__init__.py +5 -0
- src/audioseal/libs/audiocraft/__init__.py +5 -0
- src/audioseal/libs/audiocraft/modules/__init__.py +8 -0
- src/audioseal/libs/audiocraft/modules/conv.py +337 -0
- src/audioseal/libs/audiocraft/modules/lstm.py +28 -0
- src/audioseal/libs/audiocraft/modules/seanet.py +426 -0
- src/audioseal/loader.py +227 -0
- src/audioseal/models.py +229 -0
- src/audioseal/py.typed +0 -0
- src/scripts/checkpoints.py +51 -0
src/audioseal/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Watermarking and detection for speech audios
|
9 |
+
|
10 |
+
A Pytorch-based localized algorithm for proactive detection
|
11 |
+
of the watermarkings in AI-generated audios, with very fast
|
12 |
+
detector.
|
13 |
+
|
14 |
+
"""
|
15 |
+
|
16 |
+
__version__ = "0.1.4"
|
17 |
+
|
18 |
+
|
19 |
+
from audioseal import builder
|
20 |
+
from audioseal.loader import AudioSeal
|
21 |
+
from audioseal.models import AudioSealDetector, AudioSealWM, MsgProcessor
|
src/audioseal/builder.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from dataclasses import asdict, dataclass, field, is_dataclass
|
8 |
+
from typing import Any, Dict, List, Optional
|
9 |
+
|
10 |
+
from omegaconf import DictConfig, OmegaConf
|
11 |
+
from torch import device, dtype
|
12 |
+
from typing_extensions import TypeAlias
|
13 |
+
|
14 |
+
from audioseal.libs import audiocraft
|
15 |
+
from audioseal.models import AudioSealDetector, AudioSealWM, MsgProcessor
|
16 |
+
|
17 |
+
Device: TypeAlias = device
|
18 |
+
|
19 |
+
DataType: TypeAlias = dtype
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class SEANetConfig:
|
24 |
+
"""
|
25 |
+
Map common hparams of SEANet encoder and decoder.
|
26 |
+
"""
|
27 |
+
|
28 |
+
channels: int
|
29 |
+
dimension: int
|
30 |
+
n_filters: int
|
31 |
+
n_residual_layers: int
|
32 |
+
ratios: List[int]
|
33 |
+
activation: str
|
34 |
+
activation_params: Dict[str, float]
|
35 |
+
norm: str
|
36 |
+
norm_params: Dict[str, Any]
|
37 |
+
kernel_size: int
|
38 |
+
last_kernel_size: int
|
39 |
+
residual_kernel_size: int
|
40 |
+
dilation_base: int
|
41 |
+
causal: bool
|
42 |
+
pad_mode: str
|
43 |
+
true_skip: bool
|
44 |
+
compress: int
|
45 |
+
lstm: int
|
46 |
+
disable_norm_outer_blocks: int
|
47 |
+
|
48 |
+
|
49 |
+
@dataclass
|
50 |
+
class DecoderConfig:
|
51 |
+
final_activation: Optional[str]
|
52 |
+
final_activation_params: Optional[dict]
|
53 |
+
trim_right_ratio: float
|
54 |
+
|
55 |
+
|
56 |
+
@dataclass
|
57 |
+
class DetectorConfig:
|
58 |
+
output_dim: int = 32
|
59 |
+
|
60 |
+
|
61 |
+
@dataclass
|
62 |
+
class AudioSealWMConfig:
|
63 |
+
nbits: int
|
64 |
+
seanet: SEANetConfig
|
65 |
+
decoder: DecoderConfig
|
66 |
+
|
67 |
+
|
68 |
+
@dataclass
|
69 |
+
class AudioSealDetectorConfig:
|
70 |
+
nbits: int
|
71 |
+
seanet: SEANetConfig
|
72 |
+
detector: DetectorConfig = field(default_factory=lambda: DetectorConfig())
|
73 |
+
|
74 |
+
|
75 |
+
def as_dict(obj: Any) -> Dict[str, Any]:
|
76 |
+
if isinstance(obj, dict):
|
77 |
+
return obj
|
78 |
+
if is_dataclass(obj) and not isinstance(obj, type):
|
79 |
+
return asdict(obj)
|
80 |
+
elif isinstance(obj, DictConfig):
|
81 |
+
return OmegaConf.to_container(obj) # type: ignore
|
82 |
+
else:
|
83 |
+
raise NotImplementedError(f"Unsupported type for config: {type(obj)}")
|
84 |
+
|
85 |
+
|
86 |
+
def create_generator(
|
87 |
+
config: AudioSealWMConfig,
|
88 |
+
*,
|
89 |
+
device: Optional[Device] = None,
|
90 |
+
dtype: Optional[DataType] = None,
|
91 |
+
) -> AudioSealWM:
|
92 |
+
"""Create a generator from hparams"""
|
93 |
+
|
94 |
+
# Currently the encoder hparams are the same as
|
95 |
+
# SEANet, but this can be changed in the future.
|
96 |
+
encoder = audiocraft.modules.SEANetEncoder(**as_dict(config.seanet))
|
97 |
+
encoder = encoder.to(device=device, dtype=dtype)
|
98 |
+
|
99 |
+
decoder_config = {**as_dict(config.seanet), **as_dict(config.decoder)}
|
100 |
+
decoder = audiocraft.modules.SEANetDecoder(**as_dict(decoder_config))
|
101 |
+
decoder = decoder.to(device=device, dtype=dtype)
|
102 |
+
|
103 |
+
msgprocessor = MsgProcessor(nbits=config.nbits, hidden_size=config.seanet.dimension)
|
104 |
+
msgprocessor = msgprocessor.to(device=device, dtype=dtype)
|
105 |
+
|
106 |
+
return AudioSealWM(encoder=encoder, decoder=decoder, msg_processor=msgprocessor)
|
107 |
+
|
108 |
+
|
109 |
+
def create_detector(
|
110 |
+
config: AudioSealDetectorConfig,
|
111 |
+
*,
|
112 |
+
device: Optional[Device] = None,
|
113 |
+
dtype: Optional[DataType] = None,
|
114 |
+
) -> AudioSealDetector:
|
115 |
+
detector_config = {**as_dict(config.seanet), **as_dict(config.detector)}
|
116 |
+
detector = AudioSealDetector(nbits=config.nbits, **detector_config)
|
117 |
+
detector = detector.to(device=device, dtype=dtype)
|
118 |
+
return detector
|
src/audioseal/cards/audioseal_detector_16bits.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package __global__
|
2 |
+
|
3 |
+
name: audioseal_detector_16bits
|
4 |
+
model_type: seanet
|
5 |
+
checkpoint: "https://huggingface.co/facebook/audioseal/resolve/main/detector_base.pth"
|
6 |
+
nbits: 16
|
7 |
+
seanet:
|
8 |
+
activation: ELU
|
9 |
+
activation_params:
|
10 |
+
alpha: 1.0
|
11 |
+
causal: false
|
12 |
+
channels: 1
|
13 |
+
compress: 2
|
14 |
+
dilation_base: 2
|
15 |
+
dimension: 128
|
16 |
+
disable_norm_outer_blocks: 0
|
17 |
+
kernel_size: 7
|
18 |
+
last_kernel_size: 7
|
19 |
+
lstm: 2
|
20 |
+
n_filters: 32
|
21 |
+
n_residual_layers: 1
|
22 |
+
norm: weight_norm
|
23 |
+
norm_params: {}
|
24 |
+
pad_mode: constant
|
25 |
+
ratios:
|
26 |
+
- 8
|
27 |
+
- 5
|
28 |
+
- 4
|
29 |
+
- 2
|
30 |
+
residual_kernel_size: 3
|
31 |
+
true_skip: true
|
32 |
+
detector:
|
33 |
+
output_dim: 32
|
src/audioseal/cards/audioseal_wm_16bits.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
name: audioseal_wm_16bits
|
8 |
+
model_type: seanet
|
9 |
+
checkpoint: "https://huggingface.co/facebook/audioseal/resolve/main/generator_base.pth"
|
10 |
+
nbits: 16
|
11 |
+
seanet:
|
12 |
+
activation: ELU
|
13 |
+
activation_params:
|
14 |
+
alpha: 1.0
|
15 |
+
causal: false
|
16 |
+
channels: 1
|
17 |
+
compress: 2
|
18 |
+
dilation_base: 2
|
19 |
+
dimension: 128
|
20 |
+
disable_norm_outer_blocks: 0
|
21 |
+
kernel_size: 7
|
22 |
+
last_kernel_size: 7
|
23 |
+
lstm: 2
|
24 |
+
n_filters: 32
|
25 |
+
n_residual_layers: 1
|
26 |
+
norm: weight_norm
|
27 |
+
norm_params: {}
|
28 |
+
pad_mode: constant
|
29 |
+
ratios:
|
30 |
+
- 8
|
31 |
+
- 5
|
32 |
+
- 4
|
33 |
+
- 2
|
34 |
+
residual_kernel_size: 3
|
35 |
+
true_skip: true
|
36 |
+
decoder:
|
37 |
+
final_activation: null
|
38 |
+
final_activation_params: null
|
39 |
+
trim_right_ratio: 1.0
|
src/audioseal/libs/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
src/audioseal/libs/audiocraft/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
src/audioseal/libs/audiocraft/modules/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
from .seanet import SEANetDecoder, SEANetEncoder, SEANetEncoderKeepDimension
|
src/audioseal/libs/audiocraft/modules/conv.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
# Vendor from https://github.com/facebookresearch/audiocraft
|
8 |
+
|
9 |
+
import math
|
10 |
+
import typing as tp
|
11 |
+
import warnings
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
from torch.nn import functional as F
|
16 |
+
from torch.nn.utils import spectral_norm
|
17 |
+
|
18 |
+
try:
|
19 |
+
from torch.nn.utils.parametrizations import weight_norm
|
20 |
+
except ImportError:
|
21 |
+
# Old Pytorch
|
22 |
+
from torch.nn.utils import weight_norm
|
23 |
+
|
24 |
+
|
25 |
+
CONV_NORMALIZATIONS = frozenset(
|
26 |
+
["none", "weight_norm", "spectral_norm", "time_group_norm"]
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
def apply_parametrization_norm(module: nn.Module, norm: str = "none"):
|
31 |
+
assert norm in CONV_NORMALIZATIONS
|
32 |
+
if norm == "weight_norm":
|
33 |
+
return weight_norm(module)
|
34 |
+
elif norm == "spectral_norm":
|
35 |
+
return spectral_norm(module)
|
36 |
+
else:
|
37 |
+
# We already check was in CONV_NORMALIZATION, so any other choice
|
38 |
+
# doesn't need reparametrization.
|
39 |
+
return module
|
40 |
+
|
41 |
+
|
42 |
+
def get_norm_module(
|
43 |
+
module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
|
44 |
+
):
|
45 |
+
"""Return the proper normalization module. If causal is True, this will ensure the returned
|
46 |
+
module is causal, or return an error if the normalization doesn't support causal evaluation.
|
47 |
+
"""
|
48 |
+
assert norm in CONV_NORMALIZATIONS
|
49 |
+
if norm == "time_group_norm":
|
50 |
+
if causal:
|
51 |
+
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
52 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
53 |
+
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
54 |
+
else:
|
55 |
+
return nn.Identity()
|
56 |
+
|
57 |
+
|
58 |
+
def get_extra_padding_for_conv1d(
|
59 |
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
60 |
+
) -> int:
|
61 |
+
"""See `pad_for_conv1d`."""
|
62 |
+
length = x.shape[-1]
|
63 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
64 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
65 |
+
return ideal_length - length
|
66 |
+
|
67 |
+
|
68 |
+
def pad_for_conv1d(
|
69 |
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
70 |
+
):
|
71 |
+
"""Pad for a convolution to make sure that the last window is full.
|
72 |
+
Extra padding is added at the end. This is required to ensure that we can rebuild
|
73 |
+
an output of the same length, as otherwise, even with padding, some time steps
|
74 |
+
might get removed.
|
75 |
+
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
76 |
+
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
77 |
+
1 2 3 # (output frames of a convolution, last 0 is never used)
|
78 |
+
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
79 |
+
1 2 3 4 # once you removed padding, we are missing one time step !
|
80 |
+
"""
|
81 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
82 |
+
return F.pad(x, (0, extra_padding))
|
83 |
+
|
84 |
+
|
85 |
+
def pad1d(
|
86 |
+
x: torch.Tensor,
|
87 |
+
paddings: tp.Tuple[int, int],
|
88 |
+
mode: str = "constant",
|
89 |
+
value: float = 0.0,
|
90 |
+
):
|
91 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
92 |
+
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
93 |
+
"""
|
94 |
+
length = x.shape[-1]
|
95 |
+
padding_left, padding_right = paddings
|
96 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
97 |
+
if mode == "reflect":
|
98 |
+
max_pad = max(padding_left, padding_right)
|
99 |
+
extra_pad = 0
|
100 |
+
if length <= max_pad:
|
101 |
+
extra_pad = max_pad - length + 1
|
102 |
+
x = F.pad(x, (0, extra_pad))
|
103 |
+
padded = F.pad(x, paddings, mode, value)
|
104 |
+
end = padded.shape[-1] - extra_pad
|
105 |
+
return padded[..., :end]
|
106 |
+
else:
|
107 |
+
return F.pad(x, paddings, mode, value)
|
108 |
+
|
109 |
+
|
110 |
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
111 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
112 |
+
padding_left, padding_right = paddings
|
113 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
114 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
115 |
+
end = x.shape[-1] - padding_right
|
116 |
+
return x[..., padding_left:end]
|
117 |
+
|
118 |
+
|
119 |
+
class NormConv1d(nn.Module):
|
120 |
+
"""Wrapper around Conv1d and normalization applied to this conv
|
121 |
+
to provide a uniform interface across normalization approaches.
|
122 |
+
"""
|
123 |
+
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
*args,
|
127 |
+
causal: bool = False,
|
128 |
+
norm: str = "none",
|
129 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
130 |
+
**kwargs,
|
131 |
+
):
|
132 |
+
super().__init__()
|
133 |
+
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
134 |
+
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
135 |
+
self.norm_type = norm
|
136 |
+
|
137 |
+
def forward(self, x):
|
138 |
+
x = self.conv(x)
|
139 |
+
x = self.norm(x)
|
140 |
+
return x
|
141 |
+
|
142 |
+
|
143 |
+
class NormConv2d(nn.Module):
|
144 |
+
"""Wrapper around Conv2d and normalization applied to this conv
|
145 |
+
to provide a uniform interface across normalization approaches.
|
146 |
+
"""
|
147 |
+
|
148 |
+
def __init__(
|
149 |
+
self,
|
150 |
+
*args,
|
151 |
+
norm: str = "none",
|
152 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
153 |
+
**kwargs,
|
154 |
+
):
|
155 |
+
super().__init__()
|
156 |
+
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
|
157 |
+
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
|
158 |
+
self.norm_type = norm
|
159 |
+
|
160 |
+
def forward(self, x):
|
161 |
+
x = self.conv(x)
|
162 |
+
x = self.norm(x)
|
163 |
+
return x
|
164 |
+
|
165 |
+
|
166 |
+
class NormConvTranspose1d(nn.Module):
|
167 |
+
"""Wrapper around ConvTranspose1d and normalization applied to this conv
|
168 |
+
to provide a uniform interface across normalization approaches.
|
169 |
+
"""
|
170 |
+
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
*args,
|
174 |
+
causal: bool = False,
|
175 |
+
norm: str = "none",
|
176 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
177 |
+
**kwargs,
|
178 |
+
):
|
179 |
+
super().__init__()
|
180 |
+
self.convtr = apply_parametrization_norm(
|
181 |
+
nn.ConvTranspose1d(*args, **kwargs), norm
|
182 |
+
)
|
183 |
+
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
184 |
+
self.norm_type = norm
|
185 |
+
|
186 |
+
def forward(self, x):
|
187 |
+
x = self.convtr(x)
|
188 |
+
x = self.norm(x)
|
189 |
+
return x
|
190 |
+
|
191 |
+
|
192 |
+
class NormConvTranspose2d(nn.Module):
|
193 |
+
"""Wrapper around ConvTranspose2d and normalization applied to this conv
|
194 |
+
to provide a uniform interface across normalization approaches.
|
195 |
+
"""
|
196 |
+
|
197 |
+
def __init__(
|
198 |
+
self,
|
199 |
+
*args,
|
200 |
+
norm: str = "none",
|
201 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
202 |
+
**kwargs,
|
203 |
+
):
|
204 |
+
super().__init__()
|
205 |
+
self.convtr = apply_parametrization_norm(
|
206 |
+
nn.ConvTranspose2d(*args, **kwargs), norm
|
207 |
+
)
|
208 |
+
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
|
209 |
+
|
210 |
+
def forward(self, x):
|
211 |
+
x = self.convtr(x)
|
212 |
+
x = self.norm(x)
|
213 |
+
return x
|
214 |
+
|
215 |
+
|
216 |
+
class StreamableConv1d(nn.Module):
|
217 |
+
"""Conv1d with some builtin handling of asymmetric or causal padding
|
218 |
+
and normalization.
|
219 |
+
"""
|
220 |
+
|
221 |
+
def __init__(
|
222 |
+
self,
|
223 |
+
in_channels: int,
|
224 |
+
out_channels: int,
|
225 |
+
kernel_size: int,
|
226 |
+
stride: int = 1,
|
227 |
+
dilation: int = 1,
|
228 |
+
groups: int = 1,
|
229 |
+
bias: bool = True,
|
230 |
+
causal: bool = False,
|
231 |
+
norm: str = "none",
|
232 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
233 |
+
pad_mode: str = "reflect",
|
234 |
+
):
|
235 |
+
super().__init__()
|
236 |
+
# warn user on unusual setup between dilation and stride
|
237 |
+
if stride > 1 and dilation > 1:
|
238 |
+
warnings.warn(
|
239 |
+
"StreamableConv1d has been initialized with stride > 1 and dilation > 1"
|
240 |
+
f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
|
241 |
+
)
|
242 |
+
self.conv = NormConv1d(
|
243 |
+
in_channels,
|
244 |
+
out_channels,
|
245 |
+
kernel_size,
|
246 |
+
stride,
|
247 |
+
dilation=dilation,
|
248 |
+
groups=groups,
|
249 |
+
bias=bias,
|
250 |
+
causal=causal,
|
251 |
+
norm=norm,
|
252 |
+
norm_kwargs=norm_kwargs,
|
253 |
+
)
|
254 |
+
self.causal = causal
|
255 |
+
self.pad_mode = pad_mode
|
256 |
+
|
257 |
+
def forward(self, x):
|
258 |
+
B, C, T = x.shape
|
259 |
+
kernel_size = self.conv.conv.kernel_size[0]
|
260 |
+
stride = self.conv.conv.stride[0]
|
261 |
+
dilation = self.conv.conv.dilation[0]
|
262 |
+
kernel_size = (
|
263 |
+
kernel_size - 1
|
264 |
+
) * dilation + 1 # effective kernel size with dilations
|
265 |
+
padding_total = kernel_size - stride
|
266 |
+
extra_padding = get_extra_padding_for_conv1d(
|
267 |
+
x, kernel_size, stride, padding_total
|
268 |
+
)
|
269 |
+
if self.causal:
|
270 |
+
# Left padding for causal
|
271 |
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
272 |
+
else:
|
273 |
+
# Asymmetric padding required for odd strides
|
274 |
+
padding_right = padding_total // 2
|
275 |
+
padding_left = padding_total - padding_right
|
276 |
+
x = pad1d(
|
277 |
+
x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
|
278 |
+
)
|
279 |
+
return self.conv(x)
|
280 |
+
|
281 |
+
|
282 |
+
class StreamableConvTranspose1d(nn.Module):
|
283 |
+
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
|
284 |
+
and normalization.
|
285 |
+
"""
|
286 |
+
|
287 |
+
def __init__(
|
288 |
+
self,
|
289 |
+
in_channels: int,
|
290 |
+
out_channels: int,
|
291 |
+
kernel_size: int,
|
292 |
+
stride: int = 1,
|
293 |
+
causal: bool = False,
|
294 |
+
norm: str = "none",
|
295 |
+
trim_right_ratio: float = 1.0,
|
296 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
297 |
+
):
|
298 |
+
super().__init__()
|
299 |
+
self.convtr = NormConvTranspose1d(
|
300 |
+
in_channels,
|
301 |
+
out_channels,
|
302 |
+
kernel_size,
|
303 |
+
stride,
|
304 |
+
causal=causal,
|
305 |
+
norm=norm,
|
306 |
+
norm_kwargs=norm_kwargs,
|
307 |
+
)
|
308 |
+
self.causal = causal
|
309 |
+
self.trim_right_ratio = trim_right_ratio
|
310 |
+
assert (
|
311 |
+
self.causal or self.trim_right_ratio == 1.0
|
312 |
+
), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
313 |
+
assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
|
314 |
+
|
315 |
+
def forward(self, x):
|
316 |
+
kernel_size = self.convtr.convtr.kernel_size[0]
|
317 |
+
stride = self.convtr.convtr.stride[0]
|
318 |
+
padding_total = kernel_size - stride
|
319 |
+
|
320 |
+
y = self.convtr(x)
|
321 |
+
|
322 |
+
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
323 |
+
# removed at the very end, when keeping only the right length for the output,
|
324 |
+
# as removing it here would require also passing the length at the matching layer
|
325 |
+
# in the encoder.
|
326 |
+
if self.causal:
|
327 |
+
# Trim the padding on the right according to the specified ratio
|
328 |
+
# if trim_right_ratio = 1.0, trim everything from right
|
329 |
+
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
330 |
+
padding_left = padding_total - padding_right
|
331 |
+
y = unpad1d(y, (padding_left, padding_right))
|
332 |
+
else:
|
333 |
+
# Asymmetric padding required for odd strides
|
334 |
+
padding_right = padding_total // 2
|
335 |
+
padding_left = padding_total - padding_right
|
336 |
+
y = unpad1d(y, (padding_left, padding_right))
|
337 |
+
return y
|
src/audioseal/libs/audiocraft/modules/lstm.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
# Vendor from https://github.com/facebookresearch/audiocraft
|
8 |
+
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
|
12 |
+
class StreamableLSTM(nn.Module):
|
13 |
+
"""LSTM without worrying about the hidden state, nor the layout of the data.
|
14 |
+
Expects input as convolutional layout.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
|
18 |
+
super().__init__()
|
19 |
+
self.skip = skip
|
20 |
+
self.lstm = nn.LSTM(dimension, dimension, num_layers)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
x = x.permute(2, 0, 1)
|
24 |
+
y, _ = self.lstm(x)
|
25 |
+
if self.skip:
|
26 |
+
y = y + x
|
27 |
+
y = y.permute(1, 2, 0)
|
28 |
+
return y
|
src/audioseal/libs/audiocraft/modules/seanet.py
ADDED
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
# Vendor from https://github.com/facebookresearch/audiocraft
|
8 |
+
|
9 |
+
import math
|
10 |
+
import typing as tp
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch.nn as nn
|
14 |
+
|
15 |
+
from audioseal.libs.audiocraft.modules.conv import (
|
16 |
+
StreamableConv1d,
|
17 |
+
StreamableConvTranspose1d,
|
18 |
+
)
|
19 |
+
from audioseal.libs.audiocraft.modules.lstm import StreamableLSTM
|
20 |
+
|
21 |
+
|
22 |
+
class SEANetResnetBlock(nn.Module):
|
23 |
+
"""Residual block from SEANet model.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
dim (int): Dimension of the input/output.
|
27 |
+
kernel_sizes (list): List of kernel sizes for the convolutions.
|
28 |
+
dilations (list): List of dilations for the convolutions.
|
29 |
+
activation (str): Activation function.
|
30 |
+
activation_params (dict): Parameters to provide to the activation function.
|
31 |
+
norm (str): Normalization method.
|
32 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
33 |
+
causal (bool): Whether to use fully causal convolution.
|
34 |
+
pad_mode (str): Padding mode for the convolutions.
|
35 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
36 |
+
true_skip (bool): Whether to use true skip connection or a simple
|
37 |
+
(streamable) convolution as the skip connection.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
dim: int,
|
43 |
+
kernel_sizes: tp.List[int] = [3, 1],
|
44 |
+
dilations: tp.List[int] = [1, 1],
|
45 |
+
activation: str = "ELU",
|
46 |
+
activation_params: dict = {"alpha": 1.0},
|
47 |
+
norm: str = "none",
|
48 |
+
norm_params: tp.Dict[str, tp.Any] = {},
|
49 |
+
causal: bool = False,
|
50 |
+
pad_mode: str = "reflect",
|
51 |
+
compress: int = 2,
|
52 |
+
true_skip: bool = True,
|
53 |
+
):
|
54 |
+
super().__init__()
|
55 |
+
assert len(kernel_sizes) == len(
|
56 |
+
dilations
|
57 |
+
), "Number of kernel sizes should match number of dilations"
|
58 |
+
act = getattr(nn, activation)
|
59 |
+
hidden = dim // compress
|
60 |
+
block = []
|
61 |
+
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
|
62 |
+
in_chs = dim if i == 0 else hidden
|
63 |
+
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
|
64 |
+
block += [
|
65 |
+
act(**activation_params),
|
66 |
+
StreamableConv1d(
|
67 |
+
in_chs,
|
68 |
+
out_chs,
|
69 |
+
kernel_size=kernel_size,
|
70 |
+
dilation=dilation,
|
71 |
+
norm=norm,
|
72 |
+
norm_kwargs=norm_params,
|
73 |
+
causal=causal,
|
74 |
+
pad_mode=pad_mode,
|
75 |
+
),
|
76 |
+
]
|
77 |
+
self.block = nn.Sequential(*block)
|
78 |
+
self.shortcut: nn.Module
|
79 |
+
if true_skip:
|
80 |
+
self.shortcut = nn.Identity()
|
81 |
+
else:
|
82 |
+
self.shortcut = StreamableConv1d(
|
83 |
+
dim,
|
84 |
+
dim,
|
85 |
+
kernel_size=1,
|
86 |
+
norm=norm,
|
87 |
+
norm_kwargs=norm_params,
|
88 |
+
causal=causal,
|
89 |
+
pad_mode=pad_mode,
|
90 |
+
)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
return self.shortcut(x) + self.block(x)
|
94 |
+
|
95 |
+
|
96 |
+
class SEANetEncoder(nn.Module):
|
97 |
+
"""SEANet encoder.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
channels (int): Audio channels.
|
101 |
+
dimension (int): Intermediate representation dimension.
|
102 |
+
n_filters (int): Base width for the model.
|
103 |
+
n_residual_layers (int): nb of residual layers.
|
104 |
+
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
|
105 |
+
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
|
106 |
+
that must match the decoder order. We use the decoder order as some models may only employ the decoder.
|
107 |
+
activation (str): Activation function.
|
108 |
+
activation_params (dict): Parameters to provide to the activation function.
|
109 |
+
norm (str): Normalization method.
|
110 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
111 |
+
kernel_size (int): Kernel size for the initial convolution.
|
112 |
+
last_kernel_size (int): Kernel size for the initial convolution.
|
113 |
+
residual_kernel_size (int): Kernel size for the residual layers.
|
114 |
+
dilation_base (int): How much to increase the dilation with each layer.
|
115 |
+
causal (bool): Whether to use fully causal convolution.
|
116 |
+
pad_mode (str): Padding mode for the convolutions.
|
117 |
+
true_skip (bool): Whether to use true skip connection or a simple
|
118 |
+
(streamable) convolution as the skip connection in the residual network blocks.
|
119 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
120 |
+
lstm (int): Number of LSTM layers at the end of the encoder.
|
121 |
+
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
122 |
+
For the encoder, it corresponds to the N first blocks.
|
123 |
+
"""
|
124 |
+
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
channels: int = 1,
|
128 |
+
dimension: int = 128,
|
129 |
+
n_filters: int = 32,
|
130 |
+
n_residual_layers: int = 3,
|
131 |
+
ratios: tp.List[int] = [8, 5, 4, 2],
|
132 |
+
activation: str = "ELU",
|
133 |
+
activation_params: dict = {"alpha": 1.0},
|
134 |
+
norm: str = "none",
|
135 |
+
norm_params: tp.Dict[str, tp.Any] = {},
|
136 |
+
kernel_size: int = 7,
|
137 |
+
last_kernel_size: int = 7,
|
138 |
+
residual_kernel_size: int = 3,
|
139 |
+
dilation_base: int = 2,
|
140 |
+
causal: bool = False,
|
141 |
+
pad_mode: str = "reflect",
|
142 |
+
true_skip: bool = True,
|
143 |
+
compress: int = 2,
|
144 |
+
lstm: int = 0,
|
145 |
+
disable_norm_outer_blocks: int = 0,
|
146 |
+
):
|
147 |
+
super().__init__()
|
148 |
+
self.channels = channels
|
149 |
+
self.dimension = dimension
|
150 |
+
self.n_filters = n_filters
|
151 |
+
self.ratios = list(reversed(ratios))
|
152 |
+
del ratios
|
153 |
+
self.n_residual_layers = n_residual_layers
|
154 |
+
self.hop_length = np.prod(self.ratios)
|
155 |
+
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
156 |
+
self.disable_norm_outer_blocks = disable_norm_outer_blocks
|
157 |
+
assert (
|
158 |
+
self.disable_norm_outer_blocks >= 0
|
159 |
+
and self.disable_norm_outer_blocks <= self.n_blocks
|
160 |
+
), (
|
161 |
+
"Number of blocks for which to disable norm is invalid."
|
162 |
+
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
|
163 |
+
)
|
164 |
+
|
165 |
+
act = getattr(nn, activation)
|
166 |
+
mult = 1
|
167 |
+
model: tp.List[nn.Module] = [
|
168 |
+
StreamableConv1d(
|
169 |
+
channels,
|
170 |
+
mult * n_filters,
|
171 |
+
kernel_size,
|
172 |
+
norm="none" if self.disable_norm_outer_blocks >= 1 else norm,
|
173 |
+
norm_kwargs=norm_params,
|
174 |
+
causal=causal,
|
175 |
+
pad_mode=pad_mode,
|
176 |
+
)
|
177 |
+
]
|
178 |
+
# Downsample to raw audio scale
|
179 |
+
for i, ratio in enumerate(self.ratios):
|
180 |
+
block_norm = "none" if self.disable_norm_outer_blocks >= i + 2 else norm
|
181 |
+
# Add residual layers
|
182 |
+
for j in range(n_residual_layers):
|
183 |
+
model += [
|
184 |
+
SEANetResnetBlock(
|
185 |
+
mult * n_filters,
|
186 |
+
kernel_sizes=[residual_kernel_size, 1],
|
187 |
+
dilations=[dilation_base**j, 1],
|
188 |
+
norm=block_norm,
|
189 |
+
norm_params=norm_params,
|
190 |
+
activation=activation,
|
191 |
+
activation_params=activation_params,
|
192 |
+
causal=causal,
|
193 |
+
pad_mode=pad_mode,
|
194 |
+
compress=compress,
|
195 |
+
true_skip=true_skip,
|
196 |
+
)
|
197 |
+
]
|
198 |
+
|
199 |
+
# Add downsampling layers
|
200 |
+
model += [
|
201 |
+
act(**activation_params),
|
202 |
+
StreamableConv1d(
|
203 |
+
mult * n_filters,
|
204 |
+
mult * n_filters * 2,
|
205 |
+
kernel_size=ratio * 2,
|
206 |
+
stride=ratio,
|
207 |
+
norm=block_norm,
|
208 |
+
norm_kwargs=norm_params,
|
209 |
+
causal=causal,
|
210 |
+
pad_mode=pad_mode,
|
211 |
+
),
|
212 |
+
]
|
213 |
+
mult *= 2
|
214 |
+
|
215 |
+
if lstm:
|
216 |
+
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
|
217 |
+
|
218 |
+
model += [
|
219 |
+
act(**activation_params),
|
220 |
+
StreamableConv1d(
|
221 |
+
mult * n_filters,
|
222 |
+
dimension,
|
223 |
+
last_kernel_size,
|
224 |
+
norm=(
|
225 |
+
"none" if self.disable_norm_outer_blocks == self.n_blocks else norm
|
226 |
+
),
|
227 |
+
norm_kwargs=norm_params,
|
228 |
+
causal=causal,
|
229 |
+
pad_mode=pad_mode,
|
230 |
+
),
|
231 |
+
]
|
232 |
+
|
233 |
+
self.model = nn.Sequential(*model)
|
234 |
+
|
235 |
+
def forward(self, x):
|
236 |
+
return self.model(x)
|
237 |
+
|
238 |
+
|
239 |
+
class SEANetEncoderKeepDimension(SEANetEncoder):
|
240 |
+
"""
|
241 |
+
similar architecture to the SEANet encoder but with an extra step that
|
242 |
+
projects the output dimension to the same input dimension by repeating
|
243 |
+
the sequential
|
244 |
+
|
245 |
+
Args:
|
246 |
+
SEANetEncoder (_type_): _description_
|
247 |
+
"""
|
248 |
+
|
249 |
+
def __init__(self, *args, **kwargs):
|
250 |
+
|
251 |
+
self.output_dim = kwargs.pop("output_dim")
|
252 |
+
super().__init__(*args, **kwargs)
|
253 |
+
# Adding a reverse convolution layer
|
254 |
+
self.reverse_convolution = nn.ConvTranspose1d(
|
255 |
+
in_channels=self.dimension,
|
256 |
+
out_channels=self.output_dim,
|
257 |
+
kernel_size=math.prod(self.ratios),
|
258 |
+
stride=math.prod(self.ratios),
|
259 |
+
padding=0,
|
260 |
+
)
|
261 |
+
|
262 |
+
def forward(self, x):
|
263 |
+
orig_nframes = x.shape[-1]
|
264 |
+
x = self.model(x)
|
265 |
+
x = self.reverse_convolution(x)
|
266 |
+
# make sure dim didn't change
|
267 |
+
return x[:, :, :orig_nframes]
|
268 |
+
|
269 |
+
|
270 |
+
class SEANetDecoder(nn.Module):
|
271 |
+
"""SEANet decoder.
|
272 |
+
|
273 |
+
Args:
|
274 |
+
channels (int): Audio channels.
|
275 |
+
dimension (int): Intermediate representation dimension.
|
276 |
+
n_filters (int): Base width for the model.
|
277 |
+
n_residual_layers (int): nb of residual layers.
|
278 |
+
ratios (Sequence[int]): kernel size and stride ratios.
|
279 |
+
activation (str): Activation function.
|
280 |
+
activation_params (dict): Parameters to provide to the activation function.
|
281 |
+
final_activation (str): Final activation function after all convolutions.
|
282 |
+
final_activation_params (dict): Parameters to provide to the activation function.
|
283 |
+
norm (str): Normalization method.
|
284 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
285 |
+
kernel_size (int): Kernel size for the initial convolution.
|
286 |
+
last_kernel_size (int): Kernel size for the initial convolution.
|
287 |
+
residual_kernel_size (int): Kernel size for the residual layers.
|
288 |
+
dilation_base (int): How much to increase the dilation with each layer.
|
289 |
+
causal (bool): Whether to use fully causal convolution.
|
290 |
+
pad_mode (str): Padding mode for the convolutions.
|
291 |
+
true_skip (bool): Whether to use true skip connection or a simple.
|
292 |
+
(streamable) convolution as the skip connection in the residual network blocks.
|
293 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
294 |
+
lstm (int): Number of LSTM layers at the end of the encoder.
|
295 |
+
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
296 |
+
For the decoder, it corresponds to the N last blocks.
|
297 |
+
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
|
298 |
+
If equal to 1.0, it means that all the trimming is done at the right.
|
299 |
+
"""
|
300 |
+
|
301 |
+
def __init__(
|
302 |
+
self,
|
303 |
+
channels: int = 1,
|
304 |
+
dimension: int = 128,
|
305 |
+
n_filters: int = 32,
|
306 |
+
n_residual_layers: int = 3,
|
307 |
+
ratios: tp.List[int] = [8, 5, 4, 2],
|
308 |
+
activation: str = "ELU",
|
309 |
+
activation_params: dict = {"alpha": 1.0},
|
310 |
+
final_activation: tp.Optional[str] = None,
|
311 |
+
final_activation_params: tp.Optional[dict] = None,
|
312 |
+
norm: str = "none",
|
313 |
+
norm_params: tp.Dict[str, tp.Any] = {},
|
314 |
+
kernel_size: int = 7,
|
315 |
+
last_kernel_size: int = 7,
|
316 |
+
residual_kernel_size: int = 3,
|
317 |
+
dilation_base: int = 2,
|
318 |
+
causal: bool = False,
|
319 |
+
pad_mode: str = "reflect",
|
320 |
+
true_skip: bool = True,
|
321 |
+
compress: int = 2,
|
322 |
+
lstm: int = 0,
|
323 |
+
disable_norm_outer_blocks: int = 0,
|
324 |
+
trim_right_ratio: float = 1.0,
|
325 |
+
):
|
326 |
+
super().__init__()
|
327 |
+
self.dimension = dimension
|
328 |
+
self.channels = channels
|
329 |
+
self.n_filters = n_filters
|
330 |
+
self.ratios = ratios
|
331 |
+
del ratios
|
332 |
+
self.n_residual_layers = n_residual_layers
|
333 |
+
self.hop_length = np.prod(self.ratios)
|
334 |
+
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
335 |
+
self.disable_norm_outer_blocks = disable_norm_outer_blocks
|
336 |
+
assert (
|
337 |
+
self.disable_norm_outer_blocks >= 0
|
338 |
+
and self.disable_norm_outer_blocks <= self.n_blocks
|
339 |
+
), (
|
340 |
+
"Number of blocks for which to disable norm is invalid."
|
341 |
+
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
|
342 |
+
)
|
343 |
+
|
344 |
+
act = getattr(nn, activation)
|
345 |
+
mult = int(2 ** len(self.ratios))
|
346 |
+
model: tp.List[nn.Module] = [
|
347 |
+
StreamableConv1d(
|
348 |
+
dimension,
|
349 |
+
mult * n_filters,
|
350 |
+
kernel_size,
|
351 |
+
norm=(
|
352 |
+
"none" if self.disable_norm_outer_blocks == self.n_blocks else norm
|
353 |
+
),
|
354 |
+
norm_kwargs=norm_params,
|
355 |
+
causal=causal,
|
356 |
+
pad_mode=pad_mode,
|
357 |
+
)
|
358 |
+
]
|
359 |
+
|
360 |
+
if lstm:
|
361 |
+
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
|
362 |
+
|
363 |
+
# Upsample to raw audio scale
|
364 |
+
for i, ratio in enumerate(self.ratios):
|
365 |
+
block_norm = (
|
366 |
+
"none"
|
367 |
+
if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1)
|
368 |
+
else norm
|
369 |
+
)
|
370 |
+
# Add upsampling layers
|
371 |
+
model += [
|
372 |
+
act(**activation_params),
|
373 |
+
StreamableConvTranspose1d(
|
374 |
+
mult * n_filters,
|
375 |
+
mult * n_filters // 2,
|
376 |
+
kernel_size=ratio * 2,
|
377 |
+
stride=ratio,
|
378 |
+
norm=block_norm,
|
379 |
+
norm_kwargs=norm_params,
|
380 |
+
causal=causal,
|
381 |
+
trim_right_ratio=trim_right_ratio,
|
382 |
+
),
|
383 |
+
]
|
384 |
+
# Add residual layers
|
385 |
+
for j in range(n_residual_layers):
|
386 |
+
model += [
|
387 |
+
SEANetResnetBlock(
|
388 |
+
mult * n_filters // 2,
|
389 |
+
kernel_sizes=[residual_kernel_size, 1],
|
390 |
+
dilations=[dilation_base**j, 1],
|
391 |
+
activation=activation,
|
392 |
+
activation_params=activation_params,
|
393 |
+
norm=block_norm,
|
394 |
+
norm_params=norm_params,
|
395 |
+
causal=causal,
|
396 |
+
pad_mode=pad_mode,
|
397 |
+
compress=compress,
|
398 |
+
true_skip=true_skip,
|
399 |
+
)
|
400 |
+
]
|
401 |
+
|
402 |
+
mult //= 2
|
403 |
+
|
404 |
+
# Add final layers
|
405 |
+
model += [
|
406 |
+
act(**activation_params),
|
407 |
+
StreamableConv1d(
|
408 |
+
n_filters,
|
409 |
+
channels,
|
410 |
+
last_kernel_size,
|
411 |
+
norm="none" if self.disable_norm_outer_blocks >= 1 else norm,
|
412 |
+
norm_kwargs=norm_params,
|
413 |
+
causal=causal,
|
414 |
+
pad_mode=pad_mode,
|
415 |
+
),
|
416 |
+
]
|
417 |
+
# Add optional final activation to decoder (eg. tanh)
|
418 |
+
if final_activation is not None:
|
419 |
+
final_act = getattr(nn, final_activation)
|
420 |
+
final_activation_params = final_activation_params or {}
|
421 |
+
model += [final_act(**final_activation_params)]
|
422 |
+
self.model = nn.Sequential(*model)
|
423 |
+
|
424 |
+
def forward(self, z):
|
425 |
+
y = self.model(z)
|
426 |
+
return y
|
src/audioseal/loader.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
import os
|
9 |
+
from dataclasses import fields
|
10 |
+
from hashlib import sha1
|
11 |
+
from pathlib import Path
|
12 |
+
from typing import ( # type: ignore[attr-defined]
|
13 |
+
Any,
|
14 |
+
Dict,
|
15 |
+
List,
|
16 |
+
Optional,
|
17 |
+
Tuple,
|
18 |
+
Type,
|
19 |
+
TypeVar,
|
20 |
+
Union,
|
21 |
+
cast,
|
22 |
+
)
|
23 |
+
from urllib.parse import urlparse # noqa: F401
|
24 |
+
|
25 |
+
import torch
|
26 |
+
from omegaconf import DictConfig, OmegaConf
|
27 |
+
|
28 |
+
import audioseal
|
29 |
+
from audioseal.builder import (
|
30 |
+
AudioSealDetectorConfig,
|
31 |
+
AudioSealWMConfig,
|
32 |
+
create_detector,
|
33 |
+
create_generator,
|
34 |
+
)
|
35 |
+
from audioseal.models import AudioSealDetector, AudioSealWM
|
36 |
+
|
37 |
+
AudioSealT = TypeVar("AudioSealT", AudioSealWMConfig, AudioSealDetectorConfig)
|
38 |
+
|
39 |
+
|
40 |
+
class ModelLoadError(RuntimeError):
|
41 |
+
"""Raised when the model loading fails"""
|
42 |
+
|
43 |
+
|
44 |
+
def _get_path_from_env(var_name: str) -> Optional[Path]:
|
45 |
+
pathname = os.getenv(var_name)
|
46 |
+
if not pathname:
|
47 |
+
return None
|
48 |
+
|
49 |
+
try:
|
50 |
+
return Path(pathname)
|
51 |
+
except ValueError as ex:
|
52 |
+
raise RuntimeError(f"Expect valid pathname, get '{pathname}'.") from ex
|
53 |
+
|
54 |
+
|
55 |
+
def _get_cache_dir(env_names: List[str]):
|
56 |
+
"""Re-use cache dir from a list of existing caches"""
|
57 |
+
for env in env_names:
|
58 |
+
cache_dir = _get_path_from_env(env)
|
59 |
+
if cache_dir:
|
60 |
+
break
|
61 |
+
else:
|
62 |
+
cache_dir = Path("~/.cache").expanduser().resolve()
|
63 |
+
|
64 |
+
# Create a sub-dir to not mess up with existing caches
|
65 |
+
cache_dir = cache_dir / "audioseal"
|
66 |
+
cache_dir.mkdir(exist_ok=True, parents=True)
|
67 |
+
|
68 |
+
return cache_dir
|
69 |
+
|
70 |
+
|
71 |
+
def load_model_checkpoint(
|
72 |
+
model_path: Union[Path, str],
|
73 |
+
device: Union[str, torch.device] = "cpu",
|
74 |
+
):
|
75 |
+
if Path(model_path).is_file():
|
76 |
+
return torch.load(model_path, map_location=device)
|
77 |
+
|
78 |
+
cache_dir = _get_cache_dir(
|
79 |
+
["AUDIOSEAL_CACHE_DIR", "AUDIOCRAFT_CACHE_DIR", "XDG_CACHE_HOME"]
|
80 |
+
)
|
81 |
+
parts = urlparse(str(model_path))
|
82 |
+
if parts.scheme == "https":
|
83 |
+
|
84 |
+
hash_ = sha1(parts.path.encode()).hexdigest()[:24]
|
85 |
+
return torch.hub.load_state_dict_from_url(
|
86 |
+
str(model_path), model_dir=cache_dir, map_location=device, file_name=hash_
|
87 |
+
)
|
88 |
+
elif str(model_path).startswith("facebook/audioseal/"):
|
89 |
+
hf_filename = str(model_path)[len("facebook/audioseal/") :]
|
90 |
+
|
91 |
+
try:
|
92 |
+
from huggingface_hub import hf_hub_download
|
93 |
+
except ModuleNotFoundError:
|
94 |
+
print(
|
95 |
+
f"The model path {model_path} seems to be a direct HF path, "
|
96 |
+
"but you do not install Huggingface_hub. Install with for example "
|
97 |
+
"`pip install huggingface_hub` to use this feature."
|
98 |
+
)
|
99 |
+
file = hf_hub_download(
|
100 |
+
repo_id="facebook/audioseal",
|
101 |
+
filename=hf_filename,
|
102 |
+
cache_dir=cache_dir,
|
103 |
+
library_name="audioseal",
|
104 |
+
library_version=audioseal.__version__,
|
105 |
+
)
|
106 |
+
return torch.load(file, map_location=device)
|
107 |
+
else:
|
108 |
+
raise ModelLoadError(f"Path or uri {model_path} is unknown or does not exist")
|
109 |
+
|
110 |
+
|
111 |
+
def load_local_model_config(model_card: str) -> Optional[DictConfig]:
|
112 |
+
config_file = Path(__file__).parent / "cards" / (model_card + ".yaml")
|
113 |
+
if Path(config_file).is_file():
|
114 |
+
return cast(DictConfig, OmegaConf.load(config_file.resolve()))
|
115 |
+
else:
|
116 |
+
return None
|
117 |
+
|
118 |
+
|
119 |
+
class AudioSeal:
|
120 |
+
|
121 |
+
@staticmethod
|
122 |
+
def parse_model(
|
123 |
+
model_card_or_path: str,
|
124 |
+
model_type: Type[AudioSealT],
|
125 |
+
nbits: Optional[int] = None,
|
126 |
+
) -> Tuple[Dict[str, Any], AudioSealT]:
|
127 |
+
"""
|
128 |
+
Parse the information from the model card or checkpoint path using
|
129 |
+
the schema `model_type` that defines the model type
|
130 |
+
"""
|
131 |
+
# Get the raw checkpoint and config from the local model cards
|
132 |
+
config = load_local_model_config(model_card_or_path)
|
133 |
+
|
134 |
+
if config:
|
135 |
+
assert "checkpoint" in config, f"Checkpoint missing in {model_card_or_path}"
|
136 |
+
config_dict = OmegaConf.to_container(config)
|
137 |
+
assert isinstance(
|
138 |
+
config_dict, dict
|
139 |
+
), f"Cannot parse config from {model_card_or_path}"
|
140 |
+
checkpoint = config_dict.pop("checkpoint")
|
141 |
+
checkpoint = load_model_checkpoint(checkpoint)
|
142 |
+
|
143 |
+
# Get the raw checkpoint and config from the checkpoint path
|
144 |
+
else:
|
145 |
+
config_dict = {}
|
146 |
+
checkpoint = load_model_checkpoint(model_card_or_path)
|
147 |
+
|
148 |
+
if "xp.cfg" in checkpoint:
|
149 |
+
config_dict = {**checkpoint["xp.cfg"], **config_dict} # type: ignore
|
150 |
+
|
151 |
+
model_config = AudioSeal.parse_config(config_dict, config_type=model_type, nbits=nbits) # type: ignore
|
152 |
+
|
153 |
+
if "model" in checkpoint:
|
154 |
+
checkpoint = checkpoint["model"]
|
155 |
+
|
156 |
+
return checkpoint, model_config
|
157 |
+
|
158 |
+
@staticmethod
|
159 |
+
def parse_config(
|
160 |
+
config: Dict[str, Any],
|
161 |
+
config_type: Type[AudioSealT],
|
162 |
+
nbits: Optional[int] = None,
|
163 |
+
) -> AudioSealT:
|
164 |
+
|
165 |
+
assert "seanet" in config, f"missing seanet backbone config in {config}"
|
166 |
+
|
167 |
+
# Patch 1: Resolve the variables in the checkpoint
|
168 |
+
config = OmegaConf.create(config) # type: ignore
|
169 |
+
OmegaConf.resolve(config) # type: ignore
|
170 |
+
config = OmegaConf.to_container(config) # type: ignore
|
171 |
+
|
172 |
+
# Patch 2: Put decoder, encoder and detector outside seanet
|
173 |
+
seanet_config = config["seanet"]
|
174 |
+
for key_to_patch in ["encoder", "decoder", "detector"]:
|
175 |
+
if key_to_patch in seanet_config:
|
176 |
+
config_to_patch = config.get(key_to_patch) or {}
|
177 |
+
config[key_to_patch] = {
|
178 |
+
**config_to_patch,
|
179 |
+
**seanet_config.pop(key_to_patch),
|
180 |
+
}
|
181 |
+
|
182 |
+
config["seanet"] = seanet_config
|
183 |
+
|
184 |
+
# Patch 3: Put nbits into config if specified
|
185 |
+
if nbits and "nbits" not in config:
|
186 |
+
config["nbits"] = nbits
|
187 |
+
|
188 |
+
# remove attributes not related to the model_type
|
189 |
+
result_config = {}
|
190 |
+
assert config, f"Empty config"
|
191 |
+
for field in fields(config_type):
|
192 |
+
if field.name in config:
|
193 |
+
result_config[field.name] = config[field.name]
|
194 |
+
|
195 |
+
schema = OmegaConf.structured(config_type)
|
196 |
+
schema.merge_with(result_config)
|
197 |
+
return schema
|
198 |
+
|
199 |
+
@staticmethod
|
200 |
+
def load_generator(
|
201 |
+
model_card_or_path: str,
|
202 |
+
nbits: Optional[int] = None,
|
203 |
+
) -> AudioSealWM:
|
204 |
+
"""Load the AudioSeal generator from the model card"""
|
205 |
+
checkpoint, config = AudioSeal.parse_model(
|
206 |
+
model_card_or_path,
|
207 |
+
AudioSealWMConfig,
|
208 |
+
nbits=nbits,
|
209 |
+
)
|
210 |
+
|
211 |
+
model = create_generator(config)
|
212 |
+
model.load_state_dict(checkpoint)
|
213 |
+
return model
|
214 |
+
|
215 |
+
@staticmethod
|
216 |
+
def load_detector(
|
217 |
+
model_card_or_path: str,
|
218 |
+
nbits: Optional[int] = None,
|
219 |
+
) -> AudioSealDetector:
|
220 |
+
checkpoint, config = AudioSeal.parse_model(
|
221 |
+
model_card_or_path,
|
222 |
+
AudioSealDetectorConfig,
|
223 |
+
nbits=nbits,
|
224 |
+
)
|
225 |
+
model = create_detector(config)
|
226 |
+
model.load_state_dict(checkpoint)
|
227 |
+
return model
|
src/audioseal/models.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
from typing import Optional, Tuple
|
9 |
+
|
10 |
+
import julius
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from audioseal.libs.audiocraft.modules.seanet import SEANetEncoderKeepDimension
|
14 |
+
|
15 |
+
logger = logging.getLogger("Audioseal")
|
16 |
+
|
17 |
+
COMPATIBLE_WARNING = """
|
18 |
+
AudioSeal is designed to work at a sample rate 16khz.
|
19 |
+
Implicit sampling rate usage is deprecated and will be removed in future version.
|
20 |
+
To remove this warning please add this argument to the function call:
|
21 |
+
sample_rate = your_sample_rate
|
22 |
+
"""
|
23 |
+
|
24 |
+
|
25 |
+
class MsgProcessor(torch.nn.Module):
|
26 |
+
"""
|
27 |
+
Apply the secret message to the encoder output.
|
28 |
+
Args:
|
29 |
+
nbits: Number of bits used to generate the message. Must be non-zero
|
30 |
+
hidden_size: Dimension of the encoder output
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, nbits: int, hidden_size: int):
|
34 |
+
super().__init__()
|
35 |
+
assert nbits > 0, "MsgProcessor should not be built in 0bit watermarking"
|
36 |
+
self.nbits = nbits
|
37 |
+
self.hidden_size = hidden_size
|
38 |
+
self.msg_processor = torch.nn.Embedding(2 * nbits, hidden_size)
|
39 |
+
|
40 |
+
def forward(self, hidden: torch.Tensor, msg: torch.Tensor) -> torch.Tensor:
|
41 |
+
"""
|
42 |
+
Build the embedding map: 2 x k -> k x h, then sum on the first dim
|
43 |
+
Args:
|
44 |
+
hidden: The encoder output, size: batch x hidden x frames
|
45 |
+
msg: The secret message, size: batch x k
|
46 |
+
"""
|
47 |
+
# create indices to take from embedding layer
|
48 |
+
indices = 2 * torch.arange(msg.shape[-1]).to(msg.device) # k: 0 2 4 ... 2k
|
49 |
+
indices = indices.repeat(msg.shape[0], 1) # b x k
|
50 |
+
indices = (indices + msg).long()
|
51 |
+
msg_aux = self.msg_processor(indices) # b x k -> b x k x h
|
52 |
+
msg_aux = msg_aux.sum(dim=-2) # b x k x h -> b x h
|
53 |
+
msg_aux = msg_aux.unsqueeze(-1).repeat(
|
54 |
+
1, 1, hidden.shape[2]
|
55 |
+
) # b x h -> b x h x t/f
|
56 |
+
hidden = hidden + msg_aux # -> b x h x t/f
|
57 |
+
return hidden
|
58 |
+
|
59 |
+
|
60 |
+
class AudioSealWM(torch.nn.Module):
|
61 |
+
"""
|
62 |
+
Generate watermarking for a given audio signal
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
encoder: torch.nn.Module,
|
68 |
+
decoder: torch.nn.Module,
|
69 |
+
msg_processor: Optional[torch.nn.Module] = None,
|
70 |
+
):
|
71 |
+
super().__init__()
|
72 |
+
self.encoder = encoder
|
73 |
+
self.decoder = decoder
|
74 |
+
# The build should take care of validating the dimensions between component
|
75 |
+
self.msg_processor = msg_processor
|
76 |
+
self._message: Optional[torch.Tensor] = None
|
77 |
+
|
78 |
+
@property
|
79 |
+
def message(self) -> Optional[torch.Tensor]:
|
80 |
+
return self._message
|
81 |
+
|
82 |
+
@message.setter
|
83 |
+
def message(self, message: torch.Tensor) -> None:
|
84 |
+
self._message = message
|
85 |
+
|
86 |
+
def get_watermark(
|
87 |
+
self,
|
88 |
+
x: torch.Tensor,
|
89 |
+
sample_rate: Optional[int] = None,
|
90 |
+
message: Optional[torch.Tensor] = None,
|
91 |
+
) -> torch.Tensor:
|
92 |
+
"""
|
93 |
+
Get the watermark from an audio tensor and a message.
|
94 |
+
If the input message is None, a random message of
|
95 |
+
n bits {0,1} will be generated.
|
96 |
+
Args:
|
97 |
+
x: Audio signal, size: batch x frames
|
98 |
+
sample_rate: The sample rate of the input audio (default 16khz as
|
99 |
+
currently supported by the main AudioSeal model)
|
100 |
+
message: An optional binary message, size: batch x k
|
101 |
+
"""
|
102 |
+
length = x.size(-1)
|
103 |
+
if sample_rate is None:
|
104 |
+
logger.warning(COMPATIBLE_WARNING)
|
105 |
+
sample_rate = 16_000
|
106 |
+
assert sample_rate
|
107 |
+
if sample_rate != 16000:
|
108 |
+
x = julius.resample_frac(x, old_sr=sample_rate, new_sr=16000)
|
109 |
+
hidden = self.encoder(x)
|
110 |
+
|
111 |
+
if self.msg_processor is not None:
|
112 |
+
if message is None:
|
113 |
+
if self.message is None:
|
114 |
+
message = torch.randint(
|
115 |
+
0, 2, (x.shape[0], self.msg_processor.nbits), device=x.device
|
116 |
+
)
|
117 |
+
else:
|
118 |
+
message = self.message.to(device=x.device)
|
119 |
+
else:
|
120 |
+
message = message.to(device=x.device)
|
121 |
+
|
122 |
+
hidden = self.msg_processor(hidden, message)
|
123 |
+
|
124 |
+
watermark = self.decoder(hidden)
|
125 |
+
|
126 |
+
if sample_rate != 16000:
|
127 |
+
watermark = julius.resample_frac(
|
128 |
+
watermark, old_sr=16000, new_sr=sample_rate
|
129 |
+
)
|
130 |
+
|
131 |
+
return watermark[..., :length] # trim output cf encodec codebase
|
132 |
+
|
133 |
+
def forward(
|
134 |
+
self,
|
135 |
+
x: torch.Tensor,
|
136 |
+
sample_rate: Optional[int] = None,
|
137 |
+
message: Optional[torch.Tensor] = None,
|
138 |
+
alpha: float = 1.0,
|
139 |
+
) -> torch.Tensor:
|
140 |
+
"""Apply the watermarking to the audio signal x with a tune-down ratio (default 1.0)"""
|
141 |
+
if sample_rate is None:
|
142 |
+
logger.warning(COMPATIBLE_WARNING)
|
143 |
+
sample_rate = 16_000
|
144 |
+
wm = self.get_watermark(x, sample_rate=sample_rate, message=message)
|
145 |
+
return x + alpha * wm
|
146 |
+
|
147 |
+
|
148 |
+
class AudioSealDetector(torch.nn.Module):
|
149 |
+
"""
|
150 |
+
Detect the watermarking from an audio signal
|
151 |
+
Args:
|
152 |
+
SEANetEncoderKeepDimension (_type_): _description_
|
153 |
+
nbits (int): The number of bits in the secret message. The result will have size
|
154 |
+
of 2 + nbits, where the first two items indicate the possibilities of the
|
155 |
+
audio being watermarked (positive / negative scores), he rest is used to decode
|
156 |
+
the secret message. In 0bit watermarking (no secret message), the detector just
|
157 |
+
returns 2 values.
|
158 |
+
"""
|
159 |
+
|
160 |
+
def __init__(self, *args, nbits: int = 0, **kwargs):
|
161 |
+
super().__init__()
|
162 |
+
encoder = SEANetEncoderKeepDimension(*args, **kwargs)
|
163 |
+
last_layer = torch.nn.Conv1d(encoder.output_dim, 2 + nbits, 1)
|
164 |
+
self.detector = torch.nn.Sequential(encoder, last_layer)
|
165 |
+
self.nbits = nbits
|
166 |
+
|
167 |
+
def detect_watermark(
|
168 |
+
self,
|
169 |
+
x: torch.Tensor,
|
170 |
+
sample_rate: Optional[int] = None,
|
171 |
+
message_threshold: float = 0.5,
|
172 |
+
) -> Tuple[float, torch.Tensor]:
|
173 |
+
"""
|
174 |
+
A convenience function that returns a probability of an audio being watermarked,
|
175 |
+
together with its message in n-bits (binary) format. If the audio is not watermarked,
|
176 |
+
the message will be random.
|
177 |
+
Args:
|
178 |
+
x: Audio signal, size: batch x frames
|
179 |
+
sample_rate: The sample rate of the input audio
|
180 |
+
message_threshold: threshold used to convert the watermark output (probability
|
181 |
+
of each bits being 0 or 1) into the binary n-bit message.
|
182 |
+
"""
|
183 |
+
if sample_rate is None:
|
184 |
+
logger.warning(COMPATIBLE_WARNING)
|
185 |
+
sample_rate = 16_000
|
186 |
+
result, message = self.forward(x, sample_rate=sample_rate) # b x 2+nbits
|
187 |
+
detected = (
|
188 |
+
torch.count_nonzero(torch.gt(result[:, 1, :], 0.5)) / result.shape[-1]
|
189 |
+
)
|
190 |
+
detect_prob = detected.cpu().item() # type: ignore
|
191 |
+
message = torch.gt(message, message_threshold).int()
|
192 |
+
return detect_prob, message
|
193 |
+
|
194 |
+
def decode_message(self, result: torch.Tensor) -> torch.Tensor:
|
195 |
+
"""
|
196 |
+
Decode the message from the watermark result (batch x nbits x frames)
|
197 |
+
Args:
|
198 |
+
result: watermark result (batch x nbits x frames)
|
199 |
+
Returns:
|
200 |
+
The message of size batch x nbits, indicating probability of 1 for each bit
|
201 |
+
"""
|
202 |
+
assert (result.dim() > 2 and result.shape[1] == self.nbits) or (
|
203 |
+
self.dim() == 2 and result.shape[0] == self.nbits
|
204 |
+
), f"Expect message of size [,{self.nbits}, frames] (get {result.size()})"
|
205 |
+
decoded_message = result.mean(dim=-1)
|
206 |
+
return torch.sigmoid(decoded_message)
|
207 |
+
|
208 |
+
def forward(
|
209 |
+
self,
|
210 |
+
x: torch.Tensor,
|
211 |
+
sample_rate: Optional[int] = None,
|
212 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
213 |
+
"""
|
214 |
+
Detect the watermarks from the audio signal
|
215 |
+
Args:
|
216 |
+
x: Audio signal, size batch x frames
|
217 |
+
sample_rate: The sample rate of the input audio
|
218 |
+
"""
|
219 |
+
if sample_rate is None:
|
220 |
+
logger.warning(COMPATIBLE_WARNING)
|
221 |
+
sample_rate = 16_000
|
222 |
+
assert sample_rate
|
223 |
+
if sample_rate != 16000:
|
224 |
+
x = julius.resample_frac(x, old_sr=sample_rate, new_sr=16000)
|
225 |
+
result = self.detector(x) # b x 2+nbits
|
226 |
+
# hardcode softmax on 2 first units used for detection
|
227 |
+
result[:, :2, :] = torch.softmax(result[:, :2, :], dim=1)
|
228 |
+
message = self.decode_message(result[:, 2:, :])
|
229 |
+
return result[:, :2, :], message
|
src/audioseal/py.typed
ADDED
File without changes
|
src/scripts/checkpoints.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
|
13 |
+
def convert(checkpoint: str, outdir: str, suffix: str = "base"):
|
14 |
+
"""Convert the checkpoint to generator and detector"""
|
15 |
+
outdir_path = Path(outdir)
|
16 |
+
ckpt = torch.load(checkpoint)
|
17 |
+
|
18 |
+
# keep inference-related params only
|
19 |
+
infer_cfg = {
|
20 |
+
"seanet": ckpt["xp.cfg"]["seanet"],
|
21 |
+
"channels": ckpt["xp.cfg"]["channels"],
|
22 |
+
"dtype": ckpt["xp.cfg"]["dtype"],
|
23 |
+
"sample_rate": ckpt["xp.cfg"]["sample_rate"],
|
24 |
+
}
|
25 |
+
|
26 |
+
generator_ckpt = {"xp.cfg": infer_cfg, "model": {}}
|
27 |
+
detector_ckpt = {"xp.cfg": infer_cfg, "model": {}}
|
28 |
+
|
29 |
+
for layer in ckpt["model"].keys():
|
30 |
+
if layer.startswith("detector"):
|
31 |
+
new_layer = layer[9:]
|
32 |
+
detector_ckpt["model"][new_layer] = ckpt["model"][layer] # type: ignore
|
33 |
+
elif layer == "msg_processor.msg_processor.0.weight":
|
34 |
+
generator_ckpt["model"]["msg_processor.msg_processor.weight"] = ckpt[ # type: ignore
|
35 |
+
"model"
|
36 |
+
][
|
37 |
+
layer
|
38 |
+
]
|
39 |
+
else:
|
40 |
+
assert layer.startswith("generator"), f"Invalid layer: {layer}"
|
41 |
+
new_layer = layer[10:]
|
42 |
+
generator_ckpt["model"][new_layer] = ckpt["model"][layer] # type: ignore
|
43 |
+
|
44 |
+
torch.save(generator_ckpt, outdir_path / (f"checkpoint_generator_{suffix}.pth"))
|
45 |
+
torch.save(detector_ckpt, outdir_path / (f"checkpoint_detector_{suffix}.pth"))
|
46 |
+
|
47 |
+
|
48 |
+
if __name__ == "__main__":
|
49 |
+
import fire
|
50 |
+
|
51 |
+
fire.Fire(convert)
|