File size: 1,709 Bytes
9e6a96a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

from utils.data_loader import get_data_loaders
from models.resnet_model import MonkeyResNet  # using your ResNet model

# Set data path, batch size, and model file
data_dir = "data"
batch_size = 32
model_path = "models/monkey_resnet.pth"

# Load validation data (no training data needed here)
_, val_loader, class_names = get_data_loaders(data_dir, batch_size)

# Load the trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MonkeyResNet(num_classes=len(class_names)).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()  # disable training-specific features like dropout

# Store all predictions and actual labels
all_preds = []
all_labels = []

# Go through the validation data to collect predictions
with torch.no_grad():
    for images, labels in val_loader:
        images = images.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.numpy())

# Generate confusion matrix plot
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap="Blues")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.savefig("confusion_matrix.png")
plt.show()

# Print the classification report in text format
print("Classification Report:")
print(classification_report(all_labels, all_preds, target_names=class_names))