test1 / src /audioseal /models.py
Zw07's picture
Upload 14 files
0209786 verified
raw
history blame
8.62 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Optional, Tuple
import julius
import torch
from audioseal.libs.audiocraft.modules.seanet import SEANetEncoderKeepDimension
logger = logging.getLogger("Audioseal")
COMPATIBLE_WARNING = """
AudioSeal is designed to work at a sample rate 16khz.
Implicit sampling rate usage is deprecated and will be removed in future version.
To remove this warning please add this argument to the function call:
sample_rate = your_sample_rate
"""
class MsgProcessor(torch.nn.Module):
"""
Apply the secret message to the encoder output.
Args:
nbits: Number of bits used to generate the message. Must be non-zero
hidden_size: Dimension of the encoder output
"""
def __init__(self, nbits: int, hidden_size: int):
super().__init__()
assert nbits > 0, "MsgProcessor should not be built in 0bit watermarking"
self.nbits = nbits
self.hidden_size = hidden_size
self.msg_processor = torch.nn.Embedding(2 * nbits, hidden_size)
def forward(self, hidden: torch.Tensor, msg: torch.Tensor) -> torch.Tensor:
"""
Build the embedding map: 2 x k -> k x h, then sum on the first dim
Args:
hidden: The encoder output, size: batch x hidden x frames
msg: The secret message, size: batch x k
"""
# create indices to take from embedding layer
indices = 2 * torch.arange(msg.shape[-1]).to(msg.device) # k: 0 2 4 ... 2k
indices = indices.repeat(msg.shape[0], 1) # b x k
indices = (indices + msg).long()
msg_aux = self.msg_processor(indices) # b x k -> b x k x h
msg_aux = msg_aux.sum(dim=-2) # b x k x h -> b x h
msg_aux = msg_aux.unsqueeze(-1).repeat(
1, 1, hidden.shape[2]
) # b x h -> b x h x t/f
hidden = hidden + msg_aux # -> b x h x t/f
return hidden
class AudioSealWM(torch.nn.Module):
"""
Generate watermarking for a given audio signal
"""
def __init__(
self,
encoder: torch.nn.Module,
decoder: torch.nn.Module,
msg_processor: Optional[torch.nn.Module] = None,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
# The build should take care of validating the dimensions between component
self.msg_processor = msg_processor
self._message: Optional[torch.Tensor] = None
@property
def message(self) -> Optional[torch.Tensor]:
return self._message
@message.setter
def message(self, message: torch.Tensor) -> None:
self._message = message
def get_watermark(
self,
x: torch.Tensor,
sample_rate: Optional[int] = None,
message: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Get the watermark from an audio tensor and a message.
If the input message is None, a random message of
n bits {0,1} will be generated.
Args:
x: Audio signal, size: batch x frames
sample_rate: The sample rate of the input audio (default 16khz as
currently supported by the main AudioSeal model)
message: An optional binary message, size: batch x k
"""
length = x.size(-1)
if sample_rate is None:
logger.warning(COMPATIBLE_WARNING)
sample_rate = 16_000
assert sample_rate
if sample_rate != 16000:
x = julius.resample_frac(x, old_sr=sample_rate, new_sr=16000)
hidden = self.encoder(x)
if self.msg_processor is not None:
if message is None:
if self.message is None:
message = torch.randint(
0, 2, (x.shape[0], self.msg_processor.nbits), device=x.device
)
else:
message = self.message.to(device=x.device)
else:
message = message.to(device=x.device)
hidden = self.msg_processor(hidden, message)
watermark = self.decoder(hidden)
if sample_rate != 16000:
watermark = julius.resample_frac(
watermark, old_sr=16000, new_sr=sample_rate
)
return watermark[..., :length] # trim output cf encodec codebase
def forward(
self,
x: torch.Tensor,
sample_rate: Optional[int] = None,
message: Optional[torch.Tensor] = None,
alpha: float = 1.0,
) -> torch.Tensor:
"""Apply the watermarking to the audio signal x with a tune-down ratio (default 1.0)"""
if sample_rate is None:
logger.warning(COMPATIBLE_WARNING)
sample_rate = 16_000
wm = self.get_watermark(x, sample_rate=sample_rate, message=message)
return x + alpha * wm
class AudioSealDetector(torch.nn.Module):
"""
Detect the watermarking from an audio signal
Args:
SEANetEncoderKeepDimension (_type_): _description_
nbits (int): The number of bits in the secret message. The result will have size
of 2 + nbits, where the first two items indicate the possibilities of the
audio being watermarked (positive / negative scores), he rest is used to decode
the secret message. In 0bit watermarking (no secret message), the detector just
returns 2 values.
"""
def __init__(self, *args, nbits: int = 0, **kwargs):
super().__init__()
encoder = SEANetEncoderKeepDimension(*args, **kwargs)
last_layer = torch.nn.Conv1d(encoder.output_dim, 2 + nbits, 1)
self.detector = torch.nn.Sequential(encoder, last_layer)
self.nbits = nbits
def detect_watermark(
self,
x: torch.Tensor,
sample_rate: Optional[int] = None,
message_threshold: float = 0.5,
) -> Tuple[float, torch.Tensor]:
"""
A convenience function that returns a probability of an audio being watermarked,
together with its message in n-bits (binary) format. If the audio is not watermarked,
the message will be random.
Args:
x: Audio signal, size: batch x frames
sample_rate: The sample rate of the input audio
message_threshold: threshold used to convert the watermark output (probability
of each bits being 0 or 1) into the binary n-bit message.
"""
if sample_rate is None:
logger.warning(COMPATIBLE_WARNING)
sample_rate = 16_000
result, message = self.forward(x, sample_rate=sample_rate) # b x 2+nbits
detected = (
torch.count_nonzero(torch.gt(result[:, 1, :], 0.5)) / result.shape[-1]
)
detect_prob = detected.cpu().item() # type: ignore
message = torch.gt(message, message_threshold).int()
return detect_prob, message
def decode_message(self, result: torch.Tensor) -> torch.Tensor:
"""
Decode the message from the watermark result (batch x nbits x frames)
Args:
result: watermark result (batch x nbits x frames)
Returns:
The message of size batch x nbits, indicating probability of 1 for each bit
"""
assert (result.dim() > 2 and result.shape[1] == self.nbits) or (
self.dim() == 2 and result.shape[0] == self.nbits
), f"Expect message of size [,{self.nbits}, frames] (get {result.size()})"
decoded_message = result.mean(dim=-1)
return torch.sigmoid(decoded_message)
def forward(
self,
x: torch.Tensor,
sample_rate: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Detect the watermarks from the audio signal
Args:
x: Audio signal, size batch x frames
sample_rate: The sample rate of the input audio
"""
if sample_rate is None:
logger.warning(COMPATIBLE_WARNING)
sample_rate = 16_000
assert sample_rate
if sample_rate != 16000:
x = julius.resample_frac(x, old_sr=sample_rate, new_sr=16000)
result = self.detector(x) # b x 2+nbits
# hardcode softmax on 2 first units used for detection
result[:, :2, :] = torch.softmax(result[:, :2, :], dim=1)
message = self.decode_message(result[:, 2:, :])
return result[:, :2, :], message