jiang
init commit
650c5f6
raw
history blame
3.34 kB
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from itertools import zip_longest
def replace_oovs(source_in, target_in, vocabulary, source_out, target_out):
"""Replaces out-of-vocabulary words in source and target text with <unk-N>,
where N in is the position of the word in the source sequence.
"""
def format_unk(pos):
return "<unk-{}>".format(pos)
if target_in is None:
target_in = []
for seq_num, (source_seq, target_seq) in enumerate(
zip_longest(source_in, target_in)
):
source_seq_out = []
target_seq_out = []
word_to_pos = dict()
for position, token in enumerate(source_seq.strip().split()):
if token in vocabulary:
token_out = token
else:
if token in word_to_pos:
oov_pos = word_to_pos[token]
else:
word_to_pos[token] = position
oov_pos = position
token_out = format_unk(oov_pos)
source_seq_out.append(token_out)
source_out.write(" ".join(source_seq_out) + "\n")
if target_seq is not None:
for token in target_seq.strip().split():
if token in word_to_pos:
token_out = format_unk(word_to_pos[token])
else:
token_out = token
target_seq_out.append(token_out)
if target_out is not None:
target_out.write(" ".join(target_seq_out) + "\n")
def main():
parser = argparse.ArgumentParser(
description="Replaces out-of-vocabulary words in both source and target "
"sequences with tokens that indicate the position of the word "
"in the source sequence."
)
parser.add_argument(
"--source", type=str, help="text file with source sequences", required=True
)
parser.add_argument(
"--target", type=str, help="text file with target sequences", default=None
)
parser.add_argument("--vocab", type=str, help="vocabulary file", required=True)
parser.add_argument(
"--source-out",
type=str,
help="where to write source sequences with <unk-N> entries",
required=True,
)
parser.add_argument(
"--target-out",
type=str,
help="where to write target sequences with <unk-N> entries",
default=None,
)
args = parser.parse_args()
with open(args.vocab, encoding="utf-8") as vocab:
vocabulary = vocab.read().splitlines()
target_in = (
open(args.target, "r", encoding="utf-8") if args.target is not None else None
)
target_out = (
open(args.target_out, "w", encoding="utf-8")
if args.target_out is not None
else None
)
with open(args.source, "r", encoding="utf-8") as source_in, open(
args.source_out, "w", encoding="utf-8"
) as source_out:
replace_oovs(source_in, target_in, vocabulary, source_out, target_out)
if target_in is not None:
target_in.close()
if target_out is not None:
target_out.close()
if __name__ == "__main__":
main()