Spaces:
Sleeping
Sleeping
Commit
·
8df8155
1
Parent(s):
744ae9f
Delete nlp/pretrain_bert_mlm.py
Browse files- nlp/pretrain_bert_mlm.py +0 -153
nlp/pretrain_bert_mlm.py
DELETED
@@ -1,153 +0,0 @@
|
|
1 |
-
import codecs
|
2 |
-
import collections
|
3 |
-
from datetime import datetime
|
4 |
-
import json
|
5 |
-
import numpy as np
|
6 |
-
import random
|
7 |
-
|
8 |
-
from datasets import Dataset, DatasetDict, load_dataset
|
9 |
-
import torch
|
10 |
-
from transformers import (
|
11 |
-
AutoTokenizer,
|
12 |
-
AutoModelForMaskedLM,
|
13 |
-
default_data_collator,
|
14 |
-
DataCollatorForLanguageModeling,
|
15 |
-
TrainingArguments,
|
16 |
-
Trainer,
|
17 |
-
)
|
18 |
-
|
19 |
-
|
20 |
-
model_path = '../models/bert-base-uncased'
|
21 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path, padding=True, truncation=True, max_length=512)
|
22 |
-
model = AutoModelForMaskedLM.from_pretrained(model_path)
|
23 |
-
|
24 |
-
|
25 |
-
# Data
|
26 |
-
def chunkize(text, n_words=300, overlap=150):
|
27 |
-
words = text.split()
|
28 |
-
if len(words) < n_words:
|
29 |
-
return [' '.join(words)]
|
30 |
-
else:
|
31 |
-
return [' '.join(words[i: i+n_words]) for i in range(0, len(words)-n_words+1, n_words-overlap)]
|
32 |
-
|
33 |
-
def tokenize_function(examples):
|
34 |
-
result = tokenizer(examples["text"])
|
35 |
-
return result
|
36 |
-
|
37 |
-
def group_tokens(examples):
|
38 |
-
# Take a batch of tokens and group them into lines with same chunk size
|
39 |
-
chunk_size = 384
|
40 |
-
n = chunk_size - 2
|
41 |
-
input_ids_all = []
|
42 |
-
cls_id = 101
|
43 |
-
sep_id = 102
|
44 |
-
for item in examples['input_ids']:
|
45 |
-
input_ids_each = item[1: -1] # Get rid of the first [CLS] and the last [SEP] token
|
46 |
-
input_ids_all += input_ids_each
|
47 |
-
result = {
|
48 |
-
'input_ids': [],
|
49 |
-
'token_type_ids': [],
|
50 |
-
'attention_mask': [],
|
51 |
-
}
|
52 |
-
chunk = []
|
53 |
-
for i in range(len(input_ids_all)):
|
54 |
-
chunk.append(input_ids_all[i])
|
55 |
-
if (i+1) % n == 0: # complete a chunk
|
56 |
-
result['input_ids'].append([101]+chunk.copy()+[102])
|
57 |
-
result['token_type_ids'].append([0 for j in range(len(chunk)+2)])
|
58 |
-
result['attention_mask'].append([1 for j in range(len(chunk)+2)])
|
59 |
-
chunk = []
|
60 |
-
if len(chunk) > 0:
|
61 |
-
result['input_ids'].append([101]+chunk.copy()+[102])
|
62 |
-
result['token_type_ids'].append([0 for j in range(len(chunk)+2)])
|
63 |
-
result['attention_mask'].append([1 for j in range(len(chunk)+2)])
|
64 |
-
return result
|
65 |
-
|
66 |
-
|
67 |
-
def load_data(data_path):
|
68 |
-
with codecs.open(f'{data_path}/train.json', 'r', encoding='utf-8') as f:
|
69 |
-
data = json.load(f)
|
70 |
-
train_data = Dataset.from_dict({'text': data})
|
71 |
-
with codecs.open(f'{data_path}/eval.json', 'r', encoding='utf-8') as f:
|
72 |
-
data = json.load(f)
|
73 |
-
test_data = Dataset.from_dict({'text': data})
|
74 |
-
data_hf = DatasetDict({'train': train_data, 'test': test_data})
|
75 |
-
# Tokenize
|
76 |
-
dataset_tokens = data_hf.map(tokenize_function, batched=True, remove_columns=["text"])
|
77 |
-
# Make each text item have the same length
|
78 |
-
dataset_tokens_group= dataset_tokens.map(group_tokens, batched=True)
|
79 |
-
return dataset_tokens_group
|
80 |
-
|
81 |
-
def build_train_valid_data(data_path, target_path, test_size=0.2, n_words=300, overlap=150):
|
82 |
-
# Load raw data
|
83 |
-
with codecs.open(data_path, 'r', encoding='utf-8') as f:
|
84 |
-
data_raw = json.load(f)
|
85 |
-
data_clean = []
|
86 |
-
for text in data_raw:
|
87 |
-
# there are float in data_raw
|
88 |
-
try:
|
89 |
-
nouse = len(text)
|
90 |
-
data_clean.append(text)
|
91 |
-
except:
|
92 |
-
continue
|
93 |
-
# Chunkize
|
94 |
-
data = []
|
95 |
-
for text in data_clean:
|
96 |
-
list_text = chunkize(text, n_words=n_words, overlap=overlap)
|
97 |
-
for r in list_text:
|
98 |
-
data.append(r)
|
99 |
-
n_train = int(len(data)*(1-test_size))
|
100 |
-
set_data = set(data)
|
101 |
-
set_train_data = set(random.sample(set_data, n_train))
|
102 |
-
set_eval_data = set_data - set_train_data
|
103 |
-
train_data = list(set_train_data)
|
104 |
-
eval_data = list(set_eval_data)
|
105 |
-
with codecs.open(f'{target_path}/train.json', 'w', encoding='utf-8') as w:
|
106 |
-
json.dump(train_data, w, ensure_ascii=False)
|
107 |
-
with codecs.open(f'{target_path}/eval.json', 'w', encoding='utf-8') as w:
|
108 |
-
json.dump(eval_data, w, ensure_ascii=False)
|
109 |
-
|
110 |
-
# Model
|
111 |
-
def train(checkpoints_dir, target_dir, dataset_tokens_group):
|
112 |
-
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15, pad_to_multiple_of=32)
|
113 |
-
batch_size = 64
|
114 |
-
# model_name = 'bert-base-uncased-finetune-mlm-hashtag'
|
115 |
-
training_args = TrainingArguments(
|
116 |
-
output_dir=checkpoints_dir,
|
117 |
-
overwrite_output_dir=True,
|
118 |
-
evaluation_strategy = 'steps',
|
119 |
-
eval_steps = 1000,
|
120 |
-
learning_rate=3e-5,
|
121 |
-
weight_decay=0.01,
|
122 |
-
per_device_train_batch_size=batch_size,
|
123 |
-
per_device_eval_batch_size=batch_size,
|
124 |
-
fp16=True,
|
125 |
-
save_strategy = 'steps',
|
126 |
-
save_steps = 20000,
|
127 |
-
logging_strategy = 'steps',
|
128 |
-
report_to="none",
|
129 |
-
num_train_epochs=160,
|
130 |
-
)
|
131 |
-
trainer = Trainer(
|
132 |
-
model=model,
|
133 |
-
args=training_args,
|
134 |
-
train_dataset=dataset_tokens_group["train"],
|
135 |
-
eval_dataset=dataset_tokens_group["test"],
|
136 |
-
data_collator=data_collator,
|
137 |
-
tokenizer=tokenizer,
|
138 |
-
)
|
139 |
-
trainer.train()
|
140 |
-
trainer.save_model(target_dir)
|
141 |
-
|
142 |
-
|
143 |
-
if __name__ == '__main__':
|
144 |
-
data_path = '../data/hashtags/dataset_hashtag_english_pretrain.json'
|
145 |
-
target_data_hf_path = './data_hf_for_pretrain_bert_mlm'
|
146 |
-
# build_train_valid_data(data_path, target_data_hf_path, test_size=0.2, n_words=300, overlap=150)
|
147 |
-
# exit()
|
148 |
-
dataset_tokens_group = load_data(target_data_hf_path)
|
149 |
-
|
150 |
-
checkpoints_dir = f'../checkpoints/bert-base-uncased-finetune-mlm-hashtag'
|
151 |
-
now = datetime.strftime(datetime.now(), '%Y_%m_%d_%H_%M')
|
152 |
-
target_dir = f'models_pretrain_bert_mlm/bert-base-uncased-finetune-mlm-hashtag-{now}'
|
153 |
-
train(checkpoints_dir, target_dir, dataset_tokens_group)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|