sakharamg's picture
Uploading all files
158b61b
#!/usr/bin/env python
# Written by Michael Denkowski
#
# This file is part of moses. Its use is licensed under the GNU Lesser General
# Public License version 2.1 or, at your option, any later version.
"""Parallelize decoding with simulated post-editing via moses XML input.
(XML entities need to be escaped in tokenization). Memory mapped
dynamic phrase tables (Ulrich Germann,
www.statmt.org/moses/?n=Moses.AdvancedFeatures#ntoc40) and language models
(Kenneth Heafield,
http://www.statmt.org/moses/?n=FactoredTraining.BuildingLanguageModel#ntoc19)
facilitate memory efficient multi process decoding. Input is divided into
batches, each of which is decoded sequentially. Each batch pre-loads the
data from previous batches.
To use in tuning, run mert-moses.pl with --sim-pe=SYMAL where SYMAL is the
alignment from input to references. Specify the number of jobs with
--decoder-flags="-threads N".
"""
import gzip
import itertools
import math
import os
import shutil
import subprocess
import sys
import tempfile
import threading
HELP = '''Moses with simulated post-editing
Usage:
{} moses-cmd -config moses.ini -input-file text.src -ref text.tgt \
-symal text.src-tgt.symal [options] [decoder flags]
Options:
-threads N: number of decoders to run in parallel \
(default read from moses.ini, 1 if not present)
-n-best-list nbest.out N [distinct]: location and size of N-best list
-show-weights: for mert-moses.pl, just call moses and exit
-tmp: location of temp directory (default /tmp)
Other options (decoder flags) are passed through to moses-cmd\n'''
class ProgramFailure(Exception):
"""Known kind of failure, with a known presentation to the user.
Error message will be printed, and the program will return an error,
but no traceback will be shown to the user.
"""
class Progress:
"""Provides progress bar."""
def __init__(self):
self.i = 0
self.lock = threading.Lock()
def inc(self):
self.lock.acquire()
self.i += 1
if self.i % 100 == 0:
sys.stderr.write('.')
if self.i % 1000 == 0:
sys.stderr.write(' [{}]\n'.format(self.i))
sys.stderr.flush()
self.lock.release()
def done(self):
self.lock.acquire()
if self.i % 1000 != 0:
sys.stderr.write('\n')
self.lock.release()
def atomic_io(cmd, in_file, out_file, err_file, prog=None):
"""Run with atomic (synchronous) I/O."""
with open(in_file, 'r') as inp, open(out_file, 'w') as out, open(err_file, 'w') as err:
p = subprocess.Popen(
cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=err)
while True:
line = inp.readline()
if not line:
break
p.stdin.write(line)
out.write(p.stdout.readline())
out.flush()
if prog:
prog.inc()
p.stdin.close()
p.wait()
def gzopen(f):
"""Open plain or gzipped text."""
return gzip.open(f, 'rb') if f.endswith('.gz') else open(f, 'r')
def wc(f):
"""Word count."""
i = 0
for line in gzopen(f):
i += 1
return i
def write_gzfile(lines, f):
"""Write lines to gzipped file."""
out = gzip.open(f, 'wb')
for line in lines:
out.write('{}\n'.format(line))
out.close()
def main(argv):
# Defaults
moses_ini = None
moses_ini_lines = None
text_src = None
text_tgt = None
text_symal = None
text_len = None
threads_found = False
threads = 1
n_best_out = None
n_best_size = None
n_best_distinct = False
hg_ext = None
hg_dir = None
tmp_dir = '/tmp'
xml_found = False
xml_input = 'exclusive'
show_weights = False
mmsapt_dynamic = []
mmsapt_static = []
mmsapt_l1 = None
mmsapt_l2 = None
# Decoder command
cmd = argv[1:]
# Parse special options and remove from cmd
i = 1
while i < len(cmd):
if cmd[i] in ('-f', '-config'):
moses_ini = cmd[i + 1]
cmd = cmd[:i] + cmd[i + 2:]
elif cmd[i] in ('-i', '-input-file'):
text_src = cmd[i + 1]
cmd = cmd[:i] + cmd[i + 2:]
elif cmd[i] == '-ref':
text_tgt = cmd[i + 1]
cmd = cmd[:i] + cmd[i + 2:]
elif cmd[i] == '-symal':
text_symal = cmd[i + 1]
cmd = cmd[:i] + cmd[i + 2:]
elif cmd[i] in ('-th', '-threads'):
threads_found = True
threads = int(cmd[i + 1])
cmd = cmd[:i] + cmd[i + 2:]
elif cmd[i] == '-n-best-list':
n_best_out = cmd[i + 1]
n_best_size = cmd[i + 2]
# Optional "distinct"
if i + 3 < len(cmd) and cmd[i + 3] == 'distinct':
n_best_distinct = True
cmd = cmd[:i] + cmd[i + 4:]
else:
cmd = cmd[:i] + cmd[i + 3:]
elif cmd[i] == '-output-search-graph-hypergraph':
# cmd[i + 1] == true
hg_ext = cmd[i + 2]
if i + 3 < len(cmd) and cmd[i + 3][0] != '-':
hg_dir = cmd[i + 3]
cmd = cmd[:i] + cmd[i + 4:]
else:
hg_dir = 'hypergraph'
cmd = cmd[:i] + cmd[i + 3:]
elif cmd[i] == '-tmp':
tmp_dir = cmd[i + 1]
cmd = cmd[:i] + cmd[i + 2:]
# Handled specially to make sure XML input is turned on somewhere
elif cmd[i] in ('-xi', '-xml-input'):
xml_found = True
xml_input = cmd[i + 1]
cmd = cmd[:i] + cmd[i + 2:]
# Handled specially for mert-moses.pl
elif cmd[i] == '-show-weights':
show_weights = True
# Do not remove from cmd
i += 1
else:
i += 1
# Read moses.ini
if moses_ini:
moses_ini_lines = [line.strip() for line in open(moses_ini, 'r')]
i = 0
while i < len(moses_ini_lines):
# PhraseDictionaryBitextSampling name=TranslationModel0
# output-factor=0 num-features=7 path=corpus. L1=src L2=tgt
# pfwd=g pbwd=g smooth=0 sample=1000 workers=1
if moses_ini_lines[i].startswith('PhraseDictionaryBitextSampling'):
for (k, v) in (pair.split('=') for pair in moses_ini_lines[i].split()[1:]):
if k == 'name':
# Dynamic means update this model
if v.startswith('Dynamic'):
mmsapt_dynamic.append(v)
moses_ini_lines[i] += '{mmsapt_extra}'
else:
mmsapt_static.append(v)
elif k == 'L1':
if mmsapt_l1 and v != mmsapt_l1:
raise ProgramFailure(
'Error: All PhraseDictionaryBitextSampling '
'entries should have same L1: '
'{} != {}\n'.format(v, mmsapt_l1))
mmsapt_l1 = v
elif k == 'L2':
if mmsapt_l2 and v != mmsapt_l2:
raise ProgramFailure(
'Error: All PhraseDictionaryBitextSampling '
'entries should have same L2: '
'{} != {}\n'.format(v, mmsapt_l2))
mmsapt_l2 = v
# [threads]
# 8
elif moses_ini_lines[i] == '[threads]':
# Prefer command line over moses.ini
if not threads_found:
threads = int(moses_ini_lines[i + 1])
i += 1
# [xml-input]
# exclusive
elif moses_ini_lines[i] == '[xml-input]':
# Prefer command line over moses.ini
if not xml_found:
xml_found = True
xml_input = moses_ini_lines[i + 1]
i += 1
i += 1
# If mert-moses.pl passes -show-weights, just call moses
if show_weights:
# re-append original moses.ini
cmd.append('-config')
cmd.append(moses_ini)
sys.stdout.write(subprocess.check_output(cmd))
sys.stdout.flush()
sys.exit(0)
# Input length
if text_src:
text_len = wc(text_src)
# Check inputs
if not (len(cmd) > 0 and all((moses_ini, text_src, text_tgt, text_symal))):
sys.stderr.write(HELP.format(argv[0]))
sys.exit(2)
if not (os.path.isfile(cmd[0]) and os.access(cmd[0], os.X_OK)):
raise ProgramFailure(
'Error: moses-cmd "{}" is not executable\n'.format(cmd[0]))
if not mmsapt_dynamic:
raise ProgramFailure((
'Error: no PhraseDictionaryBitextSampling entries named '
'"Dynamic..." found in {}. See '
'http://www.statmt.org/moses/?n=Moses.AdvancedFeatures#ntoc40\n'
).format(moses_ini))
if wc(text_tgt) != text_len or wc(text_symal) != text_len:
raise ProgramFailure(
'Error: length mismatch between "{}", "{}", and "{}"\n'.format(
text_src, text_tgt, text_symal))
# Setup
work_dir = tempfile.mkdtemp(prefix='moses.', dir=os.path.abspath(tmp_dir))
threads = min(threads, text_len)
batch_size = int(math.ceil(float(text_len) / threads))
# Report settings
sys.stderr.write(
'Moses flags: {}\n'.format(
' '.join('\'{}\''.format(s) if ' ' in s else s for s in cmd[1:])))
for (i, n) in enumerate(mmsapt_dynamic):
sys.stderr.write(
'Dynamic mmsapt {}: {} {} {}\n'.format(
i, n, mmsapt_l1, mmsapt_l2))
for (i, n) in enumerate(mmsapt_static):
sys.stderr.write(
'Static mmsapt {}: {} {} {}\n'.format(i, n, mmsapt_l1, mmsapt_l2))
sys.stderr.write('XML mode: {}\n'.format(xml_input))
sys.stderr.write(
'Inputs: {} {} {} ({})\n'.format(
text_src, text_tgt, text_symal, text_len))
sys.stderr.write('Jobs: {}\n'.format(threads))
sys.stderr.write('Batch size: {}\n'.format(batch_size))
if n_best_out:
sys.stderr.write(
'N-best list: {} ({}{})\n'.format(
n_best_out, n_best_size,
', distinct' if n_best_distinct else ''))
if hg_dir:
sys.stderr.write('Hypergraph dir: {} ({})\n'.format(hg_dir, hg_ext))
sys.stderr.write('Temp dir: {}\n'.format(work_dir))
# Accumulate seen lines
src_lines = []
tgt_lines = []
symal_lines = []
# Current XML source file
xml_out = None
# Split into batches. Each batch after 0 gets extra files with data from
# previous batches.
# Data from previous lines in the current batch is added using XML input.
job = -1
lc = -1
lines = itertools.izip(
gzopen(text_src), gzopen(text_tgt), gzopen(text_symal))
for (src, tgt, symal) in lines:
(src, tgt, symal) = (src.strip(), tgt.strip(), symal.strip())
lc += 1
if lc % batch_size == 0:
job += 1
xml_file = os.path.join(work_dir, 'input.{}.xml'.format(job))
extra_src_file = os.path.join(
work_dir, 'extra.{}.{}.txt.gz'.format(job, mmsapt_l1))
extra_tgt_file = os.path.join(
work_dir, 'extra.{}.{}.txt.gz'.format(job, mmsapt_l2))
extra_symal_file = os.path.join(
work_dir, 'extra.{}.{}-{}.symal.gz'.format(
job, mmsapt_l1, mmsapt_l2))
if job > 0:
xml_out.close()
write_gzfile(src_lines, extra_src_file)
write_gzfile(tgt_lines, extra_tgt_file)
write_gzfile(symal_lines, extra_symal_file)
xml_out = open(xml_file, 'w')
ini_file = os.path.join(work_dir, 'moses.{}.ini'.format(job))
with open(ini_file, 'w') as moses_ini_out:
if job == 0:
extra = ''
else:
extra = ' extra={}'.format(
os.path.join(work_dir, 'extra.{}.'.format(job)))
moses_ini_out.write(
'{}\n'.format(
'\n'.join(moses_ini_lines).format(mmsapt_extra=extra)))
src_lines.append(src)
tgt_lines.append(tgt)
symal_lines.append(symal)
# Lines after first start with update tag including previous
# translation.
# Translation of last line of each batch is included in extra for
# next batch.
xml_tags = []
if lc % batch_size != 0:
tag_template = (
'<update '
'name="{}" source="{}" target="{}" alignment="{}" /> ')
for n in mmsapt_dynamic:
# Note: space after tag.
xml_tags.append(
tag_template.format(
n, src_lines[-2], tgt_lines[-2], symal_lines[-2]))
xml_out.write('{}{}\n'.format(''.join(xml_tags), src))
xml_out.close()
# Run decoders in parallel
workers = []
prog = Progress()
for i in range(threads):
work_cmd = cmd[:]
work_cmd.append('-config')
work_cmd.append(os.path.join(work_dir, 'moses.{}.ini'.format(i)))
# Workers use 1 CPU each
work_cmd.append('-threads')
work_cmd.append('1')
if not xml_found:
work_cmd.append('-xml-input')
work_cmd.append(xml_input)
if n_best_out:
work_cmd.append('-n-best-list')
work_cmd.append(os.path.join(work_dir, 'nbest.{}'.format(i)))
work_cmd.append(str(n_best_size))
if n_best_distinct:
work_cmd.append('distinct')
if hg_dir:
work_cmd.append('-output-search-graph-hypergraph')
work_cmd.append('true')
work_cmd.append(hg_ext)
work_cmd.append(os.path.join(work_dir, 'hg.{}'.format(i)))
in_file = os.path.join(work_dir, 'input.{}.xml'.format(i))
out_file = os.path.join(work_dir, 'out.{}'.format(i))
err_file = os.path.join(work_dir, 'err.{}'.format(i))
t = threading.Thread(
target=atomic_io,
args=(work_cmd, in_file, out_file, err_file, prog))
workers.append(t)
t.start()
# Wait for all to finish
for t in workers:
t.join()
prog.done()
# Gather N-best lists
if n_best_out:
with open(n_best_out, 'w') as out:
for i in range(threads):
path = os.path.join(work_dir, 'nbest.{}'.format(i))
for line in open(path, 'r'):
entry = line.partition(' ')
out.write(
'{} {}'.format(
int(entry[0]) + (i * batch_size), entry[2]))
# Gather hypergraphs
if hg_dir:
if not os.path.exists(hg_dir):
os.mkdir(hg_dir)
shutil.copy(
os.path.join(work_dir, 'hg.0', 'weights'),
os.path.join(hg_dir, 'weights'))
for i in range(threads):
for j in range(batch_size):
shutil.copy(
os.path.join(
work_dir, 'hg.{}'.format(i),
'{}.{}'.format(j, hg_ext)),
os.path.join(
hg_dir, '{}.{}'.format((i * batch_size) + j, hg_ext)))
# Gather stdout
for i in range(threads):
for line in open(os.path.join(work_dir, 'out.{}'.format(i)), 'r'):
sys.stdout.write(line)
# Cleanup
shutil.rmtree(work_dir)
if __name__ == '__main__':
try:
main(sys.argv)
except ProgramFailure as error:
sys.stderr.write("%s\n" % error)
sys.exit(1)