Spaces:
Runtime error
Runtime error
File size: 4,847 Bytes
3cecacc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import os
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from sklearn.metrics import precision_score, recall_score, f1_score, balanced_accuracy_score, confusion_matrix
from datalib import (
FakeMusicCapsDataset,
closed_test_files, closed_test_labels,
open_test_files, open_test_labels,
val_files, val_labels
)
from networks import MERTFeatureExtractor
import argparse
parser = argparse.ArgumentParser(description="AI Music Detection Testing with MERT")
parser.add_argument('--gpu', type=str, default='1', help='GPU ID')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
parser.add_argument('--ckpt_path', type=str, default="/data/kym/AI_Music_Detection/Code/model/MERT/ckpt/1e-3/mert_finetune_10.pth", help='Path to the pretrained checkpoint')
parser.add_argument('--model_name', type=str, default="mert", help="Model name")
parser.add_argument('--closed_test', action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)")
parser.add_argument('--open_test', action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)")
parser.add_argument('--output_path', type=str, default='', help='Path to save test results')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def plot_confusion_matrix(y_true, y_pred, classes, output_path):
cm = confusion_matrix(y_true, y_pred)
fig, ax = plt.subplots(figsize=(6, 6))
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)
num_classes = cm.shape[0]
tick_labels = classes[:num_classes]
ax.set(xticks=np.arange(num_classes),
yticks=np.arange(num_classes),
xticklabels=tick_labels,
yticklabels=tick_labels,
ylabel='True label',
xlabel='Predicted label')
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], 'd'),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()
plt.savefig(output_path)
plt.close(fig)
model = MERTFeatureExtractor().to(device)
ckpt_file = args.ckpt_path
if not os.path.exists(ckpt_file):
raise FileNotFoundError(f"Checkpoint not found: {ckpt_file}")
print(f"\nLoading MERT model from {ckpt_file}")
model.load_state_dict(torch.load(ckpt_file, map_location=device))
model.eval()
torch.cuda.empty_cache()
if args.closed_test:
print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...")
test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels, target_duration=10.0)
elif args.open_test:
print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...")
test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels, target_duration=10.0)
else:
print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...")
test_dataset = FakeMusicCapsDataset(val_files, val_labels, target_duration=10.0)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8)
def test_mert(model, test_loader, device):
model.eval()
test_loss, test_correct, test_total = 0, 0, 0
all_preds, all_labels = [], []
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
loss = F.cross_entropy(output, target)
test_loss += loss.item() * data.size(0)
preds = output.argmax(dim=1)
test_correct += (preds == target).sum().item()
test_total += target.size(0)
all_labels.extend(target.cpu().numpy())
all_preds.extend(preds.cpu().numpy())
test_loss /= test_total
test_acc = test_correct / test_total
test_bal_acc = balanced_accuracy_score(all_labels, all_preds)
test_precision = precision_score(all_labels, all_preds, average="binary")
test_recall = recall_score(all_labels, all_preds, average="binary")
test_f1 = f1_score(all_labels, all_preds, average="binary")
print(f"\nTest Results - Loss: {test_loss:.4f} | Test Acc: {test_acc:.3f} | "
f"Test B_ACC: {test_bal_acc:.4f} | Test Prec: {test_precision:.3f} | "
f"Test Rec: {test_recall:.3f} | Test F1: {test_f1:.3f}")
os.makedirs(args.output_path, exist_ok=True)
conf_matrix_path = os.path.join(args.output_path, f"confusion_matrix_{args.model_name}.png")
plot_confusion_matrix(all_labels, all_preds, classes=["real", "generative"], output_path=conf_matrix_path)
print("\nEvaluating MERT Model on Test Set...")
test_mert(model, test_loader, device)
|