#!/usr/bin/env python

# Written by Michael Denkowski
#
# This file is part of moses.  Its use is licensed under the GNU Lesser General
# Public License version 2.1 or, at your option, any later version.

"""Parallelize decoding with simulated post-editing via moses XML input.

(XML entities need to be escaped in tokenization).  Memory mapped
dynamic phrase tables (Ulrich Germann,
www.statmt.org/moses/?n=Moses.AdvancedFeatures#ntoc40) and language models
(Kenneth Heafield,
http://www.statmt.org/moses/?n=FactoredTraining.BuildingLanguageModel#ntoc19)
facilitate memory efficient multi process decoding.  Input is divided into
batches, each of which is decoded sequentially.  Each batch pre-loads the
data from previous batches.

To use in tuning, run mert-moses.pl with --sim-pe=SYMAL where SYMAL is the
alignment from input to references.  Specify the number of jobs with
--decoder-flags="-threads N".
"""

import gzip
import itertools
import math
import os
import shutil
import subprocess
import sys
import tempfile
import threading

HELP = '''Moses with simulated post-editing

Usage:
    {} moses-cmd -config moses.ini -input-file text.src -ref text.tgt \
    -symal text.src-tgt.symal [options] [decoder flags]

Options:
    -threads N: number of decoders to run in parallel \
(default read from moses.ini, 1 if not present)
    -n-best-list nbest.out N [distinct]: location and size of N-best list
    -show-weights: for mert-moses.pl, just call moses and exit
    -tmp: location of temp directory (default /tmp)

Other options (decoder flags) are passed through to moses-cmd\n'''


class ProgramFailure(Exception):
    """Known kind of failure, with a known presentation to the user.

    Error message will be printed, and the program will return an error,
    but no traceback will be shown to the user.
    """


class Progress:
    """Provides progress bar."""

    def __init__(self):
        self.i = 0
        self.lock = threading.Lock()

    def inc(self):
        self.lock.acquire()
        self.i += 1
        if self.i % 100 == 0:
            sys.stderr.write('.')
            if self.i % 1000 == 0:
                sys.stderr.write(' [{}]\n'.format(self.i))
            sys.stderr.flush()
        self.lock.release()

    def done(self):
        self.lock.acquire()
        if self.i % 1000 != 0:
            sys.stderr.write('\n')
        self.lock.release()


def atomic_io(cmd, in_file, out_file, err_file, prog=None):
    """Run with atomic (synchronous) I/O."""
    with open(in_file, 'r') as inp, open(out_file, 'w') as out, open(err_file, 'w') as err:
        p = subprocess.Popen(
            cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=err)
        while True:
            line = inp.readline()
            if not line:
                break
            p.stdin.write(line)
            out.write(p.stdout.readline())
            out.flush()
            if prog:
                prog.inc()
        p.stdin.close()
        p.wait()


def gzopen(f):
    """Open plain or gzipped text."""
    return gzip.open(f, 'rb') if f.endswith('.gz') else open(f, 'r')


def wc(f):
    """Word count."""
    i = 0
    for line in gzopen(f):
        i += 1
    return i


def write_gzfile(lines, f):
    """Write lines to gzipped file."""
    out = gzip.open(f, 'wb')
    for line in lines:
        out.write('{}\n'.format(line))
    out.close()


def main(argv):
    # Defaults
    moses_ini = None
    moses_ini_lines = None
    text_src = None
    text_tgt = None
    text_symal = None
    text_len = None
    threads_found = False
    threads = 1
    n_best_out = None
    n_best_size = None
    n_best_distinct = False
    hg_ext = None
    hg_dir = None
    tmp_dir = '/tmp'
    xml_found = False
    xml_input = 'exclusive'
    show_weights = False
    mmsapt_dynamic = []
    mmsapt_static = []
    mmsapt_l1 = None
    mmsapt_l2 = None

    # Decoder command
    cmd = argv[1:]

    # Parse special options and remove from cmd
    i = 1
    while i < len(cmd):
        if cmd[i] in ('-f', '-config'):
            moses_ini = cmd[i + 1]
            cmd = cmd[:i] + cmd[i + 2:]
        elif cmd[i] in ('-i', '-input-file'):
            text_src = cmd[i + 1]
            cmd = cmd[:i] + cmd[i + 2:]
        elif cmd[i] == '-ref':
            text_tgt = cmd[i + 1]
            cmd = cmd[:i] + cmd[i + 2:]
        elif cmd[i] == '-symal':
            text_symal = cmd[i + 1]
            cmd = cmd[:i] + cmd[i + 2:]
        elif cmd[i] in ('-th', '-threads'):
            threads_found = True
            threads = int(cmd[i + 1])
            cmd = cmd[:i] + cmd[i + 2:]
        elif cmd[i] == '-n-best-list':
            n_best_out = cmd[i + 1]
            n_best_size = cmd[i + 2]
            # Optional "distinct"
            if i + 3 < len(cmd) and cmd[i + 3] == 'distinct':
                n_best_distinct = True
                cmd = cmd[:i] + cmd[i + 4:]
            else:
                cmd = cmd[:i] + cmd[i + 3:]
        elif cmd[i] == '-output-search-graph-hypergraph':
            # cmd[i + 1] == true
            hg_ext = cmd[i + 2]
            if i + 3 < len(cmd) and cmd[i + 3][0] != '-':
                hg_dir = cmd[i + 3]
                cmd = cmd[:i] + cmd[i + 4:]
            else:
                hg_dir = 'hypergraph'
                cmd = cmd[:i] + cmd[i + 3:]
        elif cmd[i] == '-tmp':
            tmp_dir = cmd[i + 1]
            cmd = cmd[:i] + cmd[i + 2:]
        # Handled specially to make sure XML input is turned on somewhere
        elif cmd[i] in ('-xi', '-xml-input'):
            xml_found = True
            xml_input = cmd[i + 1]
            cmd = cmd[:i] + cmd[i + 2:]
        # Handled specially for mert-moses.pl
        elif cmd[i] == '-show-weights':
            show_weights = True
            # Do not remove from cmd
            i += 1
        else:
            i += 1

    # Read moses.ini
    if moses_ini:
        moses_ini_lines = [line.strip() for line in open(moses_ini, 'r')]
        i = 0
        while i < len(moses_ini_lines):
            # PhraseDictionaryBitextSampling name=TranslationModel0
            # output-factor=0 num-features=7 path=corpus. L1=src L2=tgt
            # pfwd=g pbwd=g smooth=0 sample=1000 workers=1
            if moses_ini_lines[i].startswith('PhraseDictionaryBitextSampling'):
                for (k, v) in (pair.split('=') for pair in moses_ini_lines[i].split()[1:]):
                    if k == 'name':
                        # Dynamic means update this model
                        if v.startswith('Dynamic'):
                            mmsapt_dynamic.append(v)
                            moses_ini_lines[i] += '{mmsapt_extra}'
                        else:
                            mmsapt_static.append(v)
                    elif k == 'L1':
                        if mmsapt_l1 and v != mmsapt_l1:
                            raise ProgramFailure(
                                'Error: All PhraseDictionaryBitextSampling '
                                'entries should have same L1: '
                                '{} != {}\n'.format(v, mmsapt_l1))
                        mmsapt_l1 = v
                    elif k == 'L2':
                        if mmsapt_l2 and v != mmsapt_l2:
                            raise ProgramFailure(
                                'Error: All PhraseDictionaryBitextSampling '
                                'entries should have same L2: '
                                '{} != {}\n'.format(v, mmsapt_l2))
                        mmsapt_l2 = v
            # [threads]
            # 8
            elif moses_ini_lines[i] == '[threads]':
                # Prefer command line over moses.ini
                if not threads_found:
                    threads = int(moses_ini_lines[i + 1])
                i += 1
            # [xml-input]
            # exclusive
            elif moses_ini_lines[i] == '[xml-input]':
                # Prefer command line over moses.ini
                if not xml_found:
                    xml_found = True
                    xml_input = moses_ini_lines[i + 1]
                i += 1
            i += 1

    # If mert-moses.pl passes -show-weights, just call moses
    if show_weights:
        # re-append original moses.ini
        cmd.append('-config')
        cmd.append(moses_ini)
        sys.stdout.write(subprocess.check_output(cmd))
        sys.stdout.flush()
        sys.exit(0)

    # Input length
    if text_src:
        text_len = wc(text_src)

    # Check inputs
    if not (len(cmd) > 0 and all((moses_ini, text_src, text_tgt, text_symal))):
        sys.stderr.write(HELP.format(argv[0]))
        sys.exit(2)
    if not (os.path.isfile(cmd[0]) and os.access(cmd[0], os.X_OK)):
        raise ProgramFailure(
            'Error: moses-cmd "{}" is not executable\n'.format(cmd[0]))
    if not mmsapt_dynamic:
        raise ProgramFailure((
            'Error: no PhraseDictionaryBitextSampling entries named '
            '"Dynamic..." found in {}.  See '
            'http://www.statmt.org/moses/?n=Moses.AdvancedFeatures#ntoc40\n'
            ).format(moses_ini))
    if wc(text_tgt) != text_len or wc(text_symal) != text_len:
        raise ProgramFailure(
            'Error: length mismatch between "{}", "{}", and "{}"\n'.format(
                text_src, text_tgt, text_symal))

    # Setup
    work_dir = tempfile.mkdtemp(prefix='moses.', dir=os.path.abspath(tmp_dir))
    threads = min(threads, text_len)
    batch_size = int(math.ceil(float(text_len) / threads))

    # Report settings
    sys.stderr.write(
        'Moses flags: {}\n'.format(
            ' '.join('\'{}\''.format(s) if ' ' in s else s for s in cmd[1:])))
    for (i, n) in enumerate(mmsapt_dynamic):
        sys.stderr.write(
            'Dynamic mmsapt {}: {} {} {}\n'.format(
                i, n, mmsapt_l1, mmsapt_l2))
    for (i, n) in enumerate(mmsapt_static):
        sys.stderr.write(
            'Static mmsapt {}: {} {} {}\n'.format(i, n, mmsapt_l1, mmsapt_l2))
    sys.stderr.write('XML mode: {}\n'.format(xml_input))
    sys.stderr.write(
        'Inputs: {} {} {} ({})\n'.format(
            text_src, text_tgt, text_symal, text_len))
    sys.stderr.write('Jobs: {}\n'.format(threads))
    sys.stderr.write('Batch size: {}\n'.format(batch_size))
    if n_best_out:
        sys.stderr.write(
            'N-best list: {} ({}{})\n'.format(
                n_best_out, n_best_size,
                ', distinct' if n_best_distinct else ''))
    if hg_dir:
        sys.stderr.write('Hypergraph dir: {} ({})\n'.format(hg_dir, hg_ext))
    sys.stderr.write('Temp dir: {}\n'.format(work_dir))

    # Accumulate seen lines
    src_lines = []
    tgt_lines = []
    symal_lines = []

    # Current XML source file
    xml_out = None

    # Split into batches.  Each batch after 0 gets extra files with data from
    # previous batches.
    # Data from previous lines in the current batch is added using XML input.
    job = -1
    lc = -1
    lines = itertools.izip(
        gzopen(text_src), gzopen(text_tgt), gzopen(text_symal))
    for (src, tgt, symal) in lines:
        (src, tgt, symal) = (src.strip(), tgt.strip(), symal.strip())
        lc += 1
        if lc % batch_size == 0:
            job += 1
            xml_file = os.path.join(work_dir, 'input.{}.xml'.format(job))
            extra_src_file = os.path.join(
                work_dir, 'extra.{}.{}.txt.gz'.format(job, mmsapt_l1))
            extra_tgt_file = os.path.join(
                work_dir, 'extra.{}.{}.txt.gz'.format(job, mmsapt_l2))
            extra_symal_file = os.path.join(
                work_dir, 'extra.{}.{}-{}.symal.gz'.format(
                    job, mmsapt_l1, mmsapt_l2))
            if job > 0:
                xml_out.close()
                write_gzfile(src_lines, extra_src_file)
                write_gzfile(tgt_lines, extra_tgt_file)
                write_gzfile(symal_lines, extra_symal_file)
            xml_out = open(xml_file, 'w')
            ini_file = os.path.join(work_dir, 'moses.{}.ini'.format(job))
            with open(ini_file, 'w') as moses_ini_out:
                if job == 0:
                    extra = ''
                else:
                    extra = ' extra={}'.format(
                        os.path.join(work_dir, 'extra.{}.'.format(job)))
                moses_ini_out.write(
                    '{}\n'.format(
                        '\n'.join(moses_ini_lines).format(mmsapt_extra=extra)))
        src_lines.append(src)
        tgt_lines.append(tgt)
        symal_lines.append(symal)
        # Lines after first start with update tag including previous
        # translation.
        # Translation of last line of each batch is included in extra for
        # next batch.
        xml_tags = []
        if lc % batch_size != 0:
            tag_template = (
                '<update '
                'name="{}" source="{}" target="{}" alignment="{}" /> ')
            for n in mmsapt_dynamic:
                # Note: space after tag.
                xml_tags.append(
                    tag_template.format(
                        n, src_lines[-2], tgt_lines[-2], symal_lines[-2]))
        xml_out.write('{}{}\n'.format(''.join(xml_tags), src))
    xml_out.close()

    # Run decoders in parallel
    workers = []
    prog = Progress()
    for i in range(threads):
        work_cmd = cmd[:]
        work_cmd.append('-config')
        work_cmd.append(os.path.join(work_dir, 'moses.{}.ini'.format(i)))
        # Workers use 1 CPU each
        work_cmd.append('-threads')
        work_cmd.append('1')
        if not xml_found:
            work_cmd.append('-xml-input')
            work_cmd.append(xml_input)
        if n_best_out:
            work_cmd.append('-n-best-list')
            work_cmd.append(os.path.join(work_dir, 'nbest.{}'.format(i)))
            work_cmd.append(str(n_best_size))
            if n_best_distinct:
                work_cmd.append('distinct')
        if hg_dir:
            work_cmd.append('-output-search-graph-hypergraph')
            work_cmd.append('true')
            work_cmd.append(hg_ext)
            work_cmd.append(os.path.join(work_dir, 'hg.{}'.format(i)))
        in_file = os.path.join(work_dir, 'input.{}.xml'.format(i))
        out_file = os.path.join(work_dir, 'out.{}'.format(i))
        err_file = os.path.join(work_dir, 'err.{}'.format(i))
        t = threading.Thread(
            target=atomic_io,
            args=(work_cmd, in_file, out_file, err_file, prog))
        workers.append(t)
        t.start()
    # Wait for all to finish
    for t in workers:
        t.join()
    prog.done()

    # Gather N-best lists
    if n_best_out:
        with open(n_best_out, 'w') as out:
            for i in range(threads):
                path = os.path.join(work_dir, 'nbest.{}'.format(i))
                for line in open(path, 'r'):
                    entry = line.partition(' ')
                    out.write(
                        '{} {}'.format(
                            int(entry[0]) + (i * batch_size), entry[2]))

    # Gather hypergraphs
    if hg_dir:
        if not os.path.exists(hg_dir):
            os.mkdir(hg_dir)
        shutil.copy(
            os.path.join(work_dir, 'hg.0', 'weights'),
            os.path.join(hg_dir, 'weights'))
        for i in range(threads):
            for j in range(batch_size):
                shutil.copy(
                    os.path.join(
                        work_dir, 'hg.{}'.format(i),
                        '{}.{}'.format(j, hg_ext)),
                    os.path.join(
                        hg_dir, '{}.{}'.format((i * batch_size) + j, hg_ext)))

    # Gather stdout
    for i in range(threads):
        for line in open(os.path.join(work_dir, 'out.{}'.format(i)), 'r'):
            sys.stdout.write(line)

    # Cleanup
    shutil.rmtree(work_dir)

if __name__ == '__main__':
    try:
        main(sys.argv)
    except ProgramFailure as error:
        sys.stderr.write("%s\n" % error)
        sys.exit(1)