from datasets import load_dataset from torch.utils.data import DataLoader import lightning as L from transformers import AutoTokenizer class Squad_v2(L.LightningDataModule): ds = None data_model_name_or_path = "" tokenizer_model_name_or_path = "" batch_size = 32 def __init__( self, *, data_model_name_or_path: str = "rajpurkar/squad_v2", tokenizer_model_name_or_path="google-bert/bert-base-uncased", batch_size: int = 32, data_from_hf: str = "eming/squad_v2_processed", ): super().__init__() self.data_model_name_or_path = data_model_name_or_path self.batch_size = batch_size self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_name_or_path) self.data_hf = data_from_hf def _handle1(self, x): # 初始化返回的字典 new_examples = { "id": [], "input_text": [], "start_positions": [], "end_positions": [], } for i in range(len(x["id"])): new_examples["id"].append(x["id"][i]) context = x["context"][i] question = x["question"][i] answers = x["answers"][i]["text"] answer_start = x["answers"][i]["answer_start"] # Step 1: Split the context into smaller chunks context_list = context.split(" ") # 假设以空格进行分割 # Step 2: Handle the case where there is no answer if not answers: # 如果答案为空 start_positions = 0 end_positions = 0 else: # 获取分割后的答案 answer_split = answers[0].split(" ") # Step 3: Calculate the answer_start position in the context_list start_char = answer_start[0] end_char = start_char + len(" ".join(answer_split)) - 1 # 计算答案结束的字符位置 start_positions = None end_positions = None char_count = 0 for idx, word in enumerate(context_list): char_count += len(word) + 1 # 考虑到空格 # assert context[char_count - 1] == " " if char_count > start_char and start_positions is None: start_positions = idx if char_count > end_char and end_positions is None: end_positions = idx break # if end_char == len(context): # end_positions = idx assert start_positions is not None assert end_positions is not None # Step 4: Create the input format for BERT [CLS] [SEP] question_split = question.split(" ") input_text = ["[CLS]"] + question_split + ["[SEP]"] + context_list # Step 5: Adjust the answer positions relative to the input_text # Since `[CLS]`, ``, and `[SEP]` are part of the input, we need to offset answer positions if start_positions != -1 and end_positions != -1: # Add the number of tokens in the question and [SEP] start_positions += 1 + len(question_split) + 1 # +1 for [CLS] and +1 for [SEP] end_positions += 1 + len(question_split) + 1 # Step 6: Append the results to the dictionary new_examples["input_text"].append(input_text) new_examples["start_positions"].append(start_positions) new_examples["end_positions"].append(end_positions) return new_examples def _handle2(self, x): """ tokenized the input_text x: batch of input_text """ new_examples = { "id": [], "input_text": [], "start_positions": [], "end_positions": [], "input_ids": [], "attention_mask": [], } for i in range(len(x["id"])): if x["end_positions"][i] >= 512: continue # clean remove all the punctuation input_text = [ # re.sub(r"[^\w\s]", "", t) ( self.tokenizer.backend_tokenizer.normalizer.normalize_str(t) if t not in self.tokenizer.all_special_tokens else t ) for t in x["input_text"][i] ] # input_text = self.tokenizer.backend_tokenizer.normalizer.normalize_str( # x["input_text"][i] # ) new_examples["id"].append(x["id"][i]) new_examples["input_text"].append(input_text) new_examples["start_positions"].append(x["start_positions"][i]) new_examples["end_positions"].append(x["end_positions"][i]) tkn = self.tokenizer( x["input_text"][i], padding="max_length", truncation=True, max_length=512, is_split_into_words=True, ) new_examples["input_ids"].append(tkn["input_ids"]) new_examples["attention_mask"].append(tkn["attention_mask"]) return new_examples def setup(self, stage): if self.data_hf != "": self.ds = load_dataset(self.data_hf) else: self.ds = load_dataset(self.data_model_name_or_path) self.ds = self.ds.map( self._handle1, batched=True, remove_columns=self.ds["train"].column_names, ) self.ds = self.ds.map( self._handle2, batched=True, remove_columns=self.ds["train"].column_names, ) self.squad_train = self.ds["train"] self.squad_train.set_format(type="torch") self.squad_test = self.ds["validation"] self.squad_test.set_format(type="torch") def train_dataloader(self): return DataLoader( self.squad_train.remove_columns(["id", "input_text"]), batch_size=self.batch_size, num_workers=8, ) def val_dataloader(self): return DataLoader( self.squad_test.remove_columns(["id", "input_text"]), batch_size=self.batch_size, num_workers=8, ) if __name__ == "__main__": from datasets import load_dataset data = Squad_v2(data_from_hf="eming/squad_v2_processed") data.setup() # 将数据集推送到 Hugging Face Hub data.ds.push_to_hub( "eming/squad_v2_processed", private=False, )