|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
from typing import Optional, Tuple |
|
|
|
import julius |
|
import torch |
|
|
|
from audioseal.libs.audiocraft.modules.seanet import SEANetEncoderKeepDimension |
|
|
|
logger = logging.getLogger("Audioseal") |
|
|
|
COMPATIBLE_WARNING = """ |
|
AudioSeal is designed to work at a sample rate 16khz. |
|
Implicit sampling rate usage is deprecated and will be removed in future version. |
|
To remove this warning please add this argument to the function call: |
|
sample_rate = your_sample_rate |
|
""" |
|
|
|
|
|
class MsgProcessor(torch.nn.Module): |
|
""" |
|
Apply the secret message to the encoder output. |
|
Args: |
|
nbits: Number of bits used to generate the message. Must be non-zero |
|
hidden_size: Dimension of the encoder output |
|
""" |
|
|
|
def __init__(self, nbits: int, hidden_size: int): |
|
super().__init__() |
|
assert nbits > 0, "MsgProcessor should not be built in 0bit watermarking" |
|
self.nbits = nbits |
|
self.hidden_size = hidden_size |
|
self.msg_processor = torch.nn.Embedding(2 * nbits, hidden_size) |
|
|
|
def forward(self, hidden: torch.Tensor, msg: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Build the embedding map: 2 x k -> k x h, then sum on the first dim |
|
Args: |
|
hidden: The encoder output, size: batch x hidden x frames |
|
msg: The secret message, size: batch x k |
|
""" |
|
|
|
indices = 2 * torch.arange(msg.shape[-1]).to(msg.device) |
|
indices = indices.repeat(msg.shape[0], 1) |
|
indices = (indices + msg).long() |
|
msg_aux = self.msg_processor(indices) |
|
msg_aux = msg_aux.sum(dim=-2) |
|
msg_aux = msg_aux.unsqueeze(-1).repeat( |
|
1, 1, hidden.shape[2] |
|
) |
|
hidden = hidden + msg_aux |
|
return hidden |
|
|
|
|
|
class AudioSealWM(torch.nn.Module): |
|
""" |
|
Generate watermarking for a given audio signal |
|
""" |
|
|
|
def __init__( |
|
self, |
|
encoder: torch.nn.Module, |
|
decoder: torch.nn.Module, |
|
msg_processor: Optional[torch.nn.Module] = None, |
|
): |
|
super().__init__() |
|
self.encoder = encoder |
|
self.decoder = decoder |
|
|
|
self.msg_processor = msg_processor |
|
self._message: Optional[torch.Tensor] = None |
|
|
|
@property |
|
def message(self) -> Optional[torch.Tensor]: |
|
return self._message |
|
|
|
@message.setter |
|
def message(self, message: torch.Tensor) -> None: |
|
self._message = message |
|
|
|
def get_watermark( |
|
self, |
|
x: torch.Tensor, |
|
sample_rate: Optional[int] = None, |
|
message: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
""" |
|
Get the watermark from an audio tensor and a message. |
|
If the input message is None, a random message of |
|
n bits {0,1} will be generated. |
|
Args: |
|
x: Audio signal, size: batch x frames |
|
sample_rate: The sample rate of the input audio (default 16khz as |
|
currently supported by the main AudioSeal model) |
|
message: An optional binary message, size: batch x k |
|
""" |
|
length = x.size(-1) |
|
if sample_rate is None: |
|
logger.warning(COMPATIBLE_WARNING) |
|
sample_rate = 16_000 |
|
assert sample_rate |
|
if sample_rate != 16000: |
|
x = julius.resample_frac(x, old_sr=sample_rate, new_sr=16000) |
|
hidden = self.encoder(x) |
|
|
|
if self.msg_processor is not None: |
|
if message is None: |
|
if self.message is None: |
|
message = torch.randint( |
|
0, 2, (x.shape[0], self.msg_processor.nbits), device=x.device |
|
) |
|
else: |
|
message = self.message.to(device=x.device) |
|
else: |
|
message = message.to(device=x.device) |
|
|
|
hidden = self.msg_processor(hidden, message) |
|
|
|
watermark = self.decoder(hidden) |
|
|
|
if sample_rate != 16000: |
|
watermark = julius.resample_frac( |
|
watermark, old_sr=16000, new_sr=sample_rate |
|
) |
|
|
|
return watermark[..., :length] |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
sample_rate: Optional[int] = None, |
|
message: Optional[torch.Tensor] = None, |
|
alpha: float = 1.0, |
|
) -> torch.Tensor: |
|
"""Apply the watermarking to the audio signal x with a tune-down ratio (default 1.0)""" |
|
if sample_rate is None: |
|
logger.warning(COMPATIBLE_WARNING) |
|
sample_rate = 16_000 |
|
wm = self.get_watermark(x, sample_rate=sample_rate, message=message) |
|
return x + alpha * wm |
|
|
|
|
|
class AudioSealDetector(torch.nn.Module): |
|
""" |
|
Detect the watermarking from an audio signal |
|
Args: |
|
SEANetEncoderKeepDimension (_type_): _description_ |
|
nbits (int): The number of bits in the secret message. The result will have size |
|
of 2 + nbits, where the first two items indicate the possibilities of the |
|
audio being watermarked (positive / negative scores), he rest is used to decode |
|
the secret message. In 0bit watermarking (no secret message), the detector just |
|
returns 2 values. |
|
""" |
|
|
|
def __init__(self, *args, nbits: int = 0, **kwargs): |
|
super().__init__() |
|
encoder = SEANetEncoderKeepDimension(*args, **kwargs) |
|
last_layer = torch.nn.Conv1d(encoder.output_dim, 2 + nbits, 1) |
|
self.detector = torch.nn.Sequential(encoder, last_layer) |
|
self.nbits = nbits |
|
|
|
def detect_watermark( |
|
self, |
|
x: torch.Tensor, |
|
sample_rate: Optional[int] = None, |
|
message_threshold: float = 0.5, |
|
) -> Tuple[float, torch.Tensor]: |
|
""" |
|
A convenience function that returns a probability of an audio being watermarked, |
|
together with its message in n-bits (binary) format. If the audio is not watermarked, |
|
the message will be random. |
|
Args: |
|
x: Audio signal, size: batch x frames |
|
sample_rate: The sample rate of the input audio |
|
message_threshold: threshold used to convert the watermark output (probability |
|
of each bits being 0 or 1) into the binary n-bit message. |
|
""" |
|
if sample_rate is None: |
|
logger.warning(COMPATIBLE_WARNING) |
|
sample_rate = 16_000 |
|
result, message = self.forward(x, sample_rate=sample_rate) |
|
detected = ( |
|
torch.count_nonzero(torch.gt(result[:, 1, :], 0.5)) / result.shape[-1] |
|
) |
|
detect_prob = detected.cpu().item() |
|
message = torch.gt(message, message_threshold).int() |
|
return detect_prob, message |
|
|
|
def decode_message(self, result: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Decode the message from the watermark result (batch x nbits x frames) |
|
Args: |
|
result: watermark result (batch x nbits x frames) |
|
Returns: |
|
The message of size batch x nbits, indicating probability of 1 for each bit |
|
""" |
|
assert (result.dim() > 2 and result.shape[1] == self.nbits) or ( |
|
self.dim() == 2 and result.shape[0] == self.nbits |
|
), f"Expect message of size [,{self.nbits}, frames] (get {result.size()})" |
|
decoded_message = result.mean(dim=-1) |
|
return torch.sigmoid(decoded_message) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
sample_rate: Optional[int] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Detect the watermarks from the audio signal |
|
Args: |
|
x: Audio signal, size batch x frames |
|
sample_rate: The sample rate of the input audio |
|
""" |
|
if sample_rate is None: |
|
logger.warning(COMPATIBLE_WARNING) |
|
sample_rate = 16_000 |
|
assert sample_rate |
|
if sample_rate != 16000: |
|
x = julius.resample_frac(x, old_sr=sample_rate, new_sr=16000) |
|
result = self.detector(x) |
|
|
|
result[:, :2, :] = torch.softmax(result[:, :2, :], dim=1) |
|
message = self.decode_message(result[:, 2:, :]) |
|
return result[:, :2, :], message |
|
|