Spaces:
Running
Running
File size: 5,125 Bytes
d44849f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
import sys
from tqdm import tqdm
from typing import Iterator, List, Tuple
from remove_train_devtest_overlaps import pair_dedup_files
def read_file(fname: str) -> Iterator[str]:
"""
Reads text from the input file and yields the text line-by-line as string.
Args:
fname (str): name of the input file to read.
Yields:
Iterator[str]: yields text line-by-line as a string from the input file.
"""
with open(fname, "r", encoding="utf-8") as infile:
for line in infile:
yield line.strip()
def extract_non_english_pairs(in_dir: str, out_dir: str, pivot_lang: str, langs: List[str]):
"""
Extracts non-English language pairs from a parallel corpora using pivot-translation.
Args:
in_dir (str): path of the directory where the input files are stored.
out_dir (str): path of the directory where the output files are stored.
pivot_lang (str): pivot language that the input files are translated to.
langs (List[str]): a list of language codes for the non-English languages.
"""
for i in tqdm(range(len(langs) - 1)):
print()
for j in range(i + 1, len(langs)):
lang1 = langs[i]
lang2 = langs[j]
print("{} {}".format(lang1, lang2))
fname1 = "{}/{}-{}/train.{}".format(in_dir, pivot_lang, lang1, pivot_lang)
fname2 = "{}/{}-{}/train.{}".format(in_dir, pivot_lang, lang2, pivot_lang)
enset_l1 = set(read_file(fname1))
common_en_set = enset_l1.intersection(read_file(fname2))
il_fname1 = "{}/{}-{}/train.{}".format(in_dir, pivot_lang, lang1, lang1)
en_lang1_dict = {}
for en_line, il_line in zip(read_file(fname1), read_file(il_fname1)):
if en_line in common_en_set:
en_lang1_dict[en_line] = il_line
os.makedirs("{}/{}-{}".format(out_dir, lang1, lang2), exist_ok=True)
out_l1_fname = "{o}/{l1}-{l2}/train.{l1}".format(o=out_dir, l1=lang1, l2=lang2)
out_l2_fname = "{o}/{l1}-{l2}/train.{l2}".format(o=out_dir, l1=lang1, l2=lang2)
il_fname2 = "{}/en-{}/train.{}".format(in_dir, lang2, lang2)
with open(out_l1_fname, "w", encoding="utf-8") as out_l1_file, open(
out_l2_fname, "w", encoding="utf-8"
) as out_l2_file:
for en_line, il_line in zip(read_file(fname2), read_file(il_fname2)):
if en_line in en_lang1_dict:
# this block should be used if you want to consider multiple tranlations.
for il_line_lang1 in en_lang1_dict[en_line]:
# lang1_line, lang2_line = il_line_lang1, il_line
# out_l1_file.write(lang1_line + "\n")
# out_l2_file.write(lang2_line + "\n")
# this block should be used if you DONT to consider multiple translation.
lang1_line, lang2_line = en_lang1_dict[en_line], il_line
out_l1_file.write(lang1_line + "\n")
out_l2_file.write(lang2_line + "\n")
pair_dedup_files(out_l1_fname, out_l2_fname)
def get_extracted_stats(out_dir: str, langs: List[str]) -> List[Tuple[str, str, int]]:
"""
Gathers stats from the extracted non-english pairs.
Args:
out_dir (str): path of the directory where the output files are stored.
langs (List[str]): a list of language codes.
Returns:
List[Tuple[str, str, int]]: a list of tuples, where each tuple contains statistical information
about a language pair in the form "`(lang1, lang2, count)`".
"""
common_stats = []
for i in tqdm(range(len(langs) - 1)):
for j in range(i + 1, len(langs)):
lang1 = langs[i]
lang2 = langs[j]
out_l1_fname = "{o}/{l1}-{l2}/train.{l1}".format(o=out_dir, l1=lang1, l2=lang2)
cnt = sum([1 for _ in read_file(out_l1_fname)])
common_stats.append((lang1, lang2, cnt))
common_stats.append((lang2, lang1, cnt))
return common_stats
if __name__ == "__main__":
#TODO: need to fix this
in_dir = sys.argv[1]
out_dir = sys.argv[2]
langs = sorted([lang.strip() for lang in sys.argv[3].split(",")])
if len(sys.argv) == 4:
pivot_lang = "eng_Latn"
else:
pivot_lang = sys.argv[4]
for pair in os.listdir(in_dir):
src_lang, tgt_lang = pair.split("-")
if src_lang == pivot_lang:
continue
else:
tmp_in_dir = os.path.join(in_dir, pair)
tmp_out_dir = os.path.join(in_dir, "{}-{}".format(pivot_lang, src_lang))
os.rename(tmp_in_dir, tmp_out_dir)
#extract_non_english_pairs(in_dir, out_dir, pivot_lang, langs)
"""stats = get_extracted_stats(out_dir, langs)
with open("{}/lang_pairs.txt", "w") as f:
for stat in stats:
stat = list(map(str, stat))
f.write("\t".join(stat) + "\n")
""" |