Spaces:
Runtime error
Runtime error
Add application files
Browse files- app.py +25 -0
- device.py +15 -0
- predict.py +105 -0
- simple_cnn.py +28 -0
- simple_nn.py +17 -0
- test.py +130 -0
- train.py +109 -0
- view_image.py +63 -0
- view_model_information.py +26 -0
app.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from predict import predict_image
|
3 |
+
# Custom CSS to make the label text bigger
|
4 |
+
def predict(image_dict):
|
5 |
+
# Extract the "composite" key from the dictionary
|
6 |
+
composite_image = image_dict["composite"]
|
7 |
+
# composite_image.save("sketchpad_output.png") # Save as PNG
|
8 |
+
predicted = predict_image(composite_image)
|
9 |
+
# print(predicted)
|
10 |
+
return predicted #, composite_image # Directly return the PIL image
|
11 |
+
css = """
|
12 |
+
.big-label {
|
13 |
+
font-size: 24px; /* Adjust this value to make the label bigger */
|
14 |
+
font-weight: bold; /* Optional: to make it bold */
|
15 |
+
}
|
16 |
+
"""
|
17 |
+
demo = gr.Interface(
|
18 |
+
fn=predict,
|
19 |
+
inputs=gr.Sketchpad(type="pil", brush=gr.Brush(default_size=20)), # Ensure it returns a PIL image
|
20 |
+
outputs=[gr.Label(num_top_classes=3, label="Predicted number is:", elem_classes=["big-label"])],
|
21 |
+
css=css
|
22 |
+
)
|
23 |
+
|
24 |
+
if __name__ == "__main__":
|
25 |
+
demo.launch()
|
device.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def get_device():
|
4 |
+
if torch.cuda.is_available:
|
5 |
+
# print('cuda is available')
|
6 |
+
return 'cuda'
|
7 |
+
elif torch.backends.mps.is_available:
|
8 |
+
# print('mps is available')
|
9 |
+
return 'mps'
|
10 |
+
else:
|
11 |
+
# print('using cpu')
|
12 |
+
return 'cpu'
|
13 |
+
|
14 |
+
device = get_device()
|
15 |
+
# print(device)
|
predict.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torchvision import transforms
|
6 |
+
from device import get_device
|
7 |
+
from simple_nn import SimpleNN
|
8 |
+
from simple_cnn import SimpleCNN
|
9 |
+
from view_image import view_image, view_tensor_image
|
10 |
+
|
11 |
+
transform = transforms.Compose([
|
12 |
+
transforms.ToTensor(),
|
13 |
+
transforms.Normalize((0.5,), (0.5,))
|
14 |
+
])
|
15 |
+
def predict_image(image):
|
16 |
+
|
17 |
+
model = SimpleCNN()
|
18 |
+
with open('mnist_simple_cnn.pht', 'rb') as f:
|
19 |
+
state_dict = torch.load(f, weights_only=True)
|
20 |
+
model.load_state_dict(state_dict)
|
21 |
+
model.eval()
|
22 |
+
|
23 |
+
image = image.convert('RGBA')
|
24 |
+
grayscale_image = Image.new("L", image.size, 255) # Create a white background
|
25 |
+
grayscale_image.paste(image.convert("L"), mask=image.split()[3]) # Use alpha channel as mask
|
26 |
+
grayscale_image = grayscale_image.resize((28, 28)) # Resize to 28x28 pixels
|
27 |
+
|
28 |
+
grayscale_image.save("processed_image.png")
|
29 |
+
|
30 |
+
image_np = np.array(grayscale_image)
|
31 |
+
image_np = 255 - image_np # Invert colors (MNIST has white digits on black)
|
32 |
+
|
33 |
+
# Normalize to range [0, 1]
|
34 |
+
image_np = image_np / 255.0
|
35 |
+
|
36 |
+
image_tensor = transform(image_np) # Add batch and channel dimensions
|
37 |
+
image_tensor = image_tensor.unsqueeze(0)
|
38 |
+
image_tensor = image_tensor.to(torch.float32)
|
39 |
+
|
40 |
+
# image_tensor = transform(grayscale_image).unsqueeze(0) # Add batch and channel dimensions
|
41 |
+
|
42 |
+
with torch.no_grad():
|
43 |
+
output = model(image_tensor)
|
44 |
+
#_, predicted = torch.max(output.data, 1)
|
45 |
+
probabilities = torch.softmax(output, dim=1)
|
46 |
+
# Convert probabilities to a list of (class, probability)
|
47 |
+
class_probabilities = {
|
48 |
+
str(class_index): prob.item() for class_index, prob in enumerate(probabilities[0])
|
49 |
+
}
|
50 |
+
print(class_probabilities)
|
51 |
+
# class_probabilities = {}
|
52 |
+
return class_probabilities
|
53 |
+
|
54 |
+
def predict(model_path, image_path):
|
55 |
+
model = SimpleCNN()
|
56 |
+
with open(model_path, 'rb') as f:
|
57 |
+
state_dict = torch.load(f, weights_only=True)
|
58 |
+
model.load_state_dict(state_dict)
|
59 |
+
model.eval()
|
60 |
+
|
61 |
+
# Load and preprocess the image
|
62 |
+
image = Image.open(image_path).convert("L") # Convert to grayscale
|
63 |
+
# view_image(image=image)
|
64 |
+
# Resize to 28x28
|
65 |
+
image = image.resize((28, 28))
|
66 |
+
|
67 |
+
# Convert to NumPy array and invert colors if needed
|
68 |
+
image_np = np.array(image)
|
69 |
+
image_np = 255 - image_np # Invert colors (MNIST has white digits on black)
|
70 |
+
|
71 |
+
# Normalize to range [0, 1]
|
72 |
+
image_np = image_np / 255.0
|
73 |
+
|
74 |
+
# Convert to tensor
|
75 |
+
|
76 |
+
image_tensor = transform(image_np) # Add batch and channel dimensions
|
77 |
+
image_tensor = image_tensor.unsqueeze(0)
|
78 |
+
image_tensor = image_tensor.to(torch.float32)
|
79 |
+
# Ensure the tensor is in the correct dtype
|
80 |
+
|
81 |
+
# view_tensor_image(image_tensor=image_tensor)
|
82 |
+
|
83 |
+
with torch.no_grad():
|
84 |
+
output = model(image_tensor)
|
85 |
+
#_, predicted = torch.max(output.data, 1)
|
86 |
+
probabilities = torch.softmax(output, dim=1)
|
87 |
+
# Convert probabilities to a list of (class, probability)
|
88 |
+
class_probabilities = {
|
89 |
+
str(class_index): prob.item() for class_index, prob in enumerate(probabilities[0])
|
90 |
+
}
|
91 |
+
# return predicted.item()
|
92 |
+
return class_probabilities
|
93 |
+
if __name__ == "__main__":
|
94 |
+
device = get_device()
|
95 |
+
model_path = "trained_model/mnist_simple_cnn.pht"
|
96 |
+
|
97 |
+
# Loop through all files in the test folder
|
98 |
+
test_folder = "test/"
|
99 |
+
for filename in os.listdir(test_folder):
|
100 |
+
if filename.endswith(".png"): # Only process .png files (you can add more extensions if needed)
|
101 |
+
image_path = os.path.join(test_folder, filename)
|
102 |
+
predicted = predict(model_path = "mnist_model.pht",image_path=image_path)
|
103 |
+
print(F"[INFO] The predicted results of the image {image_path} are: {predicted}")
|
104 |
+
print()
|
105 |
+
|
simple_cnn.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class SimpleCNN(nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
super(SimpleCNN, self).__init__()
|
8 |
+
# Convolutional layers
|
9 |
+
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1) # 28x28 -> 28x28
|
10 |
+
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) # 28x28 -> 28x28
|
11 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # 28x28 -> 14x14
|
12 |
+
|
13 |
+
self.bn1 = nn.BatchNorm2d(32)
|
14 |
+
self.bn2 = nn.BatchNorm2d(64)
|
15 |
+
|
16 |
+
# Fully connected layers
|
17 |
+
self.fc1 = nn.Linear(64 * 14 * 14, 128)
|
18 |
+
self.dropout = nn.Dropout(0.5)
|
19 |
+
self.fc2 = nn.Linear(128, 10)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = F.relu(self.bn1(self.conv1(x))) # Apply first convolution and ReLU
|
23 |
+
x = self.pool(F.relu(self.bn2(self.conv2(x)))) # Apply second convolution, ReLU, and pooling
|
24 |
+
x = torch.flatten(x, 1) # Flatten the feature maps
|
25 |
+
x = F.relu(self.fc1(x)) # Fully connected layer with ReLU
|
26 |
+
x = self.dropout(x)
|
27 |
+
x = self.fc2(x) # Output layer
|
28 |
+
return x
|
simple_nn.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class SimpleNN(nn.Module):
|
5 |
+
def __init__(self):
|
6 |
+
super(SimpleNN, self).__init__()
|
7 |
+
self.fc1 = nn.Linear(28*28, 128)
|
8 |
+
self.fc2 = nn.Linear(128, 64)
|
9 |
+
self.fc3 = nn.Linear(64, 10)
|
10 |
+
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
x = x.view(-1, 28*28)
|
14 |
+
x = torch.relu(self.fc1(x))
|
15 |
+
x = torch.relu(self.fc2(x))
|
16 |
+
x = self.fc3(x)
|
17 |
+
return x
|
test.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.optim as optim
|
6 |
+
from torchvision import datasets, transforms
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from torch.utils.tensorboard import SummaryWriter
|
9 |
+
from simple_nn import SimpleNN
|
10 |
+
from simple_cnn import SimpleCNN
|
11 |
+
from device import get_device
|
12 |
+
from view_image import view_image,view_batch_images, save_batch_images
|
13 |
+
from tqdm import tqdm # Import tqdm for the progress bar
|
14 |
+
|
15 |
+
root_path = os.path.expanduser('data')
|
16 |
+
|
17 |
+
# Define transforms for training and testing
|
18 |
+
transforms = {
|
19 |
+
'train': transforms.Compose([
|
20 |
+
transforms.RandomRotation(10),
|
21 |
+
transforms.RandomHorizontalFlip(),
|
22 |
+
transforms.ToTensor(),
|
23 |
+
transforms.Normalize((0.5,), (0.5,)) # Normalize for MNIST
|
24 |
+
]),
|
25 |
+
'valid_test': transforms.Compose([
|
26 |
+
transforms.ToTensor(),
|
27 |
+
transforms.Normalize((0.5,), (0.5,))
|
28 |
+
])
|
29 |
+
}
|
30 |
+
|
31 |
+
# Define dataset and dataloader
|
32 |
+
train_dataset = datasets.MNIST(root=root_path, download=True, train=True, transform=transforms['train'])
|
33 |
+
test_dataset = datasets.MNIST(root=root_path, download=True, train=False, transform=transforms['valid_test'])
|
34 |
+
|
35 |
+
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
|
36 |
+
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)
|
37 |
+
|
38 |
+
model = SimpleCNN()
|
39 |
+
device = get_device()
|
40 |
+
|
41 |
+
model.to(device=device)
|
42 |
+
criterion = nn.CrossEntropyLoss()
|
43 |
+
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
44 |
+
epochs = 50
|
45 |
+
for epoch in range(epochs):
|
46 |
+
model.train() # Set the model to training mode
|
47 |
+
epoch_loss = 0 # Initialize epoch loss
|
48 |
+
correct = 0 # Track number of correct predictions
|
49 |
+
total = 0 # Track total predictions
|
50 |
+
|
51 |
+
with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs}", unit="batch") as pbar:
|
52 |
+
for batch_idx, (data, target) in enumerate(train_loader):
|
53 |
+
data, target = data.to(device), target.to(device) # Move data to device
|
54 |
+
|
55 |
+
# Forward pass
|
56 |
+
output = model(data)
|
57 |
+
loss = criterion(output, target)
|
58 |
+
|
59 |
+
# Backward pass and optimization
|
60 |
+
optimizer.zero_grad()
|
61 |
+
loss.backward()
|
62 |
+
optimizer.step()
|
63 |
+
|
64 |
+
epoch_loss += loss.item() # Add current batch loss to epoch loss
|
65 |
+
|
66 |
+
# Calculate accuracy for this batch
|
67 |
+
_, predicted = torch.max(output, 1)
|
68 |
+
total += target.size(0)
|
69 |
+
correct += (predicted == target).sum().item()
|
70 |
+
|
71 |
+
# Update the progress bar
|
72 |
+
pbar.set_postfix(loss=loss.item(), accuracy=100. * correct / total)
|
73 |
+
pbar.update(1)
|
74 |
+
|
75 |
+
# Calculate the average loss and accuracy for this epoch
|
76 |
+
avg_loss = epoch_loss / len(train_loader)
|
77 |
+
accuracy = 100. * correct / total
|
78 |
+
|
79 |
+
|
80 |
+
print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")
|
81 |
+
|
82 |
+
# Now validation (on validation set)
|
83 |
+
correct_val = 0
|
84 |
+
total_val = 0
|
85 |
+
model.eval() # Set the model to evaluation mode
|
86 |
+
with torch.no_grad():
|
87 |
+
for data, target in test_loader: # Assuming `val_loader` is your validation data loader
|
88 |
+
data, target = data.to(device), target.to(device)
|
89 |
+
output = model(data)
|
90 |
+
_, predicted = torch.max(output.data, 1)
|
91 |
+
total_val += target.size(0)
|
92 |
+
correct_val += (predicted == target).sum().item()
|
93 |
+
|
94 |
+
accuracy_val = 100 * correct_val / total_val
|
95 |
+
print(f'Validation Accuracy: {accuracy_val:.2f}%')
|
96 |
+
|
97 |
+
|
98 |
+
# Save model after training
|
99 |
+
torch.save(model.state_dict(), "mnist_simple_cnn.pht")
|
100 |
+
print(f'Model saved')
|
101 |
+
|
102 |
+
|
103 |
+
model.eval()
|
104 |
+
|
105 |
+
test_folder = "test/"
|
106 |
+
|
107 |
+
# Loop through all files in the test folder
|
108 |
+
for filename in os.listdir(test_folder):
|
109 |
+
if filename.endswith(".png"): # Only process .png files (you can add more extensions if needed)
|
110 |
+
image_path = os.path.join(test_folder, filename)
|
111 |
+
|
112 |
+
# Load and preprocess the image
|
113 |
+
image = Image.open(image_path).convert("L") # Convert to grayscale
|
114 |
+
image = image.resize((28, 28)) # Resize to 28x28 pixels
|
115 |
+
image_tensor = transforms['valid_test'](image).unsqueeze(0) # Add batch and channel dimensions
|
116 |
+
image_tensor = image_tensor.to(device)
|
117 |
+
|
118 |
+
# Make prediction
|
119 |
+
with torch.no_grad():
|
120 |
+
output = model(image_tensor)
|
121 |
+
probabilities = torch.softmax(output, dim=1)
|
122 |
+
|
123 |
+
# Convert probabilities to a list of (class, probability)
|
124 |
+
class_probabilities = {
|
125 |
+
str(class_index): prob.item() for class_index, prob in enumerate(probabilities[0])
|
126 |
+
}
|
127 |
+
|
128 |
+
# Print or store the predictions for the current image
|
129 |
+
print(f"Predictions for {filename}: {class_probabilities}")
|
130 |
+
print() # Line break for separation
|
train.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.optim as optim
|
5 |
+
from torchvision import datasets, transforms
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from torch.utils.tensorboard import SummaryWriter
|
8 |
+
from simple_nn import SimpleNN
|
9 |
+
from simple_cnn import SimpleCNN
|
10 |
+
from device import get_device
|
11 |
+
from view_image import view_batch_images, save_batch_images
|
12 |
+
from tqdm import tqdm # Import tqdm for the progress bar
|
13 |
+
|
14 |
+
|
15 |
+
def train(model, device, train_loader, test_loader, criterion, optimizer, epochs = 5):
|
16 |
+
|
17 |
+
# Initialize TensorBoard writer
|
18 |
+
# writer = SummaryWriter('runs/mnist_experiment')
|
19 |
+
|
20 |
+
# Train the model
|
21 |
+
for epoch in range(epochs):
|
22 |
+
model.train()
|
23 |
+
epoch_loss = 0
|
24 |
+
with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs}", unit="batch") as pbar:
|
25 |
+
for batch_idx, (data, target) in enumerate(train_loader):
|
26 |
+
# Forward pass
|
27 |
+
data, target = data.to(device), target.to(device)
|
28 |
+
output = model(data)
|
29 |
+
loss = criterion(output, target)
|
30 |
+
|
31 |
+
# Backward pass and optimization
|
32 |
+
optimizer.zero_grad()
|
33 |
+
loss.backward()
|
34 |
+
optimizer.step()
|
35 |
+
|
36 |
+
epoch_loss += loss.item()
|
37 |
+
|
38 |
+
pbar.set_postfix(loss=loss.item())
|
39 |
+
pbar.update(1)
|
40 |
+
|
41 |
+
# Log the average loss for this epoch to TensorBoard
|
42 |
+
# avg_loss = epoch_loss / len(train_loader)
|
43 |
+
# writer.add_scalar('Loss/train', avg_loss, epoch)
|
44 |
+
|
45 |
+
print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}')
|
46 |
+
validation(model,device,test_loader)
|
47 |
+
# After training, visualize with TensorBoard
|
48 |
+
# writer.close()
|
49 |
+
|
50 |
+
# Test the model
|
51 |
+
def validation(model, device, data_loader):
|
52 |
+
correct = 0
|
53 |
+
total = 0
|
54 |
+
model.eval() # Set the model to evaluation mode
|
55 |
+
with torch.no_grad():
|
56 |
+
for data, target in data_loader:
|
57 |
+
data, target = data.to(device), target.to(device)
|
58 |
+
output = model(data)
|
59 |
+
_, predicted = torch.max(output.data, 1)
|
60 |
+
total += target.size(0)
|
61 |
+
correct += (predicted == target).sum().item()
|
62 |
+
|
63 |
+
accuracy = 100 * correct / total
|
64 |
+
print(f'Accuracy on test set: {accuracy:.2f}%')
|
65 |
+
|
66 |
+
# Save the trained model
|
67 |
+
def save_model(model, model_save_path):
|
68 |
+
torch.save(model.state_dict(), model_save_path)
|
69 |
+
print(f'[INFO] Model saved to {model_save_path}')
|
70 |
+
|
71 |
+
|
72 |
+
if __name__ == "__main__":
|
73 |
+
root_path = os.path.expanduser('data')
|
74 |
+
# Load the dataset
|
75 |
+
transforms = {
|
76 |
+
'train': transforms.Compose([
|
77 |
+
transforms.RandomRotation(10),
|
78 |
+
transforms.RandomHorizontalFlip(),
|
79 |
+
transforms.ToTensor(),
|
80 |
+
transforms.Normalize((0.5,), (0.5,)) # Normalize for MNIST
|
81 |
+
]),
|
82 |
+
'valid_test' : transforms.Compose([
|
83 |
+
transforms.ToTensor(),
|
84 |
+
transforms.Normalize((0.5,), (0.5,))
|
85 |
+
])
|
86 |
+
}
|
87 |
+
train_dataset = datasets.MNIST(root=root_path, download=False, train=True, transform=transforms['train'])
|
88 |
+
test_dataset = datasets.MNIST(root=root_path, download=True, train=False, transform=transforms['valid_test'])
|
89 |
+
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
|
90 |
+
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)
|
91 |
+
# view_batch_images(train_loader=train_loader)
|
92 |
+
|
93 |
+
# Get a batch of images
|
94 |
+
# data_iter = iter(train_loader)
|
95 |
+
# images, labels = next(data_iter)
|
96 |
+
# Save the images
|
97 |
+
# save_batch_images(images, save_dir="output_images", prefix="mnist_image", file_format="png")
|
98 |
+
|
99 |
+
# Initialize the model, loss function, and optimizer
|
100 |
+
device = get_device()
|
101 |
+
# model = SimpleNN()
|
102 |
+
model = SimpleCNN()
|
103 |
+
model.to(device=device)
|
104 |
+
criterion = nn.CrossEntropyLoss()
|
105 |
+
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
106 |
+
|
107 |
+
epochs = 20
|
108 |
+
train(model=model,device=device,train_loader=train_loader,test_loader=test_loader,criterion=criterion,optimizer=optimizer,epochs = epochs)
|
109 |
+
save_model(model=model,model_save_path="mnist_model.pht")
|
view_image.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import torchvision
|
5 |
+
from torchvision.transforms import ToPILImage
|
6 |
+
|
7 |
+
def view_image(image):
|
8 |
+
plt.imshow(image, cmap="gray")
|
9 |
+
plt.title("Grayscale Image")
|
10 |
+
plt.axis("off") # Hide axes for better visualization
|
11 |
+
plt.show()
|
12 |
+
|
13 |
+
def view_tensor_image(image_tensor, title="Image"):
|
14 |
+
image_np = image_tensor.squeeze().numpy()
|
15 |
+
plt.imshow(image_np)
|
16 |
+
plt.title(title)
|
17 |
+
plt.axis('off')
|
18 |
+
plt.show()
|
19 |
+
|
20 |
+
def view_batch_images(train_loader, num_images=8):
|
21 |
+
"""
|
22 |
+
Display a batch of images from the train_loader.
|
23 |
+
|
24 |
+
Parameters:
|
25 |
+
train_loader (DataLoader): The DataLoader containing the images.
|
26 |
+
num_images (int): Number of images to display from the batch.
|
27 |
+
"""
|
28 |
+
data_iter = iter(train_loader)
|
29 |
+
images, labels = next(data_iter) # Get a batch of images and labels
|
30 |
+
|
31 |
+
# Make a grid of images
|
32 |
+
img_grid = torchvision.utils.make_grid(images[:num_images], nrow=num_images, normalize=True)
|
33 |
+
img_np = img_grid.numpy().transpose((1, 2, 0)) # Rearrange dimensions for plotting
|
34 |
+
|
35 |
+
plt.figure(figsize=(12, 6))
|
36 |
+
plt.imshow(img_np, cmap="gray")
|
37 |
+
plt.title("Batch of Images")
|
38 |
+
plt.axis("off")
|
39 |
+
plt.show()
|
40 |
+
|
41 |
+
|
42 |
+
def save_batch_images(images, save_dir, prefix="image", file_format="png", unnormalize=None):
|
43 |
+
"""
|
44 |
+
Save each image in a batch to a specified directory.
|
45 |
+
|
46 |
+
Parameters:
|
47 |
+
images (torch.Tensor): Batch of images with shape (B, C, H, W).
|
48 |
+
save_dir (str): Directory to save the images.
|
49 |
+
prefix (str): Prefix for the saved image filenames.
|
50 |
+
file_format (str): File format for the saved images (e.g., "png", "jpg").
|
51 |
+
unnormalize (callable, optional): Function to unnormalize the images before saving.
|
52 |
+
"""
|
53 |
+
os.makedirs(save_dir, exist_ok=True) # Create the directory if it doesn't exist
|
54 |
+
to_pil = ToPILImage() # Converts tensors to PIL images
|
55 |
+
|
56 |
+
for idx, image in enumerate(images):
|
57 |
+
if unnormalize:
|
58 |
+
image = unnormalize(image) # Apply unnormalization if provided
|
59 |
+
|
60 |
+
pil_image = to_pil(image) # Convert to PIL Image
|
61 |
+
filename = os.path.join(save_dir, f"{prefix}_{idx}.{file_format}")
|
62 |
+
pil_image.save(filename)
|
63 |
+
print(f"Saved: {filename}")
|
view_model_information.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torchvision import transforms
|
6 |
+
from device import get_device
|
7 |
+
from simple_nn import SimpleNN
|
8 |
+
from simple_cnn import SimpleCNN
|
9 |
+
from view_image import view_image, view_tensor_image
|
10 |
+
|
11 |
+
model_path = "mnist_model.pht"
|
12 |
+
# Load model
|
13 |
+
model = SimpleCNN()
|
14 |
+
with open(model_path, 'rb') as f:
|
15 |
+
state_dict = torch.load(f, weights_only=True)
|
16 |
+
model.load_state_dict(state_dict)
|
17 |
+
|
18 |
+
# View model information
|
19 |
+
print(model) # Display the model architecture
|
20 |
+
|
21 |
+
# For more detailed information about the model's parameters:
|
22 |
+
print(f"Model summary: {model}")
|
23 |
+
|
24 |
+
# You can also view the parameters' details (e.g., number of parameters, layers, etc.)
|
25 |
+
for name, param in model.named_parameters():
|
26 |
+
print(f"Parameter: {name}, Shape: {param.shape}")
|