File size: 6,516 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
import os
import time
from abc import abstractmethod
import torch
from .base_trainer import BaseTrainer
_logger = logging.getLogger(__name__)
class TorchTensorEncoder(json.JSONEncoder):
def default(self, o): # pylint: disable=method-hidden
if isinstance(o, torch.Tensor):
olist = o.tolist()
if "bool" not in o.type().lower() and all(map(lambda d: d == 0 or d == 1, olist)):
_logger.warning("Every element in %s is either 0 or 1. "
"You might consider convert it into bool.", olist)
return olist
return super().default(o)
class Trainer(BaseTrainer):
"""
A trainer with some helper functions implemented. To implement a new trainer,
users need to implement :meth:`train_one_epoch`, :meth:`validate_one_epoch` and :meth:`checkpoint`.
Parameters
----------
model : nn.Module
Model with mutables.
mutator : BaseMutator
A mutator object that has been initialized with the model.
loss : callable
Called with logits and targets. Returns a loss tensor.
See `PyTorch loss functions`_ for examples.
metrics : callable
Called with logits and targets. Returns a dict that maps metrics keys to metrics data. For example,
.. code-block:: python
def metrics_fn(output, target):
return {"acc1": accuracy(output, target, topk=1), "acc5": accuracy(output, target, topk=5)}
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
dataset_train : torch.utils.data.Dataset
Dataset of training. If not otherwise specified, ``dataset_train`` and ``dataset_valid`` should be standard
PyTorch Dataset. See `torch.utils.data`_ for examples.
dataset_valid : torch.utils.data.Dataset
Dataset of validation/testing.
batch_size : int
Batch size.
workers : int
Number of workers used in data preprocessing.
device : torch.device
Device object. Either ``torch.device("cuda")`` or ``torch.device("cpu")``. When ``None``, trainer will
automatic detects GPU and selects GPU first.
log_frequency : int
Number of mini-batches to log metrics.
callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks.
.. _`PyTorch loss functions`: https://pytorch.org/docs/stable/nn.html#loss-functions
.. _`torch.utils.data`: https://pytorch.org/docs/stable/data.html
"""
def __init__(self, model, mutator, loss, metrics, optimizer, num_epochs,
dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.model = model
self.mutator = mutator
self.loss = loss
self.metrics = metrics
self.optimizer = optimizer
self.model.to(self.device)
self.mutator.to(self.device)
self.loss.to(self.device)
self.num_epochs = num_epochs
self.dataset_train = dataset_train
self.dataset_valid = dataset_valid
self.batch_size = batch_size
self.workers = workers
self.log_frequency = log_frequency
self.log_dir = os.path.join("logs", str(time.time()))
os.makedirs(self.log_dir, exist_ok=True)
self.status_writer = open(os.path.join(self.log_dir, "log"), "w")
self.callbacks = callbacks if callbacks is not None else []
for callback in self.callbacks:
callback.build(self.model, self.mutator, self)
@abstractmethod
def train_one_epoch(self, epoch):
"""
Train one epoch.
Parameters
----------
epoch : int
Epoch number starting from 0.
"""
pass
@abstractmethod
def validate_one_epoch(self, epoch):
"""
Validate one epoch.
Parameters
----------
epoch : int
Epoch number starting from 0.
"""
pass
def train(self, validate=True):
"""
Train ``num_epochs``.
Trigger callbacks at the start and the end of each epoch.
Parameters
----------
validate : bool
If ``true``, will do validation every epoch.
"""
for epoch in range(self.num_epochs):
for callback in self.callbacks:
callback.on_epoch_begin(epoch)
# training
_logger.info("Epoch %d Training", epoch + 1)
self.train_one_epoch(epoch)
if validate:
# validation
_logger.info("Epoch %d Validating", epoch + 1)
self.validate_one_epoch(epoch)
for callback in self.callbacks:
callback.on_epoch_end(epoch)
def validate(self):
"""
Do one validation.
"""
self.validate_one_epoch(-1)
def export(self, file):
"""
Call ``mutator.export()`` and dump the architecture to ``file``.
Parameters
----------
file : str
A file path. Expected to be a JSON.
"""
mutator_export = self.mutator.export()
with open(file, "w") as f:
json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
def checkpoint(self):
"""
Return trainer checkpoint.
"""
raise NotImplementedError("Not implemented yet")
def enable_visualization(self):
"""
Enable visualization. Write graph and training log to folder ``logs/<timestamp>``.
"""
sample = None
for x, _ in self.train_loader:
sample = x.to(self.device)[:2]
break
if sample is None:
_logger.warning("Sample is %s.", sample)
_logger.info("Creating graph json, writing to %s. Visualization enabled.", self.log_dir)
with open(os.path.join(self.log_dir, "graph.json"), "w") as f:
json.dump(self.mutator.graph(sample), f)
self.visualization_enabled = True
def _write_graph_status(self):
if hasattr(self, "visualization_enabled") and self.visualization_enabled:
print(json.dumps(self.mutator.status()), file=self.status_writer, flush=True)
|