handwritten-digit-recognition / view_model_information.py
quanglnt's picture
Add application files
8c36119
raw
history blame contribute delete
784 Bytes
import os
from PIL import Image
import numpy as np
import torch
from torchvision import transforms
from device import get_device
from simple_nn import SimpleNN
from simple_cnn import SimpleCNN
from view_image import view_image, view_tensor_image
model_path = "mnist_model.pht"
# Load model
model = SimpleCNN()
with open(model_path, 'rb') as f:
state_dict = torch.load(f, weights_only=True)
model.load_state_dict(state_dict)
# View model information
print(model) # Display the model architecture
# For more detailed information about the model's parameters:
print(f"Model summary: {model}")
# You can also view the parameters' details (e.g., number of parameters, layers, etc.)
for name, param in model.named_parameters():
print(f"Parameter: {name}, Shape: {param.shape}")