File size: 5,016 Bytes
7332c68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
# tag.py
# Author: Julie Kallini
# For importing utils
import sys
sys.path.append("..")
import pytest
import glob
import tqdm
import os
import argparse
import stanza
import json
test_all_files = sorted(glob.glob("babylm_data/babylm_*/*"))
test_original_files = [f for f in test_all_files if ".json" not in f]
test_json_files = [f for f in test_all_files if "_parsed.json" in f]
test_cases = list(zip(test_original_files, test_json_files))
@pytest.mark.parametrize("original_file, json_file", test_cases)
def test_equivalent_lines(original_file, json_file):
# Read lines of file and remove all whitespace
original_file = open(original_file)
original_data = "".join(original_file.readlines())
original_data = "".join(original_data.split())
json_file = open(json_file)
json_lines = json.load(json_file)
json_data = ""
for line in json_lines:
for sent in line["sent_annotations"]:
json_data += sent["sent_text"]
json_data = "".join(json_data.split())
# Test equivalence
assert (original_data == json_data)
def __get_constituency_parse(sent, nlp):
# Try parsing the doc
try:
parse_doc = nlp(sent.text)
except:
return None
# Get set of constituency parse trees
parse_trees = [str(sent.constituency) for sent in parse_doc.sentences]
# Join parse trees and add ROOT
constituency_parse = "(ROOT " + " ".join(parse_trees) + ")"
return constituency_parse
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='Tag BabyLM dataset',
description='Tag BabyLM dataset using Stanza')
parser.add_argument('path', type=argparse.FileType('r'),
nargs='+', help="Path to file(s)")
parser.add_argument('-p', '--parse', action='store_true',
help="Include constituency parse")
# Get args
args = parser.parse_args()
# Init Stanza NLP tools
nlp1 = stanza.Pipeline(
lang='en',
processors='tokenize, pos, lemma',
package="default_accurate",
use_gpu=True)
# If constituency parse is needed, init second Stanza parser
if args.parse:
nlp2 = stanza.Pipeline(lang='en',
processors='tokenize,pos,constituency',
package="default_accurate",
use_gpu=True)
# BATCH_SIZE = 5000
BATCH_SIZE=100
# Iterate over BabyLM files
for file in args.path:
print(file.name)
lines = file.readlines()
# Strip lines and join text
print("Concatenating lines...")
lines = [l.strip() for l in lines]
line_batches = [lines[i:i + BATCH_SIZE]
for i in range(0, len(lines), BATCH_SIZE)]
text_batches = [" ".join(l) for l in line_batches]
# Iterate over lines in file and track annotations
line_annotations = []
print("Segmenting and parsing text batches...")
for text in tqdm.tqdm(text_batches):
# Tokenize text with stanza
doc = nlp1(text)
# Iterate over sents in the line and track annotations
sent_annotations = []
for sent in doc.sentences:
# Iterate over words in sent and track annotations
word_annotations = []
for token, word in zip(sent.tokens, sent.words):
wa = {
'id': word.id,
'text': word.text,
'lemma': word.lemma,
'upos': word.upos,
'xpos': word.xpos,
'feats': word.feats,
'start_char': token.start_char,
'end_char': token.end_char
}
word_annotations.append(wa) # Track word annotation
# Get constituency parse if needed
if args.parse:
constituency_parse = __get_constituency_parse(sent, nlp2)
sa = {
'sent_text': sent.text,
'constituency_parse': constituency_parse,
'word_annotations': word_annotations,
}
else:
sa = {
'sent_text': sent.text,
'word_annotations': word_annotations,
}
sent_annotations.append(sa) # Track sent annotation
la = {
'sent_annotations': sent_annotations
}
line_annotations.append(la) # Track line annotation
# Write annotations to file as a JSON
print("Writing JSON outfile...")
ext = '_parsed.json' if args.parse else '.json'
json_filename = os.path.splitext(file.name)[0] + ext
with open(json_filename, "w") as outfile:
json.dump(line_annotations, outfile, indent=4) |