# Before running, install required packages: !pip install numpy torch torchvision pytorch-ignite import numpy as np import torch from torch import optim, nn from torch.utils.data import DataLoader, TensorDataset from torchvision import models, datasets, transforms from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator from ignite.metrics import Accuracy, Loss # ----------------------------------- Setup ----------------------------------- # Dataset MNIST will be loaded further down. # Set up hyperparameters. lr = 0.001 batch_size = 200 num_epochs = 1 # Set up logging. print_every = 1 # batches # Set up device. use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") # -------------------------- Dataset & Preprocessing -------------------------- def load_data(train): # Download and transform dataset. transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.RandomVerticalFlip(), transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # grayscale to RGB ]) dataset = datasets.MNIST("./data", train=train, download=True, transform=transform) # Wrap in data loader. if use_cuda: kwargs = {"pin_memory": True, "num_workers": 1} else: kwargs = {} loader = DataLoader(dataset, batch_size=batch_size, shuffle=train, **kwargs) return loader train_loader = load_data(train=True) val_loader = None test_loader = load_data(train=False) # ----------------------------------- Model ----------------------------------- # Set up model, loss, optimizer. model = models.alexnet(pretrained=False) model = model.to(device) loss_func = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=lr) # --------------------------------- Training ---------------------------------- # Set up pytorch-ignite trainer and evaluator. trainer = create_supervised_trainer( model, optimizer, loss_func, device=device, ) metrics = { "accuracy": Accuracy(), "loss": Loss(loss_func), } evaluator = create_supervised_evaluator( model, metrics=metrics, device=device ) @trainer.on(Events.ITERATION_COMPLETED(every=print_every)) def log_batch(trainer): batch = (trainer.state.iteration - 1) % trainer.state.epoch_length + 1 print( f"Epoch {trainer.state.epoch} / {num_epochs}, " f"batch {batch} / {trainer.state.epoch_length}: " f"loss: {trainer.state.output:.3f}" ) @trainer.on(Events.EPOCH_COMPLETED) def log_epoch(trainer): print(f"Epoch {trainer.state.epoch} / {num_epochs} average results: ") def log_results(name, metrics, epoch): print( f"{name + ':':6} loss: {metrics['loss']:.3f}, accuracy: {metrics['accuracy']:.3f}" ) # Train data. evaluator.run(train_loader) log_results("train", evaluator.state.metrics, trainer.state.epoch) # Val data. if val_loader: evaluator.run(val_loader) log_results("val", evaluator.state.metrics, trainer.state.epoch) # Test data. if test_loader: evaluator.run(test_loader) log_results("test", evaluator.state.metrics, trainer.state.epoch) print() print("-" * 80) print() # Start training. trainer.run(train_loader, max_epochs=num_epochs)