import logging import os import time from argparse import Namespace from pathlib import Path import datasets import torch from accelerate import Accelerator, DistributedType from accelerate.utils import ProjectConfiguration from arguments import TrainingArguments from datasets import load_dataset from huggingface_hub import Repository from torch.optim import AdamW from torch.utils.data import IterableDataset from torch.utils.data.dataloader import DataLoader from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe import transformers from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed class ConstantLengthDataset(IterableDataset): """ Iterable dataset that returns constant length chunks of tokens from stream of text files. Args: tokenizer (Tokenizer): The processor used for proccessing the data. dataset (dataset.Dataset): Dataset with text files. infinite (bool): If True the iterator is reset after dataset reaches end else stops. seq_length (int): Length of token sequences to return. num_of_sequences (int): Number of token sequences to keep in buffer. chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer. tokenized (bool): If true we use a pretokenized dataset. """ def __init__( self, tokenizer, dataset, infinite=False, seq_length=1024, num_of_sequences=1024, chars_per_token=3.6, tokenized=False, ): self.tokenizer = tokenizer self.concat_token_id = tokenizer.bos_token_id self.dataset = dataset self.seq_length = seq_length self.epoch = 0 self.infinite = infinite self.current_size = 0 self.tokenized = tokenized if self.tokenized: self.max_buffer_size = seq_length * num_of_sequences self.content_field = "input_ids" else: self.max_buffer_size = seq_length * chars_per_token * num_of_sequences self.content_field = "content" def __iter__(self): iterator = iter(self.dataset) more_examples = True while more_examples: buffer, buffer_len = [], 0 while True: if buffer_len >= self.max_buffer_size: break try: buffer.append(next(iterator)[self.content_field]) buffer_len += len(buffer[-1]) except StopIteration: if self.infinite: iterator = iter(self.dataset) self.epoch += 1 logger.info(f"Dataset epoch: {self.epoch}") else: more_examples = False break if self.tokenized: tokenized_inputs = buffer else: tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"] all_token_ids = [] for tokenized_input in tokenized_inputs: all_token_ids.extend(tokenized_input + [self.concat_token_id]) for i in range(0, len(all_token_ids), self.seq_length): input_ids = all_token_ids[i : i + self.seq_length] if len(input_ids) == self.seq_length: self.current_size += 1 yield torch.tensor(input_ids) def shuffle(self, buffer_size=1000): return ShufflerIterDataPipe(self, buffer_size=buffer_size) def setup_logging(args): project_name = args.model_ckpt.split("/")[-1] logger = logging.getLogger(__name__) log_dir = Path(args.save_dir) / "log/" log_dir.mkdir(exist_ok=True) filename = f"debug_{accelerator.process_index}.log" logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, handlers=[logging.FileHandler(log_dir / filename), logging.StreamHandler()], ) if accelerator.is_main_process: # we only want to setup logging once accelerator.init_trackers(project_name, vars(args)) run_name = accelerator.trackers[0].run.name logger.setLevel(logging.INFO) datasets.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info() else: run_name = "" logger.setLevel(logging.ERROR) datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() return logger, run_name def create_dataloaders(args): ds_kwargs = {"streaming": True} train_data = load_dataset(args.dataset_name_train, split="train", **ds_kwargs) train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed) valid_data = load_dataset(args.dataset_name_valid, split="train", **ds_kwargs) train_dataset = ConstantLengthDataset( tokenizer, train_data, infinite=True, seq_length=args.seq_length, tokenized=args.tokenized ) valid_dataset = ConstantLengthDataset( tokenizer, valid_data, infinite=False, seq_length=args.seq_length, tokenized=args.tokenized ) train_dataset = train_dataset.shuffle(buffer_size=args.shuffle_buffer) train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True) eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size) return train_dataloader, eval_dataloader def get_grouped_params(model, args, no_decay=["bias", "ln_1.weight", "ln_2.weight", "ln_f.weight"]): params_with_wd, params_without_wd = [], [] for n, p in model.named_parameters(): if any(nd in n for nd in no_decay): params_without_wd.append(p) else: params_with_wd.append(p) return [ {"params": params_with_wd, "weight_decay": args.weight_decay}, {"params": params_without_wd, "weight_decay": 0.0}, ] def log_metrics(step, metrics): logger.info(f"Step {step}: {metrics}") if accelerator.is_main_process: accelerator.log(metrics, step) def compute_tflops(elapsed_time, accelerator, args): # TFLOPs formula (from Equation 3 in Section 5.1 of https://arxiv.org/pdf/2104.04473.pdf). config_model = accelerator.unwrap_model(model).config checkpoint_factor = 4 if args.gradient_checkpointing else 3 batch_size = args.train_batch_size * accelerator.state.num_processes * args.gradient_accumulation_steps factor = 24 * checkpoint_factor * batch_size * args.seq_length * config_model.n_layer * (config_model.n_embd**2) flops_per_iteration = factor * ( 1.0 + (args.seq_length / (6.0 * config_model.n_embd)) + (tokenizer.vocab_size / (16.0 * config_model.n_layer * config_model.n_embd)) ) tflops = flops_per_iteration / (elapsed_time * accelerator.state.num_processes * (10**12)) return tflops def evaluate(args): model.eval() losses = [] for step, batch in enumerate(eval_dataloader): with torch.no_grad(): outputs = model(batch, labels=batch) loss = outputs.loss.repeat(args.valid_batch_size) losses.append(accelerator.gather(loss)) if args.max_eval_steps > 0 and step >= args.max_eval_steps: break losses = torch.cat(losses) loss = losses[: eval_dataloader.dataset.current_size].mean() try: perplexity = torch.exp(loss) except OverflowError: perplexity = float("inf") return loss.item(), perplexity.item() # Settings parser = HfArgumentParser(TrainingArguments) args = parser.parse_args() # Accelerator config = ProjectConfiguration(project_dir=args.save_dir, logging_dir="log") accelerator = Accelerator(log_with=["wandb", "tensorboard"], project_config=config) acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()} args = Namespace(**vars(args), **acc_state) samples_per_step = accelerator.state.num_processes * args.train_batch_size set_seed(args.seed) # Clone model repository if accelerator.is_main_process: hf_repo = Repository(args.save_dir, clone_from=args.model_ckpt) # Logging logger, run_name = setup_logging(args) logger.info(accelerator.state) # Checkout new branch on repo if accelerator.is_main_process: hf_repo.git_checkout(run_name, create_branch_ok=True) # Load model and tokenizer model = AutoModelForCausalLM.from_pretrained(args.save_dir) if args.gradient_checkpointing: model.gradient_checkpointing_enable() tokenizer = AutoTokenizer.from_pretrained(args.save_dir) # Load dataset and dataloader train_dataloader, eval_dataloader = create_dataloaders(args) # Prepare the optimizer and learning rate scheduler optimizer = AdamW(get_grouped_params(model, args), lr=args.learning_rate) lr_scheduler = get_scheduler( name=args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=args.num_warmup_steps, num_training_steps=args.max_train_steps, ) accelerator.register_for_checkpointing(lr_scheduler) def get_lr(): return optimizer.param_groups[0]["lr"] # Prepare everything with our `accelerator`. model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( model, optimizer, train_dataloader, eval_dataloader ) # load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.load_state(args.resume_from_checkpoint) path = os.path.basename(args.resume_from_checkpoint) else: # Get the most recent checkpoint dirs = [f.name for f in os.scandir(args.save_dir) if f.is_dir() and "step" in str(f)] dirs.sort(key=os.path.getctime) path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last # Extract the step of the checkpoint to continue from there training_difference = os.path.splitext(path)[0] resume_step = int(training_difference.replace("step_", "")) # Train model model.train() completed_steps = 0 t_start = time.time() loss_tracking = 0 for step, batch in enumerate(train_dataloader, start=1): if args.resume_from_checkpoint and step < resume_step: continue # we need to skip steps until we reach the resumed step loss = model(batch, labels=batch, use_cache=False).loss avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() loss_tracking += avg_loss.item() / args.gradient_accumulation_steps log_metrics(step, {"samples": step * samples_per_step, "loss_per_step/train": loss.item()}) loss = loss / args.gradient_accumulation_steps if step % args.gradient_accumulation_steps != 0: # Prevent backward from doing gradient all_reduce in every step if accelerator.distributed_type == DistributedType.MULTI_GPU: with model.no_sync(): accelerator.backward(loss) else: accelerator.backward(loss) else: lr = get_lr() accelerator.backward(loss) accelerator.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() lr_scheduler.step() optimizer.zero_grad() elapsed_time = time.time() - t_start tflops = compute_tflops(elapsed_time, accelerator, args) log_metrics( step, { "steps": completed_steps, "loss/train": loss_tracking, "lr": lr, "tflops": tflops, "time_per_iteration": elapsed_time, }, ) t_start = time.time() loss_tracking = 0 completed_steps += 1 if step % args.save_checkpoint_steps == 0: logger.info("Evaluating and saving model checkpoint") eval_loss, perplexity = evaluate(args) log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity}) accelerator.wait_for_everyone() save_dir = os.path.join(args.save_dir, f"step_{step}") accelerator.save_state(save_dir) if accelerator.is_main_process: hf_repo.push_to_hub(commit_message=f"step {step}") model.train() if completed_steps >= args.max_train_steps: break # Evaluate and save the last checkpoint logger.info("Evaluating and saving model after training") eval_loss, perplexity = evaluate(args) log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity}) accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save) save_dir = os.path.join(args.save_dir, f"step_{step}") accelerator.save_state(save_dir) if accelerator.is_main_process: hf_repo.push_to_hub(commit_message="final model")