Spaces:
Sleeping
Sleeping
File size: 1,709 Bytes
9e6a96a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from utils.data_loader import get_data_loaders
from models.resnet_model import MonkeyResNet # using your ResNet model
# Set data path, batch size, and model file
data_dir = "data"
batch_size = 32
model_path = "models/monkey_resnet.pth"
# Load validation data (no training data needed here)
_, val_loader, class_names = get_data_loaders(data_dir, batch_size)
# Load the trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MonkeyResNet(num_classes=len(class_names)).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval() # disable training-specific features like dropout
# Store all predictions and actual labels
all_preds = []
all_labels = []
# Go through the validation data to collect predictions
with torch.no_grad():
for images, labels in val_loader:
images = images.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
all_preds.extend(predicted.cpu().numpy())
all_labels.extend(labels.numpy())
# Generate confusion matrix plot
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap="Blues")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.savefig("confusion_matrix.png")
plt.show()
# Print the classification report in text format
print("Classification Report:")
print(classification_report(all_labels, all_preds, target_names=class_names))
|