quanglnt's picture
Add application files
8c36119
import os
from PIL import Image
import numpy as np
import torch
from torchvision import transforms
from device import get_device
from simple_nn import SimpleNN
from simple_cnn import SimpleCNN
from view_image import view_image, view_tensor_image
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
def predict_image(image):
model = SimpleCNN()
with open('mnist_simple_cnn.pht', 'rb') as f:
state_dict = torch.load(f, weights_only=True)
model.load_state_dict(state_dict)
model.eval()
image = image.convert('RGBA')
grayscale_image = Image.new("L", image.size, 255) # Create a white background
grayscale_image.paste(image.convert("L"), mask=image.split()[3]) # Use alpha channel as mask
grayscale_image = grayscale_image.resize((28, 28)) # Resize to 28x28 pixels
grayscale_image.save("processed_image.png")
image_np = np.array(grayscale_image)
image_np = 255 - image_np # Invert colors (MNIST has white digits on black)
# Normalize to range [0, 1]
image_np = image_np / 255.0
image_tensor = transform(image_np) # Add batch and channel dimensions
image_tensor = image_tensor.unsqueeze(0)
image_tensor = image_tensor.to(torch.float32)
# image_tensor = transform(grayscale_image).unsqueeze(0) # Add batch and channel dimensions
with torch.no_grad():
output = model(image_tensor)
#_, predicted = torch.max(output.data, 1)
probabilities = torch.softmax(output, dim=1)
# Convert probabilities to a list of (class, probability)
class_probabilities = {
str(class_index): prob.item() for class_index, prob in enumerate(probabilities[0])
}
print(class_probabilities)
# class_probabilities = {}
return class_probabilities
def predict(model_path, image_path):
model = SimpleCNN()
with open(model_path, 'rb') as f:
state_dict = torch.load(f, weights_only=True)
model.load_state_dict(state_dict)
model.eval()
# Load and preprocess the image
image = Image.open(image_path).convert("L") # Convert to grayscale
# view_image(image=image)
# Resize to 28x28
image = image.resize((28, 28))
# Convert to NumPy array and invert colors if needed
image_np = np.array(image)
image_np = 255 - image_np # Invert colors (MNIST has white digits on black)
# Normalize to range [0, 1]
image_np = image_np / 255.0
# Convert to tensor
image_tensor = transform(image_np) # Add batch and channel dimensions
image_tensor = image_tensor.unsqueeze(0)
image_tensor = image_tensor.to(torch.float32)
# Ensure the tensor is in the correct dtype
# view_tensor_image(image_tensor=image_tensor)
with torch.no_grad():
output = model(image_tensor)
#_, predicted = torch.max(output.data, 1)
probabilities = torch.softmax(output, dim=1)
# Convert probabilities to a list of (class, probability)
class_probabilities = {
str(class_index): prob.item() for class_index, prob in enumerate(probabilities[0])
}
# return predicted.item()
return class_probabilities
if __name__ == "__main__":
device = get_device()
model_path = "trained_model/mnist_simple_cnn.pht"
# Loop through all files in the test folder
test_folder = "test/"
for filename in os.listdir(test_folder):
if filename.endswith(".png"): # Only process .png files (you can add more extensions if needed)
image_path = os.path.join(test_folder, filename)
predicted = predict(model_path = "mnist_model.pht",image_path=image_path)
print(F"[INFO] The predicted results of the image {image_path} are: {predicted}")
print()