LINC-BIT's picture
Upload 1912 files
b84549f verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
import os
import time
from abc import abstractmethod
import torch
from .base_trainer import BaseTrainer
_logger = logging.getLogger(__name__)
class TorchTensorEncoder(json.JSONEncoder):
def default(self, o): # pylint: disable=method-hidden
if isinstance(o, torch.Tensor):
olist = o.tolist()
if "bool" not in o.type().lower() and all(map(lambda d: d == 0 or d == 1, olist)):
_logger.warning("Every element in %s is either 0 or 1. "
"You might consider convert it into bool.", olist)
return olist
return super().default(o)
class Trainer(BaseTrainer):
"""
A trainer with some helper functions implemented. To implement a new trainer,
users need to implement :meth:`train_one_epoch`, :meth:`validate_one_epoch` and :meth:`checkpoint`.
Parameters
----------
model : nn.Module
Model with mutables.
mutator : BaseMutator
A mutator object that has been initialized with the model.
loss : callable
Called with logits and targets. Returns a loss tensor.
See `PyTorch loss functions`_ for examples.
metrics : callable
Called with logits and targets. Returns a dict that maps metrics keys to metrics data. For example,
.. code-block:: python
def metrics_fn(output, target):
return {"acc1": accuracy(output, target, topk=1), "acc5": accuracy(output, target, topk=5)}
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
dataset_train : torch.utils.data.Dataset
Dataset of training. If not otherwise specified, ``dataset_train`` and ``dataset_valid`` should be standard
PyTorch Dataset. See `torch.utils.data`_ for examples.
dataset_valid : torch.utils.data.Dataset
Dataset of validation/testing.
batch_size : int
Batch size.
workers : int
Number of workers used in data preprocessing.
device : torch.device
Device object. Either ``torch.device("cuda")`` or ``torch.device("cpu")``. When ``None``, trainer will
automatic detects GPU and selects GPU first.
log_frequency : int
Number of mini-batches to log metrics.
callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks.
.. _`PyTorch loss functions`: https://pytorch.org/docs/stable/nn.html#loss-functions
.. _`torch.utils.data`: https://pytorch.org/docs/stable/data.html
"""
def __init__(self, model, mutator, loss, metrics, optimizer, num_epochs,
dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.model = model
self.mutator = mutator
self.loss = loss
self.metrics = metrics
self.optimizer = optimizer
self.model.to(self.device)
self.mutator.to(self.device)
self.loss.to(self.device)
self.num_epochs = num_epochs
self.dataset_train = dataset_train
self.dataset_valid = dataset_valid
self.batch_size = batch_size
self.workers = workers
self.log_frequency = log_frequency
self.log_dir = os.path.join("logs", str(time.time()))
os.makedirs(self.log_dir, exist_ok=True)
self.status_writer = open(os.path.join(self.log_dir, "log"), "w")
self.callbacks = callbacks if callbacks is not None else []
for callback in self.callbacks:
callback.build(self.model, self.mutator, self)
@abstractmethod
def train_one_epoch(self, epoch):
"""
Train one epoch.
Parameters
----------
epoch : int
Epoch number starting from 0.
"""
pass
@abstractmethod
def validate_one_epoch(self, epoch):
"""
Validate one epoch.
Parameters
----------
epoch : int
Epoch number starting from 0.
"""
pass
def train(self, validate=True):
"""
Train ``num_epochs``.
Trigger callbacks at the start and the end of each epoch.
Parameters
----------
validate : bool
If ``true``, will do validation every epoch.
"""
for epoch in range(self.num_epochs):
for callback in self.callbacks:
callback.on_epoch_begin(epoch)
# training
_logger.info("Epoch %d Training", epoch + 1)
self.train_one_epoch(epoch)
if validate:
# validation
_logger.info("Epoch %d Validating", epoch + 1)
self.validate_one_epoch(epoch)
for callback in self.callbacks:
callback.on_epoch_end(epoch)
def validate(self):
"""
Do one validation.
"""
self.validate_one_epoch(-1)
def export(self, file):
"""
Call ``mutator.export()`` and dump the architecture to ``file``.
Parameters
----------
file : str
A file path. Expected to be a JSON.
"""
mutator_export = self.mutator.export()
with open(file, "w") as f:
json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
def checkpoint(self):
"""
Return trainer checkpoint.
"""
raise NotImplementedError("Not implemented yet")
def enable_visualization(self):
"""
Enable visualization. Write graph and training log to folder ``logs/<timestamp>``.
"""
sample = None
for x, _ in self.train_loader:
sample = x.to(self.device)[:2]
break
if sample is None:
_logger.warning("Sample is %s.", sample)
_logger.info("Creating graph json, writing to %s. Visualization enabled.", self.log_dir)
with open(os.path.join(self.log_dir, "graph.json"), "w") as f:
json.dump(self.mutator.graph(sample), f)
self.visualization_enabled = True
def _write_graph_status(self):
if hasattr(self, "visualization_enabled") and self.visualization_enabled:
print(json.dumps(self.mutator.status()), file=self.status_writer, flush=True)