Spectrum / validate.py
nilekhet's picture
Upload 6 files
b743670
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import load_model
from sklearn.metrics import classification_report, confusion_matrix
import pickle
# Load the saved model
model = load_model("malware_classifier_lime.h5")
data_dir = 'Malign/extract'
# Load the number of classes from the cache file
with open("cache.pkl", "rb") as f:
num_classes = pickle.load(f)
# Parameters
batch_size = 32
image_size = (200, 200)
# Data preprocessing
test_datagen = ImageDataGenerator(rescale=1./255)
test_generator = test_datagen.flow_from_directory(
data_dir,
target_size=image_size,
batch_size=batch_size,
class_mode='categorical',
shuffle=False
)
# Evaluate the model
print("Evaluating the model...")
score = model.evaluate(test_generator)
print("Loss: ", score[0])
print("Accuracy: ", score[1])
# Predict the class labels
print("Predicting the class labels...")
y_pred = model.predict(test_generator)
y_pred_classes = np.argmax(y_pred, axis=1)
# Classification report
print("Classification report:")
print(classification_report(test_generator.classes, y_pred_classes, target_names=test_generator.class_indices.keys()))
# Confusion matrix
print("Confusion matrix:")
print(confusion_matrix(test_generator.classes, y_pred_classes))