Spaces:
Runtime error
Runtime error
import jax | |
import jax.numpy as jnp | |
from bigbird_flax import FlaxBigBirdForNaturalQuestions | |
from datasets import load_from_disk | |
from transformers import BigBirdTokenizerFast | |
CATEGORY_MAPPING = {0: "null", 1: "short", 2: "long", 3: "yes", 4: "no"} | |
PUNCTUATION_SET_TO_EXCLUDE = set("".join(["‘", "’", "´", "`", ".", ",", "-", '"'])) | |
def get_sub_answers(answers, begin=0, end=None): | |
return [" ".join(x.split(" ")[begin:end]) for x in answers if len(x.split(" ")) > 1] | |
def expand_to_aliases(given_answers, make_sub_answers=False): | |
if make_sub_answers: | |
# if answers are longer than one word, make sure a predictions is correct if it coresponds to the complete 1: or :-1 sub word | |
# *e.g.* if the correct answer contains a prefix such as "the", or "a" | |
given_answers = ( | |
given_answers + get_sub_answers(given_answers, begin=1) + get_sub_answers(given_answers, end=-1) | |
) | |
answers = [] | |
for answer in given_answers: | |
alias = answer.replace("_", " ").lower() | |
alias = "".join(c if c not in PUNCTUATION_SET_TO_EXCLUDE else " " for c in alias) | |
answers.append(" ".join(alias.split()).strip()) | |
return set(answers) | |
def get_best_valid_start_end_idx(start_scores, end_scores, top_k=1, max_size=100): | |
best_start_scores, best_start_idx = jax.lax.top_k(start_scores, top_k) | |
best_end_scores, best_end_idx = jax.lax.top_k(end_scores, top_k) | |
widths = best_end_idx[:, None] - best_start_idx[None, :] | |
mask = jnp.logical_or(widths < 0, widths > max_size) | |
scores = (best_end_scores[:, None] + best_start_scores[None, :]) - (1e8 * mask) | |
best_score = jnp.argmax(scores).item() | |
return best_start_idx[best_score % top_k], best_end_idx[best_score // top_k] | |
def format_dataset(sample): | |
question = sample["question"]["text"] | |
context = sample["document"]["tokens"]["token"] | |
is_html = sample["document"]["tokens"]["is_html"] | |
long_answers = sample["annotations"]["long_answer"] | |
short_answers = sample["annotations"]["short_answers"] | |
context_string = " ".join([context[i] for i in range(len(context)) if not is_html[i]]) | |
# 0 - No ; 1 - Yes | |
for answer in sample["annotations"]["yes_no_answer"]: | |
if answer == 0 or answer == 1: | |
return { | |
"question": question, | |
"context": context_string, | |
"short": [], | |
"long": [], | |
"category": "no" if answer == 0 else "yes", | |
} | |
short_targets = [] | |
for s in short_answers: | |
short_targets.extend(s["text"]) | |
short_targets = list(set(short_targets)) | |
long_targets = [] | |
for s in long_answers: | |
if s["start_token"] == -1: | |
continue | |
answer = context[s["start_token"] : s["end_token"]] | |
html = is_html[s["start_token"] : s["end_token"]] | |
new_answer = " ".join([answer[i] for i in range(len(answer)) if not html[i]]) | |
if new_answer not in long_targets: | |
long_targets.append(new_answer) | |
category = "long_short" if len(short_targets + long_targets) > 0 else "null" | |
return { | |
"question": question, | |
"context": context_string, | |
"short": short_targets, | |
"long": long_targets, | |
"category": category, | |
} | |
def main(): | |
dataset = load_from_disk("natural-questions-validation") | |
dataset = dataset.map(format_dataset).remove_columns(["annotations", "document", "id"]) | |
print(dataset) | |
short_validation_dataset = dataset.filter(lambda x: (len(x["question"]) + len(x["context"])) < 4 * 4096) | |
short_validation_dataset = short_validation_dataset.filter(lambda x: x["category"] != "null") | |
short_validation_dataset | |
model_id = "vasudevgupta/flax-bigbird-natural-questions" | |
model = FlaxBigBirdForNaturalQuestions.from_pretrained(model_id) | |
tokenizer = BigBirdTokenizerFast.from_pretrained(model_id) | |
def forward(*args, **kwargs): | |
start_logits, end_logits, pooled_logits = model(*args, **kwargs) | |
return start_logits, end_logits, jnp.argmax(pooled_logits, axis=-1) | |
def evaluate(example): | |
# encode question and context so that they are separated by a tokenizer.sep_token and cut at max_length | |
inputs = tokenizer( | |
example["question"], | |
example["context"], | |
return_tensors="np", | |
max_length=4096, | |
padding="max_length", | |
truncation=True, | |
) | |
start_scores, end_scores, category = forward(**inputs) | |
predicted_category = CATEGORY_MAPPING[category.item()] | |
example["targets"] = example["long"] + example["short"] | |
if example["category"] in ["yes", "no", "null"]: | |
example["targets"] = [example["category"]] | |
example["has_tgt"] = example["category"] != "null" | |
# Now target can be: "yes", "no", "null", "list of long & short answers" | |
if predicted_category in ["yes", "no", "null"]: | |
example["output"] = [predicted_category] | |
example["match"] = example["output"] == example["targets"] | |
example["has_pred"] = predicted_category != "null" | |
return example | |
max_size = 38 if predicted_category == "short" else 1024 | |
start_score, end_score = get_best_valid_start_end_idx( | |
start_scores[0], end_scores[0], top_k=8, max_size=max_size | |
) | |
input_ids = inputs["input_ids"][0].tolist() | |
example["output"] = [tokenizer.decode(input_ids[start_score : end_score + 1])] | |
answers = expand_to_aliases(example["targets"], make_sub_answers=True) | |
predictions = expand_to_aliases(example["output"]) | |
# some preprocessing to both prediction and answer | |
answers = {"".join(a.split()) for a in answers} | |
predictions = {"".join(p.split()) for p in predictions} | |
predictions = {s for s in predictions if s not in ["``", "''", "`", "'"]} | |
# if there is a common element, it's a exact match | |
example["match"] = len(list(answers & predictions)) > 0 | |
example["has_pred"] = predicted_category != "null" and len(predictions) > 0 | |
return example | |
short_validation_dataset = short_validation_dataset.map(evaluate) | |
total = len(short_validation_dataset) | |
matched = len(short_validation_dataset.filter(lambda x: x["match"] == 1)) | |
print("EM score:", (matched / total) * 100, "%") | |
if __name__ == "__main__": | |
main() | |