diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..1c29e60ddafaf3934309b999feb8f866b0fdd09d 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +v2_request/request_music/AI_Detection/9486_music_id_Lost[[:space:]]Sky[[:space:]]x[[:space:]]Anna[[:space:]]Yvette[[:space:]]-[[:space:]]Carry[[:space:]]On[[:space:]]|[[:space:]]Trap[[:space:]]|[[:space:]]NCS[[:space:]]-[[:space:]]Copyright[[:space:]]Free[[:space:]]Music.wav filter=lfs diff=lfs merge=lfs -text diff --git a/ISMIR_2025/MERT/__pycache__/datalib.cpython-311.pyc b/ISMIR_2025/MERT/__pycache__/datalib.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c25d929f202b8609fdadccf09fa33f2815fb7df Binary files /dev/null and b/ISMIR_2025/MERT/__pycache__/datalib.cpython-311.pyc differ diff --git a/ISMIR_2025/MERT/__pycache__/datalib.cpython-312.pyc b/ISMIR_2025/MERT/__pycache__/datalib.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1270e9134dfdf68b1658a320f107df6ef3a2ba92 Binary files /dev/null and b/ISMIR_2025/MERT/__pycache__/datalib.cpython-312.pyc differ diff --git a/ISMIR_2025/MERT/__pycache__/datalib_singfake.cpython-311.pyc b/ISMIR_2025/MERT/__pycache__/datalib_singfake.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03f5ac2f8487445f7c584207de635871e35e59f3 Binary files /dev/null and b/ISMIR_2025/MERT/__pycache__/datalib_singfake.cpython-311.pyc differ diff --git a/ISMIR_2025/MERT/__pycache__/main.cpython-312.pyc b/ISMIR_2025/MERT/__pycache__/main.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b49c972a99e034e49974bd9f3dfe2e32b703c78e Binary files /dev/null and b/ISMIR_2025/MERT/__pycache__/main.cpython-312.pyc differ diff --git a/ISMIR_2025/MERT/__pycache__/networks.cpython-311.pyc b/ISMIR_2025/MERT/__pycache__/networks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98e1f8197ddd0a2f39c8d17f12beb6d92e0f54dd Binary files /dev/null and b/ISMIR_2025/MERT/__pycache__/networks.cpython-311.pyc differ diff --git a/ISMIR_2025/MERT/__pycache__/networks.cpython-312.pyc b/ISMIR_2025/MERT/__pycache__/networks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19fa5c4c32a3fc21c0d310f7a4b7ef65ec85291b Binary files /dev/null and b/ISMIR_2025/MERT/__pycache__/networks.cpython-312.pyc differ diff --git a/ISMIR_2025/MERT/__pycache__/networks.cpython-39.pyc b/ISMIR_2025/MERT/__pycache__/networks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1904703b75742f5e6c37fd084069210cc08116af Binary files /dev/null and b/ISMIR_2025/MERT/__pycache__/networks.cpython-39.pyc differ diff --git a/ISMIR_2025/MERT/datalib.py b/ISMIR_2025/MERT/datalib.py new file mode 100644 index 0000000000000000000000000000000000000000..96a74e056c54a2e8cc7ced35fe63f2ce7d58d54f --- /dev/null +++ b/ISMIR_2025/MERT/datalib.py @@ -0,0 +1,203 @@ +import os +import glob +import torch +import torchaudio +import librosa +import numpy as np +from sklearn.model_selection import train_test_split +from torch.utils.data import Dataset +from imblearn.over_sampling import RandomOverSampler +from transformers import Wav2Vec2Processor +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence +from transformers import Wav2Vec2FeatureExtractor +import scipy.signal as signal +import scipy.signal +# class FakeMusicCapsDataset(Dataset): +# def __init__(self, file_paths, labels, sr=16000, target_duration=10.0): +# self.file_paths = file_paths +# self.labels = labels +# self.sr = sr +# self.target_samples = int(target_duration * sr) # Fixed length: 5 seconds + +# def __len__(self): +# return len(self.file_paths) + +# def __getitem__(self, idx): +# audio_path = self.file_paths[idx] +# label = self.labels[idx] + +# waveform, sr = torchaudio.load(audio_path) +# waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sr)(waveform) +# waveform = waveform.mean(dim=0) # Convert to mono +# waveform = waveform.squeeze(0) + + +# current_samples = waveform.shape[0] + +# # **Ensure waveform is exactly `target_samples` long** +# if current_samples > self.target_samples: +# waveform = waveform[:self.target_samples] # Truncate if too long +# elif current_samples < self.target_samples: +# pad_length = self.target_samples - current_samples +# waveform = torch.nn.functional.pad(waveform, (0, pad_length)) # Pad if too short + +# return waveform.unsqueeze(0), torch.tensor(label, dtype=torch.long) # Ensure 2D shape (1, target_samples) + +class FakeMusicCapsDataset(Dataset): + def __init__(self, file_paths, labels, sr=16000, target_duration=10.0): + self.file_paths = file_paths + self.labels = labels + self.sr = sr + self.target_samples = int(target_duration * sr) # Fixed length: 10 seconds + self.processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True) + + def __len__(self): + return len(self.file_paths) + + def highpass_filter(self, y, sr, cutoff=500, order=5): + if isinstance(sr, np.ndarray): + # print(f"[ERROR] sr is an array, taking mean value. Original sr: {sr}") + sr = np.mean(sr) + if not isinstance(sr, (int, float)): + raise ValueError(f"[ERROR] sr must be a number, but got {type(sr)}: {sr}") + # print(f"[DEBUG] Highpass filter using sr={sr}, cutoff={cutoff}") + if sr <= 0: + raise ValueError(f"Invalid sample rate: {sr}. It must be greater than 0.") + nyquist = 0.5 * sr + # print(f"[DEBUG] Nyquist frequency={nyquist}") + if cutoff <= 0 or cutoff >= nyquist: + print(f"[WARNING] Invalid cutoff frequency {cutoff}, adjusting...") + cutoff = max(10, min(cutoff, nyquist - 1)) + normal_cutoff = cutoff / nyquist + # print(f"[DEBUG] Adjusted cutoff={cutoff}, normal_cutoff={normal_cutoff}") + b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) + y_filtered = signal.lfilter(b, a, y) + return y_filtered + + def __getitem__(self, idx): + audio_path = self.file_paths[idx] + label = self.labels[idx] + + waveform, sr = torchaudio.load(audio_path) + + target_sr = self.processor.sampling_rate + + if sr != target_sr: + resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr) + waveform = resampler(waveform) + + waveform = waveform.mean(dim=0).squeeze(0) # [Time] + + if label == 1: + waveform = self.highpass_filter(waveform, self.sr) + + current_samples = waveform.shape[0] + if current_samples > self.target_samples: + waveform = waveform[:self.target_samples] # Truncate + elif current_samples < self.target_samples: + pad_length = self.target_samples - current_samples + waveform = torch.nn.functional.pad(waveform, (0, pad_length)) # Pad + + if isinstance(waveform, torch.Tensor): + waveform = waveform.numpy() # Tensor일 경우에만 변환 + + inputs = self.processor(waveform, sampling_rate=target_sr, return_tensors="pt", padding=True) + + return inputs["input_values"].squeeze(0), torch.tensor(label, dtype=torch.long) # [1, time] → [time] + + @staticmethod + def collate_fn(batch, target_samples=16000 * 10): + + inputs, labels = zip(*batch) # Unzip batch + + processed_inputs = [] + for waveform in inputs: + current_samples = waveform.shape[0] + + if current_samples > target_samples: + start_idx = (current_samples - target_samples) // 2 + cropped_waveform = waveform[start_idx:start_idx + target_samples] + else: + pad_length = target_samples - current_samples + cropped_waveform = torch.nn.functional.pad(waveform, (0, pad_length)) + + processed_inputs.append(cropped_waveform) + + processed_inputs = torch.stack(processed_inputs) # [batch, target_samples] + labels = torch.tensor(labels, dtype=torch.long) # [batch] + + return processed_inputs, labels + + def preprocess_audio(audio_path, target_sr=16000, max_length=160000): + """ + 오디오를 모델 입력에 맞게 변환 + - target_sr: 16kHz로 변환 + - max_length: 최대 길이 160000 (10초) + """ + waveform, sr = torchaudio.load(audio_path) + + # Resample if needed + if sr != target_sr: + waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(waveform) + + # Convert to mono + waveform = waveform.mean(dim=0).unsqueeze(0) # (1, sequence_length) + + current_samples = waveform.shape[1] + if current_samples > max_length: + start_idx = (current_samples - max_length) // 2 + waveform = waveform[:, start_idx:start_idx + max_length] + elif current_samples < max_length: + pad_length = max_length - current_samples + waveform = torch.nn.functional.pad(waveform, (0, pad_length)) + + return waveform + + +DATASET_PATH = "/data/kym/AI_Music_Detection/audio/FakeMusicCaps" +SUNOCAPS_PATH = "/data/kym/Audio/SunoCaps" # Open Set 포함 데이터 + +# Closed Test: FakeMusicCaps 데이터셋 사용 +real_files = glob.glob(os.path.join(DATASET_PATH, "real", "**", "*.wav"), recursive=True) +gen_files = glob.glob(os.path.join(DATASET_PATH, "generative", "**", "*.wav"), recursive=True) + +# Open Set Test: SUNOCAPS_PATH 데이터 포함 +open_real_files = real_files + glob.glob(os.path.join(SUNOCAPS_PATH, "real", "**", "*.wav"), recursive=True) +open_gen_files = gen_files + glob.glob(os.path.join(SUNOCAPS_PATH, "generative", "**", "*.wav"), recursive=True) + +real_labels = [0] * len(real_files) +gen_labels = [1] * len(gen_files) + +open_real_labels = [0] * len(open_real_files) +open_gen_labels = [1] * len(open_gen_files) + +# Closed Train, Val +real_train, real_val, real_train_labels, real_val_labels = train_test_split(real_files, real_labels, test_size=0.2, random_state=42) +gen_train, gen_val, gen_train_labels, gen_val_labels = train_test_split(gen_files, gen_labels, test_size=0.2, random_state=42) + +train_files = real_train + gen_train +train_labels = real_train_labels + gen_train_labels +val_files = real_val + gen_val +val_labels = real_val_labels + gen_val_labels + +# Closed Set Test용 데이터셋 +closed_test_files = real_files + gen_files +closed_test_labels = real_labels + gen_labels + +# Open Set Test용 데이터셋 +open_test_files = open_real_files + open_gen_files +open_test_labels = open_real_labels + open_gen_labels + +# Oversampling 적용 +ros = RandomOverSampler(sampling_strategy='auto', random_state=42) +train_files_resampled, train_labels_resampled = ros.fit_resample(np.array(train_files).reshape(-1, 1), train_labels) + +train_files = train_files_resampled.reshape(-1).tolist() +train_labels = train_labels_resampled + +print(f"📌 Train Original FAKE: {len(gen_train)}") +print(f"📌 Train set (Oversampled) - REAL: {sum(1 for label in train_labels if label == 0)}, " + f"FAKE: {sum(1 for label in train_labels if label == 1)}, Total: {len(train_files)}") +print(f"📌 Validation set - REAL: {len(real_val)}, FAKE: {len(gen_val)}, Total: {len(val_files)}") diff --git a/ISMIR_2025/MERT/main.py b/ISMIR_2025/MERT/main.py new file mode 100644 index 0000000000000000000000000000000000000000..2992ffcb72eea51d096f2f40fca04c524d309cd3 --- /dev/null +++ b/ISMIR_2025/MERT/main.py @@ -0,0 +1,197 @@ +import os +import random +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from tqdm import tqdm +from torch.utils.data import DataLoader +from sklearn.metrics import f1_score, precision_score, recall_score, balanced_accuracy_score +import wandb +import argparse +from transformers import AutoModel, AutoConfig, Wav2Vec2FeatureExtractor +from ISMIR_2025.MERT.datalib import FakeMusicCapsDataset, train_files, train_labels, val_files, val_labels +from ISMIR_2025.MERT.networks import MERTFeatureExtractor +# Set device +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Seed for reproducibility +torch.manual_seed(42) +random.seed(42) +np.random.seed(42) + +# Initialize wandb +wandb.init(project="mert", name=f"hpfilter_pretrain_{args.pretrain_epochs}_finetune_{args.finetune_epochs}", config=args) + +# Load datasets +print("🔍 Preparing datasets...") +train_dataset = FakeMusicCapsDataset(train_files, train_labels) +val_dataset = FakeMusicCapsDataset(val_files, val_labels) + +train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, collate_fn=FakeMusicCapsDataset.collate_fn) +val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, collate_fn=FakeMusicCapsDataset.collate_fn) + +# Model Checkpoint Paths +pretrain_ckpt = os.path.join(args.checkpoint_dir, f"mert_pretrain_{args.pretrain_epochs}.pth") +finetune_ckpt = os.path.join(args.checkpoint_dir, f"mert_finetune_{args.finetune_epochs}.pth") + +# Load Music2Vec Model for Pretraining +print("🔍 Initializing MERT model for Pretraining...") + +config = AutoConfig.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True) +if not hasattr(config, "conv_pos_batch_norm"): + setattr(config, "conv_pos_batch_norm", False) + +mert_model = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", config=config, trust_remote_code=True).to(device) +mert_model = MERTFeatureExtractor().to(device) + +# Loss and Optimizer +criterion = nn.CrossEntropyLoss() +optimizer = optim.Adam(mert_model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) +scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) + +# Training function +def train(model, dataloader, optimizer, criterion, device, epoch, phase="Pretrain"): + model.train() + total_loss, total_correct, total_samples = 0, 0, 0 + all_preds, all_labels = [], [] + + for inputs, labels in tqdm(dataloader, desc=f"{phase} Training Epoch {epoch+1}"): + labels = labels.to(device) + inputs = inputs.to(device) + + # inputs = inputs.float() + # output = model(inputs) + output = model(inputs) + + # Check if the output is a tensor or an object with logits + if isinstance(output, torch.Tensor): + logits = output + elif hasattr(output, "logits"): + logits = output.logits + elif isinstance(output, (tuple, list)): + logits = output[0] + else: + raise ValueError("Unexpected model output type") + + loss = criterion(logits, labels) + + + # loss = criterion(output, labels) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + preds = output.argmax(dim=1) + total_correct += (preds == labels).sum().item() + total_samples += labels.size(0) + all_preds.extend(preds.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + + scheduler.step() + + accuracy = total_correct / total_samples + f1 = f1_score(all_labels, all_preds, average="binary") + precision = precision_score(all_labels, all_preds, average="binary") + recall = recall_score(all_labels, all_preds, average="binary", pos_label=1) + balanced_acc = balanced_accuracy_score(all_labels, all_preds) + + + wandb.log({ + f"{phase} Train Loss": total_loss / len(dataloader), + f"{phase} Train Accuracy": accuracy, + f"{phase} Train F1 Score": f1, + f"{phase} Train Precision": precision, + f"{phase} Train Recall": recall, + f"{phase} Train Balanced Accuracy": balanced_acc, + }) + + print(f"{phase} Train Epoch {epoch+1}: Train Loss: {total_loss / len(dataloader):.4f}, " + f"Train Acc: {accuracy:.4f}, Train F1: {f1:.4f}, Train Prec: {precision:.4f}, Train Rec: {recall:.4f}, B_ACC: {balanced_acc:.4f}") + +def validate(model, dataloader, optimizer, criterion, device, epoch, phase="Validation"): + model.eval() + total_loss, total_correct, total_samples = 0, 0, 0 + all_preds, all_labels = [], [] + + for inputs, labels in tqdm(dataloader, desc=f"{phase} Validation Epoch {epoch+1}"): + labels = labels.to(device) + inputs = inputs.to(device) + + output = model(inputs) + + # Check if the output is a tensor or an object with logits + if isinstance(output, torch.Tensor): + logits = output + elif hasattr(output, "logits"): + logits = output.logits + elif isinstance(output, (tuple, list)): + logits = output[0] + else: + raise ValueError("Unexpected model output type") + + loss = criterion(logits, labels) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + preds = outputs.argmax(dim=1) + total_correct += (preds == labels).sum().item() + total_samples += labels.size(0) + all_preds.extend(preds.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + scheduler.step() + accuracy = total_correct / total_samples + val_f1 = f1_score(all_labels, all_preds, average="weighted") + val_precision = precision_score(all_labels, all_preds, average="binary") + val_recall = recall_score(all_labels, all_preds, average="binary") + val_bal_acc = balanced_accuracy_score(all_labels, all_preds) + + wandb.log({ + f"{phase} Val Loss": total_loss / len(dataloader), + f"{phase} Val Accuracy": accuracy, + f"{phase} Val F1 Score": val_f1, + f"{phase} Val Precision": val_precision, + f"{phase} Val Recall": val_recall, + f"{phase} Val Balanced Accuracy": val_bal_acc, + }) + print(f"{phase} Val Loss: {total_loss / len(dataloader):.4f}, " + f"Val Acc: {accuracy:.4f}, Val F1: {val_f1:.4f}, Val Prec: {val_precision:.4f}, Val Rec: {val_recall:.4f}, Val B_ACC: {val_bal_acc:.4f}") + return total_loss / len(dataloader), accuracy, val_f1 + + +print("\n🔍 Step 1: Self-Supervised Pretraining on REAL Data") +# for epoch in range(args.pretrain_epochs): +# train(mert_model, train_loader, optimizer, criterion, device, epoch, phase="Pretrain") +# torch.save(mert_model.state_dict(), pretrain_ckpt) +# print(f"\nPretraining completed! Model saved at: {pretrain_ckpt}") + +# print("\n🔍 Initializing CCV Model for Fine-Tuning...") +# mert_model = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", config=config, trust_remote_code=True).to(device) +# mert_model.feature_extractor.load_state_dict(torch.load(pretrain_ckpt), strict=False) + +# optimizer = optim.Adam(mert_model.parameters(), lr=args.finetune_lr, weight_decay=args.weight_decay) +scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) + +print("\n🔍 Step 2: Fine-Tuning CCV Model") +for epoch in range(args.finetune_epochs): + train(mert_model, train_loader, optimizer, criterion, device, epoch, phase="Fine-Tune") + +torch.save(mert_model.state_dict(), finetune_ckpt) +print(f"\nFine-Tuning completed! Model saved at: {finetune_ckpt}") + +print("\n🔍 Step 2: Fine-Tuning MERT Model") +mert_model.load_state_dict(torch.load(pretrain_ckpt), strict=False) + +optimizer = optim.Adam(mert_model.parameters(), lr=args.finetune_lr, weight_decay=args.weight_decay) +scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) + +for epoch in range(args.finetune_epochs): + train(mert_model, train_loader, optimizer, criterion, device, epoch, phase="Fine-Tune") + +torch.save(mert_model.state_dict(), finetune_ckpt) +print(f"\nFine-Tuning completed! Model saved at: {finetune_ckpt}") \ No newline at end of file diff --git a/ISMIR_2025/MERT/networks.py b/ISMIR_2025/MERT/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..2e985e909ccd7b63f5a44bc82d7e368903e92c07 --- /dev/null +++ b/ISMIR_2025/MERT/networks.py @@ -0,0 +1,107 @@ +import torch +import torch.nn as nn +from transformers import AutoModel, AutoConfig + +class MERTFeatureExtractor(nn.Module): + def __init__(self, freeze_feature_extractor=True): + super(MERTFeatureExtractor, self).__init__() + config = AutoConfig.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True) + if not hasattr(config, "conv_pos_batch_norm"): + setattr(config, "conv_pos_batch_norm", False) + self.mert = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", config=config, trust_remote_code=True) + + if freeze_feature_extractor: + self.freeze() + + def forward(self, input_values): + # 입력: [batch, time] + # 사전학습된 MERT의 hidden_states 추출 (예시로 모든 레이어의 hidden state 사용) + with torch.no_grad(): + outputs = self.mert(input_values, output_hidden_states=True) + # hidden_states: tuple of [batch, time, feature_dim] + # 여러 레이어의 hidden state를 스택한 뒤 시간축에 대해 평균하여 feature를 얻음 + hidden_states = torch.stack(outputs.hidden_states) # [num_layers, batch, time, feature_dim] + hidden_states = hidden_states.detach().clone().requires_grad_(True) + time_reduced = hidden_states.mean(dim=2) # [num_layers, batch, feature_dim] + time_reduced = time_reduced.permute(1, 0, 2) # [batch, num_layers, feature_dim] + return time_reduced + + def freeze(self): + for param in self.mert.parameters(): + param.requires_grad = False + + def unfreeze(self): + for param in self.mert.parameters(): + param.requires_grad = True + + +class CrossAttentionLayer(nn.Module): + def __init__(self, embed_dim, num_heads): + super(CrossAttentionLayer, self).__init__() + self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) + self.layer_norm1 = nn.LayerNorm(embed_dim) + self.layer_norm2 = nn.LayerNorm(embed_dim) + self.feed_forward = nn.Sequential( + nn.Linear(embed_dim, embed_dim * 4), + nn.ReLU(), + nn.Linear(embed_dim * 4, embed_dim) + ) + + def forward(self, x, cross_input): + # x와 cross_input 간의 어텐션 수행 + attn_output, _ = self.multihead_attn(query=x, key=cross_input, value=cross_input) + x = self.layer_norm1(x + attn_output) + ff_output = self.feed_forward(x) + x = self.layer_norm2(x + ff_output) + return x + + +class CCV(nn.Module): + def __init__(self, embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True): + super(CCV, self).__init__() + # MERT 기반 feature extractor (pretraining weight로부터 유의미한 피쳐 추출) + self.feature_extractor = MERTFeatureExtractor(freeze_feature_extractor=freeze_feature_extractor) + # Cross-Attention 레이어 여러 층 + self.cross_attention_layers = nn.ModuleList([ + CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers) + ]) + # Transformer Encoder (배치 차원 고려) + encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + # 분류기 + self.classifier = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, 256), + nn.BatchNorm1d(256), + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(256, num_classes) + ) + + + def forward(self, input_values): + """ + input_values: Tensor [batch, time] + 1. MERT로부터 feature 추출 → [batch, num_layers, feature_dim] + 2. 임베딩 차원 맞추기 위해 transpose → [batch, feature_dim, num_layers] + 3. Cross-Attention 적용 + 4. Transformer Encoding 후 평균 풀링 + 5. 분류기 통과하여 최종 출력(logits) 반환 + """ + features = self.feature_extractor(input_values) # [batch, num_layers, feature_dim] + # embed_dim는 보통 feature_dim과 동일하게 맞춤 (예시: 768) + # features = features.permute(0, 2, 1) # [batch, embed_dim, num_layers] + + # Cross-Attention 적용 (여기서는 자기자신과의 어텐션으로 예시) + for layer in self.cross_attention_layers: + features = layer(features, features) + + # Transformer Encoder를 위해 시간 축(여기서는 num_layers 축)에 대해 평균 + features = features.mean(dim=1).unsqueeze(1) # [batch, 1, embed_dim] + encoded = self.transformer(features) # [batch, 1, embed_dim] + encoded = encoded.mean(dim=1) # [batch, embed_dim] + output = self.classifier(encoded) # [batch, num_classes] + return output, encoded + + def unfreeze_feature_extractor(self): + self.feature_extractor.unfreeze() diff --git a/ISMIR_2025/MERT/test.py b/ISMIR_2025/MERT/test.py new file mode 100644 index 0000000000000000000000000000000000000000..0dba71d3b798ffd56b0afc123e659dca49a5f024 --- /dev/null +++ b/ISMIR_2025/MERT/test.py @@ -0,0 +1,114 @@ +import os +import torch +import torch.nn.functional as F +import numpy as np +import matplotlib.pyplot as plt +from torch.utils.data import DataLoader +from sklearn.metrics import precision_score, recall_score, f1_score, balanced_accuracy_score, confusion_matrix +from datalib import ( + FakeMusicCapsDataset, + closed_test_files, closed_test_labels, + open_test_files, open_test_labels, + val_files, val_labels +) +from networks import MERTFeatureExtractor +import argparse +parser = argparse.ArgumentParser(description="AI Music Detection Testing with MERT") +parser.add_argument('--gpu', type=str, default='1', help='GPU ID') +parser.add_argument('--batch_size', type=int, default=32, help='Batch size') +parser.add_argument('--ckpt_path', type=str, default="/data/kym/AI_Music_Detection/Code/model/MERT/ckpt/1e-3/mert_finetune_10.pth", help='Path to the pretrained checkpoint') +parser.add_argument('--model_name', type=str, default="mert", help="Model name") +parser.add_argument('--closed_test', action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)") +parser.add_argument('--open_test', action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)") +parser.add_argument('--output_path', type=str, default='', help='Path to save test results') + +args = parser.parse_args() +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def plot_confusion_matrix(y_true, y_pred, classes, output_path): + cm = confusion_matrix(y_true, y_pred) + fig, ax = plt.subplots(figsize=(6, 6)) + im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) + ax.figure.colorbar(im, ax=ax) + + num_classes = cm.shape[0] + tick_labels = classes[:num_classes] + + ax.set(xticks=np.arange(num_classes), + yticks=np.arange(num_classes), + xticklabels=tick_labels, + yticklabels=tick_labels, + ylabel='True label', + xlabel='Predicted label') + + thresh = cm.max() / 2. + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + ax.text(j, i, format(cm[i, j], 'd'), + ha="center", va="center", + color="white" if cm[i, j] > thresh else "black") + + fig.tight_layout() + plt.savefig(output_path) + plt.close(fig) + +model = MERTFeatureExtractor().to(device) + +ckpt_file = args.ckpt_path +if not os.path.exists(ckpt_file): + raise FileNotFoundError(f"Checkpoint not found: {ckpt_file}") +print(f"\nLoading MERT model from {ckpt_file}") +model.load_state_dict(torch.load(ckpt_file, map_location=device)) +model.eval() + +torch.cuda.empty_cache() + +if args.closed_test: + print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...") + test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels, target_duration=10.0) +elif args.open_test: + print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...") + test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels, target_duration=10.0) +else: + print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...") + test_dataset = FakeMusicCapsDataset(val_files, val_labels, target_duration=10.0) + +test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8) + +def test_mert(model, test_loader, device): + model.eval() + test_loss, test_correct, test_total = 0, 0, 0 + all_preds, all_labels = [], [] + + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + loss = F.cross_entropy(output, target) + + test_loss += loss.item() * data.size(0) + preds = output.argmax(dim=1) + test_correct += (preds == target).sum().item() + test_total += target.size(0) + + all_labels.extend(target.cpu().numpy()) + all_preds.extend(preds.cpu().numpy()) + + test_loss /= test_total + test_acc = test_correct / test_total + test_bal_acc = balanced_accuracy_score(all_labels, all_preds) + test_precision = precision_score(all_labels, all_preds, average="binary") + test_recall = recall_score(all_labels, all_preds, average="binary") + test_f1 = f1_score(all_labels, all_preds, average="binary") + + print(f"\nTest Results - Loss: {test_loss:.4f} | Test Acc: {test_acc:.3f} | " + f"Test B_ACC: {test_bal_acc:.4f} | Test Prec: {test_precision:.3f} | " + f"Test Rec: {test_recall:.3f} | Test F1: {test_f1:.3f}") + + os.makedirs(args.output_path, exist_ok=True) + conf_matrix_path = os.path.join(args.output_path, f"confusion_matrix_{args.model_name}.png") + plot_confusion_matrix(all_labels, all_preds, classes=["real", "generative"], output_path=conf_matrix_path) + +print("\nEvaluating MERT Model on Test Set...") +test_mert(model, test_loader, device) diff --git a/ISMIR_2025/MERT/utils/__pycache__/config.cpython-311.pyc b/ISMIR_2025/MERT/utils/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5105baa6b5cc20b8d2fe0a9b029b40c9e6f7c6af Binary files /dev/null and b/ISMIR_2025/MERT/utils/__pycache__/config.cpython-311.pyc differ diff --git a/ISMIR_2025/MERT/utils/__pycache__/idr_torch.cpython-311.pyc b/ISMIR_2025/MERT/utils/__pycache__/idr_torch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb9379cca8c0e244eb2a7b169fa2ce307f2bab05 Binary files /dev/null and b/ISMIR_2025/MERT/utils/__pycache__/idr_torch.cpython-311.pyc differ diff --git a/ISMIR_2025/MERT/utils/__pycache__/utilities.cpython-311.pyc b/ISMIR_2025/MERT/utils/__pycache__/utilities.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97ed2e09ebbddb16a923974a1ef3c30f2259ac34 Binary files /dev/null and b/ISMIR_2025/MERT/utils/__pycache__/utilities.cpython-311.pyc differ diff --git a/ISMIR_2025/MERT/utils/config.py b/ISMIR_2025/MERT/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..69f72ecd472eed266bb9a0d811d7eeb07a3c06db --- /dev/null +++ b/ISMIR_2025/MERT/utils/config.py @@ -0,0 +1,565 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import csv + +import numpy as np + +sample_rate = 32000 +clip_samples = sample_rate * 10 # Audio clips are 10-second + +# Load label +with open( + "/gpfswork/rech/djl/uzj43um/audio_retrieval/audioset_tagging_cnn/metadata/class_labels_indices.csv", + "r", +) as f: + reader = csv.reader(f, delimiter=",") + lines = list(reader) + +labels = [] +ids = [] # Each label has a unique id such as "/m/068hy" +for i1 in range(1, len(lines)): + id = lines[i1][1] + label = lines[i1][2] + ids.append(id) + labels.append(label) + +classes_num = len(labels) + +lb_to_ix = {label: i for i, label in enumerate(labels)} +ix_to_lb = {i: label for i, label in enumerate(labels)} + +id_to_ix = {id: i for i, id in enumerate(ids)} +ix_to_id = {i: id for i, id in enumerate(ids)} + +full_samples_per_class = np.array( + [ + 937432, + 16344, + 7822, + 10271, + 2043, + 14420, + 733, + 1511, + 1258, + 424, + 1751, + 704, + 369, + 590, + 1063, + 1375, + 5026, + 743, + 853, + 1648, + 714, + 1497, + 1251, + 2139, + 1093, + 133, + 224, + 39469, + 6423, + 407, + 1559, + 4546, + 6826, + 7464, + 2468, + 549, + 4063, + 334, + 587, + 238, + 1766, + 691, + 114, + 2153, + 236, + 209, + 421, + 740, + 269, + 959, + 137, + 4192, + 485, + 1515, + 655, + 274, + 69, + 157, + 1128, + 807, + 1022, + 346, + 98, + 680, + 890, + 352, + 4169, + 2061, + 1753, + 9883, + 1339, + 708, + 37857, + 18504, + 12864, + 2475, + 2182, + 757, + 3624, + 677, + 1683, + 3583, + 444, + 1780, + 2364, + 409, + 4060, + 3097, + 3143, + 502, + 723, + 600, + 230, + 852, + 1498, + 1865, + 1879, + 2429, + 5498, + 5430, + 2139, + 1761, + 1051, + 831, + 2401, + 2258, + 1672, + 1711, + 987, + 646, + 794, + 25061, + 5792, + 4256, + 96, + 8126, + 2740, + 752, + 513, + 554, + 106, + 254, + 1592, + 556, + 331, + 615, + 2841, + 737, + 265, + 1349, + 358, + 1731, + 1115, + 295, + 1070, + 972, + 174, + 937780, + 112337, + 42509, + 49200, + 11415, + 6092, + 13851, + 2665, + 1678, + 13344, + 2329, + 1415, + 2244, + 1099, + 5024, + 9872, + 10948, + 4409, + 2732, + 1211, + 1289, + 4807, + 5136, + 1867, + 16134, + 14519, + 3086, + 19261, + 6499, + 4273, + 2790, + 8820, + 1228, + 1575, + 4420, + 3685, + 2019, + 664, + 324, + 513, + 411, + 436, + 2997, + 5162, + 3806, + 1389, + 899, + 8088, + 7004, + 1105, + 3633, + 2621, + 9753, + 1082, + 26854, + 3415, + 4991, + 2129, + 5546, + 4489, + 2850, + 1977, + 1908, + 1719, + 1106, + 1049, + 152, + 136, + 802, + 488, + 592, + 2081, + 2712, + 1665, + 1128, + 250, + 544, + 789, + 2715, + 8063, + 7056, + 2267, + 8034, + 6092, + 3815, + 1833, + 3277, + 8813, + 2111, + 4662, + 2678, + 2954, + 5227, + 1472, + 2591, + 3714, + 1974, + 1795, + 4680, + 3751, + 6585, + 2109, + 36617, + 6083, + 16264, + 17351, + 3449, + 5034, + 3931, + 2599, + 4134, + 3892, + 2334, + 2211, + 4516, + 2766, + 2862, + 3422, + 1788, + 2544, + 2403, + 2892, + 4042, + 3460, + 1516, + 1972, + 1563, + 1579, + 2776, + 1647, + 4535, + 3921, + 1261, + 6074, + 2922, + 3068, + 1948, + 4407, + 712, + 1294, + 1019, + 1572, + 3764, + 5218, + 975, + 1539, + 6376, + 1606, + 6091, + 1138, + 1169, + 7925, + 3136, + 1108, + 2677, + 2680, + 1383, + 3144, + 2653, + 1986, + 1800, + 1308, + 1344, + 122231, + 12977, + 2552, + 2678, + 7824, + 768, + 8587, + 39503, + 3474, + 661, + 430, + 193, + 1405, + 1442, + 3588, + 6280, + 10515, + 785, + 710, + 305, + 206, + 4990, + 5329, + 3398, + 1771, + 3022, + 6907, + 1523, + 8588, + 12203, + 666, + 2113, + 7916, + 434, + 1636, + 5185, + 1062, + 664, + 952, + 3490, + 2811, + 2749, + 2848, + 15555, + 363, + 117, + 1494, + 1647, + 5886, + 4021, + 633, + 1013, + 5951, + 11343, + 2324, + 243, + 372, + 943, + 734, + 242, + 3161, + 122, + 127, + 201, + 1654, + 768, + 134, + 1467, + 642, + 1148, + 2156, + 1368, + 1176, + 302, + 1909, + 61, + 223, + 1812, + 287, + 422, + 311, + 228, + 748, + 230, + 1876, + 539, + 1814, + 737, + 689, + 1140, + 591, + 943, + 353, + 289, + 198, + 490, + 7938, + 1841, + 850, + 457, + 814, + 146, + 551, + 728, + 1627, + 620, + 648, + 1621, + 2731, + 535, + 88, + 1736, + 736, + 328, + 293, + 3170, + 344, + 384, + 7640, + 433, + 215, + 715, + 626, + 128, + 3059, + 1833, + 2069, + 3732, + 1640, + 1508, + 836, + 567, + 2837, + 1151, + 2068, + 695, + 1494, + 3173, + 364, + 88, + 188, + 740, + 677, + 273, + 1533, + 821, + 1091, + 293, + 647, + 318, + 1202, + 328, + 532, + 2847, + 526, + 721, + 370, + 258, + 956, + 1269, + 1641, + 339, + 1322, + 4485, + 286, + 1874, + 277, + 757, + 1393, + 1330, + 380, + 146, + 377, + 394, + 318, + 339, + 1477, + 1886, + 101, + 1435, + 284, + 1425, + 686, + 621, + 221, + 117, + 87, + 1340, + 201, + 1243, + 1222, + 651, + 1899, + 421, + 712, + 1016, + 1279, + 124, + 351, + 258, + 7043, + 368, + 666, + 162, + 7664, + 137, + 70159, + 26179, + 6321, + 32236, + 33320, + 771, + 1169, + 269, + 1103, + 444, + 364, + 2710, + 121, + 751, + 1609, + 855, + 1141, + 2287, + 1940, + 3943, + 289, + ] +) \ No newline at end of file diff --git a/ISMIR_2025/MERT/utils/confusion_matrix_plot.py b/ISMIR_2025/MERT/utils/confusion_matrix_plot.py new file mode 100644 index 0000000000000000000000000000000000000000..e57d6d77e51949970ea76d8400d78ed6540cc155 --- /dev/null +++ b/ISMIR_2025/MERT/utils/confusion_matrix_plot.py @@ -0,0 +1,29 @@ +from sklearn.metrics import confusion_matrix +import matplotlib.pyplot as plt +import numpy as np + +def plot_confusion_matrix(y_true, y_pred, classes, writer, epoch): + cm = confusion_matrix(y_true, y_pred) + fig, ax = plt.subplots(figsize=(6, 6)) + im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) + ax.figure.colorbar(im, ax=ax) + + num_classes = cm.shape[0] + tick_labels = classes[:num_classes] + + ax.set(xticks=np.arange(num_classes), + yticks=np.arange(num_classes), + xticklabels=tick_labels, + yticklabels=tick_labels, + ylabel='True label', + xlabel='Predicted label') + + thresh = cm.max() / 2. + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + ax.text(j, i, format(cm[i, j], 'd'), + ha="center", va="center", + color="white" if cm[i, j] > thresh else "black") + + fig.tight_layout() + writer.add_figure("Confusion Matrix", fig, epoch) \ No newline at end of file diff --git a/ISMIR_2025/MERT/utils/freqeuncy.py b/ISMIR_2025/MERT/utils/freqeuncy.py new file mode 100644 index 0000000000000000000000000000000000000000..b21c5222467ec4906c63e5b9d02052a69aeb67e2 --- /dev/null +++ b/ISMIR_2025/MERT/utils/freqeuncy.py @@ -0,0 +1,24 @@ +import librosa +import librosa.display +import numpy as np +import matplotlib.pyplot as plt + +# 🔹 오디오 파일 로드 +file_real = "/path/to/real_audio.wav" # Real 오디오 경로 +file_fake = "/path/to/generative_audio.wav" # AI 생성 오디오 경로 + +def plot_spectrogram(audio_file, title): + y, sr = librosa.load(audio_file, sr=16000) # 샘플링 레이트 16kHz + D = librosa.amplitude_to_db(np.abs(librosa.stft(y)), ref=np.max) # STFT 변환 + + plt.figure(figsize=(10, 4)) + librosa.display.specshow(D, sr=sr, x_axis='time', y_axis='hz', cmap='magma') + plt.colorbar(format='%+2.0f dB') + plt.title(title) + plt.ylim(4000, 16000) # 4kHz 이상 고주파 영역만 표시 + plt.show() + +# 🔹 Real vs Generative Spectrogram 비교 +plot_spectrogram(file_real, "Real Audio Spectrogram (4kHz+)") +plot_spectrogram(file_fake, "Generative Audio Spectrogram (4kHz+)") + diff --git a/ISMIR_2025/MERT/utils/hf_vis.py b/ISMIR_2025/MERT/utils/hf_vis.py new file mode 100644 index 0000000000000000000000000000000000000000..c99b61bfb27f99880b0c44313daf476e6c0c278f --- /dev/null +++ b/ISMIR_2025/MERT/utils/hf_vis.py @@ -0,0 +1,89 @@ +import librosa +import librosa.display +import numpy as np +import matplotlib.pyplot as plt +import scipy.signal as signal +import torch +import torch.nn as nn +import soundfile as sf + +from networks import audiocnn, AudioCNNWithViTDecoder, AudioCNNWithViTDecoderAndCrossAttention + + +def highpass_filter(y, sr, cutoff=500, order=5): + """High-pass filter to remove low frequencies below `cutoff` Hz.""" + nyquist = 0.5 * sr + normal_cutoff = cutoff / nyquist + b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) + y_filtered = signal.lfilter(b, a, y) + return y_filtered + +def plot_combined_visualization(y_original, y_filtered, sr, save_path="combined_visualization.png"): + """Plot waveform comparison and spectrograms in a single figure.""" + fig, axes = plt.subplots(3, 1, figsize=(12, 12)) + + # 1️⃣ Waveform Comparison + time = np.linspace(0, len(y_original) / sr, len(y_original)) + axes[0].plot(time, y_original, label='Original', alpha=0.7) + axes[0].plot(time, y_filtered, label='High-pass Filtered', alpha=0.7, linestyle='dashed') + axes[0].set_xlabel("Time (s)") + axes[0].set_ylabel("Amplitude") + axes[0].set_title("Waveform Comparison (Original vs High-pass Filtered)") + axes[0].legend() + + # 2️⃣ Spectrogram - Original + S_orig = librosa.amplitude_to_db(np.abs(librosa.stft(y_original)), ref=np.max) + img = librosa.display.specshow(S_orig, sr=sr, x_axis='time', y_axis='log', ax=axes[1]) + axes[1].set_title("Original Spectrogram") + fig.colorbar(img, ax=axes[1], format="%+2.0f dB") + + # 3️⃣ Spectrogram - High-pass Filtered + S_filt = librosa.amplitude_to_db(np.abs(librosa.stft(y_filtered)), ref=np.max) + img = librosa.display.specshow(S_filt, sr=sr, x_axis='time', y_axis='log', ax=axes[2]) + axes[2].set_title("High-pass Filtered Spectrogram") + fig.colorbar(img, ax=axes[2], format="%+2.0f dB") + + plt.tight_layout() + plt.savefig(save_path, dpi=300) + plt.show() + + +def load_model(checkpoint_path, model_class, device): + """Load a trained model from checkpoint.""" + model = model_class() + model.load_state_dict(torch.load(checkpoint_path, map_location=device)) + model.to(device) + model.eval() + return model + +def predict_audio(model, audio_tensor, device): + """Make predictions using a trained model.""" + with torch.no_grad(): + audio_tensor = audio_tensor.unsqueeze(0).to(device) # Add batch dimension + output = model(audio_tensor) + prediction = torch.argmax(output, dim=1).cpu().numpy()[0] + return prediction + +# Load audio +audio_path = "/data/kym/AI Music Detection/audio/FakeMusicCaps/real/musiccaps/_RrA-0lfIiU.wav" # Replace with actual file path +y, sr = librosa.load(audio_path, sr=None) +y_filtered = highpass_filter(y, sr, cutoff=500) + +# Convert audio to tensor +audio_tensor = torch.tensor(librosa.feature.melspectrogram(y=y, sr=sr), dtype=torch.float).unsqueeze(0) +audio_tensor_filtered = torch.tensor(librosa.feature.melspectrogram(y=y_filtered, sr=sr), dtype=torch.float).unsqueeze(0) + +# Load models +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +original_model = load_model("/data/kym/AI Music Detection/AudioCNN/ckpt/FakeMusicCaps/pretraining/best_model_audiocnn.pth", audiocnn, device) +highpass_model = load_model("/data/kym/AI Music Detection/AudioCNN/ckpt/FakeMusicCaps/500hz_Add_crossattn_decoder/best_model_AudioCNNWithViTDecoderAndCrossAttention.pth", AudioCNNWithViTDecoderAndCrossAttention, device) + +# Predict +original_pred = predict_audio(original_model, audio_tensor, device) +highpass_pred = predict_audio(highpass_model, audio_tensor_filtered, device) + +print(f"Original Model Prediction: {original_pred}") +print(f"High-pass Filter Model Prediction: {highpass_pred}") + +# Generate combined visualization (all plots in one image) +plot_combined_visualization(y, y_filtered, sr, save_path="/data/kym/AI Music Detection/AudioCNN/hf_vis/rawvs500.png") diff --git a/ISMIR_2025/MERT/utils/idr_torch.py b/ISMIR_2025/MERT/utils/idr_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..b2e76040394ce27390c27bd8ef022e126d8e55dc --- /dev/null +++ b/ISMIR_2025/MERT/utils/idr_torch.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python +# coding: utf-8 + +import os +import hostlist + +# get SLURM variables +# rank = int(os.environ["SLURM_PROCID"]) +local_rank = int(os.environ["SLURM_LOCALID"]) +size = int(os.environ["SLURM_NTASKS"]) +cpus_per_task = int(os.environ["SLURM_CPUS_PER_TASK"]) + +# get node list from slurm +hostnames = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"]) + +# get IDs of reserved GPU +gpu_ids = os.environ["SLURM_STEP_GPUS"].split(",") + +# define MASTER_ADD & MASTER_PORT +os.environ["MASTER_ADDR"] = hostnames[0] +os.environ["MASTER_PORT"] = str( + 12345 + int(min(gpu_ids)) +) # to avoid port conflict on the same node \ No newline at end of file diff --git a/ISMIR_2025/MERT/utils/mfcc.py b/ISMIR_2025/MERT/utils/mfcc.py new file mode 100644 index 0000000000000000000000000000000000000000..5d63db14375fedcc1cc60f2ef3cecf5c70e9a8fb --- /dev/null +++ b/ISMIR_2025/MERT/utils/mfcc.py @@ -0,0 +1,266 @@ +import os +import glob +import librosa +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import Dataset, DataLoader, random_split +import torch.nn.functional as F +from sklearn.metrics import precision_score, recall_score, f1_score +from tqdm import tqdm +import argparse +import wandb + +class RealFakeDataset(Dataset): + """ + audio/FakeMusicCaps/ + ├─ real/ + │ └─ MusicCaps/*.wav (label=0) + └─ generative/ + └─ .../*.wav (label=1) + """ + def __init__(self, root_dir, sr=16000, n_mels=64, target_duration=10.0): + + self.sr = sr + self.n_mels = n_mels + self.target_duration = target_duration + self.target_samples = int(target_duration * sr) # 10초 = 160,000 샘플 + + self.file_paths = [] + self.labels = [] + + # Real 데이터 (label=0) + real_dir = os.path.join(root_dir, "real") + real_wav_files = glob.glob(os.path.join(real_dir, "**", "*.wav"), recursive=True) + for f in real_wav_files: + self.file_paths.append(f) + self.labels.append(0) + + # Generative 데이터 (label=1) + gen_dir = os.path.join(root_dir, "generative") + gen_wav_files = glob.glob(os.path.join(gen_dir, "**", "*.wav"), recursive=True) + for f in gen_wav_files: + self.file_paths.append(f) + self.labels.append(1) + + def __len__(self): + return len(self.file_paths) + + def __getitem__(self, idx): + audio_path = self.file_paths[idx] + label = self.labels[idx] + # print(f"[DEBUG] Path: {audio_path}, Label: {label}") # 추가 + + waveform, sr = librosa.load(audio_path, sr=self.sr, mono=True) + + current_samples = waveform.shape[0] + if current_samples > self.target_samples: + waveform = waveform[:self.target_samples] + elif current_samples < self.target_samples: + stretch_factor = self.target_samples / current_samples + waveform = librosa.effects.time_stretch(waveform, rate=stretch_factor) + waveform = waveform[:self.target_samples] + + mfcc = librosa.feature.mfcc( + y=waveform, sr=self.sr, n_mfcc=self.n_mels, n_fft=1024, hop_length=256 + ) + mfcc = librosa.util.normalize(mfcc) + + mfcc = np.expand_dims(mfcc, axis=0) + mfcc_tensor = torch.tensor(mfcc, dtype=torch.float) + label_tensor = torch.tensor(label, dtype=torch.long) + + return mfcc_tensor, label_tensor + + + +class AudioCNN(nn.Module): + def __init__(self, num_classes=2): + super(AudioCNN, self).__init__() + self.conv_block = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Conv2d(16, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.AdaptiveAvgPool2d((4,4)) # 최종 -> (B,32,4,4) + ) + self.fc_block = nn.Sequential( + nn.Linear(32*4*4, 128), + nn.ReLU(), + nn.Linear(128, num_classes) + ) + + + def forward(self, x): + x = self.conv_block(x) + # x.shape: (B,32,new_freq,new_time) + + # 1) Flatten + B, C, H, W = x.shape # 동적 shape + x = x.view(B, -1) # (B, 32*H*W) + + # 2) FC + x = self.fc_block(x) + return x + + +def my_collate_fn(batch): + mel_list, label_list = zip(*batch) + + max_frames = max(m.shape[2] for m in mel_list) + + padded = [] + for m in mel_list: + diff = max_frames - m.shape[2] + if diff > 0: + print(f"Padding applied: Original frames = {m.shape[2]}, Target frames = {max_frames}") + m = F.pad(m, (0, diff), mode='constant', value=0) + padded.append(m) + + + mel_batch = torch.stack(padded, dim=0) + label_batch = torch.tensor(label_list, dtype=torch.long) + return mel_batch, label_batch + + +class EarlyStopping: + def __init__(self, patience=5, delta=0, path='./ckpt/mfcc/early_stop_best_batch_{batch_size}_epochs_{epochs}_lr_{learning_rate}.pth', verbose=False): + self.patience = patience + self.delta = delta + self.path = path + self.verbose = verbose + self.counter = 0 + self.best_loss = None + self.early_stop = False + + def __call__(self, val_loss, model): + if self.best_loss is None: + self.best_loss = val_loss + self._save_checkpoint(val_loss, model) + elif val_loss > self.best_loss - self.delta: + self.counter += 1 + if self.verbose: + print(f"EarlyStopping counter: {self.counter} out of {self.patience}") + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_loss = val_loss + self._save_checkpoint(val_loss, model) + self.counter = 0 + + def _save_checkpoint(self, val_loss, model): + if self.verbose: + print(f"Validation loss decreased ({self.best_loss:.6f} --> {val_loss:.6f}). Saving model ...") + torch.save(model.state_dict(), self.path) + +def train(batch_size, epochs, learning_rate, root_dir="audio/FakeMusicCaps"): + if not os.path.exists("./ckpt/mfcc/"): + os.makedirs("./ckpt/mfcc/") + + wandb.init( + project="AI Music Detection", + name=f"mfcc_batch_{batch_size}_epochs_{epochs}_lr_{learning_rate}", + config={"batch_size": batch_size, "epochs": epochs, "learning_rate": learning_rate}, + ) + + dataset = RealFakeDataset(root_dir=root_dir) + n_total = len(dataset) + n_train = int(n_total * 0.8) + n_val = n_total - n_train + train_ds, val_ds = random_split(dataset, [n_train, n_val]) + + train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn) + val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=my_collate_fn) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = AudioCNN(num_classes=2).to(device) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=learning_rate) + + best_val_loss = float('inf') + patience = 3 + patience_counter = 0 + + for epoch in range(1, epochs + 1): + print(f"\n[Epoch {epoch}/{epochs}]") + + # Training + model.train() + train_loss, train_correct, train_total = 0, 0, 0 + train_pbar = tqdm(train_loader, desc="Train", leave=False) + for mel_batch, labels in train_pbar: + mel_batch, labels = mel_batch.to(device), labels.to(device) + optimizer.zero_grad() + outputs = model(mel_batch) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + train_loss += loss.item() * mel_batch.size(0) + preds = outputs.argmax(dim=1) + train_correct += (preds == labels).sum().item() + train_total += labels.size(0) + + train_pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + + train_loss /= train_total + train_acc = train_correct / train_total + + # Validation + model.eval() + val_loss, val_correct, val_total = 0, 0, 0 + all_preds, all_labels = [], [] + val_pbar = tqdm(val_loader, desc=" Val ", leave=False) + with torch.no_grad(): + for mel_batch, labels in val_pbar: + mel_batch, labels = mel_batch.to(device), labels.to(device) + outputs = model(mel_batch) + loss = criterion(outputs, labels) + val_loss += loss.item() * mel_batch.size(0) + preds = outputs.argmax(dim=1) + val_correct += (preds == labels).sum().item() + val_total += labels.size(0) + all_preds.extend(preds.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + + val_loss /= val_total + val_acc = val_correct / val_total + val_precision = precision_score(all_labels, all_preds, average="macro") + val_recall = recall_score(all_labels, all_preds, average="macro") + val_f1 = f1_score(all_labels, all_preds, average="macro") + + print(f"Train Loss: {train_loss:.4f} Acc: {train_acc:.3f} | " + f"Val Loss: {val_loss:.4f} Acc: {val_acc:.3f} " + f"Precision: {val_precision:.3f} Recall: {val_recall:.3f} F1: {val_f1:.3f}") + + wandb.log({"train_loss": train_loss, "train_acc": train_acc, + "val_loss": val_loss, "val_acc": val_acc, + "val_precision": val_precision, "val_recall": val_recall, "val_f1": val_f1}) + + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + best_model_path = f"./ckpt/mfcc/best_batch_{batch_size}_epochs_{epochs}_lr_{learning_rate}.pth" + torch.save(model.state_dict(), best_model_path) + print(f"[INFO] New best model saved: {best_model_path}") + else: + patience_counter += 1 + if patience_counter >= patience: + print("Early stopping triggered!") + break + + wandb.finish() + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Train AI Music Detection model.") + parser.add_argument('--batch_size', type=int, required=True, help="Batch size for training") + parser.add_argument('--epochs', type=int, required=True, help="Number of epochs") + parser.add_argument('--learning_rate', type=float, required=True, help="Learning rate") + parser.add_argument('--root_dir', type=str, default="audio/FakeMusicCaps", help="Root directory for dataset") + + args = parser.parse_args() + + train(batch_size=args.batch_size, epochs=args.epochs, learning_rate=args.learning_rate, root_dir=args.root_dir) diff --git a/ISMIR_2025/MERT/utils/utilities.py b/ISMIR_2025/MERT/utils/utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..e0be98e8645b8bb1c838d3dc9ae49daac706df62 --- /dev/null +++ b/ISMIR_2025/MERT/utils/utilities.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import logging +import pickle + +import numpy as np + +from scipy import stats + +import csv +import json + +def create_folder(fd): + if not os.path.exists(fd): + os.makedirs(fd, exist_ok=True) + + +def get_filename(path): + path = os.path.realpath(path) + na_ext = path.split("/")[-1] + na = os.path.splitext(na_ext)[0] + return na + + +def get_sub_filepaths(folder): + paths = [] + for root, dirs, files in os.walk(folder): + for name in files: + path = os.path.join(root, name) + paths.append(path) + return paths + + +def create_logging(log_dir, filemode): + create_folder(log_dir) + i1 = 0 + + while os.path.isfile(os.path.join(log_dir, "{:04d}.log".format(i1))): + i1 += 1 + + log_path = os.path.join(log_dir, "{:04d}.log".format(i1)) + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s", + datefmt="%a, %d %b %Y %H:%M:%S", + filename=log_path, + filemode=filemode, + ) + + # Print to console + console = logging.StreamHandler() + console.setLevel(logging.INFO) + formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s") + console.setFormatter(formatter) + logging.getLogger("").addHandler(console) + + return logging + + +def read_metadata(csv_path, audio_dir, classes_num, id_to_ix): + """Read metadata of AudioSet from a csv file. + + Args: + csv_path: str + + Returns: + meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)} + """ + + with open(csv_path, "r") as fr: + lines = fr.readlines() + lines = lines[3:] # Remove heads + + # first, count the audio names only of existing files on disk only + + audios_num = 0 + for n, line in enumerate(lines): + items = line.split(", ") + """items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']""" + + # audio_name = 'Y{}.wav'.format(items[0]) # Audios are started with an extra 'Y' when downloading + audio_name = "{}_{}_{}.flac".format( + items[0], items[1].replace(".", ""), items[2].replace(".", "") + ) + audio_name = audio_name.replace("_0000_", "_0_") + + if os.path.exists(os.path.join(audio_dir, audio_name)): + audios_num += 1 + + print("CSV audio files: %d" % (len(lines))) + print("Existing audio files: %d" % audios_num) + + # audios_num = len(lines) + targets = np.zeros((audios_num, classes_num), dtype=bool) + audio_names = [] + + n = 0 + for line in lines: + items = line.split(", ") + """items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']""" + + # audio_name = 'Y{}.wav'.format(items[0]) # Audios are started with an extra 'Y' when downloading + audio_name = "{}_{}_{}.flac".format( + items[0], items[1].replace(".", ""), items[2].replace(".", "") + ) + audio_name = audio_name.replace("_0000_", "_0_") + + if not os.path.exists(os.path.join(audio_dir, audio_name)): + continue + + label_ids = items[3].split('"')[1].split(",") + + audio_names.append(audio_name) + + # Target + for id in label_ids: + ix = id_to_ix[id] + targets[n, ix] = 1 + n += 1 + + meta_dict = {"audio_name": np.array(audio_names), "target": targets} + return meta_dict + + +def read_audioset_ontology(id_to_ix): + with open('../metadata/audioset_ontology.json', 'r') as f: + data = json.load(f) + + # Output: {'name': 'Bob', 'languages': ['English', 'French']} + sentences = [] + for el in data: + print(el.keys()) + id = el['id'] + if id in id_to_ix: + name = el['name'] + desc = el['description'] + # if '(' in desc: + # print(name, '---', desc) + # print(id_to_ix[id], name, '---', ) + + # sent = name + # sent = name + ', ' + desc.replace('(', '').replace(')', '').lower() + # sent = desc.replace('(', '').replace(')', '').lower() + # sentences.append(sent) + sentences.append(desc) + # print(sent) + # break + return sentences + + +def original_read_metadata(csv_path, classes_num, id_to_ix): + """Read metadata of AudioSet from a csv file. + + Args: + csv_path: str + + Returns: + meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)} + """ + + with open(csv_path, "r") as fr: + lines = fr.readlines() + lines = lines[3:] # Remove heads + + # Thomas Pellegrini: added 02/12/2022 + # check if the audio files indeed exist, otherwise remove from list + + audios_num = len(lines) + targets = np.zeros((audios_num, classes_num), dtype=bool) + audio_names = [] + + for n, line in enumerate(lines): + items = line.split(", ") + """items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']""" + + audio_name = "{}_{}_{}.flac".format( + items[0], items[1].replace(".", ""), items[2].replace(".", "") + ) # Audios are started with an extra 'Y' when downloading + audio_name = audio_name.replace("_0000_", "_0_") + + label_ids = items[3].split('"')[1].split(",") + + audio_names.append(audio_name) + + # Target + for id in label_ids: + ix = id_to_ix[id] + targets[n, ix] = 1 + + meta_dict = {"audio_name": np.array(audio_names), "target": targets} + return meta_dict + +def read_audioset_label_tags(class_labels_indices_csv): + with open(class_labels_indices_csv, 'r') as f: + reader = csv.reader(f, delimiter=',') + lines = list(reader) + + labels = [] + ids = [] # Each label has a unique id such as "/m/068hy" + for i1 in range(1, len(lines)): + id = lines[i1][1] + label = lines[i1][2] + ids.append(id) + labels.append(label) + + classes_num = len(labels) + + lb_to_ix = {label : i for i, label in enumerate(labels)} + ix_to_lb = {i : label for i, label in enumerate(labels)} + + id_to_ix = {id : i for i, id in enumerate(ids)} + ix_to_id = {i : id for i, id in enumerate(ids)} + + return lb_to_ix, ix_to_lb, id_to_ix, ix_to_id + + + +def float32_to_int16(x): + # assert np.max(np.abs(x)) <= 1.5 + x = np.clip(x, -1, 1) + return (x * 32767.0).astype(np.int16) + + +def int16_to_float32(x): + return (x / 32767.0).astype(np.float32) + + +def pad_or_truncate(x, audio_length): + """Pad all audio to specific length.""" + if len(x) <= audio_length: + return np.concatenate((x, np.zeros(audio_length - len(x))), axis=0) + else: + return x[0:audio_length] + + +def pad_audio(x, audio_length): + """Pad all audio to specific length.""" + if len(x) <= audio_length: + return np.concatenate((x, np.zeros(audio_length - len(x))), axis=0) + else: + return x + + +def d_prime(auc): + d_prime = stats.norm().ppf(auc) * np.sqrt(2.0) + return d_prime + + +class Mixup(object): + def __init__(self, mixup_alpha, random_seed=1234): + """Mixup coefficient generator.""" + self.mixup_alpha = mixup_alpha + self.random_state = np.random.RandomState(random_seed) + + def get_lambda(self, batch_size): + """Get mixup random coefficients. + Args: + batch_size: int + Returns: + mixup_lambdas: (batch_size,) + """ + mixup_lambdas = [] + for n in range(0, batch_size, 2): + lam = self.random_state.beta(self.mixup_alpha, self.mixup_alpha, 1)[0] + mixup_lambdas.append(lam) + mixup_lambdas.append(1.0 - lam) + + return np.array(mixup_lambdas) + + +class StatisticsContainer(object): + def __init__(self, statistics_path): + """Contain statistics of different training iterations.""" + self.statistics_path = statistics_path + + self.backup_statistics_path = "{}_{}.pkl".format( + os.path.splitext(self.statistics_path)[0], + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), + ) + + self.statistics_dict = {"bal": [], "test": []} + + def append(self, iteration, statistics, data_type): + statistics["iteration"] = iteration + self.statistics_dict[data_type].append(statistics) + + def dump(self): + pickle.dump(self.statistics_dict, open(self.statistics_path, "wb")) + pickle.dump(self.statistics_dict, open(self.backup_statistics_path, "wb")) + logging.info(" Dump statistics to {}".format(self.statistics_path)) + logging.info(" Dump statistics to {}".format(self.backup_statistics_path)) + + def load_state_dict(self, resume_iteration): + self.statistics_dict = pickle.load(open(self.statistics_path, "rb")) + + resume_statistics_dict = {"bal": [], "test": []} + + for key in self.statistics_dict.keys(): + for statistics in self.statistics_dict[key]: + if statistics["iteration"] <= resume_iteration: + resume_statistics_dict[key].append(statistics) + + self.statistics_dict = resume_statistics_dict \ No newline at end of file diff --git a/ISMIR_2025/Model/__pycache__/networks.cpython-312.pyc b/ISMIR_2025/Model/__pycache__/networks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e29a727e05815865175b9cd4c06134b0656bf4bd Binary files /dev/null and b/ISMIR_2025/Model/__pycache__/networks.cpython-312.pyc differ diff --git a/ISMIR_2025/Model/datalib.py b/ISMIR_2025/Model/datalib.py new file mode 100644 index 0000000000000000000000000000000000000000..f2478eab2f6439506b085bc2c305a048c30ff827 --- /dev/null +++ b/ISMIR_2025/Model/datalib.py @@ -0,0 +1,206 @@ +import os +import glob +import random +import torch +import librosa +import numpy as np +import utils +from sklearn.model_selection import train_test_split +from torch.utils.data import Dataset, DataLoader +import scipy.signal as signal +import scipy.signal +from scipy.signal import butter, lfilter +import numpy as np +import scipy.signal as signal +import librosa +import torch +import random +from torch.utils.data import Dataset +import logging +import csv +import logging +import time +import numpy as np +import h5py +import torch +import torchaudio +# Oversampling Lib +from imblearn.over_sampling import RandomOverSampler + +class FakeMusicCapsDataset(Dataset): + def __init__(self, file_paths, labels, feat_type=['mel'], sr=16000, n_mels=64, target_duration=10.0, augment=True, augment_real=True): + self.file_paths = file_paths + self.labels = labels + self.feat_type = feat_type + self.sr = sr + self.n_mels = n_mels + self.target_duration = target_duration + self.target_samples = int(target_duration * sr) + self.augment = augment + self.augment_real = augment_real + + + def pre_emphasis(self, x, alpha=0.97): + return np.append(x[0], x[1:] - alpha * x[:-1]) + + def highpass_filter(self, y, sr, cutoff=1000, order=5): + nyquist = 0.5 * sr + normal_cutoff = cutoff / nyquist + b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) + return signal.lfilter(b, a, y) + + def augment_audio(self, y, sr): + if random.random() < 0.5: + rate = random.uniform(0.8, 1.2) + y = librosa.effects.time_stretch(y=y, rate=rate) + + if random.random() < 0.5: + n_steps = random.randint(-2, 2) + y = librosa.effects.pitch_shift(y=y, sr=sr, n_steps=n_steps) + + if random.random() < 0.5: + noise_level = np.random.uniform(0.001, 0.005) + y = y + np.random.normal(0, noise_level, y.shape) + + if random.random() < 0.5: + gain = np.random.uniform(0.9, 1.1) + y = y * gain + + return y + + + def __len__(self): + return len(self.file_paths) + + def __getitem__(self, idx): + """ + Load and preprocess audio file. + """ + audio_path = self.file_paths[idx] + label = self.labels[idx] + + waveform, sr = librosa.load(audio_path, sr=self.sr, mono=True) + if label == 0: + if self.augment_real: + waveform = self.augment_audio(waveform, self.sr) + if label == 1: + waveform = self.highpass_filter(waveform, self.sr) + waveform = self.augment_audio(waveform, self.sr) + + current_samples = waveform.shape[0] + if current_samples > self.target_samples: + start_idx = (current_samples - self.target_samples) // 2 + waveform = waveform[start_idx:start_idx + self.target_samples] + elif current_samples < self.target_samples: + waveform = np.pad(waveform, (0, self.target_samples - current_samples), mode='constant') + + + mel_spec = librosa.feature.melspectrogram( + y=waveform, sr=self.sr, n_mels=self.n_mels, n_fft=1024, hop_length=256 + ) + log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max) + + log_mel_spec = np.expand_dims(log_mel_spec, axis=0) + mel_tensor = torch.tensor(log_mel_spec, dtype=torch.float) + label_tensor = torch.tensor(label, dtype=torch.long) + + return mel_tensor, label_tensor + + def extract_feature(self, waveform, feat): + """Extracts specified feature (mel, stft, cqt) from waveform.""" + try: + if feat == 'mel': + mel_spec = librosa.feature.melspectrogram(y=waveform, sr=self.sr, n_mels=self.n_mels, n_fft=1024, hop_length=256) + log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max) + return torch.tensor(log_mel_spec, dtype=torch.float).unsqueeze(0) + elif feat == 'stft': + stft = librosa.stft(waveform, n_fft=512, hop_length=128, window="hann") + logSTFT = np.log(np.abs(stft) + 1e-3) + return torch.tensor(logSTFT, dtype=torch.float).unsqueeze(0) + elif feat == 'cqt': + cqt = librosa.cqt(waveform, sr=self.sr, hop_length=128, bins_per_octave=24) + logCQT = np.log(np.abs(cqt) + 1e-3) + return torch.tensor(logCQT, dtype=torch.float).unsqueeze(0) + else: + raise ValueError(f"[ERROR] Unsupported feature type: {feat}") + except Exception as e: + print(f"[ERROR] Feature extraction failed for {feat}: {e}") + return None + + def highpass_filter(self, y, sr, cutoff=1000, order=5): + if isinstance(sr, np.ndarray): + sr = np.mean(sr) + if not isinstance(sr, (int, float)): + raise ValueError(f"[ERROR] sr must be a number, but got {type(sr)}: {sr}") + if sr <= 0: + raise ValueError(f"Invalid sample rate: {sr}. It must be greater than 0.") + nyquist = 0.5 * sr + if cutoff <= 0 or cutoff >= nyquist: + print(f"[WARNING] Invalid cutoff frequency {cutoff}, adjusting...") + cutoff = max(10, min(cutoff, nyquist - 1)) + normal_cutoff = cutoff / nyquist + b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) + y_filtered = signal.lfilter(b, a, y) + return y_filtered + +def preprocess_audio(audio_path, sr=16000, n_mels=64, target_duration=10.0): + try: + waveform, _ = librosa.load(audio_path, sr=sr, mono=True) + + target_samples = int(target_duration * sr) + if len(waveform) > target_samples: + start_idx = (len(waveform) - target_samples) // 2 + waveform = waveform[start_idx:start_idx + target_samples] + elif len(waveform) < target_samples: + waveform = np.pad(waveform, (0, target_samples - len(waveform)), mode='constant') + mel_spec = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=n_mels, n_fft=1024, hop_length=256) + log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max) + return torch.tensor(log_mel_spec, dtype=torch.float32).unsqueeze(0).unsqueeze(0) + + except Exception as e: + print(f"[ERROR] 전처리 실패: {audio_path} | 오류: {e}") + return None + + +DATASET_PATH = "/data/kym/AI_Music_Detection/audio/FakeMusicCaps" +SUNOCAPS_PATH = "/data/kym/Audio/SunoCaps" # Open Set 포함 데이터 + +real_files = glob.glob(os.path.join(DATASET_PATH, "real", "**", "*.wav"), recursive=True) +gen_files = glob.glob(os.path.join(DATASET_PATH, "generative", "**", "*.wav"), recursive=True) + +open_real_files = real_files + glob.glob(os.path.join(SUNOCAPS_PATH, "real", "**", "*.wav"), recursive=True) +open_gen_files = gen_files + glob.glob(os.path.join(SUNOCAPS_PATH, "generative", "**", "*.wav"), recursive=True) + +real_labels = [0] * len(real_files) +gen_labels = [1] * len(gen_files) + +open_real_labels = [0] * len(open_real_files) +open_gen_labels = [1] * len(open_gen_files) + +real_train, real_val, real_train_labels, real_val_labels = train_test_split(real_files, real_labels, test_size=0.2, random_state=42) +gen_train, gen_val, gen_train_labels, gen_val_labels = train_test_split(gen_files, gen_labels, test_size=0.2, random_state=42) + +train_files = real_train + gen_train +train_labels = real_train_labels + gen_train_labels +val_files = real_val + gen_val +val_labels = real_val_labels + gen_val_labels + +closed_test_files = real_files + gen_files +closed_test_labels = real_labels + gen_labels + +open_test_files = open_real_files + open_gen_files +open_test_labels = open_real_labels + open_gen_labels + +ros = RandomOverSampler(sampling_strategy='auto', random_state=42) +train_files_resampled, train_labels_resampled = ros.fit_resample(np.array(train_files).reshape(-1, 1), train_labels) + +train_files = train_files_resampled.reshape(-1).tolist() +train_labels = train_labels_resampled +print(f"type(train_labels_resampled): {type(train_labels_resampled)}") + +print(f"Train Org Fake: {len(gen_val)}") +print(f"Train set (Oversampled) - Real: {sum(1 for label in train_labels if label == 0)}, " + f"Fake: {sum(1 for label in train_labels if label == 1)}, Total: {len(train_files)}") +print(f"Validation set - Real: {len(real_val)}, Fake: {len(gen_val)}, Total: {len(val_files)}") +print(f"Closed Test set - Real: {len(real_files)}, Fake: {len(gen_files)}, Total: {len(closed_test_files)}") +print(f"Open Test set - Real: {len(open_real_files)}, Fake: {len(open_gen_files)}, Total: {len(open_test_files)}") \ No newline at end of file diff --git a/ISMIR_2025/Model/main.py b/ISMIR_2025/Model/main.py new file mode 100644 index 0000000000000000000000000000000000000000..68ce313c7adc981df67673e8e2f1c40472f22725 --- /dev/null +++ b/ISMIR_2025/Model/main.py @@ -0,0 +1,336 @@ +import os +import random +import numpy as np +import torch +import torch.nn.functional as F +import torch.nn as nn +import torch.optim as optim +from tqdm import tqdm +from torch.utils.tensorboard import SummaryWriter +import wandb +import matplotlib.pyplot as plt +from torch.utils.data import DataLoader +from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, balanced_accuracy_score +from datalib import FakeMusicCapsDataset +from datalib import ( + FakeMusicCapsDataset, + train_files, val_files, train_labels, val_labels, + closed_test_files, closed_test_labels, + open_test_files, open_test_labels, + preprocess_audio +) +from datalib import preprocess_audio +from networks import CCV +from attentionmap import visualize_attention_map +from confusion_matrix import plot_confusion_matrix + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) +''' +python3 main.py --model_name CCV --batch_size 32 --epochs 10 --loss_type ce --oversample True + +audiocnn encoder - crossattn based decoder (ViT) model +''' +# Argument parsing +import argparse +parser = argparse.ArgumentParser(description='AI Music Detection Training') +parser.add_argument('--gpu', type=str, default='1', help='GPU ID') +parser.add_argument('--model_name', type=str, choices=['audiocnn', 'CCV'], default='CCV', help='Model name') +parser.add_argument('--batch_size', type=int, default=32, help='Batch size') +parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate') +parser.add_argument('--epochs', type=int, default=10, help='Number of epochs') +parser.add_argument('--audio_duration', type=float, default=10, help='Length of the audio slice in seconds') +parser.add_argument('--patience_counter', type=int, default=5, help='Early stopping patience') +parser.add_argument('--log_dir', type=str, default='', help='TensorBoard log directory') +parser.add_argument('--ckpt_path', type=str, default='', help='Checkpoint directory') +parser.add_argument("--weight_decay", type=float, default=0.05, help="weight decay (default: 0.0)") +parser.add_argument("--loss_type", type=str, choices=["ce", "weighted_ce", "focal"], default="ce", help="Loss function type") + +parser.add_argument('--inference', type=str, help='Path to a .wav file for inference') +parser.add_argument("--closed_test", action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)") +parser.add_argument("--open_test", action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)") +parser.add_argument("--oversample", type=bool, default=True, help="Apply Oversampling to balance classes") # real data oversampling + + +args = parser.parse_args() +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +torch.manual_seed(42) +random.seed(42) +np.random.seed(42) +wandb.init(project="", + name=f"{args.model_name}_lr{args.learning_rate}_ep{args.epochs}_bs{args.batch_size}", config=args) + +if args.model_name == 'CCV': + model = CCV(embed_dim=512, num_heads=8, num_layers=6, num_classes=2).cuda() + feat_type = 'mel' +else: + raise ValueError(f"Invalid model name: {args.model_name}") + +model = model.to(device) +print(f"Using model: {args.model_name}, Parameters: {count_parameters(model)}") +print(f"weight_decay WD: {args.weight_decay}") + +optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) +scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1) + +if args.loss_type == "ce": + print("Using CrossEntropyLoss") + criterion = nn.CrossEntropyLoss() + +elif args.loss_type == "weighted_ce": + print("Using Weighted CrossEntropyLoss") + + num_real = sum(1 for label in train_labels if label == 0) + num_fake = sum(1 for label in train_labels if label == 1) + + total_samples = num_real + num_fake + weight_real = total_samples / (2 * num_real) + weight_fake = total_samples / (2 * num_fake) + class_weights = torch.tensor([weight_real, weight_fake]).to(device) + + criterion = nn.CrossEntropyLoss(weight=class_weights) + +elif args.loss_type == "focal": + print("Using Focal Loss") + + class FocalLoss(torch.nn.Module): + def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'): + super(FocalLoss, self).__init__() + self.alpha = alpha + self.gamma = gamma + self.reduction = reduction + + def forward(self, inputs, targets): + ce_loss = F.cross_entropy(inputs, targets, reduction='none') + pt = torch.exp(-ce_loss) + focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss + + if self.reduction == 'mean': + return focal_loss.mean() + elif self.reduction == 'sum': + return focal_loss.sum() + else: + return focal_loss + + criterion = FocalLoss().to(device) + +if not os.path.exists(args.ckpt_path): + os.makedirs(args.ckpt_path) + +train_dataset = FakeMusicCapsDataset(train_files, train_labels, feat_type=feat_type, target_duration=args.audio_duration) +val_dataset = FakeMusicCapsDataset(val_files, val_labels, feat_type=feat_type, target_duration=args.audio_duration) + +train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16) +val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16) + +def train(model, train_loader, val_loader, optimizer, scheduler, criterion, device, args): + writer = SummaryWriter(log_dir=args.log_dir) + best_val_bal_acc = float('inf') + early_stop_cnt = 0 + log_interval = 1 + + for epoch in range(args.epochs): + print(f"\n[Epoch {epoch + 1}/{args.epochs}]") + model.train() + train_loss, train_correct, train_total = 0, 0, 0 + + all_train_preds= [] + all_train_labels = [] + attention_maps = [] + + train_pbar = tqdm(train_loader, desc="Train", leave=False) + for batch_idx, (data, target) in enumerate(train_pbar): + data = data.to(device) + target = target.to(device) + output = model(data) + loss = criterion(output, target) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + train_loss += loss.item() * data.size(0) + preds = output.argmax(dim=1) + train_correct += (preds == target).sum().item() + train_total += target.size(0) + + all_train_labels.extend(target.cpu().numpy()) + all_train_preds.extend(preds.cpu().numpy()) + + if hasattr(model, "get_attention_maps"): + attention_maps.append(model.get_attention_maps()) + + train_loss /= train_total + train_acc = train_correct / train_total + train_bal_acc = balanced_accuracy_score(all_train_labels, all_train_preds) + train_precision = precision_score(all_train_labels, all_train_preds, average="binary") + train_recall = recall_score(all_train_labels, all_train_preds, average="binary") + train_f1 = f1_score(all_train_labels, all_train_preds, average="binary") + + wandb.log({ + "Train Loss": train_loss, "Train Accuracy": train_acc, + "Train Precision": train_precision, "Train Recall": train_recall, + "Train F1 Score": train_f1, "Train B_ACC": train_bal_acc, + }) + + print(f"Train Epoch: {epoch+1} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.3f} | " + f"Train B_ACC: {train_bal_acc:.4f} | Train Prec: {train_precision:.3f} | " + f"Train Rec: {train_recall:.3f} | Train F1: {train_f1:.3f}") + + model.eval() + val_loss, val_correct, val_total = 0, 0, 0 + all_val_preds, all_val_labels = [], [] + attention_maps = [] + val_pbar = tqdm(val_loader, desc=" Val ", leave=False) + with torch.no_grad(): + for data, target in val_pbar: + data, target = data.to(device), target.to(device) + output = model(data) + loss = criterion(output, target) + val_loss += loss.item() * data.size(0) + preds = output.argmax(dim=1) + val_correct += (preds == target).sum().item() + val_total += target.size(0) + + all_val_labels.extend(target.cpu().numpy()) + all_val_preds.extend(preds.cpu().numpy()) + + if hasattr(model, "get_attention_maps"): + attention_maps.append(model.get_attention_maps()) + + val_loss /= val_total + val_acc = val_correct / val_total + val_bal_acc = balanced_accuracy_score(all_val_labels, all_val_preds) + val_precision = precision_score(all_val_labels, all_val_preds, average="binary") + val_recall = recall_score(all_val_labels, all_val_preds, average="binary") + val_f1 = f1_score(all_val_labels, all_val_preds, average="binary") + + wandb.log({ + "Validation Loss": val_loss, "Validation Accuracy": val_acc, + "Validation Precision": val_precision, "Validation Recall": val_recall, + "Validation F1 Score": val_f1, "Validation B_ACC": val_bal_acc, + }) + + print(f"Val Epoch: {epoch+1} [{batch_idx * len(data)}/{len(val_loader.dataset)} " + f"({100. * batch_idx / len(val_loader):.0f}%)]\t" + f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.3f} | " + f"Val B_ACC: {val_bal_acc:.4f} | Val Prec: {val_precision:.3f} | " + f"Val Rec: {val_recall:.3f} | Val F1: {val_f1:.3f}") + + if epoch % 1 == 0 and len(attention_maps) > 0: + print(f"Visualizing Attention Map at Epoch {epoch+1}") + + if isinstance(attention_maps[0], list): + attn_map_numpy = np.array([t.detach().cpu().numpy() for t in attention_maps[0]]) + elif isinstance(attention_maps[0], torch.Tensor): + attn_map_numpy = attention_maps[0].detach().cpu().numpy() + else: + attn_map_numpy = np.array(attention_maps[0]) + + print(f"Attention Map Shape: {attn_map_numpy.shape}") + + if len(attn_map_numpy) > 0: + fig, ax = plt.subplots(figsize=(10, 8)) + ax.imshow(attn_map_numpy[0], cmap='viridis', interpolation='nearest') + ax.set_title(f"Attention Map - Epoch {epoch+1}") + plt.colorbar(ax.imshow(attn_map_numpy[0], cmap='viridis')) + plt.savefig("") + plt.show() + else: + print(f"Warning: attention_maps[0] is empty! Shape={attn_map_numpy.shape}") + + if val_bal_acc < best_val_bal_acc: + best_val_bal_acc = val_bal_acc + early_stop_cnt = 0 + torch.save(model.state_dict(), os.path.join(args.ckpt_path, f"best_model_{args.model_name}.pth")) + print("Best model saved.") + else: + early_stop_cnt += 1 + print(f'PATIENCE {early_stop_cnt}/{args.patience_counter}') + + if early_stop_cnt >= args.patience_counter: + print("Early stopping triggered.") + break + + scheduler.step() + plot_confusion_matrix(all_val_labels, all_val_preds, classes=["REAL", "FAKE"], writer=writer, epoch=epoch) + + wandb.finish() + writer.close() + +def predict(audio_path): + print(f"Loading model from {args.ckpt_path}/celoss_best_model_{args.model_name}.pth") + model.load_state_dict(torch.load(os.path.join(args.ckpt_path, f"best_model_{args.model_name}.pth"), map_location=device)) + model.eval() + + input_tensor = preprocess_audio(audio_path).to(device) + + with torch.no_grad(): + output = model(input_tensor) + probabilities = F.softmax(output, dim=1) + ai_music_prob = probabilities[0, 1].item() + + if ai_music_prob > 0.5: + print(f"FAKE MUSIC {ai_music_prob:.2%})") + else: + print(f"REAL MUSIC {100 - ai_music_prob * 100:.2f}%") + +def Test(model, test_loader, criterion, device): + model.load_state_dict(torch.load(os.path.join(args.ckpt_path, f"best_model_{args.model_name}.pth"), map_location=device)) + model.eval() + test_loss, test_correct, test_total = 0, 0, 0 + all_preds, all_labels = [], [] + + with torch.no_grad(): + for data, target in tqdm(test_loader, desc=" Test ", leave=False): + data, target = data.to(device), target.to(device) + output = model(data) + loss = criterion(output, target) + + test_loss += loss.item() * data.size(0) + preds = output.argmax(dim=1) + test_correct += (preds == target).sum().item() + test_total += target.size(0) + + all_labels.extend(target.cpu().numpy()) + all_preds.extend(preds.cpu().numpy()) + + test_loss /= test_total + test_acc = test_correct / test_total + test_bal_acc = balanced_accuracy_score(all_labels, all_preds) + test_precision = precision_score(all_labels, all_preds, average="binary") + test_recall = recall_score(all_labels, all_preds, average="binary") + test_f1 = f1_score(all_labels, all_preds, average="binary") + + print(f"\nTest Results - Loss: {test_loss:.4f} | Test Acc: {test_acc:.3f} | " + f"Test B_ACC: {test_bal_acc:.4f} | Test Prec: {test_precision:.3f} | " + f"Test Rec: {test_recall:.3f} | Test F1: {test_f1:.3f}") + + +if __name__ == "__main__": + train(model, train_loader, val_loader, optimizer, scheduler, criterion, device, args) + if args.closed_test: + print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...") + test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels, feat_type=feat_type, target_duration=args.audio_duration) + test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16) + + elif args.open_test: + print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...") + test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels, feat_type=feat_type, target_duration=args.audio_duration) + test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16) + + else: + print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...") + test_dataset = FakeMusicCapsDataset(val_files, val_labels, feat_type=feat_type, target_duration=args.audio_duration) + test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16) + + print("\nEvaluating Model on Test Set...") + Test(model, test_loader, criterion, device) + + if args.inference: + if not os.path.exists(args.inference): + print(f"[ERROR] No File Found: {args.inference}") + else: + predict(args.inference) diff --git a/ISMIR_2025/Model/networks.py b/ISMIR_2025/Model/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..07914ba386647553f4696c452323c5c0a77f37ce --- /dev/null +++ b/ISMIR_2025/Model/networks.py @@ -0,0 +1,237 @@ +import torch +import torch.nn as nn + +class audiocnn(nn.Module): + def __init__(self, num_classes=2): + super(audiocnn, self).__init__() + self.conv_block = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Conv2d(16, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.AdaptiveAvgPool2d((4,4)) # 최종 -> (B,32,4,4) + ) + self.fc_block = nn.Sequential( + nn.Linear(32*4*4, 128), + nn.ReLU(), + nn.Linear(128, num_classes) + ) + + def forward(self, x): + x = self.conv_block(x) + # x.shape: (B,32,new_freq,new_time) + + # 1) Flatten + B, C, H, W = x.shape # 동적 shape + x = x.view(B, -1) # (B, 32*H*W) + + # 2) FC + x = self.fc_block(x) + return x + +class AudioCNN(nn.Module): + def __init__(self, embed_dim=512): + super(AudioCNN, self).__init__() + self.conv_block = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Conv2d(16, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.AdaptiveAvgPool2d((4, 4)) # 최종 -> (B, 32, 4, 4) + ) + self.projection = nn.Linear(32 * 4 * 4, embed_dim) + + def forward(self, x): + x = self.conv_block(x) + B, C, H, W = x.shape + x = x.view(B, -1) # Flatten (B, C * H * W) + x = self.projection(x) # Project to embed_dim + return x + +class ViTDecoder(nn.Module): + def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2): + super(ViTDecoder, self).__init__() + + # Transformer layers + encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + # Classification head + self.classifier = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, num_classes) + ) + + def forward(self, x): + # Transformer expects input of shape (seq_len, batch, embed_dim) + x = x.unsqueeze(1).permute(1, 0, 2) # Add sequence dim (1, B, embed_dim) + x = self.transformer(x) # Pass through Transformer + x = x.mean(dim=0) # Take the mean over the sequence dimension (B, embed_dim) + + x = self.classifier(x) # Classification head + return x + +class AudioCNNWithViTDecoder(nn.Module): + def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2): + super(AudioCNNWithViTDecoder, self).__init__() + self.encoder = AudioCNN(embed_dim=embed_dim) + self.decoder = ViTDecoder(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes) + + def forward(self, x): + x = self.encoder(x) # Pass through AudioCNN encoder + x = self.decoder(x) # Pass through ViT decoder + return x + + +# class AudioCNN(nn.Module): +# def __init__(self, num_classes=2): +# super(AudioCNN, self).__init__() +# self.conv_block = nn.Sequential( +# nn.Conv2d(1, 16, kernel_size=3, padding=1), +# nn.ReLU(), +# nn.MaxPool2d(2), +# nn.Conv2d(16, 32, kernel_size=3, padding=1), +# nn.ReLU(), +# nn.MaxPool2d(2), +# nn.AdaptiveAvgPool2d((4,4)) # 최종 -> (B,32,4,4) +# ) +# self.fc_block = nn.Sequential( +# nn.Linear(32*4*4, 128), +# nn.ReLU(), +# nn.Linear(128, num_classes) +# ) + + +# def forward(self, x): +# x = self.conv_block(x) +# # x.shape: (B,32,new_freq,new_time) + +# # 1) Flatten +# B, C, H, W = x.shape # 동적 shape +# x = x.view(B, -1) # (B, 32*H*W) + +# # 2) FC +# x = self.fc_block(x) +# return x + + + +class audio_crossattn(nn.Module): + def __init__(self, embed_dim=512): + super(audio_crossattn, self).__init__() + self.conv_block = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Conv2d(16, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.AdaptiveAvgPool2d((4, 4)) # 최종 출력 -> (B, 32, 4, 4) + ) + self.projection = nn.Linear(32 * 4 * 4, embed_dim) + + def forward(self, x): + x = self.conv_block(x) # Convolutional feature extraction + B, C, H, W = x.shape + x = x.view(B, -1) # Flatten (B, C * H * W) + x = self.projection(x) # Linear projection to embed_dim + return x + + +class CrossAttentionLayer(nn.Module): + def __init__(self, embed_dim, num_heads): + super(CrossAttentionLayer, self).__init__() + self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) + self.layer_norm = nn.LayerNorm(embed_dim) + self.feed_forward = nn.Sequential( + nn.Linear(embed_dim, embed_dim * 4), + nn.ReLU(), + nn.Linear(embed_dim * 4, embed_dim) + ) + + def forward(self, x, cross_input): + # Cross-attention between x and cross_input + attn_output, _ = self.multihead_attn(query=x, key=cross_input, value=cross_input) + x = self.layer_norm(x + attn_output) # Add & Norm + feed_forward_output = self.feed_forward(x) + x = self.layer_norm(x + feed_forward_output) # Add & Norm + return x + +class ViTDecoderWithCrossAttention(nn.Module): + def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2): + super(ViTDecoderWithCrossAttention, self).__init__() + + # Cross-Attention layers + self.cross_attention_layers = nn.ModuleList([ + CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers) + ]) + + # Transformer Encoder layers + encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + # Classification head + self.classifier = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, num_classes) + ) + + def forward(self, x, cross_attention_input): + # Pass through Cross-Attention layers + for layer in self.cross_attention_layers: + x = layer(x, cross_attention_input) + + # Transformer expects input of shape (seq_len, batch, embed_dim) + x = x.unsqueeze(1).permute(1, 0, 2) # Add sequence dim (1, B, embed_dim) + x = self.transformer(x) # Pass through Transformer + embedding = x.mean(dim=0) # Take the mean over the sequence dimension (B, embed_dim) + + # Classification head + x = self.classifier(embedding) + return x, embedding + +# class AudioCNNWithViTDecoderAndCrossAttention(nn.Module): +# def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2): +# super(AudioCNNWithViTDecoderAndCrossAttention, self).__init__() +# self.encoder = audio_crossattn(embed_dim=embed_dim) +# self.decoder = ViTDecoderWithCrossAttention(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes) + +# def forward(self, x, cross_attention_input): +# # Pass through AudioCNN encoder +# x = self.encoder(x) + +# # Pass through ViTDecoder with Cross-Attention +# x = self.decoder(x, cross_attention_input) +# return x +class CCV(nn.Module): + def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True): + super(CCV, self).__init__() + self.encoder = AudioCNN(embed_dim=embed_dim) + self.decoder = ViTDecoderWithCrossAttention(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes) + if freeze_feature_extractor: + for param in self.encoder.parameters(): + param.requires_grad = False + for param in self.decoder.parameters(): + param.requires_grad = False + def forward(self, x, cross_attention_input=None): + # Pass through AudioCNN encoder + x = self.encoder(x) + + # If cross_attention_input is not provided, use the encoder output + if cross_attention_input is None: + cross_attention_input = x + + # Pass through ViTDecoder with Cross-Attention + x, embedding = self.decoder(x, cross_attention_input) + return x, embedding + +#--------------------------------------------------------- +''' +audiocnn weight frozen +crossatten decoder -lora tuning +''' + diff --git a/ISMIR_2025/Model/test.py b/ISMIR_2025/Model/test.py new file mode 100644 index 0000000000000000000000000000000000000000..0837d7094ef5affc3111a56bdfe4adfe67333bf6 --- /dev/null +++ b/ISMIR_2025/Model/test.py @@ -0,0 +1,129 @@ +import os +import torch +import torch.nn.functional as F +import numpy as np +import matplotlib.pyplot as plt +from torch.utils.data import DataLoader +from sklearn.metrics import precision_score, recall_score, f1_score, balanced_accuracy_score, confusion_matrix +from datalib_f import ( + FakeMusicCapsDataset, + closed_test_files, closed_test_labels, + open_test_files, open_test_labels, + val_files, val_labels +) +from networks_f import CCV_Wav2Vec2 +import argparse + +parser = argparse.ArgumentParser(description="AI Music Detection Testing") +parser.add_argument('--gpu', type=str, default='1', help='GPU ID') +parser.add_argument('--model_name', type=str, choices=['audiocnn', 'CCV'], default='CCV_Wav2Vec2', help='Model name') +parser.add_argument('--batch_size', type=int, default=32, help='Batch size') +parser.add_argument('--ckpt_path', type=str, default='/data/kym/AI_Music_Detection/Code/model/tensorboard/wav2vec', help='Checkpoint directory') +parser.add_argument('--closed_test', action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)") +parser.add_argument('--open_test', action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)") +parser.add_argument('--output_path', type=str, default='/data/kym/AI_Music_Detection/Code/model/test_results/w_celoss_repreprocess/wav2vec', help='Path to save test results') + +args = parser.parse_args() +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def plot_confusion_matrix(y_true, y_pred, classes, output_path): + cm = confusion_matrix(y_true, y_pred) + fig, ax = plt.subplots(figsize=(6, 6)) + im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) + ax.figure.colorbar(im, ax=ax) + + num_classes = cm.shape[0] + tick_labels = classes[:num_classes] + + ax.set(xticks=np.arange(num_classes), + yticks=np.arange(num_classes), + xticklabels=tick_labels, + yticklabels=tick_labels, + ylabel='True label', + xlabel='Predicted label') + + thresh = cm.max() / 2. + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + ax.text(j, i, format(cm[i, j], 'd'), + ha="center", va="center", + color="white" if cm[i, j] > thresh else "black") + + fig.tight_layout() + plt.savefig(output_path) + plt.close(fig) + +if args.model_name == 'CCV_Wav2Vec2': + model = CCV_Wav2Vec2(embed_dim=512, num_heads=8, num_layers=6, num_classes=2).to(device) +else: + raise ValueError(f"Invalid model name: {args.model_name}") + +ckpt_file = os.path.join(args.ckpt_path, f"best_model_{args.model_name}.pth") +if not os.path.exists(ckpt_file): + raise FileNotFoundError(f"Checkpoint not found: {ckpt_file}") + +print(f"\nLoading model from {ckpt_file}") + +# model.load_state_dict(torch.load(ckpt_file, map_location=device)) +# 병렬 +state_dict = torch.load(ckpt_file, map_location=device) +from collections import OrderedDict +new_state_dict = OrderedDict() +for k, v in state_dict.items(): + name = k[7:] if k.startswith("module.") else k + new_state_dict[name] = v +model.load_state_dict(new_state_dict) +# 병렬 +model.eval() + +torch.cuda.empty_cache() + +if args.closed_test: + print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...") + test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels, feat_type="mel", target_duration=10.0) +elif args.open_test: + print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...") + test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels, feat_type="mel", target_duration=10.0) +else: + print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...") + test_dataset = FakeMusicCapsDataset(val_files, val_labels, feat_type="mel", target_duration=10.0) + +test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16) + +def Test(model, test_loader, device): + model.eval() + test_loss, test_correct, test_total = 0, 0, 0 + all_preds, all_labels = [], [] + + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + loss = F.cross_entropy(output, target) + + test_loss += loss.item() * data.size(0) + preds = output.argmax(dim=1) + test_correct += (preds == target).sum().item() + test_total += target.size(0) + + all_labels.extend(target.cpu().numpy()) + all_preds.extend(preds.cpu().numpy()) + + test_loss /= test_total + test_acc = test_correct / test_total + test_bal_acc = balanced_accuracy_score(all_labels, all_preds) + test_precision = precision_score(all_labels, all_preds, average="binary") + test_recall = recall_score(all_labels, all_preds, average="binary") + test_f1 = f1_score(all_labels, all_preds, average="binary") + + print(f"\nTest Results - Loss: {test_loss:.4f} | Test Acc: {test_acc:.3f} | " + f"Test B_ACC: {test_bal_acc:.4f} | Test Prec: {test_precision:.3f} | " + f"Test Rec: {test_recall:.3f} | Test F1: {test_f1:.3f}") + + os.makedirs(args.output_path, exist_ok=True) + conf_matrix_path = os.path.join(args.output_path, f"confusion_matrix_{args.model_name}.png") + plot_confusion_matrix(all_labels, all_preds, classes=["real", "generative"], output_path=conf_matrix_path) + +print("\nEvaluating Model on Test Set...") +Test(model, test_loader, device) diff --git a/ISMIR_2025/music2vec/__pycache__/datalib.cpython-311.pyc b/ISMIR_2025/music2vec/__pycache__/datalib.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ed2fc175e6d0bd0ac67453d34fe473075f3a1c7 Binary files /dev/null and b/ISMIR_2025/music2vec/__pycache__/datalib.cpython-311.pyc differ diff --git a/ISMIR_2025/music2vec/__pycache__/networks.cpython-311.pyc b/ISMIR_2025/music2vec/__pycache__/networks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..392b7845da3002c5a1d4588c1162b82c9ec40b75 Binary files /dev/null and b/ISMIR_2025/music2vec/__pycache__/networks.cpython-311.pyc differ diff --git a/ISMIR_2025/music2vec/__pycache__/networks.cpython-312.pyc b/ISMIR_2025/music2vec/__pycache__/networks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bb6c0f9728de048a1dc2be8f3a720842f1791a3 Binary files /dev/null and b/ISMIR_2025/music2vec/__pycache__/networks.cpython-312.pyc differ diff --git a/ISMIR_2025/music2vec/datalib.py b/ISMIR_2025/music2vec/datalib.py new file mode 100644 index 0000000000000000000000000000000000000000..cac3cc9f878d5db91d2cfb3311ec859d8ac826c5 --- /dev/null +++ b/ISMIR_2025/music2vec/datalib.py @@ -0,0 +1,144 @@ +import os +import glob +import torch +import torchaudio +import librosa +import numpy as np +from sklearn.model_selection import train_test_split +from torch.utils.data import Dataset +from imblearn.over_sampling import RandomOverSampler +from transformers import Wav2Vec2Processor +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence +import scipy.signal as signal +import random + +class FakeMusicCapsDataset(Dataset): + def __init__(self, file_paths, labels, sr=16000, target_duration=10.0, augment=True): + self.file_paths = file_paths + self.labels = labels + self.sr = sr + self.target_samples = int(target_duration * sr) + self.augment = augment + def __len__(self): + return len(self.file_paths) + + def augment_audio(self, y, sr): + if isinstance(y, torch.Tensor): + y = y.numpy() + if random.random() < 0.5: + rate = random.uniform(0.8, 1.2) + y = librosa.effects.time_stretch(y=y, rate=rate) + if random.random() < 0.5: + n_steps = random.randint(-2, 2) + y = librosa.effects.pitch_shift(y=y, sr=sr, n_steps=n_steps) + if random.random() < 0.5: + noise_level = np.random.uniform(0.001, 0.005) + y = y + np.random.normal(0, noise_level, y.shape) + if random.random() < 0.5: + gain = np.random.uniform(0.9, 1.1) + y = y * gain + return torch.tensor(y, dtype=torch.float32) + + + def __getitem__(self, idx): + audio_path = self.file_paths[idx] + label = self.labels[idx] + + waveform, sr = torchaudio.load(audio_path) + waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sr)(waveform) + waveform = waveform.mean(dim=0) + current_samples = waveform.shape[0] + + if label == 0: + waveform = self.augment_audio(waveform, self.sr) + if label == 1: + waveform = self.highpass_filter(waveform, self.sr) + waveform = self.augment_audio(waveform, self.sr) + + if current_samples > self.target_samples: + waveform = waveform[:self.target_samples] + elif current_samples < self.target_samples: + pad_length = self.target_samples - current_samples + waveform = torch.nn.functional.pad(waveform, (0, pad_length)) + + # waveform = waveform.squeeze(0) + if isinstance(waveform, np.ndarray): + waveform = torch.tensor(waveform, dtype=torch.float32) + + return waveform.unsqueeze(0), torch.tensor(label, dtype=torch.long) + + def highpass_filter(self, y, sr, cutoff=500, order=5): + if isinstance(sr, np.ndarray): + sr = np.mean(sr) + if not isinstance(sr, (int, float)): + raise ValueError(f"[ERROR] sr must be a number, but got {type(sr)}: {sr}") + if sr <= 0: + raise ValueError(f"Invalid sample rate: {sr}. It must be greater than 0.") + nyquist = 0.5 * sr + if cutoff <= 0 or cutoff >= nyquist: + print(f"[WARNING] Invalid cutoff frequency {cutoff}, adjusting...") + cutoff = max(10, min(cutoff, nyquist - 1)) + normal_cutoff = cutoff / nyquist + b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) + y_filtered = signal.lfilter(b, a, y) + return y_filtered + + def preprocess_audio(audio_path, target_sr=16000, max_length=160000): + waveform, sr = torchaudio.load(audio_path) + if sr != target_sr: + waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(waveform) + + waveform = waveform.mean(dim=0).unsqueeze(0) + + current_samples = waveform.shape[1] + if current_samples > max_length: + start_idx = (current_samples - max_length) // 2 + waveform = waveform[:, start_idx:start_idx + max_length] + elif current_samples < max_length: + pad_length = max_length - current_samples + waveform = torch.nn.functional.pad(waveform, (0, pad_length)) + + return waveform + + +DATASET_PATH = "/data/kym/AI_Music_Detection/audio/FakeMusicCaps" +SUNOCAPS_PATH = "/data/kym/Audio/SunoCaps" # Open Set 포함 데이터 + +real_files = glob.glob(os.path.join(DATASET_PATH, "real", "**", "*.wav"), recursive=True) +gen_files = glob.glob(os.path.join(DATASET_PATH, "generative", "**", "*.wav"), recursive=True) + +open_real_files = real_files + glob.glob(os.path.join(SUNOCAPS_PATH, "real", "**", "*.wav"), recursive=True) +open_gen_files = gen_files + glob.glob(os.path.join(SUNOCAPS_PATH, "generative", "**", "*.wav"), recursive=True) + +real_labels = [0] * len(real_files) +gen_labels = [1] * len(gen_files) + +open_real_labels = [0] * len(open_real_files) +open_gen_labels = [1] * len(open_gen_files) + +real_train, real_val, real_train_labels, real_val_labels = train_test_split(real_files, real_labels, test_size=0.2, random_state=42) +gen_train, gen_val, gen_train_labels, gen_val_labels = train_test_split(gen_files, gen_labels, test_size=0.2, random_state=42) + +train_files = real_train + gen_train +train_labels = real_train_labels + gen_train_labels +val_files = real_val + gen_val +val_labels = real_val_labels + gen_val_labels + +closed_test_files = real_files + gen_files +closed_test_labels = real_labels + gen_labels + +open_test_files = open_real_files + open_gen_files +open_test_labels = open_real_labels + open_gen_labels + +ros = RandomOverSampler(sampling_strategy='auto', random_state=42) +train_files_resampled, train_labels_resampled = ros.fit_resample(np.array(train_files).reshape(-1, 1), train_labels) + +train_files = train_files_resampled.reshape(-1).tolist() +train_labels = train_labels_resampled + +print(f"Train Original FAKE: {len(gen_train)}") +print(f"Train set (Oversampled) - REAL: {sum(1 for label in train_labels if label == 0)}, " + f"FAKE: {sum(1 for label in train_labels if label == 1)}, Total: {len(train_files)}") +print(f"Validation set - REAL: {len(real_val)}, FAKE: {len(gen_val)}, Total: {len(val_files)}") diff --git a/ISMIR_2025/music2vec/inference.py b/ISMIR_2025/music2vec/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..f3200bff3f3f86877f747fb54e5602f912c16106 --- /dev/null +++ b/ISMIR_2025/music2vec/inference.py @@ -0,0 +1,64 @@ +import os +import torch +import torch.nn.functional as F +import torchaudio +import argparse +from datalib import preprocess_audio +from networks import Wav2Vec2ForFakeMusic + +# Argument Parsing +parser = argparse.ArgumentParser(description="Wav2Vec2 AI Music Detection Inference") +parser.add_argument('--gpu', type=str, default='0', help='GPU ID') +parser.add_argument('--model_name', type=str, choices=['Wav2Vec2ForFakeMusic'], default='Wav2Vec2ForFakeMusic', help='Model name') +parser.add_argument('--ckpt_path', type=str, default='/data/kym/AI_Music_Detection/Code/model/wav2vec/ckpt/', help='Checkpoint directory') +parser.add_argument('--model_type', type=str, choices=['pretrain', 'finetune'], required=True, help='Choose between pretrained or fine-tuned model') +parser.add_argument('--inference', type=str, required=True, help='Path to a .wav file for inference') +args = parser.parse_args() +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Load Model Checkpoint +if args.model_type == 'pretrain': + model_file = os.path.join(args.ckpt_path, "wav2vec2_pretrain_10.pth") +elif args.model_type == 'finetune': + model_file = os.path.join(args.ckpt_path, "wav2vec2_finetune_5.pth") +else: + raise ValueError("Invalid model type. Choose between 'pretrain' or 'finetune'.") + +if not os.path.exists(model_file): + raise FileNotFoundError(f"Model checkpoint not found: {model_file}") + +if args.model_name == 'Wav2Vec2ForFakeMusic': + model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=(args.model_type == 'finetune')) +else: + raise ValueError(f"Invalid model name: {args.model_name}") + +def predict(audio_path): + print(f"\n🔍 Loading model from {model_file}") + + if not os.path.exists(audio_path): + raise FileNotFoundError(f"[ERROR] Audio file not found: {audio_path}") + + model.to(device) + model.eval() + + input_tensor = preprocess_audio(audio_path).to(device) + print(f"Input shape after preprocessing: {input_tensor.shape}") + + with torch.no_grad(): + output = model(input_tensor) + print(f"Raw model output (logits): {output}") + + probabilities = F.softmax(output, dim=1) + ai_music_prob = probabilities[0, 1].item() + + print(f"Softmax Probabilities: {probabilities}") + print(f"AI Music Probability: {ai_music_prob:.4f}") + + if ai_music_prob > 0.5: + print(f" FAKE MUSIC DETECTED ({ai_music_prob:.2%})") + else: + print(f" REAL MUSIC DETECTED ({100 - ai_music_prob * 100:.2f}%)") + +if __name__ == "__main__": + predict(args.inference) diff --git a/ISMIR_2025/music2vec/main.py b/ISMIR_2025/music2vec/main.py new file mode 100644 index 0000000000000000000000000000000000000000..7a06ae4cf712d8496142acd34d506f6e14a391a2 --- /dev/null +++ b/ISMIR_2025/music2vec/main.py @@ -0,0 +1,155 @@ +import os +import random +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from tqdm import tqdm +from torch.utils.data import DataLoader +from sklearn.metrics import f1_score, precision_score, recall_score, balanced_accuracy_score +import wandb +import argparse +from transformers import Wav2Vec2Processor +from datalib import FakeMusicCapsDataset, train_files, train_labels, val_files, val_labels +from networks import Music2VecClassifier, CCV + +parser = argparse.ArgumentParser(description='AI Music Detection Training with Music2Vec + CCV') +parser.add_argument('--gpu', type=str, default='2', help='GPU ID') +parser.add_argument('--batch_size', type=int, default=32, help='Batch size') +parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate') +parser.add_argument('--finetune_lr', type=float, default=1e-3, help='Fine-Tune Learning rate') +parser.add_argument('--pretrain_epochs', type=int, default=20, help='Pretraining epochs (REAL data only)') +parser.add_argument('--finetune_epochs', type=int, default=10, help='Fine-tuning epochs (REAL + FAKE data)') +parser.add_argument('--checkpoint_dir', type=str, default='', help='Checkpoint directory') +parser.add_argument('--weight_decay', type=float, default=0.001, help="Weight decay for optimizer") + +args = parser.parse_args() + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +torch.manual_seed(42) +random.seed(42) +np.random.seed(42) + +wandb.init(project="music2vec_ccv", name=f"pretrain_{args.pretrain_epochs}_finetune_{args.finetune_epochs}", config=args) + +print("Preparing datasets...") +train_dataset = FakeMusicCapsDataset(train_files, train_labels) +val_dataset = FakeMusicCapsDataset(val_files, val_labels) + +train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) +val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) + +pretrain_ckpt = os.path.join(args.checkpoint_dir, f"music2vec_pretrain_{args.pretrain_epochs}.pth") +finetune_ckpt = os.path.join(args.checkpoint_dir, f"music2vec_ccv_finetune_{args.finetune_epochs}.pth") + +print("Initializing Music2Vec model for Pretraining...") +processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h") +model = Music2VecClassifier(freeze_feature_extractor=False).to(device) # Pretraining에서는 freeze + +criterion = nn.CrossEntropyLoss() +optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) +scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) + +def train(model, dataloader, optimizer, criterion, device, epoch, phase="Pretrain"): + model.train() + total_loss, total_correct, total_samples = 0, 0, 0 + all_preds, all_labels = [], [] + + for inputs, labels in tqdm(dataloader, desc=f"{phase} Training Epoch {epoch+1}"): + labels = labels.to(device) + inputs = inputs.to(device) + + logits = model(inputs) + loss = criterion(logits, labels) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + preds = logits.argmax(dim=1) + total_correct += (preds == labels).sum().item() + total_samples += labels.size(0) + all_preds.extend(preds.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + + scheduler.step() + accuracy = total_correct / total_samples + f1 = f1_score(all_labels, all_preds, average="binary") + balanced_acc = balanced_accuracy_score(all_labels, all_preds) + precision = precision_score(all_labels, all_preds, average="binary") + recall = recall_score(all_labels, all_preds, average="binary") + + wandb.log({ + f"{phase} Train Loss": total_loss / len(dataloader), + f"{phase} Train Accuracy": accuracy, + f"{phase} Train F1 Score": f1, + f"{phase} Train Precision": precision, + f"{phase} Train Recall": recall, + f"{phase} Train Balanced Accuracy": balanced_acc, + }) + + print(f"{phase} Train Epoch {epoch+1}: Train Loss: {total_loss / len(dataloader):.4f}, " + f"Train Acc: {accuracy:.4f}, Train F1: {f1:.4f}, Train Prec: {precision:.4f}, Train Rec: {recall:.4f}, B_ACC: {balanced_acc:.4f}") + +def validate(model, dataloader, criterion, device, phase="Validation"): + model.eval() + total_loss, total_correct, total_samples = 0, 0, 0 + all_preds, all_labels = [], [] + + with torch.no_grad(): + for inputs, labels in tqdm(dataloader, desc=f"{phase}"): + inputs, labels = inputs.to(device), labels.to(device) + inputs = inputs.squeeze(1) + outputs = model(inputs) + loss = criterion(outputs, labels) + + total_loss += loss.item() + preds = outputs.argmax(dim=1) + total_correct += (preds == labels).sum().item() + total_samples += labels.size(0) + + all_preds.extend(preds.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + + accuracy = total_correct / total_samples + f1 = f1_score(all_labels, all_preds, average="weighted") + val_bal_acc = balanced_accuracy_score(all_labels, all_preds) + val_precision = precision_score(all_labels, all_preds, average="binary") + val_recall = recall_score(all_labels, all_preds, average="binary") + + wandb.log({ + f"{phase} Val Loss": total_loss / len(dataloader), + f"{phase} Val Accuracy": accuracy, + f"{phase} Val F1 Score": f1, + f"{phase} Val Precision": val_precision, + f"{phase} Val Recall": val_recall, + f"{phase} Val Balanced Accuracy": val_bal_acc, + }) + print(f"{phase} Val Loss: {total_loss / len(dataloader):.4f}, " + f"Val Acc: {accuracy:.4f}, Val F1: {f1:.4f}, Val Prec: {val_precision:.4f}, Val Rec: {val_recall:.4f}, Val B_ACC: {val_bal_acc:.4f}") + return total_loss / len(dataloader), accuracy, f1 + +print("\nStep 1: Self-Supervised Pretraining on REAL Data") +for epoch in range(args.pretrain_epochs): + train(model, train_loader, optimizer, criterion, device, epoch, phase="Pretrain") + +torch.save(model.state_dict(), pretrain_ckpt) +print(f"\nPretraining completed! Model saved at: {pretrain_ckpt}") + +print("\nInitializing Music2Vec + CCV Model for Fine-Tuning...") +model.load_state_dict(torch.load(pretrain_ckpt)) + +# model = CCV(embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True).to(device) +model = Music2VecClassifier(freeze_feature_extractor=False).to(device) +optimizer = optim.Adam(model.parameters(), lr=args.finetune_lr, weight_decay=args.weight_decay) +scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) + +print("\nStep 2: Fine-Tuning CCV Model using Music2Vec Features") +for epoch in range(args.finetune_epochs): + train(model, train_loader, optimizer, criterion, device, epoch, phase="Fine-Tune") + +torch.save(model.state_dict(), finetune_ckpt) +print(f"\nFine-Tuning completed! Model saved at: {finetune_ckpt}") diff --git a/ISMIR_2025/music2vec/networks.py b/ISMIR_2025/music2vec/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..621f30c95a3933ac286b9a2a54c8b4be2e7aa550 --- /dev/null +++ b/ISMIR_2025/music2vec/networks.py @@ -0,0 +1,247 @@ +import torch +import torch.nn as nn +from transformers import Data2VecAudioModel, Wav2Vec2Processor + +class Music2VecClassifier(nn.Module): + def __init__(self, num_classes=2, freeze_feature_extractor=True): + super(Music2VecClassifier, self).__init__() + + self.processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h") + self.music2vec = Data2VecAudioModel.from_pretrained("m-a-p/music2vec-v1") + + if freeze_feature_extractor: + for param in self.music2vec.parameters(): + param.requires_grad = False + + # Conv1d for learnable weighted average across layers + self.conv1d = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1) + + # Classification head + self.classifier = nn.Sequential( + nn.Linear(self.music2vec.config.hidden_size, 256), + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(256, num_classes) + ) + + def forward(self, input_values): + input_values = input_values.squeeze(1) # Ensure shape [batch, time] + + with torch.no_grad(): + outputs = self.music2vec(input_values, output_hidden_states=True) + hidden_states = torch.stack(outputs.hidden_states) + time_reduced = hidden_states.mean(dim=2) + time_reduced = time_reduced.permute(1, 0, 2) + weighted_avg = self.conv1d(time_reduced).squeeze(1) + + return self.classifier(weighted_avg), weighted_avg + + def unfreeze_feature_extractor(self): + for param in self.music2vec.parameters(): + param.requires_grad = True + +class Music2VecFeatureExtractor(nn.Module): + def __init__(self, freeze_feature_extractor=True): + super(Music2VecFeatureExtractor, self).__init__() + self.processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h") + self.music2vec = Data2VecAudioModel.from_pretrained("m-a-p/music2vec-v1") + + if freeze_feature_extractor: + for param in self.music2vec.parameters(): + param.requires_grad = False + + # Conv1d for learnable weighted average across layers + self.conv1d = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1) + + def forward(self, input_values): + # input_values: [batch, time] + input_values = input_values.squeeze(1) + with torch.no_grad(): + outputs = self.music2vec(input_values, output_hidden_states=True) + hidden_states = torch.stack(outputs.hidden_states) # [num_layers, batch, time, hidden_dim] + time_reduced = hidden_states.mean(dim=2) # [num_layers, batch, hidden_dim] + time_reduced = time_reduced.permute(1, 0, 2) # [batch, num_layers, hidden_dim] + weighted_avg = self.conv1d(time_reduced).squeeze(1) # [batch, hidden_dim] + return weighted_avg + +''' +music2vec+CCV +# ''' +# import torch +# import torch.nn as nn +# from transformers import Data2VecAudioModel, Wav2Vec2Processor +# import torch.nn.functional as F + + +# ### Music2Vec Feature Extractor (Pretrained Model) +# class Music2VecFeatureExtractor(nn.Module): +# def __init__(self, freeze_feature_extractor=True): +# super(Music2VecFeatureExtractor, self).__init__() + +# self.processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h") +# self.music2vec = Data2VecAudioModel.from_pretrained("m-a-p/music2vec-v1") + +# if freeze_feature_extractor: +# for param in self.music2vec.parameters(): +# param.requires_grad = False + +# # Conv1d for learnable weighted average across layers +# self.conv1d = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1) + +# def forward(self, input_values): +# with torch.no_grad(): +# outputs = self.music2vec(input_values, output_hidden_states=True) + +# hidden_states = torch.stack(outputs.hidden_states) # [13, batch, time, hidden_size] +# time_reduced = hidden_states.mean(dim=2) # 평균 풀링: [13, batch, hidden_size] +# time_reduced = time_reduced.permute(1, 0, 2) # [batch, 13, hidden_size] +# weighted_avg = self.conv1d(time_reduced).squeeze(1) # [batch, hidden_size] + +# return weighted_avg # Extracted feature representation + + +# def unfreeze_feature_extractor(self): +# for param in self.music2vec.parameters(): +# param.requires_grad = True # Unfreeze for Fine-tuning + +# ### CNN Feature Extractor for CCV +class CNNEncoder(nn.Module): + def __init__(self, embed_dim=512): + super(CNNEncoder, self).__init__() + self.conv_block = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d((2,1)), # 기존 MaxPool2d(2)를 MaxPool2d((2,1))으로 변경 + nn.Conv2d(16, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d((1,1)), # 추가된 MaxPool2d(1,1)로 크기 유지 + nn.AdaptiveAvgPool2d((4, 4)) # 최종 크기 조정 + ) + self.projection = nn.Linear(32 * 4 * 4, embed_dim) + + def forward(self, x): + # print(f"Input shape before CNNEncoder: {x.shape}") # 디버깅용 출력 + x = self.conv_block(x) + B, C, H, W = x.shape + x = x.view(B, -1) + x = self.projection(x) + return x + + +### Cross-Attention Module +class CrossAttentionLayer(nn.Module): + def __init__(self, embed_dim, num_heads): + super(CrossAttentionLayer, self).__init__() + self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) + self.layer_norm = nn.LayerNorm(embed_dim) + self.feed_forward = nn.Sequential( + nn.Linear(embed_dim, embed_dim * 4), + nn.ReLU(), + nn.Linear(embed_dim * 4, embed_dim) + ) + self.attention_weights = None + + def forward(self, x, cross_input): + attn_output, attn_weights = self.multihead_attn(query=x, key=cross_input, value=cross_input) + self.attention_weights = attn_weights + x = self.layer_norm(x + attn_output) + feed_forward_output = self.feed_forward(x) + x = self.layer_norm(x + feed_forward_output) + return x + +### Cross-Attention Transformer +class CrossAttentionViT(nn.Module): + def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2): + super(CrossAttentionViT, self).__init__() + + self.cross_attention_layers = nn.ModuleList([ + CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers) + ]) + + encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + self.classifier = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, num_classes) + ) + + def forward(self, x, cross_attention_input): + self.attention_maps = [] + for layer in self.cross_attention_layers: + x = layer(x, cross_attention_input) + self.attention_maps.append(layer.attention_weights) + + x = x.unsqueeze(1).permute(1, 0, 2) + x = self.transformer(x) + x = x.mean(dim=0) + x = self.classifier(x) + return x + +### CCV Model (Final Classifier) +# class CCV(nn.Module): +# def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True): +# super(CCV, self).__init__() + +# self.music2vec_extractor = Music2VecClassifier(freeze_feature_extractor=freeze_feature_extractor) + +# # CNN Encoder for Image Representation +# self.encoder = CNNEncoder(embed_dim=embed_dim) + +# # Transformer with Cross-Attention +# self.decoder = CrossAttentionViT(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes) + +# def forward(self, x, cross_attention_input=None): +# x = self.music2vec_extractor(x) +# # print(f"After Music2VecExtractor: {x.shape}") # (batch, 2) 출력됨 + +# # CNNEncoder가 기대하는 입력 크기 맞추기 +# x = x.unsqueeze(1).unsqueeze(-1) # (batch, 1, 2, 1) 형태로 변환 +# # print(f"Before CNNEncoder: {x.shape}") # CNN 입력 확인 + +# x = self.encoder(x) + +# if cross_attention_input is None: +# cross_attention_input = x + +# x = self.decoder(x, cross_attention_input) + +# return x + +class CCV(nn.Module): + def __init__(self, embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True): + super(CCV, self).__init__() + self.feature_extractor = Music2VecFeatureExtractor(freeze_feature_extractor=freeze_feature_extractor) + + # Cross-Attention Transformer + self.cross_attention_layers = nn.ModuleList([ + CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers) + ]) + + # Transformer Encoder + encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + # Classification Head + self.classifier = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, num_classes) + ) + + def forward(self, input_values): + # Extract feature embeddings + features = self.feature_extractor(input_values) # [batch, feature_dim] + # Average over layer dimension if necessary (여기서는 이미 [batch, hidden_dim]) + # Apply Cross-Attention Layers + for layer in self.cross_attention_layers: + features = layer(features.unsqueeze(1), features.unsqueeze(1)).squeeze(1) + # Transformer Encoding + encoded = self.transformer(features.unsqueeze(1)) + encoded = encoded.mean(dim=1) + # Classification Head + logits = self.classifier(encoded) + return logits + + def get_attention_maps(self): + # 만약 CrossAttentionLayer의 attention_maps를 사용하고 싶다면 구현 + return None diff --git a/ISMIR_2025/music2vec/test.py b/ISMIR_2025/music2vec/test.py new file mode 100644 index 0000000000000000000000000000000000000000..010c83dcd2e5bb47d7cd79fcb613ee30e5e3d6c1 --- /dev/null +++ b/ISMIR_2025/music2vec/test.py @@ -0,0 +1,119 @@ +import os +import torch +import torch.nn.functional as F +import numpy as np +import matplotlib.pyplot as plt +from torch.utils.data import DataLoader +from sklearn.metrics import precision_score, recall_score, f1_score, balanced_accuracy_score, confusion_matrix +from datalib import ( + FakeMusicCapsDataset, + closed_test_files, closed_test_labels, + open_test_files, open_test_labels, + val_files, val_labels +) +from networks import Music2VecClassifier +import argparse + +''' +python3 test.py --gpu 1 --closed_test --ckpt_path "" +''' +parser = argparse.ArgumentParser(description="AI Music Detection Testing with Music2Vec") +parser.add_argument('--gpu', type=str, default='1', help='GPU ID') +parser.add_argument('--batch_size', type=int, default=32, help='Batch size') +parser.add_argument('--ckpt_path', type=str, default='/data/kym/AI_Music_Detection/Code/model/music2vec/ckpt/music2vec_pretrain_10.pth', help='Checkpoint directory') +parser.add_argument('--model_name', type=str, default="music2vec", help="Model name") +parser.add_argument('--closed_test', action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)") +parser.add_argument('--open_test', action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)") +parser.add_argument('--output_path', type=str, default='', help='Path to save test results') + +args = parser.parse_args() +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def plot_confusion_matrix(y_true, y_pred, classes, output_path): + cm = confusion_matrix(y_true, y_pred) + fig, ax = plt.subplots(figsize=(6, 6)) + im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) + ax.figure.colorbar(im, ax=ax) + + num_classes = cm.shape[0] + tick_labels = classes[:num_classes] + + ax.set(xticks=np.arange(num_classes), + yticks=np.arange(num_classes), + xticklabels=tick_labels, + yticklabels=tick_labels, + ylabel='True label', + xlabel='Predicted label') + + thresh = cm.max() / 2. + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + ax.text(j, i, format(cm[i, j], 'd'), + ha="center", va="center", + color="white" if cm[i, j] > thresh else "black") + + fig.tight_layout() + plt.savefig(output_path) + plt.close(fig) + +model = Music2VecClassifier().to(device) + +ckpt_file = os.path.join(args.ckpt_path) +if not os.path.exists(ckpt_file): + raise FileNotFoundError(f"Checkpoint not found: {ckpt_file}") + +print(f"\nLoading model from {ckpt_file}") +model.load_state_dict(torch.load(ckpt_file, map_location=device)) +model.eval() + +torch.cuda.empty_cache() + +if args.closed_test: + print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...") + test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels, target_duration=10.0) +elif args.open_test: + print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...") + test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels, target_duration=10.0) +else: + print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...") + test_dataset = FakeMusicCapsDataset(val_files, val_labels, target_duration=10.0) + +test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8) + +def Test(model, test_loader, device): + model.eval() + test_loss, test_correct, test_total = 0, 0, 0 + all_preds, all_labels = [], [] + + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + loss = F.cross_entropy(output, target) + + test_loss += loss.item() * data.size(0) + preds = output.argmax(dim=1) + test_correct += (preds == target).sum().item() + test_total += target.size(0) + + all_labels.extend(target.cpu().numpy()) + all_preds.extend(preds.cpu().numpy()) + + test_loss /= test_total + test_acc = test_correct / test_total + test_bal_acc = balanced_accuracy_score(all_labels, all_preds) + test_precision = precision_score(all_labels, all_preds, average="binary") + test_recall = recall_score(all_labels, all_preds, average="binary") + test_f1 = f1_score(all_labels, all_preds, average="binary") + + print(f"\nTest Results - Loss: {test_loss:.4f} | Test Acc: {test_acc:.3f} | " + f"Test B_ACC: {test_bal_acc:.4f} | Test Prec: {test_precision:.3f} | " + f"Test Rec: {test_recall:.3f} | Test F1: {test_f1:.3f}") + + os.makedirs(args.output_path, exist_ok=True) + conf_matrix_path = os.path.join(args.output_path, f"confusion_matrix_{args.model_name}.png") + plot_confusion_matrix(all_labels, all_preds, classes=["real", "generative"], output_path=conf_matrix_path) + +print("\nEvaluating Model on Test Set...") +Test(model, test_loader, device) diff --git a/ISMIR_2025/wav2vec/__pycache__/datalib.cpython-311.pyc b/ISMIR_2025/wav2vec/__pycache__/datalib.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d01365bf65abec49d0a9dc587356b45c18bc946 Binary files /dev/null and b/ISMIR_2025/wav2vec/__pycache__/datalib.cpython-311.pyc differ diff --git a/ISMIR_2025/wav2vec/__pycache__/loss.cpython-311.pyc b/ISMIR_2025/wav2vec/__pycache__/loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..784044f606e8cc8e5530ac9e685e3bbde15466ad Binary files /dev/null and b/ISMIR_2025/wav2vec/__pycache__/loss.cpython-311.pyc differ diff --git a/ISMIR_2025/wav2vec/__pycache__/networks.cpython-311.pyc b/ISMIR_2025/wav2vec/__pycache__/networks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3098268ad5d0adf9a55381ccd3f9d04fe59a86e3 Binary files /dev/null and b/ISMIR_2025/wav2vec/__pycache__/networks.cpython-311.pyc differ diff --git a/ISMIR_2025/wav2vec/__pycache__/networks.cpython-312.pyc b/ISMIR_2025/wav2vec/__pycache__/networks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66822dd6b37c07f737cfe0cf7e15e8f67c17ccd0 Binary files /dev/null and b/ISMIR_2025/wav2vec/__pycache__/networks.cpython-312.pyc differ diff --git a/ISMIR_2025/wav2vec/__pycache__/wav2vec_datalib.cpython-311.pyc b/ISMIR_2025/wav2vec/__pycache__/wav2vec_datalib.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea8412e40b9fecb971b99a19a4888affcd9a93d0 Binary files /dev/null and b/ISMIR_2025/wav2vec/__pycache__/wav2vec_datalib.cpython-311.pyc differ diff --git a/ISMIR_2025/wav2vec/datalib.py b/ISMIR_2025/wav2vec/datalib.py new file mode 100644 index 0000000000000000000000000000000000000000..af0d13fe30e562a272c953e58c6579c581caf372 --- /dev/null +++ b/ISMIR_2025/wav2vec/datalib.py @@ -0,0 +1,139 @@ +import os +import glob +import random +import torch +import librosa +import numpy as np +import utils +from sklearn.model_selection import train_test_split +from torch.utils.data import Dataset, DataLoader +import scipy.signal as signal +import scipy.signal +from scipy.signal import butter, lfilter +import numpy as np +import scipy.signal as signal +import librosa +import torch +import random +from torch.utils.data import Dataset +import logging +import csv +import logging +import time +import numpy as np +import h5py +import torch +import torchaudio +from imblearn.over_sampling import RandomOverSampler +from networks import Wav2Vec2ForFakeMusic +from transformers import Wav2Vec2Processor +import torchaudio.transforms as T + +class FakeMusicCapsDataset(Dataset): + def __init__(self, file_paths, labels, sr=16000, target_duration=10.0): + self.file_paths = file_paths + self.labels = labels + self.sr = sr + self.target_duration = target_duration + self.target_samples = int(target_duration * sr) + + self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base") + + def highpass_filter(self, y, sr, cutoff=500, order=5): + if isinstance(sr, np.ndarray): + sr = np.mean(sr) + if not isinstance(sr, (int, float)): + raise ValueError(f"[ERROR] sr must be a number, but got {type(sr)}: {sr}") + if sr <= 0: + raise ValueError(f"Invalid sample rate: {sr}. It must be greater than 0.") + nyquist = 0.5 * sr + if cutoff <= 0 or cutoff >= nyquist: + print(f"[WARNING] Invalid cutoff frequency {cutoff}, adjusting...") + cutoff = max(10, min(cutoff, nyquist - 1)) + normal_cutoff = cutoff / nyquist + b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) + y_filtered = signal.lfilter(b, a, y) + return y_filtered + + def __len__(self): + return len(self.file_paths) + + def __getitem__(self, idx): + audio_path = self.file_paths[idx] + label = self.labels[idx] + + waveform, sr = torchaudio.load(audio_path) + waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sr)(waveform) + + waveform = waveform.squeeze(0) + if label == 0: + waveform = self.augment_audio(waveform, self.sr) + if label == 1: + waveform = self.highpass_filter(waveform, self.sr) + + current_samples = waveform.shape[0] + if current_samples > self.target_samples: + start_idx = (current_samples - self.target_samples) // 2 + waveform = waveform[start_idx:start_idx + self.target_samples] + elif current_samples < self.target_samples: + waveform = torch.nn.functional.pad(waveform, (0, self.target_samples - current_samples)) + + waveform = torch.tensor(waveform, dtype=torch.float32).unsqueeze(0) + label = torch.tensor(label, dtype=torch.long) + + return waveform, label + +def preprocess_audio(audio_path, target_sr=16000, target_duration=10.0): + waveform, sr = librosa.load(audio_path, sr=target_sr) + + target_samples = int(target_duration * target_sr) + current_samples = len(waveform) + + if current_samples > target_samples: + waveform = waveform[:target_samples] + elif current_samples < target_samples: + waveform = np.pad(waveform, (0, target_samples - current_samples)) + + waveform = torch.tensor(waveform).unsqueeze(0) + return waveform + + +DATASET_PATH = "/data/kym/AI_Music_Detection/audio/FakeMusicCaps" +SUNOCAPS_PATH = "/data/kym/Audio/SunoCaps" # Open Set 포함 데이터 + +real_files = glob.glob(os.path.join(DATASET_PATH, "real", "**", "*.wav"), recursive=True) +gen_files = glob.glob(os.path.join(DATASET_PATH, "generative", "**", "*.wav"), recursive=True) + +open_real_files = real_files + glob.glob(os.path.join(SUNOCAPS_PATH, "real", "**", "*.wav"), recursive=True) +open_gen_files = gen_files + glob.glob(os.path.join(SUNOCAPS_PATH, "generative", "**", "*.wav"), recursive=True) + +real_labels = [0] * len(real_files) +gen_labels = [1] * len(gen_files) + +open_real_labels = [0] * len(open_real_files) +open_gen_labels = [1] * len(open_gen_files) + +real_train, real_val, real_train_labels, real_val_labels = train_test_split(real_files, real_labels, test_size=0.2, random_state=42) +gen_train, gen_val, gen_train_labels, gen_val_labels = train_test_split(gen_files, gen_labels, test_size=0.2, random_state=42) + +train_files = real_train + gen_train +train_labels = real_train_labels + gen_train_labels +val_files = real_val + gen_val +val_labels = real_val_labels + gen_val_labels + +closed_test_files = real_files + gen_files +closed_test_labels = real_labels + gen_labels + +open_test_files = open_real_files + open_gen_files +open_test_labels = open_real_labels + open_gen_labels + +ros = RandomOverSampler(sampling_strategy='auto', random_state=42) +train_files_resampled, train_labels_resampled = ros.fit_resample(np.array(train_files).reshape(-1, 1), train_labels) + +train_files = train_files_resampled.reshape(-1).tolist() +train_labels = train_labels_resampled + +print(f"Train Original FAKE: {len(gen_train)}") +print(f"Train set (Oversampled) - REAL: {sum(1 for label in train_labels if label == 0)}, " + f"FAKE: {sum(1 for label in train_labels if label == 1)}, Total: {len(train_files)}") +print(f"Validation set - REAL: {len(real_val)}, FAKE: {len(gen_val)}, Total: {len(val_files)}") diff --git a/ISMIR_2025/wav2vec/inference.py b/ISMIR_2025/wav2vec/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..d0951f644b839155028b6b5fcd2811fd0126cbbf --- /dev/null +++ b/ISMIR_2025/wav2vec/inference.py @@ -0,0 +1,71 @@ +import os +import torch +import torch.nn.functional as F +import torchaudio +import argparse +from AI_Music_Detection.Code.model.wav2vec.wav2vec_datalib import preprocess_audio +from networks import Wav2Vec2ForFakeMusic + +''' +command: python inference.py --gpu 0 --model_type pretrain --inference .wav +''' +parser = argparse.ArgumentParser(description="Wav2Vec2 AI Music Detection Inference") +parser.add_argument('--gpu', type=str, default='0', help='GPU ID') +parser.add_argument('--model_name', type=str, choices=['Wav2Vec2ForFakeMusic'], default='Wav2Vec2ForFakeMusic', help='Model name') +parser.add_argument('--ckpt_path', type=str, default='/data/kym/AI_Music_Detection/Code/model/wav2vec/ckpt/', help='Checkpoint directory') +parser.add_argument('--model_type', type=str, choices=['pretrain', 'finetune'], required=True, help='Choose between pretrained or fine-tuned model') +parser.add_argument('--inference', type=str, help='Path to a .wav file for inference') +args = parser.parse_args() +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +if args.model_type == 'pretrain': + model_file = os.path.join(args.ckpt_path, "wav2vec2_pretrain_10.pth") +elif args.model_type == 'finetune': + model_file = os.path.join(args.ckpt_path, "wav2vec2_finetune_5.pth") +else: + raise ValueError("Invalid model type. Choose between 'pretrain' or 'finetune'.") + +if not os.path.exists(model_file): + raise FileNotFoundError(f"Model checkpoint not found: {model_file}") + +if args.model_name == 'Wav2Vec2ForFakeMusic': + model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=(args.model_type == 'finetune')) +else: + raise ValueError(f"Invalid model name: {args.model_name}") + +def predict(audio_path): + print(f"\n🔍 Loading model from {model_file}") + + if not os.path.exists(audio_path): + raise FileNotFoundError(f"[ERROR] Audio file not found: {audio_path}") + + model.to(device) + + input_tensor = preprocess_audio(audio_path).to(device) + print(f"Input shape after preprocessing: {input_tensor.shape}") + + with torch.no_grad(): + output = model(input_tensor) + print(f"Raw model output (logits): {output}") + + probabilities = F.softmax(output, dim=1) + ai_music_prob = probabilities[0, 1].item() + + print(f"Softmax Probabilities: {probabilities}") + print(f"AI Music Probability: {ai_music_prob:.4f}") + + if ai_music_prob > 0.5: + print(f" FAKE MUSIC DETECTED ({ai_music_prob:.2%})") + else: + print(f" REAL MUSIC DETECTED ({100 - ai_music_prob * 100:.2f}%)") + + + +if __name__ == "__main__": + if args.inference: + if not os.path.exists(args.inference): + print(f"[ERROR] No File Found: {args.inference}") + else: + predict(args.inference) + diff --git a/ISMIR_2025/wav2vec/main.py b/ISMIR_2025/wav2vec/main.py new file mode 100644 index 0000000000000000000000000000000000000000..6b0fca7906256ada966b563a613bf6842c475c18 --- /dev/null +++ b/ISMIR_2025/wav2vec/main.py @@ -0,0 +1,162 @@ +import os +import random +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from tqdm import tqdm +from torch.utils.data import DataLoader +from sklearn.metrics import f1_score, precision_score, recall_score, balanced_accuracy_score, classification_report +import wandb +import argparse +from datalib import FakeMusicCapsDataset, train_files, train_labels, val_files, val_labels +from networks import Wav2Vec2ForFakeMusic + +''' +python inference.py --gpu 0 --model_type finetune --inference +''' +parser = argparse.ArgumentParser(description='AI Music Detection Training with Wav2Vec 2.0') +parser.add_argument('--gpu', type=str, default='2', help='GPU ID') +parser.add_argument('--batch_size', type=int, default=32, help='Batch size') +parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate') +parser.add_argument('--pretrain_epochs', type=int, default=20, help='Pretraining epochs (REAL data only)') +parser.add_argument('--finetune_epochs', type=int, default=10, help='Fine-tuning epochs (REAL + FAKE data)') +parser.add_argument('--checkpoint_dir', type=str, default='', help='Checkpoint directory') +parser.add_argument('--weight_decay', type=float, default=0.05, help="Weight decay for optimizer") + +args = parser.parse_args() + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +torch.manual_seed(42) +random.seed(42) +np.random.seed(42) + +wandb.init(project="", name=f"pretrain_{args.pretrain_epochs}_finetune_{args.finetune_epochs}", config=args) + +print("Preparing datasets...") +train_dataset = FakeMusicCapsDataset(train_files, train_labels) +val_dataset = FakeMusicCapsDataset(val_files, val_labels) + +train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) +val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) + +pretrain_ckpt = os.path.join(args.checkpoint_dir, f"wav2vec2_pretrain_{args.pretrain_epochs}.pth") +finetune_ckpt = os.path.join(args.checkpoint_dir, f"wav2vec2_finetune_{args.finetune_epochs}.pth") + +print("Initializing model...") +model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=True).to(device) + +criterion = nn.CrossEntropyLoss() +optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) +scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1) + +def train(model, dataloader, optimizer, criterion, scheduler, device, epoch, phase="Pretrain"): + model.train() + total_loss, total_correct, total_samples = 0, 0, 0 + all_preds, all_labels = [], [] + attention_maps = [] + + for inputs, labels in tqdm(dataloader, desc=f"{phase} Training Epoch {epoch+1}"): + inputs, labels = inputs.to(device), labels.to(device) + inputs = inputs.float() + + outputs = model(inputs) + loss = criterion(outputs, labels) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + preds = outputs.argmax(dim=1) + total_correct += (preds == labels).sum().item() + total_samples += labels.size(0) + + all_preds.extend(preds.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + + if hasattr(model, "get_attention_maps"): + attention_maps.append(model.get_attention_maps()) + + scheduler.step() + + accuracy = total_correct / total_samples + f1 = f1_score(all_labels, all_preds, average="weighted") + precision = precision_score(all_labels, all_preds, average="binary") + recall = recall_score(all_labels, all_preds, average="binary") + balanced_acc = balanced_accuracy_score(all_labels, all_preds) + + wandb.log({ + f"{phase} Train Loss": total_loss / len(dataloader), + f"{phase} Train Accuracy": accuracy, + f"{phase} Train F1 Score": f1, + f"{phase} Train Precision": precision, + f"{phase} Train Recall": recall, + f"{phase} Train Balanced Accuracy": balanced_acc, + }) + + print(f"{phase} Train Epoch {epoch+1}: Train Loss: {total_loss / len(dataloader):.4f}, " + f"Train Acc: {accuracy:.4f}, Train F1: {f1:.4f}, Train Prec: {precision:.4f}, Train Rec: {recall:.4f}, B_ACC: {balanced_acc:.4f}") + +def validate(model, dataloader, criterion, device, phase="Validation"): + model.eval() + total_loss, total_correct, total_samples = 0, 0, 0 + all_preds, all_labels = [], [] + + with torch.no_grad(): + for inputs, labels in tqdm(dataloader, desc=f"{phase}"): + inputs, labels = inputs.to(device), labels.to(device) + inputs = inputs.squeeze(1) + + outputs = model(inputs) + loss = criterion(outputs, labels) + + + total_loss += loss.item() + preds = outputs.argmax(dim=1) + total_correct += (preds == labels).sum().item() + total_samples += labels.size(0) + + all_preds.extend(preds.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + + accuracy = total_correct / total_samples + f1 = f1_score(all_labels, all_preds, average="weighted") + val_bal_acc = balanced_accuracy_score(all_labels, all_preds) + val_precision = precision_score(all_labels, all_preds, average="binary") + val_recall = recall_score(all_labels, all_preds, average="binary") + + wandb.log({ + f"{phase} Val Loss": total_loss / len(dataloader), + f"{phase} Val Accuracy": accuracy, + f"{phase} Val F1 Score": f1, + f"{phase} Val Precision": val_precision, + f"{phase} Val Recall": val_recall, + f"{phase} Val Balanced Accuracy": val_bal_acc, + }) + print(f"{phase} Val Loss: {total_loss / len(dataloader):.4f}, " + f"Val Acc: {accuracy:.4f}, Val F1: {f1:.4f}, Val Prec: {val_precision:.4f}, Val Rec: {val_recall:.4f}, Val B_ACC: {val_bal_acc:.4f}") + return total_loss / len(dataloader), accuracy, f1 + +print("\nStep 1: Self-Supervised Pretraining on REAL Data") +for epoch in range(args.pretrain_epochs): + train(model, train_loader, optimizer, criterion, scheduler, device, epoch, phase="Pretrain") + +torch.save(model.state_dict(), pretrain_ckpt) +print(f"\nPretraining completed! Model saved at: {pretrain_ckpt}") + +model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=False).to(device) +model.load_state_dict(torch.load(pretrain_ckpt)) +print(f"\n🔍 Loaded Pretrained Model from {pretrain_ckpt}") + +optimizer = optim.Adam(model.parameters(), lr=args.learning_rate / 10, weight_decay=args.weight_decay) + +print("\nStep 2: Fine-Tuning on REAL + FAKE Data") +for epoch in range(args.finetune_epochs): + train(model, train_loader, optimizer, criterion, scheduler, device, epoch, phase="Fine-Tune") + validate(model, val_loader, criterion, device, phase="Fine-Tune Validation") + +torch.save(model.state_dict(), finetune_ckpt) +print(f"\nFine-Tuning completed! Model saved at: {finetune_ckpt}") diff --git a/ISMIR_2025/wav2vec/networks.py b/ISMIR_2025/wav2vec/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..63634b0f5323a5ae7cd0d2618c448c8ad261a7e7 --- /dev/null +++ b/ISMIR_2025/wav2vec/networks.py @@ -0,0 +1,161 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import matplotlib.pyplot as plt +import seaborn as sns + +''' +freeze_feature_extractor=True 시 Feature Extractor를 동결 (Pretraining) +unfreeze_feature_extractor()를 호출하면 Fine-Tuning 가능 +''' +import torch +import torch.nn as nn +import torch.nn.functional as F +import matplotlib.pyplot as plt +import seaborn as sns +from transformers import Wav2Vec2Model + +class cnn(nn.Module): + def __init__(self, embed_dim=512): + super(cnn, self).__init__() + self.conv_block = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Conv2d(16, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.AdaptiveAvgPool2d((4, 4)) + ) + self.projection = nn.Linear(32 * 4 * 4, embed_dim) + + def forward(self, x): + x = self.conv_block(x) + B, C, H, W = x.shape + x = x.view(B, -1) + x = self.projection(x) + return x + +class CrossAttentionLayer(nn.Module): + def __init__(self, embed_dim, num_heads): + super(CrossAttentionLayer, self).__init__() + self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) + self.layer_norm = nn.LayerNorm(embed_dim) + self.feed_forward = nn.Sequential( + nn.Linear(embed_dim, embed_dim * 4), + nn.ReLU(), + nn.Linear(embed_dim * 4, embed_dim) + ) + self.attention_weights = None + + def forward(self, x, cross_input): + # Cross-attention between x and cross_input + attn_output, attn_weights = self.multihead_attn(query=x, key=cross_input, value=cross_input) + self.attention_weights = attn_weights + x = self.layer_norm(x + attn_output) + feed_forward_output = self.feed_forward(x) + x = self.layer_norm(x + feed_forward_output) + return x + +class CrossAttentionViT(nn.Module): + def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2): + super(CrossAttentionViT, self).__init__() + + self.cross_attention_layers = nn.ModuleList([ + CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers) + ]) + + encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + self.classifier = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, num_classes) + ) + + def forward(self, x, cross_attention_input): + self.attention_maps = [] + for layer in self.cross_attention_layers: + x = layer(x, cross_attention_input) + self.attention_maps.append(layer.attention_weights) + + x = x.unsqueeze(1).permute(1, 0, 2) + x = self.transformer(x) + x = x.mean(dim=0) + x = self.classifier(x) + return x + +class CCV(nn.Module): + def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2): + super(CCV, self).__init__() + self.encoder = cnn(embed_dim=embed_dim) + self.decoder = CrossAttentionViT(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes) + + def forward(self, x, cross_attention_input=None): + x = self.encoder(x) + + if cross_attention_input is None: + cross_attention_input = x + + x = self.decoder(x, cross_attention_input) + + # Attention Map 저장 + self.attention_maps = self.decoder.attention_maps + + return x + + def get_attention_maps(self): + return self.attention_maps + +import torch +import torch.nn as nn +from transformers import Wav2Vec2Model + +class Wav2Vec2ForFakeMusic(nn.Module): + def __init__(self, num_classes=2, freeze_feature_extractor=True): + super(Wav2Vec2ForFakeMusic, self).__init__() + + self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base") + + if freeze_feature_extractor: + for param in self.wav2vec.parameters(): + param.requires_grad = False + + self.classifier = nn.Sequential( + nn.Linear(self.wav2vec.config.hidden_size, 256), # 768 → 256 + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(256, num_classes) # 256 → 2 (Binary Classification) + ) + + def forward(self, x): + x = x.squeeze(1) + output = self.wav2vec(x) + features = output["last_hidden_state"] # (batch_size, seq_len, feature_dim) + pooled_features = features.mean(dim=1) # ✅ Mean Pooling 적용 (batch_size, feature_dim) + logits = self.classifier(pooled_features) # (batch_size, num_classes) + + return logits, pooled_features + + +def visualize_attention_map(attn_map, mel_spec, layer_idx): + attn_map = attn_map.mean(dim=1).squeeze().cpu().numpy() # 여러 head 평균 + mel_spec = mel_spec.squeeze().cpu().numpy() + + fig, axs = plt.subplots(2, 1, figsize=(10, 8)) + + # 1Log-Mel Spectrogram 시각화 + sns.heatmap(mel_spec, cmap='inferno', ax=axs[0]) + axs[0].set_title("Log-Mel Spectrogram") + axs[0].set_xlabel("Time Frames") + axs[0].set_ylabel("Mel Frequency Bins") + + # Attention Map 시각화 + sns.heatmap(attn_map, cmap='viridis', ax=axs[1]) + axs[1].set_title(f"Attention Map (Layer {layer_idx})") + axs[1].set_xlabel("Time Frames") + axs[1].set_ylabel("Query Positions") + + plt.tight_layout() + plt.show() + plt.savefig("/data/kym/AI_Music_Detection/Code/model/attention_map/crossattn.png") diff --git a/ISMIR_2025/wav2vec/test.py b/ISMIR_2025/wav2vec/test.py new file mode 100644 index 0000000000000000000000000000000000000000..d74d07983c843a247b1a15a4112d67f82ef34521 --- /dev/null +++ b/ISMIR_2025/wav2vec/test.py @@ -0,0 +1,148 @@ +import os +import torch +import torch.nn.functional as F +import numpy as np +import matplotlib.pyplot as plt +from torch.utils.data import DataLoader +from sklearn.metrics import precision_score, recall_score, f1_score, balanced_accuracy_score, confusion_matrix +from datalib import ( + FakeMusicCapsDataset, + closed_test_files, closed_test_labels, + open_test_files, open_test_labels, + val_files, val_labels +) +from networks import Wav2Vec2ForFakeMusic +import tqdm +from tqdm import tqdm +import argparse +''' +python3 test.py --finetune_test --closed_test | --open_test +''' +parser = argparse.ArgumentParser(description="AI Music Detection Testing with Wav2Vec 2.0") +parser.add_argument('--gpu', type=str, default='0', help='GPU ID') +parser.add_argument('--batch_size', type=int, default=32, help='Batch size') +parser.add_argument('--ckpt_path', type=str, default='', help='Checkpoint directory') +parser.add_argument('--pretrain_test', action="store_true", help="Test Pretrained Wav2Vec2 Model") +parser.add_argument('--finetune_test', action="store_true", help="Test Fine-Tuned Wav2Vec2 Model") +parser.add_argument('--closed_test', action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)") +parser.add_argument('--open_test', action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)") +parser.add_argument('--output_path', type=str, default='', help='Path to save test results') + +args = parser.parse_args() +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def plot_confusion_matrix(y_true, y_pred, classes, output_path): + cm = confusion_matrix(y_true, y_pred) + fig, ax = plt.subplots(figsize=(6, 6)) + im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) + ax.figure.colorbar(im, ax=ax) + + num_classes = cm.shape[0] + tick_labels = classes[:num_classes] + + ax.set(xticks=np.arange(num_classes), + yticks=np.arange(num_classes), + xticklabels=tick_labels, + yticklabels=tick_labels, + ylabel='True label', + xlabel='Predicted label') + + thresh = cm.max() / 2. + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + ax.text(j, i, format(cm[i, j], 'd'), + ha="center", va="center", + color="white" if cm[i, j] > thresh else "black") + + fig.tight_layout() + plt.savefig(output_path) + plt.close(fig) + +if args.pretrain_test: + ckpt_file = os.path.join(args.ckpt_path, "wav2vec2_pretrain_20.pth") + print("\n🔍 Loading Pretrained Model:", ckpt_file) + model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=True).to(device) + +elif args.finetune_test: + ckpt_file = os.path.join(args.ckpt_path, "wav2vec2_finetune_10.pth") + print("\n🔍 Loading Fine-Tuned Model:", ckpt_file) + model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=False).to(device) + +else: + raise ValueError("You must specify --pretrain_test or --finetune_test") + +if not os.path.exists(ckpt_file): + raise FileNotFoundError(f"Checkpoint not found: {ckpt_file}") + +# model.load_state_dict(torch.load(ckpt_file, map_location=device)) +# model.eval() + +ckpt = torch.load(ckpt_file, map_location=device) + +keys_to_remove = [key for key in ckpt.keys() if "masked_spec_embed" in key] +for key in keys_to_remove: + print(f"Removing unexpected key: {key}") + del ckpt[key] + +try: + model.load_state_dict(ckpt, strict=False) +except RuntimeError as e: + print("Model loading error:", e) + print("Trying to load entire model...") + model = torch.load(ckpt_file, map_location=device) +model.to(device) +model.eval() + +torch.cuda.empty_cache() + +if args.closed_test: + print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...") + test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels) +elif args.open_test: + print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...") + test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels) +else: + print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...") + test_dataset = FakeMusicCapsDataset(val_files, val_labels) + +test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16) + +def Test(model, test_loader, device, phase="Test"): + model.eval() + test_loss, test_correct, test_total = 0, 0, 0 + all_preds, all_labels = [], [] + + with torch.no_grad(): + for inputs, labels in tqdm(test_loader, desc=f"{phase}"): + inputs, labels = inputs.to(device), labels.to(device) + inputs = inputs.squeeze(1) # Ensure correct input shape + + output = model(inputs) + loss = F.cross_entropy(output, labels) + + test_loss += loss.item() * inputs.size(0) + preds = output.argmax(dim=1) + test_correct += (preds == labels).sum().item() + test_total += labels.size(0) + + all_labels.extend(labels.cpu().numpy()) + all_preds.extend(preds.cpu().numpy()) + + test_loss /= test_total + test_acc = test_correct / test_total + test_bal_acc = balanced_accuracy_score(all_labels, all_preds) + test_precision = precision_score(all_labels, all_preds, average="binary") + test_recall = recall_score(all_labels, all_preds, average="binary") + test_f1 = f1_score(all_labels, all_preds, average="binary") + + print(f"\n{phase} Test Results - Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.3f} | " + f"Test Balanced Acc: {test_bal_acc:.4f} | Test Precision: {test_precision:.3f} | " + f"Test Recall: {test_recall:.3f} | Test F1: {test_f1:.3f}") + + os.makedirs(args.output_path, exist_ok=True) + conf_matrix_path = os.path.join(args.output_path, f"confusion_matrix_{phase}_opentest.png") + plot_confusion_matrix(all_labels, all_preds, classes=["real", "generative"], output_path=conf_matrix_path) + +print("\nEvaluating Model on Test Set...") +Test(model, test_loader, device, phase="Pretrained Model" if args.pretrain_test else "Fine-Tuned Model") diff --git a/ISMIR_2025/wav2vec/utils/__pycache__/config.cpython-311.pyc b/ISMIR_2025/wav2vec/utils/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5105baa6b5cc20b8d2fe0a9b029b40c9e6f7c6af Binary files /dev/null and b/ISMIR_2025/wav2vec/utils/__pycache__/config.cpython-311.pyc differ diff --git a/ISMIR_2025/wav2vec/utils/__pycache__/idr_torch.cpython-311.pyc b/ISMIR_2025/wav2vec/utils/__pycache__/idr_torch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb9379cca8c0e244eb2a7b169fa2ce307f2bab05 Binary files /dev/null and b/ISMIR_2025/wav2vec/utils/__pycache__/idr_torch.cpython-311.pyc differ diff --git a/ISMIR_2025/wav2vec/utils/__pycache__/utilities.cpython-311.pyc b/ISMIR_2025/wav2vec/utils/__pycache__/utilities.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97ed2e09ebbddb16a923974a1ef3c30f2259ac34 Binary files /dev/null and b/ISMIR_2025/wav2vec/utils/__pycache__/utilities.cpython-311.pyc differ diff --git a/ISMIR_2025/wav2vec/utils/config.py b/ISMIR_2025/wav2vec/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..69f72ecd472eed266bb9a0d811d7eeb07a3c06db --- /dev/null +++ b/ISMIR_2025/wav2vec/utils/config.py @@ -0,0 +1,565 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import csv + +import numpy as np + +sample_rate = 32000 +clip_samples = sample_rate * 10 # Audio clips are 10-second + +# Load label +with open( + "/gpfswork/rech/djl/uzj43um/audio_retrieval/audioset_tagging_cnn/metadata/class_labels_indices.csv", + "r", +) as f: + reader = csv.reader(f, delimiter=",") + lines = list(reader) + +labels = [] +ids = [] # Each label has a unique id such as "/m/068hy" +for i1 in range(1, len(lines)): + id = lines[i1][1] + label = lines[i1][2] + ids.append(id) + labels.append(label) + +classes_num = len(labels) + +lb_to_ix = {label: i for i, label in enumerate(labels)} +ix_to_lb = {i: label for i, label in enumerate(labels)} + +id_to_ix = {id: i for i, id in enumerate(ids)} +ix_to_id = {i: id for i, id in enumerate(ids)} + +full_samples_per_class = np.array( + [ + 937432, + 16344, + 7822, + 10271, + 2043, + 14420, + 733, + 1511, + 1258, + 424, + 1751, + 704, + 369, + 590, + 1063, + 1375, + 5026, + 743, + 853, + 1648, + 714, + 1497, + 1251, + 2139, + 1093, + 133, + 224, + 39469, + 6423, + 407, + 1559, + 4546, + 6826, + 7464, + 2468, + 549, + 4063, + 334, + 587, + 238, + 1766, + 691, + 114, + 2153, + 236, + 209, + 421, + 740, + 269, + 959, + 137, + 4192, + 485, + 1515, + 655, + 274, + 69, + 157, + 1128, + 807, + 1022, + 346, + 98, + 680, + 890, + 352, + 4169, + 2061, + 1753, + 9883, + 1339, + 708, + 37857, + 18504, + 12864, + 2475, + 2182, + 757, + 3624, + 677, + 1683, + 3583, + 444, + 1780, + 2364, + 409, + 4060, + 3097, + 3143, + 502, + 723, + 600, + 230, + 852, + 1498, + 1865, + 1879, + 2429, + 5498, + 5430, + 2139, + 1761, + 1051, + 831, + 2401, + 2258, + 1672, + 1711, + 987, + 646, + 794, + 25061, + 5792, + 4256, + 96, + 8126, + 2740, + 752, + 513, + 554, + 106, + 254, + 1592, + 556, + 331, + 615, + 2841, + 737, + 265, + 1349, + 358, + 1731, + 1115, + 295, + 1070, + 972, + 174, + 937780, + 112337, + 42509, + 49200, + 11415, + 6092, + 13851, + 2665, + 1678, + 13344, + 2329, + 1415, + 2244, + 1099, + 5024, + 9872, + 10948, + 4409, + 2732, + 1211, + 1289, + 4807, + 5136, + 1867, + 16134, + 14519, + 3086, + 19261, + 6499, + 4273, + 2790, + 8820, + 1228, + 1575, + 4420, + 3685, + 2019, + 664, + 324, + 513, + 411, + 436, + 2997, + 5162, + 3806, + 1389, + 899, + 8088, + 7004, + 1105, + 3633, + 2621, + 9753, + 1082, + 26854, + 3415, + 4991, + 2129, + 5546, + 4489, + 2850, + 1977, + 1908, + 1719, + 1106, + 1049, + 152, + 136, + 802, + 488, + 592, + 2081, + 2712, + 1665, + 1128, + 250, + 544, + 789, + 2715, + 8063, + 7056, + 2267, + 8034, + 6092, + 3815, + 1833, + 3277, + 8813, + 2111, + 4662, + 2678, + 2954, + 5227, + 1472, + 2591, + 3714, + 1974, + 1795, + 4680, + 3751, + 6585, + 2109, + 36617, + 6083, + 16264, + 17351, + 3449, + 5034, + 3931, + 2599, + 4134, + 3892, + 2334, + 2211, + 4516, + 2766, + 2862, + 3422, + 1788, + 2544, + 2403, + 2892, + 4042, + 3460, + 1516, + 1972, + 1563, + 1579, + 2776, + 1647, + 4535, + 3921, + 1261, + 6074, + 2922, + 3068, + 1948, + 4407, + 712, + 1294, + 1019, + 1572, + 3764, + 5218, + 975, + 1539, + 6376, + 1606, + 6091, + 1138, + 1169, + 7925, + 3136, + 1108, + 2677, + 2680, + 1383, + 3144, + 2653, + 1986, + 1800, + 1308, + 1344, + 122231, + 12977, + 2552, + 2678, + 7824, + 768, + 8587, + 39503, + 3474, + 661, + 430, + 193, + 1405, + 1442, + 3588, + 6280, + 10515, + 785, + 710, + 305, + 206, + 4990, + 5329, + 3398, + 1771, + 3022, + 6907, + 1523, + 8588, + 12203, + 666, + 2113, + 7916, + 434, + 1636, + 5185, + 1062, + 664, + 952, + 3490, + 2811, + 2749, + 2848, + 15555, + 363, + 117, + 1494, + 1647, + 5886, + 4021, + 633, + 1013, + 5951, + 11343, + 2324, + 243, + 372, + 943, + 734, + 242, + 3161, + 122, + 127, + 201, + 1654, + 768, + 134, + 1467, + 642, + 1148, + 2156, + 1368, + 1176, + 302, + 1909, + 61, + 223, + 1812, + 287, + 422, + 311, + 228, + 748, + 230, + 1876, + 539, + 1814, + 737, + 689, + 1140, + 591, + 943, + 353, + 289, + 198, + 490, + 7938, + 1841, + 850, + 457, + 814, + 146, + 551, + 728, + 1627, + 620, + 648, + 1621, + 2731, + 535, + 88, + 1736, + 736, + 328, + 293, + 3170, + 344, + 384, + 7640, + 433, + 215, + 715, + 626, + 128, + 3059, + 1833, + 2069, + 3732, + 1640, + 1508, + 836, + 567, + 2837, + 1151, + 2068, + 695, + 1494, + 3173, + 364, + 88, + 188, + 740, + 677, + 273, + 1533, + 821, + 1091, + 293, + 647, + 318, + 1202, + 328, + 532, + 2847, + 526, + 721, + 370, + 258, + 956, + 1269, + 1641, + 339, + 1322, + 4485, + 286, + 1874, + 277, + 757, + 1393, + 1330, + 380, + 146, + 377, + 394, + 318, + 339, + 1477, + 1886, + 101, + 1435, + 284, + 1425, + 686, + 621, + 221, + 117, + 87, + 1340, + 201, + 1243, + 1222, + 651, + 1899, + 421, + 712, + 1016, + 1279, + 124, + 351, + 258, + 7043, + 368, + 666, + 162, + 7664, + 137, + 70159, + 26179, + 6321, + 32236, + 33320, + 771, + 1169, + 269, + 1103, + 444, + 364, + 2710, + 121, + 751, + 1609, + 855, + 1141, + 2287, + 1940, + 3943, + 289, + ] +) \ No newline at end of file diff --git a/ISMIR_2025/wav2vec/utils/confusion_matrix_plot.py b/ISMIR_2025/wav2vec/utils/confusion_matrix_plot.py new file mode 100644 index 0000000000000000000000000000000000000000..e57d6d77e51949970ea76d8400d78ed6540cc155 --- /dev/null +++ b/ISMIR_2025/wav2vec/utils/confusion_matrix_plot.py @@ -0,0 +1,29 @@ +from sklearn.metrics import confusion_matrix +import matplotlib.pyplot as plt +import numpy as np + +def plot_confusion_matrix(y_true, y_pred, classes, writer, epoch): + cm = confusion_matrix(y_true, y_pred) + fig, ax = plt.subplots(figsize=(6, 6)) + im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) + ax.figure.colorbar(im, ax=ax) + + num_classes = cm.shape[0] + tick_labels = classes[:num_classes] + + ax.set(xticks=np.arange(num_classes), + yticks=np.arange(num_classes), + xticklabels=tick_labels, + yticklabels=tick_labels, + ylabel='True label', + xlabel='Predicted label') + + thresh = cm.max() / 2. + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + ax.text(j, i, format(cm[i, j], 'd'), + ha="center", va="center", + color="white" if cm[i, j] > thresh else "black") + + fig.tight_layout() + writer.add_figure("Confusion Matrix", fig, epoch) \ No newline at end of file diff --git a/ISMIR_2025/wav2vec/utils/freqeuncy.py b/ISMIR_2025/wav2vec/utils/freqeuncy.py new file mode 100644 index 0000000000000000000000000000000000000000..b21c5222467ec4906c63e5b9d02052a69aeb67e2 --- /dev/null +++ b/ISMIR_2025/wav2vec/utils/freqeuncy.py @@ -0,0 +1,24 @@ +import librosa +import librosa.display +import numpy as np +import matplotlib.pyplot as plt + +# 🔹 오디오 파일 로드 +file_real = "/path/to/real_audio.wav" # Real 오디오 경로 +file_fake = "/path/to/generative_audio.wav" # AI 생성 오디오 경로 + +def plot_spectrogram(audio_file, title): + y, sr = librosa.load(audio_file, sr=16000) # 샘플링 레이트 16kHz + D = librosa.amplitude_to_db(np.abs(librosa.stft(y)), ref=np.max) # STFT 변환 + + plt.figure(figsize=(10, 4)) + librosa.display.specshow(D, sr=sr, x_axis='time', y_axis='hz', cmap='magma') + plt.colorbar(format='%+2.0f dB') + plt.title(title) + plt.ylim(4000, 16000) # 4kHz 이상 고주파 영역만 표시 + plt.show() + +# 🔹 Real vs Generative Spectrogram 비교 +plot_spectrogram(file_real, "Real Audio Spectrogram (4kHz+)") +plot_spectrogram(file_fake, "Generative Audio Spectrogram (4kHz+)") + diff --git a/ISMIR_2025/wav2vec/utils/hf_vis.py b/ISMIR_2025/wav2vec/utils/hf_vis.py new file mode 100644 index 0000000000000000000000000000000000000000..c99b61bfb27f99880b0c44313daf476e6c0c278f --- /dev/null +++ b/ISMIR_2025/wav2vec/utils/hf_vis.py @@ -0,0 +1,89 @@ +import librosa +import librosa.display +import numpy as np +import matplotlib.pyplot as plt +import scipy.signal as signal +import torch +import torch.nn as nn +import soundfile as sf + +from networks import audiocnn, AudioCNNWithViTDecoder, AudioCNNWithViTDecoderAndCrossAttention + + +def highpass_filter(y, sr, cutoff=500, order=5): + """High-pass filter to remove low frequencies below `cutoff` Hz.""" + nyquist = 0.5 * sr + normal_cutoff = cutoff / nyquist + b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) + y_filtered = signal.lfilter(b, a, y) + return y_filtered + +def plot_combined_visualization(y_original, y_filtered, sr, save_path="combined_visualization.png"): + """Plot waveform comparison and spectrograms in a single figure.""" + fig, axes = plt.subplots(3, 1, figsize=(12, 12)) + + # 1️⃣ Waveform Comparison + time = np.linspace(0, len(y_original) / sr, len(y_original)) + axes[0].plot(time, y_original, label='Original', alpha=0.7) + axes[0].plot(time, y_filtered, label='High-pass Filtered', alpha=0.7, linestyle='dashed') + axes[0].set_xlabel("Time (s)") + axes[0].set_ylabel("Amplitude") + axes[0].set_title("Waveform Comparison (Original vs High-pass Filtered)") + axes[0].legend() + + # 2️⃣ Spectrogram - Original + S_orig = librosa.amplitude_to_db(np.abs(librosa.stft(y_original)), ref=np.max) + img = librosa.display.specshow(S_orig, sr=sr, x_axis='time', y_axis='log', ax=axes[1]) + axes[1].set_title("Original Spectrogram") + fig.colorbar(img, ax=axes[1], format="%+2.0f dB") + + # 3️⃣ Spectrogram - High-pass Filtered + S_filt = librosa.amplitude_to_db(np.abs(librosa.stft(y_filtered)), ref=np.max) + img = librosa.display.specshow(S_filt, sr=sr, x_axis='time', y_axis='log', ax=axes[2]) + axes[2].set_title("High-pass Filtered Spectrogram") + fig.colorbar(img, ax=axes[2], format="%+2.0f dB") + + plt.tight_layout() + plt.savefig(save_path, dpi=300) + plt.show() + + +def load_model(checkpoint_path, model_class, device): + """Load a trained model from checkpoint.""" + model = model_class() + model.load_state_dict(torch.load(checkpoint_path, map_location=device)) + model.to(device) + model.eval() + return model + +def predict_audio(model, audio_tensor, device): + """Make predictions using a trained model.""" + with torch.no_grad(): + audio_tensor = audio_tensor.unsqueeze(0).to(device) # Add batch dimension + output = model(audio_tensor) + prediction = torch.argmax(output, dim=1).cpu().numpy()[0] + return prediction + +# Load audio +audio_path = "/data/kym/AI Music Detection/audio/FakeMusicCaps/real/musiccaps/_RrA-0lfIiU.wav" # Replace with actual file path +y, sr = librosa.load(audio_path, sr=None) +y_filtered = highpass_filter(y, sr, cutoff=500) + +# Convert audio to tensor +audio_tensor = torch.tensor(librosa.feature.melspectrogram(y=y, sr=sr), dtype=torch.float).unsqueeze(0) +audio_tensor_filtered = torch.tensor(librosa.feature.melspectrogram(y=y_filtered, sr=sr), dtype=torch.float).unsqueeze(0) + +# Load models +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +original_model = load_model("/data/kym/AI Music Detection/AudioCNN/ckpt/FakeMusicCaps/pretraining/best_model_audiocnn.pth", audiocnn, device) +highpass_model = load_model("/data/kym/AI Music Detection/AudioCNN/ckpt/FakeMusicCaps/500hz_Add_crossattn_decoder/best_model_AudioCNNWithViTDecoderAndCrossAttention.pth", AudioCNNWithViTDecoderAndCrossAttention, device) + +# Predict +original_pred = predict_audio(original_model, audio_tensor, device) +highpass_pred = predict_audio(highpass_model, audio_tensor_filtered, device) + +print(f"Original Model Prediction: {original_pred}") +print(f"High-pass Filter Model Prediction: {highpass_pred}") + +# Generate combined visualization (all plots in one image) +plot_combined_visualization(y, y_filtered, sr, save_path="/data/kym/AI Music Detection/AudioCNN/hf_vis/rawvs500.png") diff --git a/ISMIR_2025/wav2vec/utils/idr_torch.py b/ISMIR_2025/wav2vec/utils/idr_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..b2e76040394ce27390c27bd8ef022e126d8e55dc --- /dev/null +++ b/ISMIR_2025/wav2vec/utils/idr_torch.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python +# coding: utf-8 + +import os +import hostlist + +# get SLURM variables +# rank = int(os.environ["SLURM_PROCID"]) +local_rank = int(os.environ["SLURM_LOCALID"]) +size = int(os.environ["SLURM_NTASKS"]) +cpus_per_task = int(os.environ["SLURM_CPUS_PER_TASK"]) + +# get node list from slurm +hostnames = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"]) + +# get IDs of reserved GPU +gpu_ids = os.environ["SLURM_STEP_GPUS"].split(",") + +# define MASTER_ADD & MASTER_PORT +os.environ["MASTER_ADDR"] = hostnames[0] +os.environ["MASTER_PORT"] = str( + 12345 + int(min(gpu_ids)) +) # to avoid port conflict on the same node \ No newline at end of file diff --git a/ISMIR_2025/wav2vec/utils/mfcc.py b/ISMIR_2025/wav2vec/utils/mfcc.py new file mode 100644 index 0000000000000000000000000000000000000000..5d63db14375fedcc1cc60f2ef3cecf5c70e9a8fb --- /dev/null +++ b/ISMIR_2025/wav2vec/utils/mfcc.py @@ -0,0 +1,266 @@ +import os +import glob +import librosa +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import Dataset, DataLoader, random_split +import torch.nn.functional as F +from sklearn.metrics import precision_score, recall_score, f1_score +from tqdm import tqdm +import argparse +import wandb + +class RealFakeDataset(Dataset): + """ + audio/FakeMusicCaps/ + ├─ real/ + │ └─ MusicCaps/*.wav (label=0) + └─ generative/ + └─ .../*.wav (label=1) + """ + def __init__(self, root_dir, sr=16000, n_mels=64, target_duration=10.0): + + self.sr = sr + self.n_mels = n_mels + self.target_duration = target_duration + self.target_samples = int(target_duration * sr) # 10초 = 160,000 샘플 + + self.file_paths = [] + self.labels = [] + + # Real 데이터 (label=0) + real_dir = os.path.join(root_dir, "real") + real_wav_files = glob.glob(os.path.join(real_dir, "**", "*.wav"), recursive=True) + for f in real_wav_files: + self.file_paths.append(f) + self.labels.append(0) + + # Generative 데이터 (label=1) + gen_dir = os.path.join(root_dir, "generative") + gen_wav_files = glob.glob(os.path.join(gen_dir, "**", "*.wav"), recursive=True) + for f in gen_wav_files: + self.file_paths.append(f) + self.labels.append(1) + + def __len__(self): + return len(self.file_paths) + + def __getitem__(self, idx): + audio_path = self.file_paths[idx] + label = self.labels[idx] + # print(f"[DEBUG] Path: {audio_path}, Label: {label}") # 추가 + + waveform, sr = librosa.load(audio_path, sr=self.sr, mono=True) + + current_samples = waveform.shape[0] + if current_samples > self.target_samples: + waveform = waveform[:self.target_samples] + elif current_samples < self.target_samples: + stretch_factor = self.target_samples / current_samples + waveform = librosa.effects.time_stretch(waveform, rate=stretch_factor) + waveform = waveform[:self.target_samples] + + mfcc = librosa.feature.mfcc( + y=waveform, sr=self.sr, n_mfcc=self.n_mels, n_fft=1024, hop_length=256 + ) + mfcc = librosa.util.normalize(mfcc) + + mfcc = np.expand_dims(mfcc, axis=0) + mfcc_tensor = torch.tensor(mfcc, dtype=torch.float) + label_tensor = torch.tensor(label, dtype=torch.long) + + return mfcc_tensor, label_tensor + + + +class AudioCNN(nn.Module): + def __init__(self, num_classes=2): + super(AudioCNN, self).__init__() + self.conv_block = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Conv2d(16, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.AdaptiveAvgPool2d((4,4)) # 최종 -> (B,32,4,4) + ) + self.fc_block = nn.Sequential( + nn.Linear(32*4*4, 128), + nn.ReLU(), + nn.Linear(128, num_classes) + ) + + + def forward(self, x): + x = self.conv_block(x) + # x.shape: (B,32,new_freq,new_time) + + # 1) Flatten + B, C, H, W = x.shape # 동적 shape + x = x.view(B, -1) # (B, 32*H*W) + + # 2) FC + x = self.fc_block(x) + return x + + +def my_collate_fn(batch): + mel_list, label_list = zip(*batch) + + max_frames = max(m.shape[2] for m in mel_list) + + padded = [] + for m in mel_list: + diff = max_frames - m.shape[2] + if diff > 0: + print(f"Padding applied: Original frames = {m.shape[2]}, Target frames = {max_frames}") + m = F.pad(m, (0, diff), mode='constant', value=0) + padded.append(m) + + + mel_batch = torch.stack(padded, dim=0) + label_batch = torch.tensor(label_list, dtype=torch.long) + return mel_batch, label_batch + + +class EarlyStopping: + def __init__(self, patience=5, delta=0, path='./ckpt/mfcc/early_stop_best_batch_{batch_size}_epochs_{epochs}_lr_{learning_rate}.pth', verbose=False): + self.patience = patience + self.delta = delta + self.path = path + self.verbose = verbose + self.counter = 0 + self.best_loss = None + self.early_stop = False + + def __call__(self, val_loss, model): + if self.best_loss is None: + self.best_loss = val_loss + self._save_checkpoint(val_loss, model) + elif val_loss > self.best_loss - self.delta: + self.counter += 1 + if self.verbose: + print(f"EarlyStopping counter: {self.counter} out of {self.patience}") + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_loss = val_loss + self._save_checkpoint(val_loss, model) + self.counter = 0 + + def _save_checkpoint(self, val_loss, model): + if self.verbose: + print(f"Validation loss decreased ({self.best_loss:.6f} --> {val_loss:.6f}). Saving model ...") + torch.save(model.state_dict(), self.path) + +def train(batch_size, epochs, learning_rate, root_dir="audio/FakeMusicCaps"): + if not os.path.exists("./ckpt/mfcc/"): + os.makedirs("./ckpt/mfcc/") + + wandb.init( + project="AI Music Detection", + name=f"mfcc_batch_{batch_size}_epochs_{epochs}_lr_{learning_rate}", + config={"batch_size": batch_size, "epochs": epochs, "learning_rate": learning_rate}, + ) + + dataset = RealFakeDataset(root_dir=root_dir) + n_total = len(dataset) + n_train = int(n_total * 0.8) + n_val = n_total - n_train + train_ds, val_ds = random_split(dataset, [n_train, n_val]) + + train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn) + val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=my_collate_fn) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = AudioCNN(num_classes=2).to(device) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=learning_rate) + + best_val_loss = float('inf') + patience = 3 + patience_counter = 0 + + for epoch in range(1, epochs + 1): + print(f"\n[Epoch {epoch}/{epochs}]") + + # Training + model.train() + train_loss, train_correct, train_total = 0, 0, 0 + train_pbar = tqdm(train_loader, desc="Train", leave=False) + for mel_batch, labels in train_pbar: + mel_batch, labels = mel_batch.to(device), labels.to(device) + optimizer.zero_grad() + outputs = model(mel_batch) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + train_loss += loss.item() * mel_batch.size(0) + preds = outputs.argmax(dim=1) + train_correct += (preds == labels).sum().item() + train_total += labels.size(0) + + train_pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + + train_loss /= train_total + train_acc = train_correct / train_total + + # Validation + model.eval() + val_loss, val_correct, val_total = 0, 0, 0 + all_preds, all_labels = [], [] + val_pbar = tqdm(val_loader, desc=" Val ", leave=False) + with torch.no_grad(): + for mel_batch, labels in val_pbar: + mel_batch, labels = mel_batch.to(device), labels.to(device) + outputs = model(mel_batch) + loss = criterion(outputs, labels) + val_loss += loss.item() * mel_batch.size(0) + preds = outputs.argmax(dim=1) + val_correct += (preds == labels).sum().item() + val_total += labels.size(0) + all_preds.extend(preds.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + + val_loss /= val_total + val_acc = val_correct / val_total + val_precision = precision_score(all_labels, all_preds, average="macro") + val_recall = recall_score(all_labels, all_preds, average="macro") + val_f1 = f1_score(all_labels, all_preds, average="macro") + + print(f"Train Loss: {train_loss:.4f} Acc: {train_acc:.3f} | " + f"Val Loss: {val_loss:.4f} Acc: {val_acc:.3f} " + f"Precision: {val_precision:.3f} Recall: {val_recall:.3f} F1: {val_f1:.3f}") + + wandb.log({"train_loss": train_loss, "train_acc": train_acc, + "val_loss": val_loss, "val_acc": val_acc, + "val_precision": val_precision, "val_recall": val_recall, "val_f1": val_f1}) + + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + best_model_path = f"./ckpt/mfcc/best_batch_{batch_size}_epochs_{epochs}_lr_{learning_rate}.pth" + torch.save(model.state_dict(), best_model_path) + print(f"[INFO] New best model saved: {best_model_path}") + else: + patience_counter += 1 + if patience_counter >= patience: + print("Early stopping triggered!") + break + + wandb.finish() + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Train AI Music Detection model.") + parser.add_argument('--batch_size', type=int, required=True, help="Batch size for training") + parser.add_argument('--epochs', type=int, required=True, help="Number of epochs") + parser.add_argument('--learning_rate', type=float, required=True, help="Learning rate") + parser.add_argument('--root_dir', type=str, default="audio/FakeMusicCaps", help="Root directory for dataset") + + args = parser.parse_args() + + train(batch_size=args.batch_size, epochs=args.epochs, learning_rate=args.learning_rate, root_dir=args.root_dir) diff --git a/ISMIR_2025/wav2vec/utils/utilities.py b/ISMIR_2025/wav2vec/utils/utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..e0be98e8645b8bb1c838d3dc9ae49daac706df62 --- /dev/null +++ b/ISMIR_2025/wav2vec/utils/utilities.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import logging +import pickle + +import numpy as np + +from scipy import stats + +import csv +import json + +def create_folder(fd): + if not os.path.exists(fd): + os.makedirs(fd, exist_ok=True) + + +def get_filename(path): + path = os.path.realpath(path) + na_ext = path.split("/")[-1] + na = os.path.splitext(na_ext)[0] + return na + + +def get_sub_filepaths(folder): + paths = [] + for root, dirs, files in os.walk(folder): + for name in files: + path = os.path.join(root, name) + paths.append(path) + return paths + + +def create_logging(log_dir, filemode): + create_folder(log_dir) + i1 = 0 + + while os.path.isfile(os.path.join(log_dir, "{:04d}.log".format(i1))): + i1 += 1 + + log_path = os.path.join(log_dir, "{:04d}.log".format(i1)) + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s", + datefmt="%a, %d %b %Y %H:%M:%S", + filename=log_path, + filemode=filemode, + ) + + # Print to console + console = logging.StreamHandler() + console.setLevel(logging.INFO) + formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s") + console.setFormatter(formatter) + logging.getLogger("").addHandler(console) + + return logging + + +def read_metadata(csv_path, audio_dir, classes_num, id_to_ix): + """Read metadata of AudioSet from a csv file. + + Args: + csv_path: str + + Returns: + meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)} + """ + + with open(csv_path, "r") as fr: + lines = fr.readlines() + lines = lines[3:] # Remove heads + + # first, count the audio names only of existing files on disk only + + audios_num = 0 + for n, line in enumerate(lines): + items = line.split(", ") + """items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']""" + + # audio_name = 'Y{}.wav'.format(items[0]) # Audios are started with an extra 'Y' when downloading + audio_name = "{}_{}_{}.flac".format( + items[0], items[1].replace(".", ""), items[2].replace(".", "") + ) + audio_name = audio_name.replace("_0000_", "_0_") + + if os.path.exists(os.path.join(audio_dir, audio_name)): + audios_num += 1 + + print("CSV audio files: %d" % (len(lines))) + print("Existing audio files: %d" % audios_num) + + # audios_num = len(lines) + targets = np.zeros((audios_num, classes_num), dtype=bool) + audio_names = [] + + n = 0 + for line in lines: + items = line.split(", ") + """items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']""" + + # audio_name = 'Y{}.wav'.format(items[0]) # Audios are started with an extra 'Y' when downloading + audio_name = "{}_{}_{}.flac".format( + items[0], items[1].replace(".", ""), items[2].replace(".", "") + ) + audio_name = audio_name.replace("_0000_", "_0_") + + if not os.path.exists(os.path.join(audio_dir, audio_name)): + continue + + label_ids = items[3].split('"')[1].split(",") + + audio_names.append(audio_name) + + # Target + for id in label_ids: + ix = id_to_ix[id] + targets[n, ix] = 1 + n += 1 + + meta_dict = {"audio_name": np.array(audio_names), "target": targets} + return meta_dict + + +def read_audioset_ontology(id_to_ix): + with open('../metadata/audioset_ontology.json', 'r') as f: + data = json.load(f) + + # Output: {'name': 'Bob', 'languages': ['English', 'French']} + sentences = [] + for el in data: + print(el.keys()) + id = el['id'] + if id in id_to_ix: + name = el['name'] + desc = el['description'] + # if '(' in desc: + # print(name, '---', desc) + # print(id_to_ix[id], name, '---', ) + + # sent = name + # sent = name + ', ' + desc.replace('(', '').replace(')', '').lower() + # sent = desc.replace('(', '').replace(')', '').lower() + # sentences.append(sent) + sentences.append(desc) + # print(sent) + # break + return sentences + + +def original_read_metadata(csv_path, classes_num, id_to_ix): + """Read metadata of AudioSet from a csv file. + + Args: + csv_path: str + + Returns: + meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)} + """ + + with open(csv_path, "r") as fr: + lines = fr.readlines() + lines = lines[3:] # Remove heads + + # Thomas Pellegrini: added 02/12/2022 + # check if the audio files indeed exist, otherwise remove from list + + audios_num = len(lines) + targets = np.zeros((audios_num, classes_num), dtype=bool) + audio_names = [] + + for n, line in enumerate(lines): + items = line.split(", ") + """items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']""" + + audio_name = "{}_{}_{}.flac".format( + items[0], items[1].replace(".", ""), items[2].replace(".", "") + ) # Audios are started with an extra 'Y' when downloading + audio_name = audio_name.replace("_0000_", "_0_") + + label_ids = items[3].split('"')[1].split(",") + + audio_names.append(audio_name) + + # Target + for id in label_ids: + ix = id_to_ix[id] + targets[n, ix] = 1 + + meta_dict = {"audio_name": np.array(audio_names), "target": targets} + return meta_dict + +def read_audioset_label_tags(class_labels_indices_csv): + with open(class_labels_indices_csv, 'r') as f: + reader = csv.reader(f, delimiter=',') + lines = list(reader) + + labels = [] + ids = [] # Each label has a unique id such as "/m/068hy" + for i1 in range(1, len(lines)): + id = lines[i1][1] + label = lines[i1][2] + ids.append(id) + labels.append(label) + + classes_num = len(labels) + + lb_to_ix = {label : i for i, label in enumerate(labels)} + ix_to_lb = {i : label for i, label in enumerate(labels)} + + id_to_ix = {id : i for i, id in enumerate(ids)} + ix_to_id = {i : id for i, id in enumerate(ids)} + + return lb_to_ix, ix_to_lb, id_to_ix, ix_to_id + + + +def float32_to_int16(x): + # assert np.max(np.abs(x)) <= 1.5 + x = np.clip(x, -1, 1) + return (x * 32767.0).astype(np.int16) + + +def int16_to_float32(x): + return (x / 32767.0).astype(np.float32) + + +def pad_or_truncate(x, audio_length): + """Pad all audio to specific length.""" + if len(x) <= audio_length: + return np.concatenate((x, np.zeros(audio_length - len(x))), axis=0) + else: + return x[0:audio_length] + + +def pad_audio(x, audio_length): + """Pad all audio to specific length.""" + if len(x) <= audio_length: + return np.concatenate((x, np.zeros(audio_length - len(x))), axis=0) + else: + return x + + +def d_prime(auc): + d_prime = stats.norm().ppf(auc) * np.sqrt(2.0) + return d_prime + + +class Mixup(object): + def __init__(self, mixup_alpha, random_seed=1234): + """Mixup coefficient generator.""" + self.mixup_alpha = mixup_alpha + self.random_state = np.random.RandomState(random_seed) + + def get_lambda(self, batch_size): + """Get mixup random coefficients. + Args: + batch_size: int + Returns: + mixup_lambdas: (batch_size,) + """ + mixup_lambdas = [] + for n in range(0, batch_size, 2): + lam = self.random_state.beta(self.mixup_alpha, self.mixup_alpha, 1)[0] + mixup_lambdas.append(lam) + mixup_lambdas.append(1.0 - lam) + + return np.array(mixup_lambdas) + + +class StatisticsContainer(object): + def __init__(self, statistics_path): + """Contain statistics of different training iterations.""" + self.statistics_path = statistics_path + + self.backup_statistics_path = "{}_{}.pkl".format( + os.path.splitext(self.statistics_path)[0], + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), + ) + + self.statistics_dict = {"bal": [], "test": []} + + def append(self, iteration, statistics, data_type): + statistics["iteration"] = iteration + self.statistics_dict[data_type].append(statistics) + + def dump(self): + pickle.dump(self.statistics_dict, open(self.statistics_path, "wb")) + pickle.dump(self.statistics_dict, open(self.backup_statistics_path, "wb")) + logging.info(" Dump statistics to {}".format(self.statistics_path)) + logging.info(" Dump statistics to {}".format(self.backup_statistics_path)) + + def load_state_dict(self, resume_iteration): + self.statistics_dict = pickle.load(open(self.statistics_path, "rb")) + + resume_statistics_dict = {"bal": [], "test": []} + + for key in self.statistics_dict.keys(): + for statistics in self.statistics_dict[key]: + if statistics["iteration"] <= resume_iteration: + resume_statistics_dict[key].append(statistics) + + self.statistics_dict = resume_statistics_dict \ No newline at end of file diff --git a/__pycache__/celery_app.cpython-39.pyc b/__pycache__/celery_app.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c08b1351de8e2de48b7c83de5567bb9b6f4d4431 Binary files /dev/null and b/__pycache__/celery_app.cpython-39.pyc differ diff --git a/__pycache__/database.cpython-39.pyc b/__pycache__/database.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca38004965f54eb0157b7f290d76faaa2e243df5 Binary files /dev/null and b/__pycache__/database.cpython-39.pyc differ diff --git a/__pycache__/dataset_f.cpython-39.pyc b/__pycache__/dataset_f.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d625befb1ec00cc9487bd9c899f67a11814b0135 Binary files /dev/null and b/__pycache__/dataset_f.cpython-39.pyc differ diff --git a/__pycache__/inference.cpython-39.pyc b/__pycache__/inference.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69de5c7c7de7b181468ed2904ce41fb3b6557cce Binary files /dev/null and b/__pycache__/inference.cpython-39.pyc differ diff --git a/__pycache__/model.cpython-39.pyc b/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cf8638dfed944d9886879536c90198b5c45b5e7 Binary files /dev/null and b/__pycache__/model.cpython-39.pyc differ diff --git a/__pycache__/model_with_pure_bert.cpython-39.pyc b/__pycache__/model_with_pure_bert.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82df6b1a9279c41bda9169d6bb267a3a3e9bdd65 Binary files /dev/null and b/__pycache__/model_with_pure_bert.cpython-39.pyc differ diff --git a/__pycache__/preprocess.cpython-39.pyc b/__pycache__/preprocess.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfc4df308056e0758702ea2a21cff3f2a867ef5f Binary files /dev/null and b/__pycache__/preprocess.cpython-39.pyc differ diff --git a/__pycache__/web_inference.cpython-39.pyc b/__pycache__/web_inference.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5876d16b05753086aa5412012d200329746f3784 Binary files /dev/null and b/__pycache__/web_inference.cpython-39.pyc differ diff --git a/__pycache__/worker.cpython-39.pyc b/__pycache__/worker.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c97cfb9c721c6c77059d3da3ca5bc1439ad389ac Binary files /dev/null and b/__pycache__/worker.cpython-39.pyc differ diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..963f42d014051c702dc993826901ccd5dc869e88 --- /dev/null +++ b/app.py @@ -0,0 +1,11 @@ +import gradio as gr +from inference import inference_with_audio + +def inference(audio): + + result = inference_with_audio(audio) + + return result + +iface = gr.Interface(fn=inference, inputs=gr.Audio(type="filepath"), outputs="text") +iface.launch(share=True, repo_id="Ai-Detection-Segment-Transformer") \ No newline at end of file diff --git a/celery_app.py b/celery_app.py new file mode 100644 index 0000000000000000000000000000000000000000..beb7630661b35841de68fdd6dcb1666f945cca85 --- /dev/null +++ b/celery_app.py @@ -0,0 +1,56 @@ +from celery import Celery +from dotenv import load_dotenv +import os +import torch +import random +import numpy as np +import torch +import os +import hashlib +import sys +import uuid +from db_models.music import * +from db_models.billing import * +from db_models.notification import * +from db_models.user import * +from db_models.data import * + +os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'quick-hangout-422604-e3-5f7544b49207.json' +# 환경 변수 로드 +load_dotenv() + + +REDIS_PASSWORD = os.getenv('TEST_REDIS_PASSWORD') +REDIS_PORT = os.getenv('TEST_REDIS_PORT') +REDIS_HOST = os.getenv('TEST_REDIS_HOST') + + +# # 환경 변수에서 Redis 연결 정보 가져오기 +# REDIS_PASSWORD = os.getenv('REDIS_PASSWORD') +# REDIS_PORT = os.getenv('REDIS_PORT') +# REDIS_HOST = os.getenv("REDIS_HOST") + +common_conf = { + 'broker_connection_retry_on_startup': True, # 시작 시 연결 재시도 활성화 + 'broker_transport_options': { + 'max_retries': 5, # 최대 재시도 횟수 + 'interval_start': 0.1, # 첫 재시도 간격 (초) + 'interval_step': 0.2, # 재시도 간격 증가 (초) + 'interval_max': 0.5, # 최대 재시도 간격 (초) + }, + 'broker_connection_timeout': None, # 연결 시간 초과 (초), None으로 설정하여 무제한 + 'worker_concurrency': 1, # 동시성을 1로 설정 +} + + +# 신버전 음원 처리용 Celery 애플리케이션 +AI_detection_celery_app = Celery( + 'mippia', + broker=f'redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/0', + backend=f'redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/1', +) + +AI_detection_celery_app.conf.update(common_conf) +AI_detection_celery_app.conf.task_routes = {"worker.AI_detection_task": "AI-detection-queue"} +AI_detection_celery_app.conf.task_default_queue = "AI-detection-queue" +from worker import AI_detection_task # worker에서 태스크를 등록합니다. \ No newline at end of file diff --git a/commands.txt b/commands.txt new file mode 100644 index 0000000000000000000000000000000000000000..f49639413bd0747fdc7dfbfc517673bfd6fc20b3 --- /dev/null +++ b/commands.txt @@ -0,0 +1,13 @@ +celery -A celery_app.AI_detection_celery_app worker --loglevel=info --pool=solo + +CUDA_VISIBLE_DEVICES="5" python embedding_train.py MERT +CUDA_VISIBLE_DEVICES="6" python embedding_train.py music2vec +CUDA_VISIBLE_DEVICES="7" python embedding_train.py wav2vec +CUDA_VISIBLE_DEVICES="4" python embedding_train.py ccv + + +1. Default Experiments + +python test.py wav2vec with_embedding_wav2vec_768_embedding_process_comp/EmbeddingModel_wav2vec_768-epoch=0469-val_loss=0.4800-val_acc=0.9389-val_f1=0.9377-val_precision=0.9219-val_recall=0.9540.ckpt both +python test.py music2vec /data/gsh/AI_detection/with_embedding_music2vec_768_embedding_process_comp/EmbeddingModel_music2vec_768-epoch=0147-val_loss=0.1940-val_acc=0.9419-val_f1=0.9400-val_precision=0.9370-val_recall=0.9429.ckpt both +python test.py MERT /data/gsh/AI_detection/with_embedding_MERT_768_embedding/EmbeddingModel_MERT_768-epoch=0353-val_loss=0.3866-val_acc=0.9809-val_f1=0.9803-val_precision=0.9764-val_recall=0.9842.ckpt both diff --git a/database.py b/database.py new file mode 100644 index 0000000000000000000000000000000000000000..486fd124568cea57538328872ede4d8265303477 --- /dev/null +++ b/database.py @@ -0,0 +1,28 @@ +import os +from dotenv import load_dotenv +from sqlalchemy import create_engine, Column, Integer, String, ForeignKey, Boolean, DateTime +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker +load_dotenv() +SQLALCHEMY_DATABASE_URL = os.getenv("SQLALCHEMY_DATABASE_URL") + + +engine = create_engine( + SQLALCHEMY_DATABASE_URL, + pool_size=10, + max_overflow=20, + pool_timeout=30, + pool_recycle=1800, +) + +Base = declarative_base() +Base.metadata.create_all(bind=engine) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() diff --git a/dataset_f.py b/dataset_f.py new file mode 100644 index 0000000000000000000000000000000000000000..a03a0a0b3dabb797539f687e849aa9f887c7ce4e --- /dev/null +++ b/dataset_f.py @@ -0,0 +1,158 @@ +import os +import glob +import torch +import torchaudio +import librosa +import numpy as np +from torch.utils.data import Dataset +import torch +import torchaudio +from transformers import Wav2Vec2FeatureExtractor +import scipy.signal as signal +import scipy.signal +import random + +class FakeMusicCapsDataset(Dataset): + def __init__(self, file_paths, labels, sr=16000, target_duration=10.0): + self.file_paths = file_paths + self.labels = labels + self.sr = sr + self.target_samples = int(target_duration * sr) # Fixed length: 10 seconds + self.processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True) + + def __len__(self): + return len(self.file_paths) + + def pre_emphasis(self, x, alpha=0.97): + return np.append(x[0], x[1:] - alpha * x[:-1]) + + def highpass_filter(self, y, sr, cutoff=1000, order=5): + if isinstance(sr, np.ndarray): + sr = np.mean(sr) + if not isinstance(sr, (int, float)): + raise ValueError(f"[ERROR] sr must be a number, but got {type(sr)}: {sr}") + if sr <= 0: + raise ValueError(f"Invalid sample rate: {sr}. It must be greater than 0.") + nyquist = 0.5 * sr + if cutoff <= 0 or cutoff >= nyquist: + print(f"[WARNING] Invalid cutoff frequency {cutoff}, adjusting...") + cutoff = max(10, min(cutoff, nyquist - 1)) + normal_cutoff = cutoff / nyquist + b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) + y_filtered = signal.lfilter(b, a, y) + return y_filtered + + # 시간 조절(Time Stretch), 이퀄라이저 조정(EQ), 리버브 추가 + def augment_audio(self, y, sr): + if isinstance(y, torch.Tensor): + y = y.numpy() # Tensor → Numpy 변환 + + if random.random() < 0.5: # 시간 조절 (Time Stretch) + rate = random.uniform(0.8, 1.2) + y = librosa.effects.time_stretch(y=y, rate=rate) + + if random.random() < 0.5: # 피치 시프트 (Pitch Shift) + n_steps = random.randint(-2, 2) + y = librosa.effects.pitch_shift(y=y, sr=sr, n_steps=n_steps) + + if random.random() < 0.5: # 화이트 노이즈 추가 (White Noise Addition) + noise_level = np.random.uniform(0.001, 0.005) + y = y + np.random.normal(0, noise_level, y.shape) + + + return torch.tensor(y, dtype=torch.float32) # 다시 Tensor로 변환 + + def __getitem__(self, idx): + audio_path = self.file_paths[idx] + label = self.labels[idx] + + waveform, sr = torchaudio.load(audio_path) + + target_sr = self.processor.sampling_rate + + if sr != target_sr: + resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr) + waveform = resampler(waveform) + + waveform = waveform.mean(dim=0).squeeze(0) + if label == 0: + waveform = self.augment_audio(waveform, self.sr) + if label == 1: + waveform = self.highpass_filter(waveform, self.sr) + # waveform = self.pre_emphasis(waveform) + waveform = self.augment_audio(waveform, self.sr) + # if label == 1: + # waveform = self.pre_emphasis(waveform) + # waveform = torch.tensor(waveform, dtype=torch.float32) + + + current_samples = waveform.shape[0] + if current_samples > self.target_samples: + waveform = waveform[:self.target_samples] # Truncate + elif current_samples < self.target_samples: + pad_length = self.target_samples - current_samples + waveform = torch.nn.functional.pad(waveform, (0, pad_length)) # Pad + + if isinstance(waveform, torch.Tensor): + waveform = waveform.numpy() # Tensor일 경우에만 변환 + print(waveform.shape) + inputs = self.processor(waveform, sampling_rate=target_sr, return_tensors="pt", padding=True) + print(inputs["input_values"].shape) + + return inputs["input_values"].squeeze(0), torch.tensor(label, dtype=torch.long) # [1, time] → [time] + + @staticmethod + def collate_fn(batch, target_samples=16000 * 10): + + inputs, labels = zip(*batch) # Unzip batch + + processed_inputs = [] + for waveform in inputs: + current_samples = waveform.shape[0] + + if current_samples > target_samples: + start_idx = (current_samples - target_samples) // 2 + cropped_waveform = waveform[start_idx:start_idx + target_samples] + else: + pad_length = target_samples - current_samples + cropped_waveform = torch.nn.functional.pad(waveform, (0, pad_length)) + + processed_inputs.append(cropped_waveform) + + processed_inputs = torch.stack(processed_inputs) # [batch, target_samples] + labels = torch.tensor(labels, dtype=torch.long) # [batch] + + return processed_inputs, labels + +def preprocess_audio(audio_path, target_sr=16000, max_length=160000): + """ + 오디오를 모델 입력에 맞게 변환 + - target_sr: 16kHz로 변환 + - max_length: 최대 길이 160000 (10초) + """ + waveform, sr = torchaudio.load(audio_path) + + # Resample if needed + if sr != target_sr: + waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(waveform) + + # Convert to mono + waveform = waveform.mean(dim=0).unsqueeze(0) # (1, sequence_length) + + current_samples = waveform.shape[1] + if current_samples > max_length: + start_idx = (current_samples - max_length) // 2 + waveform = waveform[:, start_idx:start_idx + max_length] + elif current_samples < max_length: + pad_length = max_length - current_samples + waveform = torch.nn.functional.pad(waveform, (0, pad_length)) + + return waveform + + +def collect_files(base_path): + real_files = glob.glob(os.path.join(base_path, "real", "**", "*.wav"), recursive=True) + gen_files = glob.glob(os.path.join(base_path, "generative", "**", "*.wav"), recursive=True) + real_labels = [0] * len(real_files) + gen_labels = [1] * len(gen_files) + return real_files + gen_files, real_labels + gen_labels diff --git a/db_models/__init__.py b/db_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/db_models/__pycache__/__init__.cpython-312.pyc b/db_models/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc694465d6fd66efac3d32fb376b2049fe0b142a Binary files /dev/null and b/db_models/__pycache__/__init__.cpython-312.pyc differ diff --git a/db_models/__pycache__/__init__.cpython-39.pyc b/db_models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..737adc5930c6165d2a708245cc0b736c58768b91 Binary files /dev/null and b/db_models/__pycache__/__init__.cpython-39.pyc differ diff --git a/db_models/__pycache__/analysis.cpython-312.pyc b/db_models/__pycache__/analysis.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59d26f0c4d3b934030a0b0cd442d6c07f6abe02d Binary files /dev/null and b/db_models/__pycache__/analysis.cpython-312.pyc differ diff --git a/db_models/__pycache__/analysis.cpython-39.pyc b/db_models/__pycache__/analysis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61b48df2cb6ade7c5b3be0a38d8954877e83b85d Binary files /dev/null and b/db_models/__pycache__/analysis.cpython-39.pyc differ diff --git a/db_models/__pycache__/billing.cpython-312.pyc b/db_models/__pycache__/billing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9051afdae411d00422ac800fba6cfda087f24804 Binary files /dev/null and b/db_models/__pycache__/billing.cpython-312.pyc differ diff --git a/db_models/__pycache__/billing.cpython-39.pyc b/db_models/__pycache__/billing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d866e86197effa0d3ab023f83b0bcd67e5af9743 Binary files /dev/null and b/db_models/__pycache__/billing.cpython-39.pyc differ diff --git a/db_models/__pycache__/data.cpython-312.pyc b/db_models/__pycache__/data.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ae400db386957b124af8853567120da13c2a985 Binary files /dev/null and b/db_models/__pycache__/data.cpython-312.pyc differ diff --git a/db_models/__pycache__/data.cpython-39.pyc b/db_models/__pycache__/data.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc516ff58b2785e112673abee054e32340fff857 Binary files /dev/null and b/db_models/__pycache__/data.cpython-39.pyc differ diff --git a/db_models/__pycache__/music.cpython-312.pyc b/db_models/__pycache__/music.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c8bdab77f7813ec4698c5455c054c74214b78e4 Binary files /dev/null and b/db_models/__pycache__/music.cpython-312.pyc differ diff --git a/db_models/__pycache__/music.cpython-39.pyc b/db_models/__pycache__/music.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79974404b32b7e20afe152eee3f75963e569c92f Binary files /dev/null and b/db_models/__pycache__/music.cpython-39.pyc differ diff --git a/db_models/__pycache__/notification.cpython-312.pyc b/db_models/__pycache__/notification.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4540169d88ced2e0965028bfb8a554f228a56c33 Binary files /dev/null and b/db_models/__pycache__/notification.cpython-312.pyc differ diff --git a/db_models/__pycache__/notification.cpython-39.pyc b/db_models/__pycache__/notification.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75349876746740d79c1a595fa39f5af3bc91d306 Binary files /dev/null and b/db_models/__pycache__/notification.cpython-39.pyc differ diff --git a/db_models/__pycache__/user.cpython-312.pyc b/db_models/__pycache__/user.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..876451f42b65da2032f46d3fc02ed174750d6409 Binary files /dev/null and b/db_models/__pycache__/user.cpython-312.pyc differ diff --git a/db_models/__pycache__/user.cpython-39.pyc b/db_models/__pycache__/user.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7630f6936befa25f4b5796cd5bfc18d577813de5 Binary files /dev/null and b/db_models/__pycache__/user.cpython-39.pyc differ diff --git a/db_models/analysis.py b/db_models/analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..c4011e63a3dcaa25a397e37ed967b4a456754b21 --- /dev/null +++ b/db_models/analysis.py @@ -0,0 +1,82 @@ +from sqlalchemy import Column, Integer, Float, String, DateTime, Boolean, JSON, ForeignKey, Enum as SQLEnum +from enum import Enum +from sqlalchemy import Enum as SQLAlchemyEnum +from sqlalchemy.orm import relationship, backref +from datetime import datetime +from sqlalchemy.ext.declarative import declarative_base +from database import Base + + + + +class AnalysisStatus(str, Enum): + PENDING = "pending" # 분석 대기 중 + PROCESSING = "processing" # 분석 진행 중 + COMPARING = "comparing" # 분석 결과 비교 중 + COMPLETED = "completed" # 분석 완료 + FAILED = "failed" # 분석 실패 + +class Music_Analysis(Base): # Music 간 분석 관계를 나타내는 테이블 + __tablename__ = "music_analysis" + + id = Column(Integer, primary_key=True, index=True) + user_id = Column(Integer, ForeignKey("users.id")) + user = relationship("User", back_populates="music_analysis") + music1_id = Column(Integer, ForeignKey("musics.id"), nullable=False) # 첫 번째 Music + music2_id = Column(Integer, ForeignKey("musics.id"), nullable=False) # 두 번째 Music + + analysis_results = relationship("Analysis_Result", back_populates="music_analysis", cascade="all, delete-orphan") + compare_results = relationship("Compare_Result", back_populates="music_analysis", cascade="all, delete-orphan") + music1 = relationship("Music", foreign_keys=[music1_id]) # 관계 정의 + music2 = relationship("Music", foreign_keys=[music2_id]) # 관계 정의 + requested_at = Column(DateTime, default=datetime.utcnow) # 요청 시간 + status = Column(SQLAlchemyEnum(AnalysisStatus), default=AnalysisStatus.COMPLETED) + progress = Column(Float, default=0.0) # 1대1 분석 진행률. 야매로 넣습니다 + deleted = Column(Boolean, default=False) # 삭제 여부 + + + def __repr__(self): + return f"" + + @property + def user_name(self): + return self.user.user_info.name if self.user and self.user.user_info else None + + @property + def music1_title(self): + return self.music1.title if self.music1 else None + + @property + def music2_title(self): + return self.music2.title if self.music2 else None + + def __repr__(self): + return f"" + +class Analysis_Result(Base): + __tablename__ = "analysis_results" + + id = Column(Integer, primary_key=True, index=True) + analysis_id = Column(Integer, ForeignKey("music_analysis.id"), nullable=False) # Music_Analysis와의 외래 키 관계 + music_analysis = relationship("Music_Analysis", back_populates="analysis_results") + + requested_at = Column(DateTime, default=datetime.utcnow) # 요청 시간 + music1_json_path = Column(String) # 첫 번째 Music JSON 파일 경로 + music2_json_path = Column(String) # 두 번째 Music JSON 파일 경로 + + + def __repr__(self): + return f"" + +class Compare_Result(Base): + __tablename__ = "compare_results" + + id = Column(Integer, primary_key=True, index=True) + analysis_id = Column(Integer, ForeignKey("music_analysis.id"), nullable=False) # Music_Analysis와의 외래 키 관계 + music_analysis = relationship("Music_Analysis", back_populates="compare_results") # 여기를 수정 + + requested_at = Column(DateTime, default=datetime.utcnow) # 요청 시간 + compare_result_path = Column(String) # 분석 결과 파일 경로 + + def __repr__(self): + return f"" \ No newline at end of file diff --git a/db_models/billing.py b/db_models/billing.py new file mode 100644 index 0000000000000000000000000000000000000000..93e3920ec4d1c5687cb59af6d4be044c1f27a874 --- /dev/null +++ b/db_models/billing.py @@ -0,0 +1,80 @@ +from sqlalchemy import Column, Integer, Float, String, DateTime, Boolean, JSON, ForeignKey, Enum as SQLEnum +from sqlalchemy.orm import relationship, backref +from datetime import datetime +from sqlalchemy.ext.declarative import declarative_base +from database import Base + +class Subscription(Base): + __tablename__ = "subscriptions" + + id = Column(Integer, primary_key=True, index=True) + user_id = Column(Integer, ForeignKey('users.id'), nullable=False) + plan_name = Column(String) # 구독 플랜 이름 + start_time = Column(DateTime) # 구독 시작 시간 + end_time = Column(DateTime, nullable=True) # 구독 종료 시간 + billing_key = Column(String) # 결제 키 + schedule_id = Column(String) # 스케줄 ID + issue_id = Column(String) # 이슈 ID + customer_id = Column(String) # 고객 ID + billing_method = Column(String, nullable = True) # 결제 수단, billing_key와 역할이 겹칠수도 있긴함 + billing_cycle = Column(String, nullable = True) # 결제 주기. year or month + next_billing_time = Column(DateTime, nullable=True) # 다음 결제 시간 #같은 일자이고, 없는 경우 말일 + is_active = Column(Boolean, default=True) # 구독 활성 상태 + deleted = Column(Boolean, default = False) + user = relationship("User", back_populates="subscriptions") + def __repr__(self): + return f"" + +class Payment(Base): + __tablename__ = "payments" + + id = Column(Integer, primary_key=True, index=True) + user_id = Column(Integer, ForeignKey('users.id'), nullable=False) + subscription_id = Column(Integer, ForeignKey('subscriptions.id'), nullable=False) + portone_tx_id = Column(String) # 포트원 주문번호(트랜잭션 ID) + amount = Column(Float) # 결제 금액 + currency = Column(String) # 결제 통화 + dollar_amount = Column(Float) # Scale 적용없는, 진짜 달러 + payment_time = Column(DateTime, default=datetime.utcnow) # 결제 날짜 + payment_method = Column(String) # 결제 방법 + payment_id = Column(String) # 결제 ID + issue_id = Column(String) + receipt_url = Column(String) # 영수증 URL + status = Column(String) # 결제 상태 + user = relationship("User", back_populates="payments") + subscription = relationship("Subscription", backref=backref("payments", cascade="all, delete-orphan")) + deleted = Column(Boolean, default = False) + payment_number = Column(String, unique=True) # 결제 번호 + + def __repr__(self): + return f"" + + def generate_payment_number(self): + print(self.payment_time) + unique_string = f"{self.payment_time.strftime('%Y%m%d')}{self.user_id}{self.id:010d}" + return unique_string + +class Billing_Key(Base): + __tablename__ = "billing_keys" + + id = Column(Integer, primary_key=True, index=True) + user_id = Column(Integer, ForeignKey('users.id'), nullable=False) + billing_key = Column(String) # 결제 키 + deleted = Column(Boolean, default = False) + requested_at = Column(DateTime, default=datetime.utcnow) # 요청 날짜 + payment_method = Column(String) # toss payments or paypal + billing_key_json = Column(JSON) + + def __repr__(self): + return f"" + +class Exchange_Rate(Base): + __tablename__ = "exchange_rates" + + id = Column(Integer, primary_key=True, index=True) + currency = Column(String) # 통화 + rate = Column(Float) # 환율 + updated_at = Column(DateTime, default=datetime.utcnow) # 업데이트 날짜 + + def __repr__(self): + return f"" \ No newline at end of file diff --git a/db_models/data.py b/db_models/data.py new file mode 100644 index 0000000000000000000000000000000000000000..705f52569ee55de7c5534793241783cf2f6aac44 --- /dev/null +++ b/db_models/data.py @@ -0,0 +1,177 @@ +from sqlalchemy import Column, Integer, String, ARRAY, DateTime, Boolean, ForeignKey, Text, Enum +from sqlalchemy.orm import relationship +from datetime import datetime +from database import Base +from enum import Enum as PyEnum + +class Like(Base): + __tablename__ = 'likes' + id = Column(Integer, primary_key=True, index=True) + post_id = Column(Integer, ForeignKey('posts.id'), nullable=False) + user_id = Column(Integer, ForeignKey('users.id'), nullable=False) + + post = relationship("Post", back_populates="likes") + user = relationship("User", back_populates="likes") + +class Post(Base): + __tablename__ = "posts" + + id = Column(Integer, primary_key=True, index=True) + title = Column(String, nullable=False) + content = Column(String, nullable=False) + created_at = Column(DateTime, default=datetime.utcnow) + user_id = Column(Integer, ForeignKey("users.id")) + user = relationship("User", back_populates="posts") + requested_at = Column(DateTime, default=datetime.utcnow) + likes = relationship("Like", back_populates="post", cascade="all, delete-orphan") + comments = relationship('Comment', back_populates='post', cascade='all, delete-orphan') + deleted = Column(Boolean, default=False) + + + @property + def user_name(self): + return self.user.user_info.name if self.user and self.user.user_info else None + + @property + def user_nickname(self): + return self.user.user_info.nickname if self.user and self.user.user_info else None + + @property + def user_profile_image(self): + return self.user.user_info.profile_image_link if self.user and self.user.user_info else "https://storage.cloud.google.com/mippia-userfile-bucket/userfiles/blank-profile-picture.svg" + + + def to_dict(self): + comment_count = 0 + for comment in self.comments: + if not comment.deleted: + comment_count += 1 + return { + "id": self.id, + "content": self.content, + "created_at": self.created_at, + "nickname": self.user_nickname, # 프로퍼티를 포함 + "user_name": self.user_name, + "title" : self.title, + "likes" : len(self.likes), + "profile_image" : self.user_profile_image, + "comment_count" : comment_count, + "user_id" : self.user_id + # 다른 필드 추가 가능 + } + + def __repr__(self): + return f"" + +class Comment(Base): + __tablename__ = "comments" + + id = Column(Integer, primary_key=True, index=True) + content = Column(String, nullable=False) + created_at = Column(DateTime, default=datetime.utcnow) + post_id = Column(Integer, ForeignKey('posts.id')) + user_id = Column(Integer, ForeignKey('users.id')) # 댓글을 작성한 사용자의 ID를 나타내는 외래 키 추가 + parent_id = Column(Integer, ForeignKey('comments.id'), nullable=True) + deleted = Column(Boolean, default=False) + post = relationship('Post', back_populates='comments') # Comment와 Post 간의 관계 설정 + user = relationship('User', back_populates='comments') # Comment와 User 간의 관계 설정 + parent = relationship('Comment', remote_side=[id], backref='children') # 부모-자식 관계 설정 + + @property + def user_name(self): + return self.user.user_info.name if self.user and self.user.user_info else None + + @property + def user_nickname(self): + return self.user.user_info.nickname if self.user and self.user.user_info else None + + @property + def profile_image(self): + return ( + self.user.user_info.profile_image_link + if self.user and self.user.user_info + else "https://storage.cloud.google.com/mippia-userfile-bucket/userfiles/blank-profile-picture.svg" + ) + + def __repr__(self): + return f"" + + def to_dict(self): + return { + "id": self.id, + "content": self.content, + "created_at": self.created_at, + "post_id": self.post_id, # 프로퍼티를 포함 + "user_id": self.user_id, + "parent_id" : self.parent_id, + "user_name" : self.user_name, + "nickname" : self.user_nickname, + "profile_image":self.profile_image, + # 다른 필드 추가 가능 + } + + +class Ask(Base): + __tablename__ = "asks" + + id = Column(Integer, primary_key=True, index=True) + title = Column(String, nullable=False) + content = Column(String, nullable=False) + response = Column(String, nullable=True) + responsed = Column(String, nullable=False) #will deprecate + is_responsed = Column(Boolean, default = False, nullable = False) + is_read = Column(Boolean, default = False, nullable = False) + created_at = Column(DateTime, default=datetime.utcnow) + user_id = Column(Integer, ForeignKey("users.id"), nullable = True) + user_name = Column(String) + email = Column(String) + ask_type = Column(String) + file_path = Column(String) + user = relationship("User", back_populates="asks") + requested_at = Column(DateTime, default=datetime.utcnow) + privacyAgree = Column(Boolean, default=False) + deleted = Column(Boolean, default=False) + + + def __repr__(self): + return f"" + + def to_dict(self): + return { + "id": self.id, + "title": self.title, + "content": self.content, + "created_at": self.created_at, + "user_id": self.user_id, + "user_name" : self.user_name, + "is_responsed" : self.is_responsed, + "response" : self.response + # 다른 필드 추가 가능 + } + +class Notice(Base): + __tablename__ = "notice" + id = Column(Integer, primary_key=True, index=True) + ko_title = Column(String, nullable=False) + en_title = Column(String, nullable=False) + created_at = Column(DateTime, default=datetime.utcnow) + en_content = Column(Text, nullable=False) + ko_content = Column(Text, nullable=False) + def __repr__(self): + return f", en_content='{self.en_content}'" + + +class Report(Base): + __tablename__ = "reports" + + id = Column(Integer, primary_key=True, index=True) + report_type = Column(String, nullable=False) + target_id = Column(Integer, nullable=False) # 신고 대상 type의 ID + user_id = Column(Integer, ForeignKey("users.id"), nullable=True) # 신고한 유저의 ID + reason = Column(String, nullable=True) + response = Column(String, nullable=True) + status = Column(String, default="Pending") + created_at = Column(DateTime, default=datetime.utcnow) + reviewed_at = Column(DateTime, nullable=True) + + user = relationship("User", back_populates="reports") \ No newline at end of file diff --git a/db_models/music.py b/db_models/music.py new file mode 100644 index 0000000000000000000000000000000000000000..856b1f54297f6a028b1d47090ab29d82331285a9 --- /dev/null +++ b/db_models/music.py @@ -0,0 +1,230 @@ +from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey, Float,JSON +from sqlalchemy.orm import relationship +from datetime import datetime +from database import Base +import random + +def get_random_album_art(): + number = random.randint(1, 30) + return f"https://storage.googleapis.com/mippia-userfile-bucket/userfiles/default_albumart/thumbnail_240_{number:02}.png" + +class Result(Base): # 한글 path, 영어 path, data_pkl # 3개가 result가 있음 1통합 -> migration을하면 지금 돌고있는 사이트에도 영향이감 + __tablename__ = "files" + + id = Column(Integer, primary_key=True, index=True) + path = Column(String, index=True, default = None) # deprecated + music_id = Column(Integer, ForeignKey("musics.id")) + rank = Column(Integer) + requested_at = Column(DateTime, default=datetime.utcnow) + title = Column(String, index=True) + music = relationship("Music", back_populates="results") + plagiarism_rate = Column(Float) + data_pkl = Column(String) + + def __repr__(self): + return f"" + + +class Visual_Result(Base): # 한글 path, 영어 path, data_pkl # 3개가 result가 있음 1통합 -> migration을하면 지금 돌고있는 사이트에도 영향이감 + __tablename__ = "visual_results" + + id = Column(Integer, primary_key=True) + ko_path = Column(String, default = None) # default = None + en_path = Column(String, default = None) # default = None, path2개는 거의 deprecated된 상태일듯? + music_id = Column(Integer, ForeignKey("musics.id"), index=True) + rank = Column(Integer) + requested_at = Column(DateTime, default=datetime.utcnow) + title = Column(String, index=True) + music = relationship("Music", back_populates="visual_results") + plagiarism_rate = Column(Float) + data_pkl = Column(String) # data json인데 ㅋㅋ.. column name 바꾸려다가 터질 수 있어서.. + opened = Column(Boolean, default = False) + share_token = Column(String, default = None) + + def __repr__(self): + return f"" + + +class Monitoring_Visual_Result(Base): + __tablename__ = "monitoring_visual_results" + id = Column(Integer, primary_key=True) + music_id = Column(Integer, ForeignKey("musics.id"), index=True) + rank = Column(Integer) + requested_at = Column(DateTime, default=datetime.utcnow) + title = Column(String, index=True) + + primary_artist = Column(String, index=True) + other_artist = Column(String) + distributor = Column(String, index=True) + release_date = Column(String) + additional_info = Column(JSON) + copyright_owner = Column(String) + + + music = relationship("Music", back_populates="monitoring_visual_results") + plagiarism_rate = Column(Float) + data_pkl = Column(String) + opened = Column(Boolean, default = False) + share_token = Column(String, default = None) + inst = Column(String) + + def __repr__(self): + return f"" + + + +class Music(Base): # 신청 받은 곡 + __tablename__ = "musics" + + id = Column(Integer, primary_key=True, index=True) + music_path = Column(String) + preview_music_path = Column(String) + + title = Column(String) # 이거 아마 업뎃해야할듯? 아마 다운 이후에 구해질거임. 물론 곡을 업로드한경우에는 바로 업뎃도 가능할듯 + user_input = Column(String) # 이게 곡이면 music_path랑 똑같고, 링크면, 이거는 링크인데 music_path는 여기 로컬 path임 + user_id = Column(Integer, ForeignKey("users.id")) + user = relationship("User", back_populates="musics") + results = relationship("Result", back_populates="music", cascade="all, delete-orphan") + visual_results = relationship("Visual_Result", back_populates="music", cascade="all, delete-orphan") + + key_change = Column(Integer, default=0) + key = Column(String, nullable=True) + + status = Column(String) + inst = Column(String) + requested_at = Column(DateTime, default=datetime.utcnow) + bpm = Column(Integer, nullable=True) + language = Column(String) + processed_in_aibackend = Column(Boolean) + is_monitoring = Column(Boolean, default=False) # 새로운 속성 추가 + monitoring_music = relationship("Monitoring_Music", back_populates="music") + monitoring_visual_results = relationship("Monitoring_Visual_Result", back_populates="music", cascade="all, delete-orphan") + ai_detection_musics = relationship("AI_Detection_Music", back_populates="music", cascade="all, delete-orphan") + + deleted = Column(Boolean, default=False) + info_link = Column(String) + new_info_link = Column(String) + + duration = Column(Float) + new_version = Column(Boolean, default = True) + credit = Column(Integer) + + music_task = Column(String) #['monitoring', 'analysis', 'plagiarism_check','ai_detector',...] + # 이거로 monitoring, 1v1, 1대다 등등 구분할 수 있도록. 엥간하면 Admin확인용으로!. 알고리즘으로 써도 되긴할건데 enum형태는 아니라 까다로울거임 + + opened = Column(Boolean, default = False) + + + + def __repr__(self): + return f"" + + @property + def user_name(self): + return self.user.user_info.name if self.user and self.user.user_info else None + + + +class Monitoring_Music(Base): + __tablename__ = "monitoring_musics" + + id = Column(Integer, primary_key=True, index=True) + music_id = Column(Integer, ForeignKey("musics.id")) + music = relationship("Music", back_populates="monitoring_music") + title = Column(String) + artist = Column(String) + requested_at = Column(DateTime, default=datetime.utcnow) + release_date = Column(DateTime) + album_art = Column(String, default=get_random_album_art) + detected_count = Column(Integer, default=0) + find_count = Column(Integer, default=0) + criteria = Column(Integer) + + mail_agree = Column(Boolean, default=False) + + deleted = Column(Boolean, default=False) + deleted_reason = Column(String, nullable=True, default=None) + is_monitoring = Column(Boolean, default=True) + + + # 앨범과의 관계. 여기서 앨범은 모니터링에서만 사용되는 기능이라고 가정합니다 + album_id = Column(Integer, ForeignKey("albums.id"), nullable=True) # 앨범 ID (옵션) + album = relationship("Album", back_populates="monitoring_musics") # 앨범에 대한 역참조\ + track_number = Column(Integer, nullable=True) + + def __repr__(self): + return f"" + + def to_dict(self): + music_dict = {"id": self.id, "title": self.title, "artist": self.artist, + "requested_at": self.requested_at, "release_date": self.release_date, + "album_art": self.album_art, "detected_count": self.detected_count} + return music_dict + +class Album(Base): + __tablename__ = "albums" + + id = Column(Integer, primary_key=True, index=True) + user_id = Column(Integer, ForeignKey("users.id")) # 앨범 소유자 + user = relationship("User", back_populates="albums") + title = Column(String, nullable=False) # 앨범 제목 + artist = Column(String) + album_art = Column(String, default=get_random_album_art) + created_at = Column(DateTime, default=datetime.utcnow) + deleted = Column(Boolean, default=False) + music_number = Column(Integer, default=0) + + + # 모니터링 곡과의 관계 + monitoring_musics = relationship("Monitoring_Music", back_populates="album") + + def __repr__(self): + return f"" + +class Now_Monitoring(Base): + __tablename__ = "now_monitoring" + + id = Column(Integer, primary_key=True, index=True) + music_title = Column(String) + primary_artist = Column(String, index=True) + other_artist = Column(String) + distributor = Column(String, index=True) + release_date = Column(String) + additional_info = Column(JSON) + copyright_owner = Column(String) + now_index = Column(Integer, default=0) + + +class AI_Detection_Music(Base): + __tablename__ = "ai_detection_musics" + + id = Column(Integer, primary_key=True, index=True) + user_id = Column(Integer, ForeignKey("users.id")) + music_id = Column(Integer, ForeignKey("musics.id")) + music = relationship("Music", back_populates="ai_detection_musics") + title = Column(String) + ai_percentage = Column(Float) + processed_in_ai_backend = Column(Boolean, default=False) + requested_at = Column(DateTime, default=datetime.utcnow) + vocal_percentage = Column(Float) + inst_percentage = Column(Float) + opened = Column(Boolean, default = False) + link = Column(String, default = None) + status = Column(String, default = None) + deleted = Column(Boolean, default=False) + user = relationship("User", back_populates="ai_detection_musics") + + + def to_dict(self): + return { + "id": self.id, + "music_id": self.music_id, + "title": self.title, + "ai_percentage": self.ai_percentage, + "processed_in_ai_backend": self.processed_in_ai_backend, + "requested_at": self.requested_at.isoformat() if self.requested_at else None, + "vocal_percentage": self.vocal_percentage, + "inst_percentage": self.inst_percentage + } + + diff --git a/db_models/notification.py b/db_models/notification.py new file mode 100644 index 0000000000000000000000000000000000000000..8ba363cd596d9b08ea2b42daa3dc468aa6cd4bea --- /dev/null +++ b/db_models/notification.py @@ -0,0 +1,94 @@ +from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey, Float, Text +from sqlalchemy.orm import relationship +from datetime import datetime +from database import Base +import json + +class Notification(Base): + __tablename__ = 'notifications' + id = Column(Integer, primary_key=True) + user_id = Column(Integer, ForeignKey('users.id'),nullable=True) + user = relationship("User", back_populates="notifications") + template_key = Column(String, ForeignKey('notification_templates.key')) # 템플릿 키 참조 + template = relationship("NotificationTemplate") + created_at = Column(DateTime, default=datetime.utcnow) + data = Column(String) + # 'type' 필드를 통해 다양한 알림 타입을 구분할 수 있습니다. + type = Column(String) + is_read = Column(Boolean, default=False) + deleted = Column(Boolean, default=False) + + ko_template = Column(String, default=None) # 한국어 메시지 템플릿 + ko_status_template = Column(String, default=None) + en_template = Column(String, default=None) # + en_status_template = Column(String, default=None) + + def get_message(self, lang): + template = self.template + if lang == 'ko': + if not self.ko_template or not self.ko_status_template: # 처리해둔게 없으면 해당 template에 맞게 계산합니다 + template_string = template.ko_template + template_status = template.ko_status_template + data_dict = json.loads(self.data) + result = template_string.format(**data_dict) + self.ko_template = result + self.ko_status_template = template_status + + else: # # 처리해둔게 있으면 그냥 가져옵니다. + result = self.ko_template + template_status = self.ko_status_template + elif lang == 'en': + if not self.en_template or not self.en_status_template: + template_string = template.en_template + template_status = template.en_status_template + data_dict = json.loads(self.data) + result = template_string.format(**data_dict) + self.en_template = result + self.en_status_template = template_status + else: + result = self.en_template + template_status = self.en_status_template + else: + raise ValueError(f"Unsupported language: {lang}") + + + + return result, template_status + +class NotificationTemplate(Base): + __tablename__ = 'notification_templates' + id = Column(Integer, primary_key=True) + key = Column(String, unique=True) + ko_template = Column(String) # 한국어 메시지 템플릿 + ko_status_template = Column(String) + en_template = Column(String) # + en_status_template = Column(String) + +""" 여기는 사용 예시 + template = NotificationTemplate( + key='music_test_complete', + ko_template='음원 검사가 완료되었습니다! 표절률은 {rate}% 입니다.', + en_template='Music Test Complete! The Plagiarism Rate is {rate}%.' + ) + session.add(template) + session.commit() -> 이렇게 잔뜩 넣어두고 쓰라는거 아녀.. + alert = Notification( + user_id=1, + template_key='music_test_complete', + data=json.dumps({'rate': 30}) # 포맷팅에 사용될 데이터를 JSON 문자열로 저장 + ) + session.add(alert) + session.commit() + + # 알림 메시지 조회 + alert = session.query(Notification).first() + print(alert.get_message('ko')) # 한국어 메시지 출력 + print(alert.get_message('en')) # 영어 메시지 출력 +""" + +""" 여기는 key와 포맷팅 데이터 모음. 꼴은 Dict로 써두긴 하는데 'key' : 'format' 느낌으로 걍 보고 쓰면 되는거임 +keys_and_formattings={ + 'music_test_complete' : {'rate': float.2f} + 'update_notice_templage' : {'date': String} # YYYY-MM-DD +} +""" \ No newline at end of file diff --git a/db_models/user.py b/db_models/user.py new file mode 100644 index 0000000000000000000000000000000000000000..a3b8c3c0a873bd90e7a2b884a9ba533e04ce180a --- /dev/null +++ b/db_models/user.py @@ -0,0 +1,178 @@ +from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey, Enum as SQLEnum +from sqlalchemy.orm import relationship,Mapped +from datetime import datetime +from database import Base + + + +class User(Base): + __tablename__ = "users" + + id = Column(Integer, primary_key=True, index=True) + internal_id = Column(String, unique=True, index=True) + email = Column(String, index=True) + + musics = relationship("Music", back_populates="user", + cascade="delete, delete-orphan") + + music_analysis = relationship("Music_Analysis", back_populates="user", cascade="all, delete-orphan") + ai_detection_musics = relationship("AI_Detection_Music", back_populates="user") + posts = relationship("Post", back_populates="user", + cascade="delete, delete-orphan") + asks = relationship("Ask", back_populates="user", + cascade="delete, delete-orphan") + albums = relationship("Album", back_populates="user") + user_info = relationship("UserInfo", back_populates="user", + uselist=False, cascade="delete, delete-orphan") + access_logs = relationship( + "UserAccessLog", back_populates="user", cascade="delete, delete-orphan") + comments = relationship('Comment', back_populates='user') + likes = relationship("Like", back_populates="user") + notifications = relationship( + "Notification", back_populates="user", cascade="all, delete-orphan") + survey_responses = relationship("SurveyResponse", back_populates="user", cascade="all, delete-orphan") + auth_provider = Column( + SQLEnum("Google", "Naver", "Kakao", "None", name="oauth_provier_enum")) + auth_provider_id = Column(String, nullable=True) # OAuth 공급자의 고유 식별자 + + + created_at = Column(DateTime, default=datetime.utcnow) + + subscriptions = relationship( + "Subscription", back_populates="user", cascade="all, delete-orphan") + payments = relationship( + "Payment", back_populates="user", cascade="all, delete-orphan") + billing_info: Mapped["UserBillingInfo"] = relationship( + "UserBillingInfo", back_populates="user", uselist=False, cascade="all, delete-orphan") + + additional_info_submitted = Column(Boolean, default=False) + + deleted = Column(Boolean, default=False, nullable=False) + reports = relationship("Report", back_populates="user") + + + last_monitoring_email_sent = Column(DateTime, nullable=True) + + last_notice_seen = Column(DateTime, nullable=True) + + refresh_token = Column(String, index=True) # 리프레시 토큰 저장 + token_expires = Column(DateTime, default=datetime.utcnow) + + def __repr__(self): + return f"" + + @property + def user_name(self): + return self.user_info.name if self.user_info else None + + @property + def nickname(self): + return self.user_info.nickname if self.user_info else None + + +class UserBillingInfo(Base): + __tablename__ = "user_billing_info" + id = Column(Integer, ForeignKey("users.id"), primary_key=True) + user = relationship("User", back_populates="billing_info") + subscription = Column(String, default='free') # Deprecated. 엥..간해서는 사용 안해야함 + """ +# "basic-usd-monthly" + "standard-usd-monthly" + "premium-usd-monthly" + "basic-usd-yearly" + "standard-usd-yearly" + "premium-usd-yearly" + "basic-krw-monthly" + "standard-krw-monthly" + "premium-krw-monthly" + "basic-krw-yearly" + "standard-krw-yearly" + "premium-krw-yearly" + """ + remain_credit = Column(Integer, default=600) # 해당 요소는 사용하지 않아도 될 꺼 같음. + # monitoring from subscription + monitoring_number = Column(Integer, default=0) + additional_monitoring_number = Column( + Integer, default=1) # monitoring from addtional buy + first_month_used = Column(Boolean, default=False) # deprecated maybe. + + deleted_credit = Column(Integer, default=0) # 복구용? + + def to_dict(self): + return { + 'remain_credit': self.remain_credit, + 'monitoring_number': self.monitoring_number, + 'additional_monitoring_number': self.additional_monitoring_number, + 'first_month_used': self.first_month_used + } + + +class UserInfo(Base): + __tablename__ = "user_info" + + id: Column[int] = Column(Integer, ForeignKey("users.id"), primary_key=True) + user = relationship("User", back_populates="user_info") + name:Column[str] = Column(String) + nickname: Column[str] = Column(String, unique=True) + + phone = Column(String) + birthdate = Column(String) + termsofservice = Column(Boolean, default=False) + marketingconsent = Column(Boolean, default=False) + organization = Column(String, nullable=True) + deleted = Column(Boolean, default=False) + profile_image_link = Column( + String, default="https://storage.cloud.google.com/mippia-userfile-bucket/userfiles/blank-profile-picture.svg") + song_available = Column(Integer, default=10) + + def __repr__(self): + return f"" + + +class UserAccessLog(Base): + __tablename__ = "user_access_logs" + + id = Column(Integer, primary_key=True, index=True) + user_id = Column(Integer, ForeignKey('users.id')) + access_time = Column(DateTime, default=datetime.utcnow) + ip_address = Column(String) + user_agent = Column(String) + + user = relationship("User", back_populates="access_logs") + + def __repr__(self): + return f"" + + @property + def user_name(self): + return self.user.user_info.name if self.user and self.user.user_info else None + +class SurveyResponse(Base): + __tablename__ = "survey_responses" + + id = Column(Integer, primary_key=True, index=True) + question_id = Column(String, ForeignKey("survey_questions.question_id"), index=True) + answer = Column(String, nullable=False) + user_id = Column(Integer, ForeignKey("users.id")) + created_at = Column(DateTime, default=datetime.utcnow) + + user = relationship("User", back_populates="survey_responses") + question = relationship( + "SurveyQuestion", + primaryjoin="foreign(SurveyResponse.question_id) == SurveyQuestion.question_id", + viewonly=True, + sync_backref=False + ) + + @property + def question_text(self): + return self.question.question_kr if self.question else None + + +class SurveyQuestion(Base): + __tablename__ = "survey_questions" + + id = Column(Integer, primary_key=True, index=True) + question_id = Column(String, unique=True, index=True) + question_kr = Column(String, nullable=False) + question_en = Column(String, nullable=False) \ No newline at end of file diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..bd7f02d28d83e0549a6ac7e417f6d118a91b2ba1 --- /dev/null +++ b/inference.py @@ -0,0 +1,307 @@ +import os +from pathlib import Path +import json +from beat_this.inference import File2Beats +import numpy as np +import torch +from typing import List, Tuple, Optional +import pytorch_lightning as pl +from model import MusicClassifier, MusicAudioClassifier +import argparse +import torch +import torchaudio +import deepspeed +import scipy.signal as signal +from typing import Dict, List +from google.cloud import storage +from dataset_f import FakeMusicCapsDataset + +from preprocess import get_segments_from_wav, find_optimal_segment_length + +#not for ismir + +def download_from_gcs(bucket_name, source_blob_name, destination_file_name): + destination_dir = os.path.dirname(destination_file_name) + if not os.path.exists(destination_dir): + os.makedirs(destination_dir) + + storage_client = storage.Client() + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(source_blob_name) + blob.download_to_filename(destination_file_name) + + + +def highpass_filter(y, sr, cutoff=1000, order=5): + if isinstance(sr, np.ndarray): + sr = np.mean(sr) + if not isinstance(sr, (int, float)): + raise ValueError(f"sr must be a number, but got {type(sr)}: {sr}") + + nyquist = 0.5 * sr + if cutoff <= 0 or cutoff >= nyquist: + cutoff = max(10, min(cutoff, nyquist - 1)) + + normal_cutoff = cutoff / nyquist + b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) + y_filtered = signal.lfilter(b, a, y) + return y_filtered + +def load_audio(audio_path: str, sr: int = 24000) -> Tuple[torch.Tensor, torch.Tensor]: + """ + 오디오 파일을 불러와 세그먼트로 분할합니다. + 고정된 길이의 세그먼트를 최대 48개 추출하고, 부족한 경우 패딩을 추가합니다. + + Args: + audio_path: 오디오 파일 경로 + sr: 목표 샘플링 레이트 (기본값 24000) + + Returns: + Tuple containing: + - 오디오 파형이 담긴 텐서 (48, 1, 240000) + - 패딩 마스크 텐서 (48), True = 패딩, False = 실제 오디오 + """ + + beats, downbeats = get_segments_from_wav(audio_path) + optimal_length, cleaned_downbeats = find_optimal_segment_length(downbeats) + + waveform, sample_rate = torchaudio.load(audio_path) + # 데이터 타입을 float32로 변환 + waveform = waveform.to(torch.float32) + + if sample_rate != sr: + resampler = torchaudio.transforms.Resample(sample_rate, sr) + waveform = resampler(waveform) + + # 모노로 변환 (필요한 경우) + if waveform.shape[0] > 1: + waveform = torch.mean(waveform, dim=0, keepdim=True) + + # 240000 샘플 = 10초 @ 24kHz + fixed_samples = 240000 + + # 각 downbeat에서 시작하는 segment 생성 + segments = [] + + # 다운비트가 없거나 매우 적을 경우 전체 오디오를 일정 간격으로 분할 + if len(cleaned_downbeats) < 2: + # 오디오 총 길이 (초) + total_duration = waveform.size(1) / sr + + # 5초 간격으로 세그먼트 시작점 생성 (또는 더 짧은 간격으로 설정 가능) + # 240000 샘플은 10초이므로 5초 간격은 세그먼트 간 50% 오버랩 + segment_interval = 5.0 # 초 단위 + + # 시작 시간 목록 생성 (0초부터 시작) + start_times = [t for t in np.arange(0, total_duration - (fixed_samples/sr) + 0.01, segment_interval)] + + # 최소한 하나의 세그먼트는 보장 + if not start_times and total_duration > 0: + start_times = [0.0] + else: + # 기존 방식대로 다운비트 사용 + start_times = cleaned_downbeats + + # 세그먼트 추출 + for i, start_time in enumerate(start_times): + # 시작 샘플 인덱스 계산 + start_sample = int(start_time * sr) + + # 끝 샘플 인덱스 계산 (시작 지점 + 고정 길이) + end_sample = start_sample + fixed_samples + + # 파일 끝을 넘어가는지 확인 + if end_sample > waveform.size(1): + # 짧은 곡의 경우: 끝에서부터 거꾸로 세그먼트 추출 시도 + if start_sample < waveform.size(1) and waveform.size(1) >= fixed_samples: + start_sample = waveform.size(1) - fixed_samples + end_sample = waveform.size(1) + else: + continue + + # 정확히 fixed_samples 길이의 세그먼트 추출 + segment = waveform[:, start_sample:end_sample] + # 하이패스 필터 적용 - 채널 차원 유지 + filtered = torch.tensor(highpass_filter(segment.squeeze().numpy(), sr)).unsqueeze(0) + + segments.append(filtered) + + # 최대 48개 세그먼트만 사용 + if len(segments) >= 48: + break + + # 세그먼트가 없는 경우, 곡이 너무 짧아서 고정 길이 세그먼트를 만들 수 없는 경우 + if not segments: + if waveform.size(1) > 0: # 오디오가 있지만 매우 짧은 경우 + # 전체 오디오를 하나의 세그먼트로 사용하고 나머지는 제로 패딩 + segment = waveform + # 필요한 길이에 맞게 패딩 추가 + padding_length = fixed_samples - segment.size(1) + if padding_length > 0: + segment = torch.nn.functional.pad(segment, (0, padding_length)) + + # 하이패스 필터 적용 + filtered = torch.tensor(highpass_filter(segment.squeeze().numpy(), sr)).unsqueeze(0) + segments.append(filtered) + else: + # 완전히 빈 오디오일 경우 + return torch.zeros((48, 1, fixed_samples), dtype=torch.float32), torch.ones(48, dtype=torch.bool) + + # 스택하여 텐서로 변환 - (n_segments, 1, time_samples) 형태 유지 + stacked_segments = torch.stack(segments) + + # 실제 세그먼트 수 (패딩 아님) + num_segments = stacked_segments.shape[0] + + # 패딩 마스크 생성 (False = 실제 오디오, True = 패딩) + padding_mask = torch.zeros(48, dtype=torch.bool) + + # 48개 미만인 경우 패딩 추가 + if num_segments < 48: + # 빈 세그먼트로 패딩 (zeros) + padding = torch.zeros((48 - num_segments, 1, fixed_samples), dtype=torch.float32) + stacked_segments = torch.cat([stacked_segments, padding], dim=0) + + # 패딩 마스크 설정 (True = 패딩) + padding_mask[num_segments:] = True + + return stacked_segments, padding_mask + + +def run_inference(model, audio_segments: torch.Tensor, padding_mask: torch.Tensor, device: str = 'cuda' if torch.cuda.is_available() else 'cpu') -> Dict: + """ + Run inference on audio segments. + + Args: + model: The loaded model + audio_segments: Preprocessed audio segments tensor (48, 1, 240000) + device: Device to run inference on + + Returns: + Dictionary with prediction results + """ + model.eval() + model.to(device) + model = model.half() + + print(padding_mask.shape) + with torch.no_grad(): + # 데이터 형태 확인 및 조정 + # wav_collate_with_mask 함수와 일치하도록 처리 + if audio_segments.shape[1] == 1: # (48, 1, 240000) 형태 + # 채널 차원 제거하고 배치 차원 추가 + audio_segments = audio_segments[:, 0, :].unsqueeze(0) # (1, 48, 240000) + else: + audio_segments = audio_segments.unsqueeze(0) # (1, 48, 768) # 사실 audio가 아니라 embedding segments일수도 + # 데이터를 half 타입으로 변환 + if padding_mask.dim() == 1: + padding_mask = padding_mask.unsqueeze(0) # [48] -> [1, 48] + audio_segments = audio_segments.to(device).half() + + mask = padding_mask.to(device) + + print(f"Input shape: {audio_segments.shape}") + print(f"Mask shape: {mask.shape}") + print(f"Mask: {mask}") + + # 추론 실행 (마스크 포함) + outputs = model(audio_segments, mask) + print(f"Output type: {type(outputs)}") + print(f"Output: {outputs}") + + # 모델 출력 구조에 따라 처리 + if isinstance(outputs, dict): + result = outputs + else: + # 단일 텐서인 경우 (로짓) + logits = outputs.squeeze() + prob = torch.sigmoid(logits).item() + + result = { + "prediction": "Fake" if prob > 0.5 else "Real", + "confidence": f"{max(prob, 1-prob)*100:.2f}%", + "fake_probability": f"{prob:.4f}", + "real_probability": f"{1-prob:.4f}", + "raw_output": logits.cpu().numpy().tolist() + } + + return result + +def get_model(model_type, device): + """Load the specified model.""" + if model_type == "MERT": + from ISMIR_2025.MERT.networks import CCV + model = CCV(embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True).to(device) + ckpt_file = 'mert_finetune_10.pth' + model.load_state_dict(torch.load(ckpt_file, map_location=device)) + embed_dim = 768 + else: + raise ValueError(f"Unknown model type: {model_type}") + + model.eval() + return model, embed_dim + + """ + elif model_type == "music2vec": + from ISMIR_2025.music2vec.networks import Music2VecClassifier + model = Music2VecClassifier(freeze_feature_extractor=True).to(device) + ckpt_file = '/data/kym/AI_Music_Detection/Code/model/music2vec/ckpt/fakemusicretrain/musiv2vec_processor/finetune_10.pth' + embed_dim = 768 + + elif model_type == "wav2vec": + from ISMIR_2025.wav2vec.networks import Wav2Vec2ForFakeMusic + model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=True).to(device) + ckpt_file = '/data/kym/AI_Music_Detection/Code/model/wav2vec/ckpt/split/wav2vec_processor/wav2vec2_finetune_10.pth' + embed_dim = 768 + + elif model_type == "ccv": + from ISMIR_2025.Model.networks import CCV + model = CCV(embed_dim=512, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True).to(device) + ckpt_file = '/data/kym/AI_Music_Detection/Code/model/ckpt/datasplit/hp1000/best_model_CCV.pth' + embed_dim = 512 + """ + + +def inference_with_audio(audio_path): + #audio_path = "The Chainsmokers & Coldplay - Something Just Like This (Lyric).mp3" + + model_type = "MERT" + checkpoint_path = "with_embedding_MERT_768_embedding/EmbeddingModel_MERT_768-epoch=0353-val_loss=0.3866-val_acc=0.9809-val_f1=0.9803-val_precision=0.9764-val_recall=0.9842.ckpt" + device = 'cuda' + # Note: Model loading would be handled by your code + print(f"Loading model of type {model_type} from {checkpoint_path}") + + backbone_model, input_dim = get_model(model_type, device) + segments, padding_mask = load_audio(audio_path, sr=24000) + segments = segments.to(device).to(torch.float32) + padding_mask = padding_mask.to(device).unsqueeze(0) + logits,embedding = backbone_model(segments.squeeze(1)) + test_dataset = FakeMusicCapsDataset([audio_path], [0], target_duration=10.0) + test_data, test_target = test_dataset[0] + test_data = test_data.to(device).to(torch.float32) + test_target = test_target.to(device) + output, _ = backbone_model(test_data.unsqueeze(0)) + + + # 모델 로드 부분 추가 + model = MusicAudioClassifier.load_from_checkpoint( + checkpoint_path, + input_dim=input_dim, + #emb_model=backbone_model + is_emb = True, + mode = 'both' + ) + + + # Run inference + print(f"Segments shape: {segments.shape}") + print("Running inference...") + results = run_inference(model, embedding, padding_mask, device=device) + + # 결과 출력 + print(f"Results: {results}") + return str(results) + +if __name__ == "__main__": + inference_with_audio() + diff --git a/mert_finetune_10.pth b/mert_finetune_10.pth new file mode 100644 index 0000000000000000000000000000000000000000..29161efce885a2a9a3e46a4f906fd28ec27500e8 --- /dev/null +++ b/mert_finetune_10.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:140059fcdb45d72653b96dd129bcb61cc5e359b48f8d28c5512ef7706cb1742f +size 680885656 diff --git a/model.py b/model.py new file mode 100644 index 0000000000000000000000000000000000000000..42e64e9d3d484ecbf6fbff6e807f7f15ba949890 --- /dev/null +++ b/model.py @@ -0,0 +1,490 @@ +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import pytorch_lightning as pl +from torch.utils.data import Dataset, DataLoader +from typing import List, Tuple, Optional +import numpy as np +from pathlib import Path +from deepspeed.ops.adam import FusedAdam + +class MusicClassifier(pl.LightningModule): + def __init__(self, + input_dim: int, + hidden_dim: int = 256, + learning_rate: float = 1e-4, + emb_model: Optional[nn.Module] = None, + is_emb: bool = False): + super().__init__() + self.save_hyperparameters() + + self.model = SegmentTransformer( + input_dim=input_dim, + hidden_dim=hidden_dim + ) + self.emb_model = emb_model + self.learning_rate = learning_rate + self.is_emb = is_emb + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + return self.model(x, mask) + + def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + x, y, mask = batch + if self.is_emb == False: + _,x = self.emb_model(x) + y_hat = self(x, mask) + loss = F.binary_cross_entropy_with_logits(y_hat.squeeze(), y.float()) + + # Log metrics + self.log('train_loss', loss,on_epoch=True, prog_bar=True) + preds = (torch.sigmoid(y_hat.squeeze()) > 0.5).float() + acc = (preds == y.float()).float().mean() + self.log('train_acc', acc,on_epoch=True, prog_bar=True) + + return loss + + def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> None: + x, y, mask = batch + if self.is_emb == False: + _,x = self.emb_model(x) + y_hat = self(x, mask) + loss = F.binary_cross_entropy_with_logits(y_hat.squeeze(), y.float()) + + # Calculate accuracy + preds = (torch.sigmoid(y_hat.squeeze()) > 0.5).float() + acc = (preds == y.float()).float().mean() + + self.log('val_loss', loss, prog_bar=True) + self.log('val_acc', acc, prog_bar=True) + + def configure_optimizers(self): + optimizer = torch.optim.AdamW( + self.parameters(), + lr=self.learning_rate, + weight_decay=0.01 + ) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=100, # Adjust based on your training epochs + eta_min=1e-6 + ) + return { + 'optimizer': optimizer, + 'lr_scheduler': scheduler, + 'monitor': 'val_loss' + } + +class MusicAudioClassifier(pl.LightningModule): + def __init__(self, + input_dim: int, + hidden_dim: int = 256, + learning_rate: float = 1e-4, + emb_model: Optional[nn.Module] = None, + is_emb: bool = False, + mode: str = 'only_emb', + share_parameter: bool = False): + super().__init__() + self.save_hyperparameters() + + self.model = SegmentTransformer( + input_dim=input_dim, + hidden_dim=hidden_dim, + mode = mode, + share_parameter = share_parameter + ) + self.emb_model = emb_model + self.learning_rate = learning_rate + self.is_emb = is_emb + + def _process_audio_batch(self, x: torch.Tensor) -> torch.Tensor: + + B, S = x.shape[:2] # [B, S, C, M, T] or [B, S, C, T] for wav, [B, S, 1?, embsize] for emb + x = x.view(B*S, *x.shape[2:]) # [B*S, C, M, T] + if self.is_emb == False: + _, embeddings = self.emb_model(x) # [B*S, emb_dim] + else: + embeddings = x + if embeddings.dim() == 3: + pooled_features = embeddings.mean(dim=1) # transformer + else: + pooled_features = embeddings # CCV..? no need to pooling + return pooled_features.view(B, S, -1) # [B, S, emb_dim] + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + x = self._process_audio_batch(x) # 이걸 freeze하고 쓰는게 사실상 윗버전임 + x = x.half() + return self.model(x, mask) + + def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + x, y, mask = batch + x = x.half() + y_hat = self(x, mask) + + # 배치 크기가 1인 경우 예외처리 + if y_hat.size(0) == 1: + loss = F.binary_cross_entropy_with_logits(y_hat.flatten(), y.float().flatten()) + probs = torch.sigmoid(y_hat.flatten()) + y_true = y.float().flatten() + else: + loss = F.binary_cross_entropy_with_logits(y_hat.squeeze(), y.float()) + probs = torch.sigmoid(y_hat.squeeze()) + y_true = y.float() + + # 간단한 배치 손실만 로깅 (step 수준) + self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + + # 전체 에폭에 대한 메트릭 계산을 위해 예측과 실제값 저장 + self.training_step_outputs.append({'preds': probs, 'targets': y_true}) + + return loss + + def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> None: + x, y, mask = batch + x = x.half() + y_hat = self(x, mask) + + # 배치 크기가 1인 경우 예외처리 + if y_hat.size(0) == 1: + loss = F.binary_cross_entropy_with_logits(y_hat.flatten(), y.float().flatten()) + probs = torch.sigmoid(y_hat.flatten()) + y_true = y.float().flatten() + else: + loss = F.binary_cross_entropy_with_logits(y_hat.squeeze(), y.float()) + probs = torch.sigmoid(y_hat.squeeze()) + y_true = y.float() + + # 간단한 배치 손실만 로깅 (step 수준) + self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + + # 전체 에폭에 대한 메트릭 계산을 위해 예측과 실제값 저장 + self.validation_step_outputs.append({'preds': probs, 'targets': y_true}) + + def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> None: + x, y, mask = batch + x = x.half() + y_hat = self(x, mask) + + # 배치 크기가 1인 경우 예외처리 + if y_hat.size(0) == 1: + loss = F.binary_cross_entropy_with_logits(y_hat.flatten(), y.float().flatten()) + probs = torch.sigmoid(y_hat.flatten()) + y_true = y.float().flatten() + else: + loss = F.binary_cross_entropy_with_logits(y_hat.squeeze(), y.float()) + probs = torch.sigmoid(y_hat.squeeze()) + y_true = y.float() + + # 간단한 배치 손실만 로깅 (step 수준) + self.log('test_loss', loss, on_epoch=True, prog_bar=True) + + # 전체 에폭에 대한 메트릭 계산을 위해 예측과 실제값 저장 + self.test_step_outputs.append({'preds': probs, 'targets': y_true}) + + def on_train_epoch_start(self): + # 에폭 시작 시 결과 저장용 리스트 초기화 + self.training_step_outputs = [] + + def on_validation_epoch_start(self): + # 에폭 시작 시 결과 저장용 리스트 초기화 + self.validation_step_outputs = [] + + def on_test_epoch_start(self): + # 에폭 시작 시 결과 저장용 리스트 초기화 + self.test_step_outputs = [] + + def on_train_epoch_end(self): + # 에폭이 끝날 때 전체 데이터에 대한 메트릭 계산 + if not hasattr(self, 'training_step_outputs') or not self.training_step_outputs: + return + + all_preds = torch.cat([x['preds'] for x in self.training_step_outputs]) + all_targets = torch.cat([x['targets'] for x in self.training_step_outputs]) + + # 전체 데이터에 대한 메트릭 계산 + binary_preds = (all_preds > 0.5).float() + + # 정확도 계산 + acc = (binary_preds == all_targets).float().mean() + + # 혼동 행렬 요소 계산 + tp = torch.sum((binary_preds == 1) & (all_targets == 1)).float() + fp = torch.sum((binary_preds == 1) & (all_targets == 0)).float() + tn = torch.sum((binary_preds == 0) & (all_targets == 0)).float() + fn = torch.sum((binary_preds == 0) & (all_targets == 1)).float() + + # 메트릭 계산 + precision = tp / (tp + fp) if (tp + fp) > 0 else torch.tensor(0.0).to(tp.device) + recall = tp / (tp + fn) if (tp + fn) > 0 else torch.tensor(0.0).to(tp.device) + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else torch.tensor(0.0).to(tp.device) + specificity = tn / (tn + fp) if (tn + fp) > 0 else torch.tensor(0.0).to(tn.device) + + # 로깅 - 일관된 이름 사용 + self.log('train_acc', acc, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('train_precision', precision, on_epoch=True, sync_dist=True) + self.log('train_recall', recall, on_epoch=True, sync_dist=True) + self.log('train_f1', f1, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('train_specificity', specificity, on_epoch=True, sync_dist=True) + + def on_validation_epoch_end(self): + # 에폭이 끝날 때 전체 데이터에 대한 메트릭 계산 + if not hasattr(self, 'validation_step_outputs') or not self.validation_step_outputs: + return + + all_preds = torch.cat([x['preds'] for x in self.validation_step_outputs]) + all_targets = torch.cat([x['targets'] for x in self.validation_step_outputs]) + + # ROC-AUC 계산 (간단한 근사) + sorted_indices = torch.argsort(all_preds, descending=True) + sorted_targets = all_targets[sorted_indices] + + n_pos = torch.sum(all_targets) + n_neg = len(all_targets) - n_pos + + if n_pos > 0 and n_neg > 0: + # TPR과 FPR을 누적합으로 계산 + tpr_curve = torch.cumsum(sorted_targets, dim=0) / n_pos + fpr_curve = torch.cumsum(1 - sorted_targets, dim=0) / n_neg + + # AUC 계산 (사다리꼴 법칙) + width = fpr_curve[1:] - fpr_curve[:-1] + height = (tpr_curve[1:] + tpr_curve[:-1]) / 2 + auc_approx = torch.sum(width * height) + + self.log('val_auc', auc_approx, on_epoch=True) + + # 전체 데이터에 대한 메트릭 계산 + binary_preds = (all_preds > 0.5).float() + + # 정확도 계산 + acc = (binary_preds == all_targets).float().mean() + + # 혼동 행렬 요소 계산 + tp = torch.sum((binary_preds == 1) & (all_targets == 1)).float() + fp = torch.sum((binary_preds == 1) & (all_targets == 0)).float() + tn = torch.sum((binary_preds == 0) & (all_targets == 0)).float() + fn = torch.sum((binary_preds == 0) & (all_targets == 1)).float() + + # 메트릭 계산 + precision = tp / (tp + fp) if (tp + fp) > 0 else torch.tensor(0.0).to(tp.device) + recall = tp / (tp + fn) if (tp + fn) > 0 else torch.tensor(0.0).to(tp.device) + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else torch.tensor(0.0).to(tp.device) + specificity = tn / (tn + fp) if (tn + fp) > 0 else torch.tensor(0.0).to(tn.device) + + # 로깅 - 일관된 이름 사용 (val_epoch_f1 대신 val_f1 사용) + self.log('val_acc', acc, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('val_precision', precision, on_epoch=True, sync_dist=True) + self.log('val_recall', recall, on_epoch=True, sync_dist=True) + self.log('val_f1', f1, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('val_specificity', specificity, on_epoch=True, sync_dist=True) + + def on_test_epoch_end(self): + # 에폭이 끝날 때 전체 테스트 데이터에 대한 메트릭 계산 + if not hasattr(self, 'test_step_outputs') or not self.test_step_outputs: + return + + all_preds = torch.cat([x['preds'] for x in self.test_step_outputs]) + all_targets = torch.cat([x['targets'] for x in self.test_step_outputs]) + + # ROC-AUC 계산 (간단한 근사) + sorted_indices = torch.argsort(all_preds, descending=True) + sorted_targets = all_targets[sorted_indices] + + n_pos = torch.sum(all_targets) + n_neg = len(all_targets) - n_pos + + if n_pos > 0 and n_neg > 0: + # TPR과 FPR을 누적합으로 계산 + tpr_curve = torch.cumsum(sorted_targets, dim=0) / n_pos + fpr_curve = torch.cumsum(1 - sorted_targets, dim=0) / n_neg + + # AUC 계산 (사다리꼴 법칙) + width = fpr_curve[1:] - fpr_curve[:-1] + height = (tpr_curve[1:] + tpr_curve[:-1]) / 2 + auc_approx = torch.sum(width * height) + + self.log('test_auc', auc_approx, on_epoch=True, sync_dist=True) + + # 전체 데이터에 대한 메트릭 계산 + binary_preds = (all_preds > 0.5).float() + + # 정확도 계산 + acc = (binary_preds == all_targets).float().mean() + + # 혼동 행렬 요소 계산 + tp = torch.sum((binary_preds == 1) & (all_targets == 1)).float() + fp = torch.sum((binary_preds == 1) & (all_targets == 0)).float() + tn = torch.sum((binary_preds == 0) & (all_targets == 0)).float() + fn = torch.sum((binary_preds == 0) & (all_targets == 1)).float() + + # 메트릭 계산 + precision = tp / (tp + fp) if (tp + fp) > 0 else torch.tensor(0.0).to(tp.device) + recall = tp / (tp + fn) if (tp + fn) > 0 else torch.tensor(0.0).to(tp.device) + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else torch.tensor(0.0).to(tp.device) + specificity = tn / (tn + fp) if (tn + fp) > 0 else torch.tensor(0.0).to(tn.device) + balanced_acc = (recall + specificity) / 2 + + # 로깅 - 일관된 이름 사용 + self.log('test_acc', acc, on_epoch=True, prog_bar=True) + self.log('test_precision', precision, on_epoch=True) + self.log('test_recall', recall, on_epoch=True) + self.log('test_f1', f1, on_epoch=True, prog_bar=True) + self.log('test_specificity', specificity, on_epoch=True) + self.log('test_balanced_acc', balanced_acc, on_epoch=True) + + def configure_optimizers(self): + optimizer = FusedAdam(self.parameters(),lr=self.learning_rate, + weight_decay=0.01) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=100, # Adjust based on your training epochs + eta_min=1e-6 + ) + + return { + 'optimizer': optimizer, + 'lr_scheduler': scheduler, + 'monitor': 'val_loss', + } + + +def pad_sequence_with_mask(batch: List[Tuple[torch.Tensor, int]]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Collate function for DataLoader that creates padded sequences and attention masks with fixed length (48).""" + embeddings, labels = zip(*batch) + fixed_len = 48 # 고정 길이 + + batch_size = len(embeddings) + feat_dim = embeddings[0].shape[-1] + + padded = torch.zeros((batch_size, fixed_len, feat_dim)) # 고정 길이로 패딩된 텐서 + mask = torch.ones((batch_size, fixed_len), dtype=torch.bool) # True는 padding을 의미 + + for i, emb in enumerate(embeddings): + length = emb.shape[0] + + # 길이가 고정 길이보다 길면 자르고, 짧으면 패딩 + if length > fixed_len: + padded[i, :] = emb[:fixed_len] # fixed_len보다 긴 부분을 잘라서 채운다. + mask[i, :] = False + else: + padded[i, :length] = emb # 실제 데이터 길이에 맞게 채운다. + mask[i, :length] = False # 패딩이 아닌 부분은 False로 설정 + + return padded, torch.tensor(labels), mask + + +class SegmentTransformer(nn.Module): + def __init__(self, + input_dim: int, + hidden_dim: int = 256, + num_heads: int = 8, + num_layers: int = 4, + dropout: float = 0.1, + max_sequence_length: int = 1000, + mode: str = 'only_emb', + share_parameter: bool = False): + super().__init__() + + # Original sequence processing + self.input_projection = nn.Linear(input_dim, hidden_dim) + self.mode = mode + self.share_parameter = share_parameter + # Positional encoding + position = torch.arange(max_sequence_length).unsqueeze(1) + div_term = torch.exp(torch.arange(0, hidden_dim, 2) * (-np.log(10000.0) / hidden_dim)) + pos_encoding = torch.zeros(max_sequence_length, hidden_dim) + pos_encoding[:, 0::2] = torch.sin(position * div_term) + pos_encoding[:, 1::2] = torch.cos(position * div_term) + self.register_buffer('pos_encoding', pos_encoding) + + # Transformer for original sequence + encoder_layer = nn.TransformerEncoderLayer( + d_model=hidden_dim, + nhead=num_heads, + dim_feedforward=hidden_dim * 4, + dropout=dropout, + batch_first=True + ) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + self.sim_transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + # Self-similarity stream processing + self.similarity_projection = nn.Sequential( + nn.Conv1d(1, hidden_dim // 2, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv1d(hidden_dim // 2, hidden_dim, kernel_size=3, padding=1), + nn.ReLU(), + nn.Dropout(dropout) + ) + + # Transformer for similarity stream + self.similarity_transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + # Final classification head + self.classification_head_dim = hidden_dim * 2 if mode == 'both' else hidden_dim + self.classification_head = nn.Sequential( + nn.Linear(self.classification_head_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, hidden_dim // 2), + nn.LayerNorm(hidden_dim // 2), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim // 2, 1) + ) + + def forward(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + batch_size, seq_len, _ = x.shape + + # 1. Process original sequence + x1 = self.input_projection(x) + x1 = x1 + self.pos_encoding[:seq_len].unsqueeze(0) + x1 = self.transformer(x1, src_key_padding_mask=padding_mask) # padding_mask 사용 + + # 2. Calculate and process self-similarity + x_expanded = x.unsqueeze(2) + x_transposed = x.unsqueeze(1) + distances = torch.mean((x_expanded - x_transposed) ** 2, dim=-1) + similarity_matrix = torch.exp(-distances) # (batch_size, seq_len, seq_len) + + # 자기 유사도 마스크 생성 및 적용 (각 시점에 대한 마스크 개별 적용) + if padding_mask is not None: + similarity_mask = padding_mask.unsqueeze(1) | padding_mask.unsqueeze(2) # (batch_size, seq_len, seq_len) + similarity_matrix = similarity_matrix.masked_fill(similarity_mask, 0.0) + + # Process similarity matrix row by row using Conv1d + x2 = similarity_matrix.unsqueeze(1) # (batch_size, 1, seq_len, seq_len) + x2 = x2.view(batch_size * seq_len, 1, seq_len) # Reshape for Conv1d + x2 = self.similarity_projection(x2) # (batch_size * seq_len, hidden_dim, seq_len) + x2 = x2.mean(dim=2) # Pool across sequence dimension + x2 = x2.view(batch_size, seq_len, -1) # Reshape back + + x2 = x2 + self.pos_encoding[:seq_len].unsqueeze(0) + if self.share_parameter: + x2 = self.transformer(x2, src_key_padding_mask=padding_mask) + else: + x2 = self.sim_transformer(x2, src_key_padding_mask=padding_mask) # padding_mask 사용 + + # 3. Global average pooling for both streams + if padding_mask is not None: + mask_expanded = (~padding_mask).float().unsqueeze(-1) + x1 = (x1 * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1) + x2 = (x2 * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1) + else: + x1 = x1.mean(dim=1) + x2 = x2.mean(dim=1) + + # 4. Combine both streams and classify + #x = x1 # only emb + #x = x2 # only structure + #x = torch.cat([x1, x2], dim=-1) + if self.mode == 'only_emb': + x = x1 + elif self.mode == 'only_structure': + x = x2 + elif self.mode == 'both': + x = torch.cat([x1, x2], dim=-1) + x = x.half() + return self.classification_head(x) diff --git a/model_with_pure_bert.py b/model_with_pure_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..94c291a0b35667357ec37a0d17ce333e0269e88a --- /dev/null +++ b/model_with_pure_bert.py @@ -0,0 +1,423 @@ +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import pytorch_lightning as pl +from torch.utils.data import Dataset, DataLoader +from typing import List, Tuple, Optional +import numpy as np +from pathlib import Path +from deepspeed.ops.adam import FusedAdam + + +class MusicAudioClassifier(pl.LightningModule): + def __init__(self, + input_dim: int, + hidden_dim: int = 256, + learning_rate: float = 1e-4, + emb_model: Optional[nn.Module] = None, + is_emb: bool = False, + mode: str = 'both', + share_parameter: bool = False): + super().__init__() + self.save_hyperparameters() + + self.model = SegmentTransformer( + input_dim=input_dim, + hidden_dim=hidden_dim, + mode = mode, + share_parameter = share_parameter + ) + self.emb_model = emb_model + self.learning_rate = learning_rate + self.is_emb = is_emb + + def _process_audio_batch(self, x: torch.Tensor) -> torch.Tensor: + + B, S = x.shape[:2] # [B, S, C, M, T] or [B, S, C, T] for wav, [B, S, 1?, embsize] for emb + x = x.view(B*S, *x.shape[2:]) # [B*S, C, M, T] + if self.is_emb == False: + embeddings = self.emb_model(x) # [B*S, emb_dim] + else: + embeddings = x + if embeddings.dim() == 3: + pooled_features = embeddings.mean(dim=1) # transformer + else: + pooled_features = embeddings # CCV..? no need to pooling + return pooled_features.view(B, S, -1) # [B, S, emb_dim] + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + x = self._process_audio_batch(x) # 이걸 freeze하고 쓰는게 사실상 윗버전임 + x = x.half() + return self.model(x, mask) + + def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + x, y, mask = batch + x = x.half() + y_hat = self(x, mask) + + # 배치 크기가 1인 경우 예외처리 + if y_hat.size(0) == 1: + loss = F.binary_cross_entropy_with_logits(y_hat.flatten(), y.float().flatten()) + probs = torch.sigmoid(y_hat.flatten()) + y_true = y.float().flatten() + else: + loss = F.binary_cross_entropy_with_logits(y_hat.squeeze(), y.float()) + probs = torch.sigmoid(y_hat.squeeze()) + y_true = y.float() + + # 간단한 배치 손실만 로깅 (step 수준) + self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + + # 전체 에폭에 대한 메트릭 계산을 위해 예측과 실제값 저장 + self.training_step_outputs.append({'preds': probs, 'targets': y_true}) + + return loss + + def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> None: + x, y, mask = batch + x = x.half() + y_hat = self(x, mask) + + # 배치 크기가 1인 경우 예외처리 + if y_hat.size(0) == 1: + loss = F.binary_cross_entropy_with_logits(y_hat.flatten(), y.float().flatten()) + probs = torch.sigmoid(y_hat.flatten()) + y_true = y.float().flatten() + else: + loss = F.binary_cross_entropy_with_logits(y_hat.squeeze(), y.float()) + probs = torch.sigmoid(y_hat.squeeze()) + y_true = y.float() + + # 간단한 배치 손실만 로깅 (step 수준) + self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + + # 전체 에폭에 대한 메트릭 계산을 위해 예측과 실제값 저장 + self.validation_step_outputs.append({'preds': probs, 'targets': y_true}) + + def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> None: + x, y, mask = batch + x = x.half() + y_hat = self(x, mask) + + # 배치 크기가 1인 경우 예외처리 + if y_hat.size(0) == 1: + loss = F.binary_cross_entropy_with_logits(y_hat.flatten(), y.float().flatten()) + probs = torch.sigmoid(y_hat.flatten()) + y_true = y.float().flatten() + else: + loss = F.binary_cross_entropy_with_logits(y_hat.squeeze(), y.float()) + probs = torch.sigmoid(y_hat.squeeze()) + y_true = y.float() + + # 간단한 배치 손실만 로깅 (step 수준) + self.log('test_loss', loss, on_epoch=True, prog_bar=True) + + # 전체 에폭에 대한 메트릭 계산을 위해 예측과 실제값 저장 + self.test_step_outputs.append({'preds': probs, 'targets': y_true}) + + def on_train_epoch_start(self): + # 에폭 시작 시 결과 저장용 리스트 초기화 + self.training_step_outputs = [] + + def on_validation_epoch_start(self): + # 에폭 시작 시 결과 저장용 리스트 초기화 + self.validation_step_outputs = [] + + def on_test_epoch_start(self): + # 에폭 시작 시 결과 저장용 리스트 초기화 + self.test_step_outputs = [] + + def on_train_epoch_end(self): + # 에폭이 끝날 때 전체 데이터에 대한 메트릭 계산 + if not hasattr(self, 'training_step_outputs') or not self.training_step_outputs: + return + + all_preds = torch.cat([x['preds'] for x in self.training_step_outputs]) + all_targets = torch.cat([x['targets'] for x in self.training_step_outputs]) + + # 전체 데이터에 대한 메트릭 계산 + binary_preds = (all_preds > 0.5).float() + + # 정확도 계산 + acc = (binary_preds == all_targets).float().mean() + + # 혼동 행렬 요소 계산 + tp = torch.sum((binary_preds == 1) & (all_targets == 1)).float() + fp = torch.sum((binary_preds == 1) & (all_targets == 0)).float() + tn = torch.sum((binary_preds == 0) & (all_targets == 0)).float() + fn = torch.sum((binary_preds == 0) & (all_targets == 1)).float() + + # 메트릭 계산 + precision = tp / (tp + fp) if (tp + fp) > 0 else torch.tensor(0.0).to(tp.device) + recall = tp / (tp + fn) if (tp + fn) > 0 else torch.tensor(0.0).to(tp.device) + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else torch.tensor(0.0).to(tp.device) + specificity = tn / (tn + fp) if (tn + fp) > 0 else torch.tensor(0.0).to(tn.device) + + # 로깅 - 일관된 이름 사용 + self.log('train_acc', acc, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('train_precision', precision, on_epoch=True, sync_dist=True) + self.log('train_recall', recall, on_epoch=True, sync_dist=True) + self.log('train_f1', f1, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('train_specificity', specificity, on_epoch=True, sync_dist=True) + + def on_validation_epoch_end(self): + # 에폭이 끝날 때 전체 데이터에 대한 메트릭 계산 + if not hasattr(self, 'validation_step_outputs') or not self.validation_step_outputs: + return + + all_preds = torch.cat([x['preds'] for x in self.validation_step_outputs]) + all_targets = torch.cat([x['targets'] for x in self.validation_step_outputs]) + + # ROC-AUC 계산 (간단한 근사) + sorted_indices = torch.argsort(all_preds, descending=True) + sorted_targets = all_targets[sorted_indices] + + n_pos = torch.sum(all_targets) + n_neg = len(all_targets) - n_pos + + if n_pos > 0 and n_neg > 0: + # TPR과 FPR을 누적합으로 계산 + tpr_curve = torch.cumsum(sorted_targets, dim=0) / n_pos + fpr_curve = torch.cumsum(1 - sorted_targets, dim=0) / n_neg + + # AUC 계산 (사다리꼴 법칙) + width = fpr_curve[1:] - fpr_curve[:-1] + height = (tpr_curve[1:] + tpr_curve[:-1]) / 2 + auc_approx = torch.sum(width * height) + + self.log('val_auc', auc_approx, on_epoch=True) + + # 전체 데이터에 대한 메트릭 계산 + binary_preds = (all_preds > 0.5).float() + + # 정확도 계산 + acc = (binary_preds == all_targets).float().mean() + + # 혼동 행렬 요소 계산 + tp = torch.sum((binary_preds == 1) & (all_targets == 1)).float() + fp = torch.sum((binary_preds == 1) & (all_targets == 0)).float() + tn = torch.sum((binary_preds == 0) & (all_targets == 0)).float() + fn = torch.sum((binary_preds == 0) & (all_targets == 1)).float() + + # 메트릭 계산 + precision = tp / (tp + fp) if (tp + fp) > 0 else torch.tensor(0.0).to(tp.device) + recall = tp / (tp + fn) if (tp + fn) > 0 else torch.tensor(0.0).to(tp.device) + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else torch.tensor(0.0).to(tp.device) + specificity = tn / (tn + fp) if (tn + fp) > 0 else torch.tensor(0.0).to(tn.device) + + # 로깅 - 일관된 이름 사용 (val_epoch_f1 대신 val_f1 사용) + self.log('val_acc', acc, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('val_precision', precision, on_epoch=True, sync_dist=True) + self.log('val_recall', recall, on_epoch=True, sync_dist=True) + self.log('val_f1', f1, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('val_specificity', specificity, on_epoch=True, sync_dist=True) + + def on_test_epoch_end(self): + # 에폭이 끝날 때 전체 테스트 데이터에 대한 메트릭 계산 + if not hasattr(self, 'test_step_outputs') or not self.test_step_outputs: + return + + all_preds = torch.cat([x['preds'] for x in self.test_step_outputs]) + all_targets = torch.cat([x['targets'] for x in self.test_step_outputs]) + + # ROC-AUC 계산 (간단한 근사) + sorted_indices = torch.argsort(all_preds, descending=True) + sorted_targets = all_targets[sorted_indices] + + n_pos = torch.sum(all_targets) + n_neg = len(all_targets) - n_pos + + if n_pos > 0 and n_neg > 0: + # TPR과 FPR을 누적합으로 계산 + tpr_curve = torch.cumsum(sorted_targets, dim=0) / n_pos + fpr_curve = torch.cumsum(1 - sorted_targets, dim=0) / n_neg + + # AUC 계산 (사다리꼴 법칙) + width = fpr_curve[1:] - fpr_curve[:-1] + height = (tpr_curve[1:] + tpr_curve[:-1]) / 2 + auc_approx = torch.sum(width * height) + + self.log('test_auc', auc_approx, on_epoch=True, sync_dist=True) + + # 전체 데이터에 대한 메트릭 계산 + binary_preds = (all_preds > 0.5).float() + + # 정확도 계산 + acc = (binary_preds == all_targets).float().mean() + + # 혼동 행렬 요소 계산 + tp = torch.sum((binary_preds == 1) & (all_targets == 1)).float() + fp = torch.sum((binary_preds == 1) & (all_targets == 0)).float() + tn = torch.sum((binary_preds == 0) & (all_targets == 0)).float() + fn = torch.sum((binary_preds == 0) & (all_targets == 1)).float() + + # 메트릭 계산 + precision = tp / (tp + fp) if (tp + fp) > 0 else torch.tensor(0.0).to(tp.device) + recall = tp / (tp + fn) if (tp + fn) > 0 else torch.tensor(0.0).to(tp.device) + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else torch.tensor(0.0).to(tp.device) + specificity = tn / (tn + fp) if (tn + fp) > 0 else torch.tensor(0.0).to(tn.device) + balanced_acc = (recall + specificity) / 2 + + # 로깅 - 일관된 이름 사용 + self.log('test_acc', acc, on_epoch=True, prog_bar=True) + self.log('test_precision', precision, on_epoch=True) + self.log('test_recall', recall, on_epoch=True) + self.log('test_f1', f1, on_epoch=True, prog_bar=True) + self.log('test_specificity', specificity, on_epoch=True) + self.log('test_balanced_acc', balanced_acc, on_epoch=True) + + def configure_optimizers(self): + optimizer = FusedAdam(self.parameters(),lr=self.learning_rate, + weight_decay=0.01) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=100, # Adjust based on your training epochs + eta_min=1e-6 + ) + + return { + 'optimizer': optimizer, + 'lr_scheduler': scheduler, + 'monitor': 'val_loss', + } + + +def pad_sequence_with_mask(batch: List[Tuple[torch.Tensor, int]]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Collate function for DataLoader that creates padded sequences and attention masks with fixed length (48).""" + embeddings, labels = zip(*batch) + fixed_len = 48 # 고정 길이 + + batch_size = len(embeddings) + feat_dim = embeddings[0].shape[-1] + + padded = torch.zeros((batch_size, fixed_len, feat_dim)) # 고정 길이로 패딩된 텐서 + mask = torch.ones((batch_size, fixed_len), dtype=torch.bool) # True는 padding을 의미 + + for i, emb in enumerate(embeddings): + length = emb.shape[0] + + # 길이가 고정 길이보다 길면 자르고, 짧으면 패딩 + if length > fixed_len: + padded[i, :] = emb[:fixed_len] # fixed_len보다 긴 부분을 잘라서 채운다. + mask[i, :] = False + else: + padded[i, :length] = emb # 실제 데이터 길이에 맞게 채운다. + mask[i, :length] = False # 패딩이 아닌 부분은 False로 설정 + + return padded, torch.tensor(labels), mask + + +class SegmentTransformer(nn.Module): + def __init__(self, + input_dim: int, + hidden_dim: int = 256, + num_heads: int = 8, + num_layers: int = 4, + dropout: float = 0.1, + max_sequence_length: int = 1000, + mode: str = 'only_emb', + share_parameter: bool = False): + super().__init__() + + # Original sequence processing + self.input_projection = nn.Linear(input_dim, hidden_dim) + self.mode = mode + self.share_parameter = share_parameter + # Positional encoding + position = torch.arange(max_sequence_length).unsqueeze(1) + div_term = torch.exp(torch.arange(0, hidden_dim, 2) * (-np.log(10000.0) / hidden_dim)) + pos_encoding = torch.zeros(max_sequence_length, hidden_dim) + pos_encoding[:, 0::2] = torch.sin(position * div_term) + pos_encoding[:, 1::2] = torch.cos(position * div_term) + self.register_buffer('pos_encoding', pos_encoding) + + # Transformer for original sequence + encoder_layer = nn.TransformerEncoderLayer( + d_model=hidden_dim, + nhead=num_heads, + dim_feedforward=hidden_dim * 4, + dropout=dropout, + batch_first=True + ) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + # Self-similarity stream processing + self.similarity_projection = nn.Sequential( + nn.Conv1d(1, hidden_dim // 2, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv1d(hidden_dim // 2, hidden_dim, kernel_size=3, padding=1), + nn.ReLU(), + nn.Dropout(dropout) + ) + + # Transformer for similarity stream + self.similarity_transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + # Final classification head + self.classification_head_dim = hidden_dim * 2 if mode == 'both' else hidden_dim + self.classification_head = nn.Sequential( + nn.Linear(self.classification_head_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, hidden_dim // 2), + nn.LayerNorm(hidden_dim // 2), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim // 2, 1) + ) + + def forward(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + batch_size, seq_len, _ = x.shape + + # 1. Process original sequence + x1 = self.input_projection(x) + x1 = x1 + self.pos_encoding[:seq_len].unsqueeze(0) + x1 = self.transformer(x1, src_key_padding_mask=padding_mask) # padding_mask 사용 + + # 2. Calculate and process self-similarity + x_expanded = x.unsqueeze(2) + x_transposed = x.unsqueeze(1) + distances = torch.mean((x_expanded - x_transposed) ** 2, dim=-1) + similarity_matrix = torch.exp(-distances) # (batch_size, seq_len, seq_len) + + # 자기 유사도 마스크 생성 및 적용 (각 시점에 대한 마스크 개별 적용) + if padding_mask is not None: + similarity_mask = padding_mask.unsqueeze(1) | padding_mask.unsqueeze(2) # (batch_size, seq_len, seq_len) + similarity_matrix = similarity_matrix.masked_fill(similarity_mask, 0.0) + + # Process similarity matrix row by row using Conv1d + x2 = similarity_matrix.unsqueeze(1) # (batch_size, 1, seq_len, seq_len) + x2 = x2.view(batch_size * seq_len, 1, seq_len) # Reshape for Conv1d + x2 = self.similarity_projection(x2) # (batch_size * seq_len, hidden_dim, seq_len) + x2 = x2.mean(dim=2) # Pool across sequence dimension + x2 = x2.view(batch_size, seq_len, -1) # Reshape back + + x2 = x2 + self.pos_encoding[:seq_len].unsqueeze(0) + if self.share_parameter: + x2 = self.transformer(x2, src_key_padding_mask=padding_mask) + else: + x2 = self.transformer(x2, src_key_padding_mask=padding_mask) # padding_mask 사용 + + # 3. Global average pooling for both streams + if padding_mask is not None: + mask_expanded = (~padding_mask).float().unsqueeze(-1) + x1 = (x1 * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1) + x2 = (x2 * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1) + else: + x1 = x1.mean(dim=1) + x2 = x2.mean(dim=1) + + # 4. Combine both streams and classify + #x = x1 # only emb + #x = x2 # only structure + #x = torch.cat([x1, x2], dim=-1) + if self.mode == 'only_emb': + x = x1 + elif self.mode == 'only_structure': + x = x2 + elif self.mode == 'both': + x = torch.cat([x1, x2], dim=-1) + x = x.half() + return self.classification_head(x) diff --git a/preprocess.py b/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..bacb9a10a09c0165e0ee6ba54d24e6bb87d16992 --- /dev/null +++ b/preprocess.py @@ -0,0 +1,235 @@ +from beat_this.inference import File2Beats +import torchaudio +import torch +from pathlib import Path +import numpy as np +from collections import Counter +import os +import argparse +from tqdm import tqdm +import shutil +import concurrent.futures +from madmom.features.downbeats import DBNDownBeatTrackingProcessor +from madmom.features.downbeats import RNNDownBeatProcessor + + +def get_segments_from_wav(wav_path, device="cuda"): + """오디오 파일에서 비트와 다운비트를 추출합니다.""" + #try: + file2beats = File2Beats(checkpoint_path="final0", device="cuda", dbn=False) + all_models = ["final0", "final1", "final2", "small0", "small1", "small2","single_final0", "single_final1", "single_final2"] + beats, downbeats = file2beats(wav_path) + if len(downbeats) <1: + proc = DBNDownBeatTrackingProcessor(beats_per_bar=[3, 4], fps=100) + act = RNNDownBeatProcessor()(wav_path) + array = proc(act) + downbeats = array[:, 0] + return beats, downbeats#beats는 빈거 맞음 + + return beats, downbeats + #except Exception as e: + # print(f"Error extracting beats from {wav_path}: {str(e)}") + # return None, None + +def find_optimal_segment_length(downbeats, round_decimal=1, bar_length = 4): + """다운비트 간격들의 분포를 분석하여 최적의 4마디 길이와 정제된 다운비트 위치들을 반환합니다.""" + if len(downbeats) < 2: + return 10.0, downbeats # 기본 10초 길이 반환 + + # 연속된 downbeat 간의 간격 계산 + intervals = np.diff(downbeats) + rounded_intervals = np.round(intervals, round_decimal) + + # 가장 흔한 간격 찾기 (1마디 길이) + interval_counter = Counter(rounded_intervals) + most_common_interval = interval_counter.most_common(1)[0][0] + + # 정제된 downbeat 위치 찾기 + cleaned_downbeats = [downbeats[0]] # 첫 번째 위치는 항상 포함 + + for i in range(1, len(downbeats)): + interval = rounded_intervals[i-1] + # 현재 간격이 가장 흔한 간격과 비슷한지 확인 (10% 오차 허용) + if abs(interval - most_common_interval) <= most_common_interval * 0.1: + cleaned_downbeats.append(downbeats[i]) + + return float(most_common_interval * bar_length), np.array(cleaned_downbeats) + +def process_audio_file(audio_file, output_dir, temp_dir, device="cuda"): + """단일 오디오 파일을 처리하고 세그먼트를 추출합니다.""" + try: + output_dir = Path(output_dir) # output_dir을 Path 객체로 변환 + beats, downbeats = get_segments_from_wav(str(audio_file), device=device) + for bar_length in [1,2,3]: + # 문자열로 변환 후 "segments_wav"를 "segments_wav_숫자"로 대체 + dir_str = str(output_dir) + if "segments_wav" in dir_str: + new_dir_str = dir_str.replace("segments_wav", f"segments_wav_{bar_length}") + base_dir = Path(new_dir_str) + else: + # segments_wav가 없는 경우 처리 + base_dir = output_dir.parent / f"{output_dir.name}_{bar_length}" + + file_seg_dir = base_dir / audio_file.stem + file_seg_dir.mkdir(exist_ok=True, parents=True) + + # 비트 정보 추출 + + if beats is None or downbeats is None or len(downbeats) == 0: + print(f"No beat information extracted for {audio_file.name}, skipping...") + return 0 + + # 최적의 세그먼트 길이와 정제된 다운비트 찾기 + optimal_length, cleaned_downbeats = find_optimal_segment_length(downbeats, bar_length=bar_length) + + # 오디오 로드 + waveform, sample_rate = torchaudio.load(str(audio_file)) + if waveform.size(0) > 1: + waveform = torch.mean(waveform, dim=0, keepdim=True) + + total_duration = waveform.size(1) / sample_rate + segments_count = 0 + + # 각 다운비트에서 시작하는 세그먼트 생성 + for i, start_time in enumerate(cleaned_downbeats): + end_time = start_time + optimal_length + + # 마지막 세그먼트가 파일 길이를 초과하면 건너뛰기 + if end_time > total_duration: + continue + + start_sample = int(start_time * sample_rate) + end_sample = int(end_time * sample_rate) + + # 세그먼트 추출 및 저장 + segment = waveform[:, start_sample:end_sample] + save_path = file_seg_dir / f"segment_{i}.wav" + torchaudio.save(str(save_path), segment, sample_rate) + segments_count += 1 + + # 임시 비트 정보 저장 (필요시) + if temp_dir: + segments_data = {'beat': beats, 'downbeat': downbeats} + temp_path = temp_dir / f"{audio_file.stem}_segments.npy" + np.save(str(temp_path), segments_data) + + return segments_count + + except Exception as e: + print(f"Error processing {audio_file.name}: {str(e)}") + return 0 + +def segment_dataset(base_dir, output_base_dir, temp_dir=None, num_workers=4, device="cuda"): + """ISMIR2025 데이터셋의 full_length 폴더에서 세그먼트를 추출합니다.""" + base_path = Path(base_dir) + output_base_path = Path(output_base_dir) + + # 처리 통계 + stats = { + "processed_files": 0, + "extracted_segments": 0, + "failed_files": 0 + } + + # 임시 디렉토리 생성 (비트 정보 저장용) + if temp_dir: + temp_dir = Path(temp_dir) + temp_dir.mkdir(exist_ok=True) + + # Real과 Fake 오디오 모두 처리 + for label in ["real", "fake"]: + for split in ["train", "valid", "test"]: + input_dir = base_path / label / split + output_dir = output_base_path / label / split + + if not input_dir.exists(): + print(f"Directory not found: {input_dir}") + continue + + print(f"Processing {label}/{split} files...") + audio_files = list(input_dir.glob("*.wav")) + list(input_dir.glob("*.mp3")) + + if not audio_files: + print(f"No audio files found in {input_dir}") + continue + + # 병렬 처리 설정 + if num_workers > 1: + with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: + future_to_file = { + executor.submit(process_audio_file, file, output_dir, temp_dir, device): file + for file in audio_files + } + + for future in tqdm(concurrent.futures.as_completed(future_to_file), total=len(audio_files)): + file = future_to_file[future] + try: + segments_count = future.result() + if segments_count > 0: + stats["processed_files"] += 1 + stats["extracted_segments"] += segments_count + else: + stats["failed_files"] += 1 + except Exception as e: + print(f"Error processing {file.name}: {str(e)}") + stats["failed_files"] += 1 + else: + # 직렬 처리 + for file in tqdm(audio_files): + segments_count = process_audio_file(file, output_dir, temp_dir, device) + if segments_count > 0: + stats["processed_files"] += 1 + stats["extracted_segments"] += segments_count + else: + stats["failed_files"] += 1 + + # 최종 통계 보고 + print("\n=== Segmentation Summary ===") + print(f"Successfully processed files: {stats['processed_files']}") + print(f"Failed files: {stats['failed_files']}") + print(f"Total extracted segments: {stats['extracted_segments']}") + print(f"Average segments per file: {stats['extracted_segments'] / max(1, stats['processed_files']):.2f}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Extract segments from audio files in ISMIR2025 dataset") + parser.add_argument("--input", type=str, default="/data/datasets/ISMIR2025/full_length_audio", + help="Input directory with full_length audio files") + parser.add_argument("--output", type=str, default="/data/datasets/ISMIR2025/segments_wav", + help="Output directory for segments") + parser.add_argument("--temp", type=str, default=None, + help="Temporary directory for beat information (optional)") + parser.add_argument("--workers", type=int, default=4, + help="Number of parallel workers") + parser.add_argument("--device", type=str, default="cuda", + help="Device for beat extraction (cuda or cpu)") + + args = parser.parse_args() + + # 디렉토리 유효성 검사 + input_path = Path(args.input) + if not input_path.exists(): + print(f"Input directory not found: {args.input}") + # 다른 가능한 위치 확인 + alternatives = [ + "/data/datasets/ISMIR2025/full_length", + "/data/ISMIR2025/full_length_audio", + "/data/ISMIR2025/full_length" + ] + + for alt_path in alternatives: + if os.path.exists(alt_path): + print(f"Found alternative input path: {alt_path}") + args.input = alt_path + break + else: + print("No valid input directory found.") + exit(1) + + # 세그먼트 추출 실행 + segment_dataset( + base_dir=args.input, + output_base_dir=args.output, + temp_dir=args.temp, + num_workers=args.workers, + device=args.device + ) \ No newline at end of file diff --git a/quick-hangout-422604-e3-5f7544b49207.json b/quick-hangout-422604-e3-5f7544b49207.json new file mode 100644 index 0000000000000000000000000000000000000000..31423d1a48e2b7f55774d41f8ab71bd1cbd10c25 --- /dev/null +++ b/quick-hangout-422604-e3-5f7544b49207.json @@ -0,0 +1,13 @@ +{ + "type": "service_account", + "project_id": "quick-hangout-422604-e3", + "private_key_id": "5f7544b492071972ed28f236fc05f0db2b3f52e8", + "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDIOKhoGhyi0P4h\nrFTHrot+/mamOgcJklwRVGx/079krQWP6sdTv94D3sAG/xSx8hg10v1voxfAPGFM\nUENi/39lsutjcR4L2Ld0RaDb6jhq/Tuad0Cb49NIXESlGTPaTPuseXILzH7i+QYU\nmcD+l3lqXih2S6vGm1UvkvNVY1i2WAqCw0N81y7BbDjqWNodIIenvO+Ra9jNQXu9\nX9rEfqL0ppUlQSy+M3F9ky3UHFsU0C1tDTRmWzHR+wJp2hcXMOUTOkJ0XmfYeAPy\n3YmQG6BrUYQDqJimY1F+FySfytd5hwUBwgk0HVOuVGk04IZ9F6JvC5KzapXG4BON\nxxe0F3F9AgMBAAECggEAYVaSINIY+9qowcwbB1zG+n1JgCSTyUZ7Nf2aJebWlGY9\nXwMf1opfr/f52Sznbb8pn/ksNrrOUCnKj6Qxro5R7Co4n4adudqpDYDQPDm5JeJo\nuKajJWZ0ECizQqRm/gkRDX+ZPz0yrrusBhXdqqgPYfPWrfQJ1asslc1WOTvsI/D6\n7K1ZDUzVRJ6mFxmEjT40QrONIu616N1FWLjm2yqQmyGsflh4t9TADPYuLPZCZ7G3\n0DqT4tIdG1doiZ8FzzeDpw5eZMc9Hxt2mEnFe6vyIWJPULy9ZFIwqMHTkdDRu6rg\nw05rUrCVpvyfvMFrh1rTtYySAqbukTDUuYAb7m/2LQKBgQDvJGJ2ktjgXC/Cy03n\n5zCehL5+CEvWFoJ2BtlEH6CPcc9LpI24GPiXAm471ED2rImRmFakd7LaklG5T1Je\nx51rycM3/mBGv1tM2Zx2y3nKFOJ5aa4Z5w1n8G8hawfklVko1+XGeWA/THzRAJKt\nnVafSaJScxtBPr2m786RpOeCvwKBgQDWVeYsiD3fsen2HHmIhYI8/2q5H2tKdf3Q\nOYZXvwq0Mgw7atq4wT5sAOImawHHr9uWeNjSd1CwS7vMh8CKEWxqy6Nl33AKeBJs\nxx5VDW7XYSgk07eD34jo9Y0808byHbZQM7vb37Ec5xMei3k1kQRgWyNmKocZQ+5X\nvV9d78mmwwKBgEE167+ntYoguUlmBP161pAzZHqbqopbowGqRm7ELRVQlJVs7tRH\nwuny2Lpp27koPW89WksI9PWTNsPQdIax6iRtZVWMgRZpsezX3kmqnLBVV5iCD96y\nWb6BGtzDAej8LM7taJPhnzRDmDmp7VV4dRmEi8Xt6320LSUreWnPzO2jAoGBAJ+p\n1ZEO3KiqrBJ+G7qrWd1+l03YLeCGDND0STNMSPj630nTy6MdsRZbghwEUosiYX2y\nADKoVx89C2TNK4yudgkIMWxOCfSChZcqrVnGa+9dnL3ySR6fgimn4dKSH/10TL9q\nCmM8O3/AUunKFDznDk+JFNGilIkppX8OvuAHJDxpAoGBALIzCKyWOjcWsRz8vtmb\nH0Rayck3oWnA5aRxwJT2R4/x5t1339hzMKi/3ohauwih1LSIEGw11cbgE15m9GTx\nIr37S/fhaaqPOXHNyIKlCNqsUyS/dJKyOo/tEPj8fBsZw6pnTcGISXnWgCL1NN66\nL1twW2pErY8+iKdsHkn9nS32\n-----END PRIVATE KEY-----\n", + "client_email": "mippia-storage-service@quick-hangout-422604-e3.iam.gserviceaccount.com", + "client_id": "102015060996162241130", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/mippia-storage-service%40quick-hangout-422604-e3.iam.gserviceaccount.com", + "universe_domain": "googleapis.com" +} diff --git "a/v2_request/request_music/AI_Detection/9486_music_id_Lost Sky x Anna Yvette - Carry On \357\275\234 Trap \357\275\234 NCS - Copyright Free Music.wav" "b/v2_request/request_music/AI_Detection/9486_music_id_Lost Sky x Anna Yvette - Carry On \357\275\234 Trap \357\275\234 NCS - Copyright Free Music.wav" new file mode 100644 index 0000000000000000000000000000000000000000..7fe238af2bb5de4ac7c63d76fe133e3ac3a60fef --- /dev/null +++ "b/v2_request/request_music/AI_Detection/9486_music_id_Lost Sky x Anna Yvette - Carry On \357\275\234 Trap \357\275\234 NCS - Copyright Free Music.wav" @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d89e86b7510d3fe1fd6662f2e8ec54eb0bbf52bc73c22df54e663242142e7e69 +size 37400658 diff --git a/web_inference.py b/web_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..39158dc408666d6afcbf0e85a614e0635598f292 --- /dev/null +++ b/web_inference.py @@ -0,0 +1,54 @@ +from database import SessionLocal +from db_models.music import Music, AI_Detection_Music +from db_models.analysis import Music_Analysis, AnalysisStatus +from db_models.user import User +from google.cloud import storage +import uuid +import os +import torch +from inference import run_inference, load_audio, download_from_gcs +from inference import get_model +device = 'cuda' +backbone_model, input_dim = get_model('MERT',device) + +def do_web_inference(model, music_id): + """ + 음악 파일에 대해 AI 생성 탐지 추론을 실행합니다. + + Args: + music_id: 분석할 음악의 ID + task: 수행할 작업 유형 (기본값: None) + + Returns: + 분석 결과를 포함하는 딕셔너리 + """ + try: + # 데이터베이스에서 음악 정보 가져오기 + db = SessionLocal() + music = db.query(Music).filter(Music.id == music_id).first() + AI_Detection_Music = music.ai_detection_musics[0] + + print(music, music_id) + print(AI_Detection_Music.id) + + if not music: + return {"status": "error", "message": f"Music ID {music_id} not found"} + + # 파일 경로 가져오기 + wav_path = music.music_path + download_from_gcs('mippia-bucket', wav_path, wav_path) + segments, padding_mask = load_audio(wav_path, sr=24000) + segments = segments.to(device).to(torch.float32) + logits,embedding = backbone_model(segments.squeeze(1)) + embedding.to(device) + # 추론 실행 + results = run_inference(model, embedding, padding_mask, device=device) + + # 임시 파일 삭제 + if os.path.exists(wav_path): + os.remove(wav_path) + print(results) + finally: + # 데이터베이스 세션 종료 + if 'db' in locals(): + db.close() diff --git a/with_embedding_MERT_768_embedding/EmbeddingModel_MERT_768-epoch=0353-val_loss=0.3866-val_acc=0.9809-val_f1=0.9803-val_precision=0.9764-val_recall=0.9842.ckpt b/with_embedding_MERT_768_embedding/EmbeddingModel_MERT_768-epoch=0353-val_loss=0.3866-val_acc=0.9809-val_f1=0.9803-val_precision=0.9764-val_recall=0.9842.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..c49a6db17f4816b5bccecc1f9b6fe260feabb4dc --- /dev/null +++ b/with_embedding_MERT_768_embedding/EmbeddingModel_MERT_768-epoch=0353-val_loss=0.3866-val_acc=0.9809-val_f1=0.9803-val_precision=0.9764-val_recall=0.9842.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:762e36f2a259c7e46c712aecafef534576b455ee6cecf6f602d57b8ce1e7eabb +size 95150792 diff --git a/worker.py b/worker.py new file mode 100644 index 0000000000000000000000000000000000000000..977e1fe0d20d2e32a99982035b54f0eb0919a6b8 --- /dev/null +++ b/worker.py @@ -0,0 +1,24 @@ +from celery_app import AI_detection_celery_app +from web_inference import do_web_inference +from model_with_pure_bert import MusicAudioClassifier +from ISMIR_2025.MERT.networks import MERTFeatureExtractor + + +checkpoint_path = "with_embedding_MERT_768_embedding/EmbeddingModel_MERT_768-epoch=0353-val_loss=0.3866-val_acc=0.9809-val_f1=0.9803-val_precision=0.9764-val_recall=0.9842.ckpt" # 실제 경로로 조정 +backbone_model = MERTFeatureExtractor() +input_dim = 768 +# 모델 로드 +model = MusicAudioClassifier.load_from_checkpoint( + checkpoint_path, + input_dim=input_dim, + emb_model=backbone_model, + strict=False +) + +@AI_detection_celery_app.task(name='AI_detection_task') +def AI_detection_task(data): + music_id = data.get('music_id') + do_web_inference(model, music_id) + return {"status":"done"} + +