|
|
|
|
|
|
|
|
|
|
|
""" |
|
Split a large file into a train and valid set while respecting document |
|
boundaries. Documents should be separated by a single empty line. |
|
""" |
|
|
|
import argparse |
|
import random |
|
import sys |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("input") |
|
parser.add_argument("sample_output", help="train output file") |
|
parser.add_argument("remainder_output", help="valid output file") |
|
parser.add_argument("-k", type=int, help="remainder size") |
|
parser.add_argument( |
|
"--lines", action="store_true", help="split lines instead of docs" |
|
) |
|
args = parser.parse_args() |
|
|
|
assert args.k is not None |
|
|
|
sample = [] |
|
remainder = [] |
|
num_docs = [0] |
|
|
|
def update_sample(doc): |
|
if len(sample) < args.k: |
|
sample.append(doc.copy()) |
|
else: |
|
i = num_docs[0] |
|
j = random.randrange(i + 1) |
|
if j < args.k: |
|
remainder.append(sample[j]) |
|
sample[j] = doc.copy() |
|
else: |
|
remainder.append(doc.copy()) |
|
num_docs[0] += 1 |
|
doc.clear() |
|
|
|
with open(args.input, "r", encoding="utf-8") as h: |
|
doc = [] |
|
for i, line in enumerate(h): |
|
if line.strip() == "": |
|
update_sample(doc) |
|
else: |
|
doc.append(line) |
|
if args.lines: |
|
update_sample(doc) |
|
if i % 1000000 == 0: |
|
print(i, file=sys.stderr, end="", flush=True) |
|
elif i % 100000 == 0: |
|
print(".", file=sys.stderr, end="", flush=True) |
|
if len(doc) > 0: |
|
update_sample(doc) |
|
print(file=sys.stderr, flush=True) |
|
|
|
assert len(sample) == args.k |
|
|
|
with open(args.sample_output, "w", encoding="utf-8") as out: |
|
first = True |
|
for doc in sample: |
|
if not first and not args.lines: |
|
out.write("\n") |
|
first = False |
|
for line in doc: |
|
out.write(line) |
|
|
|
with open(args.remainder_output, "w", encoding="utf-8") as out: |
|
first = True |
|
for doc in remainder: |
|
if not first and not args.lines: |
|
out.write("\n") |
|
first = False |
|
for line in doc: |
|
out.write(line) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|