Spaces:
Runtime error
Runtime error
import os | |
from PIL import Image | |
import numpy as np | |
import torch | |
from torchvision import transforms | |
from device import get_device | |
from simple_nn import SimpleNN | |
from simple_cnn import SimpleCNN | |
from view_image import view_image, view_tensor_image | |
model_path = "mnist_model.pht" | |
# Load model | |
model = SimpleCNN() | |
with open(model_path, 'rb') as f: | |
state_dict = torch.load(f, weights_only=True) | |
model.load_state_dict(state_dict) | |
# View model information | |
print(model) # Display the model architecture | |
# For more detailed information about the model's parameters: | |
print(f"Model summary: {model}") | |
# You can also view the parameters' details (e.g., number of parameters, layers, etc.) | |
for name, param in model.named_parameters(): | |
print(f"Parameter: {name}, Shape: {param.shape}") |