Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision | |
import torchvision.transforms as transforms | |
import matplotlib.pyplot as plt | |
if torch.cuda.is_available(): | |
device = torch.device("cuda:0") | |
print("GPU") | |
else: | |
device = torch.device("cpu") | |
print("CPU") | |
# MNIST dataset | |
batch_size=64 | |
train_dataset = torchvision.datasets.MNIST(root='./data', | |
train=True, | |
transform=transforms.ToTensor(), | |
download=True) | |
test_dataset = torchvision.datasets.MNIST(root='./data', | |
train=False, | |
transform=transforms.ToTensor()) | |
# Data loader | |
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, | |
batch_size=batch_size, | |
shuffle=True) | |
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, | |
batch_size=batch_size, | |
shuffle=False) | |
# NEURAL NETWORK | |
class LeNet(nn.Module): | |
def __init__(self): | |
super(LeNet, self).__init__() | |
self.convs = nn.Sequential( | |
nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(5, 5)), | |
nn.Tanh(), | |
nn.AvgPool2d(2, 2), | |
nn.Conv2d(in_channels=4, out_channels=12, kernel_size=(5, 5)), | |
nn.Tanh(), | |
nn.AvgPool2d(2, 2) | |
) | |
self.linear = nn.Sequential( | |
nn.Linear(4*4*12,10) | |
) | |
def forward(self, x): | |
x = self.convs(x) | |
x = torch.flatten(x, 1) | |
return self.linear(x) | |
# TRAIN PARAMETERS | |
criterion = nn.CrossEntropyLoss() | |
model_adam = LeNet().to(device) | |
optimizer = torch.optim.Adam(model_adam.parameters(), lr=0.05) | |
n_steps = len(train_loader) | |
num_epochs = 10 | |
# TRAIN | |
def train(model): | |
for epoch in range(num_epochs): | |
for i, (images, labels) in enumerate(train_loader): | |
images = images.to(device) | |
labels = labels.to(device) | |
# Forward pass | |
outputs = model(images) | |
loss = criterion(outputs, labels) | |
# Backward and optimize | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
torch.save(model_adam.state_dict(), "model_mnist.pth") |