# Copyright 2025-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Main entry point to run the experiments. Contains general setup and the proper training code. """ import argparse import datetime as dt import gc import json import os import random import sys import textwrap import time from contextlib import ContextManager, nullcontext from functools import partial from typing import Any, Callable, Literal, Optional import torch from torch import nn from torch.amp import GradScaler, autocast from tqdm import tqdm from transformers import GenerationConfig, set_seed from utils import ( FILE_NAME_TRAIN_PARAMS, BucketIterator, TrainResult, TrainStatus, get_accuracy, get_base_model_info, get_dataset_info, get_file_size, get_model, get_optimizer_and_scheduler, get_peft_branch, get_tokenizer, get_train_config, init_cuda, log_results, validate_experiment_path, ) from data import get_train_valid_test_datasets from peft import AdaLoraConfig, PeftConfig from peft.utils import CONFIG_NAME # # suppress all warnings # warnings.filterwarnings("ignore") # FIXME? dtype_to_bytes_linear = {"float32": 4, "float16": 2, "bfloat16": 2, "int8": 1, "int4": 0.5} # if lr scheduler with warmup is used, the ratio of warmup steps to total steps BUCKET_FACTOR = 20 # number of batches per bucket, increasing this further has diminishing returns def get_generation_config(*, seq_len, generate_kwargs) -> GenerationConfig: # filter out None values so that we don't depend on setting correct defaults in the config generation_kwargs = {k: v for k, v in generate_kwargs.items() if v is not None} if ("max_length" in generation_kwargs) and ("max_new_tokens" in generation_kwargs): # transformers does not support setting both max_length and max_new_tokens, but what we want in this case is to # take the smaller of the two values new_max_length = min(generation_kwargs["max_new_tokens"] + seq_len, generation_kwargs["max_length"]) del generation_kwargs["max_new_tokens"] generation_kwargs["max_length"] = new_max_length generation_config = GenerationConfig(**generate_kwargs) return generation_config def evaluate(model, tokenizer, ds, batch_size, generate_kwargs, use_tqdm: bool = False) -> tuple[list[str], list[str]]: with torch.inference_mode(): predictions = [] responses = [] pbar = range(0, len(ds), batch_size) if use_tqdm: pbar = tqdm(pbar) for j in pbar: sliced = ds[j : j + batch_size] responses += sliced.pop("response") batch = tokenizer.pad(sliced, return_tensors="pt", padding_side="left").to(model.device) seq_len = batch["input_ids"].shape[1] generation_config = get_generation_config(seq_len=seq_len, generate_kwargs=generate_kwargs) outputs = model.generate(**batch, generation_config=generation_config, pad_token_id=tokenizer.eos_token_id) predictions += tokenizer.batch_decode(outputs, skip_special_tokens=True) return predictions, responses class DummyGradScaler: # if no mixed precision is being used def scale(self, loss): return loss def unscale_(self, optimizer): pass def step(self, optimizer): optimizer.step() def update(self): pass def train( *, model: nn.Module, max_steps: int, batch_size: int, batch_size_eval: int, tokenizer: Any, cuda_memory_init: int, eval_steps: int, generation_kwargs: dict[str, Any], grad_norm_clip: float, optimizer_type: str, optimizer_kwargs: dict[str, Any], query_template: str, lr_scheduler_arg: Optional[Literal["cosine"]], use_amp: bool, is_adalora: bool, ) -> TrainResult: cuda_memory_allocated_log = [] cuda_memory_reserved_log = [] losses = [] durations = [] metrics = [] sample = 0 # keep count of the current sample total_samples = 0 # total number of samples over all epochs total_tokens = [] # total number of tokens over all epochs if use_amp: grad_scaler: GradScaler | DummyGradScaler = GradScaler(device="cuda") autocast_ctx: Callable[[], ContextManager[Any]] = partial(autocast, device_type="cuda") else: grad_scaler = DummyGradScaler() autocast_ctx = nullcontext optimizer, lr_scheduler = get_optimizer_and_scheduler( model, optimizer_type=optimizer_type, max_steps=max_steps, lr_scheduler_arg=lr_scheduler_arg, **optimizer_kwargs, ) # print this after getting the optimizer, in case it modifies requires_gard if hasattr(model, "get_nb_trainable_parameters"): num_trainable_params, num_params = model.get_nb_trainable_parameters() else: num_params = model.num_parameters() num_trainable_params = num_params print_verbose( f"trainable params: {num_trainable_params:,d} || all params: {num_params:,d} || " f"trainable: {100 * num_trainable_params / num_params:.4f}%" ) status = TrainStatus.FAILED tic_train = time.perf_counter() eval_time = 0.0 error_msg = "" ds_train, ds_valid, ds_test = get_train_valid_test_datasets( tokenizer=tokenizer, query_template=query_template, print_fn=print_verbose ) # note: bucketing by length is only really worth it for the train dataset, since it's length is big compared to the # batch size iterator_train = BucketIterator( ds_train, batch_size=batch_size, bucket_factor=BUCKET_FACTOR, delete_cols=["response"], ) try: pbar = tqdm(range(1, max_steps + 1)) for step, batch in zip(pbar, iterator_train): tic = time.perf_counter() # create the batch tokens_per_sample = [len(i) for i in batch["input_ids"]] total_tokens.append(sum(tokens_per_sample) + len(tokens_per_sample)) # add EOS token batch = tokenizer.pad(batch, return_tensors="pt").to(model.device) actual_batch_size = len(batch["input_ids"]) total_samples += actual_batch_size sample += batch_size if sample >= len(ds_train): # new epoch sample = 0 # add labels, they are automatically shifted by transformers labels = batch["input_ids"].clone() # We want to ignore the padding tokens except for the first EOS token; if we don't ignore them, the loss # will be dominated by padding tokens; if we ignore all, the model will not learn to predict the EOS token. # TODO: Note that the longest sequence in the batch won't have any PAD/EOS token at the end, this is fine if # the batch size is > 1 but should still be fixed eventually. for i, num_tokens in enumerate(tokens_per_sample): labels[i, num_tokens + 1 :] = -100 batch["labels"] = labels num_items_in_batch = batch["attention_mask"].sum().item() # train step optimizer.zero_grad() with autocast_ctx(): outputs = model(**batch, num_items_in_batch=num_items_in_batch) loss = outputs.loss grad_scaler.scale(loss).backward() if grad_norm_clip: grad_scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm_clip) grad_scaler.step(optimizer) grad_scaler.update() lr_scheduler.step() if is_adalora: model.base_model.update_and_allocate(step) losses.append(loss.item()) pbar.set_postfix({"loss": loss.item()}) cuda_memory_allocated_log.append(torch.cuda.memory_allocated() - cuda_memory_init) cuda_memory_reserved_log.append(torch.cuda.memory_reserved() - cuda_memory_init) toc = time.perf_counter() durations.append(toc - tic) # every couple of steps, evaluate; this can be slow due to generation if step % eval_steps == 0: tic_eval = time.perf_counter() loss_avg = sum(losses[-eval_steps:]) / eval_steps memory_allocated_avg = sum(cuda_memory_allocated_log[-eval_steps:]) / eval_steps memory_reserved_avg = sum(cuda_memory_reserved_log[-eval_steps:]) / eval_steps token_sum = sum(total_tokens[-eval_steps:]) dur_train = sum(durations[-eval_steps:]) tokens_per_sec = token_sum / dur_train model.eval() predictions, responses = evaluate( model=model, tokenizer=tokenizer, ds=ds_valid, batch_size=batch_size_eval, generate_kwargs={**generation_kwargs}, ) model.train() example = random.choice(predictions) example = textwrap.shorten(example, width=750) example = textwrap.indent(example, " ") print_verbose(f"\nExample prediction:\n{example}\n") accuracy = get_accuracy(predictions=predictions, responses=responses) num_tokens_generated = sum(sum(mask) for mask in tokenizer(predictions)["attention_mask"]) toc_eval = time.perf_counter() dur_eval = toc_eval - tic_eval eval_time += toc_eval - tic_eval elapsed = time.perf_counter() - tic_train metrics.append( { "step": step, "valid accuracy": accuracy, "train loss": loss_avg, "train samples": total_samples, "train time": dur_train, "eval time": dur_eval, "tokens / sec": tokens_per_sec, "mem allocated avg": memory_allocated_avg, "mem reserved avg": memory_reserved_avg, "elapsed time": elapsed, } ) log_dict = { "step": f"{step:5d}", "samples": f"{total_samples:7d}", "lr": f"{lr_scheduler.get_last_lr()[0]:.2e}", "loss avg": f"{loss_avg:.4f}", "valid acc": f"{accuracy:.3f}", "gen valid tokens": num_tokens_generated, "train time": f"{dur_train:.1f}s", "eval time": f"{dur_eval:.1f}s", "train tokens / sec": f"{tokens_per_sec:.0f}", "mem allocated": f"{memory_allocated_avg:.0f}", "mem reserved": f"{memory_reserved_avg:.0f}", "elapsed time": f"{elapsed // 60:.0f}min {elapsed % 60:.0f}s", } print_verbose(json.dumps(log_dict)) # # TODO is this needed? torch.cuda.empty_cache() gc.collect() print_verbose(f"Training finished after {max_steps} steps, evaluation on test set follows.") # test set evaluation model.eval() predictions, responses = evaluate( model=model, tokenizer=tokenizer, ds=ds_test, batch_size=batch_size_eval, generate_kwargs={**generation_kwargs, "pad_token_id": tokenizer.eos_token_id}, use_tqdm=len(ds_test) > 100, ) accuracy = get_accuracy(predictions=predictions, responses=responses) metrics.append( { "step": step, "test accuracy": accuracy, "train loss": sum(losses[-eval_steps:]) / eval_steps, "train samples": total_samples, "train total tokens": sum(total_tokens), } ) print_verbose(f"Test accuracy: {accuracy:.3f}") except KeyboardInterrupt: print_verbose("canceled training") status = TrainStatus.CANCELED error_msg = "manually canceled" except torch.OutOfMemoryError as exc: # ouch, still let's try to log some results print_verbose("out of memory error encountered") status = TrainStatus.CANCELED error_msg = str(exc) except Exception as exc: print_verbose(f"encountered an error: {exc}") status = TrainStatus.CANCELED error_msg = str(exc) toc_train = time.perf_counter() train_time = toc_train - tic_train - eval_time if status != TrainStatus.CANCELED: status = TrainStatus.SUCCESS train_result = TrainResult( status=status, train_time=train_time, cuda_memory_reserved_log=cuda_memory_reserved_log, losses=losses, metrics=metrics, error_msg=error_msg, num_trainable_params=num_trainable_params, num_total_params=num_params, ) return train_result def main(*, path_experiment: str, experiment_name: str, clean: bool) -> None: tic_total = time.perf_counter() start_date = dt.datetime.now(tz=dt.timezone.utc).replace(microsecond=0).isoformat() peft_branch = get_peft_branch() if peft_branch == "main": print_verbose("===== This experiment is categorized as a MAIN run because the PEFT branch is 'main' ======") else: print_verbose( f"===== This experiment is categorized as a TEST run because the PEFT branch is '{peft_branch}' ======" ) # load configs peft_config: Optional[PeftConfig] = None if os.path.exists(os.path.join(path_experiment, CONFIG_NAME)): peft_config = PeftConfig.from_pretrained(path_experiment) else: print_verbose(f"Could not find PEFT config at {path_experiment}, performing FULL FINETUNING") path_train_config = os.path.join(path_experiment, FILE_NAME_TRAIN_PARAMS) train_config = get_train_config(path_train_config) set_seed(train_config.seed) # initialize objects cuda_memory_init = init_cuda() tokenizer = get_tokenizer(model_id=train_config.model_id, max_seq_length=train_config.max_seq_length) model_info = get_base_model_info(train_config.model_id) metamath_info = get_dataset_info("meta-math/MetaMathQA") gsm8k_info = get_dataset_info("openai/gsm8k") model = get_model( model_id=train_config.model_id, dtype=train_config.dtype, compile=train_config.compile, attn_implementation=train_config.attn_implementation, peft_config=peft_config, autocast_adapter_dtype=train_config.autocast_adapter_dtype, ) print_verbose(model) # train model train_result = train( model=model, max_steps=train_config.max_steps, batch_size=train_config.batch_size, batch_size_eval=train_config.batch_size_eval, tokenizer=tokenizer, cuda_memory_init=cuda_memory_init, eval_steps=train_config.eval_steps, generation_kwargs=train_config.generation_kwargs, grad_norm_clip=train_config.grad_norm_clip, optimizer_type=train_config.optimizer_type, optimizer_kwargs=train_config.optimizer_kwargs, query_template=train_config.query_template, lr_scheduler_arg=train_config.lr_scheduler, use_amp=train_config.use_amp, is_adalora=isinstance(peft_config, AdaLoraConfig), ) if train_result.status == TrainStatus.FAILED: print_verbose("Training failed, not logging results") sys.exit(1) file_size = get_file_size( model, peft_config=peft_config, clean=clean, print_fn=print_verbose, ) time_total = time.perf_counter() - tic_total # log results: print and save to file log_results( experiment_name=experiment_name, train_result=train_result, cuda_memory_init=cuda_memory_init, time_total=time_total, file_size=file_size, model_info=model_info, datasets_info={"metamath": metamath_info, "gsm8k": gsm8k_info}, start_date=start_date, train_config=train_config, peft_config=peft_config, print_fn=print_verbose, ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose output") parser.add_argument("path_experiment", type=str, help="Path to the experiment directory") parser.add_argument( "--clean", action="store_true", help="Delete training artifacts after run finishes (logs are still saved)", ) args = parser.parse_args() experiment_name = validate_experiment_path(args.path_experiment) if args.verbose: def print_verbose(*args, **kwargs) -> None: kwargs["file"] = sys.stderr print(*args, **kwargs) else: def print_verbose(*args, **kwargs) -> None: pass main( path_experiment=args.path_experiment, experiment_name=experiment_name, clean=args.clean, )