Impossible_llm / edge_probing /load_phrase_data.py
Yaning1001's picture
Add files using upload-large-folder tool
94011a1 verified
# load_phrase_data.py
# author: Julie Kallini
# For importing utils
import sys
sys.path.append("..")
from utils import PERTURBATIONS
import argparse
import pandas as pd
import os
from tqdm import tqdm
from nltk import Tree
from numpy.random import default_rng
def get_span(tokens, sub_tokens):
for i in range(len(tokens)):
if tokens[i:i+len(sub_tokens)] == sub_tokens:
start_idx, end_idx = i, i + len(sub_tokens)
return start_idx, end_idx
def extract_phrases(tree, categories):
"""Extract phrases that belong to the specified categories."""
results = []
for subtree in tree.subtrees():
if subtree.label() in categories:
words = subtree.leaves()
if len(words) > 1:
phrase = ' '.join(words)
results.append((phrase, subtree.label()))
return results
def process_file(file):
"""Read file, extract phrases and return results."""
results = []
categories = ["NP", "VP", "ADJP", "ADVP", "PP"]
# Get all files from the given path
with open(file, 'r') as f:
lines = f.readlines()
# Process every two lines
for i in tqdm(range(0, len(lines) - 1, 2)):
sentence = lines[i].strip()
if len(sentence.split()) < 5:
continue
tree_str = lines[i + 1].strip()
tree = Tree.fromstring(tree_str)
phrases = extract_phrases(tree, categories)
for phrase, category in phrases:
if phrase in sentence:
results.append((sentence, phrase, category))
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='Get phrase spans for edge probing',
description='Get spans of text from constituency parses for edge probing experiments')
parser.add_argument('perturbation_type',
default='all',
const='all',
nargs='?',
choices=PERTURBATIONS.keys(),
help='Perturbation function used to transform BabyLM dataset')
# Get args
args = parser.parse_args()
# Get class of perturbations
perturbation_class = None
if "reverse" in args.perturbation_type:
perturbation_class = "reverse"
elif "hop" in args.perturbation_type:
perturbation_class = "hop"
else:
raise Exception("Perturbation class not implemented")
# Process constituency parses
print("Extracting phrases from constituency parses")
data = process_file(
f"test_constituency_parses/{perturbation_class}_parses.test")
# Get a sufficiently large sample of phrases
SAMPLE_SIZE = 10000
RANDOM_STATE = 62
rng = default_rng(RANDOM_STATE)
rng.shuffle(data)
# Find the start and end indices of the substring's tokens within the sentence's tokens
tokenizer = PERTURBATIONS[args.perturbation_type]["gpt2_tokenizer"]
span_sample_data = []
print("Getting spans of tokens for constituents")
for sentence, phrase, category in tqdm(data):
# Tokenize both the full sentence and the substring
tokens = tokenizer.encode(sentence)
if len(tokens) > 1024:
continue
sub_tokens = tokenizer.encode(phrase)
span = get_span(tokens, sub_tokens)
# If span is not found, append space to substring
if span is None:
sub_tokens = tokenizer.encode(" " + phrase)
span = get_span(tokens, sub_tokens)
if span is not None:
start_idx, end_idx = span
rev_start_index, rev_end_index = len(tokens) - end_idx, len(tokens) - start_idx
span_sample_data.append(
(sentence, phrase, category, " ".join([str(t) for t in tokens]),
start_idx, end_idx, rev_start_index, rev_end_index))
# Create DataFrame and write stratefied random sample
sample_df = pd.DataFrame(data=span_sample_data, columns=[
"Sentence", "Phrase", "Category",
"Sentence Tokens", "Start Index", "End Index",
"Rev Start Index", "Rev End Index"])
final_sample_df = sample_df.groupby('Category', group_keys=False).apply(
lambda x: x.sample(SAMPLE_SIZE // 5, random_state=RANDOM_STATE))
final_sample_df = final_sample_df.sample(frac=1).reset_index(drop=True)
# Create directory and write CSV
phrases_directory = f"phrase_data/"
if not os.path.exists(phrases_directory):
os.makedirs(phrases_directory)
phrases_file = phrases_directory + f"{perturbation_class}_phrase_data.csv"
final_sample_df.to_csv(phrases_file, index=False)