nullHawk's picture
deploy to huggingface
78f38b4 verified
raw
history blame
2.41 kB
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from NeuralNet import NeuralNet
# Device Config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# hyper parameters
input_size = 784 # 28*28
hidden_size = 100
num_classes = 10
num_epochs = 20
batch_size = 500
learning_rate = 0.001
# MNIST
training_dataset = torchvision.datasets.MNIST(root='./data', train=True,
transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False,
transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=training_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
example = iter(train_loader)
samples, labels = next(example)
print(samples.shape, labels.shape)
# for i in range(6):
# plt.subplot(2, 3, i+1)
# plt.imshow(samples[i][0], cmap='gray')
# plt.show()
model = NeuralNet(input_size, hidden_size, num_classes)
#loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
#training loop
n_total_steps = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
# 100, 1, 28, 28
# n, c, h, w
images = images.reshape(-1, 28*28).to(device)
labels = labels.to(device)
#forward
outputs = model(images)
loss = criterion(outputs, labels)
#backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print(f'epoch {epoch+1}/{num_epochs}, step {i+1}/{n_total_steps}, loss = {loss.item():.4f}')
# test
with torch.no_grad():
n_correct = 0
n_samples = 0
for images , labels in test_loader:
images = images.reshape(-1, 28*28).to(device)
labels = labels.to(device)
outputs = model(images)
# value, index
_, predictions = torch.max(outputs, 1)
n_samples += labels.shape[0]
n_correct += (predictions == labels).sum().item()
acc = 100.0 * n_correct / n_samples
print(f'accuracy = {acc}')
torch.save(model.state_dict(), 'model/model.pt')