File size: 6,452 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
#!/usr/bin/env python
"""Train models with dynamic data."""
import sys
import torch
from functools import partial
# import onmt.opts as opts
from onmt.utils.distributed import ErrorHandler, consumer, batch_producer
from onmt.utils.misc import set_random_seed
from onmt.modules.embeddings import prepare_pretrained_embeddings
from onmt.utils.logging import init_logger, logger
from onmt.models.model_saver import load_checkpoint
from onmt.train_single import main as single_main, _build_train_iter
from onmt.utils.parse import ArgumentParser
from onmt.opts import train_opts
from onmt.inputters.corpus import save_transformed_sample
from onmt.inputters.fields import build_dynamic_fields, save_fields, \
load_fields
from onmt.transforms import make_transforms, save_transforms, \
get_specials, get_transforms_cls
# Set sharing strategy manually instead of default based on the OS.
torch.multiprocessing.set_sharing_strategy('file_system')
def prepare_fields_transforms(opt):
"""Prepare or dump fields & transforms before training."""
transforms_cls = get_transforms_cls(opt._all_transform)
specials = get_specials(opt, transforms_cls)
fields = build_dynamic_fields(
opt, src_specials=specials['src'], tgt_specials=specials['tgt'])
# maybe prepare pretrained embeddings, if any
prepare_pretrained_embeddings(opt, fields)
if opt.dump_fields:
save_fields(fields, opt.save_data, overwrite=opt.overwrite)
if opt.dump_transforms or opt.n_sample != 0:
transforms = make_transforms(opt, transforms_cls, fields)
if opt.dump_transforms:
save_transforms(transforms, opt.save_data, overwrite=opt.overwrite)
if opt.n_sample != 0:
logger.warning(
"`-n_sample` != 0: Training will not be started. "
f"Stop after saving {opt.n_sample} samples/corpus.")
save_transformed_sample(opt, transforms, n_sample=opt.n_sample)
logger.info(
"Sample saved, please check it before restart training.")
sys.exit()
return fields, transforms_cls
def _init_train(opt):
"""Common initilization stuff for all training process."""
ArgumentParser.validate_prepare_opts(opt)
if opt.train_from:
# Load checkpoint if we resume from a previous training.
checkpoint = load_checkpoint(ckpt_path=opt.train_from)
fields = load_fields(opt.save_data, checkpoint)
transforms_cls = get_transforms_cls(opt._all_transform)
if (hasattr(checkpoint["opt"], '_all_transform') and
len(opt._all_transform.symmetric_difference(
checkpoint["opt"]._all_transform)) != 0):
_msg = "configured transforms is different from checkpoint:"
new_transf = opt._all_transform.difference(
checkpoint["opt"]._all_transform)
old_transf = checkpoint["opt"]._all_transform.difference(
opt._all_transform)
if len(new_transf) != 0:
_msg += f" +{new_transf}"
if len(old_transf) != 0:
_msg += f" -{old_transf}."
logger.warning(_msg)
if opt.update_vocab:
logger.info("Updating checkpoint vocabulary with new vocabulary")
fields, transforms_cls = prepare_fields_transforms(opt)
else:
checkpoint = None
fields, transforms_cls = prepare_fields_transforms(opt)
# Report src and tgt vocab sizes
for side in ['src', 'tgt']:
f = fields[side]
try:
f_iter = iter(f)
except TypeError:
f_iter = [(side, f)]
for sn, sf in f_iter:
if sf.use_vocab:
logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))
return checkpoint, fields, transforms_cls
def train(opt):
init_logger(opt.log_file)
ArgumentParser.validate_train_opts(opt)
ArgumentParser.update_model_opts(opt)
ArgumentParser.validate_model_opts(opt)
set_random_seed(opt.seed, False)
checkpoint, fields, transforms_cls = _init_train(opt)
train_process = partial(
single_main,
fields=fields,
transforms_cls=transforms_cls,
checkpoint=checkpoint)
nb_gpu = len(opt.gpu_ranks)
if opt.world_size > 1:
queues = []
mp = torch.multiprocessing.get_context('spawn')
semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
# Create a thread to listen for errors in the child processes.
error_queue = mp.SimpleQueue()
error_handler = ErrorHandler(error_queue)
# Train with multiprocessing.
procs = []
for device_id in range(nb_gpu):
q = mp.Queue(opt.queue_size)
queues += [q]
procs.append(mp.Process(target=consumer, args=(
train_process, opt, device_id, error_queue, q, semaphore),
daemon=True))
procs[device_id].start()
logger.info(" Starting process pid: %d " % procs[device_id].pid)
error_handler.add_child(procs[device_id].pid)
producers = []
# This does not work if we merge with the first loop, not sure why
for device_id in range(nb_gpu):
# Get the iterator to generate from
train_iter = _build_train_iter(
opt, fields, transforms_cls, stride=nb_gpu, offset=device_id)
producer = mp.Process(target=batch_producer,
args=(train_iter, queues[device_id],
semaphore, opt, device_id),
daemon=True)
producers.append(producer)
producers[device_id].start()
logger.info(" Starting producer process pid: {} ".format(
producers[device_id].pid))
error_handler.add_child(producers[device_id].pid)
for p in procs:
p.join()
# Once training is done, we can terminate the producers
for p in producers:
p.terminate()
elif nb_gpu == 1: # case 1 GPU only
train_process(opt, device_id=0)
else: # case only CPU
train_process(opt, device_id=-1)
def _get_parser():
parser = ArgumentParser(description='train.py')
train_opts(parser)
return parser
def main():
parser = _get_parser()
opt, unknown = parser.parse_known_args()
train(opt)
if __name__ == "__main__":
main()
|