Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torchvision | |
import torchvision.transforms as transforms | |
from NeuralNet import NeuralNet | |
# Device Config | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# hyper parameters | |
input_size = 784 # 28*28 | |
hidden_size = 100 | |
num_classes = 10 | |
num_epochs = 20 | |
batch_size = 500 | |
learning_rate = 0.001 | |
# MNIST | |
training_dataset = torchvision.datasets.MNIST(root='./data', train=True, | |
transform=transforms.ToTensor(), download=True) | |
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, | |
transform=transforms.ToTensor()) | |
train_loader = torch.utils.data.DataLoader(dataset=training_dataset, batch_size=batch_size, shuffle=True) | |
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False) | |
example = iter(train_loader) | |
samples, labels = next(example) | |
print(samples.shape, labels.shape) | |
# for i in range(6): | |
# plt.subplot(2, 3, i+1) | |
# plt.imshow(samples[i][0], cmap='gray') | |
# plt.show() | |
model = NeuralNet(input_size, hidden_size, num_classes) | |
#loss and optimizer | |
criterion = nn.CrossEntropyLoss() | |
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | |
#training loop | |
n_total_steps = len(train_loader) | |
for epoch in range(num_epochs): | |
for i, (images, labels) in enumerate(train_loader): | |
# 100, 1, 28, 28 | |
# n, c, h, w | |
images = images.reshape(-1, 28*28).to(device) | |
labels = labels.to(device) | |
#forward | |
outputs = model(images) | |
loss = criterion(outputs, labels) | |
#backward | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
if (i+1) % 100 == 0: | |
print(f'epoch {epoch+1}/{num_epochs}, step {i+1}/{n_total_steps}, loss = {loss.item():.4f}') | |
# test | |
with torch.no_grad(): | |
n_correct = 0 | |
n_samples = 0 | |
for images , labels in test_loader: | |
images = images.reshape(-1, 28*28).to(device) | |
labels = labels.to(device) | |
outputs = model(images) | |
# value, index | |
_, predictions = torch.max(outputs, 1) | |
n_samples += labels.shape[0] | |
n_correct += (predictions == labels).sum().item() | |
acc = 100.0 * n_correct / n_samples | |
print(f'accuracy = {acc}') | |
torch.save(model.state_dict(), 'model/model.pt') | |