nininigold commited on
Commit
3cecacc
·
verified ·
1 Parent(s): 18c3d3e

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. ISMIR_2025/MERT/__pycache__/datalib.cpython-311.pyc +0 -0
  3. ISMIR_2025/MERT/__pycache__/datalib.cpython-312.pyc +0 -0
  4. ISMIR_2025/MERT/__pycache__/datalib_singfake.cpython-311.pyc +0 -0
  5. ISMIR_2025/MERT/__pycache__/main.cpython-312.pyc +0 -0
  6. ISMIR_2025/MERT/__pycache__/networks.cpython-311.pyc +0 -0
  7. ISMIR_2025/MERT/__pycache__/networks.cpython-312.pyc +0 -0
  8. ISMIR_2025/MERT/__pycache__/networks.cpython-39.pyc +0 -0
  9. ISMIR_2025/MERT/datalib.py +203 -0
  10. ISMIR_2025/MERT/main.py +197 -0
  11. ISMIR_2025/MERT/networks.py +107 -0
  12. ISMIR_2025/MERT/test.py +114 -0
  13. ISMIR_2025/MERT/utils/__pycache__/config.cpython-311.pyc +0 -0
  14. ISMIR_2025/MERT/utils/__pycache__/idr_torch.cpython-311.pyc +0 -0
  15. ISMIR_2025/MERT/utils/__pycache__/utilities.cpython-311.pyc +0 -0
  16. ISMIR_2025/MERT/utils/config.py +565 -0
  17. ISMIR_2025/MERT/utils/confusion_matrix_plot.py +29 -0
  18. ISMIR_2025/MERT/utils/freqeuncy.py +24 -0
  19. ISMIR_2025/MERT/utils/hf_vis.py +89 -0
  20. ISMIR_2025/MERT/utils/idr_torch.py +23 -0
  21. ISMIR_2025/MERT/utils/mfcc.py +266 -0
  22. ISMIR_2025/MERT/utils/utilities.py +305 -0
  23. ISMIR_2025/Model/__pycache__/networks.cpython-312.pyc +0 -0
  24. ISMIR_2025/Model/datalib.py +206 -0
  25. ISMIR_2025/Model/main.py +336 -0
  26. ISMIR_2025/Model/networks.py +237 -0
  27. ISMIR_2025/Model/test.py +129 -0
  28. ISMIR_2025/music2vec/__pycache__/datalib.cpython-311.pyc +0 -0
  29. ISMIR_2025/music2vec/__pycache__/networks.cpython-311.pyc +0 -0
  30. ISMIR_2025/music2vec/__pycache__/networks.cpython-312.pyc +0 -0
  31. ISMIR_2025/music2vec/datalib.py +144 -0
  32. ISMIR_2025/music2vec/inference.py +64 -0
  33. ISMIR_2025/music2vec/main.py +155 -0
  34. ISMIR_2025/music2vec/networks.py +247 -0
  35. ISMIR_2025/music2vec/test.py +119 -0
  36. ISMIR_2025/wav2vec/__pycache__/datalib.cpython-311.pyc +0 -0
  37. ISMIR_2025/wav2vec/__pycache__/loss.cpython-311.pyc +0 -0
  38. ISMIR_2025/wav2vec/__pycache__/networks.cpython-311.pyc +0 -0
  39. ISMIR_2025/wav2vec/__pycache__/networks.cpython-312.pyc +0 -0
  40. ISMIR_2025/wav2vec/__pycache__/wav2vec_datalib.cpython-311.pyc +0 -0
  41. ISMIR_2025/wav2vec/datalib.py +139 -0
  42. ISMIR_2025/wav2vec/inference.py +71 -0
  43. ISMIR_2025/wav2vec/main.py +162 -0
  44. ISMIR_2025/wav2vec/networks.py +161 -0
  45. ISMIR_2025/wav2vec/test.py +148 -0
  46. ISMIR_2025/wav2vec/utils/__pycache__/config.cpython-311.pyc +0 -0
  47. ISMIR_2025/wav2vec/utils/__pycache__/idr_torch.cpython-311.pyc +0 -0
  48. ISMIR_2025/wav2vec/utils/__pycache__/utilities.cpython-311.pyc +0 -0
  49. ISMIR_2025/wav2vec/utils/config.py +565 -0
  50. ISMIR_2025/wav2vec/utils/confusion_matrix_plot.py +29 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ v2_request/request_music/AI_Detection/9486_music_id_Lost[[:space:]]Sky[[:space:]]x[[:space:]]Anna[[:space:]]Yvette[[:space:]]-[[:space:]]Carry[[:space:]]On[[:space:]]|[[:space:]]Trap[[:space:]]|[[:space:]]NCS[[:space:]]-[[:space:]]Copyright[[:space:]]Free[[:space:]]Music.wav filter=lfs diff=lfs merge=lfs -text
ISMIR_2025/MERT/__pycache__/datalib.cpython-311.pyc ADDED
Binary file (10.1 kB). View file
 
ISMIR_2025/MERT/__pycache__/datalib.cpython-312.pyc ADDED
Binary file (9.26 kB). View file
 
ISMIR_2025/MERT/__pycache__/datalib_singfake.cpython-311.pyc ADDED
Binary file (8.69 kB). View file
 
ISMIR_2025/MERT/__pycache__/main.cpython-312.pyc ADDED
Binary file (11.6 kB). View file
 
ISMIR_2025/MERT/__pycache__/networks.cpython-311.pyc ADDED
Binary file (6.86 kB). View file
 
ISMIR_2025/MERT/__pycache__/networks.cpython-312.pyc ADDED
Binary file (6.62 kB). View file
 
ISMIR_2025/MERT/__pycache__/networks.cpython-39.pyc ADDED
Binary file (4.33 kB). View file
 
ISMIR_2025/MERT/datalib.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import torch
4
+ import torchaudio
5
+ import librosa
6
+ import numpy as np
7
+ from sklearn.model_selection import train_test_split
8
+ from torch.utils.data import Dataset
9
+ from imblearn.over_sampling import RandomOverSampler
10
+ from transformers import Wav2Vec2Processor
11
+ import torch
12
+ import torchaudio
13
+ from torch.nn.utils.rnn import pad_sequence
14
+ from transformers import Wav2Vec2FeatureExtractor
15
+ import scipy.signal as signal
16
+ import scipy.signal
17
+ # class FakeMusicCapsDataset(Dataset):
18
+ # def __init__(self, file_paths, labels, sr=16000, target_duration=10.0):
19
+ # self.file_paths = file_paths
20
+ # self.labels = labels
21
+ # self.sr = sr
22
+ # self.target_samples = int(target_duration * sr) # Fixed length: 5 seconds
23
+
24
+ # def __len__(self):
25
+ # return len(self.file_paths)
26
+
27
+ # def __getitem__(self, idx):
28
+ # audio_path = self.file_paths[idx]
29
+ # label = self.labels[idx]
30
+
31
+ # waveform, sr = torchaudio.load(audio_path)
32
+ # waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sr)(waveform)
33
+ # waveform = waveform.mean(dim=0) # Convert to mono
34
+ # waveform = waveform.squeeze(0)
35
+
36
+
37
+ # current_samples = waveform.shape[0]
38
+
39
+ # # **Ensure waveform is exactly `target_samples` long**
40
+ # if current_samples > self.target_samples:
41
+ # waveform = waveform[:self.target_samples] # Truncate if too long
42
+ # elif current_samples < self.target_samples:
43
+ # pad_length = self.target_samples - current_samples
44
+ # waveform = torch.nn.functional.pad(waveform, (0, pad_length)) # Pad if too short
45
+
46
+ # return waveform.unsqueeze(0), torch.tensor(label, dtype=torch.long) # Ensure 2D shape (1, target_samples)
47
+
48
+ class FakeMusicCapsDataset(Dataset):
49
+ def __init__(self, file_paths, labels, sr=16000, target_duration=10.0):
50
+ self.file_paths = file_paths
51
+ self.labels = labels
52
+ self.sr = sr
53
+ self.target_samples = int(target_duration * sr) # Fixed length: 10 seconds
54
+ self.processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True)
55
+
56
+ def __len__(self):
57
+ return len(self.file_paths)
58
+
59
+ def highpass_filter(self, y, sr, cutoff=500, order=5):
60
+ if isinstance(sr, np.ndarray):
61
+ # print(f"[ERROR] sr is an array, taking mean value. Original sr: {sr}")
62
+ sr = np.mean(sr)
63
+ if not isinstance(sr, (int, float)):
64
+ raise ValueError(f"[ERROR] sr must be a number, but got {type(sr)}: {sr}")
65
+ # print(f"[DEBUG] Highpass filter using sr={sr}, cutoff={cutoff}")
66
+ if sr <= 0:
67
+ raise ValueError(f"Invalid sample rate: {sr}. It must be greater than 0.")
68
+ nyquist = 0.5 * sr
69
+ # print(f"[DEBUG] Nyquist frequency={nyquist}")
70
+ if cutoff <= 0 or cutoff >= nyquist:
71
+ print(f"[WARNING] Invalid cutoff frequency {cutoff}, adjusting...")
72
+ cutoff = max(10, min(cutoff, nyquist - 1))
73
+ normal_cutoff = cutoff / nyquist
74
+ # print(f"[DEBUG] Adjusted cutoff={cutoff}, normal_cutoff={normal_cutoff}")
75
+ b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
76
+ y_filtered = signal.lfilter(b, a, y)
77
+ return y_filtered
78
+
79
+ def __getitem__(self, idx):
80
+ audio_path = self.file_paths[idx]
81
+ label = self.labels[idx]
82
+
83
+ waveform, sr = torchaudio.load(audio_path)
84
+
85
+ target_sr = self.processor.sampling_rate
86
+
87
+ if sr != target_sr:
88
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)
89
+ waveform = resampler(waveform)
90
+
91
+ waveform = waveform.mean(dim=0).squeeze(0) # [Time]
92
+
93
+ if label == 1:
94
+ waveform = self.highpass_filter(waveform, self.sr)
95
+
96
+ current_samples = waveform.shape[0]
97
+ if current_samples > self.target_samples:
98
+ waveform = waveform[:self.target_samples] # Truncate
99
+ elif current_samples < self.target_samples:
100
+ pad_length = self.target_samples - current_samples
101
+ waveform = torch.nn.functional.pad(waveform, (0, pad_length)) # Pad
102
+
103
+ if isinstance(waveform, torch.Tensor):
104
+ waveform = waveform.numpy() # Tensor일 경우에만 변환
105
+
106
+ inputs = self.processor(waveform, sampling_rate=target_sr, return_tensors="pt", padding=True)
107
+
108
+ return inputs["input_values"].squeeze(0), torch.tensor(label, dtype=torch.long) # [1, time] → [time]
109
+
110
+ @staticmethod
111
+ def collate_fn(batch, target_samples=16000 * 10):
112
+
113
+ inputs, labels = zip(*batch) # Unzip batch
114
+
115
+ processed_inputs = []
116
+ for waveform in inputs:
117
+ current_samples = waveform.shape[0]
118
+
119
+ if current_samples > target_samples:
120
+ start_idx = (current_samples - target_samples) // 2
121
+ cropped_waveform = waveform[start_idx:start_idx + target_samples]
122
+ else:
123
+ pad_length = target_samples - current_samples
124
+ cropped_waveform = torch.nn.functional.pad(waveform, (0, pad_length))
125
+
126
+ processed_inputs.append(cropped_waveform)
127
+
128
+ processed_inputs = torch.stack(processed_inputs) # [batch, target_samples]
129
+ labels = torch.tensor(labels, dtype=torch.long) # [batch]
130
+
131
+ return processed_inputs, labels
132
+
133
+ def preprocess_audio(audio_path, target_sr=16000, max_length=160000):
134
+ """
135
+ 오디오를 모델 입력에 맞게 변환
136
+ - target_sr: 16kHz로 변환
137
+ - max_length: 최대 길이 160000 (10초)
138
+ """
139
+ waveform, sr = torchaudio.load(audio_path)
140
+
141
+ # Resample if needed
142
+ if sr != target_sr:
143
+ waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(waveform)
144
+
145
+ # Convert to mono
146
+ waveform = waveform.mean(dim=0).unsqueeze(0) # (1, sequence_length)
147
+
148
+ current_samples = waveform.shape[1]
149
+ if current_samples > max_length:
150
+ start_idx = (current_samples - max_length) // 2
151
+ waveform = waveform[:, start_idx:start_idx + max_length]
152
+ elif current_samples < max_length:
153
+ pad_length = max_length - current_samples
154
+ waveform = torch.nn.functional.pad(waveform, (0, pad_length))
155
+
156
+ return waveform
157
+
158
+
159
+ DATASET_PATH = "/data/kym/AI_Music_Detection/audio/FakeMusicCaps"
160
+ SUNOCAPS_PATH = "/data/kym/Audio/SunoCaps" # Open Set 포함 데이터
161
+
162
+ # Closed Test: FakeMusicCaps 데이터셋 사용
163
+ real_files = glob.glob(os.path.join(DATASET_PATH, "real", "**", "*.wav"), recursive=True)
164
+ gen_files = glob.glob(os.path.join(DATASET_PATH, "generative", "**", "*.wav"), recursive=True)
165
+
166
+ # Open Set Test: SUNOCAPS_PATH 데이터 포함
167
+ open_real_files = real_files + glob.glob(os.path.join(SUNOCAPS_PATH, "real", "**", "*.wav"), recursive=True)
168
+ open_gen_files = gen_files + glob.glob(os.path.join(SUNOCAPS_PATH, "generative", "**", "*.wav"), recursive=True)
169
+
170
+ real_labels = [0] * len(real_files)
171
+ gen_labels = [1] * len(gen_files)
172
+
173
+ open_real_labels = [0] * len(open_real_files)
174
+ open_gen_labels = [1] * len(open_gen_files)
175
+
176
+ # Closed Train, Val
177
+ real_train, real_val, real_train_labels, real_val_labels = train_test_split(real_files, real_labels, test_size=0.2, random_state=42)
178
+ gen_train, gen_val, gen_train_labels, gen_val_labels = train_test_split(gen_files, gen_labels, test_size=0.2, random_state=42)
179
+
180
+ train_files = real_train + gen_train
181
+ train_labels = real_train_labels + gen_train_labels
182
+ val_files = real_val + gen_val
183
+ val_labels = real_val_labels + gen_val_labels
184
+
185
+ # Closed Set Test용 데이터셋
186
+ closed_test_files = real_files + gen_files
187
+ closed_test_labels = real_labels + gen_labels
188
+
189
+ # Open Set Test용 데이터셋
190
+ open_test_files = open_real_files + open_gen_files
191
+ open_test_labels = open_real_labels + open_gen_labels
192
+
193
+ # Oversampling 적용
194
+ ros = RandomOverSampler(sampling_strategy='auto', random_state=42)
195
+ train_files_resampled, train_labels_resampled = ros.fit_resample(np.array(train_files).reshape(-1, 1), train_labels)
196
+
197
+ train_files = train_files_resampled.reshape(-1).tolist()
198
+ train_labels = train_labels_resampled
199
+
200
+ print(f"📌 Train Original FAKE: {len(gen_train)}")
201
+ print(f"📌 Train set (Oversampled) - REAL: {sum(1 for label in train_labels if label == 0)}, "
202
+ f"FAKE: {sum(1 for label in train_labels if label == 1)}, Total: {len(train_files)}")
203
+ print(f"📌 Validation set - REAL: {len(real_val)}, FAKE: {len(gen_val)}, Total: {len(val_files)}")
ISMIR_2025/MERT/main.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from tqdm import tqdm
8
+ from torch.utils.data import DataLoader
9
+ from sklearn.metrics import f1_score, precision_score, recall_score, balanced_accuracy_score
10
+ import wandb
11
+ import argparse
12
+ from transformers import AutoModel, AutoConfig, Wav2Vec2FeatureExtractor
13
+ from ISMIR_2025.MERT.datalib import FakeMusicCapsDataset, train_files, train_labels, val_files, val_labels
14
+ from ISMIR_2025.MERT.networks import MERTFeatureExtractor
15
+ # Set device
16
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ # Seed for reproducibility
20
+ torch.manual_seed(42)
21
+ random.seed(42)
22
+ np.random.seed(42)
23
+
24
+ # Initialize wandb
25
+ wandb.init(project="mert", name=f"hpfilter_pretrain_{args.pretrain_epochs}_finetune_{args.finetune_epochs}", config=args)
26
+
27
+ # Load datasets
28
+ print("🔍 Preparing datasets...")
29
+ train_dataset = FakeMusicCapsDataset(train_files, train_labels)
30
+ val_dataset = FakeMusicCapsDataset(val_files, val_labels)
31
+
32
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, collate_fn=FakeMusicCapsDataset.collate_fn)
33
+ val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, collate_fn=FakeMusicCapsDataset.collate_fn)
34
+
35
+ # Model Checkpoint Paths
36
+ pretrain_ckpt = os.path.join(args.checkpoint_dir, f"mert_pretrain_{args.pretrain_epochs}.pth")
37
+ finetune_ckpt = os.path.join(args.checkpoint_dir, f"mert_finetune_{args.finetune_epochs}.pth")
38
+
39
+ # Load Music2Vec Model for Pretraining
40
+ print("🔍 Initializing MERT model for Pretraining...")
41
+
42
+ config = AutoConfig.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True)
43
+ if not hasattr(config, "conv_pos_batch_norm"):
44
+ setattr(config, "conv_pos_batch_norm", False)
45
+
46
+ mert_model = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", config=config, trust_remote_code=True).to(device)
47
+ mert_model = MERTFeatureExtractor().to(device)
48
+
49
+ # Loss and Optimizer
50
+ criterion = nn.CrossEntropyLoss()
51
+ optimizer = optim.Adam(mert_model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
52
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
53
+
54
+ # Training function
55
+ def train(model, dataloader, optimizer, criterion, device, epoch, phase="Pretrain"):
56
+ model.train()
57
+ total_loss, total_correct, total_samples = 0, 0, 0
58
+ all_preds, all_labels = [], []
59
+
60
+ for inputs, labels in tqdm(dataloader, desc=f"{phase} Training Epoch {epoch+1}"):
61
+ labels = labels.to(device)
62
+ inputs = inputs.to(device)
63
+
64
+ # inputs = inputs.float()
65
+ # output = model(inputs)
66
+ output = model(inputs)
67
+
68
+ # Check if the output is a tensor or an object with logits
69
+ if isinstance(output, torch.Tensor):
70
+ logits = output
71
+ elif hasattr(output, "logits"):
72
+ logits = output.logits
73
+ elif isinstance(output, (tuple, list)):
74
+ logits = output[0]
75
+ else:
76
+ raise ValueError("Unexpected model output type")
77
+
78
+ loss = criterion(logits, labels)
79
+
80
+
81
+ # loss = criterion(output, labels)
82
+
83
+ optimizer.zero_grad()
84
+ loss.backward()
85
+ optimizer.step()
86
+
87
+ total_loss += loss.item()
88
+ preds = output.argmax(dim=1)
89
+ total_correct += (preds == labels).sum().item()
90
+ total_samples += labels.size(0)
91
+ all_preds.extend(preds.cpu().numpy())
92
+ all_labels.extend(labels.cpu().numpy())
93
+
94
+ scheduler.step()
95
+
96
+ accuracy = total_correct / total_samples
97
+ f1 = f1_score(all_labels, all_preds, average="binary")
98
+ precision = precision_score(all_labels, all_preds, average="binary")
99
+ recall = recall_score(all_labels, all_preds, average="binary", pos_label=1)
100
+ balanced_acc = balanced_accuracy_score(all_labels, all_preds)
101
+
102
+
103
+ wandb.log({
104
+ f"{phase} Train Loss": total_loss / len(dataloader),
105
+ f"{phase} Train Accuracy": accuracy,
106
+ f"{phase} Train F1 Score": f1,
107
+ f"{phase} Train Precision": precision,
108
+ f"{phase} Train Recall": recall,
109
+ f"{phase} Train Balanced Accuracy": balanced_acc,
110
+ })
111
+
112
+ print(f"{phase} Train Epoch {epoch+1}: Train Loss: {total_loss / len(dataloader):.4f}, "
113
+ f"Train Acc: {accuracy:.4f}, Train F1: {f1:.4f}, Train Prec: {precision:.4f}, Train Rec: {recall:.4f}, B_ACC: {balanced_acc:.4f}")
114
+
115
+ def validate(model, dataloader, optimizer, criterion, device, epoch, phase="Validation"):
116
+ model.eval()
117
+ total_loss, total_correct, total_samples = 0, 0, 0
118
+ all_preds, all_labels = [], []
119
+
120
+ for inputs, labels in tqdm(dataloader, desc=f"{phase} Validation Epoch {epoch+1}"):
121
+ labels = labels.to(device)
122
+ inputs = inputs.to(device)
123
+
124
+ output = model(inputs)
125
+
126
+ # Check if the output is a tensor or an object with logits
127
+ if isinstance(output, torch.Tensor):
128
+ logits = output
129
+ elif hasattr(output, "logits"):
130
+ logits = output.logits
131
+ elif isinstance(output, (tuple, list)):
132
+ logits = output[0]
133
+ else:
134
+ raise ValueError("Unexpected model output type")
135
+
136
+ loss = criterion(logits, labels)
137
+ optimizer.zero_grad()
138
+ loss.backward()
139
+ optimizer.step()
140
+
141
+ total_loss += loss.item()
142
+ preds = outputs.argmax(dim=1)
143
+ total_correct += (preds == labels).sum().item()
144
+ total_samples += labels.size(0)
145
+ all_preds.extend(preds.cpu().numpy())
146
+ all_labels.extend(labels.cpu().numpy())
147
+ scheduler.step()
148
+ accuracy = total_correct / total_samples
149
+ val_f1 = f1_score(all_labels, all_preds, average="weighted")
150
+ val_precision = precision_score(all_labels, all_preds, average="binary")
151
+ val_recall = recall_score(all_labels, all_preds, average="binary")
152
+ val_bal_acc = balanced_accuracy_score(all_labels, all_preds)
153
+
154
+ wandb.log({
155
+ f"{phase} Val Loss": total_loss / len(dataloader),
156
+ f"{phase} Val Accuracy": accuracy,
157
+ f"{phase} Val F1 Score": val_f1,
158
+ f"{phase} Val Precision": val_precision,
159
+ f"{phase} Val Recall": val_recall,
160
+ f"{phase} Val Balanced Accuracy": val_bal_acc,
161
+ })
162
+ print(f"{phase} Val Loss: {total_loss / len(dataloader):.4f}, "
163
+ f"Val Acc: {accuracy:.4f}, Val F1: {val_f1:.4f}, Val Prec: {val_precision:.4f}, Val Rec: {val_recall:.4f}, Val B_ACC: {val_bal_acc:.4f}")
164
+ return total_loss / len(dataloader), accuracy, val_f1
165
+
166
+
167
+ print("\n🔍 Step 1: Self-Supervised Pretraining on REAL Data")
168
+ # for epoch in range(args.pretrain_epochs):
169
+ # train(mert_model, train_loader, optimizer, criterion, device, epoch, phase="Pretrain")
170
+ # torch.save(mert_model.state_dict(), pretrain_ckpt)
171
+ # print(f"\nPretraining completed! Model saved at: {pretrain_ckpt}")
172
+
173
+ # print("\n🔍 Initializing CCV Model for Fine-Tuning...")
174
+ # mert_model = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", config=config, trust_remote_code=True).to(device)
175
+ # mert_model.feature_extractor.load_state_dict(torch.load(pretrain_ckpt), strict=False)
176
+
177
+ # optimizer = optim.Adam(mert_model.parameters(), lr=args.finetune_lr, weight_decay=args.weight_decay)
178
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
179
+
180
+ print("\n🔍 Step 2: Fine-Tuning CCV Model")
181
+ for epoch in range(args.finetune_epochs):
182
+ train(mert_model, train_loader, optimizer, criterion, device, epoch, phase="Fine-Tune")
183
+
184
+ torch.save(mert_model.state_dict(), finetune_ckpt)
185
+ print(f"\nFine-Tuning completed! Model saved at: {finetune_ckpt}")
186
+
187
+ print("\n🔍 Step 2: Fine-Tuning MERT Model")
188
+ mert_model.load_state_dict(torch.load(pretrain_ckpt), strict=False)
189
+
190
+ optimizer = optim.Adam(mert_model.parameters(), lr=args.finetune_lr, weight_decay=args.weight_decay)
191
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
192
+
193
+ for epoch in range(args.finetune_epochs):
194
+ train(mert_model, train_loader, optimizer, criterion, device, epoch, phase="Fine-Tune")
195
+
196
+ torch.save(mert_model.state_dict(), finetune_ckpt)
197
+ print(f"\nFine-Tuning completed! Model saved at: {finetune_ckpt}")
ISMIR_2025/MERT/networks.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoModel, AutoConfig
4
+
5
+ class MERTFeatureExtractor(nn.Module):
6
+ def __init__(self, freeze_feature_extractor=True):
7
+ super(MERTFeatureExtractor, self).__init__()
8
+ config = AutoConfig.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True)
9
+ if not hasattr(config, "conv_pos_batch_norm"):
10
+ setattr(config, "conv_pos_batch_norm", False)
11
+ self.mert = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", config=config, trust_remote_code=True)
12
+
13
+ if freeze_feature_extractor:
14
+ self.freeze()
15
+
16
+ def forward(self, input_values):
17
+ # 입력: [batch, time]
18
+ # 사전학습된 MERT의 hidden_states 추출 (예시로 모든 레이어의 hidden state 사용)
19
+ with torch.no_grad():
20
+ outputs = self.mert(input_values, output_hidden_states=True)
21
+ # hidden_states: tuple of [batch, time, feature_dim]
22
+ # 여러 레이어의 hidden state를 스택한 뒤 시간축에 대해 평균하여 feature를 얻음
23
+ hidden_states = torch.stack(outputs.hidden_states) # [num_layers, batch, time, feature_dim]
24
+ hidden_states = hidden_states.detach().clone().requires_grad_(True)
25
+ time_reduced = hidden_states.mean(dim=2) # [num_layers, batch, feature_dim]
26
+ time_reduced = time_reduced.permute(1, 0, 2) # [batch, num_layers, feature_dim]
27
+ return time_reduced
28
+
29
+ def freeze(self):
30
+ for param in self.mert.parameters():
31
+ param.requires_grad = False
32
+
33
+ def unfreeze(self):
34
+ for param in self.mert.parameters():
35
+ param.requires_grad = True
36
+
37
+
38
+ class CrossAttentionLayer(nn.Module):
39
+ def __init__(self, embed_dim, num_heads):
40
+ super(CrossAttentionLayer, self).__init__()
41
+ self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
42
+ self.layer_norm1 = nn.LayerNorm(embed_dim)
43
+ self.layer_norm2 = nn.LayerNorm(embed_dim)
44
+ self.feed_forward = nn.Sequential(
45
+ nn.Linear(embed_dim, embed_dim * 4),
46
+ nn.ReLU(),
47
+ nn.Linear(embed_dim * 4, embed_dim)
48
+ )
49
+
50
+ def forward(self, x, cross_input):
51
+ # x와 cross_input 간의 어텐션 수행
52
+ attn_output, _ = self.multihead_attn(query=x, key=cross_input, value=cross_input)
53
+ x = self.layer_norm1(x + attn_output)
54
+ ff_output = self.feed_forward(x)
55
+ x = self.layer_norm2(x + ff_output)
56
+ return x
57
+
58
+
59
+ class CCV(nn.Module):
60
+ def __init__(self, embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True):
61
+ super(CCV, self).__init__()
62
+ # MERT 기반 feature extractor (pretraining weight로부터 유의미한 피쳐 추출)
63
+ self.feature_extractor = MERTFeatureExtractor(freeze_feature_extractor=freeze_feature_extractor)
64
+ # Cross-Attention 레이어 여러 층
65
+ self.cross_attention_layers = nn.ModuleList([
66
+ CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers)
67
+ ])
68
+ # Transformer Encoder (배치 차원 고려)
69
+ encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
70
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
71
+ # 분류기
72
+ self.classifier = nn.Sequential(
73
+ nn.LayerNorm(embed_dim),
74
+ nn.Linear(embed_dim, 256),
75
+ nn.BatchNorm1d(256),
76
+ nn.ReLU(),
77
+ nn.Dropout(0.3),
78
+ nn.Linear(256, num_classes)
79
+ )
80
+
81
+
82
+ def forward(self, input_values):
83
+ """
84
+ input_values: Tensor [batch, time]
85
+ 1. MERT로부터 feature 추출 → [batch, num_layers, feature_dim]
86
+ 2. 임베딩 차원 맞추기 위해 transpose → [batch, feature_dim, num_layers]
87
+ 3. Cross-Attention 적용
88
+ 4. Transformer Encoding 후 평균 풀링
89
+ 5. 분류기 통과하여 최종 출력(logits) 반환
90
+ """
91
+ features = self.feature_extractor(input_values) # [batch, num_layers, feature_dim]
92
+ # embed_dim는 보통 feature_dim과 동일하게 맞춤 (예시: 768)
93
+ # features = features.permute(0, 2, 1) # [batch, embed_dim, num_layers]
94
+
95
+ # Cross-Attention 적용 (여기서는 자기자신과의 어텐션으로 예시)
96
+ for layer in self.cross_attention_layers:
97
+ features = layer(features, features)
98
+
99
+ # Transformer Encoder를 위해 시간 축(여기서는 num_layers 축)에 대해 평균
100
+ features = features.mean(dim=1).unsqueeze(1) # [batch, 1, embed_dim]
101
+ encoded = self.transformer(features) # [batch, 1, embed_dim]
102
+ encoded = encoded.mean(dim=1) # [batch, embed_dim]
103
+ output = self.classifier(encoded) # [batch, num_classes]
104
+ return output, encoded
105
+
106
+ def unfreeze_feature_extractor(self):
107
+ self.feature_extractor.unfreeze()
ISMIR_2025/MERT/test.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from torch.utils.data import DataLoader
7
+ from sklearn.metrics import precision_score, recall_score, f1_score, balanced_accuracy_score, confusion_matrix
8
+ from datalib import (
9
+ FakeMusicCapsDataset,
10
+ closed_test_files, closed_test_labels,
11
+ open_test_files, open_test_labels,
12
+ val_files, val_labels
13
+ )
14
+ from networks import MERTFeatureExtractor
15
+ import argparse
16
+ parser = argparse.ArgumentParser(description="AI Music Detection Testing with MERT")
17
+ parser.add_argument('--gpu', type=str, default='1', help='GPU ID')
18
+ parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
19
+ parser.add_argument('--ckpt_path', type=str, default="/data/kym/AI_Music_Detection/Code/model/MERT/ckpt/1e-3/mert_finetune_10.pth", help='Path to the pretrained checkpoint')
20
+ parser.add_argument('--model_name', type=str, default="mert", help="Model name")
21
+ parser.add_argument('--closed_test', action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)")
22
+ parser.add_argument('--open_test', action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)")
23
+ parser.add_argument('--output_path', type=str, default='', help='Path to save test results')
24
+
25
+ args = parser.parse_args()
26
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+
29
+ def plot_confusion_matrix(y_true, y_pred, classes, output_path):
30
+ cm = confusion_matrix(y_true, y_pred)
31
+ fig, ax = plt.subplots(figsize=(6, 6))
32
+ im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
33
+ ax.figure.colorbar(im, ax=ax)
34
+
35
+ num_classes = cm.shape[0]
36
+ tick_labels = classes[:num_classes]
37
+
38
+ ax.set(xticks=np.arange(num_classes),
39
+ yticks=np.arange(num_classes),
40
+ xticklabels=tick_labels,
41
+ yticklabels=tick_labels,
42
+ ylabel='True label',
43
+ xlabel='Predicted label')
44
+
45
+ thresh = cm.max() / 2.
46
+ for i in range(cm.shape[0]):
47
+ for j in range(cm.shape[1]):
48
+ ax.text(j, i, format(cm[i, j], 'd'),
49
+ ha="center", va="center",
50
+ color="white" if cm[i, j] > thresh else "black")
51
+
52
+ fig.tight_layout()
53
+ plt.savefig(output_path)
54
+ plt.close(fig)
55
+
56
+ model = MERTFeatureExtractor().to(device)
57
+
58
+ ckpt_file = args.ckpt_path
59
+ if not os.path.exists(ckpt_file):
60
+ raise FileNotFoundError(f"Checkpoint not found: {ckpt_file}")
61
+ print(f"\nLoading MERT model from {ckpt_file}")
62
+ model.load_state_dict(torch.load(ckpt_file, map_location=device))
63
+ model.eval()
64
+
65
+ torch.cuda.empty_cache()
66
+
67
+ if args.closed_test:
68
+ print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...")
69
+ test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels, target_duration=10.0)
70
+ elif args.open_test:
71
+ print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...")
72
+ test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels, target_duration=10.0)
73
+ else:
74
+ print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...")
75
+ test_dataset = FakeMusicCapsDataset(val_files, val_labels, target_duration=10.0)
76
+
77
+ test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8)
78
+
79
+ def test_mert(model, test_loader, device):
80
+ model.eval()
81
+ test_loss, test_correct, test_total = 0, 0, 0
82
+ all_preds, all_labels = [], []
83
+
84
+ with torch.no_grad():
85
+ for data, target in test_loader:
86
+ data, target = data.to(device), target.to(device)
87
+ output = model(data)
88
+ loss = F.cross_entropy(output, target)
89
+
90
+ test_loss += loss.item() * data.size(0)
91
+ preds = output.argmax(dim=1)
92
+ test_correct += (preds == target).sum().item()
93
+ test_total += target.size(0)
94
+
95
+ all_labels.extend(target.cpu().numpy())
96
+ all_preds.extend(preds.cpu().numpy())
97
+
98
+ test_loss /= test_total
99
+ test_acc = test_correct / test_total
100
+ test_bal_acc = balanced_accuracy_score(all_labels, all_preds)
101
+ test_precision = precision_score(all_labels, all_preds, average="binary")
102
+ test_recall = recall_score(all_labels, all_preds, average="binary")
103
+ test_f1 = f1_score(all_labels, all_preds, average="binary")
104
+
105
+ print(f"\nTest Results - Loss: {test_loss:.4f} | Test Acc: {test_acc:.3f} | "
106
+ f"Test B_ACC: {test_bal_acc:.4f} | Test Prec: {test_precision:.3f} | "
107
+ f"Test Rec: {test_recall:.3f} | Test F1: {test_f1:.3f}")
108
+
109
+ os.makedirs(args.output_path, exist_ok=True)
110
+ conf_matrix_path = os.path.join(args.output_path, f"confusion_matrix_{args.model_name}.png")
111
+ plot_confusion_matrix(all_labels, all_preds, classes=["real", "generative"], output_path=conf_matrix_path)
112
+
113
+ print("\nEvaluating MERT Model on Test Set...")
114
+ test_mert(model, test_loader, device)
ISMIR_2025/MERT/utils/__pycache__/config.cpython-311.pyc ADDED
Binary file (4.79 kB). View file
 
ISMIR_2025/MERT/utils/__pycache__/idr_torch.cpython-311.pyc ADDED
Binary file (1.01 kB). View file
 
ISMIR_2025/MERT/utils/__pycache__/utilities.cpython-311.pyc ADDED
Binary file (16.1 kB). View file
 
ISMIR_2025/MERT/utils/config.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import csv
5
+
6
+ import numpy as np
7
+
8
+ sample_rate = 32000
9
+ clip_samples = sample_rate * 10 # Audio clips are 10-second
10
+
11
+ # Load label
12
+ with open(
13
+ "/gpfswork/rech/djl/uzj43um/audio_retrieval/audioset_tagging_cnn/metadata/class_labels_indices.csv",
14
+ "r",
15
+ ) as f:
16
+ reader = csv.reader(f, delimiter=",")
17
+ lines = list(reader)
18
+
19
+ labels = []
20
+ ids = [] # Each label has a unique id such as "/m/068hy"
21
+ for i1 in range(1, len(lines)):
22
+ id = lines[i1][1]
23
+ label = lines[i1][2]
24
+ ids.append(id)
25
+ labels.append(label)
26
+
27
+ classes_num = len(labels)
28
+
29
+ lb_to_ix = {label: i for i, label in enumerate(labels)}
30
+ ix_to_lb = {i: label for i, label in enumerate(labels)}
31
+
32
+ id_to_ix = {id: i for i, id in enumerate(ids)}
33
+ ix_to_id = {i: id for i, id in enumerate(ids)}
34
+
35
+ full_samples_per_class = np.array(
36
+ [
37
+ 937432,
38
+ 16344,
39
+ 7822,
40
+ 10271,
41
+ 2043,
42
+ 14420,
43
+ 733,
44
+ 1511,
45
+ 1258,
46
+ 424,
47
+ 1751,
48
+ 704,
49
+ 369,
50
+ 590,
51
+ 1063,
52
+ 1375,
53
+ 5026,
54
+ 743,
55
+ 853,
56
+ 1648,
57
+ 714,
58
+ 1497,
59
+ 1251,
60
+ 2139,
61
+ 1093,
62
+ 133,
63
+ 224,
64
+ 39469,
65
+ 6423,
66
+ 407,
67
+ 1559,
68
+ 4546,
69
+ 6826,
70
+ 7464,
71
+ 2468,
72
+ 549,
73
+ 4063,
74
+ 334,
75
+ 587,
76
+ 238,
77
+ 1766,
78
+ 691,
79
+ 114,
80
+ 2153,
81
+ 236,
82
+ 209,
83
+ 421,
84
+ 740,
85
+ 269,
86
+ 959,
87
+ 137,
88
+ 4192,
89
+ 485,
90
+ 1515,
91
+ 655,
92
+ 274,
93
+ 69,
94
+ 157,
95
+ 1128,
96
+ 807,
97
+ 1022,
98
+ 346,
99
+ 98,
100
+ 680,
101
+ 890,
102
+ 352,
103
+ 4169,
104
+ 2061,
105
+ 1753,
106
+ 9883,
107
+ 1339,
108
+ 708,
109
+ 37857,
110
+ 18504,
111
+ 12864,
112
+ 2475,
113
+ 2182,
114
+ 757,
115
+ 3624,
116
+ 677,
117
+ 1683,
118
+ 3583,
119
+ 444,
120
+ 1780,
121
+ 2364,
122
+ 409,
123
+ 4060,
124
+ 3097,
125
+ 3143,
126
+ 502,
127
+ 723,
128
+ 600,
129
+ 230,
130
+ 852,
131
+ 1498,
132
+ 1865,
133
+ 1879,
134
+ 2429,
135
+ 5498,
136
+ 5430,
137
+ 2139,
138
+ 1761,
139
+ 1051,
140
+ 831,
141
+ 2401,
142
+ 2258,
143
+ 1672,
144
+ 1711,
145
+ 987,
146
+ 646,
147
+ 794,
148
+ 25061,
149
+ 5792,
150
+ 4256,
151
+ 96,
152
+ 8126,
153
+ 2740,
154
+ 752,
155
+ 513,
156
+ 554,
157
+ 106,
158
+ 254,
159
+ 1592,
160
+ 556,
161
+ 331,
162
+ 615,
163
+ 2841,
164
+ 737,
165
+ 265,
166
+ 1349,
167
+ 358,
168
+ 1731,
169
+ 1115,
170
+ 295,
171
+ 1070,
172
+ 972,
173
+ 174,
174
+ 937780,
175
+ 112337,
176
+ 42509,
177
+ 49200,
178
+ 11415,
179
+ 6092,
180
+ 13851,
181
+ 2665,
182
+ 1678,
183
+ 13344,
184
+ 2329,
185
+ 1415,
186
+ 2244,
187
+ 1099,
188
+ 5024,
189
+ 9872,
190
+ 10948,
191
+ 4409,
192
+ 2732,
193
+ 1211,
194
+ 1289,
195
+ 4807,
196
+ 5136,
197
+ 1867,
198
+ 16134,
199
+ 14519,
200
+ 3086,
201
+ 19261,
202
+ 6499,
203
+ 4273,
204
+ 2790,
205
+ 8820,
206
+ 1228,
207
+ 1575,
208
+ 4420,
209
+ 3685,
210
+ 2019,
211
+ 664,
212
+ 324,
213
+ 513,
214
+ 411,
215
+ 436,
216
+ 2997,
217
+ 5162,
218
+ 3806,
219
+ 1389,
220
+ 899,
221
+ 8088,
222
+ 7004,
223
+ 1105,
224
+ 3633,
225
+ 2621,
226
+ 9753,
227
+ 1082,
228
+ 26854,
229
+ 3415,
230
+ 4991,
231
+ 2129,
232
+ 5546,
233
+ 4489,
234
+ 2850,
235
+ 1977,
236
+ 1908,
237
+ 1719,
238
+ 1106,
239
+ 1049,
240
+ 152,
241
+ 136,
242
+ 802,
243
+ 488,
244
+ 592,
245
+ 2081,
246
+ 2712,
247
+ 1665,
248
+ 1128,
249
+ 250,
250
+ 544,
251
+ 789,
252
+ 2715,
253
+ 8063,
254
+ 7056,
255
+ 2267,
256
+ 8034,
257
+ 6092,
258
+ 3815,
259
+ 1833,
260
+ 3277,
261
+ 8813,
262
+ 2111,
263
+ 4662,
264
+ 2678,
265
+ 2954,
266
+ 5227,
267
+ 1472,
268
+ 2591,
269
+ 3714,
270
+ 1974,
271
+ 1795,
272
+ 4680,
273
+ 3751,
274
+ 6585,
275
+ 2109,
276
+ 36617,
277
+ 6083,
278
+ 16264,
279
+ 17351,
280
+ 3449,
281
+ 5034,
282
+ 3931,
283
+ 2599,
284
+ 4134,
285
+ 3892,
286
+ 2334,
287
+ 2211,
288
+ 4516,
289
+ 2766,
290
+ 2862,
291
+ 3422,
292
+ 1788,
293
+ 2544,
294
+ 2403,
295
+ 2892,
296
+ 4042,
297
+ 3460,
298
+ 1516,
299
+ 1972,
300
+ 1563,
301
+ 1579,
302
+ 2776,
303
+ 1647,
304
+ 4535,
305
+ 3921,
306
+ 1261,
307
+ 6074,
308
+ 2922,
309
+ 3068,
310
+ 1948,
311
+ 4407,
312
+ 712,
313
+ 1294,
314
+ 1019,
315
+ 1572,
316
+ 3764,
317
+ 5218,
318
+ 975,
319
+ 1539,
320
+ 6376,
321
+ 1606,
322
+ 6091,
323
+ 1138,
324
+ 1169,
325
+ 7925,
326
+ 3136,
327
+ 1108,
328
+ 2677,
329
+ 2680,
330
+ 1383,
331
+ 3144,
332
+ 2653,
333
+ 1986,
334
+ 1800,
335
+ 1308,
336
+ 1344,
337
+ 122231,
338
+ 12977,
339
+ 2552,
340
+ 2678,
341
+ 7824,
342
+ 768,
343
+ 8587,
344
+ 39503,
345
+ 3474,
346
+ 661,
347
+ 430,
348
+ 193,
349
+ 1405,
350
+ 1442,
351
+ 3588,
352
+ 6280,
353
+ 10515,
354
+ 785,
355
+ 710,
356
+ 305,
357
+ 206,
358
+ 4990,
359
+ 5329,
360
+ 3398,
361
+ 1771,
362
+ 3022,
363
+ 6907,
364
+ 1523,
365
+ 8588,
366
+ 12203,
367
+ 666,
368
+ 2113,
369
+ 7916,
370
+ 434,
371
+ 1636,
372
+ 5185,
373
+ 1062,
374
+ 664,
375
+ 952,
376
+ 3490,
377
+ 2811,
378
+ 2749,
379
+ 2848,
380
+ 15555,
381
+ 363,
382
+ 117,
383
+ 1494,
384
+ 1647,
385
+ 5886,
386
+ 4021,
387
+ 633,
388
+ 1013,
389
+ 5951,
390
+ 11343,
391
+ 2324,
392
+ 243,
393
+ 372,
394
+ 943,
395
+ 734,
396
+ 242,
397
+ 3161,
398
+ 122,
399
+ 127,
400
+ 201,
401
+ 1654,
402
+ 768,
403
+ 134,
404
+ 1467,
405
+ 642,
406
+ 1148,
407
+ 2156,
408
+ 1368,
409
+ 1176,
410
+ 302,
411
+ 1909,
412
+ 61,
413
+ 223,
414
+ 1812,
415
+ 287,
416
+ 422,
417
+ 311,
418
+ 228,
419
+ 748,
420
+ 230,
421
+ 1876,
422
+ 539,
423
+ 1814,
424
+ 737,
425
+ 689,
426
+ 1140,
427
+ 591,
428
+ 943,
429
+ 353,
430
+ 289,
431
+ 198,
432
+ 490,
433
+ 7938,
434
+ 1841,
435
+ 850,
436
+ 457,
437
+ 814,
438
+ 146,
439
+ 551,
440
+ 728,
441
+ 1627,
442
+ 620,
443
+ 648,
444
+ 1621,
445
+ 2731,
446
+ 535,
447
+ 88,
448
+ 1736,
449
+ 736,
450
+ 328,
451
+ 293,
452
+ 3170,
453
+ 344,
454
+ 384,
455
+ 7640,
456
+ 433,
457
+ 215,
458
+ 715,
459
+ 626,
460
+ 128,
461
+ 3059,
462
+ 1833,
463
+ 2069,
464
+ 3732,
465
+ 1640,
466
+ 1508,
467
+ 836,
468
+ 567,
469
+ 2837,
470
+ 1151,
471
+ 2068,
472
+ 695,
473
+ 1494,
474
+ 3173,
475
+ 364,
476
+ 88,
477
+ 188,
478
+ 740,
479
+ 677,
480
+ 273,
481
+ 1533,
482
+ 821,
483
+ 1091,
484
+ 293,
485
+ 647,
486
+ 318,
487
+ 1202,
488
+ 328,
489
+ 532,
490
+ 2847,
491
+ 526,
492
+ 721,
493
+ 370,
494
+ 258,
495
+ 956,
496
+ 1269,
497
+ 1641,
498
+ 339,
499
+ 1322,
500
+ 4485,
501
+ 286,
502
+ 1874,
503
+ 277,
504
+ 757,
505
+ 1393,
506
+ 1330,
507
+ 380,
508
+ 146,
509
+ 377,
510
+ 394,
511
+ 318,
512
+ 339,
513
+ 1477,
514
+ 1886,
515
+ 101,
516
+ 1435,
517
+ 284,
518
+ 1425,
519
+ 686,
520
+ 621,
521
+ 221,
522
+ 117,
523
+ 87,
524
+ 1340,
525
+ 201,
526
+ 1243,
527
+ 1222,
528
+ 651,
529
+ 1899,
530
+ 421,
531
+ 712,
532
+ 1016,
533
+ 1279,
534
+ 124,
535
+ 351,
536
+ 258,
537
+ 7043,
538
+ 368,
539
+ 666,
540
+ 162,
541
+ 7664,
542
+ 137,
543
+ 70159,
544
+ 26179,
545
+ 6321,
546
+ 32236,
547
+ 33320,
548
+ 771,
549
+ 1169,
550
+ 269,
551
+ 1103,
552
+ 444,
553
+ 364,
554
+ 2710,
555
+ 121,
556
+ 751,
557
+ 1609,
558
+ 855,
559
+ 1141,
560
+ 2287,
561
+ 1940,
562
+ 3943,
563
+ 289,
564
+ ]
565
+ )
ISMIR_2025/MERT/utils/confusion_matrix_plot.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.metrics import confusion_matrix
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+
5
+ def plot_confusion_matrix(y_true, y_pred, classes, writer, epoch):
6
+ cm = confusion_matrix(y_true, y_pred)
7
+ fig, ax = plt.subplots(figsize=(6, 6))
8
+ im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
9
+ ax.figure.colorbar(im, ax=ax)
10
+
11
+ num_classes = cm.shape[0]
12
+ tick_labels = classes[:num_classes]
13
+
14
+ ax.set(xticks=np.arange(num_classes),
15
+ yticks=np.arange(num_classes),
16
+ xticklabels=tick_labels,
17
+ yticklabels=tick_labels,
18
+ ylabel='True label',
19
+ xlabel='Predicted label')
20
+
21
+ thresh = cm.max() / 2.
22
+ for i in range(cm.shape[0]):
23
+ for j in range(cm.shape[1]):
24
+ ax.text(j, i, format(cm[i, j], 'd'),
25
+ ha="center", va="center",
26
+ color="white" if cm[i, j] > thresh else "black")
27
+
28
+ fig.tight_layout()
29
+ writer.add_figure("Confusion Matrix", fig, epoch)
ISMIR_2025/MERT/utils/freqeuncy.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.display
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+
6
+ # 🔹 오디오 파일 로드
7
+ file_real = "/path/to/real_audio.wav" # Real 오디오 경로
8
+ file_fake = "/path/to/generative_audio.wav" # AI 생성 오디오 경로
9
+
10
+ def plot_spectrogram(audio_file, title):
11
+ y, sr = librosa.load(audio_file, sr=16000) # 샘플링 레이트 16kHz
12
+ D = librosa.amplitude_to_db(np.abs(librosa.stft(y)), ref=np.max) # STFT 변환
13
+
14
+ plt.figure(figsize=(10, 4))
15
+ librosa.display.specshow(D, sr=sr, x_axis='time', y_axis='hz', cmap='magma')
16
+ plt.colorbar(format='%+2.0f dB')
17
+ plt.title(title)
18
+ plt.ylim(4000, 16000) # 4kHz 이상 고주파 영역만 표시
19
+ plt.show()
20
+
21
+ # 🔹 Real vs Generative Spectrogram 비교
22
+ plot_spectrogram(file_real, "Real Audio Spectrogram (4kHz+)")
23
+ plot_spectrogram(file_fake, "Generative Audio Spectrogram (4kHz+)")
24
+
ISMIR_2025/MERT/utils/hf_vis.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.display
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import scipy.signal as signal
6
+ import torch
7
+ import torch.nn as nn
8
+ import soundfile as sf
9
+
10
+ from networks import audiocnn, AudioCNNWithViTDecoder, AudioCNNWithViTDecoderAndCrossAttention
11
+
12
+
13
+ def highpass_filter(y, sr, cutoff=500, order=5):
14
+ """High-pass filter to remove low frequencies below `cutoff` Hz."""
15
+ nyquist = 0.5 * sr
16
+ normal_cutoff = cutoff / nyquist
17
+ b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
18
+ y_filtered = signal.lfilter(b, a, y)
19
+ return y_filtered
20
+
21
+ def plot_combined_visualization(y_original, y_filtered, sr, save_path="combined_visualization.png"):
22
+ """Plot waveform comparison and spectrograms in a single figure."""
23
+ fig, axes = plt.subplots(3, 1, figsize=(12, 12))
24
+
25
+ # 1️⃣ Waveform Comparison
26
+ time = np.linspace(0, len(y_original) / sr, len(y_original))
27
+ axes[0].plot(time, y_original, label='Original', alpha=0.7)
28
+ axes[0].plot(time, y_filtered, label='High-pass Filtered', alpha=0.7, linestyle='dashed')
29
+ axes[0].set_xlabel("Time (s)")
30
+ axes[0].set_ylabel("Amplitude")
31
+ axes[0].set_title("Waveform Comparison (Original vs High-pass Filtered)")
32
+ axes[0].legend()
33
+
34
+ # 2️⃣ Spectrogram - Original
35
+ S_orig = librosa.amplitude_to_db(np.abs(librosa.stft(y_original)), ref=np.max)
36
+ img = librosa.display.specshow(S_orig, sr=sr, x_axis='time', y_axis='log', ax=axes[1])
37
+ axes[1].set_title("Original Spectrogram")
38
+ fig.colorbar(img, ax=axes[1], format="%+2.0f dB")
39
+
40
+ # 3️⃣ Spectrogram - High-pass Filtered
41
+ S_filt = librosa.amplitude_to_db(np.abs(librosa.stft(y_filtered)), ref=np.max)
42
+ img = librosa.display.specshow(S_filt, sr=sr, x_axis='time', y_axis='log', ax=axes[2])
43
+ axes[2].set_title("High-pass Filtered Spectrogram")
44
+ fig.colorbar(img, ax=axes[2], format="%+2.0f dB")
45
+
46
+ plt.tight_layout()
47
+ plt.savefig(save_path, dpi=300)
48
+ plt.show()
49
+
50
+
51
+ def load_model(checkpoint_path, model_class, device):
52
+ """Load a trained model from checkpoint."""
53
+ model = model_class()
54
+ model.load_state_dict(torch.load(checkpoint_path, map_location=device))
55
+ model.to(device)
56
+ model.eval()
57
+ return model
58
+
59
+ def predict_audio(model, audio_tensor, device):
60
+ """Make predictions using a trained model."""
61
+ with torch.no_grad():
62
+ audio_tensor = audio_tensor.unsqueeze(0).to(device) # Add batch dimension
63
+ output = model(audio_tensor)
64
+ prediction = torch.argmax(output, dim=1).cpu().numpy()[0]
65
+ return prediction
66
+
67
+ # Load audio
68
+ audio_path = "/data/kym/AI Music Detection/audio/FakeMusicCaps/real/musiccaps/_RrA-0lfIiU.wav" # Replace with actual file path
69
+ y, sr = librosa.load(audio_path, sr=None)
70
+ y_filtered = highpass_filter(y, sr, cutoff=500)
71
+
72
+ # Convert audio to tensor
73
+ audio_tensor = torch.tensor(librosa.feature.melspectrogram(y=y, sr=sr), dtype=torch.float).unsqueeze(0)
74
+ audio_tensor_filtered = torch.tensor(librosa.feature.melspectrogram(y=y_filtered, sr=sr), dtype=torch.float).unsqueeze(0)
75
+
76
+ # Load models
77
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78
+ original_model = load_model("/data/kym/AI Music Detection/AudioCNN/ckpt/FakeMusicCaps/pretraining/best_model_audiocnn.pth", audiocnn, device)
79
+ highpass_model = load_model("/data/kym/AI Music Detection/AudioCNN/ckpt/FakeMusicCaps/500hz_Add_crossattn_decoder/best_model_AudioCNNWithViTDecoderAndCrossAttention.pth", AudioCNNWithViTDecoderAndCrossAttention, device)
80
+
81
+ # Predict
82
+ original_pred = predict_audio(original_model, audio_tensor, device)
83
+ highpass_pred = predict_audio(highpass_model, audio_tensor_filtered, device)
84
+
85
+ print(f"Original Model Prediction: {original_pred}")
86
+ print(f"High-pass Filter Model Prediction: {highpass_pred}")
87
+
88
+ # Generate combined visualization (all plots in one image)
89
+ plot_combined_visualization(y, y_filtered, sr, save_path="/data/kym/AI Music Detection/AudioCNN/hf_vis/rawvs500.png")
ISMIR_2025/MERT/utils/idr_torch.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import os
5
+ import hostlist
6
+
7
+ # get SLURM variables
8
+ # rank = int(os.environ["SLURM_PROCID"])
9
+ local_rank = int(os.environ["SLURM_LOCALID"])
10
+ size = int(os.environ["SLURM_NTASKS"])
11
+ cpus_per_task = int(os.environ["SLURM_CPUS_PER_TASK"])
12
+
13
+ # get node list from slurm
14
+ hostnames = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])
15
+
16
+ # get IDs of reserved GPU
17
+ gpu_ids = os.environ["SLURM_STEP_GPUS"].split(",")
18
+
19
+ # define MASTER_ADD & MASTER_PORT
20
+ os.environ["MASTER_ADDR"] = hostnames[0]
21
+ os.environ["MASTER_PORT"] = str(
22
+ 12345 + int(min(gpu_ids))
23
+ ) # to avoid port conflict on the same node
ISMIR_2025/MERT/utils/mfcc.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import librosa
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+ from torch.utils.data import Dataset, DataLoader, random_split
9
+ import torch.nn.functional as F
10
+ from sklearn.metrics import precision_score, recall_score, f1_score
11
+ from tqdm import tqdm
12
+ import argparse
13
+ import wandb
14
+
15
+ class RealFakeDataset(Dataset):
16
+ """
17
+ audio/FakeMusicCaps/
18
+ ├─ real/
19
+ │ └─ MusicCaps/*.wav (label=0)
20
+ └─ generative/
21
+ └─ .../*.wav (label=1)
22
+ """
23
+ def __init__(self, root_dir, sr=16000, n_mels=64, target_duration=10.0):
24
+
25
+ self.sr = sr
26
+ self.n_mels = n_mels
27
+ self.target_duration = target_duration
28
+ self.target_samples = int(target_duration * sr) # 10초 = 160,000 샘플
29
+
30
+ self.file_paths = []
31
+ self.labels = []
32
+
33
+ # Real 데이터 (label=0)
34
+ real_dir = os.path.join(root_dir, "real")
35
+ real_wav_files = glob.glob(os.path.join(real_dir, "**", "*.wav"), recursive=True)
36
+ for f in real_wav_files:
37
+ self.file_paths.append(f)
38
+ self.labels.append(0)
39
+
40
+ # Generative 데이터 (label=1)
41
+ gen_dir = os.path.join(root_dir, "generative")
42
+ gen_wav_files = glob.glob(os.path.join(gen_dir, "**", "*.wav"), recursive=True)
43
+ for f in gen_wav_files:
44
+ self.file_paths.append(f)
45
+ self.labels.append(1)
46
+
47
+ def __len__(self):
48
+ return len(self.file_paths)
49
+
50
+ def __getitem__(self, idx):
51
+ audio_path = self.file_paths[idx]
52
+ label = self.labels[idx]
53
+ # print(f"[DEBUG] Path: {audio_path}, Label: {label}") # 추가
54
+
55
+ waveform, sr = librosa.load(audio_path, sr=self.sr, mono=True)
56
+
57
+ current_samples = waveform.shape[0]
58
+ if current_samples > self.target_samples:
59
+ waveform = waveform[:self.target_samples]
60
+ elif current_samples < self.target_samples:
61
+ stretch_factor = self.target_samples / current_samples
62
+ waveform = librosa.effects.time_stretch(waveform, rate=stretch_factor)
63
+ waveform = waveform[:self.target_samples]
64
+
65
+ mfcc = librosa.feature.mfcc(
66
+ y=waveform, sr=self.sr, n_mfcc=self.n_mels, n_fft=1024, hop_length=256
67
+ )
68
+ mfcc = librosa.util.normalize(mfcc)
69
+
70
+ mfcc = np.expand_dims(mfcc, axis=0)
71
+ mfcc_tensor = torch.tensor(mfcc, dtype=torch.float)
72
+ label_tensor = torch.tensor(label, dtype=torch.long)
73
+
74
+ return mfcc_tensor, label_tensor
75
+
76
+
77
+
78
+ class AudioCNN(nn.Module):
79
+ def __init__(self, num_classes=2):
80
+ super(AudioCNN, self).__init__()
81
+ self.conv_block = nn.Sequential(
82
+ nn.Conv2d(1, 16, kernel_size=3, padding=1),
83
+ nn.ReLU(),
84
+ nn.MaxPool2d(2),
85
+ nn.Conv2d(16, 32, kernel_size=3, padding=1),
86
+ nn.ReLU(),
87
+ nn.MaxPool2d(2),
88
+ nn.AdaptiveAvgPool2d((4,4)) # 최종 -> (B,32,4,4)
89
+ )
90
+ self.fc_block = nn.Sequential(
91
+ nn.Linear(32*4*4, 128),
92
+ nn.ReLU(),
93
+ nn.Linear(128, num_classes)
94
+ )
95
+
96
+
97
+ def forward(self, x):
98
+ x = self.conv_block(x)
99
+ # x.shape: (B,32,new_freq,new_time)
100
+
101
+ # 1) Flatten
102
+ B, C, H, W = x.shape # 동적 shape
103
+ x = x.view(B, -1) # (B, 32*H*W)
104
+
105
+ # 2) FC
106
+ x = self.fc_block(x)
107
+ return x
108
+
109
+
110
+ def my_collate_fn(batch):
111
+ mel_list, label_list = zip(*batch)
112
+
113
+ max_frames = max(m.shape[2] for m in mel_list)
114
+
115
+ padded = []
116
+ for m in mel_list:
117
+ diff = max_frames - m.shape[2]
118
+ if diff > 0:
119
+ print(f"Padding applied: Original frames = {m.shape[2]}, Target frames = {max_frames}")
120
+ m = F.pad(m, (0, diff), mode='constant', value=0)
121
+ padded.append(m)
122
+
123
+
124
+ mel_batch = torch.stack(padded, dim=0)
125
+ label_batch = torch.tensor(label_list, dtype=torch.long)
126
+ return mel_batch, label_batch
127
+
128
+
129
+ class EarlyStopping:
130
+ def __init__(self, patience=5, delta=0, path='./ckpt/mfcc/early_stop_best_batch_{batch_size}_epochs_{epochs}_lr_{learning_rate}.pth', verbose=False):
131
+ self.patience = patience
132
+ self.delta = delta
133
+ self.path = path
134
+ self.verbose = verbose
135
+ self.counter = 0
136
+ self.best_loss = None
137
+ self.early_stop = False
138
+
139
+ def __call__(self, val_loss, model):
140
+ if self.best_loss is None:
141
+ self.best_loss = val_loss
142
+ self._save_checkpoint(val_loss, model)
143
+ elif val_loss > self.best_loss - self.delta:
144
+ self.counter += 1
145
+ if self.verbose:
146
+ print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
147
+ if self.counter >= self.patience:
148
+ self.early_stop = True
149
+ else:
150
+ self.best_loss = val_loss
151
+ self._save_checkpoint(val_loss, model)
152
+ self.counter = 0
153
+
154
+ def _save_checkpoint(self, val_loss, model):
155
+ if self.verbose:
156
+ print(f"Validation loss decreased ({self.best_loss:.6f} --> {val_loss:.6f}). Saving model ...")
157
+ torch.save(model.state_dict(), self.path)
158
+
159
+ def train(batch_size, epochs, learning_rate, root_dir="audio/FakeMusicCaps"):
160
+ if not os.path.exists("./ckpt/mfcc/"):
161
+ os.makedirs("./ckpt/mfcc/")
162
+
163
+ wandb.init(
164
+ project="AI Music Detection",
165
+ name=f"mfcc_batch_{batch_size}_epochs_{epochs}_lr_{learning_rate}",
166
+ config={"batch_size": batch_size, "epochs": epochs, "learning_rate": learning_rate},
167
+ )
168
+
169
+ dataset = RealFakeDataset(root_dir=root_dir)
170
+ n_total = len(dataset)
171
+ n_train = int(n_total * 0.8)
172
+ n_val = n_total - n_train
173
+ train_ds, val_ds = random_split(dataset, [n_train, n_val])
174
+
175
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn)
176
+ val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=my_collate_fn)
177
+
178
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
179
+ model = AudioCNN(num_classes=2).to(device)
180
+ criterion = nn.CrossEntropyLoss()
181
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
182
+
183
+ best_val_loss = float('inf')
184
+ patience = 3
185
+ patience_counter = 0
186
+
187
+ for epoch in range(1, epochs + 1):
188
+ print(f"\n[Epoch {epoch}/{epochs}]")
189
+
190
+ # Training
191
+ model.train()
192
+ train_loss, train_correct, train_total = 0, 0, 0
193
+ train_pbar = tqdm(train_loader, desc="Train", leave=False)
194
+ for mel_batch, labels in train_pbar:
195
+ mel_batch, labels = mel_batch.to(device), labels.to(device)
196
+ optimizer.zero_grad()
197
+ outputs = model(mel_batch)
198
+ loss = criterion(outputs, labels)
199
+ loss.backward()
200
+ optimizer.step()
201
+
202
+ train_loss += loss.item() * mel_batch.size(0)
203
+ preds = outputs.argmax(dim=1)
204
+ train_correct += (preds == labels).sum().item()
205
+ train_total += labels.size(0)
206
+
207
+ train_pbar.set_postfix({"loss": f"{loss.item():.4f}"})
208
+
209
+ train_loss /= train_total
210
+ train_acc = train_correct / train_total
211
+
212
+ # Validation
213
+ model.eval()
214
+ val_loss, val_correct, val_total = 0, 0, 0
215
+ all_preds, all_labels = [], []
216
+ val_pbar = tqdm(val_loader, desc=" Val ", leave=False)
217
+ with torch.no_grad():
218
+ for mel_batch, labels in val_pbar:
219
+ mel_batch, labels = mel_batch.to(device), labels.to(device)
220
+ outputs = model(mel_batch)
221
+ loss = criterion(outputs, labels)
222
+ val_loss += loss.item() * mel_batch.size(0)
223
+ preds = outputs.argmax(dim=1)
224
+ val_correct += (preds == labels).sum().item()
225
+ val_total += labels.size(0)
226
+ all_preds.extend(preds.cpu().numpy())
227
+ all_labels.extend(labels.cpu().numpy())
228
+
229
+ val_loss /= val_total
230
+ val_acc = val_correct / val_total
231
+ val_precision = precision_score(all_labels, all_preds, average="macro")
232
+ val_recall = recall_score(all_labels, all_preds, average="macro")
233
+ val_f1 = f1_score(all_labels, all_preds, average="macro")
234
+
235
+ print(f"Train Loss: {train_loss:.4f} Acc: {train_acc:.3f} | "
236
+ f"Val Loss: {val_loss:.4f} Acc: {val_acc:.3f} "
237
+ f"Precision: {val_precision:.3f} Recall: {val_recall:.3f} F1: {val_f1:.3f}")
238
+
239
+ wandb.log({"train_loss": train_loss, "train_acc": train_acc,
240
+ "val_loss": val_loss, "val_acc": val_acc,
241
+ "val_precision": val_precision, "val_recall": val_recall, "val_f1": val_f1})
242
+
243
+ if val_loss < best_val_loss:
244
+ best_val_loss = val_loss
245
+ patience_counter = 0
246
+ best_model_path = f"./ckpt/mfcc/best_batch_{batch_size}_epochs_{epochs}_lr_{learning_rate}.pth"
247
+ torch.save(model.state_dict(), best_model_path)
248
+ print(f"[INFO] New best model saved: {best_model_path}")
249
+ else:
250
+ patience_counter += 1
251
+ if patience_counter >= patience:
252
+ print("Early stopping triggered!")
253
+ break
254
+
255
+ wandb.finish()
256
+
257
+ if __name__ == "__main__":
258
+ parser = argparse.ArgumentParser(description="Train AI Music Detection model.")
259
+ parser.add_argument('--batch_size', type=int, required=True, help="Batch size for training")
260
+ parser.add_argument('--epochs', type=int, required=True, help="Number of epochs")
261
+ parser.add_argument('--learning_rate', type=float, required=True, help="Learning rate")
262
+ parser.add_argument('--root_dir', type=str, default="audio/FakeMusicCaps", help="Root directory for dataset")
263
+
264
+ args = parser.parse_args()
265
+
266
+ train(batch_size=args.batch_size, epochs=args.epochs, learning_rate=args.learning_rate, root_dir=args.root_dir)
ISMIR_2025/MERT/utils/utilities.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os
5
+ import logging
6
+ import pickle
7
+
8
+ import numpy as np
9
+
10
+ from scipy import stats
11
+
12
+ import csv
13
+ import json
14
+
15
+ def create_folder(fd):
16
+ if not os.path.exists(fd):
17
+ os.makedirs(fd, exist_ok=True)
18
+
19
+
20
+ def get_filename(path):
21
+ path = os.path.realpath(path)
22
+ na_ext = path.split("/")[-1]
23
+ na = os.path.splitext(na_ext)[0]
24
+ return na
25
+
26
+
27
+ def get_sub_filepaths(folder):
28
+ paths = []
29
+ for root, dirs, files in os.walk(folder):
30
+ for name in files:
31
+ path = os.path.join(root, name)
32
+ paths.append(path)
33
+ return paths
34
+
35
+
36
+ def create_logging(log_dir, filemode):
37
+ create_folder(log_dir)
38
+ i1 = 0
39
+
40
+ while os.path.isfile(os.path.join(log_dir, "{:04d}.log".format(i1))):
41
+ i1 += 1
42
+
43
+ log_path = os.path.join(log_dir, "{:04d}.log".format(i1))
44
+ logging.basicConfig(
45
+ level=logging.DEBUG,
46
+ format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s",
47
+ datefmt="%a, %d %b %Y %H:%M:%S",
48
+ filename=log_path,
49
+ filemode=filemode,
50
+ )
51
+
52
+ # Print to console
53
+ console = logging.StreamHandler()
54
+ console.setLevel(logging.INFO)
55
+ formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s")
56
+ console.setFormatter(formatter)
57
+ logging.getLogger("").addHandler(console)
58
+
59
+ return logging
60
+
61
+
62
+ def read_metadata(csv_path, audio_dir, classes_num, id_to_ix):
63
+ """Read metadata of AudioSet from a csv file.
64
+
65
+ Args:
66
+ csv_path: str
67
+
68
+ Returns:
69
+ meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)}
70
+ """
71
+
72
+ with open(csv_path, "r") as fr:
73
+ lines = fr.readlines()
74
+ lines = lines[3:] # Remove heads
75
+
76
+ # first, count the audio names only of existing files on disk only
77
+
78
+ audios_num = 0
79
+ for n, line in enumerate(lines):
80
+ items = line.split(", ")
81
+ """items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']"""
82
+
83
+ # audio_name = 'Y{}.wav'.format(items[0]) # Audios are started with an extra 'Y' when downloading
84
+ audio_name = "{}_{}_{}.flac".format(
85
+ items[0], items[1].replace(".", ""), items[2].replace(".", "")
86
+ )
87
+ audio_name = audio_name.replace("_0000_", "_0_")
88
+
89
+ if os.path.exists(os.path.join(audio_dir, audio_name)):
90
+ audios_num += 1
91
+
92
+ print("CSV audio files: %d" % (len(lines)))
93
+ print("Existing audio files: %d" % audios_num)
94
+
95
+ # audios_num = len(lines)
96
+ targets = np.zeros((audios_num, classes_num), dtype=bool)
97
+ audio_names = []
98
+
99
+ n = 0
100
+ for line in lines:
101
+ items = line.split(", ")
102
+ """items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']"""
103
+
104
+ # audio_name = 'Y{}.wav'.format(items[0]) # Audios are started with an extra 'Y' when downloading
105
+ audio_name = "{}_{}_{}.flac".format(
106
+ items[0], items[1].replace(".", ""), items[2].replace(".", "")
107
+ )
108
+ audio_name = audio_name.replace("_0000_", "_0_")
109
+
110
+ if not os.path.exists(os.path.join(audio_dir, audio_name)):
111
+ continue
112
+
113
+ label_ids = items[3].split('"')[1].split(",")
114
+
115
+ audio_names.append(audio_name)
116
+
117
+ # Target
118
+ for id in label_ids:
119
+ ix = id_to_ix[id]
120
+ targets[n, ix] = 1
121
+ n += 1
122
+
123
+ meta_dict = {"audio_name": np.array(audio_names), "target": targets}
124
+ return meta_dict
125
+
126
+
127
+ def read_audioset_ontology(id_to_ix):
128
+ with open('../metadata/audioset_ontology.json', 'r') as f:
129
+ data = json.load(f)
130
+
131
+ # Output: {'name': 'Bob', 'languages': ['English', 'French']}
132
+ sentences = []
133
+ for el in data:
134
+ print(el.keys())
135
+ id = el['id']
136
+ if id in id_to_ix:
137
+ name = el['name']
138
+ desc = el['description']
139
+ # if '(' in desc:
140
+ # print(name, '---', desc)
141
+ # print(id_to_ix[id], name, '---', )
142
+
143
+ # sent = name
144
+ # sent = name + ', ' + desc.replace('(', '').replace(')', '').lower()
145
+ # sent = desc.replace('(', '').replace(')', '').lower()
146
+ # sentences.append(sent)
147
+ sentences.append(desc)
148
+ # print(sent)
149
+ # break
150
+ return sentences
151
+
152
+
153
+ def original_read_metadata(csv_path, classes_num, id_to_ix):
154
+ """Read metadata of AudioSet from a csv file.
155
+
156
+ Args:
157
+ csv_path: str
158
+
159
+ Returns:
160
+ meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)}
161
+ """
162
+
163
+ with open(csv_path, "r") as fr:
164
+ lines = fr.readlines()
165
+ lines = lines[3:] # Remove heads
166
+
167
+ # Thomas Pellegrini: added 02/12/2022
168
+ # check if the audio files indeed exist, otherwise remove from list
169
+
170
+ audios_num = len(lines)
171
+ targets = np.zeros((audios_num, classes_num), dtype=bool)
172
+ audio_names = []
173
+
174
+ for n, line in enumerate(lines):
175
+ items = line.split(", ")
176
+ """items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']"""
177
+
178
+ audio_name = "{}_{}_{}.flac".format(
179
+ items[0], items[1].replace(".", ""), items[2].replace(".", "")
180
+ ) # Audios are started with an extra 'Y' when downloading
181
+ audio_name = audio_name.replace("_0000_", "_0_")
182
+
183
+ label_ids = items[3].split('"')[1].split(",")
184
+
185
+ audio_names.append(audio_name)
186
+
187
+ # Target
188
+ for id in label_ids:
189
+ ix = id_to_ix[id]
190
+ targets[n, ix] = 1
191
+
192
+ meta_dict = {"audio_name": np.array(audio_names), "target": targets}
193
+ return meta_dict
194
+
195
+ def read_audioset_label_tags(class_labels_indices_csv):
196
+ with open(class_labels_indices_csv, 'r') as f:
197
+ reader = csv.reader(f, delimiter=',')
198
+ lines = list(reader)
199
+
200
+ labels = []
201
+ ids = [] # Each label has a unique id such as "/m/068hy"
202
+ for i1 in range(1, len(lines)):
203
+ id = lines[i1][1]
204
+ label = lines[i1][2]
205
+ ids.append(id)
206
+ labels.append(label)
207
+
208
+ classes_num = len(labels)
209
+
210
+ lb_to_ix = {label : i for i, label in enumerate(labels)}
211
+ ix_to_lb = {i : label for i, label in enumerate(labels)}
212
+
213
+ id_to_ix = {id : i for i, id in enumerate(ids)}
214
+ ix_to_id = {i : id for i, id in enumerate(ids)}
215
+
216
+ return lb_to_ix, ix_to_lb, id_to_ix, ix_to_id
217
+
218
+
219
+
220
+ def float32_to_int16(x):
221
+ # assert np.max(np.abs(x)) <= 1.5
222
+ x = np.clip(x, -1, 1)
223
+ return (x * 32767.0).astype(np.int16)
224
+
225
+
226
+ def int16_to_float32(x):
227
+ return (x / 32767.0).astype(np.float32)
228
+
229
+
230
+ def pad_or_truncate(x, audio_length):
231
+ """Pad all audio to specific length."""
232
+ if len(x) <= audio_length:
233
+ return np.concatenate((x, np.zeros(audio_length - len(x))), axis=0)
234
+ else:
235
+ return x[0:audio_length]
236
+
237
+
238
+ def pad_audio(x, audio_length):
239
+ """Pad all audio to specific length."""
240
+ if len(x) <= audio_length:
241
+ return np.concatenate((x, np.zeros(audio_length - len(x))), axis=0)
242
+ else:
243
+ return x
244
+
245
+
246
+ def d_prime(auc):
247
+ d_prime = stats.norm().ppf(auc) * np.sqrt(2.0)
248
+ return d_prime
249
+
250
+
251
+ class Mixup(object):
252
+ def __init__(self, mixup_alpha, random_seed=1234):
253
+ """Mixup coefficient generator."""
254
+ self.mixup_alpha = mixup_alpha
255
+ self.random_state = np.random.RandomState(random_seed)
256
+
257
+ def get_lambda(self, batch_size):
258
+ """Get mixup random coefficients.
259
+ Args:
260
+ batch_size: int
261
+ Returns:
262
+ mixup_lambdas: (batch_size,)
263
+ """
264
+ mixup_lambdas = []
265
+ for n in range(0, batch_size, 2):
266
+ lam = self.random_state.beta(self.mixup_alpha, self.mixup_alpha, 1)[0]
267
+ mixup_lambdas.append(lam)
268
+ mixup_lambdas.append(1.0 - lam)
269
+
270
+ return np.array(mixup_lambdas)
271
+
272
+
273
+ class StatisticsContainer(object):
274
+ def __init__(self, statistics_path):
275
+ """Contain statistics of different training iterations."""
276
+ self.statistics_path = statistics_path
277
+
278
+ self.backup_statistics_path = "{}_{}.pkl".format(
279
+ os.path.splitext(self.statistics_path)[0],
280
+ datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
281
+ )
282
+
283
+ self.statistics_dict = {"bal": [], "test": []}
284
+
285
+ def append(self, iteration, statistics, data_type):
286
+ statistics["iteration"] = iteration
287
+ self.statistics_dict[data_type].append(statistics)
288
+
289
+ def dump(self):
290
+ pickle.dump(self.statistics_dict, open(self.statistics_path, "wb"))
291
+ pickle.dump(self.statistics_dict, open(self.backup_statistics_path, "wb"))
292
+ logging.info(" Dump statistics to {}".format(self.statistics_path))
293
+ logging.info(" Dump statistics to {}".format(self.backup_statistics_path))
294
+
295
+ def load_state_dict(self, resume_iteration):
296
+ self.statistics_dict = pickle.load(open(self.statistics_path, "rb"))
297
+
298
+ resume_statistics_dict = {"bal": [], "test": []}
299
+
300
+ for key in self.statistics_dict.keys():
301
+ for statistics in self.statistics_dict[key]:
302
+ if statistics["iteration"] <= resume_iteration:
303
+ resume_statistics_dict[key].append(statistics)
304
+
305
+ self.statistics_dict = resume_statistics_dict
ISMIR_2025/Model/__pycache__/networks.cpython-312.pyc ADDED
Binary file (10.5 kB). View file
 
ISMIR_2025/Model/datalib.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import random
4
+ import torch
5
+ import librosa
6
+ import numpy as np
7
+ import utils
8
+ from sklearn.model_selection import train_test_split
9
+ from torch.utils.data import Dataset, DataLoader
10
+ import scipy.signal as signal
11
+ import scipy.signal
12
+ from scipy.signal import butter, lfilter
13
+ import numpy as np
14
+ import scipy.signal as signal
15
+ import librosa
16
+ import torch
17
+ import random
18
+ from torch.utils.data import Dataset
19
+ import logging
20
+ import csv
21
+ import logging
22
+ import time
23
+ import numpy as np
24
+ import h5py
25
+ import torch
26
+ import torchaudio
27
+ # Oversampling Lib
28
+ from imblearn.over_sampling import RandomOverSampler
29
+
30
+ class FakeMusicCapsDataset(Dataset):
31
+ def __init__(self, file_paths, labels, feat_type=['mel'], sr=16000, n_mels=64, target_duration=10.0, augment=True, augment_real=True):
32
+ self.file_paths = file_paths
33
+ self.labels = labels
34
+ self.feat_type = feat_type
35
+ self.sr = sr
36
+ self.n_mels = n_mels
37
+ self.target_duration = target_duration
38
+ self.target_samples = int(target_duration * sr)
39
+ self.augment = augment
40
+ self.augment_real = augment_real
41
+
42
+
43
+ def pre_emphasis(self, x, alpha=0.97):
44
+ return np.append(x[0], x[1:] - alpha * x[:-1])
45
+
46
+ def highpass_filter(self, y, sr, cutoff=1000, order=5):
47
+ nyquist = 0.5 * sr
48
+ normal_cutoff = cutoff / nyquist
49
+ b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
50
+ return signal.lfilter(b, a, y)
51
+
52
+ def augment_audio(self, y, sr):
53
+ if random.random() < 0.5:
54
+ rate = random.uniform(0.8, 1.2)
55
+ y = librosa.effects.time_stretch(y=y, rate=rate)
56
+
57
+ if random.random() < 0.5:
58
+ n_steps = random.randint(-2, 2)
59
+ y = librosa.effects.pitch_shift(y=y, sr=sr, n_steps=n_steps)
60
+
61
+ if random.random() < 0.5:
62
+ noise_level = np.random.uniform(0.001, 0.005)
63
+ y = y + np.random.normal(0, noise_level, y.shape)
64
+
65
+ if random.random() < 0.5:
66
+ gain = np.random.uniform(0.9, 1.1)
67
+ y = y * gain
68
+
69
+ return y
70
+
71
+
72
+ def __len__(self):
73
+ return len(self.file_paths)
74
+
75
+ def __getitem__(self, idx):
76
+ """
77
+ Load and preprocess audio file.
78
+ """
79
+ audio_path = self.file_paths[idx]
80
+ label = self.labels[idx]
81
+
82
+ waveform, sr = librosa.load(audio_path, sr=self.sr, mono=True)
83
+ if label == 0:
84
+ if self.augment_real:
85
+ waveform = self.augment_audio(waveform, self.sr)
86
+ if label == 1:
87
+ waveform = self.highpass_filter(waveform, self.sr)
88
+ waveform = self.augment_audio(waveform, self.sr)
89
+
90
+ current_samples = waveform.shape[0]
91
+ if current_samples > self.target_samples:
92
+ start_idx = (current_samples - self.target_samples) // 2
93
+ waveform = waveform[start_idx:start_idx + self.target_samples]
94
+ elif current_samples < self.target_samples:
95
+ waveform = np.pad(waveform, (0, self.target_samples - current_samples), mode='constant')
96
+
97
+
98
+ mel_spec = librosa.feature.melspectrogram(
99
+ y=waveform, sr=self.sr, n_mels=self.n_mels, n_fft=1024, hop_length=256
100
+ )
101
+ log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
102
+
103
+ log_mel_spec = np.expand_dims(log_mel_spec, axis=0)
104
+ mel_tensor = torch.tensor(log_mel_spec, dtype=torch.float)
105
+ label_tensor = torch.tensor(label, dtype=torch.long)
106
+
107
+ return mel_tensor, label_tensor
108
+
109
+ def extract_feature(self, waveform, feat):
110
+ """Extracts specified feature (mel, stft, cqt) from waveform."""
111
+ try:
112
+ if feat == 'mel':
113
+ mel_spec = librosa.feature.melspectrogram(y=waveform, sr=self.sr, n_mels=self.n_mels, n_fft=1024, hop_length=256)
114
+ log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
115
+ return torch.tensor(log_mel_spec, dtype=torch.float).unsqueeze(0)
116
+ elif feat == 'stft':
117
+ stft = librosa.stft(waveform, n_fft=512, hop_length=128, window="hann")
118
+ logSTFT = np.log(np.abs(stft) + 1e-3)
119
+ return torch.tensor(logSTFT, dtype=torch.float).unsqueeze(0)
120
+ elif feat == 'cqt':
121
+ cqt = librosa.cqt(waveform, sr=self.sr, hop_length=128, bins_per_octave=24)
122
+ logCQT = np.log(np.abs(cqt) + 1e-3)
123
+ return torch.tensor(logCQT, dtype=torch.float).unsqueeze(0)
124
+ else:
125
+ raise ValueError(f"[ERROR] Unsupported feature type: {feat}")
126
+ except Exception as e:
127
+ print(f"[ERROR] Feature extraction failed for {feat}: {e}")
128
+ return None
129
+
130
+ def highpass_filter(self, y, sr, cutoff=1000, order=5):
131
+ if isinstance(sr, np.ndarray):
132
+ sr = np.mean(sr)
133
+ if not isinstance(sr, (int, float)):
134
+ raise ValueError(f"[ERROR] sr must be a number, but got {type(sr)}: {sr}")
135
+ if sr <= 0:
136
+ raise ValueError(f"Invalid sample rate: {sr}. It must be greater than 0.")
137
+ nyquist = 0.5 * sr
138
+ if cutoff <= 0 or cutoff >= nyquist:
139
+ print(f"[WARNING] Invalid cutoff frequency {cutoff}, adjusting...")
140
+ cutoff = max(10, min(cutoff, nyquist - 1))
141
+ normal_cutoff = cutoff / nyquist
142
+ b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
143
+ y_filtered = signal.lfilter(b, a, y)
144
+ return y_filtered
145
+
146
+ def preprocess_audio(audio_path, sr=16000, n_mels=64, target_duration=10.0):
147
+ try:
148
+ waveform, _ = librosa.load(audio_path, sr=sr, mono=True)
149
+
150
+ target_samples = int(target_duration * sr)
151
+ if len(waveform) > target_samples:
152
+ start_idx = (len(waveform) - target_samples) // 2
153
+ waveform = waveform[start_idx:start_idx + target_samples]
154
+ elif len(waveform) < target_samples:
155
+ waveform = np.pad(waveform, (0, target_samples - len(waveform)), mode='constant')
156
+ mel_spec = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=n_mels, n_fft=1024, hop_length=256)
157
+ log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
158
+ return torch.tensor(log_mel_spec, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
159
+
160
+ except Exception as e:
161
+ print(f"[ERROR] 전처리 실패: {audio_path} | 오류: {e}")
162
+ return None
163
+
164
+
165
+ DATASET_PATH = "/data/kym/AI_Music_Detection/audio/FakeMusicCaps"
166
+ SUNOCAPS_PATH = "/data/kym/Audio/SunoCaps" # Open Set 포함 데이터
167
+
168
+ real_files = glob.glob(os.path.join(DATASET_PATH, "real", "**", "*.wav"), recursive=True)
169
+ gen_files = glob.glob(os.path.join(DATASET_PATH, "generative", "**", "*.wav"), recursive=True)
170
+
171
+ open_real_files = real_files + glob.glob(os.path.join(SUNOCAPS_PATH, "real", "**", "*.wav"), recursive=True)
172
+ open_gen_files = gen_files + glob.glob(os.path.join(SUNOCAPS_PATH, "generative", "**", "*.wav"), recursive=True)
173
+
174
+ real_labels = [0] * len(real_files)
175
+ gen_labels = [1] * len(gen_files)
176
+
177
+ open_real_labels = [0] * len(open_real_files)
178
+ open_gen_labels = [1] * len(open_gen_files)
179
+
180
+ real_train, real_val, real_train_labels, real_val_labels = train_test_split(real_files, real_labels, test_size=0.2, random_state=42)
181
+ gen_train, gen_val, gen_train_labels, gen_val_labels = train_test_split(gen_files, gen_labels, test_size=0.2, random_state=42)
182
+
183
+ train_files = real_train + gen_train
184
+ train_labels = real_train_labels + gen_train_labels
185
+ val_files = real_val + gen_val
186
+ val_labels = real_val_labels + gen_val_labels
187
+
188
+ closed_test_files = real_files + gen_files
189
+ closed_test_labels = real_labels + gen_labels
190
+
191
+ open_test_files = open_real_files + open_gen_files
192
+ open_test_labels = open_real_labels + open_gen_labels
193
+
194
+ ros = RandomOverSampler(sampling_strategy='auto', random_state=42)
195
+ train_files_resampled, train_labels_resampled = ros.fit_resample(np.array(train_files).reshape(-1, 1), train_labels)
196
+
197
+ train_files = train_files_resampled.reshape(-1).tolist()
198
+ train_labels = train_labels_resampled
199
+ print(f"type(train_labels_resampled): {type(train_labels_resampled)}")
200
+
201
+ print(f"Train Org Fake: {len(gen_val)}")
202
+ print(f"Train set (Oversampled) - Real: {sum(1 for label in train_labels if label == 0)}, "
203
+ f"Fake: {sum(1 for label in train_labels if label == 1)}, Total: {len(train_files)}")
204
+ print(f"Validation set - Real: {len(real_val)}, Fake: {len(gen_val)}, Total: {len(val_files)}")
205
+ print(f"Closed Test set - Real: {len(real_files)}, Fake: {len(gen_files)}, Total: {len(closed_test_files)}")
206
+ print(f"Open Test set - Real: {len(open_real_files)}, Fake: {len(open_gen_files)}, Total: {len(open_test_files)}")
ISMIR_2025/Model/main.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+ from tqdm import tqdm
9
+ from torch.utils.tensorboard import SummaryWriter
10
+ import wandb
11
+ import matplotlib.pyplot as plt
12
+ from torch.utils.data import DataLoader
13
+ from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, balanced_accuracy_score
14
+ from datalib import FakeMusicCapsDataset
15
+ from datalib import (
16
+ FakeMusicCapsDataset,
17
+ train_files, val_files, train_labels, val_labels,
18
+ closed_test_files, closed_test_labels,
19
+ open_test_files, open_test_labels,
20
+ preprocess_audio
21
+ )
22
+ from datalib import preprocess_audio
23
+ from networks import CCV
24
+ from attentionmap import visualize_attention_map
25
+ from confusion_matrix import plot_confusion_matrix
26
+
27
+ def count_parameters(model):
28
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
29
+ '''
30
+ python3 main.py --model_name CCV --batch_size 32 --epochs 10 --loss_type ce --oversample True
31
+
32
+ audiocnn encoder - crossattn based decoder (ViT) model
33
+ '''
34
+ # Argument parsing
35
+ import argparse
36
+ parser = argparse.ArgumentParser(description='AI Music Detection Training')
37
+ parser.add_argument('--gpu', type=str, default='1', help='GPU ID')
38
+ parser.add_argument('--model_name', type=str, choices=['audiocnn', 'CCV'], default='CCV', help='Model name')
39
+ parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
40
+ parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
41
+ parser.add_argument('--epochs', type=int, default=10, help='Number of epochs')
42
+ parser.add_argument('--audio_duration', type=float, default=10, help='Length of the audio slice in seconds')
43
+ parser.add_argument('--patience_counter', type=int, default=5, help='Early stopping patience')
44
+ parser.add_argument('--log_dir', type=str, default='', help='TensorBoard log directory')
45
+ parser.add_argument('--ckpt_path', type=str, default='', help='Checkpoint directory')
46
+ parser.add_argument("--weight_decay", type=float, default=0.05, help="weight decay (default: 0.0)")
47
+ parser.add_argument("--loss_type", type=str, choices=["ce", "weighted_ce", "focal"], default="ce", help="Loss function type")
48
+
49
+ parser.add_argument('--inference', type=str, help='Path to a .wav file for inference')
50
+ parser.add_argument("--closed_test", action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)")
51
+ parser.add_argument("--open_test", action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)")
52
+ parser.add_argument("--oversample", type=bool, default=True, help="Apply Oversampling to balance classes") # real data oversampling
53
+
54
+
55
+ args = parser.parse_args()
56
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
57
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
+
59
+ torch.manual_seed(42)
60
+ random.seed(42)
61
+ np.random.seed(42)
62
+ wandb.init(project="",
63
+ name=f"{args.model_name}_lr{args.learning_rate}_ep{args.epochs}_bs{args.batch_size}", config=args)
64
+
65
+ if args.model_name == 'CCV':
66
+ model = CCV(embed_dim=512, num_heads=8, num_layers=6, num_classes=2).cuda()
67
+ feat_type = 'mel'
68
+ else:
69
+ raise ValueError(f"Invalid model name: {args.model_name}")
70
+
71
+ model = model.to(device)
72
+ print(f"Using model: {args.model_name}, Parameters: {count_parameters(model)}")
73
+ print(f"weight_decay WD: {args.weight_decay}")
74
+
75
+ optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
76
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
77
+
78
+ if args.loss_type == "ce":
79
+ print("Using CrossEntropyLoss")
80
+ criterion = nn.CrossEntropyLoss()
81
+
82
+ elif args.loss_type == "weighted_ce":
83
+ print("Using Weighted CrossEntropyLoss")
84
+
85
+ num_real = sum(1 for label in train_labels if label == 0)
86
+ num_fake = sum(1 for label in train_labels if label == 1)
87
+
88
+ total_samples = num_real + num_fake
89
+ weight_real = total_samples / (2 * num_real)
90
+ weight_fake = total_samples / (2 * num_fake)
91
+ class_weights = torch.tensor([weight_real, weight_fake]).to(device)
92
+
93
+ criterion = nn.CrossEntropyLoss(weight=class_weights)
94
+
95
+ elif args.loss_type == "focal":
96
+ print("Using Focal Loss")
97
+
98
+ class FocalLoss(torch.nn.Module):
99
+ def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
100
+ super(FocalLoss, self).__init__()
101
+ self.alpha = alpha
102
+ self.gamma = gamma
103
+ self.reduction = reduction
104
+
105
+ def forward(self, inputs, targets):
106
+ ce_loss = F.cross_entropy(inputs, targets, reduction='none')
107
+ pt = torch.exp(-ce_loss)
108
+ focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
109
+
110
+ if self.reduction == 'mean':
111
+ return focal_loss.mean()
112
+ elif self.reduction == 'sum':
113
+ return focal_loss.sum()
114
+ else:
115
+ return focal_loss
116
+
117
+ criterion = FocalLoss().to(device)
118
+
119
+ if not os.path.exists(args.ckpt_path):
120
+ os.makedirs(args.ckpt_path)
121
+
122
+ train_dataset = FakeMusicCapsDataset(train_files, train_labels, feat_type=feat_type, target_duration=args.audio_duration)
123
+ val_dataset = FakeMusicCapsDataset(val_files, val_labels, feat_type=feat_type, target_duration=args.audio_duration)
124
+
125
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16)
126
+ val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16)
127
+
128
+ def train(model, train_loader, val_loader, optimizer, scheduler, criterion, device, args):
129
+ writer = SummaryWriter(log_dir=args.log_dir)
130
+ best_val_bal_acc = float('inf')
131
+ early_stop_cnt = 0
132
+ log_interval = 1
133
+
134
+ for epoch in range(args.epochs):
135
+ print(f"\n[Epoch {epoch + 1}/{args.epochs}]")
136
+ model.train()
137
+ train_loss, train_correct, train_total = 0, 0, 0
138
+
139
+ all_train_preds= []
140
+ all_train_labels = []
141
+ attention_maps = []
142
+
143
+ train_pbar = tqdm(train_loader, desc="Train", leave=False)
144
+ for batch_idx, (data, target) in enumerate(train_pbar):
145
+ data = data.to(device)
146
+ target = target.to(device)
147
+ output = model(data)
148
+ loss = criterion(output, target)
149
+
150
+ optimizer.zero_grad()
151
+ loss.backward()
152
+ optimizer.step()
153
+
154
+ train_loss += loss.item() * data.size(0)
155
+ preds = output.argmax(dim=1)
156
+ train_correct += (preds == target).sum().item()
157
+ train_total += target.size(0)
158
+
159
+ all_train_labels.extend(target.cpu().numpy())
160
+ all_train_preds.extend(preds.cpu().numpy())
161
+
162
+ if hasattr(model, "get_attention_maps"):
163
+ attention_maps.append(model.get_attention_maps())
164
+
165
+ train_loss /= train_total
166
+ train_acc = train_correct / train_total
167
+ train_bal_acc = balanced_accuracy_score(all_train_labels, all_train_preds)
168
+ train_precision = precision_score(all_train_labels, all_train_preds, average="binary")
169
+ train_recall = recall_score(all_train_labels, all_train_preds, average="binary")
170
+ train_f1 = f1_score(all_train_labels, all_train_preds, average="binary")
171
+
172
+ wandb.log({
173
+ "Train Loss": train_loss, "Train Accuracy": train_acc,
174
+ "Train Precision": train_precision, "Train Recall": train_recall,
175
+ "Train F1 Score": train_f1, "Train B_ACC": train_bal_acc,
176
+ })
177
+
178
+ print(f"Train Epoch: {epoch+1} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.3f} | "
179
+ f"Train B_ACC: {train_bal_acc:.4f} | Train Prec: {train_precision:.3f} | "
180
+ f"Train Rec: {train_recall:.3f} | Train F1: {train_f1:.3f}")
181
+
182
+ model.eval()
183
+ val_loss, val_correct, val_total = 0, 0, 0
184
+ all_val_preds, all_val_labels = [], []
185
+ attention_maps = []
186
+ val_pbar = tqdm(val_loader, desc=" Val ", leave=False)
187
+ with torch.no_grad():
188
+ for data, target in val_pbar:
189
+ data, target = data.to(device), target.to(device)
190
+ output = model(data)
191
+ loss = criterion(output, target)
192
+ val_loss += loss.item() * data.size(0)
193
+ preds = output.argmax(dim=1)
194
+ val_correct += (preds == target).sum().item()
195
+ val_total += target.size(0)
196
+
197
+ all_val_labels.extend(target.cpu().numpy())
198
+ all_val_preds.extend(preds.cpu().numpy())
199
+
200
+ if hasattr(model, "get_attention_maps"):
201
+ attention_maps.append(model.get_attention_maps())
202
+
203
+ val_loss /= val_total
204
+ val_acc = val_correct / val_total
205
+ val_bal_acc = balanced_accuracy_score(all_val_labels, all_val_preds)
206
+ val_precision = precision_score(all_val_labels, all_val_preds, average="binary")
207
+ val_recall = recall_score(all_val_labels, all_val_preds, average="binary")
208
+ val_f1 = f1_score(all_val_labels, all_val_preds, average="binary")
209
+
210
+ wandb.log({
211
+ "Validation Loss": val_loss, "Validation Accuracy": val_acc,
212
+ "Validation Precision": val_precision, "Validation Recall": val_recall,
213
+ "Validation F1 Score": val_f1, "Validation B_ACC": val_bal_acc,
214
+ })
215
+
216
+ print(f"Val Epoch: {epoch+1} [{batch_idx * len(data)}/{len(val_loader.dataset)} "
217
+ f"({100. * batch_idx / len(val_loader):.0f}%)]\t"
218
+ f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.3f} | "
219
+ f"Val B_ACC: {val_bal_acc:.4f} | Val Prec: {val_precision:.3f} | "
220
+ f"Val Rec: {val_recall:.3f} | Val F1: {val_f1:.3f}")
221
+
222
+ if epoch % 1 == 0 and len(attention_maps) > 0:
223
+ print(f"Visualizing Attention Map at Epoch {epoch+1}")
224
+
225
+ if isinstance(attention_maps[0], list):
226
+ attn_map_numpy = np.array([t.detach().cpu().numpy() for t in attention_maps[0]])
227
+ elif isinstance(attention_maps[0], torch.Tensor):
228
+ attn_map_numpy = attention_maps[0].detach().cpu().numpy()
229
+ else:
230
+ attn_map_numpy = np.array(attention_maps[0])
231
+
232
+ print(f"Attention Map Shape: {attn_map_numpy.shape}")
233
+
234
+ if len(attn_map_numpy) > 0:
235
+ fig, ax = plt.subplots(figsize=(10, 8))
236
+ ax.imshow(attn_map_numpy[0], cmap='viridis', interpolation='nearest')
237
+ ax.set_title(f"Attention Map - Epoch {epoch+1}")
238
+ plt.colorbar(ax.imshow(attn_map_numpy[0], cmap='viridis'))
239
+ plt.savefig("")
240
+ plt.show()
241
+ else:
242
+ print(f"Warning: attention_maps[0] is empty! Shape={attn_map_numpy.shape}")
243
+
244
+ if val_bal_acc < best_val_bal_acc:
245
+ best_val_bal_acc = val_bal_acc
246
+ early_stop_cnt = 0
247
+ torch.save(model.state_dict(), os.path.join(args.ckpt_path, f"best_model_{args.model_name}.pth"))
248
+ print("Best model saved.")
249
+ else:
250
+ early_stop_cnt += 1
251
+ print(f'PATIENCE {early_stop_cnt}/{args.patience_counter}')
252
+
253
+ if early_stop_cnt >= args.patience_counter:
254
+ print("Early stopping triggered.")
255
+ break
256
+
257
+ scheduler.step()
258
+ plot_confusion_matrix(all_val_labels, all_val_preds, classes=["REAL", "FAKE"], writer=writer, epoch=epoch)
259
+
260
+ wandb.finish()
261
+ writer.close()
262
+
263
+ def predict(audio_path):
264
+ print(f"Loading model from {args.ckpt_path}/celoss_best_model_{args.model_name}.pth")
265
+ model.load_state_dict(torch.load(os.path.join(args.ckpt_path, f"best_model_{args.model_name}.pth"), map_location=device))
266
+ model.eval()
267
+
268
+ input_tensor = preprocess_audio(audio_path).to(device)
269
+
270
+ with torch.no_grad():
271
+ output = model(input_tensor)
272
+ probabilities = F.softmax(output, dim=1)
273
+ ai_music_prob = probabilities[0, 1].item()
274
+
275
+ if ai_music_prob > 0.5:
276
+ print(f"FAKE MUSIC {ai_music_prob:.2%})")
277
+ else:
278
+ print(f"REAL MUSIC {100 - ai_music_prob * 100:.2f}%")
279
+
280
+ def Test(model, test_loader, criterion, device):
281
+ model.load_state_dict(torch.load(os.path.join(args.ckpt_path, f"best_model_{args.model_name}.pth"), map_location=device))
282
+ model.eval()
283
+ test_loss, test_correct, test_total = 0, 0, 0
284
+ all_preds, all_labels = [], []
285
+
286
+ with torch.no_grad():
287
+ for data, target in tqdm(test_loader, desc=" Test ", leave=False):
288
+ data, target = data.to(device), target.to(device)
289
+ output = model(data)
290
+ loss = criterion(output, target)
291
+
292
+ test_loss += loss.item() * data.size(0)
293
+ preds = output.argmax(dim=1)
294
+ test_correct += (preds == target).sum().item()
295
+ test_total += target.size(0)
296
+
297
+ all_labels.extend(target.cpu().numpy())
298
+ all_preds.extend(preds.cpu().numpy())
299
+
300
+ test_loss /= test_total
301
+ test_acc = test_correct / test_total
302
+ test_bal_acc = balanced_accuracy_score(all_labels, all_preds)
303
+ test_precision = precision_score(all_labels, all_preds, average="binary")
304
+ test_recall = recall_score(all_labels, all_preds, average="binary")
305
+ test_f1 = f1_score(all_labels, all_preds, average="binary")
306
+
307
+ print(f"\nTest Results - Loss: {test_loss:.4f} | Test Acc: {test_acc:.3f} | "
308
+ f"Test B_ACC: {test_bal_acc:.4f} | Test Prec: {test_precision:.3f} | "
309
+ f"Test Rec: {test_recall:.3f} | Test F1: {test_f1:.3f}")
310
+
311
+
312
+ if __name__ == "__main__":
313
+ train(model, train_loader, val_loader, optimizer, scheduler, criterion, device, args)
314
+ if args.closed_test:
315
+ print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...")
316
+ test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels, feat_type=feat_type, target_duration=args.audio_duration)
317
+ test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16)
318
+
319
+ elif args.open_test:
320
+ print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...")
321
+ test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels, feat_type=feat_type, target_duration=args.audio_duration)
322
+ test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16)
323
+
324
+ else:
325
+ print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...")
326
+ test_dataset = FakeMusicCapsDataset(val_files, val_labels, feat_type=feat_type, target_duration=args.audio_duration)
327
+ test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16)
328
+
329
+ print("\nEvaluating Model on Test Set...")
330
+ Test(model, test_loader, criterion, device)
331
+
332
+ if args.inference:
333
+ if not os.path.exists(args.inference):
334
+ print(f"[ERROR] No File Found: {args.inference}")
335
+ else:
336
+ predict(args.inference)
ISMIR_2025/Model/networks.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class audiocnn(nn.Module):
5
+ def __init__(self, num_classes=2):
6
+ super(audiocnn, self).__init__()
7
+ self.conv_block = nn.Sequential(
8
+ nn.Conv2d(1, 16, kernel_size=3, padding=1),
9
+ nn.ReLU(),
10
+ nn.MaxPool2d(2),
11
+ nn.Conv2d(16, 32, kernel_size=3, padding=1),
12
+ nn.ReLU(),
13
+ nn.MaxPool2d(2),
14
+ nn.AdaptiveAvgPool2d((4,4)) # 최종 -> (B,32,4,4)
15
+ )
16
+ self.fc_block = nn.Sequential(
17
+ nn.Linear(32*4*4, 128),
18
+ nn.ReLU(),
19
+ nn.Linear(128, num_classes)
20
+ )
21
+
22
+ def forward(self, x):
23
+ x = self.conv_block(x)
24
+ # x.shape: (B,32,new_freq,new_time)
25
+
26
+ # 1) Flatten
27
+ B, C, H, W = x.shape # 동적 shape
28
+ x = x.view(B, -1) # (B, 32*H*W)
29
+
30
+ # 2) FC
31
+ x = self.fc_block(x)
32
+ return x
33
+
34
+ class AudioCNN(nn.Module):
35
+ def __init__(self, embed_dim=512):
36
+ super(AudioCNN, self).__init__()
37
+ self.conv_block = nn.Sequential(
38
+ nn.Conv2d(1, 16, kernel_size=3, padding=1),
39
+ nn.ReLU(),
40
+ nn.MaxPool2d(2),
41
+ nn.Conv2d(16, 32, kernel_size=3, padding=1),
42
+ nn.ReLU(),
43
+ nn.MaxPool2d(2),
44
+ nn.AdaptiveAvgPool2d((4, 4)) # 최종 -> (B, 32, 4, 4)
45
+ )
46
+ self.projection = nn.Linear(32 * 4 * 4, embed_dim)
47
+
48
+ def forward(self, x):
49
+ x = self.conv_block(x)
50
+ B, C, H, W = x.shape
51
+ x = x.view(B, -1) # Flatten (B, C * H * W)
52
+ x = self.projection(x) # Project to embed_dim
53
+ return x
54
+
55
+ class ViTDecoder(nn.Module):
56
+ def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2):
57
+ super(ViTDecoder, self).__init__()
58
+
59
+ # Transformer layers
60
+ encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
61
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
62
+
63
+ # Classification head
64
+ self.classifier = nn.Sequential(
65
+ nn.LayerNorm(embed_dim),
66
+ nn.Linear(embed_dim, num_classes)
67
+ )
68
+
69
+ def forward(self, x):
70
+ # Transformer expects input of shape (seq_len, batch, embed_dim)
71
+ x = x.unsqueeze(1).permute(1, 0, 2) # Add sequence dim (1, B, embed_dim)
72
+ x = self.transformer(x) # Pass through Transformer
73
+ x = x.mean(dim=0) # Take the mean over the sequence dimension (B, embed_dim)
74
+
75
+ x = self.classifier(x) # Classification head
76
+ return x
77
+
78
+ class AudioCNNWithViTDecoder(nn.Module):
79
+ def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2):
80
+ super(AudioCNNWithViTDecoder, self).__init__()
81
+ self.encoder = AudioCNN(embed_dim=embed_dim)
82
+ self.decoder = ViTDecoder(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes)
83
+
84
+ def forward(self, x):
85
+ x = self.encoder(x) # Pass through AudioCNN encoder
86
+ x = self.decoder(x) # Pass through ViT decoder
87
+ return x
88
+
89
+
90
+ # class AudioCNN(nn.Module):
91
+ # def __init__(self, num_classes=2):
92
+ # super(AudioCNN, self).__init__()
93
+ # self.conv_block = nn.Sequential(
94
+ # nn.Conv2d(1, 16, kernel_size=3, padding=1),
95
+ # nn.ReLU(),
96
+ # nn.MaxPool2d(2),
97
+ # nn.Conv2d(16, 32, kernel_size=3, padding=1),
98
+ # nn.ReLU(),
99
+ # nn.MaxPool2d(2),
100
+ # nn.AdaptiveAvgPool2d((4,4)) # 최종 -> (B,32,4,4)
101
+ # )
102
+ # self.fc_block = nn.Sequential(
103
+ # nn.Linear(32*4*4, 128),
104
+ # nn.ReLU(),
105
+ # nn.Linear(128, num_classes)
106
+ # )
107
+
108
+
109
+ # def forward(self, x):
110
+ # x = self.conv_block(x)
111
+ # # x.shape: (B,32,new_freq,new_time)
112
+
113
+ # # 1) Flatten
114
+ # B, C, H, W = x.shape # 동적 shape
115
+ # x = x.view(B, -1) # (B, 32*H*W)
116
+
117
+ # # 2) FC
118
+ # x = self.fc_block(x)
119
+ # return x
120
+
121
+
122
+
123
+ class audio_crossattn(nn.Module):
124
+ def __init__(self, embed_dim=512):
125
+ super(audio_crossattn, self).__init__()
126
+ self.conv_block = nn.Sequential(
127
+ nn.Conv2d(1, 16, kernel_size=3, padding=1),
128
+ nn.ReLU(),
129
+ nn.MaxPool2d(2),
130
+ nn.Conv2d(16, 32, kernel_size=3, padding=1),
131
+ nn.ReLU(),
132
+ nn.MaxPool2d(2),
133
+ nn.AdaptiveAvgPool2d((4, 4)) # 최종 출력 -> (B, 32, 4, 4)
134
+ )
135
+ self.projection = nn.Linear(32 * 4 * 4, embed_dim)
136
+
137
+ def forward(self, x):
138
+ x = self.conv_block(x) # Convolutional feature extraction
139
+ B, C, H, W = x.shape
140
+ x = x.view(B, -1) # Flatten (B, C * H * W)
141
+ x = self.projection(x) # Linear projection to embed_dim
142
+ return x
143
+
144
+
145
+ class CrossAttentionLayer(nn.Module):
146
+ def __init__(self, embed_dim, num_heads):
147
+ super(CrossAttentionLayer, self).__init__()
148
+ self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
149
+ self.layer_norm = nn.LayerNorm(embed_dim)
150
+ self.feed_forward = nn.Sequential(
151
+ nn.Linear(embed_dim, embed_dim * 4),
152
+ nn.ReLU(),
153
+ nn.Linear(embed_dim * 4, embed_dim)
154
+ )
155
+
156
+ def forward(self, x, cross_input):
157
+ # Cross-attention between x and cross_input
158
+ attn_output, _ = self.multihead_attn(query=x, key=cross_input, value=cross_input)
159
+ x = self.layer_norm(x + attn_output) # Add & Norm
160
+ feed_forward_output = self.feed_forward(x)
161
+ x = self.layer_norm(x + feed_forward_output) # Add & Norm
162
+ return x
163
+
164
+ class ViTDecoderWithCrossAttention(nn.Module):
165
+ def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2):
166
+ super(ViTDecoderWithCrossAttention, self).__init__()
167
+
168
+ # Cross-Attention layers
169
+ self.cross_attention_layers = nn.ModuleList([
170
+ CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers)
171
+ ])
172
+
173
+ # Transformer Encoder layers
174
+ encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
175
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
176
+
177
+ # Classification head
178
+ self.classifier = nn.Sequential(
179
+ nn.LayerNorm(embed_dim),
180
+ nn.Linear(embed_dim, num_classes)
181
+ )
182
+
183
+ def forward(self, x, cross_attention_input):
184
+ # Pass through Cross-Attention layers
185
+ for layer in self.cross_attention_layers:
186
+ x = layer(x, cross_attention_input)
187
+
188
+ # Transformer expects input of shape (seq_len, batch, embed_dim)
189
+ x = x.unsqueeze(1).permute(1, 0, 2) # Add sequence dim (1, B, embed_dim)
190
+ x = self.transformer(x) # Pass through Transformer
191
+ embedding = x.mean(dim=0) # Take the mean over the sequence dimension (B, embed_dim)
192
+
193
+ # Classification head
194
+ x = self.classifier(embedding)
195
+ return x, embedding
196
+
197
+ # class AudioCNNWithViTDecoderAndCrossAttention(nn.Module):
198
+ # def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2):
199
+ # super(AudioCNNWithViTDecoderAndCrossAttention, self).__init__()
200
+ # self.encoder = audio_crossattn(embed_dim=embed_dim)
201
+ # self.decoder = ViTDecoderWithCrossAttention(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes)
202
+
203
+ # def forward(self, x, cross_attention_input):
204
+ # # Pass through AudioCNN encoder
205
+ # x = self.encoder(x)
206
+
207
+ # # Pass through ViTDecoder with Cross-Attention
208
+ # x = self.decoder(x, cross_attention_input)
209
+ # return x
210
+ class CCV(nn.Module):
211
+ def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True):
212
+ super(CCV, self).__init__()
213
+ self.encoder = AudioCNN(embed_dim=embed_dim)
214
+ self.decoder = ViTDecoderWithCrossAttention(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes)
215
+ if freeze_feature_extractor:
216
+ for param in self.encoder.parameters():
217
+ param.requires_grad = False
218
+ for param in self.decoder.parameters():
219
+ param.requires_grad = False
220
+ def forward(self, x, cross_attention_input=None):
221
+ # Pass through AudioCNN encoder
222
+ x = self.encoder(x)
223
+
224
+ # If cross_attention_input is not provided, use the encoder output
225
+ if cross_attention_input is None:
226
+ cross_attention_input = x
227
+
228
+ # Pass through ViTDecoder with Cross-Attention
229
+ x, embedding = self.decoder(x, cross_attention_input)
230
+ return x, embedding
231
+
232
+ #---------------------------------------------------------
233
+ '''
234
+ audiocnn weight frozen
235
+ crossatten decoder -lora tuning
236
+ '''
237
+
ISMIR_2025/Model/test.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from torch.utils.data import DataLoader
7
+ from sklearn.metrics import precision_score, recall_score, f1_score, balanced_accuracy_score, confusion_matrix
8
+ from datalib_f import (
9
+ FakeMusicCapsDataset,
10
+ closed_test_files, closed_test_labels,
11
+ open_test_files, open_test_labels,
12
+ val_files, val_labels
13
+ )
14
+ from networks_f import CCV_Wav2Vec2
15
+ import argparse
16
+
17
+ parser = argparse.ArgumentParser(description="AI Music Detection Testing")
18
+ parser.add_argument('--gpu', type=str, default='1', help='GPU ID')
19
+ parser.add_argument('--model_name', type=str, choices=['audiocnn', 'CCV'], default='CCV_Wav2Vec2', help='Model name')
20
+ parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
21
+ parser.add_argument('--ckpt_path', type=str, default='/data/kym/AI_Music_Detection/Code/model/tensorboard/wav2vec', help='Checkpoint directory')
22
+ parser.add_argument('--closed_test', action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)")
23
+ parser.add_argument('--open_test', action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)")
24
+ parser.add_argument('--output_path', type=str, default='/data/kym/AI_Music_Detection/Code/model/test_results/w_celoss_repreprocess/wav2vec', help='Path to save test results')
25
+
26
+ args = parser.parse_args()
27
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+ def plot_confusion_matrix(y_true, y_pred, classes, output_path):
31
+ cm = confusion_matrix(y_true, y_pred)
32
+ fig, ax = plt.subplots(figsize=(6, 6))
33
+ im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
34
+ ax.figure.colorbar(im, ax=ax)
35
+
36
+ num_classes = cm.shape[0]
37
+ tick_labels = classes[:num_classes]
38
+
39
+ ax.set(xticks=np.arange(num_classes),
40
+ yticks=np.arange(num_classes),
41
+ xticklabels=tick_labels,
42
+ yticklabels=tick_labels,
43
+ ylabel='True label',
44
+ xlabel='Predicted label')
45
+
46
+ thresh = cm.max() / 2.
47
+ for i in range(cm.shape[0]):
48
+ for j in range(cm.shape[1]):
49
+ ax.text(j, i, format(cm[i, j], 'd'),
50
+ ha="center", va="center",
51
+ color="white" if cm[i, j] > thresh else "black")
52
+
53
+ fig.tight_layout()
54
+ plt.savefig(output_path)
55
+ plt.close(fig)
56
+
57
+ if args.model_name == 'CCV_Wav2Vec2':
58
+ model = CCV_Wav2Vec2(embed_dim=512, num_heads=8, num_layers=6, num_classes=2).to(device)
59
+ else:
60
+ raise ValueError(f"Invalid model name: {args.model_name}")
61
+
62
+ ckpt_file = os.path.join(args.ckpt_path, f"best_model_{args.model_name}.pth")
63
+ if not os.path.exists(ckpt_file):
64
+ raise FileNotFoundError(f"Checkpoint not found: {ckpt_file}")
65
+
66
+ print(f"\nLoading model from {ckpt_file}")
67
+
68
+ # model.load_state_dict(torch.load(ckpt_file, map_location=device))
69
+ # 병렬
70
+ state_dict = torch.load(ckpt_file, map_location=device)
71
+ from collections import OrderedDict
72
+ new_state_dict = OrderedDict()
73
+ for k, v in state_dict.items():
74
+ name = k[7:] if k.startswith("module.") else k
75
+ new_state_dict[name] = v
76
+ model.load_state_dict(new_state_dict)
77
+ # 병렬
78
+ model.eval()
79
+
80
+ torch.cuda.empty_cache()
81
+
82
+ if args.closed_test:
83
+ print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...")
84
+ test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels, feat_type="mel", target_duration=10.0)
85
+ elif args.open_test:
86
+ print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...")
87
+ test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels, feat_type="mel", target_duration=10.0)
88
+ else:
89
+ print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...")
90
+ test_dataset = FakeMusicCapsDataset(val_files, val_labels, feat_type="mel", target_duration=10.0)
91
+
92
+ test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16)
93
+
94
+ def Test(model, test_loader, device):
95
+ model.eval()
96
+ test_loss, test_correct, test_total = 0, 0, 0
97
+ all_preds, all_labels = [], []
98
+
99
+ with torch.no_grad():
100
+ for data, target in test_loader:
101
+ data, target = data.to(device), target.to(device)
102
+ output = model(data)
103
+ loss = F.cross_entropy(output, target)
104
+
105
+ test_loss += loss.item() * data.size(0)
106
+ preds = output.argmax(dim=1)
107
+ test_correct += (preds == target).sum().item()
108
+ test_total += target.size(0)
109
+
110
+ all_labels.extend(target.cpu().numpy())
111
+ all_preds.extend(preds.cpu().numpy())
112
+
113
+ test_loss /= test_total
114
+ test_acc = test_correct / test_total
115
+ test_bal_acc = balanced_accuracy_score(all_labels, all_preds)
116
+ test_precision = precision_score(all_labels, all_preds, average="binary")
117
+ test_recall = recall_score(all_labels, all_preds, average="binary")
118
+ test_f1 = f1_score(all_labels, all_preds, average="binary")
119
+
120
+ print(f"\nTest Results - Loss: {test_loss:.4f} | Test Acc: {test_acc:.3f} | "
121
+ f"Test B_ACC: {test_bal_acc:.4f} | Test Prec: {test_precision:.3f} | "
122
+ f"Test Rec: {test_recall:.3f} | Test F1: {test_f1:.3f}")
123
+
124
+ os.makedirs(args.output_path, exist_ok=True)
125
+ conf_matrix_path = os.path.join(args.output_path, f"confusion_matrix_{args.model_name}.png")
126
+ plot_confusion_matrix(all_labels, all_preds, classes=["real", "generative"], output_path=conf_matrix_path)
127
+
128
+ print("\nEvaluating Model on Test Set...")
129
+ Test(model, test_loader, device)
ISMIR_2025/music2vec/__pycache__/datalib.cpython-311.pyc ADDED
Binary file (10.2 kB). View file
 
ISMIR_2025/music2vec/__pycache__/networks.cpython-311.pyc ADDED
Binary file (11.4 kB). View file
 
ISMIR_2025/music2vec/__pycache__/networks.cpython-312.pyc ADDED
Binary file (11 kB). View file
 
ISMIR_2025/music2vec/datalib.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import torch
4
+ import torchaudio
5
+ import librosa
6
+ import numpy as np
7
+ from sklearn.model_selection import train_test_split
8
+ from torch.utils.data import Dataset
9
+ from imblearn.over_sampling import RandomOverSampler
10
+ from transformers import Wav2Vec2Processor
11
+ import torch
12
+ import torchaudio
13
+ from torch.nn.utils.rnn import pad_sequence
14
+ import scipy.signal as signal
15
+ import random
16
+
17
+ class FakeMusicCapsDataset(Dataset):
18
+ def __init__(self, file_paths, labels, sr=16000, target_duration=10.0, augment=True):
19
+ self.file_paths = file_paths
20
+ self.labels = labels
21
+ self.sr = sr
22
+ self.target_samples = int(target_duration * sr)
23
+ self.augment = augment
24
+ def __len__(self):
25
+ return len(self.file_paths)
26
+
27
+ def augment_audio(self, y, sr):
28
+ if isinstance(y, torch.Tensor):
29
+ y = y.numpy()
30
+ if random.random() < 0.5:
31
+ rate = random.uniform(0.8, 1.2)
32
+ y = librosa.effects.time_stretch(y=y, rate=rate)
33
+ if random.random() < 0.5:
34
+ n_steps = random.randint(-2, 2)
35
+ y = librosa.effects.pitch_shift(y=y, sr=sr, n_steps=n_steps)
36
+ if random.random() < 0.5:
37
+ noise_level = np.random.uniform(0.001, 0.005)
38
+ y = y + np.random.normal(0, noise_level, y.shape)
39
+ if random.random() < 0.5:
40
+ gain = np.random.uniform(0.9, 1.1)
41
+ y = y * gain
42
+ return torch.tensor(y, dtype=torch.float32)
43
+
44
+
45
+ def __getitem__(self, idx):
46
+ audio_path = self.file_paths[idx]
47
+ label = self.labels[idx]
48
+
49
+ waveform, sr = torchaudio.load(audio_path)
50
+ waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sr)(waveform)
51
+ waveform = waveform.mean(dim=0)
52
+ current_samples = waveform.shape[0]
53
+
54
+ if label == 0:
55
+ waveform = self.augment_audio(waveform, self.sr)
56
+ if label == 1:
57
+ waveform = self.highpass_filter(waveform, self.sr)
58
+ waveform = self.augment_audio(waveform, self.sr)
59
+
60
+ if current_samples > self.target_samples:
61
+ waveform = waveform[:self.target_samples]
62
+ elif current_samples < self.target_samples:
63
+ pad_length = self.target_samples - current_samples
64
+ waveform = torch.nn.functional.pad(waveform, (0, pad_length))
65
+
66
+ # waveform = waveform.squeeze(0)
67
+ if isinstance(waveform, np.ndarray):
68
+ waveform = torch.tensor(waveform, dtype=torch.float32)
69
+
70
+ return waveform.unsqueeze(0), torch.tensor(label, dtype=torch.long)
71
+
72
+ def highpass_filter(self, y, sr, cutoff=500, order=5):
73
+ if isinstance(sr, np.ndarray):
74
+ sr = np.mean(sr)
75
+ if not isinstance(sr, (int, float)):
76
+ raise ValueError(f"[ERROR] sr must be a number, but got {type(sr)}: {sr}")
77
+ if sr <= 0:
78
+ raise ValueError(f"Invalid sample rate: {sr}. It must be greater than 0.")
79
+ nyquist = 0.5 * sr
80
+ if cutoff <= 0 or cutoff >= nyquist:
81
+ print(f"[WARNING] Invalid cutoff frequency {cutoff}, adjusting...")
82
+ cutoff = max(10, min(cutoff, nyquist - 1))
83
+ normal_cutoff = cutoff / nyquist
84
+ b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
85
+ y_filtered = signal.lfilter(b, a, y)
86
+ return y_filtered
87
+
88
+ def preprocess_audio(audio_path, target_sr=16000, max_length=160000):
89
+ waveform, sr = torchaudio.load(audio_path)
90
+ if sr != target_sr:
91
+ waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(waveform)
92
+
93
+ waveform = waveform.mean(dim=0).unsqueeze(0)
94
+
95
+ current_samples = waveform.shape[1]
96
+ if current_samples > max_length:
97
+ start_idx = (current_samples - max_length) // 2
98
+ waveform = waveform[:, start_idx:start_idx + max_length]
99
+ elif current_samples < max_length:
100
+ pad_length = max_length - current_samples
101
+ waveform = torch.nn.functional.pad(waveform, (0, pad_length))
102
+
103
+ return waveform
104
+
105
+
106
+ DATASET_PATH = "/data/kym/AI_Music_Detection/audio/FakeMusicCaps"
107
+ SUNOCAPS_PATH = "/data/kym/Audio/SunoCaps" # Open Set 포함 데이터
108
+
109
+ real_files = glob.glob(os.path.join(DATASET_PATH, "real", "**", "*.wav"), recursive=True)
110
+ gen_files = glob.glob(os.path.join(DATASET_PATH, "generative", "**", "*.wav"), recursive=True)
111
+
112
+ open_real_files = real_files + glob.glob(os.path.join(SUNOCAPS_PATH, "real", "**", "*.wav"), recursive=True)
113
+ open_gen_files = gen_files + glob.glob(os.path.join(SUNOCAPS_PATH, "generative", "**", "*.wav"), recursive=True)
114
+
115
+ real_labels = [0] * len(real_files)
116
+ gen_labels = [1] * len(gen_files)
117
+
118
+ open_real_labels = [0] * len(open_real_files)
119
+ open_gen_labels = [1] * len(open_gen_files)
120
+
121
+ real_train, real_val, real_train_labels, real_val_labels = train_test_split(real_files, real_labels, test_size=0.2, random_state=42)
122
+ gen_train, gen_val, gen_train_labels, gen_val_labels = train_test_split(gen_files, gen_labels, test_size=0.2, random_state=42)
123
+
124
+ train_files = real_train + gen_train
125
+ train_labels = real_train_labels + gen_train_labels
126
+ val_files = real_val + gen_val
127
+ val_labels = real_val_labels + gen_val_labels
128
+
129
+ closed_test_files = real_files + gen_files
130
+ closed_test_labels = real_labels + gen_labels
131
+
132
+ open_test_files = open_real_files + open_gen_files
133
+ open_test_labels = open_real_labels + open_gen_labels
134
+
135
+ ros = RandomOverSampler(sampling_strategy='auto', random_state=42)
136
+ train_files_resampled, train_labels_resampled = ros.fit_resample(np.array(train_files).reshape(-1, 1), train_labels)
137
+
138
+ train_files = train_files_resampled.reshape(-1).tolist()
139
+ train_labels = train_labels_resampled
140
+
141
+ print(f"Train Original FAKE: {len(gen_train)}")
142
+ print(f"Train set (Oversampled) - REAL: {sum(1 for label in train_labels if label == 0)}, "
143
+ f"FAKE: {sum(1 for label in train_labels if label == 1)}, Total: {len(train_files)}")
144
+ print(f"Validation set - REAL: {len(real_val)}, FAKE: {len(gen_val)}, Total: {len(val_files)}")
ISMIR_2025/music2vec/inference.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torchaudio
5
+ import argparse
6
+ from datalib import preprocess_audio
7
+ from networks import Wav2Vec2ForFakeMusic
8
+
9
+ # Argument Parsing
10
+ parser = argparse.ArgumentParser(description="Wav2Vec2 AI Music Detection Inference")
11
+ parser.add_argument('--gpu', type=str, default='0', help='GPU ID')
12
+ parser.add_argument('--model_name', type=str, choices=['Wav2Vec2ForFakeMusic'], default='Wav2Vec2ForFakeMusic', help='Model name')
13
+ parser.add_argument('--ckpt_path', type=str, default='/data/kym/AI_Music_Detection/Code/model/wav2vec/ckpt/', help='Checkpoint directory')
14
+ parser.add_argument('--model_type', type=str, choices=['pretrain', 'finetune'], required=True, help='Choose between pretrained or fine-tuned model')
15
+ parser.add_argument('--inference', type=str, required=True, help='Path to a .wav file for inference')
16
+ args = parser.parse_args()
17
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ # Load Model Checkpoint
21
+ if args.model_type == 'pretrain':
22
+ model_file = os.path.join(args.ckpt_path, "wav2vec2_pretrain_10.pth")
23
+ elif args.model_type == 'finetune':
24
+ model_file = os.path.join(args.ckpt_path, "wav2vec2_finetune_5.pth")
25
+ else:
26
+ raise ValueError("Invalid model type. Choose between 'pretrain' or 'finetune'.")
27
+
28
+ if not os.path.exists(model_file):
29
+ raise FileNotFoundError(f"Model checkpoint not found: {model_file}")
30
+
31
+ if args.model_name == 'Wav2Vec2ForFakeMusic':
32
+ model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=(args.model_type == 'finetune'))
33
+ else:
34
+ raise ValueError(f"Invalid model name: {args.model_name}")
35
+
36
+ def predict(audio_path):
37
+ print(f"\n🔍 Loading model from {model_file}")
38
+
39
+ if not os.path.exists(audio_path):
40
+ raise FileNotFoundError(f"[ERROR] Audio file not found: {audio_path}")
41
+
42
+ model.to(device)
43
+ model.eval()
44
+
45
+ input_tensor = preprocess_audio(audio_path).to(device)
46
+ print(f"Input shape after preprocessing: {input_tensor.shape}")
47
+
48
+ with torch.no_grad():
49
+ output = model(input_tensor)
50
+ print(f"Raw model output (logits): {output}")
51
+
52
+ probabilities = F.softmax(output, dim=1)
53
+ ai_music_prob = probabilities[0, 1].item()
54
+
55
+ print(f"Softmax Probabilities: {probabilities}")
56
+ print(f"AI Music Probability: {ai_music_prob:.4f}")
57
+
58
+ if ai_music_prob > 0.5:
59
+ print(f" FAKE MUSIC DETECTED ({ai_music_prob:.2%})")
60
+ else:
61
+ print(f" REAL MUSIC DETECTED ({100 - ai_music_prob * 100:.2f}%)")
62
+
63
+ if __name__ == "__main__":
64
+ predict(args.inference)
ISMIR_2025/music2vec/main.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from tqdm import tqdm
8
+ from torch.utils.data import DataLoader
9
+ from sklearn.metrics import f1_score, precision_score, recall_score, balanced_accuracy_score
10
+ import wandb
11
+ import argparse
12
+ from transformers import Wav2Vec2Processor
13
+ from datalib import FakeMusicCapsDataset, train_files, train_labels, val_files, val_labels
14
+ from networks import Music2VecClassifier, CCV
15
+
16
+ parser = argparse.ArgumentParser(description='AI Music Detection Training with Music2Vec + CCV')
17
+ parser.add_argument('--gpu', type=str, default='2', help='GPU ID')
18
+ parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
19
+ parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
20
+ parser.add_argument('--finetune_lr', type=float, default=1e-3, help='Fine-Tune Learning rate')
21
+ parser.add_argument('--pretrain_epochs', type=int, default=20, help='Pretraining epochs (REAL data only)')
22
+ parser.add_argument('--finetune_epochs', type=int, default=10, help='Fine-tuning epochs (REAL + FAKE data)')
23
+ parser.add_argument('--checkpoint_dir', type=str, default='', help='Checkpoint directory')
24
+ parser.add_argument('--weight_decay', type=float, default=0.001, help="Weight decay for optimizer")
25
+
26
+ args = parser.parse_args()
27
+
28
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+
31
+ torch.manual_seed(42)
32
+ random.seed(42)
33
+ np.random.seed(42)
34
+
35
+ wandb.init(project="music2vec_ccv", name=f"pretrain_{args.pretrain_epochs}_finetune_{args.finetune_epochs}", config=args)
36
+
37
+ print("Preparing datasets...")
38
+ train_dataset = FakeMusicCapsDataset(train_files, train_labels)
39
+ val_dataset = FakeMusicCapsDataset(val_files, val_labels)
40
+
41
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
42
+ val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
43
+
44
+ pretrain_ckpt = os.path.join(args.checkpoint_dir, f"music2vec_pretrain_{args.pretrain_epochs}.pth")
45
+ finetune_ckpt = os.path.join(args.checkpoint_dir, f"music2vec_ccv_finetune_{args.finetune_epochs}.pth")
46
+
47
+ print("Initializing Music2Vec model for Pretraining...")
48
+ processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h")
49
+ model = Music2VecClassifier(freeze_feature_extractor=False).to(device) # Pretraining에서는 freeze
50
+
51
+ criterion = nn.CrossEntropyLoss()
52
+ optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
53
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
54
+
55
+ def train(model, dataloader, optimizer, criterion, device, epoch, phase="Pretrain"):
56
+ model.train()
57
+ total_loss, total_correct, total_samples = 0, 0, 0
58
+ all_preds, all_labels = [], []
59
+
60
+ for inputs, labels in tqdm(dataloader, desc=f"{phase} Training Epoch {epoch+1}"):
61
+ labels = labels.to(device)
62
+ inputs = inputs.to(device)
63
+
64
+ logits = model(inputs)
65
+ loss = criterion(logits, labels)
66
+
67
+ optimizer.zero_grad()
68
+ loss.backward()
69
+ optimizer.step()
70
+
71
+ total_loss += loss.item()
72
+ preds = logits.argmax(dim=1)
73
+ total_correct += (preds == labels).sum().item()
74
+ total_samples += labels.size(0)
75
+ all_preds.extend(preds.cpu().numpy())
76
+ all_labels.extend(labels.cpu().numpy())
77
+
78
+ scheduler.step()
79
+ accuracy = total_correct / total_samples
80
+ f1 = f1_score(all_labels, all_preds, average="binary")
81
+ balanced_acc = balanced_accuracy_score(all_labels, all_preds)
82
+ precision = precision_score(all_labels, all_preds, average="binary")
83
+ recall = recall_score(all_labels, all_preds, average="binary")
84
+
85
+ wandb.log({
86
+ f"{phase} Train Loss": total_loss / len(dataloader),
87
+ f"{phase} Train Accuracy": accuracy,
88
+ f"{phase} Train F1 Score": f1,
89
+ f"{phase} Train Precision": precision,
90
+ f"{phase} Train Recall": recall,
91
+ f"{phase} Train Balanced Accuracy": balanced_acc,
92
+ })
93
+
94
+ print(f"{phase} Train Epoch {epoch+1}: Train Loss: {total_loss / len(dataloader):.4f}, "
95
+ f"Train Acc: {accuracy:.4f}, Train F1: {f1:.4f}, Train Prec: {precision:.4f}, Train Rec: {recall:.4f}, B_ACC: {balanced_acc:.4f}")
96
+
97
+ def validate(model, dataloader, criterion, device, phase="Validation"):
98
+ model.eval()
99
+ total_loss, total_correct, total_samples = 0, 0, 0
100
+ all_preds, all_labels = [], []
101
+
102
+ with torch.no_grad():
103
+ for inputs, labels in tqdm(dataloader, desc=f"{phase}"):
104
+ inputs, labels = inputs.to(device), labels.to(device)
105
+ inputs = inputs.squeeze(1)
106
+ outputs = model(inputs)
107
+ loss = criterion(outputs, labels)
108
+
109
+ total_loss += loss.item()
110
+ preds = outputs.argmax(dim=1)
111
+ total_correct += (preds == labels).sum().item()
112
+ total_samples += labels.size(0)
113
+
114
+ all_preds.extend(preds.cpu().numpy())
115
+ all_labels.extend(labels.cpu().numpy())
116
+
117
+ accuracy = total_correct / total_samples
118
+ f1 = f1_score(all_labels, all_preds, average="weighted")
119
+ val_bal_acc = balanced_accuracy_score(all_labels, all_preds)
120
+ val_precision = precision_score(all_labels, all_preds, average="binary")
121
+ val_recall = recall_score(all_labels, all_preds, average="binary")
122
+
123
+ wandb.log({
124
+ f"{phase} Val Loss": total_loss / len(dataloader),
125
+ f"{phase} Val Accuracy": accuracy,
126
+ f"{phase} Val F1 Score": f1,
127
+ f"{phase} Val Precision": val_precision,
128
+ f"{phase} Val Recall": val_recall,
129
+ f"{phase} Val Balanced Accuracy": val_bal_acc,
130
+ })
131
+ print(f"{phase} Val Loss: {total_loss / len(dataloader):.4f}, "
132
+ f"Val Acc: {accuracy:.4f}, Val F1: {f1:.4f}, Val Prec: {val_precision:.4f}, Val Rec: {val_recall:.4f}, Val B_ACC: {val_bal_acc:.4f}")
133
+ return total_loss / len(dataloader), accuracy, f1
134
+
135
+ print("\nStep 1: Self-Supervised Pretraining on REAL Data")
136
+ for epoch in range(args.pretrain_epochs):
137
+ train(model, train_loader, optimizer, criterion, device, epoch, phase="Pretrain")
138
+
139
+ torch.save(model.state_dict(), pretrain_ckpt)
140
+ print(f"\nPretraining completed! Model saved at: {pretrain_ckpt}")
141
+
142
+ print("\nInitializing Music2Vec + CCV Model for Fine-Tuning...")
143
+ model.load_state_dict(torch.load(pretrain_ckpt))
144
+
145
+ # model = CCV(embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True).to(device)
146
+ model = Music2VecClassifier(freeze_feature_extractor=False).to(device)
147
+ optimizer = optim.Adam(model.parameters(), lr=args.finetune_lr, weight_decay=args.weight_decay)
148
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
149
+
150
+ print("\nStep 2: Fine-Tuning CCV Model using Music2Vec Features")
151
+ for epoch in range(args.finetune_epochs):
152
+ train(model, train_loader, optimizer, criterion, device, epoch, phase="Fine-Tune")
153
+
154
+ torch.save(model.state_dict(), finetune_ckpt)
155
+ print(f"\nFine-Tuning completed! Model saved at: {finetune_ckpt}")
ISMIR_2025/music2vec/networks.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import Data2VecAudioModel, Wav2Vec2Processor
4
+
5
+ class Music2VecClassifier(nn.Module):
6
+ def __init__(self, num_classes=2, freeze_feature_extractor=True):
7
+ super(Music2VecClassifier, self).__init__()
8
+
9
+ self.processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h")
10
+ self.music2vec = Data2VecAudioModel.from_pretrained("m-a-p/music2vec-v1")
11
+
12
+ if freeze_feature_extractor:
13
+ for param in self.music2vec.parameters():
14
+ param.requires_grad = False
15
+
16
+ # Conv1d for learnable weighted average across layers
17
+ self.conv1d = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1)
18
+
19
+ # Classification head
20
+ self.classifier = nn.Sequential(
21
+ nn.Linear(self.music2vec.config.hidden_size, 256),
22
+ nn.ReLU(),
23
+ nn.Dropout(0.3),
24
+ nn.Linear(256, num_classes)
25
+ )
26
+
27
+ def forward(self, input_values):
28
+ input_values = input_values.squeeze(1) # Ensure shape [batch, time]
29
+
30
+ with torch.no_grad():
31
+ outputs = self.music2vec(input_values, output_hidden_states=True)
32
+ hidden_states = torch.stack(outputs.hidden_states)
33
+ time_reduced = hidden_states.mean(dim=2)
34
+ time_reduced = time_reduced.permute(1, 0, 2)
35
+ weighted_avg = self.conv1d(time_reduced).squeeze(1)
36
+
37
+ return self.classifier(weighted_avg), weighted_avg
38
+
39
+ def unfreeze_feature_extractor(self):
40
+ for param in self.music2vec.parameters():
41
+ param.requires_grad = True
42
+
43
+ class Music2VecFeatureExtractor(nn.Module):
44
+ def __init__(self, freeze_feature_extractor=True):
45
+ super(Music2VecFeatureExtractor, self).__init__()
46
+ self.processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h")
47
+ self.music2vec = Data2VecAudioModel.from_pretrained("m-a-p/music2vec-v1")
48
+
49
+ if freeze_feature_extractor:
50
+ for param in self.music2vec.parameters():
51
+ param.requires_grad = False
52
+
53
+ # Conv1d for learnable weighted average across layers
54
+ self.conv1d = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1)
55
+
56
+ def forward(self, input_values):
57
+ # input_values: [batch, time]
58
+ input_values = input_values.squeeze(1)
59
+ with torch.no_grad():
60
+ outputs = self.music2vec(input_values, output_hidden_states=True)
61
+ hidden_states = torch.stack(outputs.hidden_states) # [num_layers, batch, time, hidden_dim]
62
+ time_reduced = hidden_states.mean(dim=2) # [num_layers, batch, hidden_dim]
63
+ time_reduced = time_reduced.permute(1, 0, 2) # [batch, num_layers, hidden_dim]
64
+ weighted_avg = self.conv1d(time_reduced).squeeze(1) # [batch, hidden_dim]
65
+ return weighted_avg
66
+
67
+ '''
68
+ music2vec+CCV
69
+ # '''
70
+ # import torch
71
+ # import torch.nn as nn
72
+ # from transformers import Data2VecAudioModel, Wav2Vec2Processor
73
+ # import torch.nn.functional as F
74
+
75
+
76
+ # ### Music2Vec Feature Extractor (Pretrained Model)
77
+ # class Music2VecFeatureExtractor(nn.Module):
78
+ # def __init__(self, freeze_feature_extractor=True):
79
+ # super(Music2VecFeatureExtractor, self).__init__()
80
+
81
+ # self.processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h")
82
+ # self.music2vec = Data2VecAudioModel.from_pretrained("m-a-p/music2vec-v1")
83
+
84
+ # if freeze_feature_extractor:
85
+ # for param in self.music2vec.parameters():
86
+ # param.requires_grad = False
87
+
88
+ # # Conv1d for learnable weighted average across layers
89
+ # self.conv1d = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1)
90
+
91
+ # def forward(self, input_values):
92
+ # with torch.no_grad():
93
+ # outputs = self.music2vec(input_values, output_hidden_states=True)
94
+
95
+ # hidden_states = torch.stack(outputs.hidden_states) # [13, batch, time, hidden_size]
96
+ # time_reduced = hidden_states.mean(dim=2) # 평균 풀링: [13, batch, hidden_size]
97
+ # time_reduced = time_reduced.permute(1, 0, 2) # [batch, 13, hidden_size]
98
+ # weighted_avg = self.conv1d(time_reduced).squeeze(1) # [batch, hidden_size]
99
+
100
+ # return weighted_avg # Extracted feature representation
101
+
102
+
103
+ # def unfreeze_feature_extractor(self):
104
+ # for param in self.music2vec.parameters():
105
+ # param.requires_grad = True # Unfreeze for Fine-tuning
106
+
107
+ # ### CNN Feature Extractor for CCV
108
+ class CNNEncoder(nn.Module):
109
+ def __init__(self, embed_dim=512):
110
+ super(CNNEncoder, self).__init__()
111
+ self.conv_block = nn.Sequential(
112
+ nn.Conv2d(1, 16, kernel_size=3, padding=1),
113
+ nn.ReLU(),
114
+ nn.MaxPool2d((2,1)), # 기존 MaxPool2d(2)를 MaxPool2d((2,1))으로 변경
115
+ nn.Conv2d(16, 32, kernel_size=3, padding=1),
116
+ nn.ReLU(),
117
+ nn.MaxPool2d((1,1)), # 추가된 MaxPool2d(1,1)로 크기 유지
118
+ nn.AdaptiveAvgPool2d((4, 4)) # 최종 크기 조정
119
+ )
120
+ self.projection = nn.Linear(32 * 4 * 4, embed_dim)
121
+
122
+ def forward(self, x):
123
+ # print(f"Input shape before CNNEncoder: {x.shape}") # 디버깅용 출력
124
+ x = self.conv_block(x)
125
+ B, C, H, W = x.shape
126
+ x = x.view(B, -1)
127
+ x = self.projection(x)
128
+ return x
129
+
130
+
131
+ ### Cross-Attention Module
132
+ class CrossAttentionLayer(nn.Module):
133
+ def __init__(self, embed_dim, num_heads):
134
+ super(CrossAttentionLayer, self).__init__()
135
+ self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
136
+ self.layer_norm = nn.LayerNorm(embed_dim)
137
+ self.feed_forward = nn.Sequential(
138
+ nn.Linear(embed_dim, embed_dim * 4),
139
+ nn.ReLU(),
140
+ nn.Linear(embed_dim * 4, embed_dim)
141
+ )
142
+ self.attention_weights = None
143
+
144
+ def forward(self, x, cross_input):
145
+ attn_output, attn_weights = self.multihead_attn(query=x, key=cross_input, value=cross_input)
146
+ self.attention_weights = attn_weights
147
+ x = self.layer_norm(x + attn_output)
148
+ feed_forward_output = self.feed_forward(x)
149
+ x = self.layer_norm(x + feed_forward_output)
150
+ return x
151
+
152
+ ### Cross-Attention Transformer
153
+ class CrossAttentionViT(nn.Module):
154
+ def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2):
155
+ super(CrossAttentionViT, self).__init__()
156
+
157
+ self.cross_attention_layers = nn.ModuleList([
158
+ CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers)
159
+ ])
160
+
161
+ encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
162
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
163
+
164
+ self.classifier = nn.Sequential(
165
+ nn.LayerNorm(embed_dim),
166
+ nn.Linear(embed_dim, num_classes)
167
+ )
168
+
169
+ def forward(self, x, cross_attention_input):
170
+ self.attention_maps = []
171
+ for layer in self.cross_attention_layers:
172
+ x = layer(x, cross_attention_input)
173
+ self.attention_maps.append(layer.attention_weights)
174
+
175
+ x = x.unsqueeze(1).permute(1, 0, 2)
176
+ x = self.transformer(x)
177
+ x = x.mean(dim=0)
178
+ x = self.classifier(x)
179
+ return x
180
+
181
+ ### CCV Model (Final Classifier)
182
+ # class CCV(nn.Module):
183
+ # def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True):
184
+ # super(CCV, self).__init__()
185
+
186
+ # self.music2vec_extractor = Music2VecClassifier(freeze_feature_extractor=freeze_feature_extractor)
187
+
188
+ # # CNN Encoder for Image Representation
189
+ # self.encoder = CNNEncoder(embed_dim=embed_dim)
190
+
191
+ # # Transformer with Cross-Attention
192
+ # self.decoder = CrossAttentionViT(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes)
193
+
194
+ # def forward(self, x, cross_attention_input=None):
195
+ # x = self.music2vec_extractor(x)
196
+ # # print(f"After Music2VecExtractor: {x.shape}") # (batch, 2) 출력됨
197
+
198
+ # # CNNEncoder가 기대하는 입력 크기 맞추기
199
+ # x = x.unsqueeze(1).unsqueeze(-1) # (batch, 1, 2, 1) 형태로 변환
200
+ # # print(f"Before CNNEncoder: {x.shape}") # CNN 입력 확인
201
+
202
+ # x = self.encoder(x)
203
+
204
+ # if cross_attention_input is None:
205
+ # cross_attention_input = x
206
+
207
+ # x = self.decoder(x, cross_attention_input)
208
+
209
+ # return x
210
+
211
+ class CCV(nn.Module):
212
+ def __init__(self, embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True):
213
+ super(CCV, self).__init__()
214
+ self.feature_extractor = Music2VecFeatureExtractor(freeze_feature_extractor=freeze_feature_extractor)
215
+
216
+ # Cross-Attention Transformer
217
+ self.cross_attention_layers = nn.ModuleList([
218
+ CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers)
219
+ ])
220
+
221
+ # Transformer Encoder
222
+ encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
223
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
224
+
225
+ # Classification Head
226
+ self.classifier = nn.Sequential(
227
+ nn.LayerNorm(embed_dim),
228
+ nn.Linear(embed_dim, num_classes)
229
+ )
230
+
231
+ def forward(self, input_values):
232
+ # Extract feature embeddings
233
+ features = self.feature_extractor(input_values) # [batch, feature_dim]
234
+ # Average over layer dimension if necessary (여기서는 이미 [batch, hidden_dim])
235
+ # Apply Cross-Attention Layers
236
+ for layer in self.cross_attention_layers:
237
+ features = layer(features.unsqueeze(1), features.unsqueeze(1)).squeeze(1)
238
+ # Transformer Encoding
239
+ encoded = self.transformer(features.unsqueeze(1))
240
+ encoded = encoded.mean(dim=1)
241
+ # Classification Head
242
+ logits = self.classifier(encoded)
243
+ return logits
244
+
245
+ def get_attention_maps(self):
246
+ # 만약 CrossAttentionLayer의 attention_maps를 사용하고 싶다면 구현
247
+ return None
ISMIR_2025/music2vec/test.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from torch.utils.data import DataLoader
7
+ from sklearn.metrics import precision_score, recall_score, f1_score, balanced_accuracy_score, confusion_matrix
8
+ from datalib import (
9
+ FakeMusicCapsDataset,
10
+ closed_test_files, closed_test_labels,
11
+ open_test_files, open_test_labels,
12
+ val_files, val_labels
13
+ )
14
+ from networks import Music2VecClassifier
15
+ import argparse
16
+
17
+ '''
18
+ python3 test.py --gpu 1 --closed_test --ckpt_path ""
19
+ '''
20
+ parser = argparse.ArgumentParser(description="AI Music Detection Testing with Music2Vec")
21
+ parser.add_argument('--gpu', type=str, default='1', help='GPU ID')
22
+ parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
23
+ parser.add_argument('--ckpt_path', type=str, default='/data/kym/AI_Music_Detection/Code/model/music2vec/ckpt/music2vec_pretrain_10.pth', help='Checkpoint directory')
24
+ parser.add_argument('--model_name', type=str, default="music2vec", help="Model name")
25
+ parser.add_argument('--closed_test', action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)")
26
+ parser.add_argument('--open_test', action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)")
27
+ parser.add_argument('--output_path', type=str, default='', help='Path to save test results')
28
+
29
+ args = parser.parse_args()
30
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+
33
+ def plot_confusion_matrix(y_true, y_pred, classes, output_path):
34
+ cm = confusion_matrix(y_true, y_pred)
35
+ fig, ax = plt.subplots(figsize=(6, 6))
36
+ im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
37
+ ax.figure.colorbar(im, ax=ax)
38
+
39
+ num_classes = cm.shape[0]
40
+ tick_labels = classes[:num_classes]
41
+
42
+ ax.set(xticks=np.arange(num_classes),
43
+ yticks=np.arange(num_classes),
44
+ xticklabels=tick_labels,
45
+ yticklabels=tick_labels,
46
+ ylabel='True label',
47
+ xlabel='Predicted label')
48
+
49
+ thresh = cm.max() / 2.
50
+ for i in range(cm.shape[0]):
51
+ for j in range(cm.shape[1]):
52
+ ax.text(j, i, format(cm[i, j], 'd'),
53
+ ha="center", va="center",
54
+ color="white" if cm[i, j] > thresh else "black")
55
+
56
+ fig.tight_layout()
57
+ plt.savefig(output_path)
58
+ plt.close(fig)
59
+
60
+ model = Music2VecClassifier().to(device)
61
+
62
+ ckpt_file = os.path.join(args.ckpt_path)
63
+ if not os.path.exists(ckpt_file):
64
+ raise FileNotFoundError(f"Checkpoint not found: {ckpt_file}")
65
+
66
+ print(f"\nLoading model from {ckpt_file}")
67
+ model.load_state_dict(torch.load(ckpt_file, map_location=device))
68
+ model.eval()
69
+
70
+ torch.cuda.empty_cache()
71
+
72
+ if args.closed_test:
73
+ print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...")
74
+ test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels, target_duration=10.0)
75
+ elif args.open_test:
76
+ print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...")
77
+ test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels, target_duration=10.0)
78
+ else:
79
+ print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...")
80
+ test_dataset = FakeMusicCapsDataset(val_files, val_labels, target_duration=10.0)
81
+
82
+ test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8)
83
+
84
+ def Test(model, test_loader, device):
85
+ model.eval()
86
+ test_loss, test_correct, test_total = 0, 0, 0
87
+ all_preds, all_labels = [], []
88
+
89
+ with torch.no_grad():
90
+ for data, target in test_loader:
91
+ data, target = data.to(device), target.to(device)
92
+ output = model(data)
93
+ loss = F.cross_entropy(output, target)
94
+
95
+ test_loss += loss.item() * data.size(0)
96
+ preds = output.argmax(dim=1)
97
+ test_correct += (preds == target).sum().item()
98
+ test_total += target.size(0)
99
+
100
+ all_labels.extend(target.cpu().numpy())
101
+ all_preds.extend(preds.cpu().numpy())
102
+
103
+ test_loss /= test_total
104
+ test_acc = test_correct / test_total
105
+ test_bal_acc = balanced_accuracy_score(all_labels, all_preds)
106
+ test_precision = precision_score(all_labels, all_preds, average="binary")
107
+ test_recall = recall_score(all_labels, all_preds, average="binary")
108
+ test_f1 = f1_score(all_labels, all_preds, average="binary")
109
+
110
+ print(f"\nTest Results - Loss: {test_loss:.4f} | Test Acc: {test_acc:.3f} | "
111
+ f"Test B_ACC: {test_bal_acc:.4f} | Test Prec: {test_precision:.3f} | "
112
+ f"Test Rec: {test_recall:.3f} | Test F1: {test_f1:.3f}")
113
+
114
+ os.makedirs(args.output_path, exist_ok=True)
115
+ conf_matrix_path = os.path.join(args.output_path, f"confusion_matrix_{args.model_name}.png")
116
+ plot_confusion_matrix(all_labels, all_preds, classes=["real", "generative"], output_path=conf_matrix_path)
117
+
118
+ print("\nEvaluating Model on Test Set...")
119
+ Test(model, test_loader, device)
ISMIR_2025/wav2vec/__pycache__/datalib.cpython-311.pyc ADDED
Binary file (9.2 kB). View file
 
ISMIR_2025/wav2vec/__pycache__/loss.cpython-311.pyc ADDED
Binary file (1.77 kB). View file
 
ISMIR_2025/wav2vec/__pycache__/networks.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
ISMIR_2025/wav2vec/__pycache__/networks.cpython-312.pyc ADDED
Binary file (9.47 kB). View file
 
ISMIR_2025/wav2vec/__pycache__/wav2vec_datalib.cpython-311.pyc ADDED
Binary file (9.15 kB). View file
 
ISMIR_2025/wav2vec/datalib.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import random
4
+ import torch
5
+ import librosa
6
+ import numpy as np
7
+ import utils
8
+ from sklearn.model_selection import train_test_split
9
+ from torch.utils.data import Dataset, DataLoader
10
+ import scipy.signal as signal
11
+ import scipy.signal
12
+ from scipy.signal import butter, lfilter
13
+ import numpy as np
14
+ import scipy.signal as signal
15
+ import librosa
16
+ import torch
17
+ import random
18
+ from torch.utils.data import Dataset
19
+ import logging
20
+ import csv
21
+ import logging
22
+ import time
23
+ import numpy as np
24
+ import h5py
25
+ import torch
26
+ import torchaudio
27
+ from imblearn.over_sampling import RandomOverSampler
28
+ from networks import Wav2Vec2ForFakeMusic
29
+ from transformers import Wav2Vec2Processor
30
+ import torchaudio.transforms as T
31
+
32
+ class FakeMusicCapsDataset(Dataset):
33
+ def __init__(self, file_paths, labels, sr=16000, target_duration=10.0):
34
+ self.file_paths = file_paths
35
+ self.labels = labels
36
+ self.sr = sr
37
+ self.target_duration = target_duration
38
+ self.target_samples = int(target_duration * sr)
39
+
40
+ self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
41
+
42
+ def highpass_filter(self, y, sr, cutoff=500, order=5):
43
+ if isinstance(sr, np.ndarray):
44
+ sr = np.mean(sr)
45
+ if not isinstance(sr, (int, float)):
46
+ raise ValueError(f"[ERROR] sr must be a number, but got {type(sr)}: {sr}")
47
+ if sr <= 0:
48
+ raise ValueError(f"Invalid sample rate: {sr}. It must be greater than 0.")
49
+ nyquist = 0.5 * sr
50
+ if cutoff <= 0 or cutoff >= nyquist:
51
+ print(f"[WARNING] Invalid cutoff frequency {cutoff}, adjusting...")
52
+ cutoff = max(10, min(cutoff, nyquist - 1))
53
+ normal_cutoff = cutoff / nyquist
54
+ b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
55
+ y_filtered = signal.lfilter(b, a, y)
56
+ return y_filtered
57
+
58
+ def __len__(self):
59
+ return len(self.file_paths)
60
+
61
+ def __getitem__(self, idx):
62
+ audio_path = self.file_paths[idx]
63
+ label = self.labels[idx]
64
+
65
+ waveform, sr = torchaudio.load(audio_path)
66
+ waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sr)(waveform)
67
+
68
+ waveform = waveform.squeeze(0)
69
+ if label == 0:
70
+ waveform = self.augment_audio(waveform, self.sr)
71
+ if label == 1:
72
+ waveform = self.highpass_filter(waveform, self.sr)
73
+
74
+ current_samples = waveform.shape[0]
75
+ if current_samples > self.target_samples:
76
+ start_idx = (current_samples - self.target_samples) // 2
77
+ waveform = waveform[start_idx:start_idx + self.target_samples]
78
+ elif current_samples < self.target_samples:
79
+ waveform = torch.nn.functional.pad(waveform, (0, self.target_samples - current_samples))
80
+
81
+ waveform = torch.tensor(waveform, dtype=torch.float32).unsqueeze(0)
82
+ label = torch.tensor(label, dtype=torch.long)
83
+
84
+ return waveform, label
85
+
86
+ def preprocess_audio(audio_path, target_sr=16000, target_duration=10.0):
87
+ waveform, sr = librosa.load(audio_path, sr=target_sr)
88
+
89
+ target_samples = int(target_duration * target_sr)
90
+ current_samples = len(waveform)
91
+
92
+ if current_samples > target_samples:
93
+ waveform = waveform[:target_samples]
94
+ elif current_samples < target_samples:
95
+ waveform = np.pad(waveform, (0, target_samples - current_samples))
96
+
97
+ waveform = torch.tensor(waveform).unsqueeze(0)
98
+ return waveform
99
+
100
+
101
+ DATASET_PATH = "/data/kym/AI_Music_Detection/audio/FakeMusicCaps"
102
+ SUNOCAPS_PATH = "/data/kym/Audio/SunoCaps" # Open Set 포함 데이터
103
+
104
+ real_files = glob.glob(os.path.join(DATASET_PATH, "real", "**", "*.wav"), recursive=True)
105
+ gen_files = glob.glob(os.path.join(DATASET_PATH, "generative", "**", "*.wav"), recursive=True)
106
+
107
+ open_real_files = real_files + glob.glob(os.path.join(SUNOCAPS_PATH, "real", "**", "*.wav"), recursive=True)
108
+ open_gen_files = gen_files + glob.glob(os.path.join(SUNOCAPS_PATH, "generative", "**", "*.wav"), recursive=True)
109
+
110
+ real_labels = [0] * len(real_files)
111
+ gen_labels = [1] * len(gen_files)
112
+
113
+ open_real_labels = [0] * len(open_real_files)
114
+ open_gen_labels = [1] * len(open_gen_files)
115
+
116
+ real_train, real_val, real_train_labels, real_val_labels = train_test_split(real_files, real_labels, test_size=0.2, random_state=42)
117
+ gen_train, gen_val, gen_train_labels, gen_val_labels = train_test_split(gen_files, gen_labels, test_size=0.2, random_state=42)
118
+
119
+ train_files = real_train + gen_train
120
+ train_labels = real_train_labels + gen_train_labels
121
+ val_files = real_val + gen_val
122
+ val_labels = real_val_labels + gen_val_labels
123
+
124
+ closed_test_files = real_files + gen_files
125
+ closed_test_labels = real_labels + gen_labels
126
+
127
+ open_test_files = open_real_files + open_gen_files
128
+ open_test_labels = open_real_labels + open_gen_labels
129
+
130
+ ros = RandomOverSampler(sampling_strategy='auto', random_state=42)
131
+ train_files_resampled, train_labels_resampled = ros.fit_resample(np.array(train_files).reshape(-1, 1), train_labels)
132
+
133
+ train_files = train_files_resampled.reshape(-1).tolist()
134
+ train_labels = train_labels_resampled
135
+
136
+ print(f"Train Original FAKE: {len(gen_train)}")
137
+ print(f"Train set (Oversampled) - REAL: {sum(1 for label in train_labels if label == 0)}, "
138
+ f"FAKE: {sum(1 for label in train_labels if label == 1)}, Total: {len(train_files)}")
139
+ print(f"Validation set - REAL: {len(real_val)}, FAKE: {len(gen_val)}, Total: {len(val_files)}")
ISMIR_2025/wav2vec/inference.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torchaudio
5
+ import argparse
6
+ from AI_Music_Detection.Code.model.wav2vec.wav2vec_datalib import preprocess_audio
7
+ from networks import Wav2Vec2ForFakeMusic
8
+
9
+ '''
10
+ command: python inference.py --gpu 0 --model_type pretrain --inference .wav
11
+ '''
12
+ parser = argparse.ArgumentParser(description="Wav2Vec2 AI Music Detection Inference")
13
+ parser.add_argument('--gpu', type=str, default='0', help='GPU ID')
14
+ parser.add_argument('--model_name', type=str, choices=['Wav2Vec2ForFakeMusic'], default='Wav2Vec2ForFakeMusic', help='Model name')
15
+ parser.add_argument('--ckpt_path', type=str, default='/data/kym/AI_Music_Detection/Code/model/wav2vec/ckpt/', help='Checkpoint directory')
16
+ parser.add_argument('--model_type', type=str, choices=['pretrain', 'finetune'], required=True, help='Choose between pretrained or fine-tuned model')
17
+ parser.add_argument('--inference', type=str, help='Path to a .wav file for inference')
18
+ args = parser.parse_args()
19
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+
22
+ if args.model_type == 'pretrain':
23
+ model_file = os.path.join(args.ckpt_path, "wav2vec2_pretrain_10.pth")
24
+ elif args.model_type == 'finetune':
25
+ model_file = os.path.join(args.ckpt_path, "wav2vec2_finetune_5.pth")
26
+ else:
27
+ raise ValueError("Invalid model type. Choose between 'pretrain' or 'finetune'.")
28
+
29
+ if not os.path.exists(model_file):
30
+ raise FileNotFoundError(f"Model checkpoint not found: {model_file}")
31
+
32
+ if args.model_name == 'Wav2Vec2ForFakeMusic':
33
+ model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=(args.model_type == 'finetune'))
34
+ else:
35
+ raise ValueError(f"Invalid model name: {args.model_name}")
36
+
37
+ def predict(audio_path):
38
+ print(f"\n🔍 Loading model from {model_file}")
39
+
40
+ if not os.path.exists(audio_path):
41
+ raise FileNotFoundError(f"[ERROR] Audio file not found: {audio_path}")
42
+
43
+ model.to(device)
44
+
45
+ input_tensor = preprocess_audio(audio_path).to(device)
46
+ print(f"Input shape after preprocessing: {input_tensor.shape}")
47
+
48
+ with torch.no_grad():
49
+ output = model(input_tensor)
50
+ print(f"Raw model output (logits): {output}")
51
+
52
+ probabilities = F.softmax(output, dim=1)
53
+ ai_music_prob = probabilities[0, 1].item()
54
+
55
+ print(f"Softmax Probabilities: {probabilities}")
56
+ print(f"AI Music Probability: {ai_music_prob:.4f}")
57
+
58
+ if ai_music_prob > 0.5:
59
+ print(f" FAKE MUSIC DETECTED ({ai_music_prob:.2%})")
60
+ else:
61
+ print(f" REAL MUSIC DETECTED ({100 - ai_music_prob * 100:.2f}%)")
62
+
63
+
64
+
65
+ if __name__ == "__main__":
66
+ if args.inference:
67
+ if not os.path.exists(args.inference):
68
+ print(f"[ERROR] No File Found: {args.inference}")
69
+ else:
70
+ predict(args.inference)
71
+
ISMIR_2025/wav2vec/main.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from tqdm import tqdm
8
+ from torch.utils.data import DataLoader
9
+ from sklearn.metrics import f1_score, precision_score, recall_score, balanced_accuracy_score, classification_report
10
+ import wandb
11
+ import argparse
12
+ from datalib import FakeMusicCapsDataset, train_files, train_labels, val_files, val_labels
13
+ from networks import Wav2Vec2ForFakeMusic
14
+
15
+ '''
16
+ python inference.py --gpu 0 --model_type finetune --inference
17
+ '''
18
+ parser = argparse.ArgumentParser(description='AI Music Detection Training with Wav2Vec 2.0')
19
+ parser.add_argument('--gpu', type=str, default='2', help='GPU ID')
20
+ parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
21
+ parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
22
+ parser.add_argument('--pretrain_epochs', type=int, default=20, help='Pretraining epochs (REAL data only)')
23
+ parser.add_argument('--finetune_epochs', type=int, default=10, help='Fine-tuning epochs (REAL + FAKE data)')
24
+ parser.add_argument('--checkpoint_dir', type=str, default='', help='Checkpoint directory')
25
+ parser.add_argument('--weight_decay', type=float, default=0.05, help="Weight decay for optimizer")
26
+
27
+ args = parser.parse_args()
28
+
29
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+
32
+ torch.manual_seed(42)
33
+ random.seed(42)
34
+ np.random.seed(42)
35
+
36
+ wandb.init(project="", name=f"pretrain_{args.pretrain_epochs}_finetune_{args.finetune_epochs}", config=args)
37
+
38
+ print("Preparing datasets...")
39
+ train_dataset = FakeMusicCapsDataset(train_files, train_labels)
40
+ val_dataset = FakeMusicCapsDataset(val_files, val_labels)
41
+
42
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
43
+ val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
44
+
45
+ pretrain_ckpt = os.path.join(args.checkpoint_dir, f"wav2vec2_pretrain_{args.pretrain_epochs}.pth")
46
+ finetune_ckpt = os.path.join(args.checkpoint_dir, f"wav2vec2_finetune_{args.finetune_epochs}.pth")
47
+
48
+ print("Initializing model...")
49
+ model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=True).to(device)
50
+
51
+ criterion = nn.CrossEntropyLoss()
52
+ optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
53
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
54
+
55
+ def train(model, dataloader, optimizer, criterion, scheduler, device, epoch, phase="Pretrain"):
56
+ model.train()
57
+ total_loss, total_correct, total_samples = 0, 0, 0
58
+ all_preds, all_labels = [], []
59
+ attention_maps = []
60
+
61
+ for inputs, labels in tqdm(dataloader, desc=f"{phase} Training Epoch {epoch+1}"):
62
+ inputs, labels = inputs.to(device), labels.to(device)
63
+ inputs = inputs.float()
64
+
65
+ outputs = model(inputs)
66
+ loss = criterion(outputs, labels)
67
+
68
+ optimizer.zero_grad()
69
+ loss.backward()
70
+ optimizer.step()
71
+
72
+ total_loss += loss.item()
73
+ preds = outputs.argmax(dim=1)
74
+ total_correct += (preds == labels).sum().item()
75
+ total_samples += labels.size(0)
76
+
77
+ all_preds.extend(preds.cpu().numpy())
78
+ all_labels.extend(labels.cpu().numpy())
79
+
80
+ if hasattr(model, "get_attention_maps"):
81
+ attention_maps.append(model.get_attention_maps())
82
+
83
+ scheduler.step()
84
+
85
+ accuracy = total_correct / total_samples
86
+ f1 = f1_score(all_labels, all_preds, average="weighted")
87
+ precision = precision_score(all_labels, all_preds, average="binary")
88
+ recall = recall_score(all_labels, all_preds, average="binary")
89
+ balanced_acc = balanced_accuracy_score(all_labels, all_preds)
90
+
91
+ wandb.log({
92
+ f"{phase} Train Loss": total_loss / len(dataloader),
93
+ f"{phase} Train Accuracy": accuracy,
94
+ f"{phase} Train F1 Score": f1,
95
+ f"{phase} Train Precision": precision,
96
+ f"{phase} Train Recall": recall,
97
+ f"{phase} Train Balanced Accuracy": balanced_acc,
98
+ })
99
+
100
+ print(f"{phase} Train Epoch {epoch+1}: Train Loss: {total_loss / len(dataloader):.4f}, "
101
+ f"Train Acc: {accuracy:.4f}, Train F1: {f1:.4f}, Train Prec: {precision:.4f}, Train Rec: {recall:.4f}, B_ACC: {balanced_acc:.4f}")
102
+
103
+ def validate(model, dataloader, criterion, device, phase="Validation"):
104
+ model.eval()
105
+ total_loss, total_correct, total_samples = 0, 0, 0
106
+ all_preds, all_labels = [], []
107
+
108
+ with torch.no_grad():
109
+ for inputs, labels in tqdm(dataloader, desc=f"{phase}"):
110
+ inputs, labels = inputs.to(device), labels.to(device)
111
+ inputs = inputs.squeeze(1)
112
+
113
+ outputs = model(inputs)
114
+ loss = criterion(outputs, labels)
115
+
116
+
117
+ total_loss += loss.item()
118
+ preds = outputs.argmax(dim=1)
119
+ total_correct += (preds == labels).sum().item()
120
+ total_samples += labels.size(0)
121
+
122
+ all_preds.extend(preds.cpu().numpy())
123
+ all_labels.extend(labels.cpu().numpy())
124
+
125
+ accuracy = total_correct / total_samples
126
+ f1 = f1_score(all_labels, all_preds, average="weighted")
127
+ val_bal_acc = balanced_accuracy_score(all_labels, all_preds)
128
+ val_precision = precision_score(all_labels, all_preds, average="binary")
129
+ val_recall = recall_score(all_labels, all_preds, average="binary")
130
+
131
+ wandb.log({
132
+ f"{phase} Val Loss": total_loss / len(dataloader),
133
+ f"{phase} Val Accuracy": accuracy,
134
+ f"{phase} Val F1 Score": f1,
135
+ f"{phase} Val Precision": val_precision,
136
+ f"{phase} Val Recall": val_recall,
137
+ f"{phase} Val Balanced Accuracy": val_bal_acc,
138
+ })
139
+ print(f"{phase} Val Loss: {total_loss / len(dataloader):.4f}, "
140
+ f"Val Acc: {accuracy:.4f}, Val F1: {f1:.4f}, Val Prec: {val_precision:.4f}, Val Rec: {val_recall:.4f}, Val B_ACC: {val_bal_acc:.4f}")
141
+ return total_loss / len(dataloader), accuracy, f1
142
+
143
+ print("\nStep 1: Self-Supervised Pretraining on REAL Data")
144
+ for epoch in range(args.pretrain_epochs):
145
+ train(model, train_loader, optimizer, criterion, scheduler, device, epoch, phase="Pretrain")
146
+
147
+ torch.save(model.state_dict(), pretrain_ckpt)
148
+ print(f"\nPretraining completed! Model saved at: {pretrain_ckpt}")
149
+
150
+ model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=False).to(device)
151
+ model.load_state_dict(torch.load(pretrain_ckpt))
152
+ print(f"\n🔍 Loaded Pretrained Model from {pretrain_ckpt}")
153
+
154
+ optimizer = optim.Adam(model.parameters(), lr=args.learning_rate / 10, weight_decay=args.weight_decay)
155
+
156
+ print("\nStep 2: Fine-Tuning on REAL + FAKE Data")
157
+ for epoch in range(args.finetune_epochs):
158
+ train(model, train_loader, optimizer, criterion, scheduler, device, epoch, phase="Fine-Tune")
159
+ validate(model, val_loader, criterion, device, phase="Fine-Tune Validation")
160
+
161
+ torch.save(model.state_dict(), finetune_ckpt)
162
+ print(f"\nFine-Tuning completed! Model saved at: {finetune_ckpt}")
ISMIR_2025/wav2vec/networks.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+
7
+ '''
8
+ freeze_feature_extractor=True 시 Feature Extractor를 동결 (Pretraining)
9
+ unfreeze_feature_extractor()를 호출하면 Fine-Tuning 가능
10
+ '''
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import matplotlib.pyplot as plt
15
+ import seaborn as sns
16
+ from transformers import Wav2Vec2Model
17
+
18
+ class cnn(nn.Module):
19
+ def __init__(self, embed_dim=512):
20
+ super(cnn, self).__init__()
21
+ self.conv_block = nn.Sequential(
22
+ nn.Conv2d(1, 16, kernel_size=3, padding=1),
23
+ nn.ReLU(),
24
+ nn.MaxPool2d(2),
25
+ nn.Conv2d(16, 32, kernel_size=3, padding=1),
26
+ nn.ReLU(),
27
+ nn.MaxPool2d(2),
28
+ nn.AdaptiveAvgPool2d((4, 4))
29
+ )
30
+ self.projection = nn.Linear(32 * 4 * 4, embed_dim)
31
+
32
+ def forward(self, x):
33
+ x = self.conv_block(x)
34
+ B, C, H, W = x.shape
35
+ x = x.view(B, -1)
36
+ x = self.projection(x)
37
+ return x
38
+
39
+ class CrossAttentionLayer(nn.Module):
40
+ def __init__(self, embed_dim, num_heads):
41
+ super(CrossAttentionLayer, self).__init__()
42
+ self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
43
+ self.layer_norm = nn.LayerNorm(embed_dim)
44
+ self.feed_forward = nn.Sequential(
45
+ nn.Linear(embed_dim, embed_dim * 4),
46
+ nn.ReLU(),
47
+ nn.Linear(embed_dim * 4, embed_dim)
48
+ )
49
+ self.attention_weights = None
50
+
51
+ def forward(self, x, cross_input):
52
+ # Cross-attention between x and cross_input
53
+ attn_output, attn_weights = self.multihead_attn(query=x, key=cross_input, value=cross_input)
54
+ self.attention_weights = attn_weights
55
+ x = self.layer_norm(x + attn_output)
56
+ feed_forward_output = self.feed_forward(x)
57
+ x = self.layer_norm(x + feed_forward_output)
58
+ return x
59
+
60
+ class CrossAttentionViT(nn.Module):
61
+ def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2):
62
+ super(CrossAttentionViT, self).__init__()
63
+
64
+ self.cross_attention_layers = nn.ModuleList([
65
+ CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers)
66
+ ])
67
+
68
+ encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
69
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
70
+
71
+ self.classifier = nn.Sequential(
72
+ nn.LayerNorm(embed_dim),
73
+ nn.Linear(embed_dim, num_classes)
74
+ )
75
+
76
+ def forward(self, x, cross_attention_input):
77
+ self.attention_maps = []
78
+ for layer in self.cross_attention_layers:
79
+ x = layer(x, cross_attention_input)
80
+ self.attention_maps.append(layer.attention_weights)
81
+
82
+ x = x.unsqueeze(1).permute(1, 0, 2)
83
+ x = self.transformer(x)
84
+ x = x.mean(dim=0)
85
+ x = self.classifier(x)
86
+ return x
87
+
88
+ class CCV(nn.Module):
89
+ def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2):
90
+ super(CCV, self).__init__()
91
+ self.encoder = cnn(embed_dim=embed_dim)
92
+ self.decoder = CrossAttentionViT(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes)
93
+
94
+ def forward(self, x, cross_attention_input=None):
95
+ x = self.encoder(x)
96
+
97
+ if cross_attention_input is None:
98
+ cross_attention_input = x
99
+
100
+ x = self.decoder(x, cross_attention_input)
101
+
102
+ # Attention Map 저장
103
+ self.attention_maps = self.decoder.attention_maps
104
+
105
+ return x
106
+
107
+ def get_attention_maps(self):
108
+ return self.attention_maps
109
+
110
+ import torch
111
+ import torch.nn as nn
112
+ from transformers import Wav2Vec2Model
113
+
114
+ class Wav2Vec2ForFakeMusic(nn.Module):
115
+ def __init__(self, num_classes=2, freeze_feature_extractor=True):
116
+ super(Wav2Vec2ForFakeMusic, self).__init__()
117
+
118
+ self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
119
+
120
+ if freeze_feature_extractor:
121
+ for param in self.wav2vec.parameters():
122
+ param.requires_grad = False
123
+
124
+ self.classifier = nn.Sequential(
125
+ nn.Linear(self.wav2vec.config.hidden_size, 256), # 768 → 256
126
+ nn.ReLU(),
127
+ nn.Dropout(0.3),
128
+ nn.Linear(256, num_classes) # 256 → 2 (Binary Classification)
129
+ )
130
+
131
+ def forward(self, x):
132
+ x = x.squeeze(1)
133
+ output = self.wav2vec(x)
134
+ features = output["last_hidden_state"] # (batch_size, seq_len, feature_dim)
135
+ pooled_features = features.mean(dim=1) # ✅ Mean Pooling 적용 (batch_size, feature_dim)
136
+ logits = self.classifier(pooled_features) # (batch_size, num_classes)
137
+
138
+ return logits, pooled_features
139
+
140
+
141
+ def visualize_attention_map(attn_map, mel_spec, layer_idx):
142
+ attn_map = attn_map.mean(dim=1).squeeze().cpu().numpy() # 여러 head 평균
143
+ mel_spec = mel_spec.squeeze().cpu().numpy()
144
+
145
+ fig, axs = plt.subplots(2, 1, figsize=(10, 8))
146
+
147
+ # 1Log-Mel Spectrogram 시각화
148
+ sns.heatmap(mel_spec, cmap='inferno', ax=axs[0])
149
+ axs[0].set_title("Log-Mel Spectrogram")
150
+ axs[0].set_xlabel("Time Frames")
151
+ axs[0].set_ylabel("Mel Frequency Bins")
152
+
153
+ # Attention Map 시각화
154
+ sns.heatmap(attn_map, cmap='viridis', ax=axs[1])
155
+ axs[1].set_title(f"Attention Map (Layer {layer_idx})")
156
+ axs[1].set_xlabel("Time Frames")
157
+ axs[1].set_ylabel("Query Positions")
158
+
159
+ plt.tight_layout()
160
+ plt.show()
161
+ plt.savefig("/data/kym/AI_Music_Detection/Code/model/attention_map/crossattn.png")
ISMIR_2025/wav2vec/test.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from torch.utils.data import DataLoader
7
+ from sklearn.metrics import precision_score, recall_score, f1_score, balanced_accuracy_score, confusion_matrix
8
+ from datalib import (
9
+ FakeMusicCapsDataset,
10
+ closed_test_files, closed_test_labels,
11
+ open_test_files, open_test_labels,
12
+ val_files, val_labels
13
+ )
14
+ from networks import Wav2Vec2ForFakeMusic
15
+ import tqdm
16
+ from tqdm import tqdm
17
+ import argparse
18
+ '''
19
+ python3 test.py --finetune_test --closed_test | --open_test
20
+ '''
21
+ parser = argparse.ArgumentParser(description="AI Music Detection Testing with Wav2Vec 2.0")
22
+ parser.add_argument('--gpu', type=str, default='0', help='GPU ID')
23
+ parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
24
+ parser.add_argument('--ckpt_path', type=str, default='', help='Checkpoint directory')
25
+ parser.add_argument('--pretrain_test', action="store_true", help="Test Pretrained Wav2Vec2 Model")
26
+ parser.add_argument('--finetune_test', action="store_true", help="Test Fine-Tuned Wav2Vec2 Model")
27
+ parser.add_argument('--closed_test', action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)")
28
+ parser.add_argument('--open_test', action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)")
29
+ parser.add_argument('--output_path', type=str, default='', help='Path to save test results')
30
+
31
+ args = parser.parse_args()
32
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+
35
+ def plot_confusion_matrix(y_true, y_pred, classes, output_path):
36
+ cm = confusion_matrix(y_true, y_pred)
37
+ fig, ax = plt.subplots(figsize=(6, 6))
38
+ im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
39
+ ax.figure.colorbar(im, ax=ax)
40
+
41
+ num_classes = cm.shape[0]
42
+ tick_labels = classes[:num_classes]
43
+
44
+ ax.set(xticks=np.arange(num_classes),
45
+ yticks=np.arange(num_classes),
46
+ xticklabels=tick_labels,
47
+ yticklabels=tick_labels,
48
+ ylabel='True label',
49
+ xlabel='Predicted label')
50
+
51
+ thresh = cm.max() / 2.
52
+ for i in range(cm.shape[0]):
53
+ for j in range(cm.shape[1]):
54
+ ax.text(j, i, format(cm[i, j], 'd'),
55
+ ha="center", va="center",
56
+ color="white" if cm[i, j] > thresh else "black")
57
+
58
+ fig.tight_layout()
59
+ plt.savefig(output_path)
60
+ plt.close(fig)
61
+
62
+ if args.pretrain_test:
63
+ ckpt_file = os.path.join(args.ckpt_path, "wav2vec2_pretrain_20.pth")
64
+ print("\n🔍 Loading Pretrained Model:", ckpt_file)
65
+ model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=True).to(device)
66
+
67
+ elif args.finetune_test:
68
+ ckpt_file = os.path.join(args.ckpt_path, "wav2vec2_finetune_10.pth")
69
+ print("\n🔍 Loading Fine-Tuned Model:", ckpt_file)
70
+ model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=False).to(device)
71
+
72
+ else:
73
+ raise ValueError("You must specify --pretrain_test or --finetune_test")
74
+
75
+ if not os.path.exists(ckpt_file):
76
+ raise FileNotFoundError(f"Checkpoint not found: {ckpt_file}")
77
+
78
+ # model.load_state_dict(torch.load(ckpt_file, map_location=device))
79
+ # model.eval()
80
+
81
+ ckpt = torch.load(ckpt_file, map_location=device)
82
+
83
+ keys_to_remove = [key for key in ckpt.keys() if "masked_spec_embed" in key]
84
+ for key in keys_to_remove:
85
+ print(f"Removing unexpected key: {key}")
86
+ del ckpt[key]
87
+
88
+ try:
89
+ model.load_state_dict(ckpt, strict=False)
90
+ except RuntimeError as e:
91
+ print("Model loading error:", e)
92
+ print("Trying to load entire model...")
93
+ model = torch.load(ckpt_file, map_location=device)
94
+ model.to(device)
95
+ model.eval()
96
+
97
+ torch.cuda.empty_cache()
98
+
99
+ if args.closed_test:
100
+ print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...")
101
+ test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels)
102
+ elif args.open_test:
103
+ print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...")
104
+ test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels)
105
+ else:
106
+ print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...")
107
+ test_dataset = FakeMusicCapsDataset(val_files, val_labels)
108
+
109
+ test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16)
110
+
111
+ def Test(model, test_loader, device, phase="Test"):
112
+ model.eval()
113
+ test_loss, test_correct, test_total = 0, 0, 0
114
+ all_preds, all_labels = [], []
115
+
116
+ with torch.no_grad():
117
+ for inputs, labels in tqdm(test_loader, desc=f"{phase}"):
118
+ inputs, labels = inputs.to(device), labels.to(device)
119
+ inputs = inputs.squeeze(1) # Ensure correct input shape
120
+
121
+ output = model(inputs)
122
+ loss = F.cross_entropy(output, labels)
123
+
124
+ test_loss += loss.item() * inputs.size(0)
125
+ preds = output.argmax(dim=1)
126
+ test_correct += (preds == labels).sum().item()
127
+ test_total += labels.size(0)
128
+
129
+ all_labels.extend(labels.cpu().numpy())
130
+ all_preds.extend(preds.cpu().numpy())
131
+
132
+ test_loss /= test_total
133
+ test_acc = test_correct / test_total
134
+ test_bal_acc = balanced_accuracy_score(all_labels, all_preds)
135
+ test_precision = precision_score(all_labels, all_preds, average="binary")
136
+ test_recall = recall_score(all_labels, all_preds, average="binary")
137
+ test_f1 = f1_score(all_labels, all_preds, average="binary")
138
+
139
+ print(f"\n{phase} Test Results - Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.3f} | "
140
+ f"Test Balanced Acc: {test_bal_acc:.4f} | Test Precision: {test_precision:.3f} | "
141
+ f"Test Recall: {test_recall:.3f} | Test F1: {test_f1:.3f}")
142
+
143
+ os.makedirs(args.output_path, exist_ok=True)
144
+ conf_matrix_path = os.path.join(args.output_path, f"confusion_matrix_{phase}_opentest.png")
145
+ plot_confusion_matrix(all_labels, all_preds, classes=["real", "generative"], output_path=conf_matrix_path)
146
+
147
+ print("\nEvaluating Model on Test Set...")
148
+ Test(model, test_loader, device, phase="Pretrained Model" if args.pretrain_test else "Fine-Tuned Model")
ISMIR_2025/wav2vec/utils/__pycache__/config.cpython-311.pyc ADDED
Binary file (4.79 kB). View file
 
ISMIR_2025/wav2vec/utils/__pycache__/idr_torch.cpython-311.pyc ADDED
Binary file (1.01 kB). View file
 
ISMIR_2025/wav2vec/utils/__pycache__/utilities.cpython-311.pyc ADDED
Binary file (16.1 kB). View file
 
ISMIR_2025/wav2vec/utils/config.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import csv
5
+
6
+ import numpy as np
7
+
8
+ sample_rate = 32000
9
+ clip_samples = sample_rate * 10 # Audio clips are 10-second
10
+
11
+ # Load label
12
+ with open(
13
+ "/gpfswork/rech/djl/uzj43um/audio_retrieval/audioset_tagging_cnn/metadata/class_labels_indices.csv",
14
+ "r",
15
+ ) as f:
16
+ reader = csv.reader(f, delimiter=",")
17
+ lines = list(reader)
18
+
19
+ labels = []
20
+ ids = [] # Each label has a unique id such as "/m/068hy"
21
+ for i1 in range(1, len(lines)):
22
+ id = lines[i1][1]
23
+ label = lines[i1][2]
24
+ ids.append(id)
25
+ labels.append(label)
26
+
27
+ classes_num = len(labels)
28
+
29
+ lb_to_ix = {label: i for i, label in enumerate(labels)}
30
+ ix_to_lb = {i: label for i, label in enumerate(labels)}
31
+
32
+ id_to_ix = {id: i for i, id in enumerate(ids)}
33
+ ix_to_id = {i: id for i, id in enumerate(ids)}
34
+
35
+ full_samples_per_class = np.array(
36
+ [
37
+ 937432,
38
+ 16344,
39
+ 7822,
40
+ 10271,
41
+ 2043,
42
+ 14420,
43
+ 733,
44
+ 1511,
45
+ 1258,
46
+ 424,
47
+ 1751,
48
+ 704,
49
+ 369,
50
+ 590,
51
+ 1063,
52
+ 1375,
53
+ 5026,
54
+ 743,
55
+ 853,
56
+ 1648,
57
+ 714,
58
+ 1497,
59
+ 1251,
60
+ 2139,
61
+ 1093,
62
+ 133,
63
+ 224,
64
+ 39469,
65
+ 6423,
66
+ 407,
67
+ 1559,
68
+ 4546,
69
+ 6826,
70
+ 7464,
71
+ 2468,
72
+ 549,
73
+ 4063,
74
+ 334,
75
+ 587,
76
+ 238,
77
+ 1766,
78
+ 691,
79
+ 114,
80
+ 2153,
81
+ 236,
82
+ 209,
83
+ 421,
84
+ 740,
85
+ 269,
86
+ 959,
87
+ 137,
88
+ 4192,
89
+ 485,
90
+ 1515,
91
+ 655,
92
+ 274,
93
+ 69,
94
+ 157,
95
+ 1128,
96
+ 807,
97
+ 1022,
98
+ 346,
99
+ 98,
100
+ 680,
101
+ 890,
102
+ 352,
103
+ 4169,
104
+ 2061,
105
+ 1753,
106
+ 9883,
107
+ 1339,
108
+ 708,
109
+ 37857,
110
+ 18504,
111
+ 12864,
112
+ 2475,
113
+ 2182,
114
+ 757,
115
+ 3624,
116
+ 677,
117
+ 1683,
118
+ 3583,
119
+ 444,
120
+ 1780,
121
+ 2364,
122
+ 409,
123
+ 4060,
124
+ 3097,
125
+ 3143,
126
+ 502,
127
+ 723,
128
+ 600,
129
+ 230,
130
+ 852,
131
+ 1498,
132
+ 1865,
133
+ 1879,
134
+ 2429,
135
+ 5498,
136
+ 5430,
137
+ 2139,
138
+ 1761,
139
+ 1051,
140
+ 831,
141
+ 2401,
142
+ 2258,
143
+ 1672,
144
+ 1711,
145
+ 987,
146
+ 646,
147
+ 794,
148
+ 25061,
149
+ 5792,
150
+ 4256,
151
+ 96,
152
+ 8126,
153
+ 2740,
154
+ 752,
155
+ 513,
156
+ 554,
157
+ 106,
158
+ 254,
159
+ 1592,
160
+ 556,
161
+ 331,
162
+ 615,
163
+ 2841,
164
+ 737,
165
+ 265,
166
+ 1349,
167
+ 358,
168
+ 1731,
169
+ 1115,
170
+ 295,
171
+ 1070,
172
+ 972,
173
+ 174,
174
+ 937780,
175
+ 112337,
176
+ 42509,
177
+ 49200,
178
+ 11415,
179
+ 6092,
180
+ 13851,
181
+ 2665,
182
+ 1678,
183
+ 13344,
184
+ 2329,
185
+ 1415,
186
+ 2244,
187
+ 1099,
188
+ 5024,
189
+ 9872,
190
+ 10948,
191
+ 4409,
192
+ 2732,
193
+ 1211,
194
+ 1289,
195
+ 4807,
196
+ 5136,
197
+ 1867,
198
+ 16134,
199
+ 14519,
200
+ 3086,
201
+ 19261,
202
+ 6499,
203
+ 4273,
204
+ 2790,
205
+ 8820,
206
+ 1228,
207
+ 1575,
208
+ 4420,
209
+ 3685,
210
+ 2019,
211
+ 664,
212
+ 324,
213
+ 513,
214
+ 411,
215
+ 436,
216
+ 2997,
217
+ 5162,
218
+ 3806,
219
+ 1389,
220
+ 899,
221
+ 8088,
222
+ 7004,
223
+ 1105,
224
+ 3633,
225
+ 2621,
226
+ 9753,
227
+ 1082,
228
+ 26854,
229
+ 3415,
230
+ 4991,
231
+ 2129,
232
+ 5546,
233
+ 4489,
234
+ 2850,
235
+ 1977,
236
+ 1908,
237
+ 1719,
238
+ 1106,
239
+ 1049,
240
+ 152,
241
+ 136,
242
+ 802,
243
+ 488,
244
+ 592,
245
+ 2081,
246
+ 2712,
247
+ 1665,
248
+ 1128,
249
+ 250,
250
+ 544,
251
+ 789,
252
+ 2715,
253
+ 8063,
254
+ 7056,
255
+ 2267,
256
+ 8034,
257
+ 6092,
258
+ 3815,
259
+ 1833,
260
+ 3277,
261
+ 8813,
262
+ 2111,
263
+ 4662,
264
+ 2678,
265
+ 2954,
266
+ 5227,
267
+ 1472,
268
+ 2591,
269
+ 3714,
270
+ 1974,
271
+ 1795,
272
+ 4680,
273
+ 3751,
274
+ 6585,
275
+ 2109,
276
+ 36617,
277
+ 6083,
278
+ 16264,
279
+ 17351,
280
+ 3449,
281
+ 5034,
282
+ 3931,
283
+ 2599,
284
+ 4134,
285
+ 3892,
286
+ 2334,
287
+ 2211,
288
+ 4516,
289
+ 2766,
290
+ 2862,
291
+ 3422,
292
+ 1788,
293
+ 2544,
294
+ 2403,
295
+ 2892,
296
+ 4042,
297
+ 3460,
298
+ 1516,
299
+ 1972,
300
+ 1563,
301
+ 1579,
302
+ 2776,
303
+ 1647,
304
+ 4535,
305
+ 3921,
306
+ 1261,
307
+ 6074,
308
+ 2922,
309
+ 3068,
310
+ 1948,
311
+ 4407,
312
+ 712,
313
+ 1294,
314
+ 1019,
315
+ 1572,
316
+ 3764,
317
+ 5218,
318
+ 975,
319
+ 1539,
320
+ 6376,
321
+ 1606,
322
+ 6091,
323
+ 1138,
324
+ 1169,
325
+ 7925,
326
+ 3136,
327
+ 1108,
328
+ 2677,
329
+ 2680,
330
+ 1383,
331
+ 3144,
332
+ 2653,
333
+ 1986,
334
+ 1800,
335
+ 1308,
336
+ 1344,
337
+ 122231,
338
+ 12977,
339
+ 2552,
340
+ 2678,
341
+ 7824,
342
+ 768,
343
+ 8587,
344
+ 39503,
345
+ 3474,
346
+ 661,
347
+ 430,
348
+ 193,
349
+ 1405,
350
+ 1442,
351
+ 3588,
352
+ 6280,
353
+ 10515,
354
+ 785,
355
+ 710,
356
+ 305,
357
+ 206,
358
+ 4990,
359
+ 5329,
360
+ 3398,
361
+ 1771,
362
+ 3022,
363
+ 6907,
364
+ 1523,
365
+ 8588,
366
+ 12203,
367
+ 666,
368
+ 2113,
369
+ 7916,
370
+ 434,
371
+ 1636,
372
+ 5185,
373
+ 1062,
374
+ 664,
375
+ 952,
376
+ 3490,
377
+ 2811,
378
+ 2749,
379
+ 2848,
380
+ 15555,
381
+ 363,
382
+ 117,
383
+ 1494,
384
+ 1647,
385
+ 5886,
386
+ 4021,
387
+ 633,
388
+ 1013,
389
+ 5951,
390
+ 11343,
391
+ 2324,
392
+ 243,
393
+ 372,
394
+ 943,
395
+ 734,
396
+ 242,
397
+ 3161,
398
+ 122,
399
+ 127,
400
+ 201,
401
+ 1654,
402
+ 768,
403
+ 134,
404
+ 1467,
405
+ 642,
406
+ 1148,
407
+ 2156,
408
+ 1368,
409
+ 1176,
410
+ 302,
411
+ 1909,
412
+ 61,
413
+ 223,
414
+ 1812,
415
+ 287,
416
+ 422,
417
+ 311,
418
+ 228,
419
+ 748,
420
+ 230,
421
+ 1876,
422
+ 539,
423
+ 1814,
424
+ 737,
425
+ 689,
426
+ 1140,
427
+ 591,
428
+ 943,
429
+ 353,
430
+ 289,
431
+ 198,
432
+ 490,
433
+ 7938,
434
+ 1841,
435
+ 850,
436
+ 457,
437
+ 814,
438
+ 146,
439
+ 551,
440
+ 728,
441
+ 1627,
442
+ 620,
443
+ 648,
444
+ 1621,
445
+ 2731,
446
+ 535,
447
+ 88,
448
+ 1736,
449
+ 736,
450
+ 328,
451
+ 293,
452
+ 3170,
453
+ 344,
454
+ 384,
455
+ 7640,
456
+ 433,
457
+ 215,
458
+ 715,
459
+ 626,
460
+ 128,
461
+ 3059,
462
+ 1833,
463
+ 2069,
464
+ 3732,
465
+ 1640,
466
+ 1508,
467
+ 836,
468
+ 567,
469
+ 2837,
470
+ 1151,
471
+ 2068,
472
+ 695,
473
+ 1494,
474
+ 3173,
475
+ 364,
476
+ 88,
477
+ 188,
478
+ 740,
479
+ 677,
480
+ 273,
481
+ 1533,
482
+ 821,
483
+ 1091,
484
+ 293,
485
+ 647,
486
+ 318,
487
+ 1202,
488
+ 328,
489
+ 532,
490
+ 2847,
491
+ 526,
492
+ 721,
493
+ 370,
494
+ 258,
495
+ 956,
496
+ 1269,
497
+ 1641,
498
+ 339,
499
+ 1322,
500
+ 4485,
501
+ 286,
502
+ 1874,
503
+ 277,
504
+ 757,
505
+ 1393,
506
+ 1330,
507
+ 380,
508
+ 146,
509
+ 377,
510
+ 394,
511
+ 318,
512
+ 339,
513
+ 1477,
514
+ 1886,
515
+ 101,
516
+ 1435,
517
+ 284,
518
+ 1425,
519
+ 686,
520
+ 621,
521
+ 221,
522
+ 117,
523
+ 87,
524
+ 1340,
525
+ 201,
526
+ 1243,
527
+ 1222,
528
+ 651,
529
+ 1899,
530
+ 421,
531
+ 712,
532
+ 1016,
533
+ 1279,
534
+ 124,
535
+ 351,
536
+ 258,
537
+ 7043,
538
+ 368,
539
+ 666,
540
+ 162,
541
+ 7664,
542
+ 137,
543
+ 70159,
544
+ 26179,
545
+ 6321,
546
+ 32236,
547
+ 33320,
548
+ 771,
549
+ 1169,
550
+ 269,
551
+ 1103,
552
+ 444,
553
+ 364,
554
+ 2710,
555
+ 121,
556
+ 751,
557
+ 1609,
558
+ 855,
559
+ 1141,
560
+ 2287,
561
+ 1940,
562
+ 3943,
563
+ 289,
564
+ ]
565
+ )
ISMIR_2025/wav2vec/utils/confusion_matrix_plot.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.metrics import confusion_matrix
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+
5
+ def plot_confusion_matrix(y_true, y_pred, classes, writer, epoch):
6
+ cm = confusion_matrix(y_true, y_pred)
7
+ fig, ax = plt.subplots(figsize=(6, 6))
8
+ im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
9
+ ax.figure.colorbar(im, ax=ax)
10
+
11
+ num_classes = cm.shape[0]
12
+ tick_labels = classes[:num_classes]
13
+
14
+ ax.set(xticks=np.arange(num_classes),
15
+ yticks=np.arange(num_classes),
16
+ xticklabels=tick_labels,
17
+ yticklabels=tick_labels,
18
+ ylabel='True label',
19
+ xlabel='Predicted label')
20
+
21
+ thresh = cm.max() / 2.
22
+ for i in range(cm.shape[0]):
23
+ for j in range(cm.shape[1]):
24
+ ax.text(j, i, format(cm[i, j], 'd'),
25
+ ha="center", va="center",
26
+ color="white" if cm[i, j] > thresh else "black")
27
+
28
+ fig.tight_layout()
29
+ writer.add_figure("Confusion Matrix", fig, epoch)