File size: 2,388 Bytes
ceedef8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from datasets import Dataset, load_dataset, concatenate_datasets
import datasets
from transformers import GPT2TokenizerFast
from tokenizers.processors import TemplateProcessing

input_dir = "dataset_location"
tokenizer_file="path/to/file"
output_dir="output/dir"
tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_file)
#Add eos tokens to the tokenization pipeline as they are not added otherwise
tokenizer._tokenizer.post_processor = TemplateProcessing(
    single="$0 "+tokenizer.eos_token,
    pair="$A "+tokenizer.eos_token+" $B:1 "+tokenizer.eos_token,
    special_tokens=[(tokenizer.eos_token, 0)],
)

def tokenize_function(examples):
    return tokenizer(examples["text"])


def group_texts(examples):
    #group texts. This is based on Hugging Face CLM example
    block_size = 1024
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_len = len(concatenated_examples[list(examples.keys())[0]])
    total_len = (total_len//block_size) * block_size
    result = {
        k: [t[i:i+block_size] for i in range(0, total_len, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

def main():
    num_proc=12 #set to something appropriate
    dataset = datasets.load_from_disk(input_dir) #This one load a saved dataset object from disk. You could create a dataset from iterable or load one like:
    #dataset = load_dataset("Finnish-NLP/mc4_fi_cleaned", split="train").remove_columns(["timestamp","url"]) #Example usage from Hugging Face Hub

    #Tokenize, filter out very short texts and group texts to blocks of attention size
    dataset\
        .shuffle(seed=42, load_from_cache_file=False, writer_batch_size=100000)\
        .map(tokenize_function, batched=True, num_proc=num_proc, remove_columns=dataset.column_names, load_from_cache_file=False, writer_batch_size=100000)\
        .filter(lambda e: len(e["input_ids"]) > 20, num_proc=num_proc, load_from_cache_file=False, writer_batch_size=100000)\
        .map(group_texts, batched=True, num_proc=num_proc, load_from_cache_file=False, writer_batch_size=100000)\
        .train_test_split(test_size=0.05, load_from_cache_file=False, writer_batch_size=100000)\
        .save_to_disk(output_dir)
    print(dataset)

if __name__ == "__main__":
    main()