Spaces:
Sleeping
Sleeping
import torch | |
from sklearn.metrics import classification_report, confusion_matrix | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import numpy as np | |
from utils.data_loader import get_data_loaders | |
from models.resnet_model import MonkeyResNet # using your ResNet model | |
# Set data path, batch size, and model file | |
data_dir = "data" | |
batch_size = 32 | |
model_path = "models/monkey_resnet.pth" | |
# Load validation data (no training data needed here) | |
_, val_loader, class_names = get_data_loaders(data_dir, batch_size) | |
# Load the trained model | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = MonkeyResNet(num_classes=len(class_names)).to(device) | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
model.eval() # disable training-specific features like dropout | |
# Store all predictions and actual labels | |
all_preds = [] | |
all_labels = [] | |
# Go through the validation data to collect predictions | |
with torch.no_grad(): | |
for images, labels in val_loader: | |
images = images.to(device) | |
outputs = model(images) | |
_, predicted = torch.max(outputs, 1) | |
all_preds.extend(predicted.cpu().numpy()) | |
all_labels.extend(labels.numpy()) | |
# Generate confusion matrix plot | |
cm = confusion_matrix(all_labels, all_preds) | |
plt.figure(figsize=(10, 8)) | |
sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap="Blues") | |
plt.xlabel("Predicted") | |
plt.ylabel("Actual") | |
plt.title("Confusion Matrix") | |
plt.tight_layout() | |
plt.savefig("confusion_matrix.png") | |
plt.show() | |
# Print the classification report in text format | |
print("Classification Report:") | |
print(classification_report(all_labels, all_preds, target_names=class_names)) | |