lvwerra HF Staff commited on
Commit
c8dd7f5
·
1 Parent(s): c9c953d

add training script

Browse files
Files changed (1) hide show
  1. codeparrot_training.py +202 -0
codeparrot_training.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2LMHeadModel, AutoTokenizer
2
+ from transformers import AdamW, get_scheduler, set_seed
3
+ from datasets import load_dataset
4
+ from accelerate import Accelerator
5
+ import datasets, transformers
6
+ from huggingface_hub import Repository
7
+
8
+ from torch.utils.data import IterableDataset
9
+ from torch.utils.data.dataloader import DataLoader
10
+ from torch.utils.tensorboard import SummaryWriter
11
+ from argparse import Namespace
12
+ import torch
13
+ import logging
14
+ import wandb
15
+
16
+
17
+ class ConstantLengthDataset(IterableDataset):
18
+
19
+ def __init__(self, tokenizer, dataset, seq_length=1024, batch_size=3,
20
+ num_of_sequences=1024, chars_per_token=3.6):
21
+ self.tokenizer = tokenizer
22
+ self.concatenation_token = tokenizer.bos_token
23
+ self.dataset = dataset
24
+ self.seq_length = seq_length
25
+ self.batch_size = batch_size
26
+ self.input_characters = seq_length * chars_per_token * num_of_sequences
27
+
28
+ def __iter__(self):
29
+ iterator = iter(self.dataset)
30
+ more_examples = True
31
+ batch = []
32
+ while more_examples:
33
+ buffer = ''
34
+ while True:
35
+ if len(buffer) >= self.input_characters:
36
+ break
37
+ try:
38
+ next_example = next(iterator)['content']
39
+ buffer = buffer + self.concatenation_token + next_example
40
+ except StopIteration:
41
+ more_examples = False
42
+ break
43
+
44
+ tokenized_input = tokenizer(buffer, truncation=True,
45
+ max_length=self.seq_length,
46
+ return_overflowing_tokens=True)
47
+
48
+ for input_ids in tokenized_input['input_ids']:
49
+ if len(input_ids) == self.seq_length:
50
+ batch.append(input_ids)
51
+ if len(batch) == self.batch_size:
52
+ yield torch.tensor(batch)
53
+ batch = []
54
+
55
+ def setup_logging(project_name):
56
+ logger = logging.getLogger(__name__)
57
+ logging.basicConfig(
58
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
59
+ datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,)
60
+ if accelerator.is_main_process: # we only want to setup logging once
61
+ wandb.init(project=project_name, config=args)
62
+ tb_writer = SummaryWriter()
63
+ tb_writer.add_hparams(vars(args), {'0': 0})
64
+ logger.setLevel(logging.INFO)
65
+ datasets.utils.logging.set_verbosity_warning()
66
+ transformers.utils.logging.set_verbosity_info()
67
+ else:
68
+ logger.setLevel(logging.ERROR)
69
+ datasets.utils.logging.set_verbosity_error()
70
+ transformers.utils.logging.set_verbosity_error()
71
+ return logger, tb_writer
72
+
73
+ def create_dataloaders(dataset_name):
74
+ train_data = load_dataset(dataset_name+'-train', split="train",
75
+ streaming=True)
76
+ train_data = train_data.shuffle(buffer_size=args.shuffle_buffer)
77
+ valid_data = load_dataset(dataset_name+'-valid', split="train",
78
+ streaming=True)
79
+
80
+ train_dataset = ConstantLengthDataset(tokenizer, train_data,
81
+ seq_length=args.seq_length,
82
+ batch_size=args.train_batch_size)
83
+ valid_dataset = ConstantLengthDataset(tokenizer, valid_data,
84
+ seq_length=args.seq_length,
85
+ batch_size=args.valid_batch_size)
86
+
87
+ train_dataloader=DataLoader(train_dataset, batch_size=args.train_batch_size)
88
+ eval_dataloader=DataLoader(valid_dataset, batch_size=args.valid_batch_size)
89
+ return train_dataloader, eval_dataloader
90
+
91
+ def get_grouped_params(model, no_decay=["bias", "LayerNorm.weight"]):
92
+ params_with_wd, params_without_wd = [], []
93
+ for n, p in model.named_parameters():
94
+ if any(nd in n for nd in no_decay): params_without_wd.append(p)
95
+ else: params_with_wd.append(p)
96
+ return [{'params': params_with_wd, 'weight_decay': args.weight_decay},
97
+ {'params': params_without_wd, 'weight_decay': 0.0}]
98
+
99
+ def log_metrics(step, metrics):
100
+ logger.info(f"Step {step}: {metrics}")
101
+ if accelerator.is_main_process:
102
+ wandb.log(metrics)
103
+ [tb_writer.add_scalar(k, v, step) for k, v in metrics.items()]
104
+
105
+ def evaluate():
106
+ model.eval()
107
+ losses = []
108
+ for step, batch in enumerate(eval_dataloader):
109
+ with torch.no_grad():
110
+ outputs = model(batch[0], labels=batch[0])
111
+ loss = outputs.loss.repeat(args.valid_batch_size)
112
+ losses.append(accelerator.gather(loss))
113
+ if args.max_eval_steps > 0 and step >= args.max_eval_steps: break
114
+ loss = torch.mean(torch.cat(losses))
115
+ try: perplexity = torch.exp(loss)
116
+ except OverflowError: perplexity = float("inf")
117
+ return loss.item(), perplexity.item()
118
+
119
+ # Hyperparameters
120
+ project_name = 'transformersbook/codeparrot-small'
121
+ dataset_name = 'transformersbook/codeparrot'
122
+ config = {"train_batch_size": 4,
123
+ "valid_batch_size": 4,
124
+ "weight_decay": 0.1,
125
+ "shuffle_buffer": 1000,
126
+ "learning_rate": 5e-4,
127
+ "lr_scheduler_type": "cosine",
128
+ "num_warmup_steps": 1000,
129
+ "gradient_accumulation_steps": 8,
130
+ "max_train_steps": 4096,
131
+ "max_eval_steps": 1024,
132
+ "seq_length": 1024,
133
+ "seed": 1,
134
+ "save_checkpoint_steps":4096,}
135
+ args = Namespace(**config)
136
+ set_seed(args.seed)
137
+
138
+ # Accelerator
139
+ accelerator = Accelerator()
140
+ samples_per_step = accelerator.state.num_processes * args.train_batch_size
141
+
142
+ # Logging
143
+ logger, tb_writer = setup_logging(project_name.split("/")[1])
144
+ logger.info(accelerator.state)
145
+ run_name = wandb.run.name
146
+
147
+ # Load model and tokenizer
148
+ hf_repo = Repository("./", clone_from=project_name, revision=run_name)
149
+ model = GPT2LMHeadModel.from_pretrained("./")
150
+ tokenizer = AutoTokenizer.from_pretrained("./")
151
+
152
+ # Load dataset and dataloader
153
+ train_dataloader, eval_dataloader = create_dataloaders(dataset_name)
154
+
155
+ # Prepare the optimizer and learning rate scheduler
156
+ optimizer = AdamW(get_grouped_params(model), lr=args.learning_rate)
157
+ lr_scheduler = get_scheduler(name=args.lr_scheduler_type, optimizer=optimizer,
158
+ num_warmup_steps=args.num_warmup_steps,
159
+ num_training_steps=args.max_train_steps,)
160
+ def get_lr(): return optimizer.param_groups[0]['lr']
161
+
162
+ # Prepare everything with our `accelerator`.
163
+ model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
164
+ model, optimizer, train_dataloader, eval_dataloader)
165
+
166
+ # Train model
167
+ model.train()
168
+ completed_steps = 0
169
+ for step, batch in enumerate(train_dataloader, start=1):
170
+ loss = model(batch[0], labels=batch[0]).loss
171
+ log_metrics(step, {'lr': get_lr(), 'samples': step*samples_per_step,
172
+ 'steps': completed_steps, 'loss/train': loss.item()})
173
+ loss = loss / args.gradient_accumulation_steps
174
+ accelerator.backward(loss)
175
+ if step % args.gradient_accumulation_steps == 0:
176
+ optimizer.step()
177
+ lr_scheduler.step()
178
+ optimizer.zero_grad()
179
+ completed_steps += 1
180
+ if step % args.save_checkpoint_steps == 0:
181
+ logger.info('Evaluating and saving model checkpoint')
182
+ eval_loss, perplexity = evaluate()
183
+ log_metrics(step, {'loss/eval': eval_loss, 'perplexity': perplexity})
184
+ accelerator.wait_for_everyone()
185
+ unwrapped_model = accelerator.unwrap_model(model)
186
+ unwrapped_model.save_pretrained("./")
187
+ if accelerator.is_main_process:
188
+ hf_repo.push_to_hub(commit_message=f'step {step}')
189
+ model.train()
190
+ if completed_steps >= args.max_train_steps:
191
+ break
192
+
193
+ # Evaluate and save the last checkpoint
194
+ logger.info('Evaluating and saving model after training')
195
+ eval_loss, perplexity = evaluate()
196
+ log_metrics(step, {'loss/eval': eval_loss, 'perplexity': perplexity})
197
+ accelerator.wait_for_everyone()
198
+ unwrapped_model = accelerator.unwrap_model(model)
199
+ unwrapped_model.save_pretrained("./")
200
+ if accelerator.is_main_process:
201
+ try: hf_repo.push_to_hub(commit_message=f'final model')
202
+ except: logger.info('No changes to previously saved model.')