File size: 6,301 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
#!/usr/bin/env python
#
# This file is part of moses. Its use is licensed under the GNU Lesser General
# Public License version 2.1 or, at your option, any later version.
from __future__ import print_function, unicode_literals
import logging
import argparse
import subprocess
import sys
import os
logging.basicConfig(
format='%(asctime)s %(levelname)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG)
parser = argparse.ArgumentParser()
parser.add_argument("-w", "--working-dir", dest="working_dir")
parser.add_argument("-c", "--corpus", dest="corpus_stem")
parser.add_argument("-l", "--nplm-home", dest="nplm_home")
parser.add_argument("-e", "--epochs", dest="epochs", type=int)
parser.add_argument("-n", "--ngram-size", dest="ngram_size", type=int)
parser.add_argument("-b", "--minibatch-size", dest="minibatch_size", type=int)
parser.add_argument("-s", "--noise", dest="noise", type=int)
parser.add_argument("-d", "--hidden", dest="hidden", type=int)
parser.add_argument(
"-i", "--input-embedding", dest="input_embedding", type=int)
parser.add_argument(
"-o", "--output-embedding", dest="output_embedding", type=int)
parser.add_argument("-t", "--threads", dest="threads", type=int)
parser.add_argument("-m", "--output-model", dest="output_model")
parser.add_argument("-r", "--output-dir", dest="output_dir")
parser.add_argument("-f", "--config-options-file", dest="config_options_file")
parser.add_argument("-g", "--log-file", dest="log_file")
parser.add_argument("-v", "--validation-ngrams", dest="validation_file")
parser.add_argument("-a", "--activation-function", dest="activation_fn")
parser.add_argument("-z", "--learning-rate", dest="learning_rate")
parser.add_argument("--input-words-file", dest="input_words_file")
parser.add_argument("--output-words-file", dest="output_words_file")
parser.add_argument("--input_vocab_size", dest="input_vocab_size", type=int)
parser.add_argument("--output_vocab_size", dest="output_vocab_size", type=int)
parser.add_argument("--mmap", dest="mmap", action="store_true",
help="Use memory-mapped file (for lower memory consumption).")
parser.add_argument("--extra-settings", dest="extra_settings",
help="Extra settings to be passed to NPLM")
parser.add_argument(
"--train-host", dest="train_host",
help="Execute nplm training on this host, via ssh")
parser.set_defaults(
working_dir="working",
corpus_stem="train.10k",
nplm_home="/home/bhaddow/tools/nplm",
epochs=10,
ngram_size=14,
minibatch_size=1000,
noise=100,
hidden=0,
input_embedding=150,
output_embedding=750,
threads=1,
output_model="train.10k",
output_dir=None,
config_options_file="config",
log_file="log",
validation_file=None,
activation_fn="rectifier",
learning_rate=1,
input_words_file=None,
output_words_file=None,
input_vocab_size=0,
output_vocab_size=0
)
def main(options):
vocab_command = []
if options.input_words_file is not None:
vocab_command += ['--input_words_file', options.input_words_file]
if options.output_words_file is not None:
vocab_command += ['--output_words_file', options.output_words_file]
if options.input_vocab_size:
vocab_command += ['--input_vocab_size', str(options.input_vocab_size)]
if options.output_vocab_size:
vocab_command += [
'--output_vocab_size', str(options.output_vocab_size)]
# Set up validation command variable to use with validation set.
validations_command = []
if options.validation_file is not None:
validations_command = [
"--validation_file", (options.validation_file + ".numberized")]
# In order to allow for different models to be trained after the same
# preparation step, we should provide an option for multiple output
# directories.
# If we have not set output_dir, set it to the same thing as the working
# dir.
if options.output_dir is None:
options.output_dir = options.working_dir
else:
# Create output dir if necessary
if not os.path.exists(options.output_dir):
os.makedirs(options.output_dir)
config_file = os.path.join(
options.output_dir,
options.config_options_file + '-' + options.output_model)
log_file = os.path.join(
options.output_dir, options.log_file + '-' + options.output_model)
log_file_write = open(log_file, 'w')
config_file_write = open(config_file, 'w')
config_file_write.write("Called: " + ' '.join(sys.argv) + '\n\n')
in_file = os.path.join(
options.working_dir,
os.path.basename(options.corpus_stem) + ".numberized")
mmap_command = []
if options.mmap:
in_file += '.mmap'
mmap_command = ['--mmap_file', '1']
model_prefix = os.path.join(
options.output_dir, options.output_model + ".model.nplm")
train_args = []
if options.train_host:
train_args = ["ssh", options.train_host]
train_args += [
options.nplm_home + "/src/trainNeuralNetwork",
"--train_file", in_file,
"--num_epochs", str(options.epochs),
"--model_prefix", model_prefix,
"--learning_rate", str(options.learning_rate),
"--minibatch_size", str(options.minibatch_size),
"--num_noise_samples", str(options.noise),
"--num_hidden", str(options.hidden),
"--input_embedding_dimension", str(options.input_embedding),
"--output_embedding_dimension", str(options.output_embedding),
"--num_threads", str(options.threads),
"--activation_function", options.activation_fn,
"--ngram_size", str(options.ngram_size),
] + validations_command + vocab_command + mmap_command
if options.extra_settings: train_args += options.extra_settings.split()
print("Train model command: ")
print(', '.join(train_args))
config_file_write.write("Training step:\n" + ' '.join(train_args) + '\n')
config_file_write.close()
log_file_write.write("Training output:\n")
ret = subprocess.call(
train_args, stdout=log_file_write, stderr=log_file_write)
if ret:
raise Exception("Training failed")
log_file_write.close()
if __name__ == "__main__":
options = parser.parse_args()
main(options)
|