BertSquard_v2 / data.py
MingLi
fix
3021ad5
from datasets import load_dataset
from torch.utils.data import DataLoader
import lightning as L
from transformers import AutoTokenizer
class Squad_v2(L.LightningDataModule):
ds = None
data_model_name_or_path = ""
tokenizer_model_name_or_path = ""
batch_size = 32
def __init__(
self,
*,
data_model_name_or_path: str = "rajpurkar/squad_v2",
tokenizer_model_name_or_path="google-bert/bert-base-uncased",
batch_size: int = 32,
data_from_hf: str = "eming/squad_v2_processed",
):
super().__init__()
self.data_model_name_or_path = data_model_name_or_path
self.batch_size = batch_size
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_name_or_path)
self.data_hf = data_from_hf
def _handle1(self, x):
# 初始化返回的字典
new_examples = {
"id": [],
"input_text": [],
"start_positions": [],
"end_positions": [],
}
for i in range(len(x["id"])):
new_examples["id"].append(x["id"][i])
context = x["context"][i]
question = x["question"][i]
answers = x["answers"][i]["text"]
answer_start = x["answers"][i]["answer_start"]
# Step 1: Split the context into smaller chunks
context_list = context.split(" ") # 假设以空格进行分割
# Step 2: Handle the case where there is no answer
if not answers: # 如果答案为空
start_positions = 0
end_positions = 0
else:
# 获取分割后的答案
answer_split = answers[0].split(" ")
# Step 3: Calculate the answer_start position in the context_list
start_char = answer_start[0]
end_char = start_char + len(" ".join(answer_split)) - 1 # 计算答案结束的字符位置
start_positions = None
end_positions = None
char_count = 0
for idx, word in enumerate(context_list):
char_count += len(word) + 1 # 考虑到空格
# assert context[char_count - 1] == " "
if char_count > start_char and start_positions is None:
start_positions = idx
if char_count > end_char and end_positions is None:
end_positions = idx
break
# if end_char == len(context):
# end_positions = idx
assert start_positions is not None
assert end_positions is not None
# Step 4: Create the input format for BERT [CLS] <question> [SEP] <context>
question_split = question.split(" ")
input_text = ["[CLS]"] + question_split + ["[SEP]"] + context_list
# Step 5: Adjust the answer positions relative to the input_text
# Since `[CLS]`, `<question>`, and `[SEP]` are part of the input, we need to offset answer positions
if start_positions != -1 and end_positions != -1:
# Add the number of tokens in the question and [SEP]
start_positions += 1 + len(question_split) + 1 # +1 for [CLS] and +1 for [SEP]
end_positions += 1 + len(question_split) + 1
# Step 6: Append the results to the dictionary
new_examples["input_text"].append(input_text)
new_examples["start_positions"].append(start_positions)
new_examples["end_positions"].append(end_positions)
return new_examples
def _handle2(self, x):
"""
tokenized the input_text
x: batch of input_text
"""
new_examples = {
"id": [],
"input_text": [],
"start_positions": [],
"end_positions": [],
"input_ids": [],
"attention_mask": [],
}
for i in range(len(x["id"])):
if x["end_positions"][i] >= 512:
continue
# clean remove all the punctuation
input_text = [
# re.sub(r"[^\w\s]", "", t)
(
self.tokenizer.backend_tokenizer.normalizer.normalize_str(t)
if t not in self.tokenizer.all_special_tokens
else t
)
for t in x["input_text"][i]
]
# input_text = self.tokenizer.backend_tokenizer.normalizer.normalize_str(
# x["input_text"][i]
# )
new_examples["id"].append(x["id"][i])
new_examples["input_text"].append(input_text)
new_examples["start_positions"].append(x["start_positions"][i])
new_examples["end_positions"].append(x["end_positions"][i])
tkn = self.tokenizer(
x["input_text"][i],
padding="max_length",
truncation=True,
max_length=512,
is_split_into_words=True,
)
new_examples["input_ids"].append(tkn["input_ids"])
new_examples["attention_mask"].append(tkn["attention_mask"])
return new_examples
def setup(self, stage):
if self.data_hf != "":
self.ds = load_dataset(self.data_hf)
else:
self.ds = load_dataset(self.data_model_name_or_path)
self.ds = self.ds.map(
self._handle1,
batched=True,
remove_columns=self.ds["train"].column_names,
)
self.ds = self.ds.map(
self._handle2,
batched=True,
remove_columns=self.ds["train"].column_names,
)
self.squad_train = self.ds["train"]
self.squad_train.set_format(type="torch")
self.squad_test = self.ds["validation"]
self.squad_test.set_format(type="torch")
def train_dataloader(self):
return DataLoader(
self.squad_train.remove_columns(["id", "input_text"]),
batch_size=self.batch_size,
num_workers=8,
)
def val_dataloader(self):
return DataLoader(
self.squad_test.remove_columns(["id", "input_text"]),
batch_size=self.batch_size,
num_workers=8,
)
if __name__ == "__main__":
from datasets import load_dataset
data = Squad_v2(data_from_hf="eming/squad_v2_processed")
data.setup()
# 将数据集推送到 Hugging Face Hub
data.ds.push_to_hub(
"eming/squad_v2_processed",
private=False,
)