TintinMeimei commited on
Commit
8df8155
·
1 Parent(s): 744ae9f

Delete nlp/pretrain_bert_mlm.py

Browse files
Files changed (1) hide show
  1. nlp/pretrain_bert_mlm.py +0 -153
nlp/pretrain_bert_mlm.py DELETED
@@ -1,153 +0,0 @@
1
- import codecs
2
- import collections
3
- from datetime import datetime
4
- import json
5
- import numpy as np
6
- import random
7
-
8
- from datasets import Dataset, DatasetDict, load_dataset
9
- import torch
10
- from transformers import (
11
- AutoTokenizer,
12
- AutoModelForMaskedLM,
13
- default_data_collator,
14
- DataCollatorForLanguageModeling,
15
- TrainingArguments,
16
- Trainer,
17
- )
18
-
19
-
20
- model_path = '../models/bert-base-uncased'
21
- tokenizer = AutoTokenizer.from_pretrained(model_path, padding=True, truncation=True, max_length=512)
22
- model = AutoModelForMaskedLM.from_pretrained(model_path)
23
-
24
-
25
- # Data
26
- def chunkize(text, n_words=300, overlap=150):
27
- words = text.split()
28
- if len(words) < n_words:
29
- return [' '.join(words)]
30
- else:
31
- return [' '.join(words[i: i+n_words]) for i in range(0, len(words)-n_words+1, n_words-overlap)]
32
-
33
- def tokenize_function(examples):
34
- result = tokenizer(examples["text"])
35
- return result
36
-
37
- def group_tokens(examples):
38
- # Take a batch of tokens and group them into lines with same chunk size
39
- chunk_size = 384
40
- n = chunk_size - 2
41
- input_ids_all = []
42
- cls_id = 101
43
- sep_id = 102
44
- for item in examples['input_ids']:
45
- input_ids_each = item[1: -1] # Get rid of the first [CLS] and the last [SEP] token
46
- input_ids_all += input_ids_each
47
- result = {
48
- 'input_ids': [],
49
- 'token_type_ids': [],
50
- 'attention_mask': [],
51
- }
52
- chunk = []
53
- for i in range(len(input_ids_all)):
54
- chunk.append(input_ids_all[i])
55
- if (i+1) % n == 0: # complete a chunk
56
- result['input_ids'].append([101]+chunk.copy()+[102])
57
- result['token_type_ids'].append([0 for j in range(len(chunk)+2)])
58
- result['attention_mask'].append([1 for j in range(len(chunk)+2)])
59
- chunk = []
60
- if len(chunk) > 0:
61
- result['input_ids'].append([101]+chunk.copy()+[102])
62
- result['token_type_ids'].append([0 for j in range(len(chunk)+2)])
63
- result['attention_mask'].append([1 for j in range(len(chunk)+2)])
64
- return result
65
-
66
-
67
- def load_data(data_path):
68
- with codecs.open(f'{data_path}/train.json', 'r', encoding='utf-8') as f:
69
- data = json.load(f)
70
- train_data = Dataset.from_dict({'text': data})
71
- with codecs.open(f'{data_path}/eval.json', 'r', encoding='utf-8') as f:
72
- data = json.load(f)
73
- test_data = Dataset.from_dict({'text': data})
74
- data_hf = DatasetDict({'train': train_data, 'test': test_data})
75
- # Tokenize
76
- dataset_tokens = data_hf.map(tokenize_function, batched=True, remove_columns=["text"])
77
- # Make each text item have the same length
78
- dataset_tokens_group= dataset_tokens.map(group_tokens, batched=True)
79
- return dataset_tokens_group
80
-
81
- def build_train_valid_data(data_path, target_path, test_size=0.2, n_words=300, overlap=150):
82
- # Load raw data
83
- with codecs.open(data_path, 'r', encoding='utf-8') as f:
84
- data_raw = json.load(f)
85
- data_clean = []
86
- for text in data_raw:
87
- # there are float in data_raw
88
- try:
89
- nouse = len(text)
90
- data_clean.append(text)
91
- except:
92
- continue
93
- # Chunkize
94
- data = []
95
- for text in data_clean:
96
- list_text = chunkize(text, n_words=n_words, overlap=overlap)
97
- for r in list_text:
98
- data.append(r)
99
- n_train = int(len(data)*(1-test_size))
100
- set_data = set(data)
101
- set_train_data = set(random.sample(set_data, n_train))
102
- set_eval_data = set_data - set_train_data
103
- train_data = list(set_train_data)
104
- eval_data = list(set_eval_data)
105
- with codecs.open(f'{target_path}/train.json', 'w', encoding='utf-8') as w:
106
- json.dump(train_data, w, ensure_ascii=False)
107
- with codecs.open(f'{target_path}/eval.json', 'w', encoding='utf-8') as w:
108
- json.dump(eval_data, w, ensure_ascii=False)
109
-
110
- # Model
111
- def train(checkpoints_dir, target_dir, dataset_tokens_group):
112
- data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15, pad_to_multiple_of=32)
113
- batch_size = 64
114
- # model_name = 'bert-base-uncased-finetune-mlm-hashtag'
115
- training_args = TrainingArguments(
116
- output_dir=checkpoints_dir,
117
- overwrite_output_dir=True,
118
- evaluation_strategy = 'steps',
119
- eval_steps = 1000,
120
- learning_rate=3e-5,
121
- weight_decay=0.01,
122
- per_device_train_batch_size=batch_size,
123
- per_device_eval_batch_size=batch_size,
124
- fp16=True,
125
- save_strategy = 'steps',
126
- save_steps = 20000,
127
- logging_strategy = 'steps',
128
- report_to="none",
129
- num_train_epochs=160,
130
- )
131
- trainer = Trainer(
132
- model=model,
133
- args=training_args,
134
- train_dataset=dataset_tokens_group["train"],
135
- eval_dataset=dataset_tokens_group["test"],
136
- data_collator=data_collator,
137
- tokenizer=tokenizer,
138
- )
139
- trainer.train()
140
- trainer.save_model(target_dir)
141
-
142
-
143
- if __name__ == '__main__':
144
- data_path = '../data/hashtags/dataset_hashtag_english_pretrain.json'
145
- target_data_hf_path = './data_hf_for_pretrain_bert_mlm'
146
- # build_train_valid_data(data_path, target_data_hf_path, test_size=0.2, n_words=300, overlap=150)
147
- # exit()
148
- dataset_tokens_group = load_data(target_data_hf_path)
149
-
150
- checkpoints_dir = f'../checkpoints/bert-base-uncased-finetune-mlm-hashtag'
151
- now = datetime.strftime(datetime.now(), '%Y_%m_%d_%H_%M')
152
- target_dir = f'models_pretrain_bert_mlm/bert-base-uncased-finetune-mlm-hashtag-{now}'
153
- train(checkpoints_dir, target_dir, dataset_tokens_group)