homeway's picture
Add application file
7713b1f
import torch
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from torch.utils.data import DataLoader
from datasets.arrow_dataset import Dataset as HFDataset
from datasets.load import load_metric, load_dataset
from transformers import AutoTokenizer, DataCollatorForTokenClassification, BertConfig
from transformers import default_data_collator, EvalPrediction
import numpy as np
import logging
from tasks.qa.utils_qa import postprocess_qa_predictions
class SQuAD:
def __init__(self, tokenizer: AutoTokenizer, data_args, training_args, qa_args) -> None:
self.data_args = data_args
self.training_args = training_args
self.qa_args = qa_args
self.version_2 = data_args.dataset_name == "squad_v2"
raw_datasets = load_dataset(data_args.dataset_name)
column_names = raw_datasets['train'].column_names
self.question_column_name = "question"
self.context_column_name = "context"
self.answer_column_name = "answers"
self.tokenizer = tokenizer
self.pad_on_right = tokenizer.padding_side == "right" # True
self.max_seq_len = 384 #data_args.max_seq_length
if training_args.do_train:
self.train_dataset = raw_datasets['train']
self.train_dataset = self.train_dataset.map(
self.prepare_train_dataset,
batched=True,
remove_columns=column_names,
load_from_cache_file=True,
desc="Running tokenizer on train dataset",
)
if data_args.max_train_samples is not None:
self.train_dataset = self.train_dataset.select(range(data_args.max_train_samples))
if training_args.do_eval:
self.eval_examples = raw_datasets['validation']
if data_args.max_eval_samples is not None:
self.eval_examples = self.eval_examples.select(range(data_args.max_eval_samples))
self.eval_dataset = self.eval_examples.map(
self.prepare_eval_dataset,
batched=True,
remove_columns=column_names,
load_from_cache_file=True,
desc="Running tokenizer on validation dataset",
)
if data_args.max_eval_samples is not None:
self.eval_dataset = self.eval_dataset.select(range(data_args.max_eval_samples))
self.predict_dataset = None
self.data_collator = default_data_collator
self.metric = load_metric(data_args.dataset_name)
def prepare_train_dataset(self, examples):
examples['question'] = [q.lstrip() for q in examples['question']]
tokenized = self.tokenizer(
examples['question' if self.pad_on_right else 'context'],
examples['context' if self.pad_on_right else 'question'],
truncation='only_second' if self.pad_on_right else 'only_first',
max_length=self.max_seq_len,
stride=128,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)
sample_maping = tokenized.pop("overflow_to_sample_mapping")
offset_mapping = tokenized.pop("offset_mapping")
tokenized["start_positions"] = []
tokenized["end_positions"] = []
for i, offsets in enumerate(offset_mapping):
input_ids = tokenized['input_ids'][i]
cls_index = input_ids.index(self.tokenizer.cls_token_id)
sequence_ids = tokenized.sequence_ids(i)
sample_index = sample_maping[i]
answers = examples['answers'][sample_index]
if len(answers['answer_start']) == 0:
tokenized["start_positions"].append(cls_index)
tokenized["end_positions"].append(cls_index)
else:
start_char = answers["answer_start"][0]
end_char = start_char + len(answers["text"][0])
token_start_index = 0
while sequence_ids[token_start_index] != (1 if self.pad_on_right else 0):
token_start_index += 1
token_end_index = len(input_ids) - 1
while sequence_ids[token_end_index] != (1 if self.pad_on_right else 0):
token_end_index -= 1
# Detect if the answer is out of the span
# (in which case this feature is labeled with the CLS index).
if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
tokenized["start_positions"].append(cls_index)
tokenized["end_positions"].append(cls_index)
else:
# Otherwise move the token_start_index and token_end_index to the two ends of the answer.
# Note: we could go after the last offset if the answer is the last word (edge case).
while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
token_start_index += 1
tokenized["start_positions"].append(token_start_index - 1)
while offsets[token_end_index][1] >= end_char:
token_end_index -= 1
tokenized["end_positions"].append(token_end_index + 1)
return tokenized
def prepare_eval_dataset(self, examples):
# if self.version_2:
examples['question'] = [q.lstrip() for q in examples['question']]
tokenized = self.tokenizer(
examples['question' if self.pad_on_right else 'context'],
examples['context' if self.pad_on_right else 'question'],
truncation='only_second' if self.pad_on_right else 'only_first',
max_length=self.max_seq_len,
stride=128,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)
sample_mapping = tokenized.pop("overflow_to_sample_mapping")
tokenized["example_id"] = []
for i in range(len(tokenized["input_ids"])):
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
sequence_ids = tokenized.sequence_ids(i)
context_index = 1 if self.pad_on_right else 0
# One example can give several spans, this is the index of the example containing this span of text.
sample_index = sample_mapping[i]
tokenized["example_id"].append(examples["id"][sample_index])
# Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
# position is part of the context or not.
tokenized["offset_mapping"][i] = [
(o if sequence_ids[k] == context_index else None)
for k, o in enumerate(tokenized["offset_mapping"][i])
]
return tokenized
def compute_metrics(self, p: EvalPrediction):
return self.metric.compute(predictions=p.predictions, references=p.label_ids)
def post_processing_function(self, examples, features, predictions, stage='eval'):
predictions = postprocess_qa_predictions(
examples=examples,
features=features,
predictions=predictions,
version_2_with_negative=self.version_2,
n_best_size=self.qa_args.n_best_size,
max_answer_length=self.qa_args.max_answer_length,
null_score_diff_threshold=self.qa_args.null_score_diff_threshold,
output_dir=self.training_args.output_dir,
prefix=stage,
log_level=logging.INFO
)
if self.version_2: # squad_v2
formatted_predictions = [
{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
]
else:
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
references = [{"id": ex["id"], "answers": ex['answers']} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)