import torch import torch.nn as nn import torchvision import torchvision.transforms as transforms from NeuralNet import NeuralNet # Device Config device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # hyper parameters input_size = 784 # 28*28 hidden_size = 100 num_classes = 10 num_epochs = 20 batch_size = 500 learning_rate = 0.001 # MNIST training_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True) test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor()) train_loader = torch.utils.data.DataLoader(dataset=training_dataset, batch_size=batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False) example = iter(train_loader) samples, labels = next(example) print(samples.shape, labels.shape) # for i in range(6): # plt.subplot(2, 3, i+1) # plt.imshow(samples[i][0], cmap='gray') # plt.show() model = NeuralNet(input_size, hidden_size, num_classes) #loss and optimizer criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) #training loop n_total_steps = len(train_loader) for epoch in range(num_epochs): for i, (images, labels) in enumerate(train_loader): # 100, 1, 28, 28 # n, c, h, w images = images.reshape(-1, 28*28).to(device) labels = labels.to(device) #forward outputs = model(images) loss = criterion(outputs, labels) #backward optimizer.zero_grad() loss.backward() optimizer.step() if (i+1) % 100 == 0: print(f'epoch {epoch+1}/{num_epochs}, step {i+1}/{n_total_steps}, loss = {loss.item():.4f}') # test with torch.no_grad(): n_correct = 0 n_samples = 0 for images , labels in test_loader: images = images.reshape(-1, 28*28).to(device) labels = labels.to(device) outputs = model(images) # value, index _, predictions = torch.max(outputs, 1) n_samples += labels.shape[0] n_correct += (predictions == labels).sum().item() acc = 100.0 * n_correct / n_samples print(f'accuracy = {acc}') torch.save(model.state_dict(), 'model/model.pt')