import os import torch import torchvision from torchvision import datasets, transforms, utils from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.transforms.dataset_normalizations import cifar10_normalization from pytorch_lightning import LightningModule, Trainer, seed_everything from pytorch_lightning.callbacks import LearningRateMonitor from pytorch_lightning.callbacks.progress import TQDMProgressBar from pytorch_lightning.loggers import CSVLogger from torch.optim.lr_scheduler import OneCycleLR from torch.optim.swa_utils import AveragedModel, update_bn from torchmetrics.functional import accuracy import pandas as pd import torch.nn as nn import torch.nn.functional as F # from IPython.core.display import display import misclas_helper import gradcam_helper import lightningmodel from misclas_helper import display_cifar_misclassified_data from gradcam_helper import display_gradcam_output from misclas_helper import get_misclassified_data2 from misclas_helper import classify_images from lightningmodel import LitResnet #ref : https://pytorch-lightning.readthedocs.io/en/1.2.10/common/weights_loading.html from pytorch_lightning.callbacks import ModelCheckpoint classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck', 'NotApplicable') inv_normalize = transforms.Normalize( mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23], std=[1/0.23, 1/0.23, 1/0.23] ) def ts_lt( # Train and Save Vs Load and Test save1_or_load0, # decision maker for training Vs testing Epochs = 1, # argument for training wt_fname = "/content/weights.ckpt" # argument for testing ): checkpoint_callback = ModelCheckpoint( monitor='val_acc', dirpath='/content/', filename='weights_{epoch:02d}_{val_acc:.2f}', save_top_k=3, mode='max', ) trainer = Trainer( max_epochs=Epochs, #26 accelerator="auto", devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs logger=CSVLogger(save_dir="logs/"), callbacks=[LearningRateMonitor(logging_interval="step"), TQDMProgressBar(refresh_rate=10), checkpoint_callback], ) PATH_DATASETS = os.environ.get("PATH_DATASETS", ".") BATCH_SIZE = 256 if torch.cuda.is_available() else 64 NUM_WORKERS = int(os.cpu_count() / 2) train_transforms = torchvision.transforms.Compose( [ torchvision.transforms.RandomCrop(32, padding=4), torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.ToTensor(), cifar10_normalization(), ] ) test_transforms = torchvision.transforms.Compose( [ torchvision.transforms.ToTensor(), cifar10_normalization(), ] ) cifar10_dm = CIFAR10DataModule( data_dir=PATH_DATASETS, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, train_transforms=train_transforms, test_transforms=test_transforms, val_transforms=test_transforms, ) if save1_or_load0 == True: model = LitResnet(lr=0.05) checkpoint_callback = ModelCheckpoint( monitor='val_acc', dirpath='/content/', filename='weights_{epoch:02d}_{val_acc:.2f}', save_top_k=3, mode='max', ) trainer.fit(model, cifar10_dm) else: model = LitResnet(lr=0.05).load_from_checkpoint(wt_fname) trainer.test(model, datamodule=cifar10_dm) return model, trainer