|
|
|
"""Training on a single process.""" |
|
import torch |
|
|
|
from onmt.inputters.inputter import IterOnDevice |
|
from onmt.model_builder import build_model |
|
from onmt.utils.optimizers import Optimizer |
|
from onmt.utils.misc import set_random_seed |
|
from onmt.trainer import build_trainer |
|
from onmt.models import build_model_saver |
|
from onmt.utils.logging import init_logger, logger |
|
from onmt.utils.parse import ArgumentParser |
|
|
|
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter |
|
|
|
|
|
def configure_process(opt, device_id): |
|
if device_id >= 0: |
|
torch.cuda.set_device(device_id) |
|
set_random_seed(opt.seed, device_id >= 0) |
|
|
|
|
|
def _get_model_opts(opt, checkpoint=None): |
|
"""Get `model_opt` to build model, may load from `checkpoint` if any.""" |
|
if checkpoint is not None: |
|
model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) |
|
ArgumentParser.update_model_opts(model_opt) |
|
ArgumentParser.validate_model_opts(model_opt) |
|
if (opt.tensorboard_log_dir == model_opt.tensorboard_log_dir and |
|
hasattr(model_opt, 'tensorboard_log_dir_dated')): |
|
|
|
|
|
opt.tensorboard_log_dir_dated = model_opt.tensorboard_log_dir_dated |
|
|
|
model_opt.update_vocab = opt.update_vocab |
|
else: |
|
model_opt = opt |
|
return model_opt |
|
|
|
|
|
def _build_valid_iter(opt, fields, transforms_cls): |
|
"""Build iterator used for validation.""" |
|
valid_iter = build_dynamic_dataset_iter( |
|
fields, transforms_cls, opt, is_train=False) |
|
return valid_iter |
|
|
|
|
|
def _build_train_iter(opt, fields, transforms_cls, stride=1, offset=0): |
|
"""Build training iterator.""" |
|
train_iter = build_dynamic_dataset_iter( |
|
fields, transforms_cls, opt, is_train=True, |
|
stride=stride, offset=offset) |
|
return train_iter |
|
|
|
|
|
def main(opt, fields, transforms_cls, checkpoint, device_id, |
|
batch_queue=None, semaphore=None): |
|
"""Start training on `device_id`.""" |
|
|
|
|
|
configure_process(opt, device_id) |
|
init_logger(opt.log_file) |
|
|
|
model_opt = _get_model_opts(opt, checkpoint=checkpoint) |
|
|
|
|
|
model = build_model(model_opt, opt, fields, checkpoint) |
|
model.count_parameters(log=logger.info) |
|
|
|
|
|
optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) |
|
|
|
|
|
model_saver = build_model_saver(model_opt, opt, model, fields, optim) |
|
|
|
trainer = build_trainer( |
|
opt, device_id, model, fields, optim, model_saver=model_saver) |
|
|
|
if batch_queue is None: |
|
_train_iter = _build_train_iter(opt, fields, transforms_cls) |
|
train_iter = IterOnDevice(_train_iter, device_id) |
|
else: |
|
assert semaphore is not None, \ |
|
"Using batch_queue requires semaphore as well" |
|
|
|
def _train_iter(): |
|
while True: |
|
batch = batch_queue.get() |
|
semaphore.release() |
|
|
|
IterOnDevice.batch_to_device(batch, device_id) |
|
yield batch |
|
|
|
train_iter = _train_iter() |
|
|
|
valid_iter = _build_valid_iter(opt, fields, transforms_cls) |
|
if valid_iter is not None: |
|
valid_iter = IterOnDevice(valid_iter, device_id) |
|
|
|
if len(opt.gpu_ranks): |
|
logger.info('Starting training on GPU: %s' % opt.gpu_ranks) |
|
else: |
|
logger.info('Starting training on CPU, could be very slow') |
|
train_steps = opt.train_steps |
|
if opt.single_pass and train_steps > 0: |
|
logger.warning("Option single_pass is enabled, ignoring train_steps.") |
|
train_steps = 0 |
|
|
|
trainer.train( |
|
train_iter, |
|
train_steps, |
|
save_checkpoint_steps=opt.save_checkpoint_steps, |
|
valid_iter=valid_iter, |
|
valid_steps=opt.valid_steps) |
|
|
|
if trainer.report_manager.tensorboard_writer is not None: |
|
trainer.report_manager.tensorboard_writer.close() |
|
|