File size: 3,585 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
import torch
import torch.nn as nn
_logger = logging.getLogger(__name__)
class Callback:
"""
Callback provides an easy way to react to events like begin/end of epochs.
"""
def __init__(self):
self.model = None
self.mutator = None
self.trainer = None
def build(self, model, mutator, trainer):
"""
Callback needs to be built with model, mutator, trainer, to get updates from them.
Parameters
----------
model : nn.Module
Model to be trained.
mutator : nn.Module
Mutator that mutates the model.
trainer : BaseTrainer
Trainer that is to call the callback.
"""
self.model = model
self.mutator = mutator
self.trainer = trainer
def on_epoch_begin(self, epoch):
"""
Implement this to do something at the begin of epoch.
Parameters
----------
epoch : int
Epoch number, starting from 0.
"""
pass
def on_epoch_end(self, epoch):
"""
Implement this to do something at the end of epoch.
Parameters
----------
epoch : int
Epoch number, starting from 0.
"""
pass
def on_batch_begin(self, epoch):
pass
def on_batch_end(self, epoch):
pass
class LRSchedulerCallback(Callback):
"""
Calls scheduler on every epoch ends.
Parameters
----------
scheduler : LRScheduler
Scheduler to be called.
"""
def __init__(self, scheduler, mode="epoch"):
super().__init__()
assert mode == "epoch"
self.scheduler = scheduler
self.mode = mode
def on_epoch_end(self, epoch):
"""
Call ``self.scheduler.step()`` on epoch end.
"""
self.scheduler.step()
class ArchitectureCheckpoint(Callback):
"""
Calls ``trainer.export()`` on every epoch ends.
Parameters
----------
checkpoint_dir : str
Location to save checkpoints.
"""
def __init__(self, checkpoint_dir):
super().__init__()
self.checkpoint_dir = checkpoint_dir
os.makedirs(self.checkpoint_dir, exist_ok=True)
def on_epoch_end(self, epoch):
"""
Dump to ``/checkpoint_dir/epoch_{number}.json`` on epoch end.
"""
dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch))
_logger.info("Saving architecture to %s", dest_path)
self.trainer.export(dest_path)
class ModelCheckpoint(Callback):
"""
Calls ``trainer.export()`` on every epoch ends.
Parameters
----------
checkpoint_dir : str
Location to save checkpoints.
"""
def __init__(self, checkpoint_dir):
super().__init__()
self.checkpoint_dir = checkpoint_dir
os.makedirs(self.checkpoint_dir, exist_ok=True)
def on_epoch_end(self, epoch):
"""
Dump to ``/checkpoint_dir/epoch_{number}.pth.tar`` on every epoch end.
``DataParallel`` object will have their inside modules exported.
"""
if isinstance(self.model, nn.DataParallel):
state_dict = self.model.module.state_dict()
else:
state_dict = self.model.state_dict()
dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.pth.tar".format(epoch))
_logger.info("Saving model to %s", dest_path)
torch.save(state_dict, dest_path)
|