File size: 3,811 Bytes
f4f5a40 f03ec98 e7c7540 1111e0a 811d3ce 1111e0a dff69a4 671512a 880164c 671512a 4e2a28d fd5ff13 4e2a28d 4b84392 64693b7 880164c fd5ff13 1111e0a 8b7f20a fd5ff13 1111e0a 64693b7 fd5ff13 1111e0a 64693b7 1111e0a 79ee142 fd5ff13 79ee142 fd5ff13 5557760 79ee142 fd5ff13 880164c 9ffef8e 270455b fd5ff13 1111e0a 09f8c18 880164c 270455b 671512a f03ec98 fd5ff13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import gradio as gr
import torch
import torchaudio
import torchaudio.transforms as T
import traceback
import matplotlib.pyplot as plt
import io
from PIL import Image
# Ensure AudioSeal is imported correctly
try:
from audioseal import AudioSeal
print("AudioSeal is available for watermark detection.")
except ImportError as e:
print(f"AudioSeal could not be imported: {e}")
def load_and_resample_audio(audio_file_path, target_sample_rate=16000):
waveform, sample_rate = torchaudio.load(audio_file_path)
if sample_rate != target_sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
waveform = resampler(waveform)
return waveform, target_sample_rate
def extract_mfcc_features(waveform, sample_rate, n_mfcc=40, n_mels=128, win_length=400, hop_length=160):
mfcc_transform = T.MFCC(
sample_rate=sample_rate,
n_mfcc=n_mfcc,
melkwargs={
'n_fft': 400,
'n_mels': n_mels,
'hop_length': hop_length,
'win_length': win_length
}
)
mfcc = mfcc_transform(waveform)
return mfcc.mean(dim=2)
def plot_spectrogram(waveform, sample_rate):
if waveform.ndim == 1:
waveform = waveform.unsqueeze(0) # Ensure waveform is 2D
spectrogram_transform = T.Spectrogram()
spectrogram = spectrogram_transform(waveform)
spectrogram_db = torchaudio.transforms.AmplitudeToDB()(spectrogram)
plt.figure(figsize=(10, 4))
plt.imshow(spectrogram_db[0].numpy(), cmap='hot', aspect='auto', origin='lower')
plt.axis('off') # Hide axes for a clean image
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
plt.close()
buf.seek(0)
return Image.open(buf)
def detect_watermark(audio_data, sample_rate):
# Ensure AudioSeal is available
if 'AudioSeal' not in globals():
spectrogram_image = plot_spectrogram(audio_data, sample_rate)
return "AudioSeal not available", spectrogram_image
# Load audio data correctly
waveform, sr = load_and_resample_audio(audio_data, target_sample_rate=16000)
# Ensure waveform is a tensor before passing to the detector
if not isinstance(waveform, torch.Tensor):
return "Error: waveform is not a tensor.", plot_spectrogram(waveform, sr)
# Load the detector
detector = AudioSeal.load_detector("audioseal_detector_16bits")
# Process waveform with the detector
results, messages = detector.forward(waveform.unsqueeze(0), sample_rate=sample_rate) # Ensure waveform is in batch form
detect_probs = results[:, 1, :]
result = detect_probs.mean().cpu().item()
message = f"Detection result: {'Watermarked Audio' if result > 0.5 else 'Not watermarked'}"
spectrogram_image = plot_spectrogram(waveform, sr)
return message, spectrogram_image
def main(audio_file_path):
waveform, resampled_sr = load_and_resample_audio(audio_file_path)
plot_spectrogram(waveform, resampled_sr)
samples_per_batch = 5 * resampled_sr # 5s audios
audio_batches = torch.split(waveform, samples_per_batch, dim=1)[:-1] # Exclude the last batch if it's not 5 seconds long
audio_batched = torch.concat(audio_batches, dim=0)
audio_batched = audio_batched.unsqueeze(1) # add channel dimension
result = detect_watermark(audio_batched, resampled_sr)
print(f"Probability of watermark: {result}")
# Gradio interface
interface = gr.Interface(
fn=detect_watermark,
inputs=gr.Audio(label="Upload your audio", type="filepath"),
outputs=["text", "image"],
title="Deep Fake Defender: AudioSeal Watermark Detection",
description="Analyzes audio to detect AI-generated content."
)
if __name__ == "__main__":
interface.launch() |