|
|
|
"""Train models with dynamic data.""" |
|
import sys |
|
import torch |
|
from functools import partial |
|
|
|
|
|
from onmt.utils.distributed import ErrorHandler, consumer, batch_producer |
|
from onmt.utils.misc import set_random_seed |
|
from onmt.modules.embeddings import prepare_pretrained_embeddings |
|
from onmt.utils.logging import init_logger, logger |
|
|
|
from onmt.models.model_saver import load_checkpoint |
|
from onmt.train_single import main as single_main, _build_train_iter |
|
|
|
from onmt.utils.parse import ArgumentParser |
|
from onmt.opts import train_opts |
|
from onmt.inputters.corpus import save_transformed_sample |
|
from onmt.inputters.fields import build_dynamic_fields, save_fields, \ |
|
load_fields |
|
from onmt.transforms import make_transforms, save_transforms, \ |
|
get_specials, get_transforms_cls |
|
|
|
|
|
torch.multiprocessing.set_sharing_strategy('file_system') |
|
|
|
|
|
def prepare_fields_transforms(opt): |
|
"""Prepare or dump fields & transforms before training.""" |
|
transforms_cls = get_transforms_cls(opt._all_transform) |
|
specials = get_specials(opt, transforms_cls) |
|
|
|
fields = build_dynamic_fields( |
|
opt, src_specials=specials['src'], tgt_specials=specials['tgt']) |
|
|
|
|
|
prepare_pretrained_embeddings(opt, fields) |
|
|
|
if opt.dump_fields: |
|
save_fields(fields, opt.save_data, overwrite=opt.overwrite) |
|
if opt.dump_transforms or opt.n_sample != 0: |
|
transforms = make_transforms(opt, transforms_cls, fields) |
|
if opt.dump_transforms: |
|
save_transforms(transforms, opt.save_data, overwrite=opt.overwrite) |
|
if opt.n_sample != 0: |
|
logger.warning( |
|
"`-n_sample` != 0: Training will not be started. " |
|
f"Stop after saving {opt.n_sample} samples/corpus.") |
|
save_transformed_sample(opt, transforms, n_sample=opt.n_sample) |
|
logger.info( |
|
"Sample saved, please check it before restart training.") |
|
sys.exit() |
|
return fields, transforms_cls |
|
|
|
|
|
def _init_train(opt): |
|
"""Common initilization stuff for all training process.""" |
|
ArgumentParser.validate_prepare_opts(opt) |
|
|
|
if opt.train_from: |
|
|
|
checkpoint = load_checkpoint(ckpt_path=opt.train_from) |
|
fields = load_fields(opt.save_data, checkpoint) |
|
transforms_cls = get_transforms_cls(opt._all_transform) |
|
if (hasattr(checkpoint["opt"], '_all_transform') and |
|
len(opt._all_transform.symmetric_difference( |
|
checkpoint["opt"]._all_transform)) != 0): |
|
_msg = "configured transforms is different from checkpoint:" |
|
new_transf = opt._all_transform.difference( |
|
checkpoint["opt"]._all_transform) |
|
old_transf = checkpoint["opt"]._all_transform.difference( |
|
opt._all_transform) |
|
if len(new_transf) != 0: |
|
_msg += f" +{new_transf}" |
|
if len(old_transf) != 0: |
|
_msg += f" -{old_transf}." |
|
logger.warning(_msg) |
|
if opt.update_vocab: |
|
logger.info("Updating checkpoint vocabulary with new vocabulary") |
|
fields, transforms_cls = prepare_fields_transforms(opt) |
|
else: |
|
checkpoint = None |
|
fields, transforms_cls = prepare_fields_transforms(opt) |
|
|
|
|
|
for side in ['src', 'tgt']: |
|
f = fields[side] |
|
try: |
|
f_iter = iter(f) |
|
except TypeError: |
|
f_iter = [(side, f)] |
|
for sn, sf in f_iter: |
|
if sf.use_vocab: |
|
logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) |
|
return checkpoint, fields, transforms_cls |
|
|
|
|
|
def train(opt): |
|
init_logger(opt.log_file) |
|
ArgumentParser.validate_train_opts(opt) |
|
ArgumentParser.update_model_opts(opt) |
|
ArgumentParser.validate_model_opts(opt) |
|
|
|
set_random_seed(opt.seed, False) |
|
|
|
checkpoint, fields, transforms_cls = _init_train(opt) |
|
train_process = partial( |
|
single_main, |
|
fields=fields, |
|
transforms_cls=transforms_cls, |
|
checkpoint=checkpoint) |
|
|
|
nb_gpu = len(opt.gpu_ranks) |
|
|
|
if opt.world_size > 1: |
|
|
|
queues = [] |
|
mp = torch.multiprocessing.get_context('spawn') |
|
semaphore = mp.Semaphore(opt.world_size * opt.queue_size) |
|
|
|
error_queue = mp.SimpleQueue() |
|
error_handler = ErrorHandler(error_queue) |
|
|
|
procs = [] |
|
for device_id in range(nb_gpu): |
|
q = mp.Queue(opt.queue_size) |
|
queues += [q] |
|
procs.append(mp.Process(target=consumer, args=( |
|
train_process, opt, device_id, error_queue, q, semaphore), |
|
daemon=True)) |
|
procs[device_id].start() |
|
logger.info(" Starting process pid: %d " % procs[device_id].pid) |
|
error_handler.add_child(procs[device_id].pid) |
|
producers = [] |
|
|
|
for device_id in range(nb_gpu): |
|
|
|
train_iter = _build_train_iter( |
|
opt, fields, transforms_cls, stride=nb_gpu, offset=device_id) |
|
producer = mp.Process(target=batch_producer, |
|
args=(train_iter, queues[device_id], |
|
semaphore, opt, device_id), |
|
daemon=True) |
|
producers.append(producer) |
|
producers[device_id].start() |
|
logger.info(" Starting producer process pid: {} ".format( |
|
producers[device_id].pid)) |
|
error_handler.add_child(producers[device_id].pid) |
|
|
|
for p in procs: |
|
p.join() |
|
|
|
for p in producers: |
|
p.terminate() |
|
|
|
elif nb_gpu == 1: |
|
train_process(opt, device_id=0) |
|
else: |
|
train_process(opt, device_id=-1) |
|
|
|
|
|
def _get_parser(): |
|
parser = ArgumentParser(description='train.py') |
|
train_opts(parser) |
|
return parser |
|
|
|
|
|
def main(): |
|
parser = _get_parser() |
|
|
|
opt, unknown = parser.parse_known_args() |
|
train(opt) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|