nreimers commited on
Commit
b034ae1
1 Parent(s): 164bae3
README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # DistilBERT with 256k token embeddings
2
+
3
+ This model was initialized with a word2vec token embedding matrix with 256k entries, but these token embeddings were updated during MLM. The word2vec was trained on 100GB data from C4, MSMARCO, News, Wikipedia, S2ORC, for 3 epochs.
4
+
5
+ Then the model was trained on this dataset with MLM for 250k steps (batch size 64). The token embeddings were updated during MLM.
6
+
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "train-w2v-model/c4_msmarco_news_s2orc_wiki/distilbert-256k/",
3
+ "activation": "gelu",
4
+ "architectures": [
5
+ "DistilBertForMaskedLM"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "dim": 768,
9
+ "dropout": 0.1,
10
+ "hidden_dim": 3072,
11
+ "initializer_range": 0.02,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "distilbert",
14
+ "n_heads": 12,
15
+ "n_layers": 6,
16
+ "pad_token_id": 0,
17
+ "qa_dropout": 0.1,
18
+ "seq_classif_dropout": 0.2,
19
+ "sinusoidal_pos_embds": false,
20
+ "tie_weights_": true,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.17.0",
23
+ "vocab_size": 256000
24
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:889a1a38de7adf5179d6161fecec2df90d29a984e6a9aabcd2b7b2e4dc2ca91c
3
+ size 961553391
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"model_max_length": 512, "unk_token": "[UNK]", "cls_token": "[CLS]", "sep_token": "[SEP]", "pad_token": "[PAD]", "mask_token": "[MASK]", "model_input_names": ["input_ids", "attention_mask"], "special_tokens_map_file": "c4_msmarco_news_s2orc_wiki/tokenizer-256k/special_tokens_map.json", "name_or_path": "train-w2v-model/c4_msmarco_news_s2orc_wiki/distilbert-256k/", "tokenizer_class": "PreTrainedTokenizerFast"}
train_script.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+ import logging
4
+ import math
5
+ import os
6
+ from datetime import datetime
7
+ import datasets
8
+ import torch
9
+ from torch.utils.data import DataLoader
10
+ from tqdm.auto import tqdm
11
+ import sys
12
+ import transformers
13
+ from accelerate import Accelerator, DistributedType
14
+ from shutil import copyfile
15
+ import wandb
16
+ import numpy as np
17
+
18
+ from transformers import (
19
+ MODEL_MAPPING,
20
+ AutoModelForMaskedLM,
21
+ AutoTokenizer,
22
+ DataCollatorForLanguageModeling,
23
+ SchedulerType,
24
+ get_scheduler
25
+ )
26
+ from transformers.utils.versions import require_version
27
+
28
+
29
+
30
+ class TrainDataset(torch.utils.data.IterableDataset):
31
+ def __init__(self, filepath, tokenizer, max_length, batch_size, train_samples):
32
+ self.tokenizer = tokenizer
33
+ self.fIn = open(filepath)
34
+ self.max_length = max_length
35
+ self.batch_size = batch_size
36
+ self.train_samples = train_samples
37
+
38
+ def __iter__(self):
39
+ batch = []
40
+ for sent in self.fIn:
41
+ batch.append(sent.strip()[0:1000])
42
+
43
+ if len(batch) >= self.batch_size:
44
+ #Use multi process tokenization
45
+ encoded = self.tokenizer(batch, add_special_tokens=True, truncation=True, max_length=self.max_length, return_special_tokens_mask=True, padding=True)
46
+ #print(len(encoded['input_ids'][0]))
47
+ for idx in range(len(batch)):
48
+ single_sample = {key: encoded[key][idx] for key in encoded}
49
+ yield single_sample
50
+
51
+ batch = []
52
+
53
+ def __len__(self):
54
+ return self.train_samples
55
+
56
+
57
+
58
+
59
+
60
+ ## Dev dataset
61
+ class DevDataset(torch.utils.data.Dataset):
62
+ def __init__(self, filepath, tokenizer, max_length):
63
+ self.tokenizer = tokenizer
64
+ self.max_length = max_length
65
+ with open(filepath) as fIn:
66
+ sentences = [sent.strip() for sent in fIn]
67
+
68
+ self.num_sentences = len(sentences)
69
+ self.tokenized = self.tokenizer(sentences, add_special_tokens=True, truncation=True, max_length=self.max_length, return_special_tokens_mask=True)
70
+
71
+ def __getitem__(self, idx):
72
+ return {key: self.tokenized[key][idx] for key in self.tokenized}
73
+
74
+ def __len__(self):
75
+ return self.num_sentences
76
+
77
+
78
+
79
+ logger = logging.getLogger(__name__)
80
+ require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
81
+ MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
82
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
83
+
84
+
85
+ def parse_args():
86
+ parser = argparse.ArgumentParser(description="Finetune a transformers model on a Masked Language Modeling task")
87
+ parser.add_argument(
88
+ "--dataset_config_name",
89
+ type=str,
90
+ default=None,
91
+ help="The configuration name of the dataset to use (via the datasets library).",
92
+ )
93
+ parser.add_argument(
94
+ "--train_file", type=str, default=None, help="A text file data (1 text per line).."
95
+ )
96
+ parser.add_argument(
97
+ "--dev_file", type=str, default=None, help="A text file data (1 text per line)."
98
+ )
99
+ parser.add_argument(
100
+ "--model_name",
101
+ default="nicoladecao/msmarco-word2vec256000-distilbert-base-uncased",
102
+ type=str,
103
+ help="Path to pretrained model or model identifier from huggingface.co/models."
104
+ )
105
+ parser.add_argument(
106
+ "--per_device_batch_size",
107
+ type=int,
108
+ default=16,
109
+ help="Batch size (per device) for the training dataloader.",
110
+ )
111
+ parser.add_argument(
112
+ "--learning_rate",
113
+ type=float,
114
+ default=5e-5,
115
+ help="Initial learning rate (after the potential warmup period) to use.",
116
+ )
117
+ parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.")
118
+ parser.add_argument("--num_train_epochs", type=int, default=1, help="Total number of training epochs to perform.")
119
+ parser.add_argument(
120
+ "--max_train_steps",
121
+ type=int,
122
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
123
+ )
124
+ parser.add_argument(
125
+ "--gradient_accumulation_steps",
126
+ type=int,
127
+ default=1,
128
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
129
+ )
130
+ parser.add_argument(
131
+ "--lr_scheduler_type",
132
+ type=SchedulerType,
133
+ default="linear",
134
+ help="The scheduler type to use.",
135
+ choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
136
+ )
137
+ parser.add_argument(
138
+ "--num_warmup_steps", type=int, default=1000, help="Number of steps for the warmup in the lr scheduler."
139
+ )
140
+ parser.add_argument(
141
+ "--model_type",
142
+ type=str,
143
+ default=None,
144
+ help="Model type to use if training from scratch.",
145
+ choices=MODEL_TYPES,
146
+ )
147
+ parser.add_argument(
148
+ "--max_seq_length",
149
+ type=int,
150
+ default=256,
151
+ help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated.",
152
+ )
153
+ parser.add_argument(
154
+ "--line_by_line",
155
+ type=bool,
156
+ default=True,
157
+ help="Whether distinct lines of text in the dataset are to be handled as distinct sequences.",
158
+ )
159
+ parser.add_argument(
160
+ "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
161
+ )
162
+ parser.add_argument(
163
+ "--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss"
164
+ )
165
+ parser.add_argument("--mixed_precision", default="fp16")
166
+ parser.add_argument("--train_samples", required=True, type=int)
167
+ parser.add_argument("--eval_steps", default=10000, type=int)
168
+ parser.add_argument("--max_grad_norm", default=1.0, type=float)
169
+ parser.add_argument("--project", default="bert-word2vec")
170
+ parser.add_argument("--freeze_emb_layer", default=False, action='store_true')
171
+ parser.add_argument("--log_interval", default=1000, type=int)
172
+ parser.add_argument("--ckp_steps", default=50000, type=int)
173
+
174
+ args = parser.parse_args()
175
+
176
+
177
+ return args
178
+
179
+
180
+ def main():
181
+ args = parse_args()
182
+
183
+ # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
184
+ accelerator = Accelerator(mixed_precision=args.mixed_precision)
185
+ # Make one log on every process with the configuration for debugging.
186
+ logging.basicConfig(
187
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
188
+ datefmt="%m/%d/%Y %H:%M:%S",
189
+ level=logging.INFO,
190
+ )
191
+ logger.info(accelerator.state)
192
+
193
+ # Setup logging, we only want one process per machine to log things on the screen.
194
+ # accelerator.is_local_main_process is only True for one process per machine.
195
+ logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
196
+ if accelerator.is_local_main_process:
197
+ datasets.utils.logging.set_verbosity_warning()
198
+ transformers.utils.logging.set_verbosity_info()
199
+ else:
200
+ datasets.utils.logging.set_verbosity_error()
201
+ transformers.utils.logging.set_verbosity_error()
202
+
203
+
204
+ accelerator.wait_for_everyone()
205
+
206
+
207
+ #Load model
208
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
209
+ model = AutoModelForMaskedLM.from_pretrained(args.model_name)
210
+
211
+ #Freeze emb layer
212
+ if args.freeze_emb_layer:
213
+ model.distilbert.embeddings.word_embeddings.requires_grad_(False)
214
+
215
+ # Logging & Co on main process
216
+ if accelerator.is_main_process:
217
+ exp_name = f'{args.model_name.replace("/", "-")}-{"freeze_emb" if args.freeze_emb_layer else "update_emb"}-{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}'
218
+ output_dir = os.path.join("output-mlm", exp_name)
219
+ wandb.init(project=args.project, name=exp_name, config=args)
220
+
221
+ os.makedirs(output_dir, exist_ok=False)
222
+
223
+ #Save tokenizer
224
+ tokenizer.save_pretrained(output_dir)
225
+
226
+ #Save train script
227
+ train_script_path = os.path.join(output_dir, 'train_script.py')
228
+ copyfile(__file__, train_script_path)
229
+ with open(train_script_path, 'a') as fOut:
230
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
231
+
232
+
233
+ total_batch_size = args.per_device_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
234
+
235
+ train_dataset = TrainDataset(args.train_file, tokenizer, args.max_seq_length, batch_size=total_batch_size, train_samples=args.train_samples)
236
+ eval_dataset = DevDataset(args.dev_file, tokenizer, args.max_seq_length)
237
+
238
+
239
+ # Data collator
240
+ # This one will take care of randomly masking the tokens.
241
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=args.mlm_probability)
242
+
243
+ # DataLoaders creation:
244
+ train_dataloader = DataLoader(train_dataset, collate_fn=data_collator, batch_size=args.per_device_batch_size)
245
+ eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_batch_size)
246
+
247
+ # Optimizer
248
+ # Split weights in two groups, one with weight decay and the other not.
249
+ no_decay = ["bias", "LayerNorm.weight"]
250
+ optimizer_grouped_parameters = [
251
+ {
252
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
253
+ "weight_decay": args.weight_decay,
254
+ },
255
+ {
256
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
257
+ "weight_decay": 0.0,
258
+ },
259
+ ]
260
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
261
+
262
+ # Prepare everything with our `accelerator`.
263
+ model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer, train_dataloader, eval_dataloader)
264
+
265
+ # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
266
+ if accelerator.distributed_type == DistributedType.TPU:
267
+ model.tie_weights()
268
+
269
+ # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
270
+ # shorter in multiprocess)
271
+
272
+ # Scheduler and math around the number of training steps.
273
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
274
+ if args.max_train_steps is None:
275
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
276
+ else:
277
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
278
+
279
+ lr_scheduler = get_scheduler(
280
+ name=args.lr_scheduler_type,
281
+ optimizer=optimizer,
282
+ num_warmup_steps=args.num_warmup_steps,
283
+ num_training_steps=args.max_train_steps,
284
+ )
285
+
286
+
287
+ # Train!
288
+ logger.info("***** Running training *****")
289
+ logger.info(f" Num examples = {args.train_samples}")
290
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
291
+ logger.info(f" Instantaneous batch size per device = {args.per_device_batch_size}")
292
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
293
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
294
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
295
+ # Only show the progress bar once on each machine.
296
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process, smoothing=0.05)
297
+ completed_steps = 0
298
+ train_loss_values = []
299
+
300
+ best_eval_loss = 999999
301
+ if accelerator.is_main_process:
302
+ best_ckp_dir = os.path.join(output_dir, "best")
303
+ tokenizer.save_pretrained(best_ckp_dir)
304
+
305
+ for epoch in range(args.num_train_epochs):
306
+ logger.info(f"Start epoch {epoch}")
307
+ model.train()
308
+ for step, batch in enumerate(train_dataloader):
309
+ outputs = model(**batch)
310
+ loss = outputs.loss
311
+ loss = loss / args.gradient_accumulation_steps
312
+
313
+ if accelerator.is_main_process:
314
+ train_loss_values.append(loss.cpu().item())
315
+
316
+ accelerator.backward(loss)
317
+ accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
318
+ if step % args.gradient_accumulation_steps == 0:
319
+ optimizer.step()
320
+ lr_scheduler.step()
321
+ optimizer.zero_grad()
322
+ progress_bar.update(1)
323
+ completed_steps += 1
324
+
325
+ ### Do logging
326
+ if accelerator.is_main_process:
327
+ if completed_steps % args.log_interval == 0:
328
+ wandb.log({"train/loss": np.mean(train_loss_values)}, step=completed_steps)
329
+ train_loss_values = []
330
+
331
+
332
+ if completed_steps % args.eval_steps == 0:
333
+ model.eval()
334
+ losses = []
335
+ for step, batch in enumerate(eval_dataloader):
336
+ with torch.no_grad():
337
+ outputs = model(**batch)
338
+
339
+ loss = outputs.loss
340
+ losses.append(accelerator.gather(loss.repeat(args.per_device_batch_size)))
341
+
342
+ losses = torch.cat(losses)
343
+ losses = losses[: len(eval_dataset)]
344
+ try:
345
+ eval_loss = torch.mean(losses)
346
+ except OverflowError:
347
+ eval_loss = float("inf")
348
+
349
+ logger.info(f"step {completed_steps}: perplexity: {eval_loss}")
350
+ if accelerator.is_main_process:
351
+ wandb.log({"eval/loss": eval_loss}, step=completed_steps)
352
+
353
+ model.train()
354
+
355
+ #Save model
356
+ accelerator.wait_for_everyone()
357
+ if accelerator.is_main_process:
358
+ unwrapped_model = accelerator.unwrap_model(model)
359
+ unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
360
+ with open(os.path.join(output_dir, "train_steps.log"), 'a') as fOut:
361
+ fOut.write(f"{completed_steps}: {eval_loss}\n")
362
+
363
+ #Save best model
364
+ if eval_loss < best_eval_loss:
365
+ best_eval_loss = eval_loss
366
+ unwrapped_model.save_pretrained(best_ckp_dir, save_function=accelerator.save)
367
+ with open(os.path.join(best_ckp_dir, "train_steps.log"), 'a') as fOut:
368
+ fOut.write(f"{completed_steps}: {eval_loss}\n")
369
+
370
+ if accelerator.is_main_process and completed_steps % args.ckp_steps == 0:
371
+ ckp_dir = os.path.join(output_dir, f"ckp-{int(completed_steps/1000)}k")
372
+ unwrapped_model = accelerator.unwrap_model(model)
373
+ unwrapped_model.save_pretrained(ckp_dir, save_function=accelerator.save)
374
+ tokenizer.save_pretrained(ckp_dir)
375
+ with open(os.path.join(ckp_dir, "train_steps.log"), 'a') as fOut:
376
+ fOut.write(f"{completed_steps}: {eval_loss}\n")
377
+
378
+
379
+ if completed_steps >= args.max_train_steps:
380
+ break
381
+
382
+ if args.output_dir is not None:
383
+ accelerator.wait_for_everyone()
384
+ if accelerator.is_main_process:
385
+ unwrapped_model = accelerator.unwrap_model(model)
386
+ unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
387
+ with open(os.path.join(output_dir, "train_steps.log"), 'a') as fOut:
388
+ fOut.write(f"{completed_steps}\n")
389
+
390
+
391
+
392
+
393
+ if __name__ == "__main__":
394
+ main()
395
+
396
+
397
+ # Script was called via:
398
+ #python train_mlm-iterable.py --train_file data/c4_msmarco_news_s2orc_wiki_train.txt --dev_file data/c4_msmarco_news_s2orc_wiki_dev.txt --train_samples 100000000 --model_name train-w2v-model/c4_msmarco_news_s2orc_wiki/distilbert-256k/
train_steps.log ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 10000: 2.9510703086853027
2
+ 20000: 2.6985881328582764
3
+ 30000: 2.6190617084503174
4
+ 40000: 2.554388999938965
5
+ 50000: 2.5151987075805664
6
+ 60000: 2.4846017360687256
7
+ 70000: 2.456629514694214
8
+ 80000: 2.4532623291015625
9
+ 90000: 2.4096949100494385
10
+ 100000: 2.390449047088623
11
+ 110000: 2.379936695098877
12
+ 120000: 2.3351876735687256
13
+ 160000: 2.326107978820801
14
+ 170000: 2.308845043182373
15
+ 190000: 2.269522190093994
16
+ 210000: 2.256661891937256
17
+ 240000: 2.2510366439819336
18
+ 280000: 2.24812388420105
19
+ 290000: 2.2429370880126953
20
+ 330000: 2.2274234294891357
21
+ 340000: 2.2205779552459717
22
+ 370000: 2.2072439193725586
23
+ 390000: 2.187239170074463
24
+ 400000: 2.1820247173309326
25
+ 420000: 2.1815905570983887
26
+ 450000: 2.175816774368286
27
+ 510000: 2.165159225463867
28
+ 540000: 2.1477503776550293
29
+ 590000: 2.1152608394622803
30
+ 710000: 2.0839574337005615
31
+ 870000: 2.0824830532073975
32
+ 950000: 2.0809004306793213
33
+ 970000: 2.0796029567718506
34
+ 980000: 2.049144983291626
35
+ 1050000: 2.0382091999053955
36
+ 1090000: 2.0330984592437744
37
+ 1160000: 2.029167652130127
38
+ 1220000: 2.018043041229248
39
+ 1250000: 2.005221366882324
40
+ 1340000: 1.9879658222198486
41
+ 1510000: 1.983816146850586
42
+ 1530000: 1.9801620244979858
43
+ 1550000: 1.9742683172225952