Spaces:
Runtime error
Runtime error
from sklearn.metrics import confusion_matrix | |
import matplotlib.pyplot as plt | |
import numpy as np | |
def plot_confusion_matrix(y_true, y_pred, classes, writer, epoch): | |
cm = confusion_matrix(y_true, y_pred) | |
fig, ax = plt.subplots(figsize=(6, 6)) | |
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) | |
ax.figure.colorbar(im, ax=ax) | |
num_classes = cm.shape[0] | |
tick_labels = classes[:num_classes] | |
ax.set(xticks=np.arange(num_classes), | |
yticks=np.arange(num_classes), | |
xticklabels=tick_labels, | |
yticklabels=tick_labels, | |
ylabel='True label', | |
xlabel='Predicted label') | |
thresh = cm.max() / 2. | |
for i in range(cm.shape[0]): | |
for j in range(cm.shape[1]): | |
ax.text(j, i, format(cm[i, j], 'd'), | |
ha="center", va="center", | |
color="white" if cm[i, j] > thresh else "black") | |
fig.tight_layout() | |
writer.add_figure("Confusion Matrix", fig, epoch) |