Spaces:
Runtime error
Runtime error
"""Module to define the train and test functions.""" | |
# from functools import partial | |
import modules.config as config | |
import pytorch_lightning as pl | |
import torch | |
from modules.utils import create_folder_if_not_exists | |
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, ModelSummary | |
# Import tuner | |
from pytorch_lightning.tuner.tuning import Tuner | |
# What is the start LR and weight decay you'd prefer? | |
PREFERRED_START_LR = config.PREFERRED_START_LR | |
def train_and_test_model( | |
batch_size, | |
num_epochs, | |
model, | |
datamodule, | |
logger, | |
debug=False, | |
): | |
"""Trains and tests the model by iterating through epochs using Lightning Trainer.""" | |
print(f"\n\nBatch size: {batch_size}, Total epochs: {num_epochs}\n\n") | |
print("Defining Lightning Callbacks") | |
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint | |
checkpoint = ModelCheckpoint( | |
dirpath=config.CHECKPOINT_PATH, monitor="val_acc", mode="max", filename="model_best_epoch", save_last=True | |
) | |
# # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.LearningRateMonitor.html#learningratemonitor | |
lr_rate_monitor = LearningRateMonitor(logging_interval="epoch", log_momentum=False) | |
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelSummary.html#lightning.pytorch.callbacks.ModelSummary | |
model_summary = ModelSummary(max_depth=0) | |
print("Defining Lightning Trainer") | |
# Change trainer settings for debugging | |
if debug: | |
num_epochs = 1 | |
fast_dev_run = True | |
overfit_batches = 0.1 | |
profiler = "advanced" | |
else: | |
fast_dev_run = False | |
overfit_batches = 0.0 | |
profiler = None | |
# https://lightning.ai/docs/pytorch/stable/common/trainer.html#methods | |
trainer = pl.Trainer( | |
precision=16, | |
fast_dev_run=fast_dev_run, | |
# deterministic=True, | |
# devices="auto", | |
# accelerator="auto", | |
max_epochs=num_epochs, | |
logger=logger, | |
# enable_model_summary=False, | |
overfit_batches=overfit_batches, | |
log_every_n_steps=10, | |
# num_sanity_val_steps=5, | |
profiler=profiler, | |
# check_val_every_n_epoch=1, | |
callbacks=[checkpoint, lr_rate_monitor, model_summary], | |
# callbacks=[checkpoint], | |
) | |
# # Using the learning rate finder | |
# model.learning_rate = model.find_optimal_lr(train_loader=datamodule.train_dataloader()) | |
# Using the lr_find from Trainer.tune method instead | |
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.tuner.tuning.Tuner.html#lightning.pytorch.tuner.tuning.Tuner | |
# https://www.youtube.com/watch?v=cLZv0eZQSIE | |
print("Finding the optimal learning rate using Lightning Tuner.") | |
tuner = Tuner(trainer) | |
tuner.lr_find( | |
model=model, | |
datamodule=datamodule, | |
min_lr=PREFERRED_START_LR, | |
max_lr=5, | |
num_training=200, | |
mode="linear", | |
early_stop_threshold=10, | |
attr_name="learning_rate", | |
) | |
trainer.fit(model, datamodule=datamodule) | |
trainer.test(model, dataloaders=datamodule.test_dataloader()) | |
# # Obtain the results dictionary from model | |
print("Collecting epoch level model results.") | |
results = model.results | |
# print(f"Results Length: {len(results)}") | |
# Get the list of misclassified images | |
print("Collecting misclassified images.") | |
misclassified_image_data = model.misclassified_image_data | |
# print(f"Misclassified Images Length: {len(misclassified_image_data)}") | |
# Save the model using torch save as backup | |
print("Saving the model.") | |
print(f"Model saved to {config.MODEL_PATH}") | |
create_folder_if_not_exists(config.MODEL_PATH) | |
torch.save(model.state_dict(), config.MODEL_PATH) | |
# Save first few misclassified images data to a file | |
num_elements = 20 | |
print(f"Saving first {num_elements} misclassified images.") | |
subset_misclassified_image_data = {"images": [], "ground_truths": [], "predicted_vals": []} | |
subset_misclassified_image_data["images"] = misclassified_image_data["images"][:num_elements] | |
subset_misclassified_image_data["ground_truths"] = misclassified_image_data["ground_truths"][:num_elements] | |
subset_misclassified_image_data["predicted_vals"] = misclassified_image_data["predicted_vals"][:num_elements] | |
create_folder_if_not_exists(config.MISCLASSIFIED_PATH) | |
torch.save(subset_misclassified_image_data, config.MISCLASSIFIED_PATH) | |
return trainer, results, misclassified_image_data | |
# return trainer |