|
|
|
!pip install numpy torch torchvision pytorch-ignite |
|
|
|
import numpy as np |
|
import torch |
|
from torch import optim, nn |
|
from torch.utils.data import DataLoader, TensorDataset |
|
from torchvision import models, datasets, transforms |
|
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
|
from ignite.metrics import Accuracy, Loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
lr = 0.001 |
|
batch_size = 200 |
|
num_epochs = 1 |
|
|
|
|
|
print_every = 1 |
|
|
|
|
|
use_cuda = torch.cuda.is_available() |
|
device = torch.device("cuda" if use_cuda else "cpu") |
|
|
|
|
|
|
|
def load_data(train): |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.RandomVerticalFlip(), |
|
transforms.Lambda(lambda x: x.repeat(3, 1, 1)), |
|
]) |
|
dataset = datasets.MNIST("./data", train=train, download=True, transform=transform) |
|
|
|
|
|
if use_cuda: |
|
kwargs = {"pin_memory": True, "num_workers": 1} |
|
else: |
|
kwargs = {} |
|
loader = DataLoader(dataset, batch_size=batch_size, shuffle=train, **kwargs) |
|
return loader |
|
|
|
train_loader = load_data(train=True) |
|
val_loader = None |
|
test_loader = load_data(train=False) |
|
|
|
|
|
|
|
|
|
model = models.alexnet(pretrained=False) |
|
model = model.to(device) |
|
loss_func = nn.CrossEntropyLoss() |
|
optimizer = optim.Adam(model.parameters(), lr=lr) |
|
|
|
|
|
|
|
|
|
trainer = create_supervised_trainer( |
|
model, |
|
optimizer, |
|
loss_func, |
|
device=device, |
|
) |
|
metrics = { |
|
"accuracy": Accuracy(), |
|
"loss": Loss(loss_func), |
|
} |
|
evaluator = create_supervised_evaluator( |
|
model, metrics=metrics, device=device |
|
) |
|
|
|
@trainer.on(Events.ITERATION_COMPLETED(every=print_every)) |
|
def log_batch(trainer): |
|
batch = (trainer.state.iteration - 1) % trainer.state.epoch_length + 1 |
|
print( |
|
f"Epoch {trainer.state.epoch} / {num_epochs}, " |
|
f"batch {batch} / {trainer.state.epoch_length}: " |
|
f"loss: {trainer.state.output:.3f}" |
|
) |
|
|
|
@trainer.on(Events.EPOCH_COMPLETED) |
|
def log_epoch(trainer): |
|
print(f"Epoch {trainer.state.epoch} / {num_epochs} average results: ") |
|
|
|
def log_results(name, metrics, epoch): |
|
print( |
|
f"{name + ':':6} loss: {metrics['loss']:.3f}, accuracy: {metrics['accuracy']:.3f}" |
|
) |
|
|
|
|
|
evaluator.run(train_loader) |
|
log_results("train", evaluator.state.metrics, trainer.state.epoch) |
|
|
|
|
|
if val_loader: |
|
evaluator.run(val_loader) |
|
log_results("val", evaluator.state.metrics, trainer.state.epoch) |
|
|
|
|
|
if test_loader: |
|
evaluator.run(test_loader) |
|
log_results("test", evaluator.state.metrics, trainer.state.epoch) |
|
|
|
print() |
|
print("-" * 80) |
|
print() |
|
|
|
|
|
trainer.run(train_loader, max_epochs=num_epochs) |