nininigold's picture
Upload folder using huggingface_hub
3cecacc verified
raw
history blame
9.81 kB
import os
import glob
import librosa
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
from sklearn.metrics import precision_score, recall_score, f1_score
from tqdm import tqdm
import argparse
import wandb
class RealFakeDataset(Dataset):
"""
audio/FakeMusicCaps/
β”œβ”€ real/
β”‚ └─ MusicCaps/*.wav (label=0)
└─ generative/
└─ .../*.wav (label=1)
"""
def __init__(self, root_dir, sr=16000, n_mels=64, target_duration=10.0):
self.sr = sr
self.n_mels = n_mels
self.target_duration = target_duration
self.target_samples = int(target_duration * sr) # 10초 = 160,000 μƒ˜ν”Œ
self.file_paths = []
self.labels = []
# Real 데이터 (label=0)
real_dir = os.path.join(root_dir, "real")
real_wav_files = glob.glob(os.path.join(real_dir, "**", "*.wav"), recursive=True)
for f in real_wav_files:
self.file_paths.append(f)
self.labels.append(0)
# Generative 데이터 (label=1)
gen_dir = os.path.join(root_dir, "generative")
gen_wav_files = glob.glob(os.path.join(gen_dir, "**", "*.wav"), recursive=True)
for f in gen_wav_files:
self.file_paths.append(f)
self.labels.append(1)
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
audio_path = self.file_paths[idx]
label = self.labels[idx]
# print(f"[DEBUG] Path: {audio_path}, Label: {label}") # μΆ”κ°€
waveform, sr = librosa.load(audio_path, sr=self.sr, mono=True)
current_samples = waveform.shape[0]
if current_samples > self.target_samples:
waveform = waveform[:self.target_samples]
elif current_samples < self.target_samples:
stretch_factor = self.target_samples / current_samples
waveform = librosa.effects.time_stretch(waveform, rate=stretch_factor)
waveform = waveform[:self.target_samples]
mfcc = librosa.feature.mfcc(
y=waveform, sr=self.sr, n_mfcc=self.n_mels, n_fft=1024, hop_length=256
)
mfcc = librosa.util.normalize(mfcc)
mfcc = np.expand_dims(mfcc, axis=0)
mfcc_tensor = torch.tensor(mfcc, dtype=torch.float)
label_tensor = torch.tensor(label, dtype=torch.long)
return mfcc_tensor, label_tensor
class AudioCNN(nn.Module):
def __init__(self, num_classes=2):
super(AudioCNN, self).__init__()
self.conv_block = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.AdaptiveAvgPool2d((4,4)) # μ΅œμ’… -> (B,32,4,4)
)
self.fc_block = nn.Sequential(
nn.Linear(32*4*4, 128),
nn.ReLU(),
nn.Linear(128, num_classes)
)
def forward(self, x):
x = self.conv_block(x)
# x.shape: (B,32,new_freq,new_time)
# 1) Flatten
B, C, H, W = x.shape # 동적 shape
x = x.view(B, -1) # (B, 32*H*W)
# 2) FC
x = self.fc_block(x)
return x
def my_collate_fn(batch):
mel_list, label_list = zip(*batch)
max_frames = max(m.shape[2] for m in mel_list)
padded = []
for m in mel_list:
diff = max_frames - m.shape[2]
if diff > 0:
print(f"Padding applied: Original frames = {m.shape[2]}, Target frames = {max_frames}")
m = F.pad(m, (0, diff), mode='constant', value=0)
padded.append(m)
mel_batch = torch.stack(padded, dim=0)
label_batch = torch.tensor(label_list, dtype=torch.long)
return mel_batch, label_batch
class EarlyStopping:
def __init__(self, patience=5, delta=0, path='./ckpt/mfcc/early_stop_best_batch_{batch_size}_epochs_{epochs}_lr_{learning_rate}.pth', verbose=False):
self.patience = patience
self.delta = delta
self.path = path
self.verbose = verbose
self.counter = 0
self.best_loss = None
self.early_stop = False
def __call__(self, val_loss, model):
if self.best_loss is None:
self.best_loss = val_loss
self._save_checkpoint(val_loss, model)
elif val_loss > self.best_loss - self.delta:
self.counter += 1
if self.verbose:
print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_loss = val_loss
self._save_checkpoint(val_loss, model)
self.counter = 0
def _save_checkpoint(self, val_loss, model):
if self.verbose:
print(f"Validation loss decreased ({self.best_loss:.6f} --> {val_loss:.6f}). Saving model ...")
torch.save(model.state_dict(), self.path)
def train(batch_size, epochs, learning_rate, root_dir="audio/FakeMusicCaps"):
if not os.path.exists("./ckpt/mfcc/"):
os.makedirs("./ckpt/mfcc/")
wandb.init(
project="AI Music Detection",
name=f"mfcc_batch_{batch_size}_epochs_{epochs}_lr_{learning_rate}",
config={"batch_size": batch_size, "epochs": epochs, "learning_rate": learning_rate},
)
dataset = RealFakeDataset(root_dir=root_dir)
n_total = len(dataset)
n_train = int(n_total * 0.8)
n_val = n_total - n_train
train_ds, val_ds = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=my_collate_fn)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AudioCNN(num_classes=2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
best_val_loss = float('inf')
patience = 3
patience_counter = 0
for epoch in range(1, epochs + 1):
print(f"\n[Epoch {epoch}/{epochs}]")
# Training
model.train()
train_loss, train_correct, train_total = 0, 0, 0
train_pbar = tqdm(train_loader, desc="Train", leave=False)
for mel_batch, labels in train_pbar:
mel_batch, labels = mel_batch.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(mel_batch)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item() * mel_batch.size(0)
preds = outputs.argmax(dim=1)
train_correct += (preds == labels).sum().item()
train_total += labels.size(0)
train_pbar.set_postfix({"loss": f"{loss.item():.4f}"})
train_loss /= train_total
train_acc = train_correct / train_total
# Validation
model.eval()
val_loss, val_correct, val_total = 0, 0, 0
all_preds, all_labels = [], []
val_pbar = tqdm(val_loader, desc=" Val ", leave=False)
with torch.no_grad():
for mel_batch, labels in val_pbar:
mel_batch, labels = mel_batch.to(device), labels.to(device)
outputs = model(mel_batch)
loss = criterion(outputs, labels)
val_loss += loss.item() * mel_batch.size(0)
preds = outputs.argmax(dim=1)
val_correct += (preds == labels).sum().item()
val_total += labels.size(0)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
val_loss /= val_total
val_acc = val_correct / val_total
val_precision = precision_score(all_labels, all_preds, average="macro")
val_recall = recall_score(all_labels, all_preds, average="macro")
val_f1 = f1_score(all_labels, all_preds, average="macro")
print(f"Train Loss: {train_loss:.4f} Acc: {train_acc:.3f} | "
f"Val Loss: {val_loss:.4f} Acc: {val_acc:.3f} "
f"Precision: {val_precision:.3f} Recall: {val_recall:.3f} F1: {val_f1:.3f}")
wandb.log({"train_loss": train_loss, "train_acc": train_acc,
"val_loss": val_loss, "val_acc": val_acc,
"val_precision": val_precision, "val_recall": val_recall, "val_f1": val_f1})
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
best_model_path = f"./ckpt/mfcc/best_batch_{batch_size}_epochs_{epochs}_lr_{learning_rate}.pth"
torch.save(model.state_dict(), best_model_path)
print(f"[INFO] New best model saved: {best_model_path}")
else:
patience_counter += 1
if patience_counter >= patience:
print("Early stopping triggered!")
break
wandb.finish()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train AI Music Detection model.")
parser.add_argument('--batch_size', type=int, required=True, help="Batch size for training")
parser.add_argument('--epochs', type=int, required=True, help="Number of epochs")
parser.add_argument('--learning_rate', type=float, required=True, help="Learning rate")
parser.add_argument('--root_dir', type=str, default="audio/FakeMusicCaps", help="Root directory for dataset")
args = parser.parse_args()
train(batch_size=args.batch_size, epochs=args.epochs, learning_rate=args.learning_rate, root_dir=args.root_dir)