Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- ISMIR_2025/MERT/__pycache__/datalib.cpython-311.pyc +0 -0
- ISMIR_2025/MERT/__pycache__/datalib.cpython-312.pyc +0 -0
- ISMIR_2025/MERT/__pycache__/datalib_singfake.cpython-311.pyc +0 -0
- ISMIR_2025/MERT/__pycache__/main.cpython-312.pyc +0 -0
- ISMIR_2025/MERT/__pycache__/networks.cpython-311.pyc +0 -0
- ISMIR_2025/MERT/__pycache__/networks.cpython-312.pyc +0 -0
- ISMIR_2025/MERT/__pycache__/networks.cpython-39.pyc +0 -0
- ISMIR_2025/MERT/datalib.py +203 -0
- ISMIR_2025/MERT/main.py +197 -0
- ISMIR_2025/MERT/networks.py +107 -0
- ISMIR_2025/MERT/test.py +114 -0
- ISMIR_2025/MERT/utils/__pycache__/config.cpython-311.pyc +0 -0
- ISMIR_2025/MERT/utils/__pycache__/idr_torch.cpython-311.pyc +0 -0
- ISMIR_2025/MERT/utils/__pycache__/utilities.cpython-311.pyc +0 -0
- ISMIR_2025/MERT/utils/config.py +565 -0
- ISMIR_2025/MERT/utils/confusion_matrix_plot.py +29 -0
- ISMIR_2025/MERT/utils/freqeuncy.py +24 -0
- ISMIR_2025/MERT/utils/hf_vis.py +89 -0
- ISMIR_2025/MERT/utils/idr_torch.py +23 -0
- ISMIR_2025/MERT/utils/mfcc.py +266 -0
- ISMIR_2025/MERT/utils/utilities.py +305 -0
- ISMIR_2025/Model/__pycache__/networks.cpython-312.pyc +0 -0
- ISMIR_2025/Model/datalib.py +206 -0
- ISMIR_2025/Model/main.py +336 -0
- ISMIR_2025/Model/networks.py +237 -0
- ISMIR_2025/Model/test.py +129 -0
- ISMIR_2025/music2vec/__pycache__/datalib.cpython-311.pyc +0 -0
- ISMIR_2025/music2vec/__pycache__/networks.cpython-311.pyc +0 -0
- ISMIR_2025/music2vec/__pycache__/networks.cpython-312.pyc +0 -0
- ISMIR_2025/music2vec/datalib.py +144 -0
- ISMIR_2025/music2vec/inference.py +64 -0
- ISMIR_2025/music2vec/main.py +155 -0
- ISMIR_2025/music2vec/networks.py +247 -0
- ISMIR_2025/music2vec/test.py +119 -0
- ISMIR_2025/wav2vec/__pycache__/datalib.cpython-311.pyc +0 -0
- ISMIR_2025/wav2vec/__pycache__/loss.cpython-311.pyc +0 -0
- ISMIR_2025/wav2vec/__pycache__/networks.cpython-311.pyc +0 -0
- ISMIR_2025/wav2vec/__pycache__/networks.cpython-312.pyc +0 -0
- ISMIR_2025/wav2vec/__pycache__/wav2vec_datalib.cpython-311.pyc +0 -0
- ISMIR_2025/wav2vec/datalib.py +139 -0
- ISMIR_2025/wav2vec/inference.py +71 -0
- ISMIR_2025/wav2vec/main.py +162 -0
- ISMIR_2025/wav2vec/networks.py +161 -0
- ISMIR_2025/wav2vec/test.py +148 -0
- ISMIR_2025/wav2vec/utils/__pycache__/config.cpython-311.pyc +0 -0
- ISMIR_2025/wav2vec/utils/__pycache__/idr_torch.cpython-311.pyc +0 -0
- ISMIR_2025/wav2vec/utils/__pycache__/utilities.cpython-311.pyc +0 -0
- ISMIR_2025/wav2vec/utils/config.py +565 -0
- ISMIR_2025/wav2vec/utils/confusion_matrix_plot.py +29 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
v2_request/request_music/AI_Detection/9486_music_id_Lost[[:space:]]Sky[[:space:]]x[[:space:]]Anna[[:space:]]Yvette[[:space:]]-[[:space:]]Carry[[:space:]]On[[:space:]]|[[:space:]]Trap[[:space:]]|[[:space:]]NCS[[:space:]]-[[:space:]]Copyright[[:space:]]Free[[:space:]]Music.wav filter=lfs diff=lfs merge=lfs -text
|
ISMIR_2025/MERT/__pycache__/datalib.cpython-311.pyc
ADDED
Binary file (10.1 kB). View file
|
|
ISMIR_2025/MERT/__pycache__/datalib.cpython-312.pyc
ADDED
Binary file (9.26 kB). View file
|
|
ISMIR_2025/MERT/__pycache__/datalib_singfake.cpython-311.pyc
ADDED
Binary file (8.69 kB). View file
|
|
ISMIR_2025/MERT/__pycache__/main.cpython-312.pyc
ADDED
Binary file (11.6 kB). View file
|
|
ISMIR_2025/MERT/__pycache__/networks.cpython-311.pyc
ADDED
Binary file (6.86 kB). View file
|
|
ISMIR_2025/MERT/__pycache__/networks.cpython-312.pyc
ADDED
Binary file (6.62 kB). View file
|
|
ISMIR_2025/MERT/__pycache__/networks.cpython-39.pyc
ADDED
Binary file (4.33 kB). View file
|
|
ISMIR_2025/MERT/datalib.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
import librosa
|
6 |
+
import numpy as np
|
7 |
+
from sklearn.model_selection import train_test_split
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
from imblearn.over_sampling import RandomOverSampler
|
10 |
+
from transformers import Wav2Vec2Processor
|
11 |
+
import torch
|
12 |
+
import torchaudio
|
13 |
+
from torch.nn.utils.rnn import pad_sequence
|
14 |
+
from transformers import Wav2Vec2FeatureExtractor
|
15 |
+
import scipy.signal as signal
|
16 |
+
import scipy.signal
|
17 |
+
# class FakeMusicCapsDataset(Dataset):
|
18 |
+
# def __init__(self, file_paths, labels, sr=16000, target_duration=10.0):
|
19 |
+
# self.file_paths = file_paths
|
20 |
+
# self.labels = labels
|
21 |
+
# self.sr = sr
|
22 |
+
# self.target_samples = int(target_duration * sr) # Fixed length: 5 seconds
|
23 |
+
|
24 |
+
# def __len__(self):
|
25 |
+
# return len(self.file_paths)
|
26 |
+
|
27 |
+
# def __getitem__(self, idx):
|
28 |
+
# audio_path = self.file_paths[idx]
|
29 |
+
# label = self.labels[idx]
|
30 |
+
|
31 |
+
# waveform, sr = torchaudio.load(audio_path)
|
32 |
+
# waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sr)(waveform)
|
33 |
+
# waveform = waveform.mean(dim=0) # Convert to mono
|
34 |
+
# waveform = waveform.squeeze(0)
|
35 |
+
|
36 |
+
|
37 |
+
# current_samples = waveform.shape[0]
|
38 |
+
|
39 |
+
# # **Ensure waveform is exactly `target_samples` long**
|
40 |
+
# if current_samples > self.target_samples:
|
41 |
+
# waveform = waveform[:self.target_samples] # Truncate if too long
|
42 |
+
# elif current_samples < self.target_samples:
|
43 |
+
# pad_length = self.target_samples - current_samples
|
44 |
+
# waveform = torch.nn.functional.pad(waveform, (0, pad_length)) # Pad if too short
|
45 |
+
|
46 |
+
# return waveform.unsqueeze(0), torch.tensor(label, dtype=torch.long) # Ensure 2D shape (1, target_samples)
|
47 |
+
|
48 |
+
class FakeMusicCapsDataset(Dataset):
|
49 |
+
def __init__(self, file_paths, labels, sr=16000, target_duration=10.0):
|
50 |
+
self.file_paths = file_paths
|
51 |
+
self.labels = labels
|
52 |
+
self.sr = sr
|
53 |
+
self.target_samples = int(target_duration * sr) # Fixed length: 10 seconds
|
54 |
+
self.processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True)
|
55 |
+
|
56 |
+
def __len__(self):
|
57 |
+
return len(self.file_paths)
|
58 |
+
|
59 |
+
def highpass_filter(self, y, sr, cutoff=500, order=5):
|
60 |
+
if isinstance(sr, np.ndarray):
|
61 |
+
# print(f"[ERROR] sr is an array, taking mean value. Original sr: {sr}")
|
62 |
+
sr = np.mean(sr)
|
63 |
+
if not isinstance(sr, (int, float)):
|
64 |
+
raise ValueError(f"[ERROR] sr must be a number, but got {type(sr)}: {sr}")
|
65 |
+
# print(f"[DEBUG] Highpass filter using sr={sr}, cutoff={cutoff}")
|
66 |
+
if sr <= 0:
|
67 |
+
raise ValueError(f"Invalid sample rate: {sr}. It must be greater than 0.")
|
68 |
+
nyquist = 0.5 * sr
|
69 |
+
# print(f"[DEBUG] Nyquist frequency={nyquist}")
|
70 |
+
if cutoff <= 0 or cutoff >= nyquist:
|
71 |
+
print(f"[WARNING] Invalid cutoff frequency {cutoff}, adjusting...")
|
72 |
+
cutoff = max(10, min(cutoff, nyquist - 1))
|
73 |
+
normal_cutoff = cutoff / nyquist
|
74 |
+
# print(f"[DEBUG] Adjusted cutoff={cutoff}, normal_cutoff={normal_cutoff}")
|
75 |
+
b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
|
76 |
+
y_filtered = signal.lfilter(b, a, y)
|
77 |
+
return y_filtered
|
78 |
+
|
79 |
+
def __getitem__(self, idx):
|
80 |
+
audio_path = self.file_paths[idx]
|
81 |
+
label = self.labels[idx]
|
82 |
+
|
83 |
+
waveform, sr = torchaudio.load(audio_path)
|
84 |
+
|
85 |
+
target_sr = self.processor.sampling_rate
|
86 |
+
|
87 |
+
if sr != target_sr:
|
88 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)
|
89 |
+
waveform = resampler(waveform)
|
90 |
+
|
91 |
+
waveform = waveform.mean(dim=0).squeeze(0) # [Time]
|
92 |
+
|
93 |
+
if label == 1:
|
94 |
+
waveform = self.highpass_filter(waveform, self.sr)
|
95 |
+
|
96 |
+
current_samples = waveform.shape[0]
|
97 |
+
if current_samples > self.target_samples:
|
98 |
+
waveform = waveform[:self.target_samples] # Truncate
|
99 |
+
elif current_samples < self.target_samples:
|
100 |
+
pad_length = self.target_samples - current_samples
|
101 |
+
waveform = torch.nn.functional.pad(waveform, (0, pad_length)) # Pad
|
102 |
+
|
103 |
+
if isinstance(waveform, torch.Tensor):
|
104 |
+
waveform = waveform.numpy() # Tensor일 경우에만 변환
|
105 |
+
|
106 |
+
inputs = self.processor(waveform, sampling_rate=target_sr, return_tensors="pt", padding=True)
|
107 |
+
|
108 |
+
return inputs["input_values"].squeeze(0), torch.tensor(label, dtype=torch.long) # [1, time] → [time]
|
109 |
+
|
110 |
+
@staticmethod
|
111 |
+
def collate_fn(batch, target_samples=16000 * 10):
|
112 |
+
|
113 |
+
inputs, labels = zip(*batch) # Unzip batch
|
114 |
+
|
115 |
+
processed_inputs = []
|
116 |
+
for waveform in inputs:
|
117 |
+
current_samples = waveform.shape[0]
|
118 |
+
|
119 |
+
if current_samples > target_samples:
|
120 |
+
start_idx = (current_samples - target_samples) // 2
|
121 |
+
cropped_waveform = waveform[start_idx:start_idx + target_samples]
|
122 |
+
else:
|
123 |
+
pad_length = target_samples - current_samples
|
124 |
+
cropped_waveform = torch.nn.functional.pad(waveform, (0, pad_length))
|
125 |
+
|
126 |
+
processed_inputs.append(cropped_waveform)
|
127 |
+
|
128 |
+
processed_inputs = torch.stack(processed_inputs) # [batch, target_samples]
|
129 |
+
labels = torch.tensor(labels, dtype=torch.long) # [batch]
|
130 |
+
|
131 |
+
return processed_inputs, labels
|
132 |
+
|
133 |
+
def preprocess_audio(audio_path, target_sr=16000, max_length=160000):
|
134 |
+
"""
|
135 |
+
오디오를 모델 입력에 맞게 변환
|
136 |
+
- target_sr: 16kHz로 변환
|
137 |
+
- max_length: 최대 길이 160000 (10초)
|
138 |
+
"""
|
139 |
+
waveform, sr = torchaudio.load(audio_path)
|
140 |
+
|
141 |
+
# Resample if needed
|
142 |
+
if sr != target_sr:
|
143 |
+
waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(waveform)
|
144 |
+
|
145 |
+
# Convert to mono
|
146 |
+
waveform = waveform.mean(dim=0).unsqueeze(0) # (1, sequence_length)
|
147 |
+
|
148 |
+
current_samples = waveform.shape[1]
|
149 |
+
if current_samples > max_length:
|
150 |
+
start_idx = (current_samples - max_length) // 2
|
151 |
+
waveform = waveform[:, start_idx:start_idx + max_length]
|
152 |
+
elif current_samples < max_length:
|
153 |
+
pad_length = max_length - current_samples
|
154 |
+
waveform = torch.nn.functional.pad(waveform, (0, pad_length))
|
155 |
+
|
156 |
+
return waveform
|
157 |
+
|
158 |
+
|
159 |
+
DATASET_PATH = "/data/kym/AI_Music_Detection/audio/FakeMusicCaps"
|
160 |
+
SUNOCAPS_PATH = "/data/kym/Audio/SunoCaps" # Open Set 포함 데이터
|
161 |
+
|
162 |
+
# Closed Test: FakeMusicCaps 데이터셋 사용
|
163 |
+
real_files = glob.glob(os.path.join(DATASET_PATH, "real", "**", "*.wav"), recursive=True)
|
164 |
+
gen_files = glob.glob(os.path.join(DATASET_PATH, "generative", "**", "*.wav"), recursive=True)
|
165 |
+
|
166 |
+
# Open Set Test: SUNOCAPS_PATH 데이터 포함
|
167 |
+
open_real_files = real_files + glob.glob(os.path.join(SUNOCAPS_PATH, "real", "**", "*.wav"), recursive=True)
|
168 |
+
open_gen_files = gen_files + glob.glob(os.path.join(SUNOCAPS_PATH, "generative", "**", "*.wav"), recursive=True)
|
169 |
+
|
170 |
+
real_labels = [0] * len(real_files)
|
171 |
+
gen_labels = [1] * len(gen_files)
|
172 |
+
|
173 |
+
open_real_labels = [0] * len(open_real_files)
|
174 |
+
open_gen_labels = [1] * len(open_gen_files)
|
175 |
+
|
176 |
+
# Closed Train, Val
|
177 |
+
real_train, real_val, real_train_labels, real_val_labels = train_test_split(real_files, real_labels, test_size=0.2, random_state=42)
|
178 |
+
gen_train, gen_val, gen_train_labels, gen_val_labels = train_test_split(gen_files, gen_labels, test_size=0.2, random_state=42)
|
179 |
+
|
180 |
+
train_files = real_train + gen_train
|
181 |
+
train_labels = real_train_labels + gen_train_labels
|
182 |
+
val_files = real_val + gen_val
|
183 |
+
val_labels = real_val_labels + gen_val_labels
|
184 |
+
|
185 |
+
# Closed Set Test용 데이터셋
|
186 |
+
closed_test_files = real_files + gen_files
|
187 |
+
closed_test_labels = real_labels + gen_labels
|
188 |
+
|
189 |
+
# Open Set Test용 데이터셋
|
190 |
+
open_test_files = open_real_files + open_gen_files
|
191 |
+
open_test_labels = open_real_labels + open_gen_labels
|
192 |
+
|
193 |
+
# Oversampling 적용
|
194 |
+
ros = RandomOverSampler(sampling_strategy='auto', random_state=42)
|
195 |
+
train_files_resampled, train_labels_resampled = ros.fit_resample(np.array(train_files).reshape(-1, 1), train_labels)
|
196 |
+
|
197 |
+
train_files = train_files_resampled.reshape(-1).tolist()
|
198 |
+
train_labels = train_labels_resampled
|
199 |
+
|
200 |
+
print(f"📌 Train Original FAKE: {len(gen_train)}")
|
201 |
+
print(f"📌 Train set (Oversampled) - REAL: {sum(1 for label in train_labels if label == 0)}, "
|
202 |
+
f"FAKE: {sum(1 for label in train_labels if label == 1)}, Total: {len(train_files)}")
|
203 |
+
print(f"📌 Validation set - REAL: {len(real_val)}, FAKE: {len(gen_val)}, Total: {len(val_files)}")
|
ISMIR_2025/MERT/main.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.optim as optim
|
7 |
+
from tqdm import tqdm
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from sklearn.metrics import f1_score, precision_score, recall_score, balanced_accuracy_score
|
10 |
+
import wandb
|
11 |
+
import argparse
|
12 |
+
from transformers import AutoModel, AutoConfig, Wav2Vec2FeatureExtractor
|
13 |
+
from ISMIR_2025.MERT.datalib import FakeMusicCapsDataset, train_files, train_labels, val_files, val_labels
|
14 |
+
from ISMIR_2025.MERT.networks import MERTFeatureExtractor
|
15 |
+
# Set device
|
16 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
|
17 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
18 |
+
|
19 |
+
# Seed for reproducibility
|
20 |
+
torch.manual_seed(42)
|
21 |
+
random.seed(42)
|
22 |
+
np.random.seed(42)
|
23 |
+
|
24 |
+
# Initialize wandb
|
25 |
+
wandb.init(project="mert", name=f"hpfilter_pretrain_{args.pretrain_epochs}_finetune_{args.finetune_epochs}", config=args)
|
26 |
+
|
27 |
+
# Load datasets
|
28 |
+
print("🔍 Preparing datasets...")
|
29 |
+
train_dataset = FakeMusicCapsDataset(train_files, train_labels)
|
30 |
+
val_dataset = FakeMusicCapsDataset(val_files, val_labels)
|
31 |
+
|
32 |
+
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, collate_fn=FakeMusicCapsDataset.collate_fn)
|
33 |
+
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, collate_fn=FakeMusicCapsDataset.collate_fn)
|
34 |
+
|
35 |
+
# Model Checkpoint Paths
|
36 |
+
pretrain_ckpt = os.path.join(args.checkpoint_dir, f"mert_pretrain_{args.pretrain_epochs}.pth")
|
37 |
+
finetune_ckpt = os.path.join(args.checkpoint_dir, f"mert_finetune_{args.finetune_epochs}.pth")
|
38 |
+
|
39 |
+
# Load Music2Vec Model for Pretraining
|
40 |
+
print("🔍 Initializing MERT model for Pretraining...")
|
41 |
+
|
42 |
+
config = AutoConfig.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True)
|
43 |
+
if not hasattr(config, "conv_pos_batch_norm"):
|
44 |
+
setattr(config, "conv_pos_batch_norm", False)
|
45 |
+
|
46 |
+
mert_model = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", config=config, trust_remote_code=True).to(device)
|
47 |
+
mert_model = MERTFeatureExtractor().to(device)
|
48 |
+
|
49 |
+
# Loss and Optimizer
|
50 |
+
criterion = nn.CrossEntropyLoss()
|
51 |
+
optimizer = optim.Adam(mert_model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
|
52 |
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
|
53 |
+
|
54 |
+
# Training function
|
55 |
+
def train(model, dataloader, optimizer, criterion, device, epoch, phase="Pretrain"):
|
56 |
+
model.train()
|
57 |
+
total_loss, total_correct, total_samples = 0, 0, 0
|
58 |
+
all_preds, all_labels = [], []
|
59 |
+
|
60 |
+
for inputs, labels in tqdm(dataloader, desc=f"{phase} Training Epoch {epoch+1}"):
|
61 |
+
labels = labels.to(device)
|
62 |
+
inputs = inputs.to(device)
|
63 |
+
|
64 |
+
# inputs = inputs.float()
|
65 |
+
# output = model(inputs)
|
66 |
+
output = model(inputs)
|
67 |
+
|
68 |
+
# Check if the output is a tensor or an object with logits
|
69 |
+
if isinstance(output, torch.Tensor):
|
70 |
+
logits = output
|
71 |
+
elif hasattr(output, "logits"):
|
72 |
+
logits = output.logits
|
73 |
+
elif isinstance(output, (tuple, list)):
|
74 |
+
logits = output[0]
|
75 |
+
else:
|
76 |
+
raise ValueError("Unexpected model output type")
|
77 |
+
|
78 |
+
loss = criterion(logits, labels)
|
79 |
+
|
80 |
+
|
81 |
+
# loss = criterion(output, labels)
|
82 |
+
|
83 |
+
optimizer.zero_grad()
|
84 |
+
loss.backward()
|
85 |
+
optimizer.step()
|
86 |
+
|
87 |
+
total_loss += loss.item()
|
88 |
+
preds = output.argmax(dim=1)
|
89 |
+
total_correct += (preds == labels).sum().item()
|
90 |
+
total_samples += labels.size(0)
|
91 |
+
all_preds.extend(preds.cpu().numpy())
|
92 |
+
all_labels.extend(labels.cpu().numpy())
|
93 |
+
|
94 |
+
scheduler.step()
|
95 |
+
|
96 |
+
accuracy = total_correct / total_samples
|
97 |
+
f1 = f1_score(all_labels, all_preds, average="binary")
|
98 |
+
precision = precision_score(all_labels, all_preds, average="binary")
|
99 |
+
recall = recall_score(all_labels, all_preds, average="binary", pos_label=1)
|
100 |
+
balanced_acc = balanced_accuracy_score(all_labels, all_preds)
|
101 |
+
|
102 |
+
|
103 |
+
wandb.log({
|
104 |
+
f"{phase} Train Loss": total_loss / len(dataloader),
|
105 |
+
f"{phase} Train Accuracy": accuracy,
|
106 |
+
f"{phase} Train F1 Score": f1,
|
107 |
+
f"{phase} Train Precision": precision,
|
108 |
+
f"{phase} Train Recall": recall,
|
109 |
+
f"{phase} Train Balanced Accuracy": balanced_acc,
|
110 |
+
})
|
111 |
+
|
112 |
+
print(f"{phase} Train Epoch {epoch+1}: Train Loss: {total_loss / len(dataloader):.4f}, "
|
113 |
+
f"Train Acc: {accuracy:.4f}, Train F1: {f1:.4f}, Train Prec: {precision:.4f}, Train Rec: {recall:.4f}, B_ACC: {balanced_acc:.4f}")
|
114 |
+
|
115 |
+
def validate(model, dataloader, optimizer, criterion, device, epoch, phase="Validation"):
|
116 |
+
model.eval()
|
117 |
+
total_loss, total_correct, total_samples = 0, 0, 0
|
118 |
+
all_preds, all_labels = [], []
|
119 |
+
|
120 |
+
for inputs, labels in tqdm(dataloader, desc=f"{phase} Validation Epoch {epoch+1}"):
|
121 |
+
labels = labels.to(device)
|
122 |
+
inputs = inputs.to(device)
|
123 |
+
|
124 |
+
output = model(inputs)
|
125 |
+
|
126 |
+
# Check if the output is a tensor or an object with logits
|
127 |
+
if isinstance(output, torch.Tensor):
|
128 |
+
logits = output
|
129 |
+
elif hasattr(output, "logits"):
|
130 |
+
logits = output.logits
|
131 |
+
elif isinstance(output, (tuple, list)):
|
132 |
+
logits = output[0]
|
133 |
+
else:
|
134 |
+
raise ValueError("Unexpected model output type")
|
135 |
+
|
136 |
+
loss = criterion(logits, labels)
|
137 |
+
optimizer.zero_grad()
|
138 |
+
loss.backward()
|
139 |
+
optimizer.step()
|
140 |
+
|
141 |
+
total_loss += loss.item()
|
142 |
+
preds = outputs.argmax(dim=1)
|
143 |
+
total_correct += (preds == labels).sum().item()
|
144 |
+
total_samples += labels.size(0)
|
145 |
+
all_preds.extend(preds.cpu().numpy())
|
146 |
+
all_labels.extend(labels.cpu().numpy())
|
147 |
+
scheduler.step()
|
148 |
+
accuracy = total_correct / total_samples
|
149 |
+
val_f1 = f1_score(all_labels, all_preds, average="weighted")
|
150 |
+
val_precision = precision_score(all_labels, all_preds, average="binary")
|
151 |
+
val_recall = recall_score(all_labels, all_preds, average="binary")
|
152 |
+
val_bal_acc = balanced_accuracy_score(all_labels, all_preds)
|
153 |
+
|
154 |
+
wandb.log({
|
155 |
+
f"{phase} Val Loss": total_loss / len(dataloader),
|
156 |
+
f"{phase} Val Accuracy": accuracy,
|
157 |
+
f"{phase} Val F1 Score": val_f1,
|
158 |
+
f"{phase} Val Precision": val_precision,
|
159 |
+
f"{phase} Val Recall": val_recall,
|
160 |
+
f"{phase} Val Balanced Accuracy": val_bal_acc,
|
161 |
+
})
|
162 |
+
print(f"{phase} Val Loss: {total_loss / len(dataloader):.4f}, "
|
163 |
+
f"Val Acc: {accuracy:.4f}, Val F1: {val_f1:.4f}, Val Prec: {val_precision:.4f}, Val Rec: {val_recall:.4f}, Val B_ACC: {val_bal_acc:.4f}")
|
164 |
+
return total_loss / len(dataloader), accuracy, val_f1
|
165 |
+
|
166 |
+
|
167 |
+
print("\n🔍 Step 1: Self-Supervised Pretraining on REAL Data")
|
168 |
+
# for epoch in range(args.pretrain_epochs):
|
169 |
+
# train(mert_model, train_loader, optimizer, criterion, device, epoch, phase="Pretrain")
|
170 |
+
# torch.save(mert_model.state_dict(), pretrain_ckpt)
|
171 |
+
# print(f"\nPretraining completed! Model saved at: {pretrain_ckpt}")
|
172 |
+
|
173 |
+
# print("\n🔍 Initializing CCV Model for Fine-Tuning...")
|
174 |
+
# mert_model = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", config=config, trust_remote_code=True).to(device)
|
175 |
+
# mert_model.feature_extractor.load_state_dict(torch.load(pretrain_ckpt), strict=False)
|
176 |
+
|
177 |
+
# optimizer = optim.Adam(mert_model.parameters(), lr=args.finetune_lr, weight_decay=args.weight_decay)
|
178 |
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
|
179 |
+
|
180 |
+
print("\n🔍 Step 2: Fine-Tuning CCV Model")
|
181 |
+
for epoch in range(args.finetune_epochs):
|
182 |
+
train(mert_model, train_loader, optimizer, criterion, device, epoch, phase="Fine-Tune")
|
183 |
+
|
184 |
+
torch.save(mert_model.state_dict(), finetune_ckpt)
|
185 |
+
print(f"\nFine-Tuning completed! Model saved at: {finetune_ckpt}")
|
186 |
+
|
187 |
+
print("\n🔍 Step 2: Fine-Tuning MERT Model")
|
188 |
+
mert_model.load_state_dict(torch.load(pretrain_ckpt), strict=False)
|
189 |
+
|
190 |
+
optimizer = optim.Adam(mert_model.parameters(), lr=args.finetune_lr, weight_decay=args.weight_decay)
|
191 |
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
|
192 |
+
|
193 |
+
for epoch in range(args.finetune_epochs):
|
194 |
+
train(mert_model, train_loader, optimizer, criterion, device, epoch, phase="Fine-Tune")
|
195 |
+
|
196 |
+
torch.save(mert_model.state_dict(), finetune_ckpt)
|
197 |
+
print(f"\nFine-Tuning completed! Model saved at: {finetune_ckpt}")
|
ISMIR_2025/MERT/networks.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import AutoModel, AutoConfig
|
4 |
+
|
5 |
+
class MERTFeatureExtractor(nn.Module):
|
6 |
+
def __init__(self, freeze_feature_extractor=True):
|
7 |
+
super(MERTFeatureExtractor, self).__init__()
|
8 |
+
config = AutoConfig.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True)
|
9 |
+
if not hasattr(config, "conv_pos_batch_norm"):
|
10 |
+
setattr(config, "conv_pos_batch_norm", False)
|
11 |
+
self.mert = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", config=config, trust_remote_code=True)
|
12 |
+
|
13 |
+
if freeze_feature_extractor:
|
14 |
+
self.freeze()
|
15 |
+
|
16 |
+
def forward(self, input_values):
|
17 |
+
# 입력: [batch, time]
|
18 |
+
# 사전학습된 MERT의 hidden_states 추출 (예시로 모든 레이어의 hidden state 사용)
|
19 |
+
with torch.no_grad():
|
20 |
+
outputs = self.mert(input_values, output_hidden_states=True)
|
21 |
+
# hidden_states: tuple of [batch, time, feature_dim]
|
22 |
+
# 여러 레이어의 hidden state를 스택한 뒤 시간축에 대해 평균하여 feature를 얻음
|
23 |
+
hidden_states = torch.stack(outputs.hidden_states) # [num_layers, batch, time, feature_dim]
|
24 |
+
hidden_states = hidden_states.detach().clone().requires_grad_(True)
|
25 |
+
time_reduced = hidden_states.mean(dim=2) # [num_layers, batch, feature_dim]
|
26 |
+
time_reduced = time_reduced.permute(1, 0, 2) # [batch, num_layers, feature_dim]
|
27 |
+
return time_reduced
|
28 |
+
|
29 |
+
def freeze(self):
|
30 |
+
for param in self.mert.parameters():
|
31 |
+
param.requires_grad = False
|
32 |
+
|
33 |
+
def unfreeze(self):
|
34 |
+
for param in self.mert.parameters():
|
35 |
+
param.requires_grad = True
|
36 |
+
|
37 |
+
|
38 |
+
class CrossAttentionLayer(nn.Module):
|
39 |
+
def __init__(self, embed_dim, num_heads):
|
40 |
+
super(CrossAttentionLayer, self).__init__()
|
41 |
+
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
|
42 |
+
self.layer_norm1 = nn.LayerNorm(embed_dim)
|
43 |
+
self.layer_norm2 = nn.LayerNorm(embed_dim)
|
44 |
+
self.feed_forward = nn.Sequential(
|
45 |
+
nn.Linear(embed_dim, embed_dim * 4),
|
46 |
+
nn.ReLU(),
|
47 |
+
nn.Linear(embed_dim * 4, embed_dim)
|
48 |
+
)
|
49 |
+
|
50 |
+
def forward(self, x, cross_input):
|
51 |
+
# x와 cross_input 간의 어텐션 수행
|
52 |
+
attn_output, _ = self.multihead_attn(query=x, key=cross_input, value=cross_input)
|
53 |
+
x = self.layer_norm1(x + attn_output)
|
54 |
+
ff_output = self.feed_forward(x)
|
55 |
+
x = self.layer_norm2(x + ff_output)
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
class CCV(nn.Module):
|
60 |
+
def __init__(self, embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True):
|
61 |
+
super(CCV, self).__init__()
|
62 |
+
# MERT 기반 feature extractor (pretraining weight로부터 유의미한 피쳐 추출)
|
63 |
+
self.feature_extractor = MERTFeatureExtractor(freeze_feature_extractor=freeze_feature_extractor)
|
64 |
+
# Cross-Attention 레이어 여러 층
|
65 |
+
self.cross_attention_layers = nn.ModuleList([
|
66 |
+
CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers)
|
67 |
+
])
|
68 |
+
# Transformer Encoder (배치 차원 고려)
|
69 |
+
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
|
70 |
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
71 |
+
# 분류기
|
72 |
+
self.classifier = nn.Sequential(
|
73 |
+
nn.LayerNorm(embed_dim),
|
74 |
+
nn.Linear(embed_dim, 256),
|
75 |
+
nn.BatchNorm1d(256),
|
76 |
+
nn.ReLU(),
|
77 |
+
nn.Dropout(0.3),
|
78 |
+
nn.Linear(256, num_classes)
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
def forward(self, input_values):
|
83 |
+
"""
|
84 |
+
input_values: Tensor [batch, time]
|
85 |
+
1. MERT로부터 feature 추출 → [batch, num_layers, feature_dim]
|
86 |
+
2. 임베딩 차원 맞추기 위해 transpose → [batch, feature_dim, num_layers]
|
87 |
+
3. Cross-Attention 적용
|
88 |
+
4. Transformer Encoding 후 평균 풀링
|
89 |
+
5. 분류기 통과하여 최종 출력(logits) 반환
|
90 |
+
"""
|
91 |
+
features = self.feature_extractor(input_values) # [batch, num_layers, feature_dim]
|
92 |
+
# embed_dim는 보통 feature_dim과 동일하게 맞춤 (예시: 768)
|
93 |
+
# features = features.permute(0, 2, 1) # [batch, embed_dim, num_layers]
|
94 |
+
|
95 |
+
# Cross-Attention 적용 (여기서는 자기자신과의 어텐션으로 예시)
|
96 |
+
for layer in self.cross_attention_layers:
|
97 |
+
features = layer(features, features)
|
98 |
+
|
99 |
+
# Transformer Encoder를 위해 시간 축(여기서는 num_layers 축)에 대해 평균
|
100 |
+
features = features.mean(dim=1).unsqueeze(1) # [batch, 1, embed_dim]
|
101 |
+
encoded = self.transformer(features) # [batch, 1, embed_dim]
|
102 |
+
encoded = encoded.mean(dim=1) # [batch, embed_dim]
|
103 |
+
output = self.classifier(encoded) # [batch, num_classes]
|
104 |
+
return output, encoded
|
105 |
+
|
106 |
+
def unfreeze_feature_extractor(self):
|
107 |
+
self.feature_extractor.unfreeze()
|
ISMIR_2025/MERT/test.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from sklearn.metrics import precision_score, recall_score, f1_score, balanced_accuracy_score, confusion_matrix
|
8 |
+
from datalib import (
|
9 |
+
FakeMusicCapsDataset,
|
10 |
+
closed_test_files, closed_test_labels,
|
11 |
+
open_test_files, open_test_labels,
|
12 |
+
val_files, val_labels
|
13 |
+
)
|
14 |
+
from networks import MERTFeatureExtractor
|
15 |
+
import argparse
|
16 |
+
parser = argparse.ArgumentParser(description="AI Music Detection Testing with MERT")
|
17 |
+
parser.add_argument('--gpu', type=str, default='1', help='GPU ID')
|
18 |
+
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
|
19 |
+
parser.add_argument('--ckpt_path', type=str, default="/data/kym/AI_Music_Detection/Code/model/MERT/ckpt/1e-3/mert_finetune_10.pth", help='Path to the pretrained checkpoint')
|
20 |
+
parser.add_argument('--model_name', type=str, default="mert", help="Model name")
|
21 |
+
parser.add_argument('--closed_test', action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)")
|
22 |
+
parser.add_argument('--open_test', action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)")
|
23 |
+
parser.add_argument('--output_path', type=str, default='', help='Path to save test results')
|
24 |
+
|
25 |
+
args = parser.parse_args()
|
26 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
|
27 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
+
|
29 |
+
def plot_confusion_matrix(y_true, y_pred, classes, output_path):
|
30 |
+
cm = confusion_matrix(y_true, y_pred)
|
31 |
+
fig, ax = plt.subplots(figsize=(6, 6))
|
32 |
+
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
|
33 |
+
ax.figure.colorbar(im, ax=ax)
|
34 |
+
|
35 |
+
num_classes = cm.shape[0]
|
36 |
+
tick_labels = classes[:num_classes]
|
37 |
+
|
38 |
+
ax.set(xticks=np.arange(num_classes),
|
39 |
+
yticks=np.arange(num_classes),
|
40 |
+
xticklabels=tick_labels,
|
41 |
+
yticklabels=tick_labels,
|
42 |
+
ylabel='True label',
|
43 |
+
xlabel='Predicted label')
|
44 |
+
|
45 |
+
thresh = cm.max() / 2.
|
46 |
+
for i in range(cm.shape[0]):
|
47 |
+
for j in range(cm.shape[1]):
|
48 |
+
ax.text(j, i, format(cm[i, j], 'd'),
|
49 |
+
ha="center", va="center",
|
50 |
+
color="white" if cm[i, j] > thresh else "black")
|
51 |
+
|
52 |
+
fig.tight_layout()
|
53 |
+
plt.savefig(output_path)
|
54 |
+
plt.close(fig)
|
55 |
+
|
56 |
+
model = MERTFeatureExtractor().to(device)
|
57 |
+
|
58 |
+
ckpt_file = args.ckpt_path
|
59 |
+
if not os.path.exists(ckpt_file):
|
60 |
+
raise FileNotFoundError(f"Checkpoint not found: {ckpt_file}")
|
61 |
+
print(f"\nLoading MERT model from {ckpt_file}")
|
62 |
+
model.load_state_dict(torch.load(ckpt_file, map_location=device))
|
63 |
+
model.eval()
|
64 |
+
|
65 |
+
torch.cuda.empty_cache()
|
66 |
+
|
67 |
+
if args.closed_test:
|
68 |
+
print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...")
|
69 |
+
test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels, target_duration=10.0)
|
70 |
+
elif args.open_test:
|
71 |
+
print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...")
|
72 |
+
test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels, target_duration=10.0)
|
73 |
+
else:
|
74 |
+
print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...")
|
75 |
+
test_dataset = FakeMusicCapsDataset(val_files, val_labels, target_duration=10.0)
|
76 |
+
|
77 |
+
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8)
|
78 |
+
|
79 |
+
def test_mert(model, test_loader, device):
|
80 |
+
model.eval()
|
81 |
+
test_loss, test_correct, test_total = 0, 0, 0
|
82 |
+
all_preds, all_labels = [], []
|
83 |
+
|
84 |
+
with torch.no_grad():
|
85 |
+
for data, target in test_loader:
|
86 |
+
data, target = data.to(device), target.to(device)
|
87 |
+
output = model(data)
|
88 |
+
loss = F.cross_entropy(output, target)
|
89 |
+
|
90 |
+
test_loss += loss.item() * data.size(0)
|
91 |
+
preds = output.argmax(dim=1)
|
92 |
+
test_correct += (preds == target).sum().item()
|
93 |
+
test_total += target.size(0)
|
94 |
+
|
95 |
+
all_labels.extend(target.cpu().numpy())
|
96 |
+
all_preds.extend(preds.cpu().numpy())
|
97 |
+
|
98 |
+
test_loss /= test_total
|
99 |
+
test_acc = test_correct / test_total
|
100 |
+
test_bal_acc = balanced_accuracy_score(all_labels, all_preds)
|
101 |
+
test_precision = precision_score(all_labels, all_preds, average="binary")
|
102 |
+
test_recall = recall_score(all_labels, all_preds, average="binary")
|
103 |
+
test_f1 = f1_score(all_labels, all_preds, average="binary")
|
104 |
+
|
105 |
+
print(f"\nTest Results - Loss: {test_loss:.4f} | Test Acc: {test_acc:.3f} | "
|
106 |
+
f"Test B_ACC: {test_bal_acc:.4f} | Test Prec: {test_precision:.3f} | "
|
107 |
+
f"Test Rec: {test_recall:.3f} | Test F1: {test_f1:.3f}")
|
108 |
+
|
109 |
+
os.makedirs(args.output_path, exist_ok=True)
|
110 |
+
conf_matrix_path = os.path.join(args.output_path, f"confusion_matrix_{args.model_name}.png")
|
111 |
+
plot_confusion_matrix(all_labels, all_preds, classes=["real", "generative"], output_path=conf_matrix_path)
|
112 |
+
|
113 |
+
print("\nEvaluating MERT Model on Test Set...")
|
114 |
+
test_mert(model, test_loader, device)
|
ISMIR_2025/MERT/utils/__pycache__/config.cpython-311.pyc
ADDED
Binary file (4.79 kB). View file
|
|
ISMIR_2025/MERT/utils/__pycache__/idr_torch.cpython-311.pyc
ADDED
Binary file (1.01 kB). View file
|
|
ISMIR_2025/MERT/utils/__pycache__/utilities.cpython-311.pyc
ADDED
Binary file (16.1 kB). View file
|
|
ISMIR_2025/MERT/utils/config.py
ADDED
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import csv
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
sample_rate = 32000
|
9 |
+
clip_samples = sample_rate * 10 # Audio clips are 10-second
|
10 |
+
|
11 |
+
# Load label
|
12 |
+
with open(
|
13 |
+
"/gpfswork/rech/djl/uzj43um/audio_retrieval/audioset_tagging_cnn/metadata/class_labels_indices.csv",
|
14 |
+
"r",
|
15 |
+
) as f:
|
16 |
+
reader = csv.reader(f, delimiter=",")
|
17 |
+
lines = list(reader)
|
18 |
+
|
19 |
+
labels = []
|
20 |
+
ids = [] # Each label has a unique id such as "/m/068hy"
|
21 |
+
for i1 in range(1, len(lines)):
|
22 |
+
id = lines[i1][1]
|
23 |
+
label = lines[i1][2]
|
24 |
+
ids.append(id)
|
25 |
+
labels.append(label)
|
26 |
+
|
27 |
+
classes_num = len(labels)
|
28 |
+
|
29 |
+
lb_to_ix = {label: i for i, label in enumerate(labels)}
|
30 |
+
ix_to_lb = {i: label for i, label in enumerate(labels)}
|
31 |
+
|
32 |
+
id_to_ix = {id: i for i, id in enumerate(ids)}
|
33 |
+
ix_to_id = {i: id for i, id in enumerate(ids)}
|
34 |
+
|
35 |
+
full_samples_per_class = np.array(
|
36 |
+
[
|
37 |
+
937432,
|
38 |
+
16344,
|
39 |
+
7822,
|
40 |
+
10271,
|
41 |
+
2043,
|
42 |
+
14420,
|
43 |
+
733,
|
44 |
+
1511,
|
45 |
+
1258,
|
46 |
+
424,
|
47 |
+
1751,
|
48 |
+
704,
|
49 |
+
369,
|
50 |
+
590,
|
51 |
+
1063,
|
52 |
+
1375,
|
53 |
+
5026,
|
54 |
+
743,
|
55 |
+
853,
|
56 |
+
1648,
|
57 |
+
714,
|
58 |
+
1497,
|
59 |
+
1251,
|
60 |
+
2139,
|
61 |
+
1093,
|
62 |
+
133,
|
63 |
+
224,
|
64 |
+
39469,
|
65 |
+
6423,
|
66 |
+
407,
|
67 |
+
1559,
|
68 |
+
4546,
|
69 |
+
6826,
|
70 |
+
7464,
|
71 |
+
2468,
|
72 |
+
549,
|
73 |
+
4063,
|
74 |
+
334,
|
75 |
+
587,
|
76 |
+
238,
|
77 |
+
1766,
|
78 |
+
691,
|
79 |
+
114,
|
80 |
+
2153,
|
81 |
+
236,
|
82 |
+
209,
|
83 |
+
421,
|
84 |
+
740,
|
85 |
+
269,
|
86 |
+
959,
|
87 |
+
137,
|
88 |
+
4192,
|
89 |
+
485,
|
90 |
+
1515,
|
91 |
+
655,
|
92 |
+
274,
|
93 |
+
69,
|
94 |
+
157,
|
95 |
+
1128,
|
96 |
+
807,
|
97 |
+
1022,
|
98 |
+
346,
|
99 |
+
98,
|
100 |
+
680,
|
101 |
+
890,
|
102 |
+
352,
|
103 |
+
4169,
|
104 |
+
2061,
|
105 |
+
1753,
|
106 |
+
9883,
|
107 |
+
1339,
|
108 |
+
708,
|
109 |
+
37857,
|
110 |
+
18504,
|
111 |
+
12864,
|
112 |
+
2475,
|
113 |
+
2182,
|
114 |
+
757,
|
115 |
+
3624,
|
116 |
+
677,
|
117 |
+
1683,
|
118 |
+
3583,
|
119 |
+
444,
|
120 |
+
1780,
|
121 |
+
2364,
|
122 |
+
409,
|
123 |
+
4060,
|
124 |
+
3097,
|
125 |
+
3143,
|
126 |
+
502,
|
127 |
+
723,
|
128 |
+
600,
|
129 |
+
230,
|
130 |
+
852,
|
131 |
+
1498,
|
132 |
+
1865,
|
133 |
+
1879,
|
134 |
+
2429,
|
135 |
+
5498,
|
136 |
+
5430,
|
137 |
+
2139,
|
138 |
+
1761,
|
139 |
+
1051,
|
140 |
+
831,
|
141 |
+
2401,
|
142 |
+
2258,
|
143 |
+
1672,
|
144 |
+
1711,
|
145 |
+
987,
|
146 |
+
646,
|
147 |
+
794,
|
148 |
+
25061,
|
149 |
+
5792,
|
150 |
+
4256,
|
151 |
+
96,
|
152 |
+
8126,
|
153 |
+
2740,
|
154 |
+
752,
|
155 |
+
513,
|
156 |
+
554,
|
157 |
+
106,
|
158 |
+
254,
|
159 |
+
1592,
|
160 |
+
556,
|
161 |
+
331,
|
162 |
+
615,
|
163 |
+
2841,
|
164 |
+
737,
|
165 |
+
265,
|
166 |
+
1349,
|
167 |
+
358,
|
168 |
+
1731,
|
169 |
+
1115,
|
170 |
+
295,
|
171 |
+
1070,
|
172 |
+
972,
|
173 |
+
174,
|
174 |
+
937780,
|
175 |
+
112337,
|
176 |
+
42509,
|
177 |
+
49200,
|
178 |
+
11415,
|
179 |
+
6092,
|
180 |
+
13851,
|
181 |
+
2665,
|
182 |
+
1678,
|
183 |
+
13344,
|
184 |
+
2329,
|
185 |
+
1415,
|
186 |
+
2244,
|
187 |
+
1099,
|
188 |
+
5024,
|
189 |
+
9872,
|
190 |
+
10948,
|
191 |
+
4409,
|
192 |
+
2732,
|
193 |
+
1211,
|
194 |
+
1289,
|
195 |
+
4807,
|
196 |
+
5136,
|
197 |
+
1867,
|
198 |
+
16134,
|
199 |
+
14519,
|
200 |
+
3086,
|
201 |
+
19261,
|
202 |
+
6499,
|
203 |
+
4273,
|
204 |
+
2790,
|
205 |
+
8820,
|
206 |
+
1228,
|
207 |
+
1575,
|
208 |
+
4420,
|
209 |
+
3685,
|
210 |
+
2019,
|
211 |
+
664,
|
212 |
+
324,
|
213 |
+
513,
|
214 |
+
411,
|
215 |
+
436,
|
216 |
+
2997,
|
217 |
+
5162,
|
218 |
+
3806,
|
219 |
+
1389,
|
220 |
+
899,
|
221 |
+
8088,
|
222 |
+
7004,
|
223 |
+
1105,
|
224 |
+
3633,
|
225 |
+
2621,
|
226 |
+
9753,
|
227 |
+
1082,
|
228 |
+
26854,
|
229 |
+
3415,
|
230 |
+
4991,
|
231 |
+
2129,
|
232 |
+
5546,
|
233 |
+
4489,
|
234 |
+
2850,
|
235 |
+
1977,
|
236 |
+
1908,
|
237 |
+
1719,
|
238 |
+
1106,
|
239 |
+
1049,
|
240 |
+
152,
|
241 |
+
136,
|
242 |
+
802,
|
243 |
+
488,
|
244 |
+
592,
|
245 |
+
2081,
|
246 |
+
2712,
|
247 |
+
1665,
|
248 |
+
1128,
|
249 |
+
250,
|
250 |
+
544,
|
251 |
+
789,
|
252 |
+
2715,
|
253 |
+
8063,
|
254 |
+
7056,
|
255 |
+
2267,
|
256 |
+
8034,
|
257 |
+
6092,
|
258 |
+
3815,
|
259 |
+
1833,
|
260 |
+
3277,
|
261 |
+
8813,
|
262 |
+
2111,
|
263 |
+
4662,
|
264 |
+
2678,
|
265 |
+
2954,
|
266 |
+
5227,
|
267 |
+
1472,
|
268 |
+
2591,
|
269 |
+
3714,
|
270 |
+
1974,
|
271 |
+
1795,
|
272 |
+
4680,
|
273 |
+
3751,
|
274 |
+
6585,
|
275 |
+
2109,
|
276 |
+
36617,
|
277 |
+
6083,
|
278 |
+
16264,
|
279 |
+
17351,
|
280 |
+
3449,
|
281 |
+
5034,
|
282 |
+
3931,
|
283 |
+
2599,
|
284 |
+
4134,
|
285 |
+
3892,
|
286 |
+
2334,
|
287 |
+
2211,
|
288 |
+
4516,
|
289 |
+
2766,
|
290 |
+
2862,
|
291 |
+
3422,
|
292 |
+
1788,
|
293 |
+
2544,
|
294 |
+
2403,
|
295 |
+
2892,
|
296 |
+
4042,
|
297 |
+
3460,
|
298 |
+
1516,
|
299 |
+
1972,
|
300 |
+
1563,
|
301 |
+
1579,
|
302 |
+
2776,
|
303 |
+
1647,
|
304 |
+
4535,
|
305 |
+
3921,
|
306 |
+
1261,
|
307 |
+
6074,
|
308 |
+
2922,
|
309 |
+
3068,
|
310 |
+
1948,
|
311 |
+
4407,
|
312 |
+
712,
|
313 |
+
1294,
|
314 |
+
1019,
|
315 |
+
1572,
|
316 |
+
3764,
|
317 |
+
5218,
|
318 |
+
975,
|
319 |
+
1539,
|
320 |
+
6376,
|
321 |
+
1606,
|
322 |
+
6091,
|
323 |
+
1138,
|
324 |
+
1169,
|
325 |
+
7925,
|
326 |
+
3136,
|
327 |
+
1108,
|
328 |
+
2677,
|
329 |
+
2680,
|
330 |
+
1383,
|
331 |
+
3144,
|
332 |
+
2653,
|
333 |
+
1986,
|
334 |
+
1800,
|
335 |
+
1308,
|
336 |
+
1344,
|
337 |
+
122231,
|
338 |
+
12977,
|
339 |
+
2552,
|
340 |
+
2678,
|
341 |
+
7824,
|
342 |
+
768,
|
343 |
+
8587,
|
344 |
+
39503,
|
345 |
+
3474,
|
346 |
+
661,
|
347 |
+
430,
|
348 |
+
193,
|
349 |
+
1405,
|
350 |
+
1442,
|
351 |
+
3588,
|
352 |
+
6280,
|
353 |
+
10515,
|
354 |
+
785,
|
355 |
+
710,
|
356 |
+
305,
|
357 |
+
206,
|
358 |
+
4990,
|
359 |
+
5329,
|
360 |
+
3398,
|
361 |
+
1771,
|
362 |
+
3022,
|
363 |
+
6907,
|
364 |
+
1523,
|
365 |
+
8588,
|
366 |
+
12203,
|
367 |
+
666,
|
368 |
+
2113,
|
369 |
+
7916,
|
370 |
+
434,
|
371 |
+
1636,
|
372 |
+
5185,
|
373 |
+
1062,
|
374 |
+
664,
|
375 |
+
952,
|
376 |
+
3490,
|
377 |
+
2811,
|
378 |
+
2749,
|
379 |
+
2848,
|
380 |
+
15555,
|
381 |
+
363,
|
382 |
+
117,
|
383 |
+
1494,
|
384 |
+
1647,
|
385 |
+
5886,
|
386 |
+
4021,
|
387 |
+
633,
|
388 |
+
1013,
|
389 |
+
5951,
|
390 |
+
11343,
|
391 |
+
2324,
|
392 |
+
243,
|
393 |
+
372,
|
394 |
+
943,
|
395 |
+
734,
|
396 |
+
242,
|
397 |
+
3161,
|
398 |
+
122,
|
399 |
+
127,
|
400 |
+
201,
|
401 |
+
1654,
|
402 |
+
768,
|
403 |
+
134,
|
404 |
+
1467,
|
405 |
+
642,
|
406 |
+
1148,
|
407 |
+
2156,
|
408 |
+
1368,
|
409 |
+
1176,
|
410 |
+
302,
|
411 |
+
1909,
|
412 |
+
61,
|
413 |
+
223,
|
414 |
+
1812,
|
415 |
+
287,
|
416 |
+
422,
|
417 |
+
311,
|
418 |
+
228,
|
419 |
+
748,
|
420 |
+
230,
|
421 |
+
1876,
|
422 |
+
539,
|
423 |
+
1814,
|
424 |
+
737,
|
425 |
+
689,
|
426 |
+
1140,
|
427 |
+
591,
|
428 |
+
943,
|
429 |
+
353,
|
430 |
+
289,
|
431 |
+
198,
|
432 |
+
490,
|
433 |
+
7938,
|
434 |
+
1841,
|
435 |
+
850,
|
436 |
+
457,
|
437 |
+
814,
|
438 |
+
146,
|
439 |
+
551,
|
440 |
+
728,
|
441 |
+
1627,
|
442 |
+
620,
|
443 |
+
648,
|
444 |
+
1621,
|
445 |
+
2731,
|
446 |
+
535,
|
447 |
+
88,
|
448 |
+
1736,
|
449 |
+
736,
|
450 |
+
328,
|
451 |
+
293,
|
452 |
+
3170,
|
453 |
+
344,
|
454 |
+
384,
|
455 |
+
7640,
|
456 |
+
433,
|
457 |
+
215,
|
458 |
+
715,
|
459 |
+
626,
|
460 |
+
128,
|
461 |
+
3059,
|
462 |
+
1833,
|
463 |
+
2069,
|
464 |
+
3732,
|
465 |
+
1640,
|
466 |
+
1508,
|
467 |
+
836,
|
468 |
+
567,
|
469 |
+
2837,
|
470 |
+
1151,
|
471 |
+
2068,
|
472 |
+
695,
|
473 |
+
1494,
|
474 |
+
3173,
|
475 |
+
364,
|
476 |
+
88,
|
477 |
+
188,
|
478 |
+
740,
|
479 |
+
677,
|
480 |
+
273,
|
481 |
+
1533,
|
482 |
+
821,
|
483 |
+
1091,
|
484 |
+
293,
|
485 |
+
647,
|
486 |
+
318,
|
487 |
+
1202,
|
488 |
+
328,
|
489 |
+
532,
|
490 |
+
2847,
|
491 |
+
526,
|
492 |
+
721,
|
493 |
+
370,
|
494 |
+
258,
|
495 |
+
956,
|
496 |
+
1269,
|
497 |
+
1641,
|
498 |
+
339,
|
499 |
+
1322,
|
500 |
+
4485,
|
501 |
+
286,
|
502 |
+
1874,
|
503 |
+
277,
|
504 |
+
757,
|
505 |
+
1393,
|
506 |
+
1330,
|
507 |
+
380,
|
508 |
+
146,
|
509 |
+
377,
|
510 |
+
394,
|
511 |
+
318,
|
512 |
+
339,
|
513 |
+
1477,
|
514 |
+
1886,
|
515 |
+
101,
|
516 |
+
1435,
|
517 |
+
284,
|
518 |
+
1425,
|
519 |
+
686,
|
520 |
+
621,
|
521 |
+
221,
|
522 |
+
117,
|
523 |
+
87,
|
524 |
+
1340,
|
525 |
+
201,
|
526 |
+
1243,
|
527 |
+
1222,
|
528 |
+
651,
|
529 |
+
1899,
|
530 |
+
421,
|
531 |
+
712,
|
532 |
+
1016,
|
533 |
+
1279,
|
534 |
+
124,
|
535 |
+
351,
|
536 |
+
258,
|
537 |
+
7043,
|
538 |
+
368,
|
539 |
+
666,
|
540 |
+
162,
|
541 |
+
7664,
|
542 |
+
137,
|
543 |
+
70159,
|
544 |
+
26179,
|
545 |
+
6321,
|
546 |
+
32236,
|
547 |
+
33320,
|
548 |
+
771,
|
549 |
+
1169,
|
550 |
+
269,
|
551 |
+
1103,
|
552 |
+
444,
|
553 |
+
364,
|
554 |
+
2710,
|
555 |
+
121,
|
556 |
+
751,
|
557 |
+
1609,
|
558 |
+
855,
|
559 |
+
1141,
|
560 |
+
2287,
|
561 |
+
1940,
|
562 |
+
3943,
|
563 |
+
289,
|
564 |
+
]
|
565 |
+
)
|
ISMIR_2025/MERT/utils/confusion_matrix_plot.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sklearn.metrics import confusion_matrix
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
def plot_confusion_matrix(y_true, y_pred, classes, writer, epoch):
|
6 |
+
cm = confusion_matrix(y_true, y_pred)
|
7 |
+
fig, ax = plt.subplots(figsize=(6, 6))
|
8 |
+
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
|
9 |
+
ax.figure.colorbar(im, ax=ax)
|
10 |
+
|
11 |
+
num_classes = cm.shape[0]
|
12 |
+
tick_labels = classes[:num_classes]
|
13 |
+
|
14 |
+
ax.set(xticks=np.arange(num_classes),
|
15 |
+
yticks=np.arange(num_classes),
|
16 |
+
xticklabels=tick_labels,
|
17 |
+
yticklabels=tick_labels,
|
18 |
+
ylabel='True label',
|
19 |
+
xlabel='Predicted label')
|
20 |
+
|
21 |
+
thresh = cm.max() / 2.
|
22 |
+
for i in range(cm.shape[0]):
|
23 |
+
for j in range(cm.shape[1]):
|
24 |
+
ax.text(j, i, format(cm[i, j], 'd'),
|
25 |
+
ha="center", va="center",
|
26 |
+
color="white" if cm[i, j] > thresh else "black")
|
27 |
+
|
28 |
+
fig.tight_layout()
|
29 |
+
writer.add_figure("Confusion Matrix", fig, epoch)
|
ISMIR_2025/MERT/utils/freqeuncy.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import librosa.display
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
|
6 |
+
# 🔹 오디오 파일 로드
|
7 |
+
file_real = "/path/to/real_audio.wav" # Real 오디오 경로
|
8 |
+
file_fake = "/path/to/generative_audio.wav" # AI 생성 오디오 경로
|
9 |
+
|
10 |
+
def plot_spectrogram(audio_file, title):
|
11 |
+
y, sr = librosa.load(audio_file, sr=16000) # 샘플링 레이트 16kHz
|
12 |
+
D = librosa.amplitude_to_db(np.abs(librosa.stft(y)), ref=np.max) # STFT 변환
|
13 |
+
|
14 |
+
plt.figure(figsize=(10, 4))
|
15 |
+
librosa.display.specshow(D, sr=sr, x_axis='time', y_axis='hz', cmap='magma')
|
16 |
+
plt.colorbar(format='%+2.0f dB')
|
17 |
+
plt.title(title)
|
18 |
+
plt.ylim(4000, 16000) # 4kHz 이상 고주파 영역만 표시
|
19 |
+
plt.show()
|
20 |
+
|
21 |
+
# 🔹 Real vs Generative Spectrogram 비교
|
22 |
+
plot_spectrogram(file_real, "Real Audio Spectrogram (4kHz+)")
|
23 |
+
plot_spectrogram(file_fake, "Generative Audio Spectrogram (4kHz+)")
|
24 |
+
|
ISMIR_2025/MERT/utils/hf_vis.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import librosa.display
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import scipy.signal as signal
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import soundfile as sf
|
9 |
+
|
10 |
+
from networks import audiocnn, AudioCNNWithViTDecoder, AudioCNNWithViTDecoderAndCrossAttention
|
11 |
+
|
12 |
+
|
13 |
+
def highpass_filter(y, sr, cutoff=500, order=5):
|
14 |
+
"""High-pass filter to remove low frequencies below `cutoff` Hz."""
|
15 |
+
nyquist = 0.5 * sr
|
16 |
+
normal_cutoff = cutoff / nyquist
|
17 |
+
b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
|
18 |
+
y_filtered = signal.lfilter(b, a, y)
|
19 |
+
return y_filtered
|
20 |
+
|
21 |
+
def plot_combined_visualization(y_original, y_filtered, sr, save_path="combined_visualization.png"):
|
22 |
+
"""Plot waveform comparison and spectrograms in a single figure."""
|
23 |
+
fig, axes = plt.subplots(3, 1, figsize=(12, 12))
|
24 |
+
|
25 |
+
# 1️⃣ Waveform Comparison
|
26 |
+
time = np.linspace(0, len(y_original) / sr, len(y_original))
|
27 |
+
axes[0].plot(time, y_original, label='Original', alpha=0.7)
|
28 |
+
axes[0].plot(time, y_filtered, label='High-pass Filtered', alpha=0.7, linestyle='dashed')
|
29 |
+
axes[0].set_xlabel("Time (s)")
|
30 |
+
axes[0].set_ylabel("Amplitude")
|
31 |
+
axes[0].set_title("Waveform Comparison (Original vs High-pass Filtered)")
|
32 |
+
axes[0].legend()
|
33 |
+
|
34 |
+
# 2️⃣ Spectrogram - Original
|
35 |
+
S_orig = librosa.amplitude_to_db(np.abs(librosa.stft(y_original)), ref=np.max)
|
36 |
+
img = librosa.display.specshow(S_orig, sr=sr, x_axis='time', y_axis='log', ax=axes[1])
|
37 |
+
axes[1].set_title("Original Spectrogram")
|
38 |
+
fig.colorbar(img, ax=axes[1], format="%+2.0f dB")
|
39 |
+
|
40 |
+
# 3️⃣ Spectrogram - High-pass Filtered
|
41 |
+
S_filt = librosa.amplitude_to_db(np.abs(librosa.stft(y_filtered)), ref=np.max)
|
42 |
+
img = librosa.display.specshow(S_filt, sr=sr, x_axis='time', y_axis='log', ax=axes[2])
|
43 |
+
axes[2].set_title("High-pass Filtered Spectrogram")
|
44 |
+
fig.colorbar(img, ax=axes[2], format="%+2.0f dB")
|
45 |
+
|
46 |
+
plt.tight_layout()
|
47 |
+
plt.savefig(save_path, dpi=300)
|
48 |
+
plt.show()
|
49 |
+
|
50 |
+
|
51 |
+
def load_model(checkpoint_path, model_class, device):
|
52 |
+
"""Load a trained model from checkpoint."""
|
53 |
+
model = model_class()
|
54 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
|
55 |
+
model.to(device)
|
56 |
+
model.eval()
|
57 |
+
return model
|
58 |
+
|
59 |
+
def predict_audio(model, audio_tensor, device):
|
60 |
+
"""Make predictions using a trained model."""
|
61 |
+
with torch.no_grad():
|
62 |
+
audio_tensor = audio_tensor.unsqueeze(0).to(device) # Add batch dimension
|
63 |
+
output = model(audio_tensor)
|
64 |
+
prediction = torch.argmax(output, dim=1).cpu().numpy()[0]
|
65 |
+
return prediction
|
66 |
+
|
67 |
+
# Load audio
|
68 |
+
audio_path = "/data/kym/AI Music Detection/audio/FakeMusicCaps/real/musiccaps/_RrA-0lfIiU.wav" # Replace with actual file path
|
69 |
+
y, sr = librosa.load(audio_path, sr=None)
|
70 |
+
y_filtered = highpass_filter(y, sr, cutoff=500)
|
71 |
+
|
72 |
+
# Convert audio to tensor
|
73 |
+
audio_tensor = torch.tensor(librosa.feature.melspectrogram(y=y, sr=sr), dtype=torch.float).unsqueeze(0)
|
74 |
+
audio_tensor_filtered = torch.tensor(librosa.feature.melspectrogram(y=y_filtered, sr=sr), dtype=torch.float).unsqueeze(0)
|
75 |
+
|
76 |
+
# Load models
|
77 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
78 |
+
original_model = load_model("/data/kym/AI Music Detection/AudioCNN/ckpt/FakeMusicCaps/pretraining/best_model_audiocnn.pth", audiocnn, device)
|
79 |
+
highpass_model = load_model("/data/kym/AI Music Detection/AudioCNN/ckpt/FakeMusicCaps/500hz_Add_crossattn_decoder/best_model_AudioCNNWithViTDecoderAndCrossAttention.pth", AudioCNNWithViTDecoderAndCrossAttention, device)
|
80 |
+
|
81 |
+
# Predict
|
82 |
+
original_pred = predict_audio(original_model, audio_tensor, device)
|
83 |
+
highpass_pred = predict_audio(highpass_model, audio_tensor_filtered, device)
|
84 |
+
|
85 |
+
print(f"Original Model Prediction: {original_pred}")
|
86 |
+
print(f"High-pass Filter Model Prediction: {highpass_pred}")
|
87 |
+
|
88 |
+
# Generate combined visualization (all plots in one image)
|
89 |
+
plot_combined_visualization(y, y_filtered, sr, save_path="/data/kym/AI Music Detection/AudioCNN/hf_vis/rawvs500.png")
|
ISMIR_2025/MERT/utils/idr_torch.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
import os
|
5 |
+
import hostlist
|
6 |
+
|
7 |
+
# get SLURM variables
|
8 |
+
# rank = int(os.environ["SLURM_PROCID"])
|
9 |
+
local_rank = int(os.environ["SLURM_LOCALID"])
|
10 |
+
size = int(os.environ["SLURM_NTASKS"])
|
11 |
+
cpus_per_task = int(os.environ["SLURM_CPUS_PER_TASK"])
|
12 |
+
|
13 |
+
# get node list from slurm
|
14 |
+
hostnames = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])
|
15 |
+
|
16 |
+
# get IDs of reserved GPU
|
17 |
+
gpu_ids = os.environ["SLURM_STEP_GPUS"].split(",")
|
18 |
+
|
19 |
+
# define MASTER_ADD & MASTER_PORT
|
20 |
+
os.environ["MASTER_ADDR"] = hostnames[0]
|
21 |
+
os.environ["MASTER_PORT"] = str(
|
22 |
+
12345 + int(min(gpu_ids))
|
23 |
+
) # to avoid port conflict on the same node
|
ISMIR_2025/MERT/utils/mfcc.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import librosa
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.optim as optim
|
8 |
+
from torch.utils.data import Dataset, DataLoader, random_split
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from sklearn.metrics import precision_score, recall_score, f1_score
|
11 |
+
from tqdm import tqdm
|
12 |
+
import argparse
|
13 |
+
import wandb
|
14 |
+
|
15 |
+
class RealFakeDataset(Dataset):
|
16 |
+
"""
|
17 |
+
audio/FakeMusicCaps/
|
18 |
+
├─ real/
|
19 |
+
│ └─ MusicCaps/*.wav (label=0)
|
20 |
+
└─ generative/
|
21 |
+
└─ .../*.wav (label=1)
|
22 |
+
"""
|
23 |
+
def __init__(self, root_dir, sr=16000, n_mels=64, target_duration=10.0):
|
24 |
+
|
25 |
+
self.sr = sr
|
26 |
+
self.n_mels = n_mels
|
27 |
+
self.target_duration = target_duration
|
28 |
+
self.target_samples = int(target_duration * sr) # 10초 = 160,000 샘플
|
29 |
+
|
30 |
+
self.file_paths = []
|
31 |
+
self.labels = []
|
32 |
+
|
33 |
+
# Real 데이터 (label=0)
|
34 |
+
real_dir = os.path.join(root_dir, "real")
|
35 |
+
real_wav_files = glob.glob(os.path.join(real_dir, "**", "*.wav"), recursive=True)
|
36 |
+
for f in real_wav_files:
|
37 |
+
self.file_paths.append(f)
|
38 |
+
self.labels.append(0)
|
39 |
+
|
40 |
+
# Generative 데이터 (label=1)
|
41 |
+
gen_dir = os.path.join(root_dir, "generative")
|
42 |
+
gen_wav_files = glob.glob(os.path.join(gen_dir, "**", "*.wav"), recursive=True)
|
43 |
+
for f in gen_wav_files:
|
44 |
+
self.file_paths.append(f)
|
45 |
+
self.labels.append(1)
|
46 |
+
|
47 |
+
def __len__(self):
|
48 |
+
return len(self.file_paths)
|
49 |
+
|
50 |
+
def __getitem__(self, idx):
|
51 |
+
audio_path = self.file_paths[idx]
|
52 |
+
label = self.labels[idx]
|
53 |
+
# print(f"[DEBUG] Path: {audio_path}, Label: {label}") # 추가
|
54 |
+
|
55 |
+
waveform, sr = librosa.load(audio_path, sr=self.sr, mono=True)
|
56 |
+
|
57 |
+
current_samples = waveform.shape[0]
|
58 |
+
if current_samples > self.target_samples:
|
59 |
+
waveform = waveform[:self.target_samples]
|
60 |
+
elif current_samples < self.target_samples:
|
61 |
+
stretch_factor = self.target_samples / current_samples
|
62 |
+
waveform = librosa.effects.time_stretch(waveform, rate=stretch_factor)
|
63 |
+
waveform = waveform[:self.target_samples]
|
64 |
+
|
65 |
+
mfcc = librosa.feature.mfcc(
|
66 |
+
y=waveform, sr=self.sr, n_mfcc=self.n_mels, n_fft=1024, hop_length=256
|
67 |
+
)
|
68 |
+
mfcc = librosa.util.normalize(mfcc)
|
69 |
+
|
70 |
+
mfcc = np.expand_dims(mfcc, axis=0)
|
71 |
+
mfcc_tensor = torch.tensor(mfcc, dtype=torch.float)
|
72 |
+
label_tensor = torch.tensor(label, dtype=torch.long)
|
73 |
+
|
74 |
+
return mfcc_tensor, label_tensor
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
class AudioCNN(nn.Module):
|
79 |
+
def __init__(self, num_classes=2):
|
80 |
+
super(AudioCNN, self).__init__()
|
81 |
+
self.conv_block = nn.Sequential(
|
82 |
+
nn.Conv2d(1, 16, kernel_size=3, padding=1),
|
83 |
+
nn.ReLU(),
|
84 |
+
nn.MaxPool2d(2),
|
85 |
+
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
86 |
+
nn.ReLU(),
|
87 |
+
nn.MaxPool2d(2),
|
88 |
+
nn.AdaptiveAvgPool2d((4,4)) # 최종 -> (B,32,4,4)
|
89 |
+
)
|
90 |
+
self.fc_block = nn.Sequential(
|
91 |
+
nn.Linear(32*4*4, 128),
|
92 |
+
nn.ReLU(),
|
93 |
+
nn.Linear(128, num_classes)
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
x = self.conv_block(x)
|
99 |
+
# x.shape: (B,32,new_freq,new_time)
|
100 |
+
|
101 |
+
# 1) Flatten
|
102 |
+
B, C, H, W = x.shape # 동적 shape
|
103 |
+
x = x.view(B, -1) # (B, 32*H*W)
|
104 |
+
|
105 |
+
# 2) FC
|
106 |
+
x = self.fc_block(x)
|
107 |
+
return x
|
108 |
+
|
109 |
+
|
110 |
+
def my_collate_fn(batch):
|
111 |
+
mel_list, label_list = zip(*batch)
|
112 |
+
|
113 |
+
max_frames = max(m.shape[2] for m in mel_list)
|
114 |
+
|
115 |
+
padded = []
|
116 |
+
for m in mel_list:
|
117 |
+
diff = max_frames - m.shape[2]
|
118 |
+
if diff > 0:
|
119 |
+
print(f"Padding applied: Original frames = {m.shape[2]}, Target frames = {max_frames}")
|
120 |
+
m = F.pad(m, (0, diff), mode='constant', value=0)
|
121 |
+
padded.append(m)
|
122 |
+
|
123 |
+
|
124 |
+
mel_batch = torch.stack(padded, dim=0)
|
125 |
+
label_batch = torch.tensor(label_list, dtype=torch.long)
|
126 |
+
return mel_batch, label_batch
|
127 |
+
|
128 |
+
|
129 |
+
class EarlyStopping:
|
130 |
+
def __init__(self, patience=5, delta=0, path='./ckpt/mfcc/early_stop_best_batch_{batch_size}_epochs_{epochs}_lr_{learning_rate}.pth', verbose=False):
|
131 |
+
self.patience = patience
|
132 |
+
self.delta = delta
|
133 |
+
self.path = path
|
134 |
+
self.verbose = verbose
|
135 |
+
self.counter = 0
|
136 |
+
self.best_loss = None
|
137 |
+
self.early_stop = False
|
138 |
+
|
139 |
+
def __call__(self, val_loss, model):
|
140 |
+
if self.best_loss is None:
|
141 |
+
self.best_loss = val_loss
|
142 |
+
self._save_checkpoint(val_loss, model)
|
143 |
+
elif val_loss > self.best_loss - self.delta:
|
144 |
+
self.counter += 1
|
145 |
+
if self.verbose:
|
146 |
+
print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
|
147 |
+
if self.counter >= self.patience:
|
148 |
+
self.early_stop = True
|
149 |
+
else:
|
150 |
+
self.best_loss = val_loss
|
151 |
+
self._save_checkpoint(val_loss, model)
|
152 |
+
self.counter = 0
|
153 |
+
|
154 |
+
def _save_checkpoint(self, val_loss, model):
|
155 |
+
if self.verbose:
|
156 |
+
print(f"Validation loss decreased ({self.best_loss:.6f} --> {val_loss:.6f}). Saving model ...")
|
157 |
+
torch.save(model.state_dict(), self.path)
|
158 |
+
|
159 |
+
def train(batch_size, epochs, learning_rate, root_dir="audio/FakeMusicCaps"):
|
160 |
+
if not os.path.exists("./ckpt/mfcc/"):
|
161 |
+
os.makedirs("./ckpt/mfcc/")
|
162 |
+
|
163 |
+
wandb.init(
|
164 |
+
project="AI Music Detection",
|
165 |
+
name=f"mfcc_batch_{batch_size}_epochs_{epochs}_lr_{learning_rate}",
|
166 |
+
config={"batch_size": batch_size, "epochs": epochs, "learning_rate": learning_rate},
|
167 |
+
)
|
168 |
+
|
169 |
+
dataset = RealFakeDataset(root_dir=root_dir)
|
170 |
+
n_total = len(dataset)
|
171 |
+
n_train = int(n_total * 0.8)
|
172 |
+
n_val = n_total - n_train
|
173 |
+
train_ds, val_ds = random_split(dataset, [n_train, n_val])
|
174 |
+
|
175 |
+
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn)
|
176 |
+
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=my_collate_fn)
|
177 |
+
|
178 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
179 |
+
model = AudioCNN(num_classes=2).to(device)
|
180 |
+
criterion = nn.CrossEntropyLoss()
|
181 |
+
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
182 |
+
|
183 |
+
best_val_loss = float('inf')
|
184 |
+
patience = 3
|
185 |
+
patience_counter = 0
|
186 |
+
|
187 |
+
for epoch in range(1, epochs + 1):
|
188 |
+
print(f"\n[Epoch {epoch}/{epochs}]")
|
189 |
+
|
190 |
+
# Training
|
191 |
+
model.train()
|
192 |
+
train_loss, train_correct, train_total = 0, 0, 0
|
193 |
+
train_pbar = tqdm(train_loader, desc="Train", leave=False)
|
194 |
+
for mel_batch, labels in train_pbar:
|
195 |
+
mel_batch, labels = mel_batch.to(device), labels.to(device)
|
196 |
+
optimizer.zero_grad()
|
197 |
+
outputs = model(mel_batch)
|
198 |
+
loss = criterion(outputs, labels)
|
199 |
+
loss.backward()
|
200 |
+
optimizer.step()
|
201 |
+
|
202 |
+
train_loss += loss.item() * mel_batch.size(0)
|
203 |
+
preds = outputs.argmax(dim=1)
|
204 |
+
train_correct += (preds == labels).sum().item()
|
205 |
+
train_total += labels.size(0)
|
206 |
+
|
207 |
+
train_pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
208 |
+
|
209 |
+
train_loss /= train_total
|
210 |
+
train_acc = train_correct / train_total
|
211 |
+
|
212 |
+
# Validation
|
213 |
+
model.eval()
|
214 |
+
val_loss, val_correct, val_total = 0, 0, 0
|
215 |
+
all_preds, all_labels = [], []
|
216 |
+
val_pbar = tqdm(val_loader, desc=" Val ", leave=False)
|
217 |
+
with torch.no_grad():
|
218 |
+
for mel_batch, labels in val_pbar:
|
219 |
+
mel_batch, labels = mel_batch.to(device), labels.to(device)
|
220 |
+
outputs = model(mel_batch)
|
221 |
+
loss = criterion(outputs, labels)
|
222 |
+
val_loss += loss.item() * mel_batch.size(0)
|
223 |
+
preds = outputs.argmax(dim=1)
|
224 |
+
val_correct += (preds == labels).sum().item()
|
225 |
+
val_total += labels.size(0)
|
226 |
+
all_preds.extend(preds.cpu().numpy())
|
227 |
+
all_labels.extend(labels.cpu().numpy())
|
228 |
+
|
229 |
+
val_loss /= val_total
|
230 |
+
val_acc = val_correct / val_total
|
231 |
+
val_precision = precision_score(all_labels, all_preds, average="macro")
|
232 |
+
val_recall = recall_score(all_labels, all_preds, average="macro")
|
233 |
+
val_f1 = f1_score(all_labels, all_preds, average="macro")
|
234 |
+
|
235 |
+
print(f"Train Loss: {train_loss:.4f} Acc: {train_acc:.3f} | "
|
236 |
+
f"Val Loss: {val_loss:.4f} Acc: {val_acc:.3f} "
|
237 |
+
f"Precision: {val_precision:.3f} Recall: {val_recall:.3f} F1: {val_f1:.3f}")
|
238 |
+
|
239 |
+
wandb.log({"train_loss": train_loss, "train_acc": train_acc,
|
240 |
+
"val_loss": val_loss, "val_acc": val_acc,
|
241 |
+
"val_precision": val_precision, "val_recall": val_recall, "val_f1": val_f1})
|
242 |
+
|
243 |
+
if val_loss < best_val_loss:
|
244 |
+
best_val_loss = val_loss
|
245 |
+
patience_counter = 0
|
246 |
+
best_model_path = f"./ckpt/mfcc/best_batch_{batch_size}_epochs_{epochs}_lr_{learning_rate}.pth"
|
247 |
+
torch.save(model.state_dict(), best_model_path)
|
248 |
+
print(f"[INFO] New best model saved: {best_model_path}")
|
249 |
+
else:
|
250 |
+
patience_counter += 1
|
251 |
+
if patience_counter >= patience:
|
252 |
+
print("Early stopping triggered!")
|
253 |
+
break
|
254 |
+
|
255 |
+
wandb.finish()
|
256 |
+
|
257 |
+
if __name__ == "__main__":
|
258 |
+
parser = argparse.ArgumentParser(description="Train AI Music Detection model.")
|
259 |
+
parser.add_argument('--batch_size', type=int, required=True, help="Batch size for training")
|
260 |
+
parser.add_argument('--epochs', type=int, required=True, help="Number of epochs")
|
261 |
+
parser.add_argument('--learning_rate', type=float, required=True, help="Learning rate")
|
262 |
+
parser.add_argument('--root_dir', type=str, default="audio/FakeMusicCaps", help="Root directory for dataset")
|
263 |
+
|
264 |
+
args = parser.parse_args()
|
265 |
+
|
266 |
+
train(batch_size=args.batch_size, epochs=args.epochs, learning_rate=args.learning_rate, root_dir=args.root_dir)
|
ISMIR_2025/MERT/utils/utilities.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import os
|
5 |
+
import logging
|
6 |
+
import pickle
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from scipy import stats
|
11 |
+
|
12 |
+
import csv
|
13 |
+
import json
|
14 |
+
|
15 |
+
def create_folder(fd):
|
16 |
+
if not os.path.exists(fd):
|
17 |
+
os.makedirs(fd, exist_ok=True)
|
18 |
+
|
19 |
+
|
20 |
+
def get_filename(path):
|
21 |
+
path = os.path.realpath(path)
|
22 |
+
na_ext = path.split("/")[-1]
|
23 |
+
na = os.path.splitext(na_ext)[0]
|
24 |
+
return na
|
25 |
+
|
26 |
+
|
27 |
+
def get_sub_filepaths(folder):
|
28 |
+
paths = []
|
29 |
+
for root, dirs, files in os.walk(folder):
|
30 |
+
for name in files:
|
31 |
+
path = os.path.join(root, name)
|
32 |
+
paths.append(path)
|
33 |
+
return paths
|
34 |
+
|
35 |
+
|
36 |
+
def create_logging(log_dir, filemode):
|
37 |
+
create_folder(log_dir)
|
38 |
+
i1 = 0
|
39 |
+
|
40 |
+
while os.path.isfile(os.path.join(log_dir, "{:04d}.log".format(i1))):
|
41 |
+
i1 += 1
|
42 |
+
|
43 |
+
log_path = os.path.join(log_dir, "{:04d}.log".format(i1))
|
44 |
+
logging.basicConfig(
|
45 |
+
level=logging.DEBUG,
|
46 |
+
format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s",
|
47 |
+
datefmt="%a, %d %b %Y %H:%M:%S",
|
48 |
+
filename=log_path,
|
49 |
+
filemode=filemode,
|
50 |
+
)
|
51 |
+
|
52 |
+
# Print to console
|
53 |
+
console = logging.StreamHandler()
|
54 |
+
console.setLevel(logging.INFO)
|
55 |
+
formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s")
|
56 |
+
console.setFormatter(formatter)
|
57 |
+
logging.getLogger("").addHandler(console)
|
58 |
+
|
59 |
+
return logging
|
60 |
+
|
61 |
+
|
62 |
+
def read_metadata(csv_path, audio_dir, classes_num, id_to_ix):
|
63 |
+
"""Read metadata of AudioSet from a csv file.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
csv_path: str
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)}
|
70 |
+
"""
|
71 |
+
|
72 |
+
with open(csv_path, "r") as fr:
|
73 |
+
lines = fr.readlines()
|
74 |
+
lines = lines[3:] # Remove heads
|
75 |
+
|
76 |
+
# first, count the audio names only of existing files on disk only
|
77 |
+
|
78 |
+
audios_num = 0
|
79 |
+
for n, line in enumerate(lines):
|
80 |
+
items = line.split(", ")
|
81 |
+
"""items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']"""
|
82 |
+
|
83 |
+
# audio_name = 'Y{}.wav'.format(items[0]) # Audios are started with an extra 'Y' when downloading
|
84 |
+
audio_name = "{}_{}_{}.flac".format(
|
85 |
+
items[0], items[1].replace(".", ""), items[2].replace(".", "")
|
86 |
+
)
|
87 |
+
audio_name = audio_name.replace("_0000_", "_0_")
|
88 |
+
|
89 |
+
if os.path.exists(os.path.join(audio_dir, audio_name)):
|
90 |
+
audios_num += 1
|
91 |
+
|
92 |
+
print("CSV audio files: %d" % (len(lines)))
|
93 |
+
print("Existing audio files: %d" % audios_num)
|
94 |
+
|
95 |
+
# audios_num = len(lines)
|
96 |
+
targets = np.zeros((audios_num, classes_num), dtype=bool)
|
97 |
+
audio_names = []
|
98 |
+
|
99 |
+
n = 0
|
100 |
+
for line in lines:
|
101 |
+
items = line.split(", ")
|
102 |
+
"""items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']"""
|
103 |
+
|
104 |
+
# audio_name = 'Y{}.wav'.format(items[0]) # Audios are started with an extra 'Y' when downloading
|
105 |
+
audio_name = "{}_{}_{}.flac".format(
|
106 |
+
items[0], items[1].replace(".", ""), items[2].replace(".", "")
|
107 |
+
)
|
108 |
+
audio_name = audio_name.replace("_0000_", "_0_")
|
109 |
+
|
110 |
+
if not os.path.exists(os.path.join(audio_dir, audio_name)):
|
111 |
+
continue
|
112 |
+
|
113 |
+
label_ids = items[3].split('"')[1].split(",")
|
114 |
+
|
115 |
+
audio_names.append(audio_name)
|
116 |
+
|
117 |
+
# Target
|
118 |
+
for id in label_ids:
|
119 |
+
ix = id_to_ix[id]
|
120 |
+
targets[n, ix] = 1
|
121 |
+
n += 1
|
122 |
+
|
123 |
+
meta_dict = {"audio_name": np.array(audio_names), "target": targets}
|
124 |
+
return meta_dict
|
125 |
+
|
126 |
+
|
127 |
+
def read_audioset_ontology(id_to_ix):
|
128 |
+
with open('../metadata/audioset_ontology.json', 'r') as f:
|
129 |
+
data = json.load(f)
|
130 |
+
|
131 |
+
# Output: {'name': 'Bob', 'languages': ['English', 'French']}
|
132 |
+
sentences = []
|
133 |
+
for el in data:
|
134 |
+
print(el.keys())
|
135 |
+
id = el['id']
|
136 |
+
if id in id_to_ix:
|
137 |
+
name = el['name']
|
138 |
+
desc = el['description']
|
139 |
+
# if '(' in desc:
|
140 |
+
# print(name, '---', desc)
|
141 |
+
# print(id_to_ix[id], name, '---', )
|
142 |
+
|
143 |
+
# sent = name
|
144 |
+
# sent = name + ', ' + desc.replace('(', '').replace(')', '').lower()
|
145 |
+
# sent = desc.replace('(', '').replace(')', '').lower()
|
146 |
+
# sentences.append(sent)
|
147 |
+
sentences.append(desc)
|
148 |
+
# print(sent)
|
149 |
+
# break
|
150 |
+
return sentences
|
151 |
+
|
152 |
+
|
153 |
+
def original_read_metadata(csv_path, classes_num, id_to_ix):
|
154 |
+
"""Read metadata of AudioSet from a csv file.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
csv_path: str
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)}
|
161 |
+
"""
|
162 |
+
|
163 |
+
with open(csv_path, "r") as fr:
|
164 |
+
lines = fr.readlines()
|
165 |
+
lines = lines[3:] # Remove heads
|
166 |
+
|
167 |
+
# Thomas Pellegrini: added 02/12/2022
|
168 |
+
# check if the audio files indeed exist, otherwise remove from list
|
169 |
+
|
170 |
+
audios_num = len(lines)
|
171 |
+
targets = np.zeros((audios_num, classes_num), dtype=bool)
|
172 |
+
audio_names = []
|
173 |
+
|
174 |
+
for n, line in enumerate(lines):
|
175 |
+
items = line.split(", ")
|
176 |
+
"""items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']"""
|
177 |
+
|
178 |
+
audio_name = "{}_{}_{}.flac".format(
|
179 |
+
items[0], items[1].replace(".", ""), items[2].replace(".", "")
|
180 |
+
) # Audios are started with an extra 'Y' when downloading
|
181 |
+
audio_name = audio_name.replace("_0000_", "_0_")
|
182 |
+
|
183 |
+
label_ids = items[3].split('"')[1].split(",")
|
184 |
+
|
185 |
+
audio_names.append(audio_name)
|
186 |
+
|
187 |
+
# Target
|
188 |
+
for id in label_ids:
|
189 |
+
ix = id_to_ix[id]
|
190 |
+
targets[n, ix] = 1
|
191 |
+
|
192 |
+
meta_dict = {"audio_name": np.array(audio_names), "target": targets}
|
193 |
+
return meta_dict
|
194 |
+
|
195 |
+
def read_audioset_label_tags(class_labels_indices_csv):
|
196 |
+
with open(class_labels_indices_csv, 'r') as f:
|
197 |
+
reader = csv.reader(f, delimiter=',')
|
198 |
+
lines = list(reader)
|
199 |
+
|
200 |
+
labels = []
|
201 |
+
ids = [] # Each label has a unique id such as "/m/068hy"
|
202 |
+
for i1 in range(1, len(lines)):
|
203 |
+
id = lines[i1][1]
|
204 |
+
label = lines[i1][2]
|
205 |
+
ids.append(id)
|
206 |
+
labels.append(label)
|
207 |
+
|
208 |
+
classes_num = len(labels)
|
209 |
+
|
210 |
+
lb_to_ix = {label : i for i, label in enumerate(labels)}
|
211 |
+
ix_to_lb = {i : label for i, label in enumerate(labels)}
|
212 |
+
|
213 |
+
id_to_ix = {id : i for i, id in enumerate(ids)}
|
214 |
+
ix_to_id = {i : id for i, id in enumerate(ids)}
|
215 |
+
|
216 |
+
return lb_to_ix, ix_to_lb, id_to_ix, ix_to_id
|
217 |
+
|
218 |
+
|
219 |
+
|
220 |
+
def float32_to_int16(x):
|
221 |
+
# assert np.max(np.abs(x)) <= 1.5
|
222 |
+
x = np.clip(x, -1, 1)
|
223 |
+
return (x * 32767.0).astype(np.int16)
|
224 |
+
|
225 |
+
|
226 |
+
def int16_to_float32(x):
|
227 |
+
return (x / 32767.0).astype(np.float32)
|
228 |
+
|
229 |
+
|
230 |
+
def pad_or_truncate(x, audio_length):
|
231 |
+
"""Pad all audio to specific length."""
|
232 |
+
if len(x) <= audio_length:
|
233 |
+
return np.concatenate((x, np.zeros(audio_length - len(x))), axis=0)
|
234 |
+
else:
|
235 |
+
return x[0:audio_length]
|
236 |
+
|
237 |
+
|
238 |
+
def pad_audio(x, audio_length):
|
239 |
+
"""Pad all audio to specific length."""
|
240 |
+
if len(x) <= audio_length:
|
241 |
+
return np.concatenate((x, np.zeros(audio_length - len(x))), axis=0)
|
242 |
+
else:
|
243 |
+
return x
|
244 |
+
|
245 |
+
|
246 |
+
def d_prime(auc):
|
247 |
+
d_prime = stats.norm().ppf(auc) * np.sqrt(2.0)
|
248 |
+
return d_prime
|
249 |
+
|
250 |
+
|
251 |
+
class Mixup(object):
|
252 |
+
def __init__(self, mixup_alpha, random_seed=1234):
|
253 |
+
"""Mixup coefficient generator."""
|
254 |
+
self.mixup_alpha = mixup_alpha
|
255 |
+
self.random_state = np.random.RandomState(random_seed)
|
256 |
+
|
257 |
+
def get_lambda(self, batch_size):
|
258 |
+
"""Get mixup random coefficients.
|
259 |
+
Args:
|
260 |
+
batch_size: int
|
261 |
+
Returns:
|
262 |
+
mixup_lambdas: (batch_size,)
|
263 |
+
"""
|
264 |
+
mixup_lambdas = []
|
265 |
+
for n in range(0, batch_size, 2):
|
266 |
+
lam = self.random_state.beta(self.mixup_alpha, self.mixup_alpha, 1)[0]
|
267 |
+
mixup_lambdas.append(lam)
|
268 |
+
mixup_lambdas.append(1.0 - lam)
|
269 |
+
|
270 |
+
return np.array(mixup_lambdas)
|
271 |
+
|
272 |
+
|
273 |
+
class StatisticsContainer(object):
|
274 |
+
def __init__(self, statistics_path):
|
275 |
+
"""Contain statistics of different training iterations."""
|
276 |
+
self.statistics_path = statistics_path
|
277 |
+
|
278 |
+
self.backup_statistics_path = "{}_{}.pkl".format(
|
279 |
+
os.path.splitext(self.statistics_path)[0],
|
280 |
+
datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
|
281 |
+
)
|
282 |
+
|
283 |
+
self.statistics_dict = {"bal": [], "test": []}
|
284 |
+
|
285 |
+
def append(self, iteration, statistics, data_type):
|
286 |
+
statistics["iteration"] = iteration
|
287 |
+
self.statistics_dict[data_type].append(statistics)
|
288 |
+
|
289 |
+
def dump(self):
|
290 |
+
pickle.dump(self.statistics_dict, open(self.statistics_path, "wb"))
|
291 |
+
pickle.dump(self.statistics_dict, open(self.backup_statistics_path, "wb"))
|
292 |
+
logging.info(" Dump statistics to {}".format(self.statistics_path))
|
293 |
+
logging.info(" Dump statistics to {}".format(self.backup_statistics_path))
|
294 |
+
|
295 |
+
def load_state_dict(self, resume_iteration):
|
296 |
+
self.statistics_dict = pickle.load(open(self.statistics_path, "rb"))
|
297 |
+
|
298 |
+
resume_statistics_dict = {"bal": [], "test": []}
|
299 |
+
|
300 |
+
for key in self.statistics_dict.keys():
|
301 |
+
for statistics in self.statistics_dict[key]:
|
302 |
+
if statistics["iteration"] <= resume_iteration:
|
303 |
+
resume_statistics_dict[key].append(statistics)
|
304 |
+
|
305 |
+
self.statistics_dict = resume_statistics_dict
|
ISMIR_2025/Model/__pycache__/networks.cpython-312.pyc
ADDED
Binary file (10.5 kB). View file
|
|
ISMIR_2025/Model/datalib.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
import librosa
|
6 |
+
import numpy as np
|
7 |
+
import utils
|
8 |
+
from sklearn.model_selection import train_test_split
|
9 |
+
from torch.utils.data import Dataset, DataLoader
|
10 |
+
import scipy.signal as signal
|
11 |
+
import scipy.signal
|
12 |
+
from scipy.signal import butter, lfilter
|
13 |
+
import numpy as np
|
14 |
+
import scipy.signal as signal
|
15 |
+
import librosa
|
16 |
+
import torch
|
17 |
+
import random
|
18 |
+
from torch.utils.data import Dataset
|
19 |
+
import logging
|
20 |
+
import csv
|
21 |
+
import logging
|
22 |
+
import time
|
23 |
+
import numpy as np
|
24 |
+
import h5py
|
25 |
+
import torch
|
26 |
+
import torchaudio
|
27 |
+
# Oversampling Lib
|
28 |
+
from imblearn.over_sampling import RandomOverSampler
|
29 |
+
|
30 |
+
class FakeMusicCapsDataset(Dataset):
|
31 |
+
def __init__(self, file_paths, labels, feat_type=['mel'], sr=16000, n_mels=64, target_duration=10.0, augment=True, augment_real=True):
|
32 |
+
self.file_paths = file_paths
|
33 |
+
self.labels = labels
|
34 |
+
self.feat_type = feat_type
|
35 |
+
self.sr = sr
|
36 |
+
self.n_mels = n_mels
|
37 |
+
self.target_duration = target_duration
|
38 |
+
self.target_samples = int(target_duration * sr)
|
39 |
+
self.augment = augment
|
40 |
+
self.augment_real = augment_real
|
41 |
+
|
42 |
+
|
43 |
+
def pre_emphasis(self, x, alpha=0.97):
|
44 |
+
return np.append(x[0], x[1:] - alpha * x[:-1])
|
45 |
+
|
46 |
+
def highpass_filter(self, y, sr, cutoff=1000, order=5):
|
47 |
+
nyquist = 0.5 * sr
|
48 |
+
normal_cutoff = cutoff / nyquist
|
49 |
+
b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
|
50 |
+
return signal.lfilter(b, a, y)
|
51 |
+
|
52 |
+
def augment_audio(self, y, sr):
|
53 |
+
if random.random() < 0.5:
|
54 |
+
rate = random.uniform(0.8, 1.2)
|
55 |
+
y = librosa.effects.time_stretch(y=y, rate=rate)
|
56 |
+
|
57 |
+
if random.random() < 0.5:
|
58 |
+
n_steps = random.randint(-2, 2)
|
59 |
+
y = librosa.effects.pitch_shift(y=y, sr=sr, n_steps=n_steps)
|
60 |
+
|
61 |
+
if random.random() < 0.5:
|
62 |
+
noise_level = np.random.uniform(0.001, 0.005)
|
63 |
+
y = y + np.random.normal(0, noise_level, y.shape)
|
64 |
+
|
65 |
+
if random.random() < 0.5:
|
66 |
+
gain = np.random.uniform(0.9, 1.1)
|
67 |
+
y = y * gain
|
68 |
+
|
69 |
+
return y
|
70 |
+
|
71 |
+
|
72 |
+
def __len__(self):
|
73 |
+
return len(self.file_paths)
|
74 |
+
|
75 |
+
def __getitem__(self, idx):
|
76 |
+
"""
|
77 |
+
Load and preprocess audio file.
|
78 |
+
"""
|
79 |
+
audio_path = self.file_paths[idx]
|
80 |
+
label = self.labels[idx]
|
81 |
+
|
82 |
+
waveform, sr = librosa.load(audio_path, sr=self.sr, mono=True)
|
83 |
+
if label == 0:
|
84 |
+
if self.augment_real:
|
85 |
+
waveform = self.augment_audio(waveform, self.sr)
|
86 |
+
if label == 1:
|
87 |
+
waveform = self.highpass_filter(waveform, self.sr)
|
88 |
+
waveform = self.augment_audio(waveform, self.sr)
|
89 |
+
|
90 |
+
current_samples = waveform.shape[0]
|
91 |
+
if current_samples > self.target_samples:
|
92 |
+
start_idx = (current_samples - self.target_samples) // 2
|
93 |
+
waveform = waveform[start_idx:start_idx + self.target_samples]
|
94 |
+
elif current_samples < self.target_samples:
|
95 |
+
waveform = np.pad(waveform, (0, self.target_samples - current_samples), mode='constant')
|
96 |
+
|
97 |
+
|
98 |
+
mel_spec = librosa.feature.melspectrogram(
|
99 |
+
y=waveform, sr=self.sr, n_mels=self.n_mels, n_fft=1024, hop_length=256
|
100 |
+
)
|
101 |
+
log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
|
102 |
+
|
103 |
+
log_mel_spec = np.expand_dims(log_mel_spec, axis=0)
|
104 |
+
mel_tensor = torch.tensor(log_mel_spec, dtype=torch.float)
|
105 |
+
label_tensor = torch.tensor(label, dtype=torch.long)
|
106 |
+
|
107 |
+
return mel_tensor, label_tensor
|
108 |
+
|
109 |
+
def extract_feature(self, waveform, feat):
|
110 |
+
"""Extracts specified feature (mel, stft, cqt) from waveform."""
|
111 |
+
try:
|
112 |
+
if feat == 'mel':
|
113 |
+
mel_spec = librosa.feature.melspectrogram(y=waveform, sr=self.sr, n_mels=self.n_mels, n_fft=1024, hop_length=256)
|
114 |
+
log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
|
115 |
+
return torch.tensor(log_mel_spec, dtype=torch.float).unsqueeze(0)
|
116 |
+
elif feat == 'stft':
|
117 |
+
stft = librosa.stft(waveform, n_fft=512, hop_length=128, window="hann")
|
118 |
+
logSTFT = np.log(np.abs(stft) + 1e-3)
|
119 |
+
return torch.tensor(logSTFT, dtype=torch.float).unsqueeze(0)
|
120 |
+
elif feat == 'cqt':
|
121 |
+
cqt = librosa.cqt(waveform, sr=self.sr, hop_length=128, bins_per_octave=24)
|
122 |
+
logCQT = np.log(np.abs(cqt) + 1e-3)
|
123 |
+
return torch.tensor(logCQT, dtype=torch.float).unsqueeze(0)
|
124 |
+
else:
|
125 |
+
raise ValueError(f"[ERROR] Unsupported feature type: {feat}")
|
126 |
+
except Exception as e:
|
127 |
+
print(f"[ERROR] Feature extraction failed for {feat}: {e}")
|
128 |
+
return None
|
129 |
+
|
130 |
+
def highpass_filter(self, y, sr, cutoff=1000, order=5):
|
131 |
+
if isinstance(sr, np.ndarray):
|
132 |
+
sr = np.mean(sr)
|
133 |
+
if not isinstance(sr, (int, float)):
|
134 |
+
raise ValueError(f"[ERROR] sr must be a number, but got {type(sr)}: {sr}")
|
135 |
+
if sr <= 0:
|
136 |
+
raise ValueError(f"Invalid sample rate: {sr}. It must be greater than 0.")
|
137 |
+
nyquist = 0.5 * sr
|
138 |
+
if cutoff <= 0 or cutoff >= nyquist:
|
139 |
+
print(f"[WARNING] Invalid cutoff frequency {cutoff}, adjusting...")
|
140 |
+
cutoff = max(10, min(cutoff, nyquist - 1))
|
141 |
+
normal_cutoff = cutoff / nyquist
|
142 |
+
b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
|
143 |
+
y_filtered = signal.lfilter(b, a, y)
|
144 |
+
return y_filtered
|
145 |
+
|
146 |
+
def preprocess_audio(audio_path, sr=16000, n_mels=64, target_duration=10.0):
|
147 |
+
try:
|
148 |
+
waveform, _ = librosa.load(audio_path, sr=sr, mono=True)
|
149 |
+
|
150 |
+
target_samples = int(target_duration * sr)
|
151 |
+
if len(waveform) > target_samples:
|
152 |
+
start_idx = (len(waveform) - target_samples) // 2
|
153 |
+
waveform = waveform[start_idx:start_idx + target_samples]
|
154 |
+
elif len(waveform) < target_samples:
|
155 |
+
waveform = np.pad(waveform, (0, target_samples - len(waveform)), mode='constant')
|
156 |
+
mel_spec = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=n_mels, n_fft=1024, hop_length=256)
|
157 |
+
log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
|
158 |
+
return torch.tensor(log_mel_spec, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
159 |
+
|
160 |
+
except Exception as e:
|
161 |
+
print(f"[ERROR] 전처리 실패: {audio_path} | 오류: {e}")
|
162 |
+
return None
|
163 |
+
|
164 |
+
|
165 |
+
DATASET_PATH = "/data/kym/AI_Music_Detection/audio/FakeMusicCaps"
|
166 |
+
SUNOCAPS_PATH = "/data/kym/Audio/SunoCaps" # Open Set 포함 데이터
|
167 |
+
|
168 |
+
real_files = glob.glob(os.path.join(DATASET_PATH, "real", "**", "*.wav"), recursive=True)
|
169 |
+
gen_files = glob.glob(os.path.join(DATASET_PATH, "generative", "**", "*.wav"), recursive=True)
|
170 |
+
|
171 |
+
open_real_files = real_files + glob.glob(os.path.join(SUNOCAPS_PATH, "real", "**", "*.wav"), recursive=True)
|
172 |
+
open_gen_files = gen_files + glob.glob(os.path.join(SUNOCAPS_PATH, "generative", "**", "*.wav"), recursive=True)
|
173 |
+
|
174 |
+
real_labels = [0] * len(real_files)
|
175 |
+
gen_labels = [1] * len(gen_files)
|
176 |
+
|
177 |
+
open_real_labels = [0] * len(open_real_files)
|
178 |
+
open_gen_labels = [1] * len(open_gen_files)
|
179 |
+
|
180 |
+
real_train, real_val, real_train_labels, real_val_labels = train_test_split(real_files, real_labels, test_size=0.2, random_state=42)
|
181 |
+
gen_train, gen_val, gen_train_labels, gen_val_labels = train_test_split(gen_files, gen_labels, test_size=0.2, random_state=42)
|
182 |
+
|
183 |
+
train_files = real_train + gen_train
|
184 |
+
train_labels = real_train_labels + gen_train_labels
|
185 |
+
val_files = real_val + gen_val
|
186 |
+
val_labels = real_val_labels + gen_val_labels
|
187 |
+
|
188 |
+
closed_test_files = real_files + gen_files
|
189 |
+
closed_test_labels = real_labels + gen_labels
|
190 |
+
|
191 |
+
open_test_files = open_real_files + open_gen_files
|
192 |
+
open_test_labels = open_real_labels + open_gen_labels
|
193 |
+
|
194 |
+
ros = RandomOverSampler(sampling_strategy='auto', random_state=42)
|
195 |
+
train_files_resampled, train_labels_resampled = ros.fit_resample(np.array(train_files).reshape(-1, 1), train_labels)
|
196 |
+
|
197 |
+
train_files = train_files_resampled.reshape(-1).tolist()
|
198 |
+
train_labels = train_labels_resampled
|
199 |
+
print(f"type(train_labels_resampled): {type(train_labels_resampled)}")
|
200 |
+
|
201 |
+
print(f"Train Org Fake: {len(gen_val)}")
|
202 |
+
print(f"Train set (Oversampled) - Real: {sum(1 for label in train_labels if label == 0)}, "
|
203 |
+
f"Fake: {sum(1 for label in train_labels if label == 1)}, Total: {len(train_files)}")
|
204 |
+
print(f"Validation set - Real: {len(real_val)}, Fake: {len(gen_val)}, Total: {len(val_files)}")
|
205 |
+
print(f"Closed Test set - Real: {len(real_files)}, Fake: {len(gen_files)}, Total: {len(closed_test_files)}")
|
206 |
+
print(f"Open Test set - Real: {len(open_real_files)}, Fake: {len(open_gen_files)}, Total: {len(open_test_files)}")
|
ISMIR_2025/Model/main.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.optim as optim
|
8 |
+
from tqdm import tqdm
|
9 |
+
from torch.utils.tensorboard import SummaryWriter
|
10 |
+
import wandb
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, balanced_accuracy_score
|
14 |
+
from datalib import FakeMusicCapsDataset
|
15 |
+
from datalib import (
|
16 |
+
FakeMusicCapsDataset,
|
17 |
+
train_files, val_files, train_labels, val_labels,
|
18 |
+
closed_test_files, closed_test_labels,
|
19 |
+
open_test_files, open_test_labels,
|
20 |
+
preprocess_audio
|
21 |
+
)
|
22 |
+
from datalib import preprocess_audio
|
23 |
+
from networks import CCV
|
24 |
+
from attentionmap import visualize_attention_map
|
25 |
+
from confusion_matrix import plot_confusion_matrix
|
26 |
+
|
27 |
+
def count_parameters(model):
|
28 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
29 |
+
'''
|
30 |
+
python3 main.py --model_name CCV --batch_size 32 --epochs 10 --loss_type ce --oversample True
|
31 |
+
|
32 |
+
audiocnn encoder - crossattn based decoder (ViT) model
|
33 |
+
'''
|
34 |
+
# Argument parsing
|
35 |
+
import argparse
|
36 |
+
parser = argparse.ArgumentParser(description='AI Music Detection Training')
|
37 |
+
parser.add_argument('--gpu', type=str, default='1', help='GPU ID')
|
38 |
+
parser.add_argument('--model_name', type=str, choices=['audiocnn', 'CCV'], default='CCV', help='Model name')
|
39 |
+
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
|
40 |
+
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
|
41 |
+
parser.add_argument('--epochs', type=int, default=10, help='Number of epochs')
|
42 |
+
parser.add_argument('--audio_duration', type=float, default=10, help='Length of the audio slice in seconds')
|
43 |
+
parser.add_argument('--patience_counter', type=int, default=5, help='Early stopping patience')
|
44 |
+
parser.add_argument('--log_dir', type=str, default='', help='TensorBoard log directory')
|
45 |
+
parser.add_argument('--ckpt_path', type=str, default='', help='Checkpoint directory')
|
46 |
+
parser.add_argument("--weight_decay", type=float, default=0.05, help="weight decay (default: 0.0)")
|
47 |
+
parser.add_argument("--loss_type", type=str, choices=["ce", "weighted_ce", "focal"], default="ce", help="Loss function type")
|
48 |
+
|
49 |
+
parser.add_argument('--inference', type=str, help='Path to a .wav file for inference')
|
50 |
+
parser.add_argument("--closed_test", action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)")
|
51 |
+
parser.add_argument("--open_test", action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)")
|
52 |
+
parser.add_argument("--oversample", type=bool, default=True, help="Apply Oversampling to balance classes") # real data oversampling
|
53 |
+
|
54 |
+
|
55 |
+
args = parser.parse_args()
|
56 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
|
57 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
58 |
+
|
59 |
+
torch.manual_seed(42)
|
60 |
+
random.seed(42)
|
61 |
+
np.random.seed(42)
|
62 |
+
wandb.init(project="",
|
63 |
+
name=f"{args.model_name}_lr{args.learning_rate}_ep{args.epochs}_bs{args.batch_size}", config=args)
|
64 |
+
|
65 |
+
if args.model_name == 'CCV':
|
66 |
+
model = CCV(embed_dim=512, num_heads=8, num_layers=6, num_classes=2).cuda()
|
67 |
+
feat_type = 'mel'
|
68 |
+
else:
|
69 |
+
raise ValueError(f"Invalid model name: {args.model_name}")
|
70 |
+
|
71 |
+
model = model.to(device)
|
72 |
+
print(f"Using model: {args.model_name}, Parameters: {count_parameters(model)}")
|
73 |
+
print(f"weight_decay WD: {args.weight_decay}")
|
74 |
+
|
75 |
+
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
|
76 |
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
|
77 |
+
|
78 |
+
if args.loss_type == "ce":
|
79 |
+
print("Using CrossEntropyLoss")
|
80 |
+
criterion = nn.CrossEntropyLoss()
|
81 |
+
|
82 |
+
elif args.loss_type == "weighted_ce":
|
83 |
+
print("Using Weighted CrossEntropyLoss")
|
84 |
+
|
85 |
+
num_real = sum(1 for label in train_labels if label == 0)
|
86 |
+
num_fake = sum(1 for label in train_labels if label == 1)
|
87 |
+
|
88 |
+
total_samples = num_real + num_fake
|
89 |
+
weight_real = total_samples / (2 * num_real)
|
90 |
+
weight_fake = total_samples / (2 * num_fake)
|
91 |
+
class_weights = torch.tensor([weight_real, weight_fake]).to(device)
|
92 |
+
|
93 |
+
criterion = nn.CrossEntropyLoss(weight=class_weights)
|
94 |
+
|
95 |
+
elif args.loss_type == "focal":
|
96 |
+
print("Using Focal Loss")
|
97 |
+
|
98 |
+
class FocalLoss(torch.nn.Module):
|
99 |
+
def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
|
100 |
+
super(FocalLoss, self).__init__()
|
101 |
+
self.alpha = alpha
|
102 |
+
self.gamma = gamma
|
103 |
+
self.reduction = reduction
|
104 |
+
|
105 |
+
def forward(self, inputs, targets):
|
106 |
+
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
|
107 |
+
pt = torch.exp(-ce_loss)
|
108 |
+
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
|
109 |
+
|
110 |
+
if self.reduction == 'mean':
|
111 |
+
return focal_loss.mean()
|
112 |
+
elif self.reduction == 'sum':
|
113 |
+
return focal_loss.sum()
|
114 |
+
else:
|
115 |
+
return focal_loss
|
116 |
+
|
117 |
+
criterion = FocalLoss().to(device)
|
118 |
+
|
119 |
+
if not os.path.exists(args.ckpt_path):
|
120 |
+
os.makedirs(args.ckpt_path)
|
121 |
+
|
122 |
+
train_dataset = FakeMusicCapsDataset(train_files, train_labels, feat_type=feat_type, target_duration=args.audio_duration)
|
123 |
+
val_dataset = FakeMusicCapsDataset(val_files, val_labels, feat_type=feat_type, target_duration=args.audio_duration)
|
124 |
+
|
125 |
+
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16)
|
126 |
+
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16)
|
127 |
+
|
128 |
+
def train(model, train_loader, val_loader, optimizer, scheduler, criterion, device, args):
|
129 |
+
writer = SummaryWriter(log_dir=args.log_dir)
|
130 |
+
best_val_bal_acc = float('inf')
|
131 |
+
early_stop_cnt = 0
|
132 |
+
log_interval = 1
|
133 |
+
|
134 |
+
for epoch in range(args.epochs):
|
135 |
+
print(f"\n[Epoch {epoch + 1}/{args.epochs}]")
|
136 |
+
model.train()
|
137 |
+
train_loss, train_correct, train_total = 0, 0, 0
|
138 |
+
|
139 |
+
all_train_preds= []
|
140 |
+
all_train_labels = []
|
141 |
+
attention_maps = []
|
142 |
+
|
143 |
+
train_pbar = tqdm(train_loader, desc="Train", leave=False)
|
144 |
+
for batch_idx, (data, target) in enumerate(train_pbar):
|
145 |
+
data = data.to(device)
|
146 |
+
target = target.to(device)
|
147 |
+
output = model(data)
|
148 |
+
loss = criterion(output, target)
|
149 |
+
|
150 |
+
optimizer.zero_grad()
|
151 |
+
loss.backward()
|
152 |
+
optimizer.step()
|
153 |
+
|
154 |
+
train_loss += loss.item() * data.size(0)
|
155 |
+
preds = output.argmax(dim=1)
|
156 |
+
train_correct += (preds == target).sum().item()
|
157 |
+
train_total += target.size(0)
|
158 |
+
|
159 |
+
all_train_labels.extend(target.cpu().numpy())
|
160 |
+
all_train_preds.extend(preds.cpu().numpy())
|
161 |
+
|
162 |
+
if hasattr(model, "get_attention_maps"):
|
163 |
+
attention_maps.append(model.get_attention_maps())
|
164 |
+
|
165 |
+
train_loss /= train_total
|
166 |
+
train_acc = train_correct / train_total
|
167 |
+
train_bal_acc = balanced_accuracy_score(all_train_labels, all_train_preds)
|
168 |
+
train_precision = precision_score(all_train_labels, all_train_preds, average="binary")
|
169 |
+
train_recall = recall_score(all_train_labels, all_train_preds, average="binary")
|
170 |
+
train_f1 = f1_score(all_train_labels, all_train_preds, average="binary")
|
171 |
+
|
172 |
+
wandb.log({
|
173 |
+
"Train Loss": train_loss, "Train Accuracy": train_acc,
|
174 |
+
"Train Precision": train_precision, "Train Recall": train_recall,
|
175 |
+
"Train F1 Score": train_f1, "Train B_ACC": train_bal_acc,
|
176 |
+
})
|
177 |
+
|
178 |
+
print(f"Train Epoch: {epoch+1} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.3f} | "
|
179 |
+
f"Train B_ACC: {train_bal_acc:.4f} | Train Prec: {train_precision:.3f} | "
|
180 |
+
f"Train Rec: {train_recall:.3f} | Train F1: {train_f1:.3f}")
|
181 |
+
|
182 |
+
model.eval()
|
183 |
+
val_loss, val_correct, val_total = 0, 0, 0
|
184 |
+
all_val_preds, all_val_labels = [], []
|
185 |
+
attention_maps = []
|
186 |
+
val_pbar = tqdm(val_loader, desc=" Val ", leave=False)
|
187 |
+
with torch.no_grad():
|
188 |
+
for data, target in val_pbar:
|
189 |
+
data, target = data.to(device), target.to(device)
|
190 |
+
output = model(data)
|
191 |
+
loss = criterion(output, target)
|
192 |
+
val_loss += loss.item() * data.size(0)
|
193 |
+
preds = output.argmax(dim=1)
|
194 |
+
val_correct += (preds == target).sum().item()
|
195 |
+
val_total += target.size(0)
|
196 |
+
|
197 |
+
all_val_labels.extend(target.cpu().numpy())
|
198 |
+
all_val_preds.extend(preds.cpu().numpy())
|
199 |
+
|
200 |
+
if hasattr(model, "get_attention_maps"):
|
201 |
+
attention_maps.append(model.get_attention_maps())
|
202 |
+
|
203 |
+
val_loss /= val_total
|
204 |
+
val_acc = val_correct / val_total
|
205 |
+
val_bal_acc = balanced_accuracy_score(all_val_labels, all_val_preds)
|
206 |
+
val_precision = precision_score(all_val_labels, all_val_preds, average="binary")
|
207 |
+
val_recall = recall_score(all_val_labels, all_val_preds, average="binary")
|
208 |
+
val_f1 = f1_score(all_val_labels, all_val_preds, average="binary")
|
209 |
+
|
210 |
+
wandb.log({
|
211 |
+
"Validation Loss": val_loss, "Validation Accuracy": val_acc,
|
212 |
+
"Validation Precision": val_precision, "Validation Recall": val_recall,
|
213 |
+
"Validation F1 Score": val_f1, "Validation B_ACC": val_bal_acc,
|
214 |
+
})
|
215 |
+
|
216 |
+
print(f"Val Epoch: {epoch+1} [{batch_idx * len(data)}/{len(val_loader.dataset)} "
|
217 |
+
f"({100. * batch_idx / len(val_loader):.0f}%)]\t"
|
218 |
+
f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.3f} | "
|
219 |
+
f"Val B_ACC: {val_bal_acc:.4f} | Val Prec: {val_precision:.3f} | "
|
220 |
+
f"Val Rec: {val_recall:.3f} | Val F1: {val_f1:.3f}")
|
221 |
+
|
222 |
+
if epoch % 1 == 0 and len(attention_maps) > 0:
|
223 |
+
print(f"Visualizing Attention Map at Epoch {epoch+1}")
|
224 |
+
|
225 |
+
if isinstance(attention_maps[0], list):
|
226 |
+
attn_map_numpy = np.array([t.detach().cpu().numpy() for t in attention_maps[0]])
|
227 |
+
elif isinstance(attention_maps[0], torch.Tensor):
|
228 |
+
attn_map_numpy = attention_maps[0].detach().cpu().numpy()
|
229 |
+
else:
|
230 |
+
attn_map_numpy = np.array(attention_maps[0])
|
231 |
+
|
232 |
+
print(f"Attention Map Shape: {attn_map_numpy.shape}")
|
233 |
+
|
234 |
+
if len(attn_map_numpy) > 0:
|
235 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
236 |
+
ax.imshow(attn_map_numpy[0], cmap='viridis', interpolation='nearest')
|
237 |
+
ax.set_title(f"Attention Map - Epoch {epoch+1}")
|
238 |
+
plt.colorbar(ax.imshow(attn_map_numpy[0], cmap='viridis'))
|
239 |
+
plt.savefig("")
|
240 |
+
plt.show()
|
241 |
+
else:
|
242 |
+
print(f"Warning: attention_maps[0] is empty! Shape={attn_map_numpy.shape}")
|
243 |
+
|
244 |
+
if val_bal_acc < best_val_bal_acc:
|
245 |
+
best_val_bal_acc = val_bal_acc
|
246 |
+
early_stop_cnt = 0
|
247 |
+
torch.save(model.state_dict(), os.path.join(args.ckpt_path, f"best_model_{args.model_name}.pth"))
|
248 |
+
print("Best model saved.")
|
249 |
+
else:
|
250 |
+
early_stop_cnt += 1
|
251 |
+
print(f'PATIENCE {early_stop_cnt}/{args.patience_counter}')
|
252 |
+
|
253 |
+
if early_stop_cnt >= args.patience_counter:
|
254 |
+
print("Early stopping triggered.")
|
255 |
+
break
|
256 |
+
|
257 |
+
scheduler.step()
|
258 |
+
plot_confusion_matrix(all_val_labels, all_val_preds, classes=["REAL", "FAKE"], writer=writer, epoch=epoch)
|
259 |
+
|
260 |
+
wandb.finish()
|
261 |
+
writer.close()
|
262 |
+
|
263 |
+
def predict(audio_path):
|
264 |
+
print(f"Loading model from {args.ckpt_path}/celoss_best_model_{args.model_name}.pth")
|
265 |
+
model.load_state_dict(torch.load(os.path.join(args.ckpt_path, f"best_model_{args.model_name}.pth"), map_location=device))
|
266 |
+
model.eval()
|
267 |
+
|
268 |
+
input_tensor = preprocess_audio(audio_path).to(device)
|
269 |
+
|
270 |
+
with torch.no_grad():
|
271 |
+
output = model(input_tensor)
|
272 |
+
probabilities = F.softmax(output, dim=1)
|
273 |
+
ai_music_prob = probabilities[0, 1].item()
|
274 |
+
|
275 |
+
if ai_music_prob > 0.5:
|
276 |
+
print(f"FAKE MUSIC {ai_music_prob:.2%})")
|
277 |
+
else:
|
278 |
+
print(f"REAL MUSIC {100 - ai_music_prob * 100:.2f}%")
|
279 |
+
|
280 |
+
def Test(model, test_loader, criterion, device):
|
281 |
+
model.load_state_dict(torch.load(os.path.join(args.ckpt_path, f"best_model_{args.model_name}.pth"), map_location=device))
|
282 |
+
model.eval()
|
283 |
+
test_loss, test_correct, test_total = 0, 0, 0
|
284 |
+
all_preds, all_labels = [], []
|
285 |
+
|
286 |
+
with torch.no_grad():
|
287 |
+
for data, target in tqdm(test_loader, desc=" Test ", leave=False):
|
288 |
+
data, target = data.to(device), target.to(device)
|
289 |
+
output = model(data)
|
290 |
+
loss = criterion(output, target)
|
291 |
+
|
292 |
+
test_loss += loss.item() * data.size(0)
|
293 |
+
preds = output.argmax(dim=1)
|
294 |
+
test_correct += (preds == target).sum().item()
|
295 |
+
test_total += target.size(0)
|
296 |
+
|
297 |
+
all_labels.extend(target.cpu().numpy())
|
298 |
+
all_preds.extend(preds.cpu().numpy())
|
299 |
+
|
300 |
+
test_loss /= test_total
|
301 |
+
test_acc = test_correct / test_total
|
302 |
+
test_bal_acc = balanced_accuracy_score(all_labels, all_preds)
|
303 |
+
test_precision = precision_score(all_labels, all_preds, average="binary")
|
304 |
+
test_recall = recall_score(all_labels, all_preds, average="binary")
|
305 |
+
test_f1 = f1_score(all_labels, all_preds, average="binary")
|
306 |
+
|
307 |
+
print(f"\nTest Results - Loss: {test_loss:.4f} | Test Acc: {test_acc:.3f} | "
|
308 |
+
f"Test B_ACC: {test_bal_acc:.4f} | Test Prec: {test_precision:.3f} | "
|
309 |
+
f"Test Rec: {test_recall:.3f} | Test F1: {test_f1:.3f}")
|
310 |
+
|
311 |
+
|
312 |
+
if __name__ == "__main__":
|
313 |
+
train(model, train_loader, val_loader, optimizer, scheduler, criterion, device, args)
|
314 |
+
if args.closed_test:
|
315 |
+
print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...")
|
316 |
+
test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels, feat_type=feat_type, target_duration=args.audio_duration)
|
317 |
+
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16)
|
318 |
+
|
319 |
+
elif args.open_test:
|
320 |
+
print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...")
|
321 |
+
test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels, feat_type=feat_type, target_duration=args.audio_duration)
|
322 |
+
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16)
|
323 |
+
|
324 |
+
else:
|
325 |
+
print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...")
|
326 |
+
test_dataset = FakeMusicCapsDataset(val_files, val_labels, feat_type=feat_type, target_duration=args.audio_duration)
|
327 |
+
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16)
|
328 |
+
|
329 |
+
print("\nEvaluating Model on Test Set...")
|
330 |
+
Test(model, test_loader, criterion, device)
|
331 |
+
|
332 |
+
if args.inference:
|
333 |
+
if not os.path.exists(args.inference):
|
334 |
+
print(f"[ERROR] No File Found: {args.inference}")
|
335 |
+
else:
|
336 |
+
predict(args.inference)
|
ISMIR_2025/Model/networks.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class audiocnn(nn.Module):
|
5 |
+
def __init__(self, num_classes=2):
|
6 |
+
super(audiocnn, self).__init__()
|
7 |
+
self.conv_block = nn.Sequential(
|
8 |
+
nn.Conv2d(1, 16, kernel_size=3, padding=1),
|
9 |
+
nn.ReLU(),
|
10 |
+
nn.MaxPool2d(2),
|
11 |
+
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
12 |
+
nn.ReLU(),
|
13 |
+
nn.MaxPool2d(2),
|
14 |
+
nn.AdaptiveAvgPool2d((4,4)) # 최종 -> (B,32,4,4)
|
15 |
+
)
|
16 |
+
self.fc_block = nn.Sequential(
|
17 |
+
nn.Linear(32*4*4, 128),
|
18 |
+
nn.ReLU(),
|
19 |
+
nn.Linear(128, num_classes)
|
20 |
+
)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
x = self.conv_block(x)
|
24 |
+
# x.shape: (B,32,new_freq,new_time)
|
25 |
+
|
26 |
+
# 1) Flatten
|
27 |
+
B, C, H, W = x.shape # 동적 shape
|
28 |
+
x = x.view(B, -1) # (B, 32*H*W)
|
29 |
+
|
30 |
+
# 2) FC
|
31 |
+
x = self.fc_block(x)
|
32 |
+
return x
|
33 |
+
|
34 |
+
class AudioCNN(nn.Module):
|
35 |
+
def __init__(self, embed_dim=512):
|
36 |
+
super(AudioCNN, self).__init__()
|
37 |
+
self.conv_block = nn.Sequential(
|
38 |
+
nn.Conv2d(1, 16, kernel_size=3, padding=1),
|
39 |
+
nn.ReLU(),
|
40 |
+
nn.MaxPool2d(2),
|
41 |
+
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
42 |
+
nn.ReLU(),
|
43 |
+
nn.MaxPool2d(2),
|
44 |
+
nn.AdaptiveAvgPool2d((4, 4)) # 최종 -> (B, 32, 4, 4)
|
45 |
+
)
|
46 |
+
self.projection = nn.Linear(32 * 4 * 4, embed_dim)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
x = self.conv_block(x)
|
50 |
+
B, C, H, W = x.shape
|
51 |
+
x = x.view(B, -1) # Flatten (B, C * H * W)
|
52 |
+
x = self.projection(x) # Project to embed_dim
|
53 |
+
return x
|
54 |
+
|
55 |
+
class ViTDecoder(nn.Module):
|
56 |
+
def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2):
|
57 |
+
super(ViTDecoder, self).__init__()
|
58 |
+
|
59 |
+
# Transformer layers
|
60 |
+
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
|
61 |
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
62 |
+
|
63 |
+
# Classification head
|
64 |
+
self.classifier = nn.Sequential(
|
65 |
+
nn.LayerNorm(embed_dim),
|
66 |
+
nn.Linear(embed_dim, num_classes)
|
67 |
+
)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
# Transformer expects input of shape (seq_len, batch, embed_dim)
|
71 |
+
x = x.unsqueeze(1).permute(1, 0, 2) # Add sequence dim (1, B, embed_dim)
|
72 |
+
x = self.transformer(x) # Pass through Transformer
|
73 |
+
x = x.mean(dim=0) # Take the mean over the sequence dimension (B, embed_dim)
|
74 |
+
|
75 |
+
x = self.classifier(x) # Classification head
|
76 |
+
return x
|
77 |
+
|
78 |
+
class AudioCNNWithViTDecoder(nn.Module):
|
79 |
+
def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2):
|
80 |
+
super(AudioCNNWithViTDecoder, self).__init__()
|
81 |
+
self.encoder = AudioCNN(embed_dim=embed_dim)
|
82 |
+
self.decoder = ViTDecoder(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes)
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
x = self.encoder(x) # Pass through AudioCNN encoder
|
86 |
+
x = self.decoder(x) # Pass through ViT decoder
|
87 |
+
return x
|
88 |
+
|
89 |
+
|
90 |
+
# class AudioCNN(nn.Module):
|
91 |
+
# def __init__(self, num_classes=2):
|
92 |
+
# super(AudioCNN, self).__init__()
|
93 |
+
# self.conv_block = nn.Sequential(
|
94 |
+
# nn.Conv2d(1, 16, kernel_size=3, padding=1),
|
95 |
+
# nn.ReLU(),
|
96 |
+
# nn.MaxPool2d(2),
|
97 |
+
# nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
98 |
+
# nn.ReLU(),
|
99 |
+
# nn.MaxPool2d(2),
|
100 |
+
# nn.AdaptiveAvgPool2d((4,4)) # 최종 -> (B,32,4,4)
|
101 |
+
# )
|
102 |
+
# self.fc_block = nn.Sequential(
|
103 |
+
# nn.Linear(32*4*4, 128),
|
104 |
+
# nn.ReLU(),
|
105 |
+
# nn.Linear(128, num_classes)
|
106 |
+
# )
|
107 |
+
|
108 |
+
|
109 |
+
# def forward(self, x):
|
110 |
+
# x = self.conv_block(x)
|
111 |
+
# # x.shape: (B,32,new_freq,new_time)
|
112 |
+
|
113 |
+
# # 1) Flatten
|
114 |
+
# B, C, H, W = x.shape # 동적 shape
|
115 |
+
# x = x.view(B, -1) # (B, 32*H*W)
|
116 |
+
|
117 |
+
# # 2) FC
|
118 |
+
# x = self.fc_block(x)
|
119 |
+
# return x
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
class audio_crossattn(nn.Module):
|
124 |
+
def __init__(self, embed_dim=512):
|
125 |
+
super(audio_crossattn, self).__init__()
|
126 |
+
self.conv_block = nn.Sequential(
|
127 |
+
nn.Conv2d(1, 16, kernel_size=3, padding=1),
|
128 |
+
nn.ReLU(),
|
129 |
+
nn.MaxPool2d(2),
|
130 |
+
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
131 |
+
nn.ReLU(),
|
132 |
+
nn.MaxPool2d(2),
|
133 |
+
nn.AdaptiveAvgPool2d((4, 4)) # 최종 출력 -> (B, 32, 4, 4)
|
134 |
+
)
|
135 |
+
self.projection = nn.Linear(32 * 4 * 4, embed_dim)
|
136 |
+
|
137 |
+
def forward(self, x):
|
138 |
+
x = self.conv_block(x) # Convolutional feature extraction
|
139 |
+
B, C, H, W = x.shape
|
140 |
+
x = x.view(B, -1) # Flatten (B, C * H * W)
|
141 |
+
x = self.projection(x) # Linear projection to embed_dim
|
142 |
+
return x
|
143 |
+
|
144 |
+
|
145 |
+
class CrossAttentionLayer(nn.Module):
|
146 |
+
def __init__(self, embed_dim, num_heads):
|
147 |
+
super(CrossAttentionLayer, self).__init__()
|
148 |
+
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
|
149 |
+
self.layer_norm = nn.LayerNorm(embed_dim)
|
150 |
+
self.feed_forward = nn.Sequential(
|
151 |
+
nn.Linear(embed_dim, embed_dim * 4),
|
152 |
+
nn.ReLU(),
|
153 |
+
nn.Linear(embed_dim * 4, embed_dim)
|
154 |
+
)
|
155 |
+
|
156 |
+
def forward(self, x, cross_input):
|
157 |
+
# Cross-attention between x and cross_input
|
158 |
+
attn_output, _ = self.multihead_attn(query=x, key=cross_input, value=cross_input)
|
159 |
+
x = self.layer_norm(x + attn_output) # Add & Norm
|
160 |
+
feed_forward_output = self.feed_forward(x)
|
161 |
+
x = self.layer_norm(x + feed_forward_output) # Add & Norm
|
162 |
+
return x
|
163 |
+
|
164 |
+
class ViTDecoderWithCrossAttention(nn.Module):
|
165 |
+
def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2):
|
166 |
+
super(ViTDecoderWithCrossAttention, self).__init__()
|
167 |
+
|
168 |
+
# Cross-Attention layers
|
169 |
+
self.cross_attention_layers = nn.ModuleList([
|
170 |
+
CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers)
|
171 |
+
])
|
172 |
+
|
173 |
+
# Transformer Encoder layers
|
174 |
+
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
|
175 |
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
176 |
+
|
177 |
+
# Classification head
|
178 |
+
self.classifier = nn.Sequential(
|
179 |
+
nn.LayerNorm(embed_dim),
|
180 |
+
nn.Linear(embed_dim, num_classes)
|
181 |
+
)
|
182 |
+
|
183 |
+
def forward(self, x, cross_attention_input):
|
184 |
+
# Pass through Cross-Attention layers
|
185 |
+
for layer in self.cross_attention_layers:
|
186 |
+
x = layer(x, cross_attention_input)
|
187 |
+
|
188 |
+
# Transformer expects input of shape (seq_len, batch, embed_dim)
|
189 |
+
x = x.unsqueeze(1).permute(1, 0, 2) # Add sequence dim (1, B, embed_dim)
|
190 |
+
x = self.transformer(x) # Pass through Transformer
|
191 |
+
embedding = x.mean(dim=0) # Take the mean over the sequence dimension (B, embed_dim)
|
192 |
+
|
193 |
+
# Classification head
|
194 |
+
x = self.classifier(embedding)
|
195 |
+
return x, embedding
|
196 |
+
|
197 |
+
# class AudioCNNWithViTDecoderAndCrossAttention(nn.Module):
|
198 |
+
# def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2):
|
199 |
+
# super(AudioCNNWithViTDecoderAndCrossAttention, self).__init__()
|
200 |
+
# self.encoder = audio_crossattn(embed_dim=embed_dim)
|
201 |
+
# self.decoder = ViTDecoderWithCrossAttention(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes)
|
202 |
+
|
203 |
+
# def forward(self, x, cross_attention_input):
|
204 |
+
# # Pass through AudioCNN encoder
|
205 |
+
# x = self.encoder(x)
|
206 |
+
|
207 |
+
# # Pass through ViTDecoder with Cross-Attention
|
208 |
+
# x = self.decoder(x, cross_attention_input)
|
209 |
+
# return x
|
210 |
+
class CCV(nn.Module):
|
211 |
+
def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True):
|
212 |
+
super(CCV, self).__init__()
|
213 |
+
self.encoder = AudioCNN(embed_dim=embed_dim)
|
214 |
+
self.decoder = ViTDecoderWithCrossAttention(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes)
|
215 |
+
if freeze_feature_extractor:
|
216 |
+
for param in self.encoder.parameters():
|
217 |
+
param.requires_grad = False
|
218 |
+
for param in self.decoder.parameters():
|
219 |
+
param.requires_grad = False
|
220 |
+
def forward(self, x, cross_attention_input=None):
|
221 |
+
# Pass through AudioCNN encoder
|
222 |
+
x = self.encoder(x)
|
223 |
+
|
224 |
+
# If cross_attention_input is not provided, use the encoder output
|
225 |
+
if cross_attention_input is None:
|
226 |
+
cross_attention_input = x
|
227 |
+
|
228 |
+
# Pass through ViTDecoder with Cross-Attention
|
229 |
+
x, embedding = self.decoder(x, cross_attention_input)
|
230 |
+
return x, embedding
|
231 |
+
|
232 |
+
#---------------------------------------------------------
|
233 |
+
'''
|
234 |
+
audiocnn weight frozen
|
235 |
+
crossatten decoder -lora tuning
|
236 |
+
'''
|
237 |
+
|
ISMIR_2025/Model/test.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from sklearn.metrics import precision_score, recall_score, f1_score, balanced_accuracy_score, confusion_matrix
|
8 |
+
from datalib_f import (
|
9 |
+
FakeMusicCapsDataset,
|
10 |
+
closed_test_files, closed_test_labels,
|
11 |
+
open_test_files, open_test_labels,
|
12 |
+
val_files, val_labels
|
13 |
+
)
|
14 |
+
from networks_f import CCV_Wav2Vec2
|
15 |
+
import argparse
|
16 |
+
|
17 |
+
parser = argparse.ArgumentParser(description="AI Music Detection Testing")
|
18 |
+
parser.add_argument('--gpu', type=str, default='1', help='GPU ID')
|
19 |
+
parser.add_argument('--model_name', type=str, choices=['audiocnn', 'CCV'], default='CCV_Wav2Vec2', help='Model name')
|
20 |
+
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
|
21 |
+
parser.add_argument('--ckpt_path', type=str, default='/data/kym/AI_Music_Detection/Code/model/tensorboard/wav2vec', help='Checkpoint directory')
|
22 |
+
parser.add_argument('--closed_test', action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)")
|
23 |
+
parser.add_argument('--open_test', action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)")
|
24 |
+
parser.add_argument('--output_path', type=str, default='/data/kym/AI_Music_Detection/Code/model/test_results/w_celoss_repreprocess/wav2vec', help='Path to save test results')
|
25 |
+
|
26 |
+
args = parser.parse_args()
|
27 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
|
28 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
29 |
+
|
30 |
+
def plot_confusion_matrix(y_true, y_pred, classes, output_path):
|
31 |
+
cm = confusion_matrix(y_true, y_pred)
|
32 |
+
fig, ax = plt.subplots(figsize=(6, 6))
|
33 |
+
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
|
34 |
+
ax.figure.colorbar(im, ax=ax)
|
35 |
+
|
36 |
+
num_classes = cm.shape[0]
|
37 |
+
tick_labels = classes[:num_classes]
|
38 |
+
|
39 |
+
ax.set(xticks=np.arange(num_classes),
|
40 |
+
yticks=np.arange(num_classes),
|
41 |
+
xticklabels=tick_labels,
|
42 |
+
yticklabels=tick_labels,
|
43 |
+
ylabel='True label',
|
44 |
+
xlabel='Predicted label')
|
45 |
+
|
46 |
+
thresh = cm.max() / 2.
|
47 |
+
for i in range(cm.shape[0]):
|
48 |
+
for j in range(cm.shape[1]):
|
49 |
+
ax.text(j, i, format(cm[i, j], 'd'),
|
50 |
+
ha="center", va="center",
|
51 |
+
color="white" if cm[i, j] > thresh else "black")
|
52 |
+
|
53 |
+
fig.tight_layout()
|
54 |
+
plt.savefig(output_path)
|
55 |
+
plt.close(fig)
|
56 |
+
|
57 |
+
if args.model_name == 'CCV_Wav2Vec2':
|
58 |
+
model = CCV_Wav2Vec2(embed_dim=512, num_heads=8, num_layers=6, num_classes=2).to(device)
|
59 |
+
else:
|
60 |
+
raise ValueError(f"Invalid model name: {args.model_name}")
|
61 |
+
|
62 |
+
ckpt_file = os.path.join(args.ckpt_path, f"best_model_{args.model_name}.pth")
|
63 |
+
if not os.path.exists(ckpt_file):
|
64 |
+
raise FileNotFoundError(f"Checkpoint not found: {ckpt_file}")
|
65 |
+
|
66 |
+
print(f"\nLoading model from {ckpt_file}")
|
67 |
+
|
68 |
+
# model.load_state_dict(torch.load(ckpt_file, map_location=device))
|
69 |
+
# 병렬
|
70 |
+
state_dict = torch.load(ckpt_file, map_location=device)
|
71 |
+
from collections import OrderedDict
|
72 |
+
new_state_dict = OrderedDict()
|
73 |
+
for k, v in state_dict.items():
|
74 |
+
name = k[7:] if k.startswith("module.") else k
|
75 |
+
new_state_dict[name] = v
|
76 |
+
model.load_state_dict(new_state_dict)
|
77 |
+
# 병렬
|
78 |
+
model.eval()
|
79 |
+
|
80 |
+
torch.cuda.empty_cache()
|
81 |
+
|
82 |
+
if args.closed_test:
|
83 |
+
print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...")
|
84 |
+
test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels, feat_type="mel", target_duration=10.0)
|
85 |
+
elif args.open_test:
|
86 |
+
print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...")
|
87 |
+
test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels, feat_type="mel", target_duration=10.0)
|
88 |
+
else:
|
89 |
+
print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...")
|
90 |
+
test_dataset = FakeMusicCapsDataset(val_files, val_labels, feat_type="mel", target_duration=10.0)
|
91 |
+
|
92 |
+
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16)
|
93 |
+
|
94 |
+
def Test(model, test_loader, device):
|
95 |
+
model.eval()
|
96 |
+
test_loss, test_correct, test_total = 0, 0, 0
|
97 |
+
all_preds, all_labels = [], []
|
98 |
+
|
99 |
+
with torch.no_grad():
|
100 |
+
for data, target in test_loader:
|
101 |
+
data, target = data.to(device), target.to(device)
|
102 |
+
output = model(data)
|
103 |
+
loss = F.cross_entropy(output, target)
|
104 |
+
|
105 |
+
test_loss += loss.item() * data.size(0)
|
106 |
+
preds = output.argmax(dim=1)
|
107 |
+
test_correct += (preds == target).sum().item()
|
108 |
+
test_total += target.size(0)
|
109 |
+
|
110 |
+
all_labels.extend(target.cpu().numpy())
|
111 |
+
all_preds.extend(preds.cpu().numpy())
|
112 |
+
|
113 |
+
test_loss /= test_total
|
114 |
+
test_acc = test_correct / test_total
|
115 |
+
test_bal_acc = balanced_accuracy_score(all_labels, all_preds)
|
116 |
+
test_precision = precision_score(all_labels, all_preds, average="binary")
|
117 |
+
test_recall = recall_score(all_labels, all_preds, average="binary")
|
118 |
+
test_f1 = f1_score(all_labels, all_preds, average="binary")
|
119 |
+
|
120 |
+
print(f"\nTest Results - Loss: {test_loss:.4f} | Test Acc: {test_acc:.3f} | "
|
121 |
+
f"Test B_ACC: {test_bal_acc:.4f} | Test Prec: {test_precision:.3f} | "
|
122 |
+
f"Test Rec: {test_recall:.3f} | Test F1: {test_f1:.3f}")
|
123 |
+
|
124 |
+
os.makedirs(args.output_path, exist_ok=True)
|
125 |
+
conf_matrix_path = os.path.join(args.output_path, f"confusion_matrix_{args.model_name}.png")
|
126 |
+
plot_confusion_matrix(all_labels, all_preds, classes=["real", "generative"], output_path=conf_matrix_path)
|
127 |
+
|
128 |
+
print("\nEvaluating Model on Test Set...")
|
129 |
+
Test(model, test_loader, device)
|
ISMIR_2025/music2vec/__pycache__/datalib.cpython-311.pyc
ADDED
Binary file (10.2 kB). View file
|
|
ISMIR_2025/music2vec/__pycache__/networks.cpython-311.pyc
ADDED
Binary file (11.4 kB). View file
|
|
ISMIR_2025/music2vec/__pycache__/networks.cpython-312.pyc
ADDED
Binary file (11 kB). View file
|
|
ISMIR_2025/music2vec/datalib.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
import librosa
|
6 |
+
import numpy as np
|
7 |
+
from sklearn.model_selection import train_test_split
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
from imblearn.over_sampling import RandomOverSampler
|
10 |
+
from transformers import Wav2Vec2Processor
|
11 |
+
import torch
|
12 |
+
import torchaudio
|
13 |
+
from torch.nn.utils.rnn import pad_sequence
|
14 |
+
import scipy.signal as signal
|
15 |
+
import random
|
16 |
+
|
17 |
+
class FakeMusicCapsDataset(Dataset):
|
18 |
+
def __init__(self, file_paths, labels, sr=16000, target_duration=10.0, augment=True):
|
19 |
+
self.file_paths = file_paths
|
20 |
+
self.labels = labels
|
21 |
+
self.sr = sr
|
22 |
+
self.target_samples = int(target_duration * sr)
|
23 |
+
self.augment = augment
|
24 |
+
def __len__(self):
|
25 |
+
return len(self.file_paths)
|
26 |
+
|
27 |
+
def augment_audio(self, y, sr):
|
28 |
+
if isinstance(y, torch.Tensor):
|
29 |
+
y = y.numpy()
|
30 |
+
if random.random() < 0.5:
|
31 |
+
rate = random.uniform(0.8, 1.2)
|
32 |
+
y = librosa.effects.time_stretch(y=y, rate=rate)
|
33 |
+
if random.random() < 0.5:
|
34 |
+
n_steps = random.randint(-2, 2)
|
35 |
+
y = librosa.effects.pitch_shift(y=y, sr=sr, n_steps=n_steps)
|
36 |
+
if random.random() < 0.5:
|
37 |
+
noise_level = np.random.uniform(0.001, 0.005)
|
38 |
+
y = y + np.random.normal(0, noise_level, y.shape)
|
39 |
+
if random.random() < 0.5:
|
40 |
+
gain = np.random.uniform(0.9, 1.1)
|
41 |
+
y = y * gain
|
42 |
+
return torch.tensor(y, dtype=torch.float32)
|
43 |
+
|
44 |
+
|
45 |
+
def __getitem__(self, idx):
|
46 |
+
audio_path = self.file_paths[idx]
|
47 |
+
label = self.labels[idx]
|
48 |
+
|
49 |
+
waveform, sr = torchaudio.load(audio_path)
|
50 |
+
waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sr)(waveform)
|
51 |
+
waveform = waveform.mean(dim=0)
|
52 |
+
current_samples = waveform.shape[0]
|
53 |
+
|
54 |
+
if label == 0:
|
55 |
+
waveform = self.augment_audio(waveform, self.sr)
|
56 |
+
if label == 1:
|
57 |
+
waveform = self.highpass_filter(waveform, self.sr)
|
58 |
+
waveform = self.augment_audio(waveform, self.sr)
|
59 |
+
|
60 |
+
if current_samples > self.target_samples:
|
61 |
+
waveform = waveform[:self.target_samples]
|
62 |
+
elif current_samples < self.target_samples:
|
63 |
+
pad_length = self.target_samples - current_samples
|
64 |
+
waveform = torch.nn.functional.pad(waveform, (0, pad_length))
|
65 |
+
|
66 |
+
# waveform = waveform.squeeze(0)
|
67 |
+
if isinstance(waveform, np.ndarray):
|
68 |
+
waveform = torch.tensor(waveform, dtype=torch.float32)
|
69 |
+
|
70 |
+
return waveform.unsqueeze(0), torch.tensor(label, dtype=torch.long)
|
71 |
+
|
72 |
+
def highpass_filter(self, y, sr, cutoff=500, order=5):
|
73 |
+
if isinstance(sr, np.ndarray):
|
74 |
+
sr = np.mean(sr)
|
75 |
+
if not isinstance(sr, (int, float)):
|
76 |
+
raise ValueError(f"[ERROR] sr must be a number, but got {type(sr)}: {sr}")
|
77 |
+
if sr <= 0:
|
78 |
+
raise ValueError(f"Invalid sample rate: {sr}. It must be greater than 0.")
|
79 |
+
nyquist = 0.5 * sr
|
80 |
+
if cutoff <= 0 or cutoff >= nyquist:
|
81 |
+
print(f"[WARNING] Invalid cutoff frequency {cutoff}, adjusting...")
|
82 |
+
cutoff = max(10, min(cutoff, nyquist - 1))
|
83 |
+
normal_cutoff = cutoff / nyquist
|
84 |
+
b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
|
85 |
+
y_filtered = signal.lfilter(b, a, y)
|
86 |
+
return y_filtered
|
87 |
+
|
88 |
+
def preprocess_audio(audio_path, target_sr=16000, max_length=160000):
|
89 |
+
waveform, sr = torchaudio.load(audio_path)
|
90 |
+
if sr != target_sr:
|
91 |
+
waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(waveform)
|
92 |
+
|
93 |
+
waveform = waveform.mean(dim=0).unsqueeze(0)
|
94 |
+
|
95 |
+
current_samples = waveform.shape[1]
|
96 |
+
if current_samples > max_length:
|
97 |
+
start_idx = (current_samples - max_length) // 2
|
98 |
+
waveform = waveform[:, start_idx:start_idx + max_length]
|
99 |
+
elif current_samples < max_length:
|
100 |
+
pad_length = max_length - current_samples
|
101 |
+
waveform = torch.nn.functional.pad(waveform, (0, pad_length))
|
102 |
+
|
103 |
+
return waveform
|
104 |
+
|
105 |
+
|
106 |
+
DATASET_PATH = "/data/kym/AI_Music_Detection/audio/FakeMusicCaps"
|
107 |
+
SUNOCAPS_PATH = "/data/kym/Audio/SunoCaps" # Open Set 포함 데이터
|
108 |
+
|
109 |
+
real_files = glob.glob(os.path.join(DATASET_PATH, "real", "**", "*.wav"), recursive=True)
|
110 |
+
gen_files = glob.glob(os.path.join(DATASET_PATH, "generative", "**", "*.wav"), recursive=True)
|
111 |
+
|
112 |
+
open_real_files = real_files + glob.glob(os.path.join(SUNOCAPS_PATH, "real", "**", "*.wav"), recursive=True)
|
113 |
+
open_gen_files = gen_files + glob.glob(os.path.join(SUNOCAPS_PATH, "generative", "**", "*.wav"), recursive=True)
|
114 |
+
|
115 |
+
real_labels = [0] * len(real_files)
|
116 |
+
gen_labels = [1] * len(gen_files)
|
117 |
+
|
118 |
+
open_real_labels = [0] * len(open_real_files)
|
119 |
+
open_gen_labels = [1] * len(open_gen_files)
|
120 |
+
|
121 |
+
real_train, real_val, real_train_labels, real_val_labels = train_test_split(real_files, real_labels, test_size=0.2, random_state=42)
|
122 |
+
gen_train, gen_val, gen_train_labels, gen_val_labels = train_test_split(gen_files, gen_labels, test_size=0.2, random_state=42)
|
123 |
+
|
124 |
+
train_files = real_train + gen_train
|
125 |
+
train_labels = real_train_labels + gen_train_labels
|
126 |
+
val_files = real_val + gen_val
|
127 |
+
val_labels = real_val_labels + gen_val_labels
|
128 |
+
|
129 |
+
closed_test_files = real_files + gen_files
|
130 |
+
closed_test_labels = real_labels + gen_labels
|
131 |
+
|
132 |
+
open_test_files = open_real_files + open_gen_files
|
133 |
+
open_test_labels = open_real_labels + open_gen_labels
|
134 |
+
|
135 |
+
ros = RandomOverSampler(sampling_strategy='auto', random_state=42)
|
136 |
+
train_files_resampled, train_labels_resampled = ros.fit_resample(np.array(train_files).reshape(-1, 1), train_labels)
|
137 |
+
|
138 |
+
train_files = train_files_resampled.reshape(-1).tolist()
|
139 |
+
train_labels = train_labels_resampled
|
140 |
+
|
141 |
+
print(f"Train Original FAKE: {len(gen_train)}")
|
142 |
+
print(f"Train set (Oversampled) - REAL: {sum(1 for label in train_labels if label == 0)}, "
|
143 |
+
f"FAKE: {sum(1 for label in train_labels if label == 1)}, Total: {len(train_files)}")
|
144 |
+
print(f"Validation set - REAL: {len(real_val)}, FAKE: {len(gen_val)}, Total: {len(val_files)}")
|
ISMIR_2025/music2vec/inference.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchaudio
|
5 |
+
import argparse
|
6 |
+
from datalib import preprocess_audio
|
7 |
+
from networks import Wav2Vec2ForFakeMusic
|
8 |
+
|
9 |
+
# Argument Parsing
|
10 |
+
parser = argparse.ArgumentParser(description="Wav2Vec2 AI Music Detection Inference")
|
11 |
+
parser.add_argument('--gpu', type=str, default='0', help='GPU ID')
|
12 |
+
parser.add_argument('--model_name', type=str, choices=['Wav2Vec2ForFakeMusic'], default='Wav2Vec2ForFakeMusic', help='Model name')
|
13 |
+
parser.add_argument('--ckpt_path', type=str, default='/data/kym/AI_Music_Detection/Code/model/wav2vec/ckpt/', help='Checkpoint directory')
|
14 |
+
parser.add_argument('--model_type', type=str, choices=['pretrain', 'finetune'], required=True, help='Choose between pretrained or fine-tuned model')
|
15 |
+
parser.add_argument('--inference', type=str, required=True, help='Path to a .wav file for inference')
|
16 |
+
args = parser.parse_args()
|
17 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
|
18 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
19 |
+
|
20 |
+
# Load Model Checkpoint
|
21 |
+
if args.model_type == 'pretrain':
|
22 |
+
model_file = os.path.join(args.ckpt_path, "wav2vec2_pretrain_10.pth")
|
23 |
+
elif args.model_type == 'finetune':
|
24 |
+
model_file = os.path.join(args.ckpt_path, "wav2vec2_finetune_5.pth")
|
25 |
+
else:
|
26 |
+
raise ValueError("Invalid model type. Choose between 'pretrain' or 'finetune'.")
|
27 |
+
|
28 |
+
if not os.path.exists(model_file):
|
29 |
+
raise FileNotFoundError(f"Model checkpoint not found: {model_file}")
|
30 |
+
|
31 |
+
if args.model_name == 'Wav2Vec2ForFakeMusic':
|
32 |
+
model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=(args.model_type == 'finetune'))
|
33 |
+
else:
|
34 |
+
raise ValueError(f"Invalid model name: {args.model_name}")
|
35 |
+
|
36 |
+
def predict(audio_path):
|
37 |
+
print(f"\n🔍 Loading model from {model_file}")
|
38 |
+
|
39 |
+
if not os.path.exists(audio_path):
|
40 |
+
raise FileNotFoundError(f"[ERROR] Audio file not found: {audio_path}")
|
41 |
+
|
42 |
+
model.to(device)
|
43 |
+
model.eval()
|
44 |
+
|
45 |
+
input_tensor = preprocess_audio(audio_path).to(device)
|
46 |
+
print(f"Input shape after preprocessing: {input_tensor.shape}")
|
47 |
+
|
48 |
+
with torch.no_grad():
|
49 |
+
output = model(input_tensor)
|
50 |
+
print(f"Raw model output (logits): {output}")
|
51 |
+
|
52 |
+
probabilities = F.softmax(output, dim=1)
|
53 |
+
ai_music_prob = probabilities[0, 1].item()
|
54 |
+
|
55 |
+
print(f"Softmax Probabilities: {probabilities}")
|
56 |
+
print(f"AI Music Probability: {ai_music_prob:.4f}")
|
57 |
+
|
58 |
+
if ai_music_prob > 0.5:
|
59 |
+
print(f" FAKE MUSIC DETECTED ({ai_music_prob:.2%})")
|
60 |
+
else:
|
61 |
+
print(f" REAL MUSIC DETECTED ({100 - ai_music_prob * 100:.2f}%)")
|
62 |
+
|
63 |
+
if __name__ == "__main__":
|
64 |
+
predict(args.inference)
|
ISMIR_2025/music2vec/main.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.optim as optim
|
7 |
+
from tqdm import tqdm
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from sklearn.metrics import f1_score, precision_score, recall_score, balanced_accuracy_score
|
10 |
+
import wandb
|
11 |
+
import argparse
|
12 |
+
from transformers import Wav2Vec2Processor
|
13 |
+
from datalib import FakeMusicCapsDataset, train_files, train_labels, val_files, val_labels
|
14 |
+
from networks import Music2VecClassifier, CCV
|
15 |
+
|
16 |
+
parser = argparse.ArgumentParser(description='AI Music Detection Training with Music2Vec + CCV')
|
17 |
+
parser.add_argument('--gpu', type=str, default='2', help='GPU ID')
|
18 |
+
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
|
19 |
+
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
|
20 |
+
parser.add_argument('--finetune_lr', type=float, default=1e-3, help='Fine-Tune Learning rate')
|
21 |
+
parser.add_argument('--pretrain_epochs', type=int, default=20, help='Pretraining epochs (REAL data only)')
|
22 |
+
parser.add_argument('--finetune_epochs', type=int, default=10, help='Fine-tuning epochs (REAL + FAKE data)')
|
23 |
+
parser.add_argument('--checkpoint_dir', type=str, default='', help='Checkpoint directory')
|
24 |
+
parser.add_argument('--weight_decay', type=float, default=0.001, help="Weight decay for optimizer")
|
25 |
+
|
26 |
+
args = parser.parse_args()
|
27 |
+
|
28 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
|
29 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
+
|
31 |
+
torch.manual_seed(42)
|
32 |
+
random.seed(42)
|
33 |
+
np.random.seed(42)
|
34 |
+
|
35 |
+
wandb.init(project="music2vec_ccv", name=f"pretrain_{args.pretrain_epochs}_finetune_{args.finetune_epochs}", config=args)
|
36 |
+
|
37 |
+
print("Preparing datasets...")
|
38 |
+
train_dataset = FakeMusicCapsDataset(train_files, train_labels)
|
39 |
+
val_dataset = FakeMusicCapsDataset(val_files, val_labels)
|
40 |
+
|
41 |
+
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
42 |
+
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
|
43 |
+
|
44 |
+
pretrain_ckpt = os.path.join(args.checkpoint_dir, f"music2vec_pretrain_{args.pretrain_epochs}.pth")
|
45 |
+
finetune_ckpt = os.path.join(args.checkpoint_dir, f"music2vec_ccv_finetune_{args.finetune_epochs}.pth")
|
46 |
+
|
47 |
+
print("Initializing Music2Vec model for Pretraining...")
|
48 |
+
processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h")
|
49 |
+
model = Music2VecClassifier(freeze_feature_extractor=False).to(device) # Pretraining에서는 freeze
|
50 |
+
|
51 |
+
criterion = nn.CrossEntropyLoss()
|
52 |
+
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
|
53 |
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
|
54 |
+
|
55 |
+
def train(model, dataloader, optimizer, criterion, device, epoch, phase="Pretrain"):
|
56 |
+
model.train()
|
57 |
+
total_loss, total_correct, total_samples = 0, 0, 0
|
58 |
+
all_preds, all_labels = [], []
|
59 |
+
|
60 |
+
for inputs, labels in tqdm(dataloader, desc=f"{phase} Training Epoch {epoch+1}"):
|
61 |
+
labels = labels.to(device)
|
62 |
+
inputs = inputs.to(device)
|
63 |
+
|
64 |
+
logits = model(inputs)
|
65 |
+
loss = criterion(logits, labels)
|
66 |
+
|
67 |
+
optimizer.zero_grad()
|
68 |
+
loss.backward()
|
69 |
+
optimizer.step()
|
70 |
+
|
71 |
+
total_loss += loss.item()
|
72 |
+
preds = logits.argmax(dim=1)
|
73 |
+
total_correct += (preds == labels).sum().item()
|
74 |
+
total_samples += labels.size(0)
|
75 |
+
all_preds.extend(preds.cpu().numpy())
|
76 |
+
all_labels.extend(labels.cpu().numpy())
|
77 |
+
|
78 |
+
scheduler.step()
|
79 |
+
accuracy = total_correct / total_samples
|
80 |
+
f1 = f1_score(all_labels, all_preds, average="binary")
|
81 |
+
balanced_acc = balanced_accuracy_score(all_labels, all_preds)
|
82 |
+
precision = precision_score(all_labels, all_preds, average="binary")
|
83 |
+
recall = recall_score(all_labels, all_preds, average="binary")
|
84 |
+
|
85 |
+
wandb.log({
|
86 |
+
f"{phase} Train Loss": total_loss / len(dataloader),
|
87 |
+
f"{phase} Train Accuracy": accuracy,
|
88 |
+
f"{phase} Train F1 Score": f1,
|
89 |
+
f"{phase} Train Precision": precision,
|
90 |
+
f"{phase} Train Recall": recall,
|
91 |
+
f"{phase} Train Balanced Accuracy": balanced_acc,
|
92 |
+
})
|
93 |
+
|
94 |
+
print(f"{phase} Train Epoch {epoch+1}: Train Loss: {total_loss / len(dataloader):.4f}, "
|
95 |
+
f"Train Acc: {accuracy:.4f}, Train F1: {f1:.4f}, Train Prec: {precision:.4f}, Train Rec: {recall:.4f}, B_ACC: {balanced_acc:.4f}")
|
96 |
+
|
97 |
+
def validate(model, dataloader, criterion, device, phase="Validation"):
|
98 |
+
model.eval()
|
99 |
+
total_loss, total_correct, total_samples = 0, 0, 0
|
100 |
+
all_preds, all_labels = [], []
|
101 |
+
|
102 |
+
with torch.no_grad():
|
103 |
+
for inputs, labels in tqdm(dataloader, desc=f"{phase}"):
|
104 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
105 |
+
inputs = inputs.squeeze(1)
|
106 |
+
outputs = model(inputs)
|
107 |
+
loss = criterion(outputs, labels)
|
108 |
+
|
109 |
+
total_loss += loss.item()
|
110 |
+
preds = outputs.argmax(dim=1)
|
111 |
+
total_correct += (preds == labels).sum().item()
|
112 |
+
total_samples += labels.size(0)
|
113 |
+
|
114 |
+
all_preds.extend(preds.cpu().numpy())
|
115 |
+
all_labels.extend(labels.cpu().numpy())
|
116 |
+
|
117 |
+
accuracy = total_correct / total_samples
|
118 |
+
f1 = f1_score(all_labels, all_preds, average="weighted")
|
119 |
+
val_bal_acc = balanced_accuracy_score(all_labels, all_preds)
|
120 |
+
val_precision = precision_score(all_labels, all_preds, average="binary")
|
121 |
+
val_recall = recall_score(all_labels, all_preds, average="binary")
|
122 |
+
|
123 |
+
wandb.log({
|
124 |
+
f"{phase} Val Loss": total_loss / len(dataloader),
|
125 |
+
f"{phase} Val Accuracy": accuracy,
|
126 |
+
f"{phase} Val F1 Score": f1,
|
127 |
+
f"{phase} Val Precision": val_precision,
|
128 |
+
f"{phase} Val Recall": val_recall,
|
129 |
+
f"{phase} Val Balanced Accuracy": val_bal_acc,
|
130 |
+
})
|
131 |
+
print(f"{phase} Val Loss: {total_loss / len(dataloader):.4f}, "
|
132 |
+
f"Val Acc: {accuracy:.4f}, Val F1: {f1:.4f}, Val Prec: {val_precision:.4f}, Val Rec: {val_recall:.4f}, Val B_ACC: {val_bal_acc:.4f}")
|
133 |
+
return total_loss / len(dataloader), accuracy, f1
|
134 |
+
|
135 |
+
print("\nStep 1: Self-Supervised Pretraining on REAL Data")
|
136 |
+
for epoch in range(args.pretrain_epochs):
|
137 |
+
train(model, train_loader, optimizer, criterion, device, epoch, phase="Pretrain")
|
138 |
+
|
139 |
+
torch.save(model.state_dict(), pretrain_ckpt)
|
140 |
+
print(f"\nPretraining completed! Model saved at: {pretrain_ckpt}")
|
141 |
+
|
142 |
+
print("\nInitializing Music2Vec + CCV Model for Fine-Tuning...")
|
143 |
+
model.load_state_dict(torch.load(pretrain_ckpt))
|
144 |
+
|
145 |
+
# model = CCV(embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True).to(device)
|
146 |
+
model = Music2VecClassifier(freeze_feature_extractor=False).to(device)
|
147 |
+
optimizer = optim.Adam(model.parameters(), lr=args.finetune_lr, weight_decay=args.weight_decay)
|
148 |
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
|
149 |
+
|
150 |
+
print("\nStep 2: Fine-Tuning CCV Model using Music2Vec Features")
|
151 |
+
for epoch in range(args.finetune_epochs):
|
152 |
+
train(model, train_loader, optimizer, criterion, device, epoch, phase="Fine-Tune")
|
153 |
+
|
154 |
+
torch.save(model.state_dict(), finetune_ckpt)
|
155 |
+
print(f"\nFine-Tuning completed! Model saved at: {finetune_ckpt}")
|
ISMIR_2025/music2vec/networks.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import Data2VecAudioModel, Wav2Vec2Processor
|
4 |
+
|
5 |
+
class Music2VecClassifier(nn.Module):
|
6 |
+
def __init__(self, num_classes=2, freeze_feature_extractor=True):
|
7 |
+
super(Music2VecClassifier, self).__init__()
|
8 |
+
|
9 |
+
self.processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h")
|
10 |
+
self.music2vec = Data2VecAudioModel.from_pretrained("m-a-p/music2vec-v1")
|
11 |
+
|
12 |
+
if freeze_feature_extractor:
|
13 |
+
for param in self.music2vec.parameters():
|
14 |
+
param.requires_grad = False
|
15 |
+
|
16 |
+
# Conv1d for learnable weighted average across layers
|
17 |
+
self.conv1d = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1)
|
18 |
+
|
19 |
+
# Classification head
|
20 |
+
self.classifier = nn.Sequential(
|
21 |
+
nn.Linear(self.music2vec.config.hidden_size, 256),
|
22 |
+
nn.ReLU(),
|
23 |
+
nn.Dropout(0.3),
|
24 |
+
nn.Linear(256, num_classes)
|
25 |
+
)
|
26 |
+
|
27 |
+
def forward(self, input_values):
|
28 |
+
input_values = input_values.squeeze(1) # Ensure shape [batch, time]
|
29 |
+
|
30 |
+
with torch.no_grad():
|
31 |
+
outputs = self.music2vec(input_values, output_hidden_states=True)
|
32 |
+
hidden_states = torch.stack(outputs.hidden_states)
|
33 |
+
time_reduced = hidden_states.mean(dim=2)
|
34 |
+
time_reduced = time_reduced.permute(1, 0, 2)
|
35 |
+
weighted_avg = self.conv1d(time_reduced).squeeze(1)
|
36 |
+
|
37 |
+
return self.classifier(weighted_avg), weighted_avg
|
38 |
+
|
39 |
+
def unfreeze_feature_extractor(self):
|
40 |
+
for param in self.music2vec.parameters():
|
41 |
+
param.requires_grad = True
|
42 |
+
|
43 |
+
class Music2VecFeatureExtractor(nn.Module):
|
44 |
+
def __init__(self, freeze_feature_extractor=True):
|
45 |
+
super(Music2VecFeatureExtractor, self).__init__()
|
46 |
+
self.processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h")
|
47 |
+
self.music2vec = Data2VecAudioModel.from_pretrained("m-a-p/music2vec-v1")
|
48 |
+
|
49 |
+
if freeze_feature_extractor:
|
50 |
+
for param in self.music2vec.parameters():
|
51 |
+
param.requires_grad = False
|
52 |
+
|
53 |
+
# Conv1d for learnable weighted average across layers
|
54 |
+
self.conv1d = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1)
|
55 |
+
|
56 |
+
def forward(self, input_values):
|
57 |
+
# input_values: [batch, time]
|
58 |
+
input_values = input_values.squeeze(1)
|
59 |
+
with torch.no_grad():
|
60 |
+
outputs = self.music2vec(input_values, output_hidden_states=True)
|
61 |
+
hidden_states = torch.stack(outputs.hidden_states) # [num_layers, batch, time, hidden_dim]
|
62 |
+
time_reduced = hidden_states.mean(dim=2) # [num_layers, batch, hidden_dim]
|
63 |
+
time_reduced = time_reduced.permute(1, 0, 2) # [batch, num_layers, hidden_dim]
|
64 |
+
weighted_avg = self.conv1d(time_reduced).squeeze(1) # [batch, hidden_dim]
|
65 |
+
return weighted_avg
|
66 |
+
|
67 |
+
'''
|
68 |
+
music2vec+CCV
|
69 |
+
# '''
|
70 |
+
# import torch
|
71 |
+
# import torch.nn as nn
|
72 |
+
# from transformers import Data2VecAudioModel, Wav2Vec2Processor
|
73 |
+
# import torch.nn.functional as F
|
74 |
+
|
75 |
+
|
76 |
+
# ### Music2Vec Feature Extractor (Pretrained Model)
|
77 |
+
# class Music2VecFeatureExtractor(nn.Module):
|
78 |
+
# def __init__(self, freeze_feature_extractor=True):
|
79 |
+
# super(Music2VecFeatureExtractor, self).__init__()
|
80 |
+
|
81 |
+
# self.processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h")
|
82 |
+
# self.music2vec = Data2VecAudioModel.from_pretrained("m-a-p/music2vec-v1")
|
83 |
+
|
84 |
+
# if freeze_feature_extractor:
|
85 |
+
# for param in self.music2vec.parameters():
|
86 |
+
# param.requires_grad = False
|
87 |
+
|
88 |
+
# # Conv1d for learnable weighted average across layers
|
89 |
+
# self.conv1d = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1)
|
90 |
+
|
91 |
+
# def forward(self, input_values):
|
92 |
+
# with torch.no_grad():
|
93 |
+
# outputs = self.music2vec(input_values, output_hidden_states=True)
|
94 |
+
|
95 |
+
# hidden_states = torch.stack(outputs.hidden_states) # [13, batch, time, hidden_size]
|
96 |
+
# time_reduced = hidden_states.mean(dim=2) # 평균 풀링: [13, batch, hidden_size]
|
97 |
+
# time_reduced = time_reduced.permute(1, 0, 2) # [batch, 13, hidden_size]
|
98 |
+
# weighted_avg = self.conv1d(time_reduced).squeeze(1) # [batch, hidden_size]
|
99 |
+
|
100 |
+
# return weighted_avg # Extracted feature representation
|
101 |
+
|
102 |
+
|
103 |
+
# def unfreeze_feature_extractor(self):
|
104 |
+
# for param in self.music2vec.parameters():
|
105 |
+
# param.requires_grad = True # Unfreeze for Fine-tuning
|
106 |
+
|
107 |
+
# ### CNN Feature Extractor for CCV
|
108 |
+
class CNNEncoder(nn.Module):
|
109 |
+
def __init__(self, embed_dim=512):
|
110 |
+
super(CNNEncoder, self).__init__()
|
111 |
+
self.conv_block = nn.Sequential(
|
112 |
+
nn.Conv2d(1, 16, kernel_size=3, padding=1),
|
113 |
+
nn.ReLU(),
|
114 |
+
nn.MaxPool2d((2,1)), # 기존 MaxPool2d(2)를 MaxPool2d((2,1))으로 변경
|
115 |
+
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
116 |
+
nn.ReLU(),
|
117 |
+
nn.MaxPool2d((1,1)), # 추가된 MaxPool2d(1,1)로 크기 유지
|
118 |
+
nn.AdaptiveAvgPool2d((4, 4)) # 최종 크기 조정
|
119 |
+
)
|
120 |
+
self.projection = nn.Linear(32 * 4 * 4, embed_dim)
|
121 |
+
|
122 |
+
def forward(self, x):
|
123 |
+
# print(f"Input shape before CNNEncoder: {x.shape}") # 디버깅용 출력
|
124 |
+
x = self.conv_block(x)
|
125 |
+
B, C, H, W = x.shape
|
126 |
+
x = x.view(B, -1)
|
127 |
+
x = self.projection(x)
|
128 |
+
return x
|
129 |
+
|
130 |
+
|
131 |
+
### Cross-Attention Module
|
132 |
+
class CrossAttentionLayer(nn.Module):
|
133 |
+
def __init__(self, embed_dim, num_heads):
|
134 |
+
super(CrossAttentionLayer, self).__init__()
|
135 |
+
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
|
136 |
+
self.layer_norm = nn.LayerNorm(embed_dim)
|
137 |
+
self.feed_forward = nn.Sequential(
|
138 |
+
nn.Linear(embed_dim, embed_dim * 4),
|
139 |
+
nn.ReLU(),
|
140 |
+
nn.Linear(embed_dim * 4, embed_dim)
|
141 |
+
)
|
142 |
+
self.attention_weights = None
|
143 |
+
|
144 |
+
def forward(self, x, cross_input):
|
145 |
+
attn_output, attn_weights = self.multihead_attn(query=x, key=cross_input, value=cross_input)
|
146 |
+
self.attention_weights = attn_weights
|
147 |
+
x = self.layer_norm(x + attn_output)
|
148 |
+
feed_forward_output = self.feed_forward(x)
|
149 |
+
x = self.layer_norm(x + feed_forward_output)
|
150 |
+
return x
|
151 |
+
|
152 |
+
### Cross-Attention Transformer
|
153 |
+
class CrossAttentionViT(nn.Module):
|
154 |
+
def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2):
|
155 |
+
super(CrossAttentionViT, self).__init__()
|
156 |
+
|
157 |
+
self.cross_attention_layers = nn.ModuleList([
|
158 |
+
CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers)
|
159 |
+
])
|
160 |
+
|
161 |
+
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
|
162 |
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
163 |
+
|
164 |
+
self.classifier = nn.Sequential(
|
165 |
+
nn.LayerNorm(embed_dim),
|
166 |
+
nn.Linear(embed_dim, num_classes)
|
167 |
+
)
|
168 |
+
|
169 |
+
def forward(self, x, cross_attention_input):
|
170 |
+
self.attention_maps = []
|
171 |
+
for layer in self.cross_attention_layers:
|
172 |
+
x = layer(x, cross_attention_input)
|
173 |
+
self.attention_maps.append(layer.attention_weights)
|
174 |
+
|
175 |
+
x = x.unsqueeze(1).permute(1, 0, 2)
|
176 |
+
x = self.transformer(x)
|
177 |
+
x = x.mean(dim=0)
|
178 |
+
x = self.classifier(x)
|
179 |
+
return x
|
180 |
+
|
181 |
+
### CCV Model (Final Classifier)
|
182 |
+
# class CCV(nn.Module):
|
183 |
+
# def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True):
|
184 |
+
# super(CCV, self).__init__()
|
185 |
+
|
186 |
+
# self.music2vec_extractor = Music2VecClassifier(freeze_feature_extractor=freeze_feature_extractor)
|
187 |
+
|
188 |
+
# # CNN Encoder for Image Representation
|
189 |
+
# self.encoder = CNNEncoder(embed_dim=embed_dim)
|
190 |
+
|
191 |
+
# # Transformer with Cross-Attention
|
192 |
+
# self.decoder = CrossAttentionViT(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes)
|
193 |
+
|
194 |
+
# def forward(self, x, cross_attention_input=None):
|
195 |
+
# x = self.music2vec_extractor(x)
|
196 |
+
# # print(f"After Music2VecExtractor: {x.shape}") # (batch, 2) 출력됨
|
197 |
+
|
198 |
+
# # CNNEncoder가 기대하는 입력 크기 맞추기
|
199 |
+
# x = x.unsqueeze(1).unsqueeze(-1) # (batch, 1, 2, 1) 형태로 변환
|
200 |
+
# # print(f"Before CNNEncoder: {x.shape}") # CNN 입력 확인
|
201 |
+
|
202 |
+
# x = self.encoder(x)
|
203 |
+
|
204 |
+
# if cross_attention_input is None:
|
205 |
+
# cross_attention_input = x
|
206 |
+
|
207 |
+
# x = self.decoder(x, cross_attention_input)
|
208 |
+
|
209 |
+
# return x
|
210 |
+
|
211 |
+
class CCV(nn.Module):
|
212 |
+
def __init__(self, embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True):
|
213 |
+
super(CCV, self).__init__()
|
214 |
+
self.feature_extractor = Music2VecFeatureExtractor(freeze_feature_extractor=freeze_feature_extractor)
|
215 |
+
|
216 |
+
# Cross-Attention Transformer
|
217 |
+
self.cross_attention_layers = nn.ModuleList([
|
218 |
+
CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers)
|
219 |
+
])
|
220 |
+
|
221 |
+
# Transformer Encoder
|
222 |
+
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
|
223 |
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
224 |
+
|
225 |
+
# Classification Head
|
226 |
+
self.classifier = nn.Sequential(
|
227 |
+
nn.LayerNorm(embed_dim),
|
228 |
+
nn.Linear(embed_dim, num_classes)
|
229 |
+
)
|
230 |
+
|
231 |
+
def forward(self, input_values):
|
232 |
+
# Extract feature embeddings
|
233 |
+
features = self.feature_extractor(input_values) # [batch, feature_dim]
|
234 |
+
# Average over layer dimension if necessary (여기서는 이미 [batch, hidden_dim])
|
235 |
+
# Apply Cross-Attention Layers
|
236 |
+
for layer in self.cross_attention_layers:
|
237 |
+
features = layer(features.unsqueeze(1), features.unsqueeze(1)).squeeze(1)
|
238 |
+
# Transformer Encoding
|
239 |
+
encoded = self.transformer(features.unsqueeze(1))
|
240 |
+
encoded = encoded.mean(dim=1)
|
241 |
+
# Classification Head
|
242 |
+
logits = self.classifier(encoded)
|
243 |
+
return logits
|
244 |
+
|
245 |
+
def get_attention_maps(self):
|
246 |
+
# 만약 CrossAttentionLayer의 attention_maps를 사용하고 싶다면 구현
|
247 |
+
return None
|
ISMIR_2025/music2vec/test.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from sklearn.metrics import precision_score, recall_score, f1_score, balanced_accuracy_score, confusion_matrix
|
8 |
+
from datalib import (
|
9 |
+
FakeMusicCapsDataset,
|
10 |
+
closed_test_files, closed_test_labels,
|
11 |
+
open_test_files, open_test_labels,
|
12 |
+
val_files, val_labels
|
13 |
+
)
|
14 |
+
from networks import Music2VecClassifier
|
15 |
+
import argparse
|
16 |
+
|
17 |
+
'''
|
18 |
+
python3 test.py --gpu 1 --closed_test --ckpt_path ""
|
19 |
+
'''
|
20 |
+
parser = argparse.ArgumentParser(description="AI Music Detection Testing with Music2Vec")
|
21 |
+
parser.add_argument('--gpu', type=str, default='1', help='GPU ID')
|
22 |
+
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
|
23 |
+
parser.add_argument('--ckpt_path', type=str, default='/data/kym/AI_Music_Detection/Code/model/music2vec/ckpt/music2vec_pretrain_10.pth', help='Checkpoint directory')
|
24 |
+
parser.add_argument('--model_name', type=str, default="music2vec", help="Model name")
|
25 |
+
parser.add_argument('--closed_test', action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)")
|
26 |
+
parser.add_argument('--open_test', action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)")
|
27 |
+
parser.add_argument('--output_path', type=str, default='', help='Path to save test results')
|
28 |
+
|
29 |
+
args = parser.parse_args()
|
30 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
|
31 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
32 |
+
|
33 |
+
def plot_confusion_matrix(y_true, y_pred, classes, output_path):
|
34 |
+
cm = confusion_matrix(y_true, y_pred)
|
35 |
+
fig, ax = plt.subplots(figsize=(6, 6))
|
36 |
+
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
|
37 |
+
ax.figure.colorbar(im, ax=ax)
|
38 |
+
|
39 |
+
num_classes = cm.shape[0]
|
40 |
+
tick_labels = classes[:num_classes]
|
41 |
+
|
42 |
+
ax.set(xticks=np.arange(num_classes),
|
43 |
+
yticks=np.arange(num_classes),
|
44 |
+
xticklabels=tick_labels,
|
45 |
+
yticklabels=tick_labels,
|
46 |
+
ylabel='True label',
|
47 |
+
xlabel='Predicted label')
|
48 |
+
|
49 |
+
thresh = cm.max() / 2.
|
50 |
+
for i in range(cm.shape[0]):
|
51 |
+
for j in range(cm.shape[1]):
|
52 |
+
ax.text(j, i, format(cm[i, j], 'd'),
|
53 |
+
ha="center", va="center",
|
54 |
+
color="white" if cm[i, j] > thresh else "black")
|
55 |
+
|
56 |
+
fig.tight_layout()
|
57 |
+
plt.savefig(output_path)
|
58 |
+
plt.close(fig)
|
59 |
+
|
60 |
+
model = Music2VecClassifier().to(device)
|
61 |
+
|
62 |
+
ckpt_file = os.path.join(args.ckpt_path)
|
63 |
+
if not os.path.exists(ckpt_file):
|
64 |
+
raise FileNotFoundError(f"Checkpoint not found: {ckpt_file}")
|
65 |
+
|
66 |
+
print(f"\nLoading model from {ckpt_file}")
|
67 |
+
model.load_state_dict(torch.load(ckpt_file, map_location=device))
|
68 |
+
model.eval()
|
69 |
+
|
70 |
+
torch.cuda.empty_cache()
|
71 |
+
|
72 |
+
if args.closed_test:
|
73 |
+
print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...")
|
74 |
+
test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels, target_duration=10.0)
|
75 |
+
elif args.open_test:
|
76 |
+
print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...")
|
77 |
+
test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels, target_duration=10.0)
|
78 |
+
else:
|
79 |
+
print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...")
|
80 |
+
test_dataset = FakeMusicCapsDataset(val_files, val_labels, target_duration=10.0)
|
81 |
+
|
82 |
+
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8)
|
83 |
+
|
84 |
+
def Test(model, test_loader, device):
|
85 |
+
model.eval()
|
86 |
+
test_loss, test_correct, test_total = 0, 0, 0
|
87 |
+
all_preds, all_labels = [], []
|
88 |
+
|
89 |
+
with torch.no_grad():
|
90 |
+
for data, target in test_loader:
|
91 |
+
data, target = data.to(device), target.to(device)
|
92 |
+
output = model(data)
|
93 |
+
loss = F.cross_entropy(output, target)
|
94 |
+
|
95 |
+
test_loss += loss.item() * data.size(0)
|
96 |
+
preds = output.argmax(dim=1)
|
97 |
+
test_correct += (preds == target).sum().item()
|
98 |
+
test_total += target.size(0)
|
99 |
+
|
100 |
+
all_labels.extend(target.cpu().numpy())
|
101 |
+
all_preds.extend(preds.cpu().numpy())
|
102 |
+
|
103 |
+
test_loss /= test_total
|
104 |
+
test_acc = test_correct / test_total
|
105 |
+
test_bal_acc = balanced_accuracy_score(all_labels, all_preds)
|
106 |
+
test_precision = precision_score(all_labels, all_preds, average="binary")
|
107 |
+
test_recall = recall_score(all_labels, all_preds, average="binary")
|
108 |
+
test_f1 = f1_score(all_labels, all_preds, average="binary")
|
109 |
+
|
110 |
+
print(f"\nTest Results - Loss: {test_loss:.4f} | Test Acc: {test_acc:.3f} | "
|
111 |
+
f"Test B_ACC: {test_bal_acc:.4f} | Test Prec: {test_precision:.3f} | "
|
112 |
+
f"Test Rec: {test_recall:.3f} | Test F1: {test_f1:.3f}")
|
113 |
+
|
114 |
+
os.makedirs(args.output_path, exist_ok=True)
|
115 |
+
conf_matrix_path = os.path.join(args.output_path, f"confusion_matrix_{args.model_name}.png")
|
116 |
+
plot_confusion_matrix(all_labels, all_preds, classes=["real", "generative"], output_path=conf_matrix_path)
|
117 |
+
|
118 |
+
print("\nEvaluating Model on Test Set...")
|
119 |
+
Test(model, test_loader, device)
|
ISMIR_2025/wav2vec/__pycache__/datalib.cpython-311.pyc
ADDED
Binary file (9.2 kB). View file
|
|
ISMIR_2025/wav2vec/__pycache__/loss.cpython-311.pyc
ADDED
Binary file (1.77 kB). View file
|
|
ISMIR_2025/wav2vec/__pycache__/networks.cpython-311.pyc
ADDED
Binary file (11.1 kB). View file
|
|
ISMIR_2025/wav2vec/__pycache__/networks.cpython-312.pyc
ADDED
Binary file (9.47 kB). View file
|
|
ISMIR_2025/wav2vec/__pycache__/wav2vec_datalib.cpython-311.pyc
ADDED
Binary file (9.15 kB). View file
|
|
ISMIR_2025/wav2vec/datalib.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
import librosa
|
6 |
+
import numpy as np
|
7 |
+
import utils
|
8 |
+
from sklearn.model_selection import train_test_split
|
9 |
+
from torch.utils.data import Dataset, DataLoader
|
10 |
+
import scipy.signal as signal
|
11 |
+
import scipy.signal
|
12 |
+
from scipy.signal import butter, lfilter
|
13 |
+
import numpy as np
|
14 |
+
import scipy.signal as signal
|
15 |
+
import librosa
|
16 |
+
import torch
|
17 |
+
import random
|
18 |
+
from torch.utils.data import Dataset
|
19 |
+
import logging
|
20 |
+
import csv
|
21 |
+
import logging
|
22 |
+
import time
|
23 |
+
import numpy as np
|
24 |
+
import h5py
|
25 |
+
import torch
|
26 |
+
import torchaudio
|
27 |
+
from imblearn.over_sampling import RandomOverSampler
|
28 |
+
from networks import Wav2Vec2ForFakeMusic
|
29 |
+
from transformers import Wav2Vec2Processor
|
30 |
+
import torchaudio.transforms as T
|
31 |
+
|
32 |
+
class FakeMusicCapsDataset(Dataset):
|
33 |
+
def __init__(self, file_paths, labels, sr=16000, target_duration=10.0):
|
34 |
+
self.file_paths = file_paths
|
35 |
+
self.labels = labels
|
36 |
+
self.sr = sr
|
37 |
+
self.target_duration = target_duration
|
38 |
+
self.target_samples = int(target_duration * sr)
|
39 |
+
|
40 |
+
self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
|
41 |
+
|
42 |
+
def highpass_filter(self, y, sr, cutoff=500, order=5):
|
43 |
+
if isinstance(sr, np.ndarray):
|
44 |
+
sr = np.mean(sr)
|
45 |
+
if not isinstance(sr, (int, float)):
|
46 |
+
raise ValueError(f"[ERROR] sr must be a number, but got {type(sr)}: {sr}")
|
47 |
+
if sr <= 0:
|
48 |
+
raise ValueError(f"Invalid sample rate: {sr}. It must be greater than 0.")
|
49 |
+
nyquist = 0.5 * sr
|
50 |
+
if cutoff <= 0 or cutoff >= nyquist:
|
51 |
+
print(f"[WARNING] Invalid cutoff frequency {cutoff}, adjusting...")
|
52 |
+
cutoff = max(10, min(cutoff, nyquist - 1))
|
53 |
+
normal_cutoff = cutoff / nyquist
|
54 |
+
b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
|
55 |
+
y_filtered = signal.lfilter(b, a, y)
|
56 |
+
return y_filtered
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return len(self.file_paths)
|
60 |
+
|
61 |
+
def __getitem__(self, idx):
|
62 |
+
audio_path = self.file_paths[idx]
|
63 |
+
label = self.labels[idx]
|
64 |
+
|
65 |
+
waveform, sr = torchaudio.load(audio_path)
|
66 |
+
waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sr)(waveform)
|
67 |
+
|
68 |
+
waveform = waveform.squeeze(0)
|
69 |
+
if label == 0:
|
70 |
+
waveform = self.augment_audio(waveform, self.sr)
|
71 |
+
if label == 1:
|
72 |
+
waveform = self.highpass_filter(waveform, self.sr)
|
73 |
+
|
74 |
+
current_samples = waveform.shape[0]
|
75 |
+
if current_samples > self.target_samples:
|
76 |
+
start_idx = (current_samples - self.target_samples) // 2
|
77 |
+
waveform = waveform[start_idx:start_idx + self.target_samples]
|
78 |
+
elif current_samples < self.target_samples:
|
79 |
+
waveform = torch.nn.functional.pad(waveform, (0, self.target_samples - current_samples))
|
80 |
+
|
81 |
+
waveform = torch.tensor(waveform, dtype=torch.float32).unsqueeze(0)
|
82 |
+
label = torch.tensor(label, dtype=torch.long)
|
83 |
+
|
84 |
+
return waveform, label
|
85 |
+
|
86 |
+
def preprocess_audio(audio_path, target_sr=16000, target_duration=10.0):
|
87 |
+
waveform, sr = librosa.load(audio_path, sr=target_sr)
|
88 |
+
|
89 |
+
target_samples = int(target_duration * target_sr)
|
90 |
+
current_samples = len(waveform)
|
91 |
+
|
92 |
+
if current_samples > target_samples:
|
93 |
+
waveform = waveform[:target_samples]
|
94 |
+
elif current_samples < target_samples:
|
95 |
+
waveform = np.pad(waveform, (0, target_samples - current_samples))
|
96 |
+
|
97 |
+
waveform = torch.tensor(waveform).unsqueeze(0)
|
98 |
+
return waveform
|
99 |
+
|
100 |
+
|
101 |
+
DATASET_PATH = "/data/kym/AI_Music_Detection/audio/FakeMusicCaps"
|
102 |
+
SUNOCAPS_PATH = "/data/kym/Audio/SunoCaps" # Open Set 포함 데이터
|
103 |
+
|
104 |
+
real_files = glob.glob(os.path.join(DATASET_PATH, "real", "**", "*.wav"), recursive=True)
|
105 |
+
gen_files = glob.glob(os.path.join(DATASET_PATH, "generative", "**", "*.wav"), recursive=True)
|
106 |
+
|
107 |
+
open_real_files = real_files + glob.glob(os.path.join(SUNOCAPS_PATH, "real", "**", "*.wav"), recursive=True)
|
108 |
+
open_gen_files = gen_files + glob.glob(os.path.join(SUNOCAPS_PATH, "generative", "**", "*.wav"), recursive=True)
|
109 |
+
|
110 |
+
real_labels = [0] * len(real_files)
|
111 |
+
gen_labels = [1] * len(gen_files)
|
112 |
+
|
113 |
+
open_real_labels = [0] * len(open_real_files)
|
114 |
+
open_gen_labels = [1] * len(open_gen_files)
|
115 |
+
|
116 |
+
real_train, real_val, real_train_labels, real_val_labels = train_test_split(real_files, real_labels, test_size=0.2, random_state=42)
|
117 |
+
gen_train, gen_val, gen_train_labels, gen_val_labels = train_test_split(gen_files, gen_labels, test_size=0.2, random_state=42)
|
118 |
+
|
119 |
+
train_files = real_train + gen_train
|
120 |
+
train_labels = real_train_labels + gen_train_labels
|
121 |
+
val_files = real_val + gen_val
|
122 |
+
val_labels = real_val_labels + gen_val_labels
|
123 |
+
|
124 |
+
closed_test_files = real_files + gen_files
|
125 |
+
closed_test_labels = real_labels + gen_labels
|
126 |
+
|
127 |
+
open_test_files = open_real_files + open_gen_files
|
128 |
+
open_test_labels = open_real_labels + open_gen_labels
|
129 |
+
|
130 |
+
ros = RandomOverSampler(sampling_strategy='auto', random_state=42)
|
131 |
+
train_files_resampled, train_labels_resampled = ros.fit_resample(np.array(train_files).reshape(-1, 1), train_labels)
|
132 |
+
|
133 |
+
train_files = train_files_resampled.reshape(-1).tolist()
|
134 |
+
train_labels = train_labels_resampled
|
135 |
+
|
136 |
+
print(f"Train Original FAKE: {len(gen_train)}")
|
137 |
+
print(f"Train set (Oversampled) - REAL: {sum(1 for label in train_labels if label == 0)}, "
|
138 |
+
f"FAKE: {sum(1 for label in train_labels if label == 1)}, Total: {len(train_files)}")
|
139 |
+
print(f"Validation set - REAL: {len(real_val)}, FAKE: {len(gen_val)}, Total: {len(val_files)}")
|
ISMIR_2025/wav2vec/inference.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchaudio
|
5 |
+
import argparse
|
6 |
+
from AI_Music_Detection.Code.model.wav2vec.wav2vec_datalib import preprocess_audio
|
7 |
+
from networks import Wav2Vec2ForFakeMusic
|
8 |
+
|
9 |
+
'''
|
10 |
+
command: python inference.py --gpu 0 --model_type pretrain --inference .wav
|
11 |
+
'''
|
12 |
+
parser = argparse.ArgumentParser(description="Wav2Vec2 AI Music Detection Inference")
|
13 |
+
parser.add_argument('--gpu', type=str, default='0', help='GPU ID')
|
14 |
+
parser.add_argument('--model_name', type=str, choices=['Wav2Vec2ForFakeMusic'], default='Wav2Vec2ForFakeMusic', help='Model name')
|
15 |
+
parser.add_argument('--ckpt_path', type=str, default='/data/kym/AI_Music_Detection/Code/model/wav2vec/ckpt/', help='Checkpoint directory')
|
16 |
+
parser.add_argument('--model_type', type=str, choices=['pretrain', 'finetune'], required=True, help='Choose between pretrained or fine-tuned model')
|
17 |
+
parser.add_argument('--inference', type=str, help='Path to a .wav file for inference')
|
18 |
+
args = parser.parse_args()
|
19 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
|
20 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
+
|
22 |
+
if args.model_type == 'pretrain':
|
23 |
+
model_file = os.path.join(args.ckpt_path, "wav2vec2_pretrain_10.pth")
|
24 |
+
elif args.model_type == 'finetune':
|
25 |
+
model_file = os.path.join(args.ckpt_path, "wav2vec2_finetune_5.pth")
|
26 |
+
else:
|
27 |
+
raise ValueError("Invalid model type. Choose between 'pretrain' or 'finetune'.")
|
28 |
+
|
29 |
+
if not os.path.exists(model_file):
|
30 |
+
raise FileNotFoundError(f"Model checkpoint not found: {model_file}")
|
31 |
+
|
32 |
+
if args.model_name == 'Wav2Vec2ForFakeMusic':
|
33 |
+
model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=(args.model_type == 'finetune'))
|
34 |
+
else:
|
35 |
+
raise ValueError(f"Invalid model name: {args.model_name}")
|
36 |
+
|
37 |
+
def predict(audio_path):
|
38 |
+
print(f"\n🔍 Loading model from {model_file}")
|
39 |
+
|
40 |
+
if not os.path.exists(audio_path):
|
41 |
+
raise FileNotFoundError(f"[ERROR] Audio file not found: {audio_path}")
|
42 |
+
|
43 |
+
model.to(device)
|
44 |
+
|
45 |
+
input_tensor = preprocess_audio(audio_path).to(device)
|
46 |
+
print(f"Input shape after preprocessing: {input_tensor.shape}")
|
47 |
+
|
48 |
+
with torch.no_grad():
|
49 |
+
output = model(input_tensor)
|
50 |
+
print(f"Raw model output (logits): {output}")
|
51 |
+
|
52 |
+
probabilities = F.softmax(output, dim=1)
|
53 |
+
ai_music_prob = probabilities[0, 1].item()
|
54 |
+
|
55 |
+
print(f"Softmax Probabilities: {probabilities}")
|
56 |
+
print(f"AI Music Probability: {ai_music_prob:.4f}")
|
57 |
+
|
58 |
+
if ai_music_prob > 0.5:
|
59 |
+
print(f" FAKE MUSIC DETECTED ({ai_music_prob:.2%})")
|
60 |
+
else:
|
61 |
+
print(f" REAL MUSIC DETECTED ({100 - ai_music_prob * 100:.2f}%)")
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
if args.inference:
|
67 |
+
if not os.path.exists(args.inference):
|
68 |
+
print(f"[ERROR] No File Found: {args.inference}")
|
69 |
+
else:
|
70 |
+
predict(args.inference)
|
71 |
+
|
ISMIR_2025/wav2vec/main.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.optim as optim
|
7 |
+
from tqdm import tqdm
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from sklearn.metrics import f1_score, precision_score, recall_score, balanced_accuracy_score, classification_report
|
10 |
+
import wandb
|
11 |
+
import argparse
|
12 |
+
from datalib import FakeMusicCapsDataset, train_files, train_labels, val_files, val_labels
|
13 |
+
from networks import Wav2Vec2ForFakeMusic
|
14 |
+
|
15 |
+
'''
|
16 |
+
python inference.py --gpu 0 --model_type finetune --inference
|
17 |
+
'''
|
18 |
+
parser = argparse.ArgumentParser(description='AI Music Detection Training with Wav2Vec 2.0')
|
19 |
+
parser.add_argument('--gpu', type=str, default='2', help='GPU ID')
|
20 |
+
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
|
21 |
+
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
|
22 |
+
parser.add_argument('--pretrain_epochs', type=int, default=20, help='Pretraining epochs (REAL data only)')
|
23 |
+
parser.add_argument('--finetune_epochs', type=int, default=10, help='Fine-tuning epochs (REAL + FAKE data)')
|
24 |
+
parser.add_argument('--checkpoint_dir', type=str, default='', help='Checkpoint directory')
|
25 |
+
parser.add_argument('--weight_decay', type=float, default=0.05, help="Weight decay for optimizer")
|
26 |
+
|
27 |
+
args = parser.parse_args()
|
28 |
+
|
29 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
|
30 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
31 |
+
|
32 |
+
torch.manual_seed(42)
|
33 |
+
random.seed(42)
|
34 |
+
np.random.seed(42)
|
35 |
+
|
36 |
+
wandb.init(project="", name=f"pretrain_{args.pretrain_epochs}_finetune_{args.finetune_epochs}", config=args)
|
37 |
+
|
38 |
+
print("Preparing datasets...")
|
39 |
+
train_dataset = FakeMusicCapsDataset(train_files, train_labels)
|
40 |
+
val_dataset = FakeMusicCapsDataset(val_files, val_labels)
|
41 |
+
|
42 |
+
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
43 |
+
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
|
44 |
+
|
45 |
+
pretrain_ckpt = os.path.join(args.checkpoint_dir, f"wav2vec2_pretrain_{args.pretrain_epochs}.pth")
|
46 |
+
finetune_ckpt = os.path.join(args.checkpoint_dir, f"wav2vec2_finetune_{args.finetune_epochs}.pth")
|
47 |
+
|
48 |
+
print("Initializing model...")
|
49 |
+
model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=True).to(device)
|
50 |
+
|
51 |
+
criterion = nn.CrossEntropyLoss()
|
52 |
+
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
|
53 |
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
|
54 |
+
|
55 |
+
def train(model, dataloader, optimizer, criterion, scheduler, device, epoch, phase="Pretrain"):
|
56 |
+
model.train()
|
57 |
+
total_loss, total_correct, total_samples = 0, 0, 0
|
58 |
+
all_preds, all_labels = [], []
|
59 |
+
attention_maps = []
|
60 |
+
|
61 |
+
for inputs, labels in tqdm(dataloader, desc=f"{phase} Training Epoch {epoch+1}"):
|
62 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
63 |
+
inputs = inputs.float()
|
64 |
+
|
65 |
+
outputs = model(inputs)
|
66 |
+
loss = criterion(outputs, labels)
|
67 |
+
|
68 |
+
optimizer.zero_grad()
|
69 |
+
loss.backward()
|
70 |
+
optimizer.step()
|
71 |
+
|
72 |
+
total_loss += loss.item()
|
73 |
+
preds = outputs.argmax(dim=1)
|
74 |
+
total_correct += (preds == labels).sum().item()
|
75 |
+
total_samples += labels.size(0)
|
76 |
+
|
77 |
+
all_preds.extend(preds.cpu().numpy())
|
78 |
+
all_labels.extend(labels.cpu().numpy())
|
79 |
+
|
80 |
+
if hasattr(model, "get_attention_maps"):
|
81 |
+
attention_maps.append(model.get_attention_maps())
|
82 |
+
|
83 |
+
scheduler.step()
|
84 |
+
|
85 |
+
accuracy = total_correct / total_samples
|
86 |
+
f1 = f1_score(all_labels, all_preds, average="weighted")
|
87 |
+
precision = precision_score(all_labels, all_preds, average="binary")
|
88 |
+
recall = recall_score(all_labels, all_preds, average="binary")
|
89 |
+
balanced_acc = balanced_accuracy_score(all_labels, all_preds)
|
90 |
+
|
91 |
+
wandb.log({
|
92 |
+
f"{phase} Train Loss": total_loss / len(dataloader),
|
93 |
+
f"{phase} Train Accuracy": accuracy,
|
94 |
+
f"{phase} Train F1 Score": f1,
|
95 |
+
f"{phase} Train Precision": precision,
|
96 |
+
f"{phase} Train Recall": recall,
|
97 |
+
f"{phase} Train Balanced Accuracy": balanced_acc,
|
98 |
+
})
|
99 |
+
|
100 |
+
print(f"{phase} Train Epoch {epoch+1}: Train Loss: {total_loss / len(dataloader):.4f}, "
|
101 |
+
f"Train Acc: {accuracy:.4f}, Train F1: {f1:.4f}, Train Prec: {precision:.4f}, Train Rec: {recall:.4f}, B_ACC: {balanced_acc:.4f}")
|
102 |
+
|
103 |
+
def validate(model, dataloader, criterion, device, phase="Validation"):
|
104 |
+
model.eval()
|
105 |
+
total_loss, total_correct, total_samples = 0, 0, 0
|
106 |
+
all_preds, all_labels = [], []
|
107 |
+
|
108 |
+
with torch.no_grad():
|
109 |
+
for inputs, labels in tqdm(dataloader, desc=f"{phase}"):
|
110 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
111 |
+
inputs = inputs.squeeze(1)
|
112 |
+
|
113 |
+
outputs = model(inputs)
|
114 |
+
loss = criterion(outputs, labels)
|
115 |
+
|
116 |
+
|
117 |
+
total_loss += loss.item()
|
118 |
+
preds = outputs.argmax(dim=1)
|
119 |
+
total_correct += (preds == labels).sum().item()
|
120 |
+
total_samples += labels.size(0)
|
121 |
+
|
122 |
+
all_preds.extend(preds.cpu().numpy())
|
123 |
+
all_labels.extend(labels.cpu().numpy())
|
124 |
+
|
125 |
+
accuracy = total_correct / total_samples
|
126 |
+
f1 = f1_score(all_labels, all_preds, average="weighted")
|
127 |
+
val_bal_acc = balanced_accuracy_score(all_labels, all_preds)
|
128 |
+
val_precision = precision_score(all_labels, all_preds, average="binary")
|
129 |
+
val_recall = recall_score(all_labels, all_preds, average="binary")
|
130 |
+
|
131 |
+
wandb.log({
|
132 |
+
f"{phase} Val Loss": total_loss / len(dataloader),
|
133 |
+
f"{phase} Val Accuracy": accuracy,
|
134 |
+
f"{phase} Val F1 Score": f1,
|
135 |
+
f"{phase} Val Precision": val_precision,
|
136 |
+
f"{phase} Val Recall": val_recall,
|
137 |
+
f"{phase} Val Balanced Accuracy": val_bal_acc,
|
138 |
+
})
|
139 |
+
print(f"{phase} Val Loss: {total_loss / len(dataloader):.4f}, "
|
140 |
+
f"Val Acc: {accuracy:.4f}, Val F1: {f1:.4f}, Val Prec: {val_precision:.4f}, Val Rec: {val_recall:.4f}, Val B_ACC: {val_bal_acc:.4f}")
|
141 |
+
return total_loss / len(dataloader), accuracy, f1
|
142 |
+
|
143 |
+
print("\nStep 1: Self-Supervised Pretraining on REAL Data")
|
144 |
+
for epoch in range(args.pretrain_epochs):
|
145 |
+
train(model, train_loader, optimizer, criterion, scheduler, device, epoch, phase="Pretrain")
|
146 |
+
|
147 |
+
torch.save(model.state_dict(), pretrain_ckpt)
|
148 |
+
print(f"\nPretraining completed! Model saved at: {pretrain_ckpt}")
|
149 |
+
|
150 |
+
model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=False).to(device)
|
151 |
+
model.load_state_dict(torch.load(pretrain_ckpt))
|
152 |
+
print(f"\n🔍 Loaded Pretrained Model from {pretrain_ckpt}")
|
153 |
+
|
154 |
+
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate / 10, weight_decay=args.weight_decay)
|
155 |
+
|
156 |
+
print("\nStep 2: Fine-Tuning on REAL + FAKE Data")
|
157 |
+
for epoch in range(args.finetune_epochs):
|
158 |
+
train(model, train_loader, optimizer, criterion, scheduler, device, epoch, phase="Fine-Tune")
|
159 |
+
validate(model, val_loader, criterion, device, phase="Fine-Tune Validation")
|
160 |
+
|
161 |
+
torch.save(model.state_dict(), finetune_ckpt)
|
162 |
+
print(f"\nFine-Tuning completed! Model saved at: {finetune_ckpt}")
|
ISMIR_2025/wav2vec/networks.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import seaborn as sns
|
6 |
+
|
7 |
+
'''
|
8 |
+
freeze_feature_extractor=True 시 Feature Extractor를 동결 (Pretraining)
|
9 |
+
unfreeze_feature_extractor()를 호출하면 Fine-Tuning 가능
|
10 |
+
'''
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
+
import seaborn as sns
|
16 |
+
from transformers import Wav2Vec2Model
|
17 |
+
|
18 |
+
class cnn(nn.Module):
|
19 |
+
def __init__(self, embed_dim=512):
|
20 |
+
super(cnn, self).__init__()
|
21 |
+
self.conv_block = nn.Sequential(
|
22 |
+
nn.Conv2d(1, 16, kernel_size=3, padding=1),
|
23 |
+
nn.ReLU(),
|
24 |
+
nn.MaxPool2d(2),
|
25 |
+
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
26 |
+
nn.ReLU(),
|
27 |
+
nn.MaxPool2d(2),
|
28 |
+
nn.AdaptiveAvgPool2d((4, 4))
|
29 |
+
)
|
30 |
+
self.projection = nn.Linear(32 * 4 * 4, embed_dim)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
x = self.conv_block(x)
|
34 |
+
B, C, H, W = x.shape
|
35 |
+
x = x.view(B, -1)
|
36 |
+
x = self.projection(x)
|
37 |
+
return x
|
38 |
+
|
39 |
+
class CrossAttentionLayer(nn.Module):
|
40 |
+
def __init__(self, embed_dim, num_heads):
|
41 |
+
super(CrossAttentionLayer, self).__init__()
|
42 |
+
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
|
43 |
+
self.layer_norm = nn.LayerNorm(embed_dim)
|
44 |
+
self.feed_forward = nn.Sequential(
|
45 |
+
nn.Linear(embed_dim, embed_dim * 4),
|
46 |
+
nn.ReLU(),
|
47 |
+
nn.Linear(embed_dim * 4, embed_dim)
|
48 |
+
)
|
49 |
+
self.attention_weights = None
|
50 |
+
|
51 |
+
def forward(self, x, cross_input):
|
52 |
+
# Cross-attention between x and cross_input
|
53 |
+
attn_output, attn_weights = self.multihead_attn(query=x, key=cross_input, value=cross_input)
|
54 |
+
self.attention_weights = attn_weights
|
55 |
+
x = self.layer_norm(x + attn_output)
|
56 |
+
feed_forward_output = self.feed_forward(x)
|
57 |
+
x = self.layer_norm(x + feed_forward_output)
|
58 |
+
return x
|
59 |
+
|
60 |
+
class CrossAttentionViT(nn.Module):
|
61 |
+
def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2):
|
62 |
+
super(CrossAttentionViT, self).__init__()
|
63 |
+
|
64 |
+
self.cross_attention_layers = nn.ModuleList([
|
65 |
+
CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers)
|
66 |
+
])
|
67 |
+
|
68 |
+
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
|
69 |
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
70 |
+
|
71 |
+
self.classifier = nn.Sequential(
|
72 |
+
nn.LayerNorm(embed_dim),
|
73 |
+
nn.Linear(embed_dim, num_classes)
|
74 |
+
)
|
75 |
+
|
76 |
+
def forward(self, x, cross_attention_input):
|
77 |
+
self.attention_maps = []
|
78 |
+
for layer in self.cross_attention_layers:
|
79 |
+
x = layer(x, cross_attention_input)
|
80 |
+
self.attention_maps.append(layer.attention_weights)
|
81 |
+
|
82 |
+
x = x.unsqueeze(1).permute(1, 0, 2)
|
83 |
+
x = self.transformer(x)
|
84 |
+
x = x.mean(dim=0)
|
85 |
+
x = self.classifier(x)
|
86 |
+
return x
|
87 |
+
|
88 |
+
class CCV(nn.Module):
|
89 |
+
def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2):
|
90 |
+
super(CCV, self).__init__()
|
91 |
+
self.encoder = cnn(embed_dim=embed_dim)
|
92 |
+
self.decoder = CrossAttentionViT(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes)
|
93 |
+
|
94 |
+
def forward(self, x, cross_attention_input=None):
|
95 |
+
x = self.encoder(x)
|
96 |
+
|
97 |
+
if cross_attention_input is None:
|
98 |
+
cross_attention_input = x
|
99 |
+
|
100 |
+
x = self.decoder(x, cross_attention_input)
|
101 |
+
|
102 |
+
# Attention Map 저장
|
103 |
+
self.attention_maps = self.decoder.attention_maps
|
104 |
+
|
105 |
+
return x
|
106 |
+
|
107 |
+
def get_attention_maps(self):
|
108 |
+
return self.attention_maps
|
109 |
+
|
110 |
+
import torch
|
111 |
+
import torch.nn as nn
|
112 |
+
from transformers import Wav2Vec2Model
|
113 |
+
|
114 |
+
class Wav2Vec2ForFakeMusic(nn.Module):
|
115 |
+
def __init__(self, num_classes=2, freeze_feature_extractor=True):
|
116 |
+
super(Wav2Vec2ForFakeMusic, self).__init__()
|
117 |
+
|
118 |
+
self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
|
119 |
+
|
120 |
+
if freeze_feature_extractor:
|
121 |
+
for param in self.wav2vec.parameters():
|
122 |
+
param.requires_grad = False
|
123 |
+
|
124 |
+
self.classifier = nn.Sequential(
|
125 |
+
nn.Linear(self.wav2vec.config.hidden_size, 256), # 768 → 256
|
126 |
+
nn.ReLU(),
|
127 |
+
nn.Dropout(0.3),
|
128 |
+
nn.Linear(256, num_classes) # 256 → 2 (Binary Classification)
|
129 |
+
)
|
130 |
+
|
131 |
+
def forward(self, x):
|
132 |
+
x = x.squeeze(1)
|
133 |
+
output = self.wav2vec(x)
|
134 |
+
features = output["last_hidden_state"] # (batch_size, seq_len, feature_dim)
|
135 |
+
pooled_features = features.mean(dim=1) # ✅ Mean Pooling 적용 (batch_size, feature_dim)
|
136 |
+
logits = self.classifier(pooled_features) # (batch_size, num_classes)
|
137 |
+
|
138 |
+
return logits, pooled_features
|
139 |
+
|
140 |
+
|
141 |
+
def visualize_attention_map(attn_map, mel_spec, layer_idx):
|
142 |
+
attn_map = attn_map.mean(dim=1).squeeze().cpu().numpy() # 여러 head 평균
|
143 |
+
mel_spec = mel_spec.squeeze().cpu().numpy()
|
144 |
+
|
145 |
+
fig, axs = plt.subplots(2, 1, figsize=(10, 8))
|
146 |
+
|
147 |
+
# 1Log-Mel Spectrogram 시각화
|
148 |
+
sns.heatmap(mel_spec, cmap='inferno', ax=axs[0])
|
149 |
+
axs[0].set_title("Log-Mel Spectrogram")
|
150 |
+
axs[0].set_xlabel("Time Frames")
|
151 |
+
axs[0].set_ylabel("Mel Frequency Bins")
|
152 |
+
|
153 |
+
# Attention Map 시각화
|
154 |
+
sns.heatmap(attn_map, cmap='viridis', ax=axs[1])
|
155 |
+
axs[1].set_title(f"Attention Map (Layer {layer_idx})")
|
156 |
+
axs[1].set_xlabel("Time Frames")
|
157 |
+
axs[1].set_ylabel("Query Positions")
|
158 |
+
|
159 |
+
plt.tight_layout()
|
160 |
+
plt.show()
|
161 |
+
plt.savefig("/data/kym/AI_Music_Detection/Code/model/attention_map/crossattn.png")
|
ISMIR_2025/wav2vec/test.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from sklearn.metrics import precision_score, recall_score, f1_score, balanced_accuracy_score, confusion_matrix
|
8 |
+
from datalib import (
|
9 |
+
FakeMusicCapsDataset,
|
10 |
+
closed_test_files, closed_test_labels,
|
11 |
+
open_test_files, open_test_labels,
|
12 |
+
val_files, val_labels
|
13 |
+
)
|
14 |
+
from networks import Wav2Vec2ForFakeMusic
|
15 |
+
import tqdm
|
16 |
+
from tqdm import tqdm
|
17 |
+
import argparse
|
18 |
+
'''
|
19 |
+
python3 test.py --finetune_test --closed_test | --open_test
|
20 |
+
'''
|
21 |
+
parser = argparse.ArgumentParser(description="AI Music Detection Testing with Wav2Vec 2.0")
|
22 |
+
parser.add_argument('--gpu', type=str, default='0', help='GPU ID')
|
23 |
+
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
|
24 |
+
parser.add_argument('--ckpt_path', type=str, default='', help='Checkpoint directory')
|
25 |
+
parser.add_argument('--pretrain_test', action="store_true", help="Test Pretrained Wav2Vec2 Model")
|
26 |
+
parser.add_argument('--finetune_test', action="store_true", help="Test Fine-Tuned Wav2Vec2 Model")
|
27 |
+
parser.add_argument('--closed_test', action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)")
|
28 |
+
parser.add_argument('--open_test', action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)")
|
29 |
+
parser.add_argument('--output_path', type=str, default='', help='Path to save test results')
|
30 |
+
|
31 |
+
args = parser.parse_args()
|
32 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
|
33 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
34 |
+
|
35 |
+
def plot_confusion_matrix(y_true, y_pred, classes, output_path):
|
36 |
+
cm = confusion_matrix(y_true, y_pred)
|
37 |
+
fig, ax = plt.subplots(figsize=(6, 6))
|
38 |
+
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
|
39 |
+
ax.figure.colorbar(im, ax=ax)
|
40 |
+
|
41 |
+
num_classes = cm.shape[0]
|
42 |
+
tick_labels = classes[:num_classes]
|
43 |
+
|
44 |
+
ax.set(xticks=np.arange(num_classes),
|
45 |
+
yticks=np.arange(num_classes),
|
46 |
+
xticklabels=tick_labels,
|
47 |
+
yticklabels=tick_labels,
|
48 |
+
ylabel='True label',
|
49 |
+
xlabel='Predicted label')
|
50 |
+
|
51 |
+
thresh = cm.max() / 2.
|
52 |
+
for i in range(cm.shape[0]):
|
53 |
+
for j in range(cm.shape[1]):
|
54 |
+
ax.text(j, i, format(cm[i, j], 'd'),
|
55 |
+
ha="center", va="center",
|
56 |
+
color="white" if cm[i, j] > thresh else "black")
|
57 |
+
|
58 |
+
fig.tight_layout()
|
59 |
+
plt.savefig(output_path)
|
60 |
+
plt.close(fig)
|
61 |
+
|
62 |
+
if args.pretrain_test:
|
63 |
+
ckpt_file = os.path.join(args.ckpt_path, "wav2vec2_pretrain_20.pth")
|
64 |
+
print("\n🔍 Loading Pretrained Model:", ckpt_file)
|
65 |
+
model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=True).to(device)
|
66 |
+
|
67 |
+
elif args.finetune_test:
|
68 |
+
ckpt_file = os.path.join(args.ckpt_path, "wav2vec2_finetune_10.pth")
|
69 |
+
print("\n🔍 Loading Fine-Tuned Model:", ckpt_file)
|
70 |
+
model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=False).to(device)
|
71 |
+
|
72 |
+
else:
|
73 |
+
raise ValueError("You must specify --pretrain_test or --finetune_test")
|
74 |
+
|
75 |
+
if not os.path.exists(ckpt_file):
|
76 |
+
raise FileNotFoundError(f"Checkpoint not found: {ckpt_file}")
|
77 |
+
|
78 |
+
# model.load_state_dict(torch.load(ckpt_file, map_location=device))
|
79 |
+
# model.eval()
|
80 |
+
|
81 |
+
ckpt = torch.load(ckpt_file, map_location=device)
|
82 |
+
|
83 |
+
keys_to_remove = [key for key in ckpt.keys() if "masked_spec_embed" in key]
|
84 |
+
for key in keys_to_remove:
|
85 |
+
print(f"Removing unexpected key: {key}")
|
86 |
+
del ckpt[key]
|
87 |
+
|
88 |
+
try:
|
89 |
+
model.load_state_dict(ckpt, strict=False)
|
90 |
+
except RuntimeError as e:
|
91 |
+
print("Model loading error:", e)
|
92 |
+
print("Trying to load entire model...")
|
93 |
+
model = torch.load(ckpt_file, map_location=device)
|
94 |
+
model.to(device)
|
95 |
+
model.eval()
|
96 |
+
|
97 |
+
torch.cuda.empty_cache()
|
98 |
+
|
99 |
+
if args.closed_test:
|
100 |
+
print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...")
|
101 |
+
test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels)
|
102 |
+
elif args.open_test:
|
103 |
+
print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...")
|
104 |
+
test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels)
|
105 |
+
else:
|
106 |
+
print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...")
|
107 |
+
test_dataset = FakeMusicCapsDataset(val_files, val_labels)
|
108 |
+
|
109 |
+
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16)
|
110 |
+
|
111 |
+
def Test(model, test_loader, device, phase="Test"):
|
112 |
+
model.eval()
|
113 |
+
test_loss, test_correct, test_total = 0, 0, 0
|
114 |
+
all_preds, all_labels = [], []
|
115 |
+
|
116 |
+
with torch.no_grad():
|
117 |
+
for inputs, labels in tqdm(test_loader, desc=f"{phase}"):
|
118 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
119 |
+
inputs = inputs.squeeze(1) # Ensure correct input shape
|
120 |
+
|
121 |
+
output = model(inputs)
|
122 |
+
loss = F.cross_entropy(output, labels)
|
123 |
+
|
124 |
+
test_loss += loss.item() * inputs.size(0)
|
125 |
+
preds = output.argmax(dim=1)
|
126 |
+
test_correct += (preds == labels).sum().item()
|
127 |
+
test_total += labels.size(0)
|
128 |
+
|
129 |
+
all_labels.extend(labels.cpu().numpy())
|
130 |
+
all_preds.extend(preds.cpu().numpy())
|
131 |
+
|
132 |
+
test_loss /= test_total
|
133 |
+
test_acc = test_correct / test_total
|
134 |
+
test_bal_acc = balanced_accuracy_score(all_labels, all_preds)
|
135 |
+
test_precision = precision_score(all_labels, all_preds, average="binary")
|
136 |
+
test_recall = recall_score(all_labels, all_preds, average="binary")
|
137 |
+
test_f1 = f1_score(all_labels, all_preds, average="binary")
|
138 |
+
|
139 |
+
print(f"\n{phase} Test Results - Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.3f} | "
|
140 |
+
f"Test Balanced Acc: {test_bal_acc:.4f} | Test Precision: {test_precision:.3f} | "
|
141 |
+
f"Test Recall: {test_recall:.3f} | Test F1: {test_f1:.3f}")
|
142 |
+
|
143 |
+
os.makedirs(args.output_path, exist_ok=True)
|
144 |
+
conf_matrix_path = os.path.join(args.output_path, f"confusion_matrix_{phase}_opentest.png")
|
145 |
+
plot_confusion_matrix(all_labels, all_preds, classes=["real", "generative"], output_path=conf_matrix_path)
|
146 |
+
|
147 |
+
print("\nEvaluating Model on Test Set...")
|
148 |
+
Test(model, test_loader, device, phase="Pretrained Model" if args.pretrain_test else "Fine-Tuned Model")
|
ISMIR_2025/wav2vec/utils/__pycache__/config.cpython-311.pyc
ADDED
Binary file (4.79 kB). View file
|
|
ISMIR_2025/wav2vec/utils/__pycache__/idr_torch.cpython-311.pyc
ADDED
Binary file (1.01 kB). View file
|
|
ISMIR_2025/wav2vec/utils/__pycache__/utilities.cpython-311.pyc
ADDED
Binary file (16.1 kB). View file
|
|
ISMIR_2025/wav2vec/utils/config.py
ADDED
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import csv
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
sample_rate = 32000
|
9 |
+
clip_samples = sample_rate * 10 # Audio clips are 10-second
|
10 |
+
|
11 |
+
# Load label
|
12 |
+
with open(
|
13 |
+
"/gpfswork/rech/djl/uzj43um/audio_retrieval/audioset_tagging_cnn/metadata/class_labels_indices.csv",
|
14 |
+
"r",
|
15 |
+
) as f:
|
16 |
+
reader = csv.reader(f, delimiter=",")
|
17 |
+
lines = list(reader)
|
18 |
+
|
19 |
+
labels = []
|
20 |
+
ids = [] # Each label has a unique id such as "/m/068hy"
|
21 |
+
for i1 in range(1, len(lines)):
|
22 |
+
id = lines[i1][1]
|
23 |
+
label = lines[i1][2]
|
24 |
+
ids.append(id)
|
25 |
+
labels.append(label)
|
26 |
+
|
27 |
+
classes_num = len(labels)
|
28 |
+
|
29 |
+
lb_to_ix = {label: i for i, label in enumerate(labels)}
|
30 |
+
ix_to_lb = {i: label for i, label in enumerate(labels)}
|
31 |
+
|
32 |
+
id_to_ix = {id: i for i, id in enumerate(ids)}
|
33 |
+
ix_to_id = {i: id for i, id in enumerate(ids)}
|
34 |
+
|
35 |
+
full_samples_per_class = np.array(
|
36 |
+
[
|
37 |
+
937432,
|
38 |
+
16344,
|
39 |
+
7822,
|
40 |
+
10271,
|
41 |
+
2043,
|
42 |
+
14420,
|
43 |
+
733,
|
44 |
+
1511,
|
45 |
+
1258,
|
46 |
+
424,
|
47 |
+
1751,
|
48 |
+
704,
|
49 |
+
369,
|
50 |
+
590,
|
51 |
+
1063,
|
52 |
+
1375,
|
53 |
+
5026,
|
54 |
+
743,
|
55 |
+
853,
|
56 |
+
1648,
|
57 |
+
714,
|
58 |
+
1497,
|
59 |
+
1251,
|
60 |
+
2139,
|
61 |
+
1093,
|
62 |
+
133,
|
63 |
+
224,
|
64 |
+
39469,
|
65 |
+
6423,
|
66 |
+
407,
|
67 |
+
1559,
|
68 |
+
4546,
|
69 |
+
6826,
|
70 |
+
7464,
|
71 |
+
2468,
|
72 |
+
549,
|
73 |
+
4063,
|
74 |
+
334,
|
75 |
+
587,
|
76 |
+
238,
|
77 |
+
1766,
|
78 |
+
691,
|
79 |
+
114,
|
80 |
+
2153,
|
81 |
+
236,
|
82 |
+
209,
|
83 |
+
421,
|
84 |
+
740,
|
85 |
+
269,
|
86 |
+
959,
|
87 |
+
137,
|
88 |
+
4192,
|
89 |
+
485,
|
90 |
+
1515,
|
91 |
+
655,
|
92 |
+
274,
|
93 |
+
69,
|
94 |
+
157,
|
95 |
+
1128,
|
96 |
+
807,
|
97 |
+
1022,
|
98 |
+
346,
|
99 |
+
98,
|
100 |
+
680,
|
101 |
+
890,
|
102 |
+
352,
|
103 |
+
4169,
|
104 |
+
2061,
|
105 |
+
1753,
|
106 |
+
9883,
|
107 |
+
1339,
|
108 |
+
708,
|
109 |
+
37857,
|
110 |
+
18504,
|
111 |
+
12864,
|
112 |
+
2475,
|
113 |
+
2182,
|
114 |
+
757,
|
115 |
+
3624,
|
116 |
+
677,
|
117 |
+
1683,
|
118 |
+
3583,
|
119 |
+
444,
|
120 |
+
1780,
|
121 |
+
2364,
|
122 |
+
409,
|
123 |
+
4060,
|
124 |
+
3097,
|
125 |
+
3143,
|
126 |
+
502,
|
127 |
+
723,
|
128 |
+
600,
|
129 |
+
230,
|
130 |
+
852,
|
131 |
+
1498,
|
132 |
+
1865,
|
133 |
+
1879,
|
134 |
+
2429,
|
135 |
+
5498,
|
136 |
+
5430,
|
137 |
+
2139,
|
138 |
+
1761,
|
139 |
+
1051,
|
140 |
+
831,
|
141 |
+
2401,
|
142 |
+
2258,
|
143 |
+
1672,
|
144 |
+
1711,
|
145 |
+
987,
|
146 |
+
646,
|
147 |
+
794,
|
148 |
+
25061,
|
149 |
+
5792,
|
150 |
+
4256,
|
151 |
+
96,
|
152 |
+
8126,
|
153 |
+
2740,
|
154 |
+
752,
|
155 |
+
513,
|
156 |
+
554,
|
157 |
+
106,
|
158 |
+
254,
|
159 |
+
1592,
|
160 |
+
556,
|
161 |
+
331,
|
162 |
+
615,
|
163 |
+
2841,
|
164 |
+
737,
|
165 |
+
265,
|
166 |
+
1349,
|
167 |
+
358,
|
168 |
+
1731,
|
169 |
+
1115,
|
170 |
+
295,
|
171 |
+
1070,
|
172 |
+
972,
|
173 |
+
174,
|
174 |
+
937780,
|
175 |
+
112337,
|
176 |
+
42509,
|
177 |
+
49200,
|
178 |
+
11415,
|
179 |
+
6092,
|
180 |
+
13851,
|
181 |
+
2665,
|
182 |
+
1678,
|
183 |
+
13344,
|
184 |
+
2329,
|
185 |
+
1415,
|
186 |
+
2244,
|
187 |
+
1099,
|
188 |
+
5024,
|
189 |
+
9872,
|
190 |
+
10948,
|
191 |
+
4409,
|
192 |
+
2732,
|
193 |
+
1211,
|
194 |
+
1289,
|
195 |
+
4807,
|
196 |
+
5136,
|
197 |
+
1867,
|
198 |
+
16134,
|
199 |
+
14519,
|
200 |
+
3086,
|
201 |
+
19261,
|
202 |
+
6499,
|
203 |
+
4273,
|
204 |
+
2790,
|
205 |
+
8820,
|
206 |
+
1228,
|
207 |
+
1575,
|
208 |
+
4420,
|
209 |
+
3685,
|
210 |
+
2019,
|
211 |
+
664,
|
212 |
+
324,
|
213 |
+
513,
|
214 |
+
411,
|
215 |
+
436,
|
216 |
+
2997,
|
217 |
+
5162,
|
218 |
+
3806,
|
219 |
+
1389,
|
220 |
+
899,
|
221 |
+
8088,
|
222 |
+
7004,
|
223 |
+
1105,
|
224 |
+
3633,
|
225 |
+
2621,
|
226 |
+
9753,
|
227 |
+
1082,
|
228 |
+
26854,
|
229 |
+
3415,
|
230 |
+
4991,
|
231 |
+
2129,
|
232 |
+
5546,
|
233 |
+
4489,
|
234 |
+
2850,
|
235 |
+
1977,
|
236 |
+
1908,
|
237 |
+
1719,
|
238 |
+
1106,
|
239 |
+
1049,
|
240 |
+
152,
|
241 |
+
136,
|
242 |
+
802,
|
243 |
+
488,
|
244 |
+
592,
|
245 |
+
2081,
|
246 |
+
2712,
|
247 |
+
1665,
|
248 |
+
1128,
|
249 |
+
250,
|
250 |
+
544,
|
251 |
+
789,
|
252 |
+
2715,
|
253 |
+
8063,
|
254 |
+
7056,
|
255 |
+
2267,
|
256 |
+
8034,
|
257 |
+
6092,
|
258 |
+
3815,
|
259 |
+
1833,
|
260 |
+
3277,
|
261 |
+
8813,
|
262 |
+
2111,
|
263 |
+
4662,
|
264 |
+
2678,
|
265 |
+
2954,
|
266 |
+
5227,
|
267 |
+
1472,
|
268 |
+
2591,
|
269 |
+
3714,
|
270 |
+
1974,
|
271 |
+
1795,
|
272 |
+
4680,
|
273 |
+
3751,
|
274 |
+
6585,
|
275 |
+
2109,
|
276 |
+
36617,
|
277 |
+
6083,
|
278 |
+
16264,
|
279 |
+
17351,
|
280 |
+
3449,
|
281 |
+
5034,
|
282 |
+
3931,
|
283 |
+
2599,
|
284 |
+
4134,
|
285 |
+
3892,
|
286 |
+
2334,
|
287 |
+
2211,
|
288 |
+
4516,
|
289 |
+
2766,
|
290 |
+
2862,
|
291 |
+
3422,
|
292 |
+
1788,
|
293 |
+
2544,
|
294 |
+
2403,
|
295 |
+
2892,
|
296 |
+
4042,
|
297 |
+
3460,
|
298 |
+
1516,
|
299 |
+
1972,
|
300 |
+
1563,
|
301 |
+
1579,
|
302 |
+
2776,
|
303 |
+
1647,
|
304 |
+
4535,
|
305 |
+
3921,
|
306 |
+
1261,
|
307 |
+
6074,
|
308 |
+
2922,
|
309 |
+
3068,
|
310 |
+
1948,
|
311 |
+
4407,
|
312 |
+
712,
|
313 |
+
1294,
|
314 |
+
1019,
|
315 |
+
1572,
|
316 |
+
3764,
|
317 |
+
5218,
|
318 |
+
975,
|
319 |
+
1539,
|
320 |
+
6376,
|
321 |
+
1606,
|
322 |
+
6091,
|
323 |
+
1138,
|
324 |
+
1169,
|
325 |
+
7925,
|
326 |
+
3136,
|
327 |
+
1108,
|
328 |
+
2677,
|
329 |
+
2680,
|
330 |
+
1383,
|
331 |
+
3144,
|
332 |
+
2653,
|
333 |
+
1986,
|
334 |
+
1800,
|
335 |
+
1308,
|
336 |
+
1344,
|
337 |
+
122231,
|
338 |
+
12977,
|
339 |
+
2552,
|
340 |
+
2678,
|
341 |
+
7824,
|
342 |
+
768,
|
343 |
+
8587,
|
344 |
+
39503,
|
345 |
+
3474,
|
346 |
+
661,
|
347 |
+
430,
|
348 |
+
193,
|
349 |
+
1405,
|
350 |
+
1442,
|
351 |
+
3588,
|
352 |
+
6280,
|
353 |
+
10515,
|
354 |
+
785,
|
355 |
+
710,
|
356 |
+
305,
|
357 |
+
206,
|
358 |
+
4990,
|
359 |
+
5329,
|
360 |
+
3398,
|
361 |
+
1771,
|
362 |
+
3022,
|
363 |
+
6907,
|
364 |
+
1523,
|
365 |
+
8588,
|
366 |
+
12203,
|
367 |
+
666,
|
368 |
+
2113,
|
369 |
+
7916,
|
370 |
+
434,
|
371 |
+
1636,
|
372 |
+
5185,
|
373 |
+
1062,
|
374 |
+
664,
|
375 |
+
952,
|
376 |
+
3490,
|
377 |
+
2811,
|
378 |
+
2749,
|
379 |
+
2848,
|
380 |
+
15555,
|
381 |
+
363,
|
382 |
+
117,
|
383 |
+
1494,
|
384 |
+
1647,
|
385 |
+
5886,
|
386 |
+
4021,
|
387 |
+
633,
|
388 |
+
1013,
|
389 |
+
5951,
|
390 |
+
11343,
|
391 |
+
2324,
|
392 |
+
243,
|
393 |
+
372,
|
394 |
+
943,
|
395 |
+
734,
|
396 |
+
242,
|
397 |
+
3161,
|
398 |
+
122,
|
399 |
+
127,
|
400 |
+
201,
|
401 |
+
1654,
|
402 |
+
768,
|
403 |
+
134,
|
404 |
+
1467,
|
405 |
+
642,
|
406 |
+
1148,
|
407 |
+
2156,
|
408 |
+
1368,
|
409 |
+
1176,
|
410 |
+
302,
|
411 |
+
1909,
|
412 |
+
61,
|
413 |
+
223,
|
414 |
+
1812,
|
415 |
+
287,
|
416 |
+
422,
|
417 |
+
311,
|
418 |
+
228,
|
419 |
+
748,
|
420 |
+
230,
|
421 |
+
1876,
|
422 |
+
539,
|
423 |
+
1814,
|
424 |
+
737,
|
425 |
+
689,
|
426 |
+
1140,
|
427 |
+
591,
|
428 |
+
943,
|
429 |
+
353,
|
430 |
+
289,
|
431 |
+
198,
|
432 |
+
490,
|
433 |
+
7938,
|
434 |
+
1841,
|
435 |
+
850,
|
436 |
+
457,
|
437 |
+
814,
|
438 |
+
146,
|
439 |
+
551,
|
440 |
+
728,
|
441 |
+
1627,
|
442 |
+
620,
|
443 |
+
648,
|
444 |
+
1621,
|
445 |
+
2731,
|
446 |
+
535,
|
447 |
+
88,
|
448 |
+
1736,
|
449 |
+
736,
|
450 |
+
328,
|
451 |
+
293,
|
452 |
+
3170,
|
453 |
+
344,
|
454 |
+
384,
|
455 |
+
7640,
|
456 |
+
433,
|
457 |
+
215,
|
458 |
+
715,
|
459 |
+
626,
|
460 |
+
128,
|
461 |
+
3059,
|
462 |
+
1833,
|
463 |
+
2069,
|
464 |
+
3732,
|
465 |
+
1640,
|
466 |
+
1508,
|
467 |
+
836,
|
468 |
+
567,
|
469 |
+
2837,
|
470 |
+
1151,
|
471 |
+
2068,
|
472 |
+
695,
|
473 |
+
1494,
|
474 |
+
3173,
|
475 |
+
364,
|
476 |
+
88,
|
477 |
+
188,
|
478 |
+
740,
|
479 |
+
677,
|
480 |
+
273,
|
481 |
+
1533,
|
482 |
+
821,
|
483 |
+
1091,
|
484 |
+
293,
|
485 |
+
647,
|
486 |
+
318,
|
487 |
+
1202,
|
488 |
+
328,
|
489 |
+
532,
|
490 |
+
2847,
|
491 |
+
526,
|
492 |
+
721,
|
493 |
+
370,
|
494 |
+
258,
|
495 |
+
956,
|
496 |
+
1269,
|
497 |
+
1641,
|
498 |
+
339,
|
499 |
+
1322,
|
500 |
+
4485,
|
501 |
+
286,
|
502 |
+
1874,
|
503 |
+
277,
|
504 |
+
757,
|
505 |
+
1393,
|
506 |
+
1330,
|
507 |
+
380,
|
508 |
+
146,
|
509 |
+
377,
|
510 |
+
394,
|
511 |
+
318,
|
512 |
+
339,
|
513 |
+
1477,
|
514 |
+
1886,
|
515 |
+
101,
|
516 |
+
1435,
|
517 |
+
284,
|
518 |
+
1425,
|
519 |
+
686,
|
520 |
+
621,
|
521 |
+
221,
|
522 |
+
117,
|
523 |
+
87,
|
524 |
+
1340,
|
525 |
+
201,
|
526 |
+
1243,
|
527 |
+
1222,
|
528 |
+
651,
|
529 |
+
1899,
|
530 |
+
421,
|
531 |
+
712,
|
532 |
+
1016,
|
533 |
+
1279,
|
534 |
+
124,
|
535 |
+
351,
|
536 |
+
258,
|
537 |
+
7043,
|
538 |
+
368,
|
539 |
+
666,
|
540 |
+
162,
|
541 |
+
7664,
|
542 |
+
137,
|
543 |
+
70159,
|
544 |
+
26179,
|
545 |
+
6321,
|
546 |
+
32236,
|
547 |
+
33320,
|
548 |
+
771,
|
549 |
+
1169,
|
550 |
+
269,
|
551 |
+
1103,
|
552 |
+
444,
|
553 |
+
364,
|
554 |
+
2710,
|
555 |
+
121,
|
556 |
+
751,
|
557 |
+
1609,
|
558 |
+
855,
|
559 |
+
1141,
|
560 |
+
2287,
|
561 |
+
1940,
|
562 |
+
3943,
|
563 |
+
289,
|
564 |
+
]
|
565 |
+
)
|
ISMIR_2025/wav2vec/utils/confusion_matrix_plot.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sklearn.metrics import confusion_matrix
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
def plot_confusion_matrix(y_true, y_pred, classes, writer, epoch):
|
6 |
+
cm = confusion_matrix(y_true, y_pred)
|
7 |
+
fig, ax = plt.subplots(figsize=(6, 6))
|
8 |
+
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
|
9 |
+
ax.figure.colorbar(im, ax=ax)
|
10 |
+
|
11 |
+
num_classes = cm.shape[0]
|
12 |
+
tick_labels = classes[:num_classes]
|
13 |
+
|
14 |
+
ax.set(xticks=np.arange(num_classes),
|
15 |
+
yticks=np.arange(num_classes),
|
16 |
+
xticklabels=tick_labels,
|
17 |
+
yticklabels=tick_labels,
|
18 |
+
ylabel='True label',
|
19 |
+
xlabel='Predicted label')
|
20 |
+
|
21 |
+
thresh = cm.max() / 2.
|
22 |
+
for i in range(cm.shape[0]):
|
23 |
+
for j in range(cm.shape[1]):
|
24 |
+
ax.text(j, i, format(cm[i, j], 'd'),
|
25 |
+
ha="center", va="center",
|
26 |
+
color="white" if cm[i, j] > thresh else "black")
|
27 |
+
|
28 |
+
fig.tight_layout()
|
29 |
+
writer.add_figure("Confusion Matrix", fig, epoch)
|