aimusicdetection / ISMIR_2025 /wav2vec /utils /confusion_matrix_plot.py
nininigold's picture
Upload folder using huggingface_hub
3cecacc verified
raw
history blame
988 Bytes
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
def plot_confusion_matrix(y_true, y_pred, classes, writer, epoch):
cm = confusion_matrix(y_true, y_pred)
fig, ax = plt.subplots(figsize=(6, 6))
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)
num_classes = cm.shape[0]
tick_labels = classes[:num_classes]
ax.set(xticks=np.arange(num_classes),
yticks=np.arange(num_classes),
xticklabels=tick_labels,
yticklabels=tick_labels,
ylabel='True label',
xlabel='Predicted label')
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], 'd'),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()
writer.add_figure("Confusion Matrix", fig, epoch)