Spaces:
Runtime error
Runtime error
import os | |
from PIL import Image | |
import numpy as np | |
import torch | |
from torchvision import transforms | |
from device import get_device | |
from simple_nn import SimpleNN | |
from simple_cnn import SimpleCNN | |
from view_image import view_image, view_tensor_image | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.5,), (0.5,)) | |
]) | |
def predict_image(image): | |
model = SimpleCNN() | |
with open('mnist_simple_cnn.pht', 'rb') as f: | |
state_dict = torch.load(f, weights_only=True) | |
model.load_state_dict(state_dict) | |
model.eval() | |
image = image.convert('RGBA') | |
grayscale_image = Image.new("L", image.size, 255) # Create a white background | |
grayscale_image.paste(image.convert("L"), mask=image.split()[3]) # Use alpha channel as mask | |
grayscale_image = grayscale_image.resize((28, 28)) # Resize to 28x28 pixels | |
grayscale_image.save("processed_image.png") | |
image_np = np.array(grayscale_image) | |
image_np = 255 - image_np # Invert colors (MNIST has white digits on black) | |
# Normalize to range [0, 1] | |
image_np = image_np / 255.0 | |
image_tensor = transform(image_np) # Add batch and channel dimensions | |
image_tensor = image_tensor.unsqueeze(0) | |
image_tensor = image_tensor.to(torch.float32) | |
# image_tensor = transform(grayscale_image).unsqueeze(0) # Add batch and channel dimensions | |
with torch.no_grad(): | |
output = model(image_tensor) | |
#_, predicted = torch.max(output.data, 1) | |
probabilities = torch.softmax(output, dim=1) | |
# Convert probabilities to a list of (class, probability) | |
class_probabilities = { | |
str(class_index): prob.item() for class_index, prob in enumerate(probabilities[0]) | |
} | |
print(class_probabilities) | |
# class_probabilities = {} | |
return class_probabilities | |
def predict(model_path, image_path): | |
model = SimpleCNN() | |
with open(model_path, 'rb') as f: | |
state_dict = torch.load(f, weights_only=True) | |
model.load_state_dict(state_dict) | |
model.eval() | |
# Load and preprocess the image | |
image = Image.open(image_path).convert("L") # Convert to grayscale | |
# view_image(image=image) | |
# Resize to 28x28 | |
image = image.resize((28, 28)) | |
# Convert to NumPy array and invert colors if needed | |
image_np = np.array(image) | |
image_np = 255 - image_np # Invert colors (MNIST has white digits on black) | |
# Normalize to range [0, 1] | |
image_np = image_np / 255.0 | |
# Convert to tensor | |
image_tensor = transform(image_np) # Add batch and channel dimensions | |
image_tensor = image_tensor.unsqueeze(0) | |
image_tensor = image_tensor.to(torch.float32) | |
# Ensure the tensor is in the correct dtype | |
# view_tensor_image(image_tensor=image_tensor) | |
with torch.no_grad(): | |
output = model(image_tensor) | |
#_, predicted = torch.max(output.data, 1) | |
probabilities = torch.softmax(output, dim=1) | |
# Convert probabilities to a list of (class, probability) | |
class_probabilities = { | |
str(class_index): prob.item() for class_index, prob in enumerate(probabilities[0]) | |
} | |
# return predicted.item() | |
return class_probabilities | |
if __name__ == "__main__": | |
device = get_device() | |
model_path = "trained_model/mnist_simple_cnn.pht" | |
# Loop through all files in the test folder | |
test_folder = "test/" | |
for filename in os.listdir(test_folder): | |
if filename.endswith(".png"): # Only process .png files (you can add more extensions if needed) | |
image_path = os.path.join(test_folder, filename) | |
predicted = predict(model_path = "mnist_model.pht",image_path=image_path) | |
print(F"[INFO] The predicted results of the image {image_path} are: {predicted}") | |
print() | |