test1 / src /audioseal /models.py
Zw07's picture
src files
83940d8 verified
raw
history blame
7.73 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Optional, Tuple
import librosa
import numpy as np
import torch
from audioseal.libs.audiocraft.modules.seanet import SEANetEncoderKeepDimension
logger = logging.getLogger("Audioseal")
COMPATIBLE_WARNING = """
AudioSeal is designed to work at a sample rate 16khz.
Implicit sampling rate usage is deprecated and will be removed in future version.
To remove this warning please add this argument to the function call:
sample_rate = your_sample_rate
"""
class MsgProcessor(torch.nn.Module):
def __init__(self, nbits: int, hidden_size: int):
super().__init__()
assert nbits > 0, "MsgProcessor should not be built in 0bit watermarking"
self.nbits = nbits
self.hidden_size = hidden_size
self.msg_processor = torch.nn.Embedding(2 * nbits, hidden_size)
def forward(self, hidden: torch.Tensor, msg: torch.Tensor) -> torch.Tensor:
indices = 2 * torch.arange(msg.shape[-1]).to(msg.device)
indices = indices.repeat(msg.shape[0], 1)
indices = (indices + msg).long()
msg_aux = self.msg_processor(indices)
msg_aux = msg_aux.sum(dim=-2)
msg_aux = msg_aux.unsqueeze(-1).repeat(1, 1, hidden.shape[2])
hidden = hidden + msg_aux
return hidden
def compute_stft_energy(audio: torch.Tensor, sr: int, n_fft: int = 2048, hop_length: int = 512) -> torch.Tensor:
batch_size = audio.size(0)
energy_values = []
for i in range(batch_size):
y = audio[i].cpu().numpy()
stft = np.abs(librosa.stft(y, n_fft=n_fft, hop_length=hop_length))
frame_energy = torch.tensor(np.sum(stft ** 2, axis=0), device=audio.device)
energy_values.append(frame_energy)
energy_values = torch.stack(energy_values, dim=0)
return energy_values
def compute_adaptive_alpha_librosa(energy_values: torch.Tensor, min_alpha: float = 0.5, max_alpha: float = 1.5) -> torch.Tensor:
normalized_energy = (energy_values - energy_values.min(dim=1, keepdim=True)[0]) / (
energy_values.max(dim=1, keepdim=True)[0] - energy_values.min(dim=1, keepdim=True)[0] + 1e-6
)
alpha_values = min_alpha + normalized_energy * (max_alpha - min_alpha)
return alpha_values
class AudioSealWM(torch.nn.Module):
def __init__(self, encoder: torch.nn.Module, decoder: torch.nn.Module, msg_processor: Optional[torch.nn.Module] = None):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.msg_processor = msg_processor
self._message: Optional[torch.Tensor] = None
self._original_payload: Optional[torch.Tensor] = None
@property
def message(self) -> Optional[torch.Tensor]:
return self._message
@message.setter
def message(self, message: torch.Tensor) -> None:
self._message = message
def get_original_payload(self) -> Optional[torch.Tensor]:
return self._original_payload
def get_watermark(self, x: torch.Tensor, sample_rate: Optional[int] = None, message: Optional[torch.Tensor] = None) -> torch.Tensor:
# Call the forward method manually here
return self.forward(x, sample_rate, message)
def forward(self, x: torch.Tensor, sample_rate: Optional[int] = None, message: Optional[torch.Tensor] = None,
n_fft: int = 2048, hop_length: int = 512, min_alpha: float = 0.5, max_alpha: float = 1.5) -> torch.Tensor:
print("Forward method called!") # This should always print if forward is being executed
if sample_rate is None:
logger.warning(COMPATIBLE_WARNING)
sample_rate = 16_000
if sample_rate != 16000:
x_np = x.detach().cpu().numpy() # Ensure detached tensor is converted to NumPy array
resampled_x = librosa.resample(x_np, orig_sr=sample_rate, target_sr=16000)
x = torch.tensor(resampled_x, device=x.device)
hidden = self.encoder(x)
if self.msg_processor is not None:
if message is None:
if self.message is None:
message = torch.randint(0, 2, (x.shape[0], self.msg_processor.nbits), device=x.device)
else:
message = self.message.to(device=x.device)
else:
message = message.to(device=x.device)
hidden = self.msg_processor(hidden, message)
self._original_payload = message
watermark = self.decoder(hidden)
if sample_rate != 16000:
watermark_np = watermark.detach().cpu().numpy()
resampled_watermark = librosa.resample(watermark_np, orig_sr=16000, target_sr=sample_rate)
watermark = torch.tensor(resampled_watermark, device=watermark.device)
energy_values = compute_stft_energy(x, sr=sample_rate, n_fft=n_fft, hop_length=hop_length)
adaptive_alpha = compute_adaptive_alpha_librosa(energy_values, min_alpha=min_alpha, max_alpha=max_alpha)
# Adjust stretched_alpha to match the dimensions of watermark
num_frames = adaptive_alpha.size(1)
stretched_alpha = torch.repeat_interleave(adaptive_alpha, hop_length, dim=1)
stretched_alpha = stretched_alpha[:, :x.size(1)]
# Make sure dimensions align
if stretched_alpha.dim() < watermark.dim():
stretched_alpha = stretched_alpha.unsqueeze(-1) # Add extra dimension
stretched_alpha = stretched_alpha.expand_as(watermark) # Match dimensions
print(f"stretched_alpha shape: {stretched_alpha.shape} for debugging")
watermarked_audio = x + stretched_alpha * watermark
return watermarked_audio
class AudioSealDetector(torch.nn.Module):
def __init__(self, *args, nbits: int = 0, **kwargs):
super().__init__()
encoder = SEANetEncoderKeepDimension(*args, **kwargs)
last_layer = torch.nn.Conv1d(encoder.output_dim, 2 + nbits, 1)
self.detector = torch.nn.Sequential(encoder, last_layer)
self.nbits = nbits
def detect_watermark(self, x: torch.Tensor, sample_rate: Optional[int] = None, message_threshold: float = 0.5) -> Tuple[float, torch.Tensor]:
result, message = self.forward(x, sample_rate=sample_rate)
print("Forward method in detector called!")
detected = (torch.count_nonzero(torch.gt(result[:, 1, :], 0.5)) / result.shape[-1])
detect_prob = detected.cpu().item()
message = torch.gt(message, message_threshold).int()
return detect_prob, message
def decode_message(self, result: torch.Tensor) -> torch.Tensor:
assert (result.dim() > 2 and result.shape[1] == self.nbits) or (
result.dim() == 2 and result.shape[0] == self.nbits
), f"Expect message of size [,{self.nbits}, frames] (get {result.size()})"
decoded_message = result.mean(dim=-1)
return torch.sigmoid(decoded_message)
def forward(self, x: torch.Tensor, sample_rate: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
if sample_rate is None:
logger.warning(COMPATIBLE_WARNING)
sample_rate = 16_000
if sample_rate != 16000:
x_np = x.detach().cpu().numpy()
resampled_x = librosa.resample(x_np, orig_sr=sample_rate, target_sr=16000)
x = torch.tensor(resampled_x, device=x.device)
result = self.detector(x)
result[:, :2, :] = torch.softmax(result[:, :2, :], dim=1)
message = self.decode_message(result[:, 2:, :])
return result[:, :2, :], message