Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from model import ResNet18 | |
from preprocessing import PreprocessedImageFolder, augmentations, make_dls | |
from trainer import ( | |
LRFinderCB, | |
ActivationStatsCB, | |
AugmentCB, | |
DeviceCB, | |
MultiClassAccuracyCB, | |
ProgressCB, | |
Trainer, | |
WandBCB, | |
) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
train_ds = PreprocessedImageFolder("./dataset/train", None) | |
valid_ds = PreprocessedImageFolder("./dataset/test", None) | |
dls = make_dls(train_ds, valid_ds, batch_size=32, num_workers=2) | |
model = ResNet18(in_channels=1, num_classes=len(train_ds.classes)) | |
# lr_find = LRFinderCB(min_lr=1e-4, max_lr=0.1, max_mult=3) | |
# act_stats = ActivationStatsCB(mod_filter=lambda x: isinstance(x, nn.Conv2d) or isinstance(x, nn.Linear), with_wandb=True) # for debugging purposes | |
progress = ProgressCB(in_notebook=False) | |
wandb_cb = WandBCB(proj_name="test", model_path="./model.pth") | |
augment = AugmentCB(device=device, transform=augmentations) | |
acc_cb = MultiClassAccuracyCB(with_wandb=True) | |
trainer = Trainer( | |
model, | |
dls, | |
F.cross_entropy, | |
torch.optim.SGD, | |
lr=1e-4, | |
cbs=[DeviceCB(device), augment, progress, wandb_cb, acc_cb], | |
) # act_stats, lr_find | |
trainer.fit(5, True, True) | |
# TODO: saving plots to wandb | |
progress.plot_losses(save=True) | |
# act_stats.plot_stats(save=True) | |
# act_stats.color_dim(save=True) | |
# act_stats.dead_chart(save=True) | |
# torch.save(trainer.model.state_dict(), "./model.pth") # done by WandBCB | |