raja5259 commited on
Commit
c0502a3
·
verified ·
1 Parent(s): 926a474

Create trainsave_loadtest.py

Browse files
Files changed (1) hide show
  1. trainsave_loadtest.py +101 -0
trainsave_loadtest.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision
4
+ from torchvision import datasets, transforms, utils
5
+ from pl_bolts.datamodules import CIFAR10DataModule
6
+ from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
7
+ from pytorch_lightning import LightningModule, Trainer, seed_everything
8
+ from pytorch_lightning.callbacks import LearningRateMonitor
9
+ from pytorch_lightning.callbacks.progress import TQDMProgressBar
10
+ from pytorch_lightning.loggers import CSVLogger
11
+ from torch.optim.lr_scheduler import OneCycleLR
12
+ from torch.optim.swa_utils import AveragedModel, update_bn
13
+ from torchmetrics.functional import accuracy
14
+ import pandas as pd
15
+ import seaborn as sn
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ # from IPython.core.display import display
19
+ import misclas_helper
20
+ import gradcam_helper
21
+ import lightningmodel
22
+ from misclas_helper import display_cifar_misclassified_data
23
+ from gradcam_helper import display_gradcam_output
24
+ from misclas_helper import get_misclassified_data2
25
+ from misclas_helper import classify_images
26
+ from lightningmodel import LitResnet
27
+ #ref : https://pytorch-lightning.readthedocs.io/en/1.2.10/common/weights_loading.html
28
+ from pytorch_lightning.callbacks import ModelCheckpoint
29
+
30
+ classes = ('plane', 'car', 'bird', 'cat', 'deer',
31
+ 'dog', 'frog', 'horse', 'ship', 'truck', 'NotApplicable')
32
+
33
+ inv_normalize = transforms.Normalize(
34
+ mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
35
+ std=[1/0.23, 1/0.23, 1/0.23]
36
+ )
37
+
38
+ def ts_lt( # Train and Save Vs Load and Test
39
+ save1_or_load0, # decision maker for training Vs testing
40
+ Epochs = 1, # argument for training
41
+ wt_fname = "/content/weights.ckpt" # argument for testing
42
+ ):
43
+ checkpoint_callback = ModelCheckpoint(
44
+ monitor='val_acc',
45
+ dirpath='/content/',
46
+ filename='weights_{epoch:02d}_{val_acc:.2f}',
47
+ save_top_k=3,
48
+ mode='max',
49
+ )
50
+
51
+ trainer = Trainer(
52
+ max_epochs=Epochs, #26
53
+ accelerator="auto",
54
+ devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
55
+ logger=CSVLogger(save_dir="logs/"),
56
+ callbacks=[LearningRateMonitor(logging_interval="step"), TQDMProgressBar(refresh_rate=10), checkpoint_callback],
57
+ )
58
+ PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
59
+ BATCH_SIZE = 256 if torch.cuda.is_available() else 64
60
+ NUM_WORKERS = int(os.cpu_count() / 2)
61
+
62
+ train_transforms = torchvision.transforms.Compose(
63
+ [
64
+ torchvision.transforms.RandomCrop(32, padding=4),
65
+ torchvision.transforms.RandomHorizontalFlip(),
66
+ torchvision.transforms.ToTensor(),
67
+ cifar10_normalization(),
68
+ ]
69
+ )
70
+
71
+ test_transforms = torchvision.transforms.Compose(
72
+ [
73
+ torchvision.transforms.ToTensor(),
74
+ cifar10_normalization(),
75
+ ]
76
+ )
77
+ cifar10_dm = CIFAR10DataModule(
78
+ data_dir=PATH_DATASETS,
79
+ batch_size=BATCH_SIZE,
80
+ num_workers=NUM_WORKERS,
81
+ train_transforms=train_transforms,
82
+ test_transforms=test_transforms,
83
+ val_transforms=test_transforms,
84
+ )
85
+
86
+ if save1_or_load0 == True:
87
+ model = LitResnet(lr=0.05)
88
+ checkpoint_callback = ModelCheckpoint(
89
+ monitor='val_acc',
90
+ dirpath='/content/',
91
+ filename='weights_{epoch:02d}_{val_acc:.2f}',
92
+ save_top_k=3,
93
+ mode='max',
94
+ )
95
+ trainer.fit(model, cifar10_dm)
96
+ else:
97
+ model = LitResnet(lr=0.05).load_from_checkpoint(wt_fname)
98
+
99
+ trainer.test(model, datamodule=cifar10_dm)
100
+
101
+ return model, trainer