Spaces:
Running
Running
import sys | |
from tqdm import tqdm | |
from typing import List, Tuple | |
def remove_large_sentences(src_path: str, tgt_path: str) -> Tuple[int, List[str], List[str]]: | |
""" | |
Removes large sentences from a parallel dataset of source and target data. | |
Args: | |
src_path (str): path to the file containing the source language data. | |
tgt_path (str): path to the file containing the target language data. | |
Returns: | |
Tuple[int, List[str], List[str]]: a tuple of | |
- an integer representing the number of sentences removed | |
- a list of strings containing the source language data after removing large sentences | |
- a list of strings containing the target language data after removing large sentences | |
""" | |
count = 0 | |
new_src_lines, new_tgt_lines = [], [] | |
src_num_lines = sum(1 for line in open(src_path, "r", encoding="utf-8")) | |
tgt_num_lines = sum(1 for line in open(tgt_path, "r", encoding="utf-8")) | |
assert src_num_lines == tgt_num_lines | |
with open(src_path, encoding="utf-8") as f1, open(tgt_path, encoding="utf-8") as f2: | |
for src_line, tgt_line in tqdm(zip(f1, f2), total=src_num_lines): | |
src_tokens = src_line.strip().split(" ") | |
tgt_tokens = tgt_line.strip().split(" ") | |
if len(src_tokens) > 200 or len(tgt_tokens) > 200: | |
count += 1 | |
continue | |
new_src_lines.append(src_line) | |
new_tgt_lines.append(tgt_line) | |
return count, new_src_lines, new_tgt_lines | |
def create_txt(out_file: str, lines: List[str]): | |
""" | |
Creates a text file and writes the given list of lines to file. | |
Args: | |
out_file (str): path to the output file to be created. | |
lines (List[str]): a list of strings to be written to the output file. | |
""" | |
add_newline = not "\n" in lines[0] | |
outfile = open("{}".format(out_file), "w", encoding="utf-8") | |
for line in lines: | |
if add_newline: | |
outfile.write(line + "\n") | |
else: | |
outfile.write(line) | |
outfile.close() | |
if __name__ == "__main__": | |
src_path = sys.argv[1] | |
tgt_path = sys.argv[2] | |
new_src_path = sys.argv[3] | |
new_tgt_path = sys.argv[4] | |
count, new_src_lines, new_tgt_lines = remove_large_sentences(src_path, tgt_path) | |
print(f"{count} lines removed due to seq_len > 200") | |
create_txt(new_src_path, new_src_lines) | |
create_txt(new_tgt_path, new_tgt_lines) | |