SpiralSense / evaluate.py
cycool29's picture
Update
73666ad
raw
history blame
6.36 kB
import torch
import numpy as np
import pathlib
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib import rcParams
from sklearn.metrics import (
classification_report,
precision_recall_curve,
accuracy_score,
f1_score,
confusion_matrix,
matthews_corrcoef,
ConfusionMatrixDisplay,
roc_curve,
auc,
average_precision_score,
cohen_kappa_score,
)
from sklearn.preprocessing import label_binarize
from configs import *
rcParams["font.family"] = "Times New Roman"
# Load the model
model = MODEL.to(DEVICE)
model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
model.eval()
# model2 = EfficientNetB3WithDropout(num_classes=NUM_CLASSES).to(DEVICE)
# model2.load_state_dict(torch.load("output/checkpoints/EfficientNetB3WithDropout.pth"))
# model1 = SqueezeNet1_0WithSE(num_classes=NUM_CLASSES).to(DEVICE)
# model1.load_state_dict(torch.load("output/checkpoints/SqueezeNet1_0WithSE.pth"))
# model3 = MobileNetV2WithDropout(num_classes=NUM_CLASSES).to(DEVICE)
# model3.load_state_dict(torch.load("output\checkpoints\MobileNetV2WithDropout.pth"))
# model1.eval()
# model2.eval()
# model3.eval()
# # Load the model
# model = WeightedVoteEnsemble([model1, model2, model3], [0.38, 0.34, 0.28])
# # model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
# model.load_state_dict(
# torch.load("output/checkpoints/WeightedVoteEnsemble.pth", map_location=DEVICE)
# )
# model.eval()
def predict_image(image_path, model, transform):
model.eval()
correct_predictions = 0
# Get a list of image files
images = list(pathlib.Path(image_path).rglob("*.png"))
total_predictions = len(images)
true_classes = []
predicted_labels = []
predicted_scores = [] # To store predicted class probabilities
with torch.no_grad():
for image_file in images:
print("---------------------------")
# Check the true label of the image by checking the sequence of the folder in Task 1
true_class = CLASSES.index(image_file.parts[-2])
print("Image path:", image_file)
print("True class:", true_class)
image = Image.open(image_file).convert("RGB")
image = transform(image).unsqueeze(0)
image = image.to(DEVICE)
output = model(image)
predicted_class = torch.argmax(output, dim=1).item()
# Print the predicted class
print("Predicted class:", predicted_class)
# Append true and predicted labels to their respective lists
true_classes.append(true_class)
predicted_labels.append(predicted_class)
predicted_scores.append(
output.softmax(dim=1).cpu().numpy()
) # Store predicted class probabilities
# Check if the prediction is correct
if predicted_class == true_class:
correct_predictions += 1
# Calculate accuracy and f1 score
accuracy = accuracy_score(true_classes, predicted_labels)
print("Accuracy:", accuracy)
f1 = f1_score(true_classes, predicted_labels, average="weighted")
print("Weighted F1 Score:", f1)
# Convert the lists to tensors
predicted_labels_tensor = torch.tensor(predicted_labels)
true_classes_tensor = torch.tensor(true_classes)
# Calculate the confusion matrix
conf_matrix = confusion_matrix(
true_classes,
predicted_labels,
)
# Plot the confusion matrix
ConfusionMatrixDisplay(confusion_matrix=conf_matrix, display_labels=CLASSES).plot(
cmap=plt.cm.Blues, xticks_rotation=25
)
# Use the exported value of margin_left to adjust the space between the yticklabels and the yticks
plt.subplots_adjust(
top=0.935,
bottom=0.155,
left=0.125,
right=0.905,
hspace=0.2,
wspace=0.2,
)
plt.title("Confusion Matrix")
manager = plt.get_current_fig_manager()
manager.full_screen_toggle()
plt.savefig("docs/evaluation/confusion_matrix.png")
plt.show()
# Classification report
class_names = CLASSES
report = classification_report(
true_classes, predicted_labels, target_names=class_names
)
print("Classification Report:\n", report)
# Calculate precision and recall for each class
true_classes_binary = label_binarize(true_classes, classes=range(NUM_CLASSES))
precision, recall, _ = precision_recall_curve(
true_classes_binary.ravel(), np.array(predicted_scores).ravel()
)
fpr, tpr, _ = roc_curve(
true_classes_binary.ravel(), np.array(predicted_scores).ravel()
)
auc_roc = auc(fpr, tpr)
print("AUC-ROC:", auc_roc)
# Calculate PRC AUC
precision, recall, _ = precision_recall_curve(
true_classes_binary.ravel(), np.array(predicted_scores).ravel()
)
auc_prc = average_precision_score(
true_classes_binary.ravel(), np.array(predicted_scores).ravel()
)
print("AUC PRC:", auc_prc)
# Plot precision-recall curve
plt.figure(figsize=(10, 6))
plt.plot(recall, precision)
plt.title("Precision-Recall Curve")
plt.xlabel("Recall")
plt.ylabel("Precision")
# Show the AUC value on the plot
plt.text(
0.6,
0.2,
"AUC-PRC = {:.3f}".format(auc_prc),
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
)
plt.savefig("docs/evaluation/prc.png")
plt.show()
# Plot ROC curve
plt.figure(figsize=(10, 6))
plt.plot(fpr, tpr)
plt.title("ROC Curve")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
# Show the AUC value on the plot
plt.text(
0.6,
0.2,
"AUC-ROC = {:.3f}".format(auc_roc),
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
)
plt.savefig("docs/evaluation/roc.png")
plt.show()
# Matthew's correlation coefficient
print("Matthew's correlation coefficient:", matthews_corrcoef(true_classes, predicted_labels))
# Cohen's kappa
print("Cohen's kappa:", cohen_kappa_score(true_classes, predicted_labels))
predict_image("data/test/Task 1/", model, preprocess)
# 89 EfficientNetB2WithDropout / 0.873118944547516
# 89 MobileNetV2WithDropout / 0.8731189445475158
# 89 SqueezeNet1_0WithSE / .8865856365856365