File size: 4,627 Bytes
29730dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""Module to define the train and test functions."""

# from functools import partial

import modules.config as config
import pytorch_lightning as pl
import torch
from modules.utils import create_folder_if_not_exists
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, ModelSummary

# Import tuner
from pytorch_lightning.tuner.tuning import Tuner

# What is the start LR and weight decay you'd prefer?
PREFERRED_START_LR = config.PREFERRED_START_LR


def train_and_test_model(
    batch_size,
    num_epochs,
    model,
    datamodule,
    logger,
    debug=False,
):
    """Trains and tests the model by iterating through epochs using Lightning Trainer."""

    print(f"\n\nBatch size: {batch_size}, Total epochs: {num_epochs}\n\n")

    print("Defining Lightning Callbacks")

    # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
    checkpoint = ModelCheckpoint(
        dirpath=config.CHECKPOINT_PATH, monitor="val_acc", mode="max", filename="model_best_epoch", save_last=True
    )
    # # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.LearningRateMonitor.html#learningratemonitor
    lr_rate_monitor = LearningRateMonitor(logging_interval="epoch", log_momentum=False)
    # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelSummary.html#lightning.pytorch.callbacks.ModelSummary
    model_summary = ModelSummary(max_depth=0)

    print("Defining Lightning Trainer")
    # Change trainer settings for debugging
    if debug:
        num_epochs = 1
        fast_dev_run = True
        overfit_batches = 0.1
        profiler = "advanced"
    else:
        fast_dev_run = False
        overfit_batches = 0.0
        profiler = None

    # https://lightning.ai/docs/pytorch/stable/common/trainer.html#methods
    trainer = pl.Trainer(
        precision=16,
        fast_dev_run=fast_dev_run,
        # deterministic=True,
        # devices="auto",
        # accelerator="auto",
        max_epochs=num_epochs,
        logger=logger,
        # enable_model_summary=False,
        overfit_batches=overfit_batches,
        log_every_n_steps=10,
        # num_sanity_val_steps=5,
        profiler=profiler,
        # check_val_every_n_epoch=1,
        callbacks=[checkpoint, lr_rate_monitor, model_summary],
        # callbacks=[checkpoint],
    )

    # # Using the learning rate finder
    # model.learning_rate = model.find_optimal_lr(train_loader=datamodule.train_dataloader())

    # Using the lr_find from Trainer.tune method instead
    # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.tuner.tuning.Tuner.html#lightning.pytorch.tuner.tuning.Tuner
    # https://www.youtube.com/watch?v=cLZv0eZQSIE
    print("Finding the optimal learning rate using Lightning Tuner.")
    tuner = Tuner(trainer)
    tuner.lr_find(
        model=model,
        datamodule=datamodule,
        min_lr=PREFERRED_START_LR,
        max_lr=5,
        num_training=200,
        mode="linear",
        early_stop_threshold=10,
        attr_name="learning_rate",
    )

    trainer.fit(model, datamodule=datamodule)
    trainer.test(model, dataloaders=datamodule.test_dataloader())

    # # Obtain the results dictionary from model
    print("Collecting epoch level model results.")
    results = model.results
    # print(f"Results Length: {len(results)}")

    # Get the list of misclassified images
    print("Collecting misclassified images.")
    misclassified_image_data = model.misclassified_image_data
    # print(f"Misclassified Images Length: {len(misclassified_image_data)}")

    # Save the model using torch save as backup
    print("Saving the model.")
    print(f"Model saved to {config.MODEL_PATH}")
    create_folder_if_not_exists(config.MODEL_PATH)
    torch.save(model.state_dict(), config.MODEL_PATH)

    # Save first few misclassified images data to a file
    num_elements = 20
    print(f"Saving first {num_elements} misclassified images.")
    subset_misclassified_image_data = {"images": [], "ground_truths": [], "predicted_vals": []}
    subset_misclassified_image_data["images"] = misclassified_image_data["images"][:num_elements]
    subset_misclassified_image_data["ground_truths"] = misclassified_image_data["ground_truths"][:num_elements]
    subset_misclassified_image_data["predicted_vals"] = misclassified_image_data["predicted_vals"][:num_elements]
    create_folder_if_not_exists(config.MISCLASSIFIED_PATH)
    torch.save(subset_misclassified_image_data, config.MISCLASSIFIED_PATH)

    return trainer, results, misclassified_image_data
    # return trainer