File size: 3,811 Bytes
f4f5a40
f03ec98
e7c7540
1111e0a
811d3ce
1111e0a
 
 
dff69a4
671512a
 
 
880164c
671512a
 
 
4e2a28d
 
 
 
 
fd5ff13
4e2a28d
4b84392
 
 
 
 
 
 
 
 
 
 
64693b7
880164c
fd5ff13
1111e0a
8b7f20a
fd5ff13
1111e0a
 
 
 
64693b7
fd5ff13
1111e0a
64693b7
1111e0a
 
 
 
79ee142
 
 
 
 
 
 
 
 
 
 
 
 
 
fd5ff13
79ee142
 
 
fd5ff13
 
5557760
79ee142
 
 
fd5ff13
880164c
 
 
 
 
 
 
 
 
 
9ffef8e
270455b
fd5ff13
 
1111e0a
09f8c18
880164c
270455b
671512a
f03ec98
fd5ff13
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
import torch
import torchaudio
import torchaudio.transforms as T
import traceback
import matplotlib.pyplot as plt
import io
from PIL import Image

# Ensure AudioSeal is imported correctly
try:
    from audioseal import AudioSeal
    print("AudioSeal is available for watermark detection.")
except ImportError as e:
    print(f"AudioSeal could not be imported: {e}")

def load_and_resample_audio(audio_file_path, target_sample_rate=16000):
    waveform, sample_rate = torchaudio.load(audio_file_path)
    if sample_rate != target_sample_rate:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        waveform = resampler(waveform)
    return waveform, target_sample_rate

def extract_mfcc_features(waveform, sample_rate, n_mfcc=40, n_mels=128, win_length=400, hop_length=160):
    mfcc_transform = T.MFCC(
        sample_rate=sample_rate,
        n_mfcc=n_mfcc,
        melkwargs={
            'n_fft': 400,
            'n_mels': n_mels,
            'hop_length': hop_length,
            'win_length': win_length
        }
    )
    mfcc = mfcc_transform(waveform)
    return mfcc.mean(dim=2)
    
def plot_spectrogram(waveform, sample_rate):
    if waveform.ndim == 1:
        waveform = waveform.unsqueeze(0)  # Ensure waveform is 2D
    spectrogram_transform = T.Spectrogram()
    spectrogram = spectrogram_transform(waveform)
    spectrogram_db = torchaudio.transforms.AmplitudeToDB()(spectrogram)
    plt.figure(figsize=(10, 4))
    plt.imshow(spectrogram_db[0].numpy(), cmap='hot', aspect='auto', origin='lower')
    plt.axis('off')  # Hide axes for a clean image
    buf = io.BytesIO()
    plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
    plt.close()
    buf.seek(0)
    return Image.open(buf)

def detect_watermark(audio_data, sample_rate):
    # Ensure AudioSeal is available
    if 'AudioSeal' not in globals():
        spectrogram_image = plot_spectrogram(audio_data, sample_rate)
        return "AudioSeal not available", spectrogram_image
    
    # Load audio data correctly
    waveform, sr = load_and_resample_audio(audio_data, target_sample_rate=16000)
    
    # Ensure waveform is a tensor before passing to the detector
    if not isinstance(waveform, torch.Tensor):
        return "Error: waveform is not a tensor.", plot_spectrogram(waveform, sr)
    
    # Load the detector
    detector = AudioSeal.load_detector("audioseal_detector_16bits")
    
    # Process waveform with the detector
    results, messages = detector.forward(waveform.unsqueeze(0), sample_rate=sample_rate)  # Ensure waveform is in batch form
    detect_probs = results[:, 1, :]
    result = detect_probs.mean().cpu().item()
    message = f"Detection result: {'Watermarked Audio' if result > 0.5 else 'Not watermarked'}"
    spectrogram_image = plot_spectrogram(waveform, sr)
    
    return message, spectrogram_image

def main(audio_file_path):
    waveform, resampled_sr = load_and_resample_audio(audio_file_path)
    plot_spectrogram(waveform, resampled_sr)
    samples_per_batch = 5 * resampled_sr  # 5s audios
    audio_batches = torch.split(waveform, samples_per_batch, dim=1)[:-1]  # Exclude the last batch if it's not 5 seconds long
    audio_batched = torch.concat(audio_batches, dim=0)
    audio_batched = audio_batched.unsqueeze(1) # add channel dimension
    result = detect_watermark(audio_batched, resampled_sr)
    print(f"Probability of watermark: {result}")

# Gradio interface
interface = gr.Interface(
    fn=detect_watermark,
    inputs=gr.Audio(label="Upload your audio", type="filepath"),
    outputs=["text", "image"],
    title="Deep Fake Defender: AudioSeal Watermark Detection",
    description="Analyzes audio to detect AI-generated content."
)

if __name__ == "__main__":
    interface.launch()