File size: 4,758 Bytes
94011a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
# load_phrase_data.py
# author: Julie Kallini
# For importing utils
import sys
sys.path.append("..")
from utils import PERTURBATIONS
import argparse
import pandas as pd
import os
from tqdm import tqdm
from nltk import Tree
from numpy.random import default_rng
def get_span(tokens, sub_tokens):
for i in range(len(tokens)):
if tokens[i:i+len(sub_tokens)] == sub_tokens:
start_idx, end_idx = i, i + len(sub_tokens)
return start_idx, end_idx
def extract_phrases(tree, categories):
"""Extract phrases that belong to the specified categories."""
results = []
for subtree in tree.subtrees():
if subtree.label() in categories:
words = subtree.leaves()
if len(words) > 1:
phrase = ' '.join(words)
results.append((phrase, subtree.label()))
return results
def process_file(file):
"""Read file, extract phrases and return results."""
results = []
categories = ["NP", "VP", "ADJP", "ADVP", "PP"]
# Get all files from the given path
with open(file, 'r') as f:
lines = f.readlines()
# Process every two lines
for i in tqdm(range(0, len(lines) - 1, 2)):
sentence = lines[i].strip()
if len(sentence.split()) < 5:
continue
tree_str = lines[i + 1].strip()
tree = Tree.fromstring(tree_str)
phrases = extract_phrases(tree, categories)
for phrase, category in phrases:
if phrase in sentence:
results.append((sentence, phrase, category))
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='Get phrase spans for edge probing',
description='Get spans of text from constituency parses for edge probing experiments')
parser.add_argument('perturbation_type',
default='all',
const='all',
nargs='?',
choices=PERTURBATIONS.keys(),
help='Perturbation function used to transform BabyLM dataset')
# Get args
args = parser.parse_args()
# Get class of perturbations
perturbation_class = None
if "reverse" in args.perturbation_type:
perturbation_class = "reverse"
elif "hop" in args.perturbation_type:
perturbation_class = "hop"
else:
raise Exception("Perturbation class not implemented")
# Process constituency parses
print("Extracting phrases from constituency parses")
data = process_file(
f"test_constituency_parses/{perturbation_class}_parses.test")
# Get a sufficiently large sample of phrases
SAMPLE_SIZE = 10000
RANDOM_STATE = 62
rng = default_rng(RANDOM_STATE)
rng.shuffle(data)
# Find the start and end indices of the substring's tokens within the sentence's tokens
tokenizer = PERTURBATIONS[args.perturbation_type]["gpt2_tokenizer"]
span_sample_data = []
print("Getting spans of tokens for constituents")
for sentence, phrase, category in tqdm(data):
# Tokenize both the full sentence and the substring
tokens = tokenizer.encode(sentence)
if len(tokens) > 1024:
continue
sub_tokens = tokenizer.encode(phrase)
span = get_span(tokens, sub_tokens)
# If span is not found, append space to substring
if span is None:
sub_tokens = tokenizer.encode(" " + phrase)
span = get_span(tokens, sub_tokens)
if span is not None:
start_idx, end_idx = span
rev_start_index, rev_end_index = len(tokens) - end_idx, len(tokens) - start_idx
span_sample_data.append(
(sentence, phrase, category, " ".join([str(t) for t in tokens]),
start_idx, end_idx, rev_start_index, rev_end_index))
# Create DataFrame and write stratefied random sample
sample_df = pd.DataFrame(data=span_sample_data, columns=[
"Sentence", "Phrase", "Category",
"Sentence Tokens", "Start Index", "End Index",
"Rev Start Index", "Rev End Index"])
final_sample_df = sample_df.groupby('Category', group_keys=False).apply(
lambda x: x.sample(SAMPLE_SIZE // 5, random_state=RANDOM_STATE))
final_sample_df = final_sample_df.sample(frac=1).reset_index(drop=True)
# Create directory and write CSV
phrases_directory = f"phrase_data/"
if not os.path.exists(phrases_directory):
os.makedirs(phrases_directory)
phrases_file = phrases_directory + f"{perturbation_class}_phrase_data.csv"
final_sample_df.to_csv(phrases_file, index=False)
|