Spaces:
Sleeping
Sleeping
import torch | |
from torch.utils.data.sampler import RandomSampler, SequentialSampler | |
from torch.utils.data import DataLoader | |
from datasets.arrow_dataset import Dataset as HFDataset | |
from datasets.load import load_metric, load_dataset | |
from transformers import AutoTokenizer, DataCollatorForTokenClassification, BertConfig | |
from transformers import default_data_collator, EvalPrediction | |
import numpy as np | |
import logging | |
from tasks.qa.utils_qa import postprocess_qa_predictions | |
class SQuAD: | |
def __init__(self, tokenizer: AutoTokenizer, data_args, training_args, qa_args) -> None: | |
self.data_args = data_args | |
self.training_args = training_args | |
self.qa_args = qa_args | |
self.version_2 = data_args.dataset_name == "squad_v2" | |
raw_datasets = load_dataset(data_args.dataset_name) | |
column_names = raw_datasets['train'].column_names | |
self.question_column_name = "question" | |
self.context_column_name = "context" | |
self.answer_column_name = "answers" | |
self.tokenizer = tokenizer | |
self.pad_on_right = tokenizer.padding_side == "right" # True | |
self.max_seq_len = 384 #data_args.max_seq_length | |
if training_args.do_train: | |
self.train_dataset = raw_datasets['train'] | |
self.train_dataset = self.train_dataset.map( | |
self.prepare_train_dataset, | |
batched=True, | |
remove_columns=column_names, | |
load_from_cache_file=True, | |
desc="Running tokenizer on train dataset", | |
) | |
if data_args.max_train_samples is not None: | |
self.train_dataset = self.train_dataset.select(range(data_args.max_train_samples)) | |
if training_args.do_eval: | |
self.eval_examples = raw_datasets['validation'] | |
if data_args.max_eval_samples is not None: | |
self.eval_examples = self.eval_examples.select(range(data_args.max_eval_samples)) | |
self.eval_dataset = self.eval_examples.map( | |
self.prepare_eval_dataset, | |
batched=True, | |
remove_columns=column_names, | |
load_from_cache_file=True, | |
desc="Running tokenizer on validation dataset", | |
) | |
if data_args.max_eval_samples is not None: | |
self.eval_dataset = self.eval_dataset.select(range(data_args.max_eval_samples)) | |
self.predict_dataset = None | |
self.data_collator = default_data_collator | |
self.metric = load_metric(data_args.dataset_name) | |
def prepare_train_dataset(self, examples): | |
examples['question'] = [q.lstrip() for q in examples['question']] | |
tokenized = self.tokenizer( | |
examples['question' if self.pad_on_right else 'context'], | |
examples['context' if self.pad_on_right else 'question'], | |
truncation='only_second' if self.pad_on_right else 'only_first', | |
max_length=self.max_seq_len, | |
stride=128, | |
return_overflowing_tokens=True, | |
return_offsets_mapping=True, | |
padding="max_length", | |
) | |
sample_maping = tokenized.pop("overflow_to_sample_mapping") | |
offset_mapping = tokenized.pop("offset_mapping") | |
tokenized["start_positions"] = [] | |
tokenized["end_positions"] = [] | |
for i, offsets in enumerate(offset_mapping): | |
input_ids = tokenized['input_ids'][i] | |
cls_index = input_ids.index(self.tokenizer.cls_token_id) | |
sequence_ids = tokenized.sequence_ids(i) | |
sample_index = sample_maping[i] | |
answers = examples['answers'][sample_index] | |
if len(answers['answer_start']) == 0: | |
tokenized["start_positions"].append(cls_index) | |
tokenized["end_positions"].append(cls_index) | |
else: | |
start_char = answers["answer_start"][0] | |
end_char = start_char + len(answers["text"][0]) | |
token_start_index = 0 | |
while sequence_ids[token_start_index] != (1 if self.pad_on_right else 0): | |
token_start_index += 1 | |
token_end_index = len(input_ids) - 1 | |
while sequence_ids[token_end_index] != (1 if self.pad_on_right else 0): | |
token_end_index -= 1 | |
# Detect if the answer is out of the span | |
# (in which case this feature is labeled with the CLS index). | |
if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): | |
tokenized["start_positions"].append(cls_index) | |
tokenized["end_positions"].append(cls_index) | |
else: | |
# Otherwise move the token_start_index and token_end_index to the two ends of the answer. | |
# Note: we could go after the last offset if the answer is the last word (edge case). | |
while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char: | |
token_start_index += 1 | |
tokenized["start_positions"].append(token_start_index - 1) | |
while offsets[token_end_index][1] >= end_char: | |
token_end_index -= 1 | |
tokenized["end_positions"].append(token_end_index + 1) | |
return tokenized | |
def prepare_eval_dataset(self, examples): | |
# if self.version_2: | |
examples['question'] = [q.lstrip() for q in examples['question']] | |
tokenized = self.tokenizer( | |
examples['question' if self.pad_on_right else 'context'], | |
examples['context' if self.pad_on_right else 'question'], | |
truncation='only_second' if self.pad_on_right else 'only_first', | |
max_length=self.max_seq_len, | |
stride=128, | |
return_overflowing_tokens=True, | |
return_offsets_mapping=True, | |
padding="max_length", | |
) | |
sample_mapping = tokenized.pop("overflow_to_sample_mapping") | |
tokenized["example_id"] = [] | |
for i in range(len(tokenized["input_ids"])): | |
# Grab the sequence corresponding to that example (to know what is the context and what is the question). | |
sequence_ids = tokenized.sequence_ids(i) | |
context_index = 1 if self.pad_on_right else 0 | |
# One example can give several spans, this is the index of the example containing this span of text. | |
sample_index = sample_mapping[i] | |
tokenized["example_id"].append(examples["id"][sample_index]) | |
# Set to None the offset_mapping that are not part of the context so it's easy to determine if a token | |
# position is part of the context or not. | |
tokenized["offset_mapping"][i] = [ | |
(o if sequence_ids[k] == context_index else None) | |
for k, o in enumerate(tokenized["offset_mapping"][i]) | |
] | |
return tokenized | |
def compute_metrics(self, p: EvalPrediction): | |
return self.metric.compute(predictions=p.predictions, references=p.label_ids) | |
def post_processing_function(self, examples, features, predictions, stage='eval'): | |
predictions = postprocess_qa_predictions( | |
examples=examples, | |
features=features, | |
predictions=predictions, | |
version_2_with_negative=self.version_2, | |
n_best_size=self.qa_args.n_best_size, | |
max_answer_length=self.qa_args.max_answer_length, | |
null_score_diff_threshold=self.qa_args.null_score_diff_threshold, | |
output_dir=self.training_args.output_dir, | |
prefix=stage, | |
log_level=logging.INFO | |
) | |
if self.version_2: # squad_v2 | |
formatted_predictions = [ | |
{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items() | |
] | |
else: | |
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] | |
references = [{"id": ex["id"], "answers": ex['answers']} for ex in examples] | |
return EvalPrediction(predictions=formatted_predictions, label_ids=references) | |