import torch from sklearn.metrics import classification_report, confusion_matrix import matplotlib.pyplot as plt import seaborn as sns import numpy as np from utils.data_loader import get_data_loaders from models.resnet_model import MonkeyResNet # using your ResNet model # Set data path, batch size, and model file data_dir = "data" batch_size = 32 model_path = "models/monkey_resnet.pth" # Load validation data (no training data needed here) _, val_loader, class_names = get_data_loaders(data_dir, batch_size) # Load the trained model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = MonkeyResNet(num_classes=len(class_names)).to(device) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() # disable training-specific features like dropout # Store all predictions and actual labels all_preds = [] all_labels = [] # Go through the validation data to collect predictions with torch.no_grad(): for images, labels in val_loader: images = images.to(device) outputs = model(images) _, predicted = torch.max(outputs, 1) all_preds.extend(predicted.cpu().numpy()) all_labels.extend(labels.numpy()) # Generate confusion matrix plot cm = confusion_matrix(all_labels, all_preds) plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap="Blues") plt.xlabel("Predicted") plt.ylabel("Actual") plt.title("Confusion Matrix") plt.tight_layout() plt.savefig("confusion_matrix.png") plt.show() # Print the classification report in text format print("Classification Report:") print(classification_report(all_labels, all_preds, target_names=class_names))