File size: 6,362 Bytes
97dcf92 c45c4e1 97dcf92 c45c4e1 73666ad 97dcf92 c45c4e1 97dcf92 c45c4e1 97dcf92 c45c4e1 97dcf92 c45c4e1 73666ad 97dcf92 c45c4e1 97dcf92 73666ad 97dcf92 c45c4e1 97dcf92 73666ad 97dcf92 c45c4e1 97dcf92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import torch
import numpy as np
import pathlib
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib import rcParams
from sklearn.metrics import (
classification_report,
precision_recall_curve,
accuracy_score,
f1_score,
confusion_matrix,
matthews_corrcoef,
ConfusionMatrixDisplay,
roc_curve,
auc,
average_precision_score,
cohen_kappa_score,
)
from sklearn.preprocessing import label_binarize
from configs import *
rcParams["font.family"] = "Times New Roman"
# Load the model
model = MODEL.to(DEVICE)
model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
model.eval()
# model2 = EfficientNetB3WithDropout(num_classes=NUM_CLASSES).to(DEVICE)
# model2.load_state_dict(torch.load("output/checkpoints/EfficientNetB3WithDropout.pth"))
# model1 = SqueezeNet1_0WithSE(num_classes=NUM_CLASSES).to(DEVICE)
# model1.load_state_dict(torch.load("output/checkpoints/SqueezeNet1_0WithSE.pth"))
# model3 = MobileNetV2WithDropout(num_classes=NUM_CLASSES).to(DEVICE)
# model3.load_state_dict(torch.load("output\checkpoints\MobileNetV2WithDropout.pth"))
# model1.eval()
# model2.eval()
# model3.eval()
# # Load the model
# model = WeightedVoteEnsemble([model1, model2, model3], [0.38, 0.34, 0.28])
# # model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
# model.load_state_dict(
# torch.load("output/checkpoints/WeightedVoteEnsemble.pth", map_location=DEVICE)
# )
# model.eval()
def predict_image(image_path, model, transform):
model.eval()
correct_predictions = 0
# Get a list of image files
images = list(pathlib.Path(image_path).rglob("*.png"))
total_predictions = len(images)
true_classes = []
predicted_labels = []
predicted_scores = [] # To store predicted class probabilities
with torch.no_grad():
for image_file in images:
print("---------------------------")
# Check the true label of the image by checking the sequence of the folder in Task 1
true_class = CLASSES.index(image_file.parts[-2])
print("Image path:", image_file)
print("True class:", true_class)
image = Image.open(image_file).convert("RGB")
image = transform(image).unsqueeze(0)
image = image.to(DEVICE)
output = model(image)
predicted_class = torch.argmax(output, dim=1).item()
# Print the predicted class
print("Predicted class:", predicted_class)
# Append true and predicted labels to their respective lists
true_classes.append(true_class)
predicted_labels.append(predicted_class)
predicted_scores.append(
output.softmax(dim=1).cpu().numpy()
) # Store predicted class probabilities
# Check if the prediction is correct
if predicted_class == true_class:
correct_predictions += 1
# Calculate accuracy and f1 score
accuracy = accuracy_score(true_classes, predicted_labels)
print("Accuracy:", accuracy)
f1 = f1_score(true_classes, predicted_labels, average="weighted")
print("Weighted F1 Score:", f1)
# Convert the lists to tensors
predicted_labels_tensor = torch.tensor(predicted_labels)
true_classes_tensor = torch.tensor(true_classes)
# Calculate the confusion matrix
conf_matrix = confusion_matrix(
true_classes,
predicted_labels,
)
# Plot the confusion matrix
ConfusionMatrixDisplay(confusion_matrix=conf_matrix, display_labels=CLASSES).plot(
cmap=plt.cm.Blues, xticks_rotation=25
)
# Use the exported value of margin_left to adjust the space between the yticklabels and the yticks
plt.subplots_adjust(
top=0.935,
bottom=0.155,
left=0.125,
right=0.905,
hspace=0.2,
wspace=0.2,
)
plt.title("Confusion Matrix")
manager = plt.get_current_fig_manager()
manager.full_screen_toggle()
plt.savefig("docs/evaluation/confusion_matrix.png")
plt.show()
# Classification report
class_names = CLASSES
report = classification_report(
true_classes, predicted_labels, target_names=class_names
)
print("Classification Report:\n", report)
# Calculate precision and recall for each class
true_classes_binary = label_binarize(true_classes, classes=range(NUM_CLASSES))
precision, recall, _ = precision_recall_curve(
true_classes_binary.ravel(), np.array(predicted_scores).ravel()
)
fpr, tpr, _ = roc_curve(
true_classes_binary.ravel(), np.array(predicted_scores).ravel()
)
auc_roc = auc(fpr, tpr)
print("AUC-ROC:", auc_roc)
# Calculate PRC AUC
precision, recall, _ = precision_recall_curve(
true_classes_binary.ravel(), np.array(predicted_scores).ravel()
)
auc_prc = average_precision_score(
true_classes_binary.ravel(), np.array(predicted_scores).ravel()
)
print("AUC PRC:", auc_prc)
# Plot precision-recall curve
plt.figure(figsize=(10, 6))
plt.plot(recall, precision)
plt.title("Precision-Recall Curve")
plt.xlabel("Recall")
plt.ylabel("Precision")
# Show the AUC value on the plot
plt.text(
0.6,
0.2,
"AUC-PRC = {:.3f}".format(auc_prc),
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
)
plt.savefig("docs/evaluation/prc.png")
plt.show()
# Plot ROC curve
plt.figure(figsize=(10, 6))
plt.plot(fpr, tpr)
plt.title("ROC Curve")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
# Show the AUC value on the plot
plt.text(
0.6,
0.2,
"AUC-ROC = {:.3f}".format(auc_roc),
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
)
plt.savefig("docs/evaluation/roc.png")
plt.show()
# Matthew's correlation coefficient
print("Matthew's correlation coefficient:", matthews_corrcoef(true_classes, predicted_labels))
# Cohen's kappa
print("Cohen's kappa:", cohen_kappa_score(true_classes, predicted_labels))
predict_image("data/test/Task 1/", model, preprocess)
# 89 EfficientNetB2WithDropout / 0.873118944547516
# 89 MobileNetV2WithDropout / 0.8731189445475158
# 89 SqueezeNet1_0WithSE / .8865856365856365
|