soutrik's picture
added: model and code and app
29730dd
"""Module to define the train and test functions."""
# from functools import partial
import modules.config as config
import pytorch_lightning as pl
import torch
from modules.utils import create_folder_if_not_exists
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, ModelSummary
# Import tuner
from pytorch_lightning.tuner.tuning import Tuner
# What is the start LR and weight decay you'd prefer?
PREFERRED_START_LR = config.PREFERRED_START_LR
def train_and_test_model(
batch_size,
num_epochs,
model,
datamodule,
logger,
debug=False,
):
"""Trains and tests the model by iterating through epochs using Lightning Trainer."""
print(f"\n\nBatch size: {batch_size}, Total epochs: {num_epochs}\n\n")
print("Defining Lightning Callbacks")
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
checkpoint = ModelCheckpoint(
dirpath=config.CHECKPOINT_PATH, monitor="val_acc", mode="max", filename="model_best_epoch", save_last=True
)
# # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.LearningRateMonitor.html#learningratemonitor
lr_rate_monitor = LearningRateMonitor(logging_interval="epoch", log_momentum=False)
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelSummary.html#lightning.pytorch.callbacks.ModelSummary
model_summary = ModelSummary(max_depth=0)
print("Defining Lightning Trainer")
# Change trainer settings for debugging
if debug:
num_epochs = 1
fast_dev_run = True
overfit_batches = 0.1
profiler = "advanced"
else:
fast_dev_run = False
overfit_batches = 0.0
profiler = None
# https://lightning.ai/docs/pytorch/stable/common/trainer.html#methods
trainer = pl.Trainer(
precision=16,
fast_dev_run=fast_dev_run,
# deterministic=True,
# devices="auto",
# accelerator="auto",
max_epochs=num_epochs,
logger=logger,
# enable_model_summary=False,
overfit_batches=overfit_batches,
log_every_n_steps=10,
# num_sanity_val_steps=5,
profiler=profiler,
# check_val_every_n_epoch=1,
callbacks=[checkpoint, lr_rate_monitor, model_summary],
# callbacks=[checkpoint],
)
# # Using the learning rate finder
# model.learning_rate = model.find_optimal_lr(train_loader=datamodule.train_dataloader())
# Using the lr_find from Trainer.tune method instead
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.tuner.tuning.Tuner.html#lightning.pytorch.tuner.tuning.Tuner
# https://www.youtube.com/watch?v=cLZv0eZQSIE
print("Finding the optimal learning rate using Lightning Tuner.")
tuner = Tuner(trainer)
tuner.lr_find(
model=model,
datamodule=datamodule,
min_lr=PREFERRED_START_LR,
max_lr=5,
num_training=200,
mode="linear",
early_stop_threshold=10,
attr_name="learning_rate",
)
trainer.fit(model, datamodule=datamodule)
trainer.test(model, dataloaders=datamodule.test_dataloader())
# # Obtain the results dictionary from model
print("Collecting epoch level model results.")
results = model.results
# print(f"Results Length: {len(results)}")
# Get the list of misclassified images
print("Collecting misclassified images.")
misclassified_image_data = model.misclassified_image_data
# print(f"Misclassified Images Length: {len(misclassified_image_data)}")
# Save the model using torch save as backup
print("Saving the model.")
print(f"Model saved to {config.MODEL_PATH}")
create_folder_if_not_exists(config.MODEL_PATH)
torch.save(model.state_dict(), config.MODEL_PATH)
# Save first few misclassified images data to a file
num_elements = 20
print(f"Saving first {num_elements} misclassified images.")
subset_misclassified_image_data = {"images": [], "ground_truths": [], "predicted_vals": []}
subset_misclassified_image_data["images"] = misclassified_image_data["images"][:num_elements]
subset_misclassified_image_data["ground_truths"] = misclassified_image_data["ground_truths"][:num_elements]
subset_misclassified_image_data["predicted_vals"] = misclassified_image_data["predicted_vals"][:num_elements]
create_folder_if_not_exists(config.MISCLASSIFIED_PATH)
torch.save(subset_misclassified_image_data, config.MISCLASSIFIED_PATH)
return trainer, results, misclassified_image_data
# return trainer