Spaces:
Running
Running
File size: 41,703 Bytes
847e3e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 |
"""Tweaked version of corresponding AllenNLP file"""
import datetime
import logging
import math
import os
import time
import traceback
from typing import Dict, Optional, List, Tuple, Union, Iterable, Any
import torch
import torch.optim.lr_scheduler
from allennlp.common import Params
from allennlp.common.checks import ConfigurationError, parse_cuda_device
from allennlp.common.tqdm import Tqdm
from allennlp.common.util import dump_metrics, gpu_memory_mb, peak_memory_mb, lazy_groups_of
from allennlp.data.instance import Instance
from allennlp.data.iterators.data_iterator import DataIterator, TensorDict
from allennlp.models.model import Model
from allennlp.nn import util as nn_util
from allennlp.training import util as training_util
from allennlp.training.checkpointer import Checkpointer
from allennlp.training.learning_rate_schedulers import LearningRateScheduler
from allennlp.training.metric_tracker import MetricTracker
from allennlp.training.momentum_schedulers import MomentumScheduler
from allennlp.training.moving_average import MovingAverage
from allennlp.training.optimizers import Optimizer
from allennlp.training.tensorboard_writer import TensorboardWriter
from allennlp.training.trainer_base import TrainerBase
logger = logging.getLogger(__name__)
class Trainer(TrainerBase):
def __init__(
self,
model: Model,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler,
iterator: DataIterator,
train_dataset: Iterable[Instance],
validation_dataset: Optional[Iterable[Instance]] = None,
patience: Optional[int] = None,
validation_metric: str = "-loss",
validation_iterator: DataIterator = None,
shuffle: bool = True,
num_epochs: int = 20,
accumulated_batch_count: int = 1,
serialization_dir: Optional[str] = None,
num_serialized_models_to_keep: int = 20,
keep_serialized_model_every_num_seconds: int = None,
checkpointer: Checkpointer = None,
model_save_interval: float = None,
cuda_device: Union[int, List] = -1,
grad_norm: Optional[float] = None,
grad_clipping: Optional[float] = None,
learning_rate_scheduler: Optional[LearningRateScheduler] = None,
momentum_scheduler: Optional[MomentumScheduler] = None,
summary_interval: int = 100,
histogram_interval: int = None,
should_log_parameter_statistics: bool = True,
should_log_learning_rate: bool = False,
log_batch_size_period: Optional[int] = None,
moving_average: Optional[MovingAverage] = None,
cold_step_count: int = 0,
cold_lr: float = 1e-3,
cuda_verbose_step=None,
) -> None:
"""
A trainer for doing supervised learning. It just takes a labeled dataset
and a ``DataIterator``, and uses the supplied ``Optimizer`` to learn the weights
for your model over some fixed number of epochs. You can also pass in a validation
dataset and enable early stopping. There are many other bells and whistles as well.
Parameters
----------
model : ``Model``, required.
An AllenNLP model to be optimized. Pytorch Modules can also be optimized if
their ``forward`` method returns a dictionary with a "loss" key, containing a
scalar tensor representing the loss function to be optimized.
If you are training your model using GPUs, your model should already be
on the correct device. (If you use `Trainer.from_params` this will be
handled for you.)
optimizer : ``torch.nn.Optimizer``, required.
An instance of a Pytorch Optimizer, instantiated with the parameters of the
model to be optimized.
iterator : ``DataIterator``, required.
A method for iterating over a ``Dataset``, yielding padded indexed batches.
train_dataset : ``Dataset``, required.
A ``Dataset`` to train on. The dataset should have already been indexed.
validation_dataset : ``Dataset``, optional, (default = None).
A ``Dataset`` to evaluate on. The dataset should have already been indexed.
patience : Optional[int] > 0, optional (default=None)
Number of epochs to be patient before early stopping: the training is stopped
after ``patience`` epochs with no improvement. If given, it must be ``> 0``.
If None, early stopping is disabled.
validation_metric : str, optional (default="loss")
Validation metric to measure for whether to stop training using patience
and whether to serialize an ``is_best`` model each epoch. The metric name
must be prepended with either "+" or "-", which specifies whether the metric
is an increasing or decreasing function.
validation_iterator : ``DataIterator``, optional (default=None)
An iterator to use for the validation set. If ``None``, then
use the training `iterator`.
shuffle: ``bool``, optional (default=True)
Whether to shuffle the instances in the iterator or not.
num_epochs : int, optional (default = 20)
Number of training epochs.
serialization_dir : str, optional (default=None)
Path to directory for saving and loading model files. Models will not be saved if
this parameter is not passed.
num_serialized_models_to_keep : ``int``, optional (default=20)
Number of previous model checkpoints to retain. Default is to keep 20 checkpoints.
A value of None or -1 means all checkpoints will be kept.
keep_serialized_model_every_num_seconds : ``int``, optional (default=None)
If num_serialized_models_to_keep is not None, then occasionally it's useful to
save models at a given interval in addition to the last num_serialized_models_to_keep.
To do so, specify keep_serialized_model_every_num_seconds as the number of seconds
between permanently saved checkpoints. Note that this option is only used if
num_serialized_models_to_keep is not None, otherwise all checkpoints are kept.
checkpointer : ``Checkpointer``, optional (default=None)
An instance of class Checkpointer to use instead of the default. If a checkpointer is specified,
the arguments num_serialized_models_to_keep and keep_serialized_model_every_num_seconds should
not be specified. The caller is responsible for initializing the checkpointer so that it is
consistent with serialization_dir.
model_save_interval : ``float``, optional (default=None)
If provided, then serialize models every ``model_save_interval``
seconds within single epochs. In all cases, models are also saved
at the end of every epoch if ``serialization_dir`` is provided.
cuda_device : ``Union[int, List[int]]``, optional (default = -1)
An integer or list of integers specifying the CUDA device(s) to use. If -1, the CPU is used.
grad_norm : ``float``, optional, (default = None).
If provided, gradient norms will be rescaled to have a maximum of this value.
grad_clipping : ``float``, optional (default = ``None``).
If provided, gradients will be clipped `during the backward pass` to have an (absolute)
maximum of this value. If you are getting ``NaNs`` in your gradients during training
that are not solved by using ``grad_norm``, you may need this.
learning_rate_scheduler : ``LearningRateScheduler``, optional (default = None)
If specified, the learning rate will be decayed with respect to
this schedule at the end of each epoch (or batch, if the scheduler implements
the ``step_batch`` method). If you use :class:`torch.optim.lr_scheduler.ReduceLROnPlateau`,
this will use the ``validation_metric`` provided to determine if learning has plateaued.
To support updating the learning rate on every batch, this can optionally implement
``step_batch(batch_num_total)`` which updates the learning rate given the batch number.
momentum_scheduler : ``MomentumScheduler``, optional (default = None)
If specified, the momentum will be updated at the end of each batch or epoch
according to the schedule.
summary_interval: ``int``, optional, (default = 100)
Number of batches between logging scalars to tensorboard
histogram_interval : ``int``, optional, (default = ``None``)
If not None, then log histograms to tensorboard every ``histogram_interval`` batches.
When this parameter is specified, the following additional logging is enabled:
* Histograms of model parameters
* The ratio of parameter update norm to parameter norm
* Histogram of layer activations
We log histograms of the parameters returned by
``model.get_parameters_for_histogram_tensorboard_logging``.
The layer activations are logged for any modules in the ``Model`` that have
the attribute ``should_log_activations`` set to ``True``. Logging
histograms requires a number of GPU-CPU copies during training and is typically
slow, so we recommend logging histograms relatively infrequently.
Note: only Modules that return tensors, tuples of tensors or dicts
with tensors as values currently support activation logging.
should_log_parameter_statistics : ``bool``, optional, (default = True)
Whether to send parameter statistics (mean and standard deviation
of parameters and gradients) to tensorboard.
should_log_learning_rate : ``bool``, optional, (default = False)
Whether to send parameter specific learning rate to tensorboard.
log_batch_size_period : ``int``, optional, (default = ``None``)
If defined, how often to log the average batch size.
moving_average: ``MovingAverage``, optional, (default = None)
If provided, we will maintain moving averages for all parameters. During training, we
employ a shadow variable for each parameter, which maintains the moving average. During
evaluation, we backup the original parameters and assign the moving averages to corresponding
parameters. Be careful that when saving the checkpoint, we will save the moving averages of
parameters. This is necessary because we want the saved model to perform as well as the validated
model if we load it later. But this may cause problems if you restart the training from checkpoint.
"""
super().__init__(serialization_dir, cuda_device)
# I am not calling move_to_gpu here, because if the model is
# not already on the GPU then the optimizer is going to be wrong.
self.model = model
self.iterator = iterator
self._validation_iterator = validation_iterator
self.shuffle = shuffle
self.optimizer = optimizer
self.scheduler = scheduler
self.train_data = train_dataset
self._validation_data = validation_dataset
self.accumulated_batch_count = accumulated_batch_count
self.cold_step_count = cold_step_count
self.cold_lr = cold_lr
self.cuda_verbose_step = cuda_verbose_step
if patience is None: # no early stopping
if validation_dataset:
logger.warning(
"You provided a validation dataset but patience was set to None, "
"meaning that early stopping is disabled"
)
elif (not isinstance(patience, int)) or patience <= 0:
raise ConfigurationError(
'{} is an invalid value for "patience": it must be a positive integer '
"or None (if you want to disable early stopping)".format(patience)
)
# For tracking is_best_so_far and should_stop_early
self._metric_tracker = MetricTracker(patience, validation_metric)
# Get rid of + or -
self._validation_metric = validation_metric[1:]
self._num_epochs = num_epochs
if checkpointer is not None:
# We can't easily check if these parameters were passed in, so check against their default values.
# We don't check against serialization_dir since it is also used by the parent class.
if num_serialized_models_to_keep != 20 \
or keep_serialized_model_every_num_seconds is not None:
raise ConfigurationError(
"When passing a custom Checkpointer, you may not also pass in separate checkpointer "
"args 'num_serialized_models_to_keep' or 'keep_serialized_model_every_num_seconds'."
)
self._checkpointer = checkpointer
else:
self._checkpointer = Checkpointer(
serialization_dir,
keep_serialized_model_every_num_seconds,
num_serialized_models_to_keep,
)
self._model_save_interval = model_save_interval
self._grad_norm = grad_norm
self._grad_clipping = grad_clipping
self._learning_rate_scheduler = learning_rate_scheduler
self._momentum_scheduler = momentum_scheduler
self._moving_average = moving_average
# We keep the total batch number as an instance variable because it
# is used inside a closure for the hook which logs activations in
# ``_enable_activation_logging``.
self._batch_num_total = 0
self._tensorboard = TensorboardWriter(
get_batch_num_total=lambda: self._batch_num_total,
serialization_dir=serialization_dir,
summary_interval=summary_interval,
histogram_interval=histogram_interval,
should_log_parameter_statistics=should_log_parameter_statistics,
should_log_learning_rate=should_log_learning_rate,
)
self._log_batch_size_period = log_batch_size_period
self._last_log = 0.0 # time of last logging
# Enable activation logging.
if histogram_interval is not None:
self._tensorboard.enable_activation_logging(self.model)
def rescale_gradients(self) -> Optional[float]:
return training_util.rescale_gradients(self.model, self._grad_norm)
def batch_loss(self, batch_group: List[TensorDict], for_training: bool) -> torch.Tensor:
"""
Does a forward pass on the given batches and returns the ``loss`` value in the result.
If ``for_training`` is `True` also applies regularization penalty.
"""
if self._multiple_gpu:
output_dict = training_util.data_parallel(batch_group, self.model, self._cuda_devices)
else:
assert len(batch_group) == 1
batch = batch_group[0]
batch = nn_util.move_to_device(batch, self._cuda_devices[0])
output_dict = self.model(**batch)
try:
loss = output_dict["loss"]
if for_training:
loss += self.model.get_regularization_penalty()
except KeyError:
if for_training:
raise RuntimeError(
"The model you are trying to optimize does not contain a"
" 'loss' key in the output of model.forward(inputs)."
)
loss = None
return loss
def _train_epoch(self, epoch: int) -> Dict[str, float]:
"""
Trains one epoch and returns metrics.
"""
logger.info("Epoch %d/%d", epoch, self._num_epochs - 1)
peak_cpu_usage = peak_memory_mb()
logger.info(f"Peak CPU memory usage MB: {peak_cpu_usage}")
gpu_usage = []
for gpu, memory in gpu_memory_mb().items():
gpu_usage.append((gpu, memory))
logger.info(f"GPU {gpu} memory usage MB: {memory}")
train_loss = 0.0
# Set the model to "train" mode.
self.model.train()
num_gpus = len(self._cuda_devices)
# Get tqdm for the training batches
raw_train_generator = self.iterator(self.train_data, num_epochs=1, shuffle=self.shuffle)
train_generator = lazy_groups_of(raw_train_generator, num_gpus)
num_training_batches = math.ceil(self.iterator.get_num_batches(self.train_data) / num_gpus)
residue = num_training_batches % self.accumulated_batch_count
self._last_log = time.time()
last_save_time = time.time()
batches_this_epoch = 0
if self._batch_num_total is None:
self._batch_num_total = 0
histogram_parameters = set(self.model.get_parameters_for_histogram_tensorboard_logging())
logger.info("Training")
train_generator_tqdm = Tqdm.tqdm(train_generator, total=num_training_batches)
cumulative_batch_size = 0
self.optimizer.zero_grad()
for batch_group in train_generator_tqdm:
batches_this_epoch += 1
self._batch_num_total += 1
batch_num_total = self._batch_num_total
iter_len = self.accumulated_batch_count \
if batches_this_epoch <= (num_training_batches - residue) else residue
if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0:
print(f'Before forward pass - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}')
print(f'Before forward pass - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}')
try:
loss = self.batch_loss(batch_group, for_training=True) / iter_len
except RuntimeError as e:
print(e)
for x in batch_group:
all_words = [len(y['words']) for y in x['metadata']]
print(f"Total sents: {len(all_words)}. "
f"Min {min(all_words)}. Max {max(all_words)}")
for elem in ['labels', 'd_tags']:
tt = x[elem]
print(
f"{elem} shape {list(tt.shape)} and min {tt.min().item()} and {tt.max().item()}")
for elem in ["bert", "mask", "bert-offsets"]:
tt = x['tokens'][elem]
print(
f"{elem} shape {list(tt.shape)} and min {tt.min().item()} and {tt.max().item()}")
raise e
if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0:
print(f'After forward pass - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}')
print(f'After forward pass - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}')
if torch.isnan(loss):
raise ValueError("nan loss encountered")
loss.backward()
if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0:
print(f'After backprop - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}')
print(f'After backprop - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}')
train_loss += loss.item() * iter_len
del batch_group, loss
torch.cuda.empty_cache()
if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0:
print(f'After collecting garbage - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}')
print(f'After collecting garbage - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}')
batch_grad_norm = self.rescale_gradients()
# This does nothing if batch_num_total is None or you are using a
# scheduler which doesn't update per batch.
if self._learning_rate_scheduler:
self._learning_rate_scheduler.step_batch(batch_num_total)
if self._momentum_scheduler:
self._momentum_scheduler.step_batch(batch_num_total)
if self._tensorboard.should_log_histograms_this_batch():
# get the magnitude of parameter updates for logging
# We need a copy of current parameters to compute magnitude of updates,
# and copy them to CPU so large models won't go OOM on the GPU.
param_updates = {
name: param.detach().cpu().clone()
for name, param in self.model.named_parameters()
}
if batches_this_epoch % self.accumulated_batch_count == 0 or \
batches_this_epoch == num_training_batches:
self.optimizer.step()
self.optimizer.zero_grad()
for name, param in self.model.named_parameters():
param_updates[name].sub_(param.detach().cpu())
update_norm = torch.norm(param_updates[name].view(-1))
param_norm = torch.norm(param.view(-1)).cpu()
self._tensorboard.add_train_scalar(
"gradient_update/" + name, update_norm / (param_norm + 1e-7)
)
else:
if batches_this_epoch % self.accumulated_batch_count == 0 or \
batches_this_epoch == num_training_batches:
self.optimizer.step()
self.optimizer.zero_grad()
# Update moving averages
if self._moving_average is not None:
self._moving_average.apply(batch_num_total)
# Update the description with the latest metrics
metrics = training_util.get_metrics(self.model, train_loss, batches_this_epoch)
description = training_util.description_from_metrics(metrics)
train_generator_tqdm.set_description(description, refresh=False)
# Log parameter values to Tensorboard
if self._tensorboard.should_log_this_batch():
self._tensorboard.log_parameter_and_gradient_statistics(self.model, batch_grad_norm)
self._tensorboard.log_learning_rates(self.model, self.optimizer)
self._tensorboard.add_train_scalar("loss/loss_train", metrics["loss"])
self._tensorboard.log_metrics({"epoch_metrics/" + k: v for k, v in metrics.items()})
if self._tensorboard.should_log_histograms_this_batch():
self._tensorboard.log_histograms(self.model, histogram_parameters)
if self._log_batch_size_period:
cur_batch = sum([training_util.get_batch_size(batch) for batch in batch_group])
cumulative_batch_size += cur_batch
if (batches_this_epoch - 1) % self._log_batch_size_period == 0:
average = cumulative_batch_size / batches_this_epoch
logger.info(f"current batch size: {cur_batch} mean batch size: {average}")
self._tensorboard.add_train_scalar("current_batch_size", cur_batch)
self._tensorboard.add_train_scalar("mean_batch_size", average)
# Save model if needed.
if self._model_save_interval is not None and (
time.time() - last_save_time > self._model_save_interval
):
last_save_time = time.time()
self._save_checkpoint(
"{0}.{1}".format(epoch, training_util.time_to_str(int(last_save_time)))
)
metrics = training_util.get_metrics(self.model, train_loss, batches_this_epoch, reset=True)
metrics["cpu_memory_MB"] = peak_cpu_usage
for (gpu_num, memory) in gpu_usage:
metrics["gpu_" + str(gpu_num) + "_memory_MB"] = memory
return metrics
def _validation_loss(self) -> Tuple[float, int]:
"""
Computes the validation loss. Returns it and the number of batches.
"""
logger.info("Validating")
self.model.eval()
# Replace parameter values with the shadow values from the moving averages.
if self._moving_average is not None:
self._moving_average.assign_average_value()
if self._validation_iterator is not None:
val_iterator = self._validation_iterator
else:
val_iterator = self.iterator
num_gpus = len(self._cuda_devices)
raw_val_generator = val_iterator(self._validation_data, num_epochs=1, shuffle=False)
val_generator = lazy_groups_of(raw_val_generator, num_gpus)
num_validation_batches = math.ceil(
val_iterator.get_num_batches(self._validation_data) / num_gpus
)
val_generator_tqdm = Tqdm.tqdm(val_generator, total=num_validation_batches)
batches_this_epoch = 0
val_loss = 0
for batch_group in val_generator_tqdm:
loss = self.batch_loss(batch_group, for_training=False)
if loss is not None:
# You shouldn't necessarily have to compute a loss for validation, so we allow for
# `loss` to be None. We need to be careful, though - `batches_this_epoch` is
# currently only used as the divisor for the loss function, so we can safely only
# count those batches for which we actually have a loss. If this variable ever
# gets used for something else, we might need to change things around a bit.
batches_this_epoch += 1
val_loss += loss.detach().cpu().numpy()
# Update the description with the latest metrics
val_metrics = training_util.get_metrics(self.model, val_loss, batches_this_epoch)
description = training_util.description_from_metrics(val_metrics)
val_generator_tqdm.set_description(description, refresh=False)
# Now restore the original parameter values.
if self._moving_average is not None:
self._moving_average.restore()
return val_loss, batches_this_epoch
def train(self) -> Dict[str, Any]:
"""
Trains the supplied model with the supplied parameters.
"""
try:
epoch_counter = self._restore_checkpoint()
except RuntimeError:
traceback.print_exc()
raise ConfigurationError(
"Could not recover training from the checkpoint. Did you mean to output to "
"a different serialization directory or delete the existing serialization "
"directory?"
)
training_util.enable_gradient_clipping(self.model, self._grad_clipping)
logger.info("Beginning training.")
train_metrics: Dict[str, float] = {}
val_metrics: Dict[str, float] = {}
this_epoch_val_metric: float = None
metrics: Dict[str, Any] = {}
epochs_trained = 0
training_start_time = time.time()
if self.cold_step_count > 0:
base_lr = self.optimizer.param_groups[0]['lr']
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.cold_lr
self.model.text_field_embedder._token_embedders['bert'].set_weights(freeze=True)
metrics["best_epoch"] = self._metric_tracker.best_epoch
for key, value in self._metric_tracker.best_epoch_metrics.items():
metrics["best_validation_" + key] = value
for epoch in range(epoch_counter, self._num_epochs):
if epoch == self.cold_step_count and epoch != 0:
for param_group in self.optimizer.param_groups:
param_group['lr'] = base_lr
self.model.text_field_embedder._token_embedders['bert'].set_weights(freeze=False)
epoch_start_time = time.time()
train_metrics = self._train_epoch(epoch)
# get peak of memory usage
if "cpu_memory_MB" in train_metrics:
metrics["peak_cpu_memory_MB"] = max(
metrics.get("peak_cpu_memory_MB", 0), train_metrics["cpu_memory_MB"]
)
for key, value in train_metrics.items():
if key.startswith("gpu_"):
metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value)
# clear cache before validation
torch.cuda.empty_cache()
if self._validation_data is not None:
with torch.no_grad():
# We have a validation set, so compute all the metrics on it.
val_loss, num_batches = self._validation_loss()
val_metrics = training_util.get_metrics(
self.model, val_loss, num_batches, reset=True
)
# Check validation metric for early stopping
this_epoch_val_metric = val_metrics[self._validation_metric]
self._metric_tracker.add_metric(this_epoch_val_metric)
if self._metric_tracker.should_stop_early():
logger.info("Ran out of patience. Stopping training.")
break
self._tensorboard.log_metrics(
train_metrics, val_metrics=val_metrics, log_to_console=True, epoch=epoch + 1
) # +1 because tensorboard doesn't like 0
# Create overall metrics dict
training_elapsed_time = time.time() - training_start_time
metrics["training_duration"] = str(datetime.timedelta(seconds=training_elapsed_time))
metrics["training_start_epoch"] = epoch_counter
metrics["training_epochs"] = epochs_trained
metrics["epoch"] = epoch
for key, value in train_metrics.items():
metrics["training_" + key] = value
for key, value in val_metrics.items():
metrics["validation_" + key] = value
# if self.cold_step_count <= epoch:
self.scheduler.step(metrics['validation_loss'])
if self._metric_tracker.is_best_so_far():
# Update all the best_ metrics.
# (Otherwise they just stay the same as they were.)
metrics["best_epoch"] = epoch
for key, value in val_metrics.items():
metrics["best_validation_" + key] = value
self._metric_tracker.best_epoch_metrics = val_metrics
if self._serialization_dir:
dump_metrics(
os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"), metrics
)
# The Scheduler API is agnostic to whether your schedule requires a validation metric -
# if it doesn't, the validation metric passed here is ignored.
if self._learning_rate_scheduler:
self._learning_rate_scheduler.step(this_epoch_val_metric, epoch)
if self._momentum_scheduler:
self._momentum_scheduler.step(this_epoch_val_metric, epoch)
self._save_checkpoint(epoch)
epoch_elapsed_time = time.time() - epoch_start_time
logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time))
if epoch < self._num_epochs - 1:
training_elapsed_time = time.time() - training_start_time
estimated_time_remaining = training_elapsed_time * (
(self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1
)
formatted_time = str(datetime.timedelta(seconds=int(estimated_time_remaining)))
logger.info("Estimated training time remaining: %s", formatted_time)
epochs_trained += 1
# make sure pending events are flushed to disk and files are closed properly
# self._tensorboard.close()
# Load the best model state before returning
best_model_state = self._checkpointer.best_model_state()
if best_model_state:
self.model.load_state_dict(best_model_state)
return metrics
def _save_checkpoint(self, epoch: Union[int, str]) -> None:
"""
Saves a checkpoint of the model to self._serialization_dir.
Is a no-op if self._serialization_dir is None.
Parameters
----------
epoch : Union[int, str], required.
The epoch of training. If the checkpoint is saved in the middle
of an epoch, the parameter is a string with the epoch and timestamp.
"""
# If moving averages are used for parameters, we save
# the moving average values into checkpoint, instead of the current values.
if self._moving_average is not None:
self._moving_average.assign_average_value()
# These are the training states we need to persist.
training_states = {
"metric_tracker": self._metric_tracker.state_dict(),
"optimizer": self.optimizer.state_dict(),
"batch_num_total": self._batch_num_total,
}
# If we have a learning rate or momentum scheduler, we should persist them too.
if self._learning_rate_scheduler is not None:
training_states["learning_rate_scheduler"] = self._learning_rate_scheduler.state_dict()
if self._momentum_scheduler is not None:
training_states["momentum_scheduler"] = self._momentum_scheduler.state_dict()
self._checkpointer.save_checkpoint(
model_state=self.model.state_dict(),
epoch=epoch,
training_states=training_states,
is_best_so_far=self._metric_tracker.is_best_so_far(),
)
# Restore the original values for parameters so that training will not be affected.
if self._moving_average is not None:
self._moving_average.restore()
def _restore_checkpoint(self) -> int:
"""
Restores the model and training state from the last saved checkpoint.
This includes an epoch count and optimizer state, which is serialized separately
from model parameters. This function should only be used to continue training -
if you wish to load a model for inference/load parts of a model into a new
computation graph, you should use the native Pytorch functions:
`` model.load_state_dict(torch.load("/path/to/model/weights.th"))``
If ``self._serialization_dir`` does not exist or does not contain any checkpointed weights,
this function will do nothing and return 0.
Returns
-------
epoch: int
The epoch at which to resume training, which should be one after the epoch
in the saved training state.
"""
model_state, training_state = self._checkpointer.restore_checkpoint()
if not training_state:
# No checkpoint to restore, start at 0
return 0
self.model.load_state_dict(model_state)
self.optimizer.load_state_dict(training_state["optimizer"])
if self._learning_rate_scheduler is not None \
and "learning_rate_scheduler" in training_state:
self._learning_rate_scheduler.load_state_dict(training_state["learning_rate_scheduler"])
if self._momentum_scheduler is not None and "momentum_scheduler" in training_state:
self._momentum_scheduler.load_state_dict(training_state["momentum_scheduler"])
training_util.move_optimizer_to_cuda(self.optimizer)
# Currently the ``training_state`` contains a serialized ``MetricTracker``.
if "metric_tracker" in training_state:
self._metric_tracker.load_state_dict(training_state["metric_tracker"])
# It used to be the case that we tracked ``val_metric_per_epoch``.
elif "val_metric_per_epoch" in training_state:
self._metric_tracker.clear()
self._metric_tracker.add_metrics(training_state["val_metric_per_epoch"])
# And before that we didn't track anything.
else:
self._metric_tracker.clear()
if isinstance(training_state["epoch"], int):
epoch_to_return = training_state["epoch"] + 1
else:
epoch_to_return = int(training_state["epoch"].split(".")[0]) + 1
# For older checkpoints with batch_num_total missing, default to old behavior where
# it is unchanged.
batch_num_total = training_state.get("batch_num_total")
if batch_num_total is not None:
self._batch_num_total = batch_num_total
return epoch_to_return
# Requires custom from_params.
@classmethod
def from_params( # type: ignore
cls,
model: Model,
serialization_dir: str,
iterator: DataIterator,
train_data: Iterable[Instance],
validation_data: Optional[Iterable[Instance]],
params: Params,
validation_iterator: DataIterator = None,
) -> "Trainer":
patience = params.pop_int("patience", None)
validation_metric = params.pop("validation_metric", "-loss")
shuffle = params.pop_bool("shuffle", True)
num_epochs = params.pop_int("num_epochs", 20)
cuda_device = parse_cuda_device(params.pop("cuda_device", -1))
grad_norm = params.pop_float("grad_norm", None)
grad_clipping = params.pop_float("grad_clipping", None)
lr_scheduler_params = params.pop("learning_rate_scheduler", None)
momentum_scheduler_params = params.pop("momentum_scheduler", None)
if isinstance(cuda_device, list):
model_device = cuda_device[0]
else:
model_device = cuda_device
if model_device >= 0:
# Moving model to GPU here so that the optimizer state gets constructed on
# the right device.
model = model.cuda(model_device)
parameters = [[n, p] for n, p in model.named_parameters() if p.requires_grad]
optimizer = Optimizer.from_params(parameters, params.pop("optimizer"))
if "moving_average" in params:
moving_average = MovingAverage.from_params(
params.pop("moving_average"), parameters=parameters
)
else:
moving_average = None
if lr_scheduler_params:
lr_scheduler = LearningRateScheduler.from_params(optimizer, lr_scheduler_params)
else:
lr_scheduler = None
if momentum_scheduler_params:
momentum_scheduler = MomentumScheduler.from_params(optimizer, momentum_scheduler_params)
else:
momentum_scheduler = None
if "checkpointer" in params:
if "keep_serialized_model_every_num_seconds" in params \
or "num_serialized_models_to_keep" in params:
raise ConfigurationError(
"Checkpointer may be initialized either from the 'checkpointer' key or from the "
"keys 'num_serialized_models_to_keep' and 'keep_serialized_model_every_num_seconds'"
" but the passed config uses both methods."
)
checkpointer = Checkpointer.from_params(params.pop("checkpointer"))
else:
num_serialized_models_to_keep = params.pop_int("num_serialized_models_to_keep", 20)
keep_serialized_model_every_num_seconds = params.pop_int(
"keep_serialized_model_every_num_seconds", None
)
checkpointer = Checkpointer(
serialization_dir=serialization_dir,
num_serialized_models_to_keep=num_serialized_models_to_keep,
keep_serialized_model_every_num_seconds=keep_serialized_model_every_num_seconds,
)
model_save_interval = params.pop_float("model_save_interval", None)
summary_interval = params.pop_int("summary_interval", 100)
histogram_interval = params.pop_int("histogram_interval", None)
should_log_parameter_statistics = params.pop_bool("should_log_parameter_statistics", True)
should_log_learning_rate = params.pop_bool("should_log_learning_rate", False)
log_batch_size_period = params.pop_int("log_batch_size_period", None)
params.assert_empty(cls.__name__)
return cls(
model,
optimizer,
iterator,
train_data,
validation_data,
patience=patience,
validation_metric=validation_metric,
validation_iterator=validation_iterator,
shuffle=shuffle,
num_epochs=num_epochs,
serialization_dir=serialization_dir,
cuda_device=cuda_device,
grad_norm=grad_norm,
grad_clipping=grad_clipping,
learning_rate_scheduler=lr_scheduler,
momentum_scheduler=momentum_scheduler,
checkpointer=checkpointer,
model_save_interval=model_save_interval,
summary_interval=summary_interval,
histogram_interval=histogram_interval,
should_log_parameter_statistics=should_log_parameter_statistics,
should_log_learning_rate=should_log_learning_rate,
log_batch_size_period=log_batch_size_period,
moving_average=moving_average,
)
|