sakharamg's picture
Uploading all files
158b61b
import configargparse as cfargparse
import os
import torch
import onmt.opts as opts
from onmt.utils.logging import logger
from onmt.constants import CorpusName, ModelTask
from onmt.transforms import AVAILABLE_TRANSFORMS
class DataOptsCheckerMixin(object):
"""Checker with methods for validate data related options."""
@staticmethod
def _validate_file(file_path, info):
"""Check `file_path` is valid or raise `IOError`."""
if not os.path.isfile(file_path):
raise IOError(f"Please check path of your {info} file!")
@classmethod
def _validate_data(cls, opt):
"""Parse corpora specified in data field of YAML file."""
import yaml
default_transforms = opt.transforms
if len(default_transforms) != 0:
logger.info(f"Default transforms: {default_transforms}.")
corpora = yaml.safe_load(opt.data)
for cname, corpus in corpora.items():
# Check Transforms
_transforms = corpus.get('transforms', None)
if _transforms is None:
logger.info(f"Missing transforms field for {cname} data, "
f"set to default: {default_transforms}.")
corpus['transforms'] = default_transforms
# Check path
path_src = corpus.get('path_src', None)
path_tgt = corpus.get('path_tgt', None)
if path_src is None:
raise ValueError(f'Corpus {cname} src path is required.'
'tgt path is also required for non language'
' modeling tasks.')
else:
opt.data_task = ModelTask.SEQ2SEQ
if path_tgt is None:
logger.warning(
"path_tgt is None, it should be set unless the task"
" is language modeling"
)
opt.data_task = ModelTask.LANGUAGE_MODEL
# tgt is src for LM task
corpus["path_tgt"] = path_src
corpora[cname] = corpus
path_tgt = path_src
cls._validate_file(path_src, info=f'{cname}/path_src')
cls._validate_file(path_tgt, info=f'{cname}/path_tgt')
path_align = corpus.get('path_align', None)
if path_align is None:
if hasattr(opt, 'lambda_align') and opt.lambda_align > 0.0:
raise ValueError(f'Corpus {cname} alignment file path are '
'required when lambda_align > 0.0')
corpus['path_align'] = None
else:
cls._validate_file(path_align, info=f'{cname}/path_align')
# Check prefix: will be used when use prefix transform
src_prefix = corpus.get('src_prefix', None)
tgt_prefix = corpus.get('tgt_prefix', None)
if src_prefix is None or tgt_prefix is None:
if 'prefix' in corpus['transforms']:
raise ValueError(f'Corpus {cname} prefix are required.')
# Check weight
weight = corpus.get('weight', None)
if weight is None:
if cname != CorpusName.VALID:
logger.warning(f"Corpus {cname}'s weight should be given."
" We default it to 1 for you.")
corpus['weight'] = 1
# Check features
src_feats = corpus.get("src_feats", None)
if src_feats is not None:
for feature_name, feature_file in src_feats.items():
cls._validate_file(
feature_file, info=f'{cname}/path_{feature_name}')
if 'inferfeats' not in corpus["transforms"]:
raise ValueError(
"'inferfeats' transform is required "
"when setting source features")
if 'filterfeats' not in corpus["transforms"]:
raise ValueError(
"'filterfeats' transform is required "
"when setting source features")
else:
corpus["src_feats"] = None
logger.info(f"Parsed {len(corpora)} corpora from -data.")
opt.data = corpora
@classmethod
def _validate_transforms_opts(cls, opt):
"""Check options used by transforms."""
for name, transform_cls in AVAILABLE_TRANSFORMS.items():
if name in opt._all_transform:
transform_cls._validate_options(opt)
@classmethod
def _get_all_transform(cls, opt):
"""Should only called after `_validate_data`."""
all_transforms = set(opt.transforms)
for cname, corpus in opt.data.items():
_transforms = set(corpus['transforms'])
if len(_transforms) != 0:
all_transforms.update(_transforms)
if hasattr(opt, 'lambda_align') and opt.lambda_align > 0.0:
if not all_transforms.isdisjoint(
{'sentencepiece', 'bpe', 'onmt_tokenize'}):
raise ValueError('lambda_align is not compatible with'
' on-the-fly tokenization.')
if not all_transforms.isdisjoint(
{'tokendrop', 'prefix', 'bart'}):
raise ValueError('lambda_align is not compatible yet with'
' potentiel token deletion/addition.')
opt._all_transform = all_transforms
@classmethod
def _get_all_transform_translate(cls, opt):
opt._all_transform = opt.transforms
@classmethod
def _validate_fields_opts(cls, opt, build_vocab_only=False):
"""Check options relate to vocab and fields."""
for cname, corpus in opt.data.items():
if cname != CorpusName.VALID and corpus["src_feats"] is not None:
assert opt.src_feats_vocab, \
"-src_feats_vocab is required if using source features."
if isinstance(opt.src_feats_vocab, str):
import yaml
opt.src_feats_vocab = yaml.safe_load(opt.src_feats_vocab)
for feature in corpus["src_feats"].keys():
assert feature in opt.src_feats_vocab, \
f"No vocab file set for feature {feature}"
if build_vocab_only:
if not opt.share_vocab:
assert opt.tgt_vocab, \
"-tgt_vocab is required if not -share_vocab."
return
# validation when train:
cls._validate_file(opt.src_vocab, info='src vocab')
if not opt.share_vocab:
cls._validate_file(opt.tgt_vocab, info='tgt vocab')
if opt.dump_fields or opt.dump_transforms:
assert opt.save_data, "-save_data should be set if set \
-dump_fields or -dump_transforms."
# Check embeddings stuff
if opt.both_embeddings is not None:
assert (opt.src_embeddings is None
and opt.tgt_embeddings is None), \
"You don't need -src_embeddings or -tgt_embeddings \
if -both_embeddings is set."
if any([opt.both_embeddings is not None,
opt.src_embeddings is not None,
opt.tgt_embeddings is not None]):
assert opt.embeddings_type is not None, \
"You need to specify an -embedding_type!"
assert opt.save_data, "-save_data should be set if use \
pretrained embeddings."
@classmethod
def _validate_language_model_compatibilities_opts(cls, opt):
if opt.model_task != ModelTask.LANGUAGE_MODEL:
return
logger.info("encoder is not used for LM task")
assert opt.share_vocab and (
opt.tgt_vocab is None
), "vocab must be shared for LM task"
assert (
opt.decoder_type == "transformer"
), "Only transformer decoder is supported for LM task"
@classmethod
def validate_prepare_opts(cls, opt, build_vocab_only=False):
"""Validate all options relate to prepare (data/transform/vocab)."""
if opt.n_sample != 0:
assert opt.save_data, "-save_data should be set if \
want save samples."
cls._validate_data(opt)
cls._get_all_transform(opt)
cls._validate_transforms_opts(opt)
cls._validate_fields_opts(opt, build_vocab_only=build_vocab_only)
@classmethod
def validate_model_opts(cls, opt):
cls._validate_language_model_compatibilities_opts(opt)
class ArgumentParser(cfargparse.ArgumentParser, DataOptsCheckerMixin):
"""OpenNMT option parser powered with option check methods."""
def __init__(
self,
config_file_parser_class=cfargparse.YAMLConfigFileParser,
formatter_class=cfargparse.ArgumentDefaultsHelpFormatter,
**kwargs):
super(ArgumentParser, self).__init__(
config_file_parser_class=config_file_parser_class,
formatter_class=formatter_class,
**kwargs)
@classmethod
def defaults(cls, *args):
"""Get default arguments added to a parser by all ``*args``."""
dummy_parser = cls()
for callback in args:
callback(dummy_parser)
defaults = dummy_parser.parse_known_args([])[0]
return defaults
@classmethod
def update_model_opts(cls, model_opt):
if model_opt.word_vec_size > 0:
model_opt.src_word_vec_size = model_opt.word_vec_size
model_opt.tgt_word_vec_size = model_opt.word_vec_size
# Backward compatibility with "fix_word_vecs_*" opts
if hasattr(model_opt, 'fix_word_vecs_enc'):
model_opt.freeze_word_vecs_enc = model_opt.fix_word_vecs_enc
if hasattr(model_opt, 'fix_word_vecs_dec'):
model_opt.freeze_word_vecs_dec = model_opt.fix_word_vecs_dec
if model_opt.layers > 0:
model_opt.enc_layers = model_opt.layers
model_opt.dec_layers = model_opt.layers
if model_opt.rnn_size > 0:
model_opt.enc_rnn_size = model_opt.rnn_size
model_opt.dec_rnn_size = model_opt.rnn_size
model_opt.brnn = model_opt.encoder_type == "brnn"
if model_opt.copy_attn_type is None:
model_opt.copy_attn_type = model_opt.global_attention
if model_opt.alignment_layer is None:
model_opt.alignment_layer = -2
model_opt.lambda_align = 0.0
model_opt.full_context_alignment = False
@classmethod
def validate_model_opts(cls, model_opt):
assert model_opt.model_type in ["text"], \
"Unsupported model type %s" % model_opt.model_type
# encoder and decoder should be same sizes
same_size = model_opt.enc_rnn_size == model_opt.dec_rnn_size
assert same_size, \
"The encoder and decoder rnns must be the same size for now"
assert model_opt.rnn_type != "SRU" or model_opt.gpu_ranks, \
"Using SRU requires -gpu_ranks set."
if model_opt.share_embeddings:
if model_opt.model_type != "text":
raise AssertionError(
"--share_embeddings requires --model_type text.")
if model_opt.lambda_align > 0.0:
assert model_opt.decoder_type == 'transformer', \
"Only transformer is supported to joint learn alignment."
assert model_opt.alignment_layer < model_opt.dec_layers and \
model_opt.alignment_layer >= -model_opt.dec_layers, \
"N° alignment_layer should be smaller than number of layers."
logger.info("Joint learn alignment at layer [{}] "
"with {} heads in full_context '{}'.".format(
model_opt.alignment_layer,
model_opt.alignment_heads,
model_opt.full_context_alignment))
@classmethod
def ckpt_model_opts(cls, ckpt_opt):
# Load default opt values, then overwrite with the opts in
# the checkpoint. That way, if there are new options added,
# the defaults are used.
opt = cls.defaults(opts.model_opts)
opt.__dict__.update(ckpt_opt.__dict__)
return opt
@classmethod
def validate_train_opts(cls, opt):
if opt.epochs:
raise AssertionError(
"-epochs is deprecated please use -train_steps.")
if opt.truncated_decoder > 0 and max(opt.accum_count) > 1:
raise AssertionError("BPTT is not compatible with -accum > 1")
if opt.gpuid:
raise AssertionError(
"gpuid is deprecated see world_size and gpu_ranks")
if torch.cuda.is_available() and not opt.gpu_ranks:
logger.warn("You have a CUDA device, should run with -gpu_ranks")
if opt.world_size < len(opt.gpu_ranks):
raise AssertionError(
"parameter counts of -gpu_ranks must be less or equal "
"than -world_size.")
if opt.world_size == len(opt.gpu_ranks) and \
min(opt.gpu_ranks) > 0:
raise AssertionError(
"-gpu_ranks should have master(=0) rank "
"unless -world_size is greater than len(gpu_ranks).")
assert len(opt.dropout) == len(opt.dropout_steps), \
"Number of dropout values must match accum_steps values"
assert len(opt.attention_dropout) == len(opt.dropout_steps), \
"Number of attention_dropout values must match accum_steps values"
assert len(opt.accum_count) == len(opt.accum_steps), \
'Number of accum_count values must match number of accum_steps'
if opt.update_vocab:
assert opt.train_from, \
"-update_vocab needs -train_from option"
assert opt.reset_optim in ['states', 'all'], \
'-update_vocab needs -reset_optim "states" or "all"'
@classmethod
def validate_translate_opts(cls, opt):
opt.src_feats = eval(opt.src_feats) if opt.src_feats else {}
@classmethod
def validate_translate_opts_dynamic(cls, opt):
# It comes from training
# TODO: needs to be added as inference opt
opt.share_vocab = False