File size: 6,449 Bytes
96e9536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import jax
import jax.numpy as jnp
from bigbird_flax import FlaxBigBirdForNaturalQuestions
from datasets import load_from_disk

from transformers import BigBirdTokenizerFast


CATEGORY_MAPPING = {0: "null", 1: "short", 2: "long", 3: "yes", 4: "no"}
PUNCTUATION_SET_TO_EXCLUDE = set("".join(["‘", "’", "´", "`", ".", ",", "-", '"']))


def get_sub_answers(answers, begin=0, end=None):
    return [" ".join(x.split(" ")[begin:end]) for x in answers if len(x.split(" ")) > 1]


def expand_to_aliases(given_answers, make_sub_answers=False):
    if make_sub_answers:
        # if answers are longer than one word, make sure a predictions is correct if it coresponds to the complete 1: or :-1 sub word
        # *e.g.* if the correct answer contains a prefix such as "the", or "a"
        given_answers = (
            given_answers + get_sub_answers(given_answers, begin=1) + get_sub_answers(given_answers, end=-1)
        )
    answers = []
    for answer in given_answers:
        alias = answer.replace("_", " ").lower()
        alias = "".join(c if c not in PUNCTUATION_SET_TO_EXCLUDE else " " for c in alias)
        answers.append(" ".join(alias.split()).strip())
    return set(answers)


def get_best_valid_start_end_idx(start_scores, end_scores, top_k=1, max_size=100):
    best_start_scores, best_start_idx = jax.lax.top_k(start_scores, top_k)
    best_end_scores, best_end_idx = jax.lax.top_k(end_scores, top_k)

    widths = best_end_idx[:, None] - best_start_idx[None, :]
    mask = jnp.logical_or(widths < 0, widths > max_size)
    scores = (best_end_scores[:, None] + best_start_scores[None, :]) - (1e8 * mask)
    best_score = jnp.argmax(scores).item()

    return best_start_idx[best_score % top_k], best_end_idx[best_score // top_k]


def format_dataset(sample):
    question = sample["question"]["text"]
    context = sample["document"]["tokens"]["token"]
    is_html = sample["document"]["tokens"]["is_html"]
    long_answers = sample["annotations"]["long_answer"]
    short_answers = sample["annotations"]["short_answers"]

    context_string = " ".join([context[i] for i in range(len(context)) if not is_html[i]])

    # 0 - No ; 1 - Yes
    for answer in sample["annotations"]["yes_no_answer"]:
        if answer == 0 or answer == 1:
            return {
                "question": question,
                "context": context_string,
                "short": [],
                "long": [],
                "category": "no" if answer == 0 else "yes",
            }

    short_targets = []
    for s in short_answers:
        short_targets.extend(s["text"])
    short_targets = list(set(short_targets))

    long_targets = []
    for s in long_answers:
        if s["start_token"] == -1:
            continue
        answer = context[s["start_token"] : s["end_token"]]
        html = is_html[s["start_token"] : s["end_token"]]
        new_answer = " ".join([answer[i] for i in range(len(answer)) if not html[i]])
        if new_answer not in long_targets:
            long_targets.append(new_answer)

    category = "long_short" if len(short_targets + long_targets) > 0 else "null"

    return {
        "question": question,
        "context": context_string,
        "short": short_targets,
        "long": long_targets,
        "category": category,
    }


def main():
    dataset = load_from_disk("natural-questions-validation")
    dataset = dataset.map(format_dataset).remove_columns(["annotations", "document", "id"])
    print(dataset)

    short_validation_dataset = dataset.filter(lambda x: (len(x["question"]) + len(x["context"])) < 4 * 4096)
    short_validation_dataset = short_validation_dataset.filter(lambda x: x["category"] != "null")
    short_validation_dataset

    model_id = "vasudevgupta/flax-bigbird-natural-questions"
    model = FlaxBigBirdForNaturalQuestions.from_pretrained(model_id)
    tokenizer = BigBirdTokenizerFast.from_pretrained(model_id)

    @jax.jit
    def forward(*args, **kwargs):
        start_logits, end_logits, pooled_logits = model(*args, **kwargs)
        return start_logits, end_logits, jnp.argmax(pooled_logits, axis=-1)

    def evaluate(example):
        # encode question and context so that they are separated by a tokenizer.sep_token and cut at max_length
        inputs = tokenizer(
            example["question"],
            example["context"],
            return_tensors="np",
            max_length=4096,
            padding="max_length",
            truncation=True,
        )

        start_scores, end_scores, category = forward(**inputs)

        predicted_category = CATEGORY_MAPPING[category.item()]

        example["targets"] = example["long"] + example["short"]
        if example["category"] in ["yes", "no", "null"]:
            example["targets"] = [example["category"]]
        example["has_tgt"] = example["category"] != "null"
        # Now target can be: "yes", "no", "null", "list of long & short answers"

        if predicted_category in ["yes", "no", "null"]:
            example["output"] = [predicted_category]
            example["match"] = example["output"] == example["targets"]
            example["has_pred"] = predicted_category != "null"
            return example

        max_size = 38 if predicted_category == "short" else 1024
        start_score, end_score = get_best_valid_start_end_idx(
            start_scores[0], end_scores[0], top_k=8, max_size=max_size
        )

        input_ids = inputs["input_ids"][0].tolist()
        example["output"] = [tokenizer.decode(input_ids[start_score : end_score + 1])]

        answers = expand_to_aliases(example["targets"], make_sub_answers=True)
        predictions = expand_to_aliases(example["output"])

        # some preprocessing to both prediction and answer
        answers = {"".join(a.split()) for a in answers}
        predictions = {"".join(p.split()) for p in predictions}
        predictions = {s for s in predictions if s not in ["``", "''", "`", "'"]}

        # if there is a common element, it's a exact match
        example["match"] = len(list(answers & predictions)) > 0
        example["has_pred"] = predicted_category != "null" and len(predictions) > 0

        return example

    short_validation_dataset = short_validation_dataset.map(evaluate)

    total = len(short_validation_dataset)
    matched = len(short_validation_dataset.filter(lambda x: x["match"] == 1))
    print("EM score:", (matched / total) * 100, "%")


if __name__ == "__main__":
    main()