MNISTClassification / classification.py
csisc's picture
Create classification.py
d139de5
# Before running, install required packages:
!pip install numpy torch torchvision pytorch-ignite
import numpy as np
import torch
from torch import optim, nn
from torch.utils.data import DataLoader, TensorDataset
from torchvision import models, datasets, transforms
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
# ----------------------------------- Setup -----------------------------------
# Dataset MNIST will be loaded further down.
# Set up hyperparameters.
lr = 0.001
batch_size = 200
num_epochs = 1
# Set up logging.
print_every = 1 # batches
# Set up device.
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
# -------------------------- Dataset & Preprocessing --------------------------
def load_data(train):
# Download and transform dataset.
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.RandomVerticalFlip(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # grayscale to RGB
])
dataset = datasets.MNIST("./data", train=train, download=True, transform=transform)
# Wrap in data loader.
if use_cuda:
kwargs = {"pin_memory": True, "num_workers": 1}
else:
kwargs = {}
loader = DataLoader(dataset, batch_size=batch_size, shuffle=train, **kwargs)
return loader
train_loader = load_data(train=True)
val_loader = None
test_loader = load_data(train=False)
# ----------------------------------- Model -----------------------------------
# Set up model, loss, optimizer.
model = models.alexnet(pretrained=False)
model = model.to(device)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
# --------------------------------- Training ----------------------------------
# Set up pytorch-ignite trainer and evaluator.
trainer = create_supervised_trainer(
model,
optimizer,
loss_func,
device=device,
)
metrics = {
"accuracy": Accuracy(),
"loss": Loss(loss_func),
}
evaluator = create_supervised_evaluator(
model, metrics=metrics, device=device
)
@trainer.on(Events.ITERATION_COMPLETED(every=print_every))
def log_batch(trainer):
batch = (trainer.state.iteration - 1) % trainer.state.epoch_length + 1
print(
f"Epoch {trainer.state.epoch} / {num_epochs}, "
f"batch {batch} / {trainer.state.epoch_length}: "
f"loss: {trainer.state.output:.3f}"
)
@trainer.on(Events.EPOCH_COMPLETED)
def log_epoch(trainer):
print(f"Epoch {trainer.state.epoch} / {num_epochs} average results: ")
def log_results(name, metrics, epoch):
print(
f"{name + ':':6} loss: {metrics['loss']:.3f}, accuracy: {metrics['accuracy']:.3f}"
)
# Train data.
evaluator.run(train_loader)
log_results("train", evaluator.state.metrics, trainer.state.epoch)
# Val data.
if val_loader:
evaluator.run(val_loader)
log_results("val", evaluator.state.metrics, trainer.state.epoch)
# Test data.
if test_loader:
evaluator.run(test_loader)
log_results("test", evaluator.state.metrics, trainer.state.epoch)
print()
print("-" * 80)
print()
# Start training.
trainer.run(train_loader, max_epochs=num_epochs)