|
from datasets import load_dataset |
|
from torch.utils.data import DataLoader |
|
import lightning as L |
|
from transformers import AutoTokenizer |
|
|
|
|
|
class Squad_v2(L.LightningDataModule): |
|
ds = None |
|
data_model_name_or_path = "" |
|
tokenizer_model_name_or_path = "" |
|
batch_size = 32 |
|
|
|
def __init__( |
|
self, |
|
*, |
|
data_model_name_or_path: str = "rajpurkar/squad_v2", |
|
tokenizer_model_name_or_path="google-bert/bert-base-uncased", |
|
batch_size: int = 32, |
|
data_from_hf: str = "eming/squad_v2_processed", |
|
): |
|
super().__init__() |
|
self.data_model_name_or_path = data_model_name_or_path |
|
self.batch_size = batch_size |
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_name_or_path) |
|
self.data_hf = data_from_hf |
|
|
|
def _handle1(self, x): |
|
|
|
new_examples = { |
|
"id": [], |
|
"input_text": [], |
|
"start_positions": [], |
|
"end_positions": [], |
|
} |
|
|
|
for i in range(len(x["id"])): |
|
new_examples["id"].append(x["id"][i]) |
|
context = x["context"][i] |
|
question = x["question"][i] |
|
answers = x["answers"][i]["text"] |
|
answer_start = x["answers"][i]["answer_start"] |
|
|
|
|
|
context_list = context.split(" ") |
|
|
|
|
|
if not answers: |
|
start_positions = 0 |
|
end_positions = 0 |
|
else: |
|
|
|
answer_split = answers[0].split(" ") |
|
|
|
|
|
start_char = answer_start[0] |
|
end_char = start_char + len(" ".join(answer_split)) - 1 |
|
start_positions = None |
|
end_positions = None |
|
|
|
char_count = 0 |
|
for idx, word in enumerate(context_list): |
|
char_count += len(word) + 1 |
|
|
|
if char_count > start_char and start_positions is None: |
|
start_positions = idx |
|
if char_count > end_char and end_positions is None: |
|
end_positions = idx |
|
break |
|
|
|
|
|
|
|
assert start_positions is not None |
|
assert end_positions is not None |
|
|
|
|
|
question_split = question.split(" ") |
|
input_text = ["[CLS]"] + question_split + ["[SEP]"] + context_list |
|
|
|
|
|
|
|
if start_positions != -1 and end_positions != -1: |
|
|
|
start_positions += 1 + len(question_split) + 1 |
|
end_positions += 1 + len(question_split) + 1 |
|
|
|
|
|
new_examples["input_text"].append(input_text) |
|
new_examples["start_positions"].append(start_positions) |
|
new_examples["end_positions"].append(end_positions) |
|
|
|
return new_examples |
|
|
|
def _handle2(self, x): |
|
""" |
|
tokenized the input_text |
|
|
|
x: batch of input_text |
|
""" |
|
new_examples = { |
|
"id": [], |
|
"input_text": [], |
|
"start_positions": [], |
|
"end_positions": [], |
|
"input_ids": [], |
|
"attention_mask": [], |
|
} |
|
for i in range(len(x["id"])): |
|
if x["end_positions"][i] >= 512: |
|
continue |
|
|
|
input_text = [ |
|
|
|
( |
|
self.tokenizer.backend_tokenizer.normalizer.normalize_str(t) |
|
if t not in self.tokenizer.all_special_tokens |
|
else t |
|
) |
|
for t in x["input_text"][i] |
|
] |
|
|
|
|
|
|
|
new_examples["id"].append(x["id"][i]) |
|
new_examples["input_text"].append(input_text) |
|
new_examples["start_positions"].append(x["start_positions"][i]) |
|
new_examples["end_positions"].append(x["end_positions"][i]) |
|
tkn = self.tokenizer( |
|
x["input_text"][i], |
|
padding="max_length", |
|
truncation=True, |
|
max_length=512, |
|
is_split_into_words=True, |
|
) |
|
new_examples["input_ids"].append(tkn["input_ids"]) |
|
new_examples["attention_mask"].append(tkn["attention_mask"]) |
|
return new_examples |
|
|
|
def setup(self, stage): |
|
if self.data_hf != "": |
|
self.ds = load_dataset(self.data_hf) |
|
else: |
|
self.ds = load_dataset(self.data_model_name_or_path) |
|
|
|
self.ds = self.ds.map( |
|
self._handle1, |
|
batched=True, |
|
remove_columns=self.ds["train"].column_names, |
|
) |
|
self.ds = self.ds.map( |
|
self._handle2, |
|
batched=True, |
|
remove_columns=self.ds["train"].column_names, |
|
) |
|
|
|
self.squad_train = self.ds["train"] |
|
self.squad_train.set_format(type="torch") |
|
|
|
self.squad_test = self.ds["validation"] |
|
self.squad_test.set_format(type="torch") |
|
|
|
def train_dataloader(self): |
|
return DataLoader( |
|
self.squad_train.remove_columns(["id", "input_text"]), |
|
batch_size=self.batch_size, |
|
num_workers=8, |
|
) |
|
|
|
def val_dataloader(self): |
|
return DataLoader( |
|
self.squad_test.remove_columns(["id", "input_text"]), |
|
batch_size=self.batch_size, |
|
num_workers=8, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
from datasets import load_dataset |
|
|
|
data = Squad_v2(data_from_hf="eming/squad_v2_processed") |
|
data.setup() |
|
|
|
|
|
data.ds.push_to_hub( |
|
"eming/squad_v2_processed", |
|
private=False, |
|
) |
|
|