|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import print_function, unicode_literals |
|
|
|
|
|
import logging |
|
|
import argparse |
|
|
import subprocess |
|
|
import sys |
|
|
import os |
|
|
|
|
|
logging.basicConfig( |
|
|
format='%(asctime)s %(levelname)s: %(message)s', |
|
|
datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG) |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("-w", "--working-dir", dest="working_dir") |
|
|
parser.add_argument("-c", "--corpus", dest="corpus_stem") |
|
|
parser.add_argument("-l", "--nplm-home", dest="nplm_home") |
|
|
parser.add_argument("-e", "--epochs", dest="epochs", type=int) |
|
|
parser.add_argument("-n", "--ngram-size", dest="ngram_size", type=int) |
|
|
parser.add_argument("-b", "--minibatch-size", dest="minibatch_size", type=int) |
|
|
parser.add_argument("-s", "--noise", dest="noise", type=int) |
|
|
parser.add_argument("-d", "--hidden", dest="hidden", type=int) |
|
|
parser.add_argument( |
|
|
"-i", "--input-embedding", dest="input_embedding", type=int) |
|
|
parser.add_argument( |
|
|
"-o", "--output-embedding", dest="output_embedding", type=int) |
|
|
parser.add_argument("-t", "--threads", dest="threads", type=int) |
|
|
parser.add_argument("-m", "--output-model", dest="output_model") |
|
|
parser.add_argument("-r", "--output-dir", dest="output_dir") |
|
|
parser.add_argument("-f", "--config-options-file", dest="config_options_file") |
|
|
parser.add_argument("-g", "--log-file", dest="log_file") |
|
|
parser.add_argument("-v", "--validation-ngrams", dest="validation_file") |
|
|
parser.add_argument("-a", "--activation-function", dest="activation_fn") |
|
|
parser.add_argument("-z", "--learning-rate", dest="learning_rate") |
|
|
parser.add_argument("--input-words-file", dest="input_words_file") |
|
|
parser.add_argument("--output-words-file", dest="output_words_file") |
|
|
parser.add_argument("--input_vocab_size", dest="input_vocab_size", type=int) |
|
|
parser.add_argument("--output_vocab_size", dest="output_vocab_size", type=int) |
|
|
parser.add_argument("--mmap", dest="mmap", action="store_true", |
|
|
help="Use memory-mapped file (for lower memory consumption).") |
|
|
parser.add_argument("--extra-settings", dest="extra_settings", |
|
|
help="Extra settings to be passed to NPLM") |
|
|
parser.add_argument( |
|
|
"--train-host", dest="train_host", |
|
|
help="Execute nplm training on this host, via ssh") |
|
|
|
|
|
parser.set_defaults( |
|
|
working_dir="working", |
|
|
corpus_stem="train.10k", |
|
|
nplm_home="/home/bhaddow/tools/nplm", |
|
|
epochs=10, |
|
|
ngram_size=14, |
|
|
minibatch_size=1000, |
|
|
noise=100, |
|
|
hidden=0, |
|
|
input_embedding=150, |
|
|
output_embedding=750, |
|
|
threads=1, |
|
|
output_model="train.10k", |
|
|
output_dir=None, |
|
|
config_options_file="config", |
|
|
log_file="log", |
|
|
validation_file=None, |
|
|
activation_fn="rectifier", |
|
|
learning_rate=1, |
|
|
input_words_file=None, |
|
|
output_words_file=None, |
|
|
input_vocab_size=0, |
|
|
output_vocab_size=0 |
|
|
) |
|
|
|
|
|
|
|
|
def main(options): |
|
|
|
|
|
vocab_command = [] |
|
|
if options.input_words_file is not None: |
|
|
vocab_command += ['--input_words_file', options.input_words_file] |
|
|
if options.output_words_file is not None: |
|
|
vocab_command += ['--output_words_file', options.output_words_file] |
|
|
if options.input_vocab_size: |
|
|
vocab_command += ['--input_vocab_size', str(options.input_vocab_size)] |
|
|
if options.output_vocab_size: |
|
|
vocab_command += [ |
|
|
'--output_vocab_size', str(options.output_vocab_size)] |
|
|
|
|
|
|
|
|
validations_command = [] |
|
|
if options.validation_file is not None: |
|
|
validations_command = [ |
|
|
"--validation_file", (options.validation_file + ".numberized")] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if options.output_dir is None: |
|
|
options.output_dir = options.working_dir |
|
|
else: |
|
|
|
|
|
if not os.path.exists(options.output_dir): |
|
|
os.makedirs(options.output_dir) |
|
|
|
|
|
config_file = os.path.join( |
|
|
options.output_dir, |
|
|
options.config_options_file + '-' + options.output_model) |
|
|
log_file = os.path.join( |
|
|
options.output_dir, options.log_file + '-' + options.output_model) |
|
|
log_file_write = open(log_file, 'w') |
|
|
config_file_write = open(config_file, 'w') |
|
|
|
|
|
config_file_write.write("Called: " + ' '.join(sys.argv) + '\n\n') |
|
|
|
|
|
in_file = os.path.join( |
|
|
options.working_dir, |
|
|
os.path.basename(options.corpus_stem) + ".numberized") |
|
|
|
|
|
mmap_command = [] |
|
|
if options.mmap: |
|
|
in_file += '.mmap' |
|
|
mmap_command = ['--mmap_file', '1'] |
|
|
|
|
|
model_prefix = os.path.join( |
|
|
options.output_dir, options.output_model + ".model.nplm") |
|
|
train_args = [] |
|
|
if options.train_host: |
|
|
train_args = ["ssh", options.train_host] |
|
|
train_args += [ |
|
|
options.nplm_home + "/src/trainNeuralNetwork", |
|
|
"--train_file", in_file, |
|
|
"--num_epochs", str(options.epochs), |
|
|
"--model_prefix", model_prefix, |
|
|
"--learning_rate", str(options.learning_rate), |
|
|
"--minibatch_size", str(options.minibatch_size), |
|
|
"--num_noise_samples", str(options.noise), |
|
|
"--num_hidden", str(options.hidden), |
|
|
"--input_embedding_dimension", str(options.input_embedding), |
|
|
"--output_embedding_dimension", str(options.output_embedding), |
|
|
"--num_threads", str(options.threads), |
|
|
"--activation_function", options.activation_fn, |
|
|
"--ngram_size", str(options.ngram_size), |
|
|
] + validations_command + vocab_command + mmap_command |
|
|
if options.extra_settings: train_args += options.extra_settings.split() |
|
|
print("Train model command: ") |
|
|
print(', '.join(train_args)) |
|
|
|
|
|
config_file_write.write("Training step:\n" + ' '.join(train_args) + '\n') |
|
|
config_file_write.close() |
|
|
|
|
|
log_file_write.write("Training output:\n") |
|
|
ret = subprocess.call( |
|
|
train_args, stdout=log_file_write, stderr=log_file_write) |
|
|
if ret: |
|
|
raise Exception("Training failed") |
|
|
|
|
|
log_file_write.close() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
options = parser.parse_args() |
|
|
main(options) |
|
|
|