handwritten-digit-recognition / view_model_information.py
quanglnt's picture
Add application files
8c36119
raw
history blame
784 Bytes
import os
from PIL import Image
import numpy as np
import torch
from torchvision import transforms
from device import get_device
from simple_nn import SimpleNN
from simple_cnn import SimpleCNN
from view_image import view_image, view_tensor_image
model_path = "mnist_model.pht"
# Load model
model = SimpleCNN()
with open(model_path, 'rb') as f:
state_dict = torch.load(f, weights_only=True)
model.load_state_dict(state_dict)
# View model information
print(model) # Display the model architecture
# For more detailed information about the model's parameters:
print(f"Model summary: {model}")
# You can also view the parameters' details (e.g., number of parameters, layers, etc.)
for name, param in model.named_parameters():
print(f"Parameter: {name}, Shape: {param.shape}")