Spaces:
Build error
Build error
'''Train CIFAR10 with PyTorch.''' | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
import torch.backends.cudnn as cudnn | |
import torchvision | |
import torchvision.transforms as transforms | |
import os | |
from Resnet101 import * | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
best_acc = 0 # best test accuracy | |
start_epoch = 0 # start from epoch 0 or last checkpoint epoch | |
end_epoch = 300 | |
resume = False | |
# Data | |
print('==> Preparing data..') | |
transform_train = transforms.Compose([ | |
transforms.RandomCrop(32, padding=4), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | |
]) | |
transform_test = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | |
]) | |
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) | |
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) | |
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) | |
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) | |
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') | |
# Model | |
print('==> Building model..') | |
net = ResNet101() | |
net_name = net.name | |
save_path = './checkpoint/{0}_ckpt.pth'.format(net.name) | |
net = net.to(device) | |
if device == 'cuda': | |
net = torch.nn.DataParallel(net) | |
cudnn.benchmark = True | |
if resume: | |
# Load best checkpoint trained last time. | |
print('==> Resuming from checkpoint..') | |
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' | |
checkpoint = torch.load(save_path) | |
net.load_state_dict(checkpoint['net']) | |
best_acc = checkpoint['acc'] | |
start_epoch = checkpoint['epoch'] | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) | |
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=70, gamma=0.1) | |
# Training | |
def train(epoch): | |
print('\nEpoch: %d' % epoch) | |
net.train() | |
train_loss = 0 | |
correct = 0 | |
total = 0 | |
for batch_idx, (inputs, targets) in enumerate(trainloader): | |
inputs, targets = inputs.to(device), targets.to(device) | |
optimizer.zero_grad() | |
outputs = net(inputs) | |
loss = criterion(outputs, targets) | |
loss.backward() | |
optimizer.step() | |
train_loss += loss.item() | |
_, predicted = outputs.max(1) | |
total += targets.size(0) | |
correct += predicted.eq(targets).sum().item() | |
print('Loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) | |
def test(epoch): | |
global best_acc | |
net.eval() | |
test_loss = 0 | |
correct = 0 | |
total = 0 | |
with torch.no_grad(): | |
for batch_idx, (inputs, targets) in enumerate(testloader): | |
inputs, targets = inputs.to(device), targets.to(device) | |
outputs = net(inputs) | |
loss = criterion(outputs, targets) | |
test_loss += loss.item() | |
_, predicted = outputs.max(1) | |
total += targets.size(0) | |
correct += predicted.eq(targets).sum().item() | |
print('Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) | |
# Save checkpoint. | |
acc = 100.*correct/total | |
if acc > best_acc: | |
print('Saving ' + net_name + ' ..') | |
state = { | |
'net': net.state_dict(), | |
'acc': acc, | |
'epoch': epoch, | |
} | |
if not os.path.isdir('checkpoint'): | |
os.mkdir('checkpoint') | |
torch.save(state, save_path) | |
best_acc = acc | |
for epoch in range(start_epoch, end_epoch): | |
train(epoch) | |
test(epoch) | |
scheduler.step() | |
print("\nTesting best accuracy:", best_acc) |