File size: 4,758 Bytes
94011a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# load_phrase_data.py
# author: Julie Kallini

# For importing utils
import sys
sys.path.append("..")

from utils import PERTURBATIONS
import argparse
import pandas as pd
import os
from tqdm import tqdm
from nltk import Tree
from numpy.random import default_rng


def get_span(tokens, sub_tokens):
    for i in range(len(tokens)):
        if tokens[i:i+len(sub_tokens)] == sub_tokens:
            start_idx, end_idx = i, i + len(sub_tokens)
            return start_idx, end_idx


def extract_phrases(tree, categories):
    """Extract phrases that belong to the specified categories."""
    results = []

    for subtree in tree.subtrees():
        if subtree.label() in categories:
            words = subtree.leaves()
            if len(words) > 1:
                phrase = ' '.join(words)
                results.append((phrase, subtree.label()))

    return results


def process_file(file):
    """Read file, extract phrases and return results."""
    results = []

    categories = ["NP", "VP", "ADJP", "ADVP", "PP"]

    # Get all files from the given path
    with open(file, 'r') as f:
        lines = f.readlines()

        # Process every two lines
        for i in tqdm(range(0, len(lines) - 1, 2)):
            sentence = lines[i].strip()

            if len(sentence.split()) < 5:
                continue

            tree_str = lines[i + 1].strip()
            tree = Tree.fromstring(tree_str)

            phrases = extract_phrases(tree, categories)
            for phrase, category in phrases:
                if phrase in sentence:
                    results.append((sentence, phrase, category))

    return results


if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        prog='Get phrase spans for edge probing',
        description='Get spans of text from constituency parses for edge probing experiments')
    parser.add_argument('perturbation_type',
                        default='all',
                        const='all',
                        nargs='?',
                        choices=PERTURBATIONS.keys(),
                        help='Perturbation function used to transform BabyLM dataset')

    # Get args
    args = parser.parse_args()

    # Get class of perturbations
    perturbation_class = None
    if "reverse" in args.perturbation_type:
        perturbation_class = "reverse" 
    elif "hop" in args.perturbation_type:
        perturbation_class = "hop"
    else:
        raise Exception("Perturbation class not implemented")

    # Process constituency parses
    print("Extracting phrases from constituency parses")
    data = process_file(
        f"test_constituency_parses/{perturbation_class}_parses.test")

    # Get a sufficiently large sample of phrases
    SAMPLE_SIZE = 10000
    RANDOM_STATE = 62
    rng = default_rng(RANDOM_STATE)
    rng.shuffle(data)

    # Find the start and end indices of the substring's tokens within the sentence's tokens
    tokenizer = PERTURBATIONS[args.perturbation_type]["gpt2_tokenizer"]
    span_sample_data = []
    print("Getting spans of tokens for constituents")
    for sentence, phrase, category in tqdm(data):

        # Tokenize both the full sentence and the substring
        tokens = tokenizer.encode(sentence)
        if len(tokens) > 1024:
            continue
        sub_tokens = tokenizer.encode(phrase)

        span = get_span(tokens, sub_tokens)

        # If span is not found, append space to substring
        if span is None:
            sub_tokens = tokenizer.encode(" " + phrase)
            span = get_span(tokens, sub_tokens)

        if span is not None:
            start_idx, end_idx = span
            rev_start_index, rev_end_index = len(tokens) - end_idx, len(tokens) - start_idx
            span_sample_data.append(
                (sentence, phrase, category, " ".join([str(t) for t in tokens]),
                 start_idx, end_idx, rev_start_index, rev_end_index))

    # Create DataFrame and write stratefied random sample
    sample_df = pd.DataFrame(data=span_sample_data, columns=[
                             "Sentence", "Phrase", "Category",
                             "Sentence Tokens", "Start Index", "End Index",
                             "Rev Start Index", "Rev End Index"])
    final_sample_df = sample_df.groupby('Category', group_keys=False).apply(
        lambda x: x.sample(SAMPLE_SIZE // 5, random_state=RANDOM_STATE))
    final_sample_df = final_sample_df.sample(frac=1).reset_index(drop=True)

    # Create directory and write CSV
    phrases_directory = f"phrase_data/"
    if not os.path.exists(phrases_directory):
        os.makedirs(phrases_directory)
    phrases_file = phrases_directory + f"{perturbation_class}_phrase_data.csv"
    final_sample_df.to_csv(phrases_file, index=False)