|
from datasets import concatenate_datasets, load_dataset |
|
|
|
|
|
|
|
|
|
|
|
bookcorpus = load_dataset("bookcorpus", split="train") |
|
wiki = load_dataset("wikipedia", "20220301.en", split="train") |
|
wiki = wiki.remove_columns([col for col in wiki.column_names if col != "text"]) |
|
|
|
assert bookcorpus.features.type == wiki.features.type |
|
raw_datasets = concatenate_datasets([bookcorpus, wiki]) |
|
print(raw_datasets) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer |
|
import multiprocessing |
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("cat_tokenizer") |
|
num_proc = min(multiprocessing.cpu_count(), 8) |
|
print(f"The max length for the tokenizer is: {tokenizer.model_max_length}") |
|
|
|
def group_texts(examples): |
|
tokenized_inputs = tokenizer( |
|
examples["text"], return_special_tokens_mask=True, truncation=True, max_length=tokenizer.model_max_length |
|
) |
|
return tokenized_inputs |
|
|
|
|
|
tokenized_datasets = raw_datasets.map(group_texts, batched=True, remove_columns=["text"], num_proc=num_proc) |
|
print(tokenized_datasets.features) |
|
|
|
|
|
|
|
from itertools import chain |
|
|
|
|
|
|
|
def group_texts(examples): |
|
|
|
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} |
|
total_length = len(concatenated_examples[list(examples.keys())[0]]) |
|
|
|
|
|
if total_length >= tokenizer.model_max_length: |
|
total_length = (total_length // tokenizer.model_max_length) * tokenizer.model_max_length |
|
|
|
result = { |
|
k: [t[i : i + tokenizer.model_max_length] for i in range(0, total_length, tokenizer.model_max_length)] |
|
for k, t in concatenated_examples.items() |
|
} |
|
return result |
|
|
|
tokenized_datasets = tokenized_datasets.map(group_texts, batched=True, num_proc=num_proc) |
|
|
|
tokenized_datasets = tokenized_datasets.shuffle(seed=34) |
|
|
|
print(tokenized_datasets) |
|
print(f"the dataset contains in total {len(tokenized_datasets)*tokenizer.model_max_length} tokens") |
|
|
|
|
|
|
|
user_id = 'chaoyan' |
|
|
|
dataset_id=f"{user_id}/processed_bert_dataset" |
|
tokenized_datasets.push_to_hub(dataset_id) |