File size: 3,747 Bytes
d44849f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
INDIC_NLP_LIB_HOME = "indic_nlp_library"
INDIC_NLP_RESOURCES = "indic_nlp_resources"
import sys

from indicnlp import transliterate

sys.path.append(r"{}".format(INDIC_NLP_LIB_HOME))
from indicnlp import common

common.set_resources_path(INDIC_NLP_RESOURCES)
from indicnlp import loader

loader.load()
from sacremoses import MosesPunctNormalizer
from sacremoses import MosesTokenizer
from sacremoses import MosesDetokenizer
from collections import defaultdict

import indicnlp
from indicnlp.tokenize import indic_tokenize
from indicnlp.tokenize import indic_detokenize
from indicnlp.normalize import indic_normalize
from indicnlp.transliterate import unicode_transliterate

from flores_codes_map_indic import flores_codes
import sentencepiece as spm

import re

en_detok = MosesDetokenizer(lang="en")


def postprocess(
    infname: str,
    outfname: str,
    input_size: int,
    lang: str,
    transliterate: bool = False,
    spm_model_path: str = None,
):
    """
    Postprocess the output of a machine translation model in the following order:
        - parse fairseq interactive output
        - convert script back to native Indic script (in case of Indic languages)
        - detokenize

    Args:
        infname (str): path to the input file containing the machine translation output.
        outfname (str): path to the output file where the postprocessed output will be written.
        input_size (int): number of sentences in the input file.
        lang (str): language code of the output language.
        transliterate (bool, optional): whether to transliterate the output text to devanagari (default: False).
        spm_model_path (str): path of the sentence piece model.
    """
    if spm_model_path is None:
        raise Exception("Please provide sentence piece model path for decoding")
    
    sp = spm.SentencePieceProcessor(model_file=spm_model_path)
    
    iso_lang = flores_codes[lang]

    consolidated_testoutput = []
    consolidated_testoutput = [(x, 0.0, "") for x in range(input_size)]

    temp_testoutput = []
    with open(infname, "r", encoding="utf-8") as infile:
        temp_testoutput = list(
            map(
                lambda x: x.strip().split("\t"),
                filter(lambda x: x.startswith("H-"), infile),
            )
        )
        temp_testoutput = list(
            map(lambda x: (int(x[0].split("-")[1]), float(x[1]), x[2]), temp_testoutput)
        )
        for sid, score, hyp in temp_testoutput:
            consolidated_testoutput[sid] = (sid, score, hyp)
        consolidated_testoutput = [x[2] for x in consolidated_testoutput]
        consolidated_testoutput = [sp.decode(x.split(" ")) for x in consolidated_testoutput]

    if iso_lang == "en":
        with open(outfname, "w", encoding="utf-8") as outfile:
            for sent in consolidated_testoutput:
                outfile.write(en_detok.detokenize(sent.split(" ")) + "\n")
    else:
        xliterator = unicode_transliterate.UnicodeIndicTransliterator()
        with open(outfname, "w", encoding="utf-8") as outfile:
            for sent in consolidated_testoutput:
                if transliterate:
                    outstr = indic_detokenize.trivial_detokenize(
                        xliterator.transliterate(sent, "hi", iso_lang), iso_lang
                    )
                else:
                    outstr = indic_detokenize.trivial_detokenize(sent, iso_lang)
                outfile.write(outstr + "\n")


if __name__ == "__main__":
    infname = sys.argv[1]
    outfname = sys.argv[2]
    input_size = int(sys.argv[3])
    lang = sys.argv[4]
    transliterate = sys.argv[5]
    spm_model_path = sys.argv[6]

    postprocess(infname, outfname, input_size, lang, transliterate, spm_model_path)