Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
import pathlib | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from matplotlib import rcParams | |
from sklearn.metrics import ( | |
classification_report, | |
precision_recall_curve, | |
accuracy_score, | |
f1_score, | |
confusion_matrix, | |
matthews_corrcoef, | |
ConfusionMatrixDisplay, | |
roc_curve, | |
auc, | |
average_precision_score, | |
cohen_kappa_score, | |
) | |
from sklearn.preprocessing import label_binarize | |
from configs import * | |
rcParams["font.family"] = "Times New Roman" | |
# Load the model | |
model = MODEL.to(DEVICE) | |
model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE)) | |
model.eval() | |
# model2 = EfficientNetB3WithDropout(num_classes=NUM_CLASSES).to(DEVICE) | |
# model2.load_state_dict(torch.load("output/checkpoints/EfficientNetB3WithDropout.pth")) | |
# model1 = SqueezeNet1_0WithSE(num_classes=NUM_CLASSES).to(DEVICE) | |
# model1.load_state_dict(torch.load("output/checkpoints/SqueezeNet1_0WithSE.pth")) | |
# model3 = MobileNetV2WithDropout(num_classes=NUM_CLASSES).to(DEVICE) | |
# model3.load_state_dict(torch.load("output\checkpoints\MobileNetV2WithDropout.pth")) | |
# model1.eval() | |
# model2.eval() | |
# model3.eval() | |
# # Load the model | |
# model = WeightedVoteEnsemble([model1, model2, model3], [0.38, 0.34, 0.28]) | |
# # model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE)) | |
# model.load_state_dict( | |
# torch.load("output/checkpoints/WeightedVoteEnsemble.pth", map_location=DEVICE) | |
# ) | |
# model.eval() | |
def predict_image(image_path, model, transform): | |
model.eval() | |
correct_predictions = 0 | |
# Get a list of image files | |
images = list(pathlib.Path(image_path).rglob("*.png")) | |
total_predictions = len(images) | |
true_classes = [] | |
predicted_labels = [] | |
predicted_scores = [] # To store predicted class probabilities | |
with torch.no_grad(): | |
for image_file in images: | |
print("---------------------------") | |
# Check the true label of the image by checking the sequence of the folder in Task 1 | |
true_class = CLASSES.index(image_file.parts[-2]) | |
print("Image path:", image_file) | |
print("True class:", true_class) | |
image = Image.open(image_file).convert("RGB") | |
image = transform(image).unsqueeze(0) | |
image = image.to(DEVICE) | |
output = model(image) | |
predicted_class = torch.argmax(output, dim=1).item() | |
# Print the predicted class | |
print("Predicted class:", predicted_class) | |
# Append true and predicted labels to their respective lists | |
true_classes.append(true_class) | |
predicted_labels.append(predicted_class) | |
predicted_scores.append( | |
output.softmax(dim=1).cpu().numpy() | |
) # Store predicted class probabilities | |
# Check if the prediction is correct | |
if predicted_class == true_class: | |
correct_predictions += 1 | |
# Calculate accuracy and f1 score | |
accuracy = accuracy_score(true_classes, predicted_labels) | |
print("Accuracy:", accuracy) | |
f1 = f1_score(true_classes, predicted_labels, average="weighted") | |
print("Weighted F1 Score:", f1) | |
# Convert the lists to tensors | |
predicted_labels_tensor = torch.tensor(predicted_labels) | |
true_classes_tensor = torch.tensor(true_classes) | |
# Calculate the confusion matrix | |
conf_matrix = confusion_matrix( | |
true_classes, | |
predicted_labels, | |
) | |
# Plot the confusion matrix | |
ConfusionMatrixDisplay(confusion_matrix=conf_matrix, display_labels=CLASSES).plot( | |
cmap=plt.cm.Blues, xticks_rotation=25 | |
) | |
# Use the exported value of margin_left to adjust the space between the yticklabels and the yticks | |
plt.subplots_adjust( | |
top=0.935, | |
bottom=0.155, | |
left=0.125, | |
right=0.905, | |
hspace=0.2, | |
wspace=0.2, | |
) | |
plt.title("Confusion Matrix") | |
manager = plt.get_current_fig_manager() | |
manager.full_screen_toggle() | |
plt.savefig("docs/evaluation/confusion_matrix.png") | |
plt.show() | |
# Classification report | |
class_names = CLASSES | |
report = classification_report( | |
true_classes, predicted_labels, target_names=class_names | |
) | |
print("Classification Report:\n", report) | |
# Calculate precision and recall for each class | |
true_classes_binary = label_binarize(true_classes, classes=range(NUM_CLASSES)) | |
precision, recall, _ = precision_recall_curve( | |
true_classes_binary.ravel(), np.array(predicted_scores).ravel() | |
) | |
fpr, tpr, _ = roc_curve( | |
true_classes_binary.ravel(), np.array(predicted_scores).ravel() | |
) | |
auc_roc = auc(fpr, tpr) | |
print("AUC-ROC:", auc_roc) | |
# Calculate PRC AUC | |
precision, recall, _ = precision_recall_curve( | |
true_classes_binary.ravel(), np.array(predicted_scores).ravel() | |
) | |
auc_prc = average_precision_score( | |
true_classes_binary.ravel(), np.array(predicted_scores).ravel() | |
) | |
print("AUC PRC:", auc_prc) | |
# Plot precision-recall curve | |
plt.figure(figsize=(10, 6)) | |
plt.plot(recall, precision) | |
plt.title("Precision-Recall Curve") | |
plt.xlabel("Recall") | |
plt.ylabel("Precision") | |
# Show the AUC value on the plot | |
plt.text( | |
0.6, | |
0.2, | |
"AUC-PRC = {:.3f}".format(auc_prc), | |
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8), | |
) | |
plt.savefig("docs/evaluation/prc.png") | |
plt.show() | |
# Plot ROC curve | |
plt.figure(figsize=(10, 6)) | |
plt.plot(fpr, tpr) | |
plt.title("ROC Curve") | |
plt.xlabel("False Positive Rate") | |
plt.ylabel("True Positive Rate") | |
# Show the AUC value on the plot | |
plt.text( | |
0.6, | |
0.2, | |
"AUC-ROC = {:.3f}".format(auc_roc), | |
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8), | |
) | |
plt.savefig("docs/evaluation/roc.png") | |
plt.show() | |
# Matthew's correlation coefficient | |
print("Matthew's correlation coefficient:", matthews_corrcoef(true_classes, predicted_labels)) | |
# Cohen's kappa | |
print("Cohen's kappa:", cohen_kappa_score(true_classes, predicted_labels)) | |
predict_image("data/test/Task 1/", model, preprocess) | |
# 89 EfficientNetB2WithDropout / 0.873118944547516 | |
# 89 MobileNetV2WithDropout / 0.8731189445475158 | |
# 89 SqueezeNet1_0WithSE / .8865856365856365 | |