|
|
|
|
|
|
|
import logging |
|
import os |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
_logger = logging.getLogger(__name__) |
|
|
|
|
|
class Callback: |
|
""" |
|
Callback provides an easy way to react to events like begin/end of epochs. |
|
""" |
|
|
|
def __init__(self): |
|
self.model = None |
|
self.mutator = None |
|
self.trainer = None |
|
|
|
def build(self, model, mutator, trainer): |
|
""" |
|
Callback needs to be built with model, mutator, trainer, to get updates from them. |
|
|
|
Parameters |
|
---------- |
|
model : nn.Module |
|
Model to be trained. |
|
mutator : nn.Module |
|
Mutator that mutates the model. |
|
trainer : BaseTrainer |
|
Trainer that is to call the callback. |
|
""" |
|
self.model = model |
|
self.mutator = mutator |
|
self.trainer = trainer |
|
|
|
def on_epoch_begin(self, epoch): |
|
""" |
|
Implement this to do something at the begin of epoch. |
|
|
|
Parameters |
|
---------- |
|
epoch : int |
|
Epoch number, starting from 0. |
|
""" |
|
pass |
|
|
|
def on_epoch_end(self, epoch): |
|
""" |
|
Implement this to do something at the end of epoch. |
|
|
|
Parameters |
|
---------- |
|
epoch : int |
|
Epoch number, starting from 0. |
|
""" |
|
pass |
|
|
|
def on_batch_begin(self, epoch): |
|
pass |
|
|
|
def on_batch_end(self, epoch): |
|
pass |
|
|
|
|
|
class LRSchedulerCallback(Callback): |
|
""" |
|
Calls scheduler on every epoch ends. |
|
|
|
Parameters |
|
---------- |
|
scheduler : LRScheduler |
|
Scheduler to be called. |
|
""" |
|
def __init__(self, scheduler, mode="epoch"): |
|
super().__init__() |
|
assert mode == "epoch" |
|
self.scheduler = scheduler |
|
self.mode = mode |
|
|
|
def on_epoch_end(self, epoch): |
|
""" |
|
Call ``self.scheduler.step()`` on epoch end. |
|
""" |
|
self.scheduler.step() |
|
|
|
|
|
class ArchitectureCheckpoint(Callback): |
|
""" |
|
Calls ``trainer.export()`` on every epoch ends. |
|
|
|
Parameters |
|
---------- |
|
checkpoint_dir : str |
|
Location to save checkpoints. |
|
""" |
|
def __init__(self, checkpoint_dir): |
|
super().__init__() |
|
self.checkpoint_dir = checkpoint_dir |
|
os.makedirs(self.checkpoint_dir, exist_ok=True) |
|
|
|
def on_epoch_end(self, epoch): |
|
""" |
|
Dump to ``/checkpoint_dir/epoch_{number}.json`` on epoch end. |
|
""" |
|
dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch)) |
|
_logger.info("Saving architecture to %s", dest_path) |
|
self.trainer.export(dest_path) |
|
|
|
|
|
class ModelCheckpoint(Callback): |
|
""" |
|
Calls ``trainer.export()`` on every epoch ends. |
|
|
|
Parameters |
|
---------- |
|
checkpoint_dir : str |
|
Location to save checkpoints. |
|
""" |
|
def __init__(self, checkpoint_dir): |
|
super().__init__() |
|
self.checkpoint_dir = checkpoint_dir |
|
os.makedirs(self.checkpoint_dir, exist_ok=True) |
|
|
|
def on_epoch_end(self, epoch): |
|
""" |
|
Dump to ``/checkpoint_dir/epoch_{number}.pth.tar`` on every epoch end. |
|
``DataParallel`` object will have their inside modules exported. |
|
""" |
|
if isinstance(self.model, nn.DataParallel): |
|
state_dict = self.model.module.state_dict() |
|
else: |
|
state_dict = self.model.state_dict() |
|
dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.pth.tar".format(epoch)) |
|
_logger.info("Saving model to %s", dest_path) |
|
torch.save(state_dict, dest_path) |
|
|