Spaces:
Running
Running
File size: 10,175 Bytes
d44849f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 |
import os
import sys
import string
from tqdm import tqdm
from collections import defaultdict
from typing import List, Tuple, Dict
def read_lines(fname: str) -> List[str]:
"""
Reads all lines from an input file and returns them as a list of strings.
Args:
fname (str): path to the input file to read
Returns:
List[str]: a list of strings, where each string is a line from the file
and returns an empty list if the file does not exist.
"""
# if path doesnt exist, return empty list
if not os.path.exists(fname):
return []
with open(fname, "r") as f:
lines = f.readlines()
return lines
def create_txt(out_file: str, lines: List[str]):
"""
Creates a text file and writes the given list of lines to file.
Args:
out_file (str): path to the output file to be created.
lines (List[str]): a list of strings to be written to the output file.
"""
add_newline = not "\n" in lines[0]
outfile = open("{}".format(out_file), "w", encoding="utf-8")
for line in lines:
if add_newline:
outfile.write(line + "\n")
else:
outfile.write(line)
outfile.close()
def pair_dedup_lists(src_list: List[str], tgt_list: List[str]) -> Tuple[List[str], List[str]]:
"""
Removes duplicates from two lists by pairing their elements and removing duplicates from the pairs.
Args:
src_list (List[str]): a list of strings from source language data.
tgt_list (List[str]): a list of strings from target language data.
Returns:
Tuple[List[str], List[str]]: a tuple of deduplicated version of "`(src_list, tgt_list)`".
"""
src_tgt = list(set(zip(src_list, tgt_list)))
src_deduped, tgt_deduped = zip(*src_tgt)
return src_deduped, tgt_deduped
def pair_dedup_files(src_file: str, tgt_file: str):
"""
Removes duplicates from two files by pairing their lines and removing duplicates from the pairs.
Args:
src_file (str): path to the source language file to deduplicate.
tgt_file (str): path to the target language file to deduplicate.
"""
src_lines = read_lines(src_file)
tgt_lines = read_lines(tgt_file)
len_before = len(src_lines)
src_dedupped, tgt_dedupped = pair_dedup_lists(src_lines, tgt_lines)
len_after = len(src_dedupped)
num_duplicates = len_before - len_after
print(f"Dropped duplicate pairs in {src_file} Num duplicates -> {num_duplicates}")
create_txt(src_file, src_dedupped)
create_txt(tgt_file, tgt_dedupped)
def strip_and_normalize(line: str) -> str:
"""
Strips and normalizes a string by lowercasing it, removing spaces and punctuation.
Args:
line (str): string to strip and normalize.
Returns:
str: stripped and normalized version of the input string.
"""
# lowercase line, remove spaces and strip punctuation
# one of the fastest way to add an exclusion list and remove that
# list of characters from a string
# https://towardsdatascience.com/how-to-efficiently-remove-punctuations-from-a-string-899ad4a059fb
exclist = string.punctuation + "\u0964"
table_ = str.maketrans("", "", exclist)
line = line.replace(" ", "").lower()
# dont use this method, it is painfully slow
# line = "".join([i for i in line if i not in string.punctuation])
line = line.translate(table_)
return line
def expand_tupled_list(list_of_tuples: List[Tuple[str, str]]) -> Tuple[List[str], List[str]]:
"""
Expands a list of tuples into two lists by extracting the first and second elements of the tuples.
Args:
list_of_tuples (List[Tuple[str, str]]): a list of tuples, where each tuple contains two strings.
Returns:
Tuple[List[str], List[str]]: a tuple containing two lists, the first being the first elements of the
tuples in `list_of_tuples` and the second being the second elements.
"""
# convert list of tuples into two lists
# https://stackoverflow.com/questions/8081545/how-to-convert-list-of-tuples-to-multiple-lists
# [(en, as), (as, bn), (bn, gu)] - > [en, as, bn], [as, bn, gu]
list_a, list_b = map(list, zip(*list_of_tuples))
return list_a, list_b
def normalize_and_gather_all_benchmarks(devtest_dir: str) -> Dict[str, Dict[str, List[str]]]:
"""
Normalizes and gathers all benchmark datasets from a directory into a dictionary.
Args:
devtest_dir (str): path to the directory containing the subdirectories named after the benchmark datasets, \
where each subdirectory is named in the format "`src_lang-tgt_lang`" and contain four files: `dev.src_lang`, \
`dev.tgt_lang`, `test.src_lang`, and `test.tgt_lang` representing the development and test sets for the language pair.
Returns:
Dict[str, Dict[str, List[str]]]: a dictionary mapping language pairs (in the format "`src_lang-tgt_lang`") \
to dictionaries containing two lists, the first being the normalized source language lines and the \
second being the normalized target language lines for all benchmark datasets.
"""
devtest_pairs_normalized = defaultdict(lambda: defaultdict(list))
for benchmark in os.listdir(devtest_dir):
print(f"{devtest_dir}/{benchmark}")
for pair in tqdm(os.listdir(f"{devtest_dir}/{benchmark}")):
src_lang, tgt_lang = pair.split("-")
src_dev = read_lines(f"{devtest_dir}/{benchmark}/{pair}/dev.{src_lang}")
tgt_dev = read_lines(f"{devtest_dir}/{benchmark}/{pair}/dev.{tgt_lang}")
src_test = read_lines(f"{devtest_dir}/{benchmark}/{pair}/test.{src_lang}")
tgt_test = read_lines(f"{devtest_dir}/{benchmark}/{pair}/test.{tgt_lang}")
# if the tgt_pair data doesnt exist for a particular test set,
# it will be an empty list
if tgt_test == [] or tgt_dev == []:
print(f"{benchmark} does not have {src_lang}-{tgt_lang} data")
continue
# combine both dev and test sets into one
src_devtest = src_dev + src_test
tgt_devtest = tgt_dev + tgt_test
src_devtest = [strip_and_normalize(line) for line in src_devtest]
tgt_devtest = [strip_and_normalize(line) for line in tgt_devtest]
devtest_pairs_normalized[pair]["src"].extend(src_devtest)
devtest_pairs_normalized[pair]["tgt"].extend(tgt_devtest)
# dedup merged benchmark datasets
for pair in devtest_pairs_normalized:
src_devtest = devtest_pairs_normalized[pair]["src"]
tgt_devtest = devtest_pairs_normalized[pair]["tgt"]
src_devtest, tgt_devtest = pair_dedup_lists(src_devtest, tgt_devtest)
devtest_pairs_normalized[pair]["src"] = src_devtest
devtest_pairs_normalized[pair]["tgt"] = tgt_devtest
return devtest_pairs_normalized
def remove_train_devtest_overlaps(train_dir: str, devtest_dir: str):
"""
Removes overlapping data between the training and dev/test (benchmark)
datasets for all language pairs.
Args:
train_dir (str): path of the directory containing the training data.
devtest_dir (str): path of the directory containing the dev/test data.
"""
devtest_pairs_normalized = normalize_and_gather_all_benchmarks(devtest_dir)
all_src_sentences_normalized = []
for key in devtest_pairs_normalized:
all_src_sentences_normalized.extend(devtest_pairs_normalized[key]["src"])
# remove duplicates in all test benchmarks across all lang pair
# this might not be the most optimal way but this is a tradeoff for generalizing the code at the moment
all_src_sentences_normalized = list(set(all_src_sentences_normalized))
src_overlaps = []
tgt_overlaps = []
pairs = os.listdir(train_dir)
for pair in pairs:
src_lang, tgt_lang = pair.split("-")
new_src_train, new_tgt_train = [], []
src_train = read_lines(f"{train_dir}/{pair}/train.{src_lang}")
tgt_train = read_lines(f"{train_dir}/{pair}/train.{tgt_lang}")
len_before = len(src_train)
if len_before == 0:
continue
src_train_normalized = [strip_and_normalize(line) for line in src_train]
tgt_train_normalized = [strip_and_normalize(line) for line in tgt_train]
src_devtest_normalized = all_src_sentences_normalized
tgt_devtest_normalized = devtest_pairs_normalized[pair]["tgt"]
# compute all src and tgt super strict overlaps for a lang pair
overlaps = set(src_train_normalized) & set(src_devtest_normalized)
src_overlaps.extend(list(overlaps))
overlaps = set(tgt_train_normalized) & set(tgt_devtest_normalized)
tgt_overlaps.extend(list(overlaps))
# dictionaries offer O(1) lookup
src_overlaps_dict, tgt_overlaps_dict = {}, {}
for line in src_overlaps:
src_overlaps_dict[line] = 1
for line in tgt_overlaps:
tgt_overlaps_dict[line] = 1
# loop to remove the ovelapped data
idx = 0
for src_line_norm, tgt_line_norm in tqdm(
zip(src_train_normalized, tgt_train_normalized), total=len_before
):
if src_overlaps_dict.get(src_line_norm, None):
continue
if tgt_overlaps_dict.get(tgt_line_norm, None):
continue
new_src_train.append(src_train[idx])
new_tgt_train.append(tgt_train[idx])
idx += 1
len_after = len(new_src_train)
print(
f"Detected overlaps between train and devetest for {pair} is {len_before - len_after}"
)
print(f"saving new files at {train_dir}/{pair}/")
create_txt(f"{train_dir}/{pair}/train.{src_lang}", new_src_train)
create_txt(f"{train_dir}/{pair}/train.{tgt_lang}", new_tgt_train)
if __name__ == "__main__":
train_data_dir = sys.argv[1]
# benchmarks directory should contains all the test sets
devtest_data_dir = sys.argv[2]
remove_train_devtest_overlaps(train_data_dir, devtest_data_dir)
|